Edit File: country.py
import logging from typing import FrozenSet, Iterable, List from defence360agent.contracts.config import CountryInfo from defence360agent.utils import timeit from im360.contracts.config import UnifiedAccessLogger from im360.model.country import CountryList from defence360agent.utils.validate import IP, IPVersion from .. import ip_versions from ..firewall import FirewallRules, firewall_logging_enabled, get_firewall from . import ( IP_SET_PREFIX, AbstractIPSet, IPSetAtomicRestoreBase, IPSetCount, get_ipset_family, libipset, ) from .libipset import IPSetCmdBuilder logger = logging.getLogger(__name__) def ips_for_country(country_code): subnets_file = CountryInfo.country_subnets_file(country_code) try: with open(subnets_file, encoding="utf-8") as f: for line in f: yield line.strip() except FileNotFoundError: logger.error("Can't find subnets file %s", subnets_file) return class IPSetCountryBlack: CHAIN = FirewallRules.COUNTRY_BLACKLIST_CHAIN PRIORITY = FirewallRules.BLACKLIST_PRIORITY def single_entry_rules(self, set_name, _): return [ FirewallRules.ipset_rule( set_name, FirewallRules.LOG_BLACKLISTED_COUNTRY_CHAIN ) ] class IPSetCountryWhite: CHAIN = FirewallRules.COUNTRY_WHITELIST_CHAIN PRIORITY = FirewallRules.WHITELIST_PRIORITY def single_entry_rules(self, set_name, ip_version: IPVersion): result = [] if firewall_logging_enabled(): result.append( FirewallRules.compose_rule( FirewallRules.ipset(set_name), action=FirewallRules.nflog_action( group=FirewallRules.nflog_group(ip_version), prefix=UnifiedAccessLogger.WHITELIST_COUNTRY, ), ) ) result.append(FirewallRules.ipset_rule(set_name, FirewallRules.ACCEPT)) return result class SingleIpSetCountry(IPSetAtomicRestoreBase): _NAME = "{prefix}.{ip_version}.country-{country_code}" MAX_ELEM = 524288 def __init__(self, country_code: str): super().__init__(country_code) self.country_code = country_code def gen_ipset_name_for_ip_version(self, ip_version: IPVersion) -> str: return self.custom_ipset_name or self._NAME.format( prefix=IP_SET_PREFIX, ip_version=ip_version, country_code=self.country_code.lower(), ) def gen_ipset_create_ops(self, ip_version: IPVersion) -> List[str]: ipset_options = self._get_ipset_create_options(ip_version) return [ IPSetCmdBuilder.get_create_cmd( self.gen_ipset_name_for_ip_version(ip_version), **ipset_options ) ] def gen_ipset_destroy_ops(self, ip_version: IPVersion) -> List[str]: ipset_name = self.gen_ipset_name_for_ip_version(ip_version) return [IPSetCmdBuilder.get_destroy_cmd(ipset_name)] def gen_ipset_flush_ops(self, ip_version: IPVersion) -> List[str]: return [ IPSetCmdBuilder.get_flush_cmd( self.gen_ipset_name_for_ip_version(ip_version) ) ] async def gen_ipset_restore_ops(self, ip_version: IPVersion) -> List[str]: commands = [] for ip in ips_for_country(self.country_code): try: version = IP.type_of(ip) except ValueError: logger.error( "{} is neither IPv4 nor IPv6 valid address".format(ip) ) continue if version != ip_version: continue set_name = self.gen_ipset_name_for_ip_version(ip_version=version) # get ips specific lines for ipset add_template = "add {set_name} {ip_net} -exist" commands.append(add_template.format(set_name=set_name, ip_net=ip)) return commands def _get_ipset_create_options(self, ip_version: IPVersion): return dict( family=get_ipset_family(ip_version), maxelem=self.MAX_ELEM, ) class IPSetCountry(AbstractIPSet): _LISTNAME = _CHAIN = _PRIORITY = None _IP_SETS = { CountryList.BLACK: IPSetCountryBlack(), CountryList.WHITE: IPSetCountryWhite(), } async def block(self, country_code, *args, **kwargs): """ Create ip set + rule :param country_code: ISO 3166-1 alpha-2 code :return: """ ipset = self._IP_SETS[kwargs["listname"]] commands = [] for ip_version in ip_versions.enabled(): ip_set = SingleIpSetCountry(country_code) async with await get_firewall(ip_version) as fw: set_name = ip_set.gen_ipset_name_for_ip_version(ip_version) await libipset.create_hash_set( set_name, **ip_set._get_ipset_create_options(ip_version) ) await fw.commit( [ fw.append_rule(r, chain=ipset.CHAIN) for r in ipset.single_entry_rules(set_name, ip_version) ] ) commands.extend(await ip_set.gen_ipset_restore_ops(ip_version)) await libipset.restore(commands) async def unblock(self, country_code, *args, **kwargs): """ Drop rule + ip set :param country_code: ISO 3166-1 alpha-2 code :return: """ ipset = self._IP_SETS[kwargs["listname"]] for ip_version in ip_versions.enabled(): ip_set = SingleIpSetCountry(country_code) async with await get_firewall(ip_version) as fw: set_name = ip_set.gen_ipset_name_for_ip_version(ip_version) await fw.commit( [ fw.delete_rule( rule, chain=ipset.CHAIN, ip_version=ip_version ) for rule in ipset.single_entry_rules( set_name, ip_version ) ] ) await libipset.delete_set(set_name) def gen_ipset_create_ops(self, ip_version: IPVersion) -> List[str]: """ Generate list of commands to create all ip sets :return: list of ipset commands to use with ipset restore """ result = [] for ip_set in self.get_all_ipset_instances(ip_version): result.extend(ip_set.gen_ipset_create_ops(ip_version)) return result async def gen_ipset_restore_ops(self, ip_version: IPVersion) -> List[str]: """ Generate list of commands to fill all ip sets :return: list of ipset commands to use with ipset restore """ commands = [] # type: List[str] for ipset in self.get_all_ipset_instances(ip_version): commands.append( IPSetCmdBuilder.get_flush_cmd( ipset.gen_ipset_name_for_ip_version(ip_version) ) ) commands.extend(await ipset.gen_ipset_restore_ops(ip_version)) return commands def _fetch(self): return [ (row["country"]["code"], row["listname"]) for row in CountryList.fetch() ] def get_all_ipsets(self, ip_version: IPVersion) -> FrozenSet[str]: return frozenset( ipset.gen_ipset_name_for_ip_version(ip_version) for ipset in self.get_all_ipset_instances(ip_version) ) def get_all_ipset_instances( self, ip_version: IPVersion ) -> List[IPSetAtomicRestoreBase]: return [ SingleIpSetCountry(country_code) for country_code, _ in self._fetch() ] def get_rules(self, ip_version: IPVersion, **kwargs) -> Iterable[dict]: result = [ dict( rule=FirewallRules.compose_action(ipset.CHAIN), chain=FirewallRules.IMUNIFY_INPUT_CHAIN, table=FirewallRules.FILTER, priority=ipset.PRIORITY, ) for ipset in self._IP_SETS.values() ] i = 0 for country, listname in self._fetch(): ip_set = SingleIpSetCountry(country) for rule in self._IP_SETS[listname].single_entry_rules( ip_set.gen_ipset_name_for_ip_version(ip_version), ip_version, ): result.append( dict( rule=rule, chain=self._IP_SETS[listname].CHAIN, table=FirewallRules.FILTER, priority=i, ) ) i += 1 return result async def restore(self, ip_version: IPVersion) -> None: with timeit("ipset_restore", logger): await libipset.restore( await self.gen_ipset_restore_ops(ip_version) ) async def get_ipsets_count( self, ip_version: IPVersion ) -> List[IPSetCount]: ipsets = [] for country, _ in self._fetch(): expected_count = sum( IP.type_of(ip) == ip_version for ip in ips_for_country(country) ) ip_set = SingleIpSetCountry(country) set_name = ip_set.gen_ipset_name_for_ip_version(ip_version) ipset_count = await libipset.get_ipset_count(set_name) ipsets.append( IPSetCount( name=set_name, db_count=expected_count, ipset_count=ipset_count, ) ) return ipsets def gen_ipset_flush_ops(self, ip_version, existing_ipsets): return self._gen_flush_id_cmds(ip_version, existing_ipsets) def _gen_flush_id_cmds(self, ip_version, existing_ipsets): logger.info("Flushing ipsets: %s", existing_ipsets) return ( libipset.IPSetCmdBuilder.get_flush_cmd(ipset_name) for ipset_name in existing_ipsets )