import asyncio
from enum import Enum
from typing import List
from defence360agent.contracts.config import Core as CoreConfig
from defence360agent.utils import batched, batched_dict
class MessageNotFoundError(Exception):
pass
class UnknownMessage:
"""
Used as stub for MessageType
"""
def __init__(self):
raise MessageNotFoundError("Message class is not found.")
def __getattr__(self, name):
return "Unknown" # pragma: no cover
class MessageT:
_subclasses = []
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._subclasses.append(cls)
@classmethod
def get_subclasses(cls):
return tuple(cls._subclasses)
class _MessageType:
"""
Used to get specific message class. For example,
>>> _MessageType().ConfigUpdate
<class 'defence360agent.contracts.messages.ConfigUpdate'>
>>> _MessageType().NotExistMessage
<class 'defence360agent.contracts.messages.UnknownMessage'>
>>>
"""
def __getattr__(self, name):
for subcls in Message.get_subclasses():
# is is supposed that all subclasses have different names
if subcls.__name__ == name:
return subcls
return UnknownMessage
MessageType = _MessageType()
class ReportTarget(Enum):
API = "api"
PERSISTENT_CONNECTION = "conn"
class Reportable(MessageT):
"""
Mixin class for messages that should be sent to the server
"""
TARGET = ReportTarget.PERSISTENT_CONNECTION
@classmethod
def get_subclass_with_method(cls, method: str):
"""
Return a subclass with the same DEFAULT_METHOD as *method*.
It can be used to detect report target from message method.
NOTE: it is not guaranteed that the class with the *method* is unique,
in this case the first subclass found is returned, but
it is tested that all such subclasses have the same TARGET.
"""
for subclass in cls.__subclasses__():
if method == getattr(subclass, "DEFAULT_METHOD"):
return subclass
return None # pragma: no cover
class Received(MessageT):
"""
Mixin class for messages received from the server.
These messages are created in the client360 plugin when receiving a
request from imunify360.cloudlinux.com.
"""
@classmethod
def get_subclass_with_action(cls, action: str):
for subclass in cls.__subclasses__():
received_actions = getattr(subclass, "RECEIVED_ACTIONS", []) or [
getattr(subclass, "DEFAULT_METHOD")
]
if action in received_actions:
return subclass
raise MessageNotFoundError(
'Message class is not found for "{}" action'.format(action)
)
class Lockable(MessageT):
_lock = None
@classmethod
async def acquire(cls) -> None:
if cls._lock is None:
cls._lock = asyncio.Lock()
await cls._lock.acquire()
@classmethod
def locked(cls) -> bool:
return cls._lock is not None and cls._lock.locked()
@classmethod
def release(cls) -> None:
if cls._lock is not None:
cls._lock.release()
class Message(dict, MessageT):
"""
Base class for messages to be passed as
a parameter to plugins.MessageSink.process_message()
"""
# Default method='...' to send to the Server
DEFAULT_METHOD = ""
PRIORITY = 10
PROCESSING_TIME_THRESHOLD = 60 # 1 min
#: fold collections' repr with more than the threshold number of items
_FOLD_LIST_THRESHOLD = 100
#: shorten strings longer than the threshold characters
_SHORTEN_STR_THRESHOLD = 320
def __init__(self, *args, **kwargs) -> None:
if self.DEFAULT_METHOD:
self["method"] = self.DEFAULT_METHOD
super(Message, self).__init__(*args, **kwargs)
@property
def payload(self):
return {k: v for k, v in self.items() if k != "method"}
def __getattr__(self, name):
"""
Called when an attribute lookup has not found the attribute
in the usual places
A shortcut to access an item from dict
"""
try:
return self[name]
except KeyError as exc:
raise AttributeError(name) from exc
def __repr__(self):
"""
Do not flood console.log with large sequences If there is a list
of more than _FOLD_LIST_THRESHOLD items inside the message,
then this list will be collapsed and the number of items in
the list will be shown. The message itself will not be
collapsed.
"""
items_to_fold = {
k: _shorten_str(colxn, self._SHORTEN_STR_THRESHOLD)
if isinstance(colxn, str)
else f"<{len(colxn)} item(s)>"
for k, colxn in self.items()
if hasattr(colxn, "__len__")
and len(colxn) > self._FOLD_LIST_THRESHOLD
}
folded_msg = self.copy()
folded_msg.update(items_to_fold)
return "{}({})".format(self.__class__.__qualname__, folded_msg)
def __str__(self):
return self.__repr__()
class MessageList(Message):
def __init__(self, msg_list):
super().__init__(list=msg_list)
@property
def payload(self):
return self.list
def __repr__(self):
"""
Do not flood console.log with full MessageList
"""
return "{}({})".format(
self.__class__.__qualname__,
"<{} item(s)>".format(len(self.get("list", []))),
)
class ShortenReprListMixin:
"""
Do not flood console.log with large sequences
The method collapses messages that are a list.
Instead of showing all the elements of the message,
their number will be displayed.
"""
def __repr__(self: dict): # type: ignore
return "{}({})".format(
self.__class__.__qualname__,
"<{} item(s)>".format(len(self.get("items", []))),
)
class Accumulatable(Message):
"""Messages of this class will be grouped into a list of LIST_CLASS
message instance by Accumulate plugin. Messages whose do_accumulate()
call returns False will not be added to list."""
LIST_CLASS = MessageList
def do_accumulate(self) -> bool:
"""Return True if this message is worth collecting, False otherwise."""
return True
class ServerConnected(Message):
pass
# alias (for better client code readability)
class ServerReconnected(ServerConnected):
pass
class Ping(Message, Reportable):
"""
Will send this message on connected, reconnected events
to provide central server with agent version
"""
DEFAULT_METHOD = "PING"
PRIORITY = 0
def __init__(self):
super().__init__()
self["version"] = CoreConfig.VERSION
class Ack(Message, Reportable):
"""
Notify Server that a persistent message with *seq_number* has been
received by Agent.
"""
DEFAULT_METHOD = "ACK"
def __init__(self, seq_number, **kwargs):
super().__init__(**kwargs)
self["_meta"] = dict(per_seq=seq_number)
class Noop(Message):
"""
Sending NOOP to the agent to track the message in agent logs.
"""
DEFAULT_METHOD = "NOOP"
class ServerConfig(Message, Reportable):
"""
Information about server environment
"""
DEFAULT_METHOD = "SERVER_CONFIG"
TARGET = ReportTarget.API
def __repr__(self):
return "{}()".format(self.__class__.__qualname__)
class DomainList(Message, Reportable):
"""
Information about server domains
"""
DEFAULT_METHOD = "DOMAIN_LIST"
TARGET = ReportTarget.API
def __repr__(self):
return "{}()".format(self.__class__.__qualname__)
class FilesUpdated(Message):
"""
To consume products of files.update()
"""
def __init__(self, files_type, files_index):
"""
:param files_type: files.Type
:param files_index: files.LocalIndex
"""
# explicit is better than implicit
self["files_type"] = files_type
self["files_index"] = files_index
def __repr__(self):
"""
Do not flood console.log with large sequences
"""
return "{}({{'files_type':'{}', 'files_index':{}}})".format(
self.__class__.__qualname__,
self["files_type"],
self["files_index"],
)
class UpdateFiles(Message, Received):
"""
Update files by getting message from the server
"""
DEFAULT_METHOD = "UPDATE"
class ConfigUpdate(Message):
DEFAULT_METHOD = "CONFIG_UPDATE"
class Reject(Exception):
"""
Kinda message filtering facility.
Raised in order to stop message processing through plugins.
Takes reason of reject as argument.
"""
pass
class Health(Message):
DEFAULT_METHOD = "HEALTH"
class CommandInvoke(Message, Reportable):
DEFAULT_METHOD = "COMMAND_INVOKE"
class ScanFailed(Message, Reportable):
DEFAULT_METHOD = "SCAN_FAILED"
class CleanupFailed(Message, Reportable):
DEFAULT_METHOD = "CLEANUP_FAILED"
class RestoreFromBackupTask(Message):
"""
Creates a task to restore files from backup
"""
DEFAULT_METHOD = "MALWARE_RESTORE_FROM_BACKUP"
class cPanelEvent(Message):
DEFAULT_METHOD = "PANEL_EVENT"
ALLOWED_FIELDS = {
"new_pkg",
"plan",
"exclude",
"imunify360_proactive",
"imunify360_av",
}
@classmethod
def from_hook_event(
cls, username: str, hook: str, ts: float, fields: dict
):
data = {
k.lower(): v
for k, v in fields.items()
if k.lower() in cls.ALLOWED_FIELDS
}
# Check for user rename
if (
hook == "Modify"
and "user" in fields
and "newuser" in fields
and fields["user"] != fields["newuser"]
):
data["old_username"] = fields["user"]
return cls(
{
"username": username,
"hook": hook,
"data": data,
"timestamp": ts,
}
)
class IContactSent(Message, Reportable):
DEFAULT_METHOD = "ICONTACT_SENT"
def _shorten_str(s: str, limit: int) -> str:
"""Shorten *s* string if its length exceeds *limit*."""
assert limit > 4
return f"{s[:limit//2-1]}...{s[-limit//2+2:]}" if len(s) > limit else s
class BackupInfo(Message, Reportable):
"""Information about enabled backup backend"""
DEFAULT_METHOD = "BACKUP_INFO"
class MDSReportList(ShortenReprListMixin, Message, Reportable):
DEFAULT_METHOD = "MDS_SCAN_LIST"
class MDSReport(Accumulatable):
LIST_CLASS = MDSReportList
class Splittable:
"""
A message list could be split into multiple batches.
The split is possible for a list itself along with internal resources.
"""
LIST_SIZE = None
BATCH_SIZE = None
BATCH_FIELD = None
@classmethod
def _split_items(cls, messages: List[Accumulatable]):
"""
Split messages' internal lists of things into batches.
A field that is meant to split is defined by `BATCH_FIELD`.
"""
if cls.BATCH_FIELD and cls.BATCH_SIZE:
for message in messages:
if (items := message.get(cls.BATCH_FIELD)) is None:
yield message
else:
message_class = type(message)
batcher = (
batched_dict if isinstance(items, dict) else batched
)
for batch in batcher(items, cls.BATCH_SIZE):
data = message.copy()
data[cls.BATCH_FIELD] = batch
new_message = message_class(data)
yield new_message
else:
yield from iter(messages)
@classmethod
def batched(cls, messages: List[Accumulatable]):
list_size = cls.LIST_SIZE or len(messages)
split = cls._split_items(messages)
yield from batched(split, list_size)