import asyncio
import inspect
import logging
import subprocess
from abc import ABC, ABCMeta, abstractmethod
from contextlib import suppress
from functools import lru_cache, wraps
from defence360agent.contracts.messages import Message, MessageType
from defence360agent.utils import Scope
logger = logging.getLogger(__name__)
class BasePlugin(object):
SCOPE = Scope.AV_IM360
SHUTDOWN_PRIORITY = 100 # lower means shuts down first
AVAILABLE_ON_FREEMIUM = True
_subclasses = []
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._subclasses.append(cls)
@classmethod
def get_active_plugins(cls):
# consider all non-abstract subclasses are active
return [
plugin
for plugin in cls._subclasses
if not inspect.isabstract(plugin)
]
async def shutdown(self):
"""Shutdown plugin's subsystems, cancel running tasks,
clean iptables (if plugin is protector).
It should be safe to assume that it is called after
corresponding create_source if applicable.
It is called only from the shutdown task that runs at most once,
meaning shutdown() is never called twice.
"""
pass
def __repr__(self):
return "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
class MessageSource(BasePlugin, ABC):
@abstractmethod
def create_source(self, loop, sink):
"""This method is a coroutine."""
class Sensor(MessageSource, ABC):
"""
Sensor is alias to MessageSource.
"""
def create_source(self, loop, sink):
"""This method is a coroutine."""
return self.create_sensor(loop, sink)
@abstractmethod
def create_sensor(self, loop, sink):
"""This method is a coroutine."""
class LogStreamReader(Sensor, metaclass=ABCMeta):
source_file = None
# Limit of bytes consumed from stream
# while trying to read one line (128 kB)
_LIMIT = 2**17
_cmd = None
async def create_sensor(self, loop, sink):
self._loop = loop
self._sink = sink
self._cmd = None
if not self.source_file:
return
self._cmd = (
"/usr/bin/tail",
# follow beyond the end of the file
"--follow=name",
"-n0",
# keep trying to open a file if it is inaccessible
"--retry",
self.source_file,
)
self._child_process = await asyncio.create_subprocess_exec(
*self._cmd,
stdin=subprocess.DEVNULL,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
bufsize=0,
limit=self._LIMIT
)
loop.create_task(
self._infinite_read_and_proceed(self._child_process.stdout)
)
async def shutdown(self):
if self._cmd is not None:
cmd, self._cmd = self._cmd, None
logger.debug("Terminating child process [%s]", cmd)
# child process dies from the same signal when agent
# is run from console (not as --daemon)
with suppress(ProcessLookupError):
self._child_process.kill()
rc = await self._child_process.wait()
logger.debug(
"Terminated child process [%s] with code [%d]", cmd, rc
)
@abstractmethod
async def _infinite_read_and_proceed(self, stream_reader):
raise NotImplementedError
class BaseMessageProcessor:
@lru_cache(maxsize=1)
def _message_processors(self):
rv = []
for attr_str in dir(self):
if attr_str.startswith("_"):
continue # skip non-public attributes
func = getattr(self, attr_str)
if callable(func) and hasattr(
func, "_decorated_for_process_message"
):
rv.append(func)
return rv
async def process_message(self, message):
logger.debug("Dispatching %r through %r...", message, self)
for coro in self._message_processors():
result = await coro(message)
if isinstance(result, Message):
return result
class MessageSink(BasePlugin, BaseMessageProcessor, ABC):
class ProcessingOrder:
# e.g. check is valid ipv4
PRE_PROCESS_MESSAGE = 10
# lfd plugin should process lfd alerts before other ignore plugins
LFD = 18
# e.g. for ignore_alert_with_whitelisted_ip
IGNORE_MESSAGE = 20
# Should be before check ip in graylist
UNBLOCK_FROM_SUBNET = 30
# Check ip in the graylist already
CHECK_IP_IN_GRAYLIST = 40
# Append ttl to alert
GRAYLIST_TIMEOUT = 50
# Store graylist to db
GRAYLIST_DB_FIXUP = 55
# this should run before IPSET_PROTECTOR
IMPORT_EXPORT_WBLIST = 60
# make ml prediction before lazy_init
ML_PREDICTION = 70
# the default
DEFAULT = 80
IPSET_PROTECTOR = DEFAULT
WEBSHIELD_PROTECTOR = 81
# should be run after ManageGrayList(DEFAULT)
WHITELIST_UNBLOCKED = 90
# Synclist timestamp update
SYNCLIST_UPDATE = 100
# post action
POST_ACTION = 120
# event hook processing
EVENT_HOOK = 150
# iContact
ICONTACT_SENT = 200
# e.g. Accumulate
POST_PROCESS_MESSAGE = 999
# alias for DEFAULT
PROCESSING_ORDER = ProcessingOrder.DEFAULT
@abstractmethod
async def create_sink(self, loop):
pass
def expect(*message_type, async_lock=None, **expect_fields):
"""
@expect decorator for MessageSink.dosmth(message) async methods.
MessageSink method will be called by MessageSink.process_message()
if message_type and expect_fields match the message ones.
@expect's can be stacked together and decision whether to call decorated
coro is made by evaluating stacked @expect's with logical OR:
@expect(MessageType.SensorAlert) # -- OR --
@expect(MessageType.SensorIncident, plugin_id='ossec')
def protect(message): ...
"""
def decorate(coro):
if getattr(coro, "__name__", "").startswith("_"):
raise TypeError("{coro} is not public".format(coro=coro))
@wraps(coro)
async def decorated(self, message):
def match():
return isinstance(message, message_type) and all(
message.get(k) == v for k, v in expect_fields.items()
)
def is_stacked(coro):
return hasattr(coro, "_decorated_for_process_message")
def terminal(coro):
if is_stacked(coro):
return terminal(coro._decorated_for_process_message)
return coro
# process stacked decorators with logical OR
if match():
if async_lock is True:
await message.acquire()
try:
result = await terminal(coro)(self, message)
except Exception as exc:
if (
isinstance(message, MessageType.Lockable)
and message.locked()
):
message.release()
raise exc
else:
if (
async_lock is False
and isinstance(message, MessageType.Lockable)
and message.locked()
):
message.release()
return result
if is_stacked(coro):
# Give next decorator a chance: logical OR
return await coro(self, message)
return None
decorated._decorated_for_process_message = coro
return decorated
return decorate
_plugin_registry = set()
def thisguy(plugincls):
"""Register class as a plugin.
>>> @thisguy
>>> class ConcreteSink (MessageSink):
>>> ...
"""
_plugin_registry.add(plugincls)
return plugincls
def theseguys():
"""Enumerate classobj for registered plugins."""
return _plugin_registry