import asyncio
import collections
import io
import reprlib
import time
import weakref
import logging
from contextlib import suppress
from operator import attrgetter
from defence360agent.contracts.messages import Message, Reject
from defence360agent.contracts.plugins import BaseMessageProcessor
from defence360agent.utils.common import DAY, ServiceBase, rate_limit
from defence360agent.internals.global_scope import g
logger = logging.getLogger(__name__)
ProcessingMessage = collections.namedtuple(
"ProcessingMessage", ["message", "start_time"]
)
class TheSink(BaseMessageProcessor):
def __init__(self, sink_list, loop):
self._sinks_ordered = sorted(
sink_list, key=attrgetter("PROCESSING_ORDER")
)
self._loop = loop
self._task_manager = TaskManager(
loop, MessageProcessor(self._sinks_ordered)
)
g.sink = self
def __repr__(self):
return "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
def decompose(self, classobj):
"""
introspection: decompose a specific role
:return classobj: instance or None
"""
options = [
sink for sink in self._sinks_ordered if isinstance(sink, classobj)
]
assert len(options) <= 1, "Ambiguous request"
return next(iter(options), None)
def start(self):
"""
Make sure to run message processing bus only
when every MessageSource (or MessageSource+MessageSink mix)
got initialized
"""
self._task_manager.start()
async def shutdown(self):
logger.info("shutdown the sink started")
self._task_manager.should_stop()
logger.info("wait for current tasks")
await self._task_manager.wait_current_tasks(timeout=5)
logger.info("finish wait task")
await self._task_manager.wait()
async def process_message(self, message):
await self._task_manager.push_msg(message)
class TaskManager(ServiceBase):
# max queue message size
MAXSIZE = 100000
# number of concurrently processed messages
CONCURRENCY = 5
# how long an individual message may be processed
TIMEOUT = 3600 # seconds
def __init__(self, loop, msg_processor):
super().__init__(loop)
self._queue = MessageQueue(maxsize=self.MAXSIZE)
self._concurrency = self.CONCURRENCY
self._process_message_timeout = self.TIMEOUT
self._msg_processor = msg_processor
self.tasks = weakref.WeakSet()
self._throttled_logger = rate_limit(period=DAY, on_drop=logger.warning)
self.throttled_log_error = self._throttled_logger(logger.error)
async def push_msg(self, msg):
"""Push message unless the queue is full."""
if not self._queue.full():
await self._queue.put(MessageComparable(msg))
else:
if self._throttled_logger.should_be_called: # send to Sentry
args = (
(
"Message queue is full %s. "
"Current processing messages: %s. Message ignored: %s"
),
self._queue,
self.current_processing_messages,
msg,
)
else: # don't serialize the queue on each log warning entry
args = (
(
"Message queue is full. Queue size: %s "
"Current processing messages: %s. Message ignored: %s"
),
self._queue.qsize(),
self.current_processing_messages,
msg,
)
self.throttled_log_error(*args)
@property
def current_processing_messages(self):
# the loop should be safe (no ref should be removed from the weak
# set while iterating)
# https://stackoverflow.com/questions/12428026/safely-iterating-over-weakkeydictionary-and-weakvaluedictionary # noqa
# the loop is constant time because
# len(self.tasks) == self._concurrency (fixed & small)
return tuple(
(
task.processing_msg.message,
round(time.monotonic() - task.processing_msg.start_time, 4),
)
for task in self.tasks
if not task.done()
)
async def wait_current_tasks(self, timeout=None):
if self.tasks:
msg_to_process = [
(m.get("method"), m.get("message_id"), lasting)
for m, lasting in self.current_processing_messages
]
logger.info(
"Waiting for %r processing to finish",
msg_to_process,
)
await asyncio.wait(self.tasks, timeout=timeout)
async def _run(self):
semaphore = asyncio.BoundedSemaphore(self._concurrency)
try:
while not self._should_stop:
logger.debug("Message queue size: %s", self._queue.qsize())
try:
await self.__limit_concurrency(semaphore)
msg_comparable = await self._queue.get()
except asyncio.CancelledError:
break
t = self._loop.create_task(
self._msg_processor(msg_comparable.msg)
) # type: asyncio.Task
t.processing_msg = ProcessingMessage(
msg_comparable.msg, time.monotonic()
)
t.add_done_callback(lambda _: semaphore.release())
t.add_done_callback(self._on_msg_processed)
self.tasks.add(t)
unprocessed = self._queue.qsize()
if unprocessed:
logger.warning(
"There is still %s unprocessed messages in the queue",
self._queue.qsize(),
)
except: # NOQA
logger.exception("Error during message processing:")
async def __limit_concurrency(self, semaphore):
"""Try to acquire *semaphore* in a loop, log error on timeout."""
while True:
try:
return await asyncio.wait_for(
semaphore.acquire(),
timeout=self._process_message_timeout,
)
except asyncio.TimeoutError:
self.throttled_log_error(
"Message hasn't been processed in %s seconds",
self._process_message_timeout,
)
@staticmethod
def _on_msg_processed(future):
e = future.exception()
if e:
logger.exception("Error during message processing:", exc_info=e)
async def cancel_task(task):
if not task.done():
task.cancel()
with suppress(asyncio.CancelledError):
await task
class MessageProcessor(object):
TIMEOUT_TO_SINK_PROCESS = 3600
def __init__(self, sinks):
self.sinks = sinks
self.locks = weakref.WeakValueDictionary()
self.throttled_log_error = rate_limit(period=60 * 60)(
logger.error
) # send event to Sentry once an hour
async def __call__(self, msg):
ip = msg.get("attackers_ip")
if ip:
lock = self.locks.setdefault(ip, asyncio.Lock())
async with lock:
await self._call_unlocked(msg)
else:
await self._call_unlocked(msg)
async def _call_unlocked(self, msg):
start = time.monotonic()
for sink in self.sinks:
try:
process_message_task = asyncio.create_task(
sink.process_message(msg)
)
processed = await asyncio.wait_for(
# shielded only for debug DEF-18627,
# it should intercept `CancelledError` that
# `asyncio.wait_for` send to `process_message_task`
# in case timeout
asyncio.shield(process_message_task),
timeout=self.TIMEOUT_TO_SINK_PROCESS,
)
except asyncio.CancelledError:
break
except Reject as e:
logger.info("Rejected: %s -> %r", str(e), msg)
return
except asyncio.TimeoutError:
# debug for DEF-18627, it's supposed that during this exception
# handling we will get call stack in logs and see last await
# that hang out coroutine was made,
# may be it's give us some hint about problem
stack = io.StringIO()
process_message_task.print_stack(file=stack)
stack.seek(0)
logger.error(
"Message %r was not processed in the %r plugin in %ss; "
"Traceback: %s",
msg,
sink,
self.TIMEOUT_TO_SINK_PROCESS,
stack.read(),
)
return
except Exception:
logger.exception("Error processing %r in %r", msg, sink)
return
else:
if isinstance(processed, Message):
msg = processed
finally:
await cancel_task(process_message_task)
processing_time = time.monotonic() - start
logger.info("%s processed in %.4f seconds", msg, processing_time)
if processing_time > msg.PROCESSING_TIME_THRESHOLD:
# send to Sentry
self.throttled_log_error(
"%s message took longer to process than expected "
"(%.4f sec > %.4f sec)",
msg,
processing_time,
msg.PROCESSING_TIME_THRESHOLD,
)
class MessageComparable(object):
"""Wrapper to make message comparable."""
# needed to keep order
index = -1
@staticmethod
def __new__(cls, msg):
cls.index += 1
rv = super().__new__(cls)
rv.priority = msg.PRIORITY, cls.index
rv.msg = msg
return rv
def __lt__(self, other):
return self.priority.__lt__(other.priority)
def __repr__(self):
return "<{klass}({msg!r}), priority={priority}>".format(
klass=self.__class__.__name__,
msg=self.msg,
priority=self.priority,
)
class MessageQueue(asyncio.PriorityQueue):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._repr = reprlib.Repr()
self._repr.maxstring = 50
self._repr.maxtuple = 2000
async def put(self, item: MessageComparable):
return await super().put(item)
def __str__(self):
# NOTE: do not flood console.log with full queue
msg_counts = sorted(
collections.Counter(
[item.msg.__class__.__qualname__ for item in self._queue]
).items(),
key=lambda item: item[1], # sorted by number of messages
reverse=True,
)
return (
f"<PriorityQueue maxsize={self.maxsize}; "
f"queue_size={self.qsize()} "
f"queue_counter={self._repr.repr(msg_counts)}>"
)