Edit File: __init__.py
"""Core module for rules and sets managing.""" import json from pathlib import Path import logging from typing import Iterable, List, Optional, Set, Tuple from defence360agent.internals.global_scope import g from im360.contracts.config import ( NetworkInterface, UnifiedAccessLogger, DOS, EnhancedDOS, ) from im360.internals.core.ipset.port_deny import ( InputPortBlockingDenyModeIPSet, OutputPortBlockingDenyModeIPSet, ) from defence360agent.utils.validate import IPVersion from im360.internals.core.firewall import Iptables from . import ip_versions from .firewall import ( FirewallRules, RuleDef, firewall_logging_enabled, is_nat_available, ) from .ipset import IP_SET_PREFIX, libipset from .ipset.country import IPSetCountry from .ipset.ip import IPSet from .ipset.libipset import IPSetCmdBuilder, IPSetRestoreCmd from .ipset.port import IPSetIgnoredByPort, IPSetPort from .ipset.redirect import ( IPSetNoRedirectPort, IPSetWebshieldPort, ) from .ipset.sync import IPSetSyncIPListPurpose, IPSetSyncIPListRecords logger = logging.getLogger(__name__) FAILED_IPSETS_FILE = "/var/imunify360/failed_ipsets_{ip_version}.json" class RuleSet: """Managing iptables rules and ipsets.""" _CHAINS = [ FirewallRules.COUNTRY_WHITELIST_CHAIN, FirewallRules.COUNTRY_BLACKLIST_CHAIN, FirewallRules.BP_INPUT_CHAIN, FirewallRules.LOG_BLACKLIST_CHAIN, FirewallRules.LOG_GRAYLIST_CHAIN, FirewallRules.LOG_BLACKLISTED_COUNTRY_CHAIN, FirewallRules.WEBSHIELD_PORTS_INPUT_CHAIN, FirewallRules.LOG_BLOCK_PORT_CHAIN, ] # Since DB and ipset are updated at different times, # check relative value instead of compare absolute values. # Use a large enough relative number to avoid false positives, # 20% difference looks reasonable for this. _IPSET_COUNT_TO_RECREATE_THRESHOLD = 0.2 def __init__(self): self.entities = ( InputPortBlockingDenyModeIPSet(), OutputPortBlockingDenyModeIPSet(), IPSetPort(), IPSet(), # Order is important here, # Ensure IPSetSyncIPListRecords is created before IPSetSyncIPListPurpose IPSetSyncIPListRecords(), IPSetSyncIPListPurpose(), IPSetCountry(), IPSetIgnoredByPort(), IPSetNoRedirectPort(), IPSetWebshieldPort(), ) @staticmethod def targets(ip_version: IPVersion) -> List[Tuple]: """ Returns tables & chains that Imunify360 will use in firewall management :param ip_version: IPv4 or IPv6 :return: List[Tuple]: """ return [ (FirewallRules.FILTER, "INPUT"), ( (FirewallRules.NAT, "PREROUTING") if is_nat_available(ip_version) else (FirewallRules.MANGLE, "PREROUTING") ), ] @staticmethod def _apply_ignored_interfaces(action, interface_conf, *args, **kwargs): """ :param interface_conf: interface configuration :param Callable action: action to perform with interface """ for interface in interface_conf[NetworkInterface.DEVICE_SKIP]: yield action( FirewallRules.compose_rule( FirewallRules.interface(interface), action=FirewallRules.compose_action(FirewallRules.ACCEPT), ), chain=FirewallRules.IMUNIFY_INPUT_CHAIN, priority=0, # max priority for firewalld *args, **kwargs, ) @staticmethod def _compose_rule(ip_version: IPVersion, interface_conf: dict) -> RuleDef: """Compose rule based on NetworkInterface config""" target_interface = interface_conf[ip_version] action = FirewallRules.compose_action( FirewallRules.IMUNIFY_INPUT_CHAIN ) if target_interface: rule = FirewallRules.compose_rule( FirewallRules.interface(target_interface), action=action ) else: rule = action return rule async def ipset_create_commands(self, ip_version: IPVersion) -> List[str]: names = [] # type: List[str] for entity in self.entities: names.extend(entity.gen_ipset_create_ops(ip_version)) return names async def ipset_flush_commands( self, ip_version: IPVersion, existing: Optional[Set[str]] = None ) -> Iterable[IPSetRestoreCmd]: """Generate ipset restore commands to destroy *existing* ipsets.""" if existing is None: existing = await self.existing_ipsets(ip_version) # get entity specific flush commands cmds = [] needed_entities = [ entity for entity in self.entities if hasattr(entity, "gen_ipset_flush_ops") ] for entity in needed_entities: cmds += entity.gen_ipset_flush_ops(ip_version, existing) return cmds async def ipset_destroy_commands( self, ip_version: IPVersion, existing: Optional[Set[str]] = None ) -> Iterable[IPSetRestoreCmd]: """Generate ipset restore commands to destroy *existing* ipsets.""" if existing is None: existing = await self.existing_ipsets(ip_version) # get entity specific destroy commands cmds = {} # type: Dict[str, IPSetRestoreCmd] for entity in self.entities: entity_cmds = entity.gen_ipset_destroy_ops(ip_version, existing) cmds.update(entity_cmds) # generic destroy for ipset_name in existing: if ipset_name not in cmds: # ipset is not special, remove using a generic destroy command cmds[ipset_name] = IPSetCmdBuilder.get_destroy_cmd(ipset_name) return cmds.values() async def create_commands( self, firewall, interface_conf: dict, ip_version: IPVersion ) -> list: """Return a list of firewall commands to create all required rules.""" actions = [] # input chains for table, chain in self.targets(ip_version): # main chain and rule actions.extend( [ firewall.create_chain( table=table, chain=FirewallRules.IMUNIFY_INPUT_CHAIN ), firewall.insert_rule( self._compose_rule(ip_version, interface_conf), table=table, chain=chain, ), *self._apply_ignored_interfaces( firewall.insert_rule, interface_conf, table=table ), ] ) # subchains actions.extend( [ firewall.create_chain(table=FirewallRules.FILTER, chain=chain) for chain in self._CHAINS ] ) # log block rules actions.extend(self._log_block_rules(firewall.append_rule, ip_version)) # output chains # main chain and rule actions.extend( [ firewall.create_chain( table=FirewallRules.FILTER, chain=FirewallRules.IMUNIFY_OUTPUT_CHAIN, ), firewall.insert_rule( FirewallRules.compose_action( FirewallRules.IMUNIFY_OUTPUT_CHAIN ), chain="OUTPUT", ), ] ) # subchains actions.extend( [ firewall.create_chain(table=FirewallRules.FILTER, chain=chain) for chain in [FirewallRules.BP_OUTPUT_CHAIN] ] ) # ipsets rules (can be in NAT or FILTER table) actions.extend( [ firewall.append_rule(**rule) for rule in await self._collect_ipset_rules(ip_version) ] ) if DOS.ENABLED or EnhancedDOS.ENABLED: # Add connection tracking rule. actions.append( firewall.insert_rule( # fmt: off ( "-m", "comment", "--comment", '"Connection tracking for Imunify360."', "-j", "CT", ), # fmt: off table="raw", chain="PREROUTING" ) ) return actions def destroy_commands( self, firewall, interface_conf: dict, ip_version: IPVersion ) -> Iterable[list]: """Returns an iterable over list of commands to destroy firewall rules. Each list should be executed as a separate firewall commit operation.""" # input chains for table, chain in self.targets(ip_version): # delete main rule yield [ firewall.delete_rule( self._compose_rule(ip_version, interface_conf), table=table, chain=chain, ) ] yield [ firewall.flush_chain( FirewallRules.IMUNIFY_INPUT_CHAIN, table=table ), firewall.delete_chain( FirewallRules.IMUNIFY_INPUT_CHAIN, table=table ), ] for chain in self._CHAINS: yield [ firewall.flush_chain(chain, table=FirewallRules.FILTER), firewall.delete_chain(chain, table=FirewallRules.FILTER), ] # output chains # delete main rule yield [ firewall.delete_rule( FirewallRules.compose_action( FirewallRules.IMUNIFY_OUTPUT_CHAIN ), chain="OUTPUT", ) ] # flush and delete main chain yield [ firewall.flush_chain(FirewallRules.IMUNIFY_OUTPUT_CHAIN), firewall.delete_chain(FirewallRules.IMUNIFY_OUTPUT_CHAIN), ] for chain in [FirewallRules.BP_OUTPUT_CHAIN]: yield [firewall.flush_chain(chain), firewall.delete_chain(chain)] # Delete connection tracking rule. yield [ firewall.delete_rule( # fmt: off ( "-m", "comment", "--comment", '"Connection tracking for Imunify360."', "-j", "CT", ), # fmt: off table="raw", chain="PREROUTING", ) ] def required_ipsets(self, ip_version: IPVersion) -> Set[str]: names = set() # type: Set[str] for entity in self.entities: names.update(entity.get_all_ipsets(ip_version)) return names async def check_commands( self, firewall: Iptables, interface_conf: dict, ip_version: IPVersion ) -> list: """Returns a list of firewall commands to check for firewall rules.""" actions = [] for table, chain in self.targets(ip_version): actions.extend( [ firewall.has_rule( self._compose_rule(ip_version, interface_conf), table=table, chain=chain, ), *self._apply_ignored_interfaces( firewall.has_rule, interface_conf, table=table ), ] ) actions.extend(self._log_block_rules(firewall.has_rule, ip_version)) actions.extend( [ firewall.has_rule( FirewallRules.compose_action( FirewallRules.IMUNIFY_OUTPUT_CHAIN ), table=FirewallRules.FILTER, chain="OUTPUT", ), ] ) actions.extend( [ firewall.has_rule(**rule) for rule in await self._collect_ipset_rules(ip_version) ] ) if DOS.ENABLED or EnhancedDOS.ENABLED: actions.append( firewall.has_rule( # fmt: off ( "-m", "comment", "--comment", '"Connection tracking for Imunify360."', "-j", "CT", ), # fmt: off table="raw", chain="PREROUTING" ) ) return actions def _log_block_rules(self, predicate, ip_version: IPVersion): rules = [] for chain, prefix, action in ( ( FirewallRules.LOG_BLACKLIST_CHAIN, UnifiedAccessLogger.BLACKLIST, FirewallRules.compose_action(FirewallRules.DROP), ), ( FirewallRules.LOG_GRAYLIST_CHAIN, UnifiedAccessLogger.GRAYLIST, FirewallRules.compose_action(FirewallRules.DROP), ), ( FirewallRules.LOG_BLACKLISTED_COUNTRY_CHAIN, UnifiedAccessLogger.BLACKLIST_COUNTRY, FirewallRules.compose_action(FirewallRules.DROP), ), ( FirewallRules.LOG_BLOCK_PORT_CHAIN, UnifiedAccessLogger.BLOCKED_BY_PORT, FirewallRules.compose_action(FirewallRules.REJECT), ), ): # At the moment, stateful packets processing is enabled # for blacklisted countries only. stateful = chain == FirewallRules.LOG_BLACKLISTED_COUNTRY_CHAIN rules.extend( predicate(rule, table=FirewallRules.FILTER, chain=chain) for rule in self._log_drop_rules( ip_version, prefix, action, stateful ) ) return rules async def _collect_ipset_rules(self, ip_version: IPVersion) -> List[dict]: rules = [] # type: List[dict] for entity in self.entities: rules.extend(entity.get_rules(ip_version)) rules.sort(key=lambda r: (r["chain"], r["priority"])) return rules async def fill_ipsets( self, ip_version: IPVersion, missing: Set[str] ) -> None: """Fills all ipsets with required elements.""" create_and_restore_cmds = [] for entity in self.entities: for ip_set in entity.get_all_ipset_instances(ip_version): if ip_set.gen_ipset_name_for_ip_version(ip_version) in missing: create_and_restore_cmds.extend( ip_set.gen_ipset_create_ops(ip_version) ) create_and_restore_cmds.extend( await ip_set.gen_ipset_restore_ops(ip_version) ) await libipset.restore(create_and_restore_cmds) logger.info("IP sets content restored from database") @staticmethod async def existing_ipsets(ip_version: IPVersion) -> Set[str]: prefix = ".".join([IP_SET_PREFIX, ip_version]) return set( s for s in await libipset.list_set() if s.startswith(prefix) ) async def _flush_ipsets(self, to_flush: set[str], ip_version: IPVersion): logger.info("Flushing ipsets: %s", to_flush) try: await libipset.restore( await self.ipset_flush_commands(ip_version, to_flush) ) except libipset.IPSetNotFoundError: logger.warning( "Failed to flush ipsets: %s", ", ".join(to_flush), ) def has_ipset_to_destroy( self, ip_version: IPVersion, existing: set[str] | None ) -> bool: if existing is None: return False prev_failed = self._get_prev_failed(ip_version) return bool(existing - prev_failed) def ipsets_to_refill( self, ip_version: IPVersion, existing: set[str], required: set[str] ) -> set[str]: """Check if ipsets need to be refilled.""" prev_failed = self._get_prev_failed(ip_version) return existing & prev_failed & required async def destroy_ipsets( self, ip_version: IPVersion, existing: set[str] | None = None, force: bool = False, ) -> None: """Destroys ipsets with given names. Args: ip_version: IP version to destroy ipsets for existing: Set of ipsets to destroy. If None, all existing ipsets will be destroyed force: If True, ignore previously failed ipsets and try to destroy them again """ logger.info( "Destroying ipsets for %s existing: %s force: %s", ip_version, existing, force, ) to_destroy = ( existing.copy() if existing is not None else await self.existing_ipsets(ip_version) ) await self._flush_ipsets(to_destroy, ip_version) prev_failed = self._get_prev_failed(ip_version) if not force else set() prev_failed -= await self._sets_without_references(prev_failed) failed_ipsets = await self._destroy_ipsets_group( to_destroy, prev_failed, ip_version ) failed_ipsets = await self._destroy_ipsets_one_by_one(failed_ipsets) if failed_ipsets: ipset_with_members = { ipset: await libipset.get_ipset_members(ipset) for ipset in failed_ipsets } references = { ipset: await libipset.get_ipset_references(ipset) for ipset in failed_ipsets } logger.error( "Failed to destroy ipsets: %s", ", ".join( f"{ipset=}: {members=} refs: {references[ipset]}" for ipset, members in ipset_with_members.items() ), ) self._save_failed_ipsets(failed_ipsets | prev_failed, ip_version) def clean_previously_failed_ipsets(self, ip_version: IPVersion) -> None: """Clean previously failed ipsets from the file.""" Path(FAILED_IPSETS_FILE.format(ip_version=ip_version)).unlink( missing_ok=True ) async def _sets_without_references(self, ipsets: set[str]) -> set[str]: """Return a set of ipsets that have no references (safe to destroy).""" res = set() for ipset in ipsets: if await libipset.get_ipset_references(ipset) == 0: res.add(ipset) return res def _get_prev_failed(self, ip_version: IPVersion) -> set[str]: prev_failed = set() if Path(FAILED_IPSETS_FILE.format(ip_version=ip_version)).exists(): try: with open( FAILED_IPSETS_FILE.format(ip_version=ip_version), "r" ) as f: prev_failed = set(json.load(f)) except (json.JSONDecodeError, OSError): logger.error( "Failed to read or parse dump file: %s", FAILED_IPSETS_FILE ) if g.get("DEBUG"): logger.info("Previous failed ipsets: %s", prev_failed) return prev_failed async def _destroy_ipsets_group( self, to_destroy: set[str], prev_failed: set[str], ip_version: IPVersion, ) -> set[str]: max_tries = 3 attempt = 0 while to_destroy and attempt < max_tries: to_destroy -= prev_failed to_destroy &= await self.existing_ipsets(ip_version) try: await libipset.restore( await self.ipset_destroy_commands(ip_version, to_destroy) ) return set() except ( libipset.IPSetNotFoundError, libipset.IPSetCannotBeDestroyedError, ): attempt += 1 logger.warning( "Failed to destroy ipsets: %s, retrying: %s", ", ".join(to_destroy), attempt, ) # return failed to destroy ipsets return to_destroy async def _destroy_ipsets_one_by_one( self, to_destroy: set[str] ) -> set[str]: if g.get("DEBUG"): logger.info("Destroying ipsets: %s", to_destroy) failed_ipsets = set() for ipset_name in to_destroy: try: await libipset.restore( [IPSetCmdBuilder.get_destroy_cmd(ipset_name)] ) except libipset.IPSetNotFoundError: # If ipset doesn't exist, we can consider it destroyed pass except libipset.IPSetCannotBeDestroyedError as e: logger.warning( "Failed to destroy ipset %s: %s", ipset_name, str(e) ) failed_ipsets.add(ipset_name) return failed_ipsets def _save_failed_ipsets( self, failed_ipsets: set[str], ip_version: IPVersion ): logger.info( "Saving failed ipsets: %s to file: %s", failed_ipsets, FAILED_IPSETS_FILE.format(ip_version=ip_version), ) try: with open( FAILED_IPSETS_FILE.format(ip_version=ip_version), "w" ) as f: json.dump(list(failed_ipsets), f) except Exception as e: logger.error("Failed to save failed ipsets to file: %s", e) async def _recreate_ipsets( self, ip_version: IPVersion, existing: Optional[Set[str]] = None ): """Reset all ipsets, create them again and fill with IPs for given ip version.""" for entity in self.entities: await entity.reset(ip_version, existing) async def recreate_ipsets( self, ip_version: IPVersion = None, existing: Optional[Set[str]] = None ): """Recreate existing ipsets (or given). If *ip_version* is None, recreate ipsets for all enabled ip versions. """ if ip_version: await self._recreate_ipsets(ip_version, existing) else: for ip_version in ip_versions.enabled(): await self._recreate_ipsets(ip_version, existing) @staticmethod def _log_drop_rules(ip_version: IPVersion, prefix, action, stateful: bool): rules = [] if stateful: rules.append( # fmt: off ( "-m", "conntrack", "--ctstate", "ESTABLISHED,RELATED", "-j", "ACCEPT", ), # fmt: on ) if firewall_logging_enabled(): rules.append( FirewallRules.compose_rule( action=FirewallRules.nflog_action( group=FirewallRules.nflog_group(ip_version), prefix=prefix, ) ) ) rules.append(action) return rules async def get_outdated_ipsets(self, ip_version: IPVersion) -> list: """ Return list of ipsets the contents of which do not match the database """ outdated: list = [] for entity in self.entities: all_ipsets = await entity.get_ipsets_count(ip_version) outdated.extend( ipset for ipset in all_ipsets if ipset.ipset_count != ipset.db_count ) return outdated