"""
This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License,
or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Copyright © 2019 Cloud Linux Software Inc.
This software is also available under ImunifyAV commercial license,
see <https://www.imunify360.com/legal/eula>
"""
import shutil
import time
from logging import getLogger
from typing import Dict, Optional, Union
from defence360agent.contracts.hook_events import HookEvent
from defence360agent.contracts.messages import MessageType
from defence360agent.contracts.plugins import (
MessageSink,
MessageSource,
expect,
)
from defence360agent.utils import Scope
from imav.malwarelib.config import (
MalwareScanResourceType,
MalwareScanType,
)
from imav.malwarelib.model import MalwareScan as MalwareScanModel
from imav.malwarelib.scan import (
ScanAlreadyCompleteError,
ScanInfoError,
)
from imav.malwarelib.scan.ai_bolit.detached import (
AiBolitDetachedScan,
)
from imav.malwarelib.scan.mds.detached import MDSDetachedScan
from imav.malwarelib.scan.queue_supervisor_sync import QueueSupervisorSync
from imav.malwarelib.scan.scan_result import aggregate_result
from imav.malwarelib.utils.user_list import fill_results_owner
logger = getLogger(__name__)
class DetachedScanPlugin(MessageSink, MessageSource):
PROCESSING_ORDER = MessageSink.ProcessingOrder.PRE_PROCESS_MESSAGE
SCOPE = Scope.AV
loop, sink = None, None
results_cache = {} # type: Dict[str, dict]
async def create_source(self, loop, sink):
self.loop = loop
self.sink = sink
async def create_sink(self, loop):
pass
@expect(MessageType.MalwareScan, async_lock=True)
async def complete_scan(self, message):
message_type = MalwareScanMessageInfo(message)
if not message_type.is_detached:
total_malicious = await self._count_total_malicious(message)
message["summary"]["total_malicious"] = total_malicious
return message
elif message_type.is_summary:
return await self._handle_summary(message)
# message_type.is_result
return await self._handle_results(message)
async def _handle_summary(self, message):
scan_id = message["summary"]["scanid"]
# If summary arrives after results, results are read from cache
if scan_id in self.results_cache:
message["summary"]["completed"] = time.time()
message["results"] = self.results_cache.pop(scan_id)
total_malicious = await self._count_total_malicious(message)
message["summary"]["total_malicious"] = total_malicious
queued_scan = QueueSupervisorSync.queue.find(
scanid=message["summary"]["scanid"]
)
if queued_scan:
QueueSupervisorSync.queue.remove(queued_scan)
await self._call_scan_finished_hook(
message["summary"], queued_scan.args if queued_scan else {}
)
return message
async def _handle_results(self, message):
message = await self.aggregate_result(message)
message_type = MalwareScanMessageInfo(message)
summary = message["summary"]
logger.info("Scan stopped")
queued_scan = QueueSupervisorSync.queue.find(scanid=summary["scanid"])
if message_type.summary_from_db is None:
if queued_scan:
summary["file_patterns"] = queued_scan.args["file_patterns"]
summary["exclude_patterns"] = queued_scan.args[
"exclude_patterns"
]
QueueSupervisorSync.queue.remove(queued_scan)
if summary.get("path") or summary.get("error"):
# Scan failed
summary["total_malicious"] = 0
await self._call_scan_finished_hook(summary, scan_args={})
return message
# Summary is not in DB yet, save results to cache
scan_id = message["summary"]["scanid"]
self.results_cache[scan_id] = message["results"]
# Report an error to Sentry if cache grows
cache_size = len(self.results_cache)
if cache_size > 1:
logger.error("MalwareScan cache size is %d", cache_size)
return
scan = message_type.summary_from_db
summary["scanid"] = scan.scanid
summary["path"] = scan.path
summary["started"] = scan.started
summary["completed"] = time.time()
if summary.get("total_files") is None:
summary["total_files"] = scan.total_resources
summary["type"] = scan.type
summary["error"] = summary.get("error", None)
message["summary"] = summary
total_malicious = await self._count_total_malicious(message)
message["summary"]["total_malicious"] = total_malicious
if queued_scan:
summary["file_patterns"] = queued_scan.args["file_patterns"]
summary["exclude_patterns"] = queued_scan.args["exclude_patterns"]
QueueSupervisorSync.queue.remove(queued_scan)
await self._call_scan_finished_hook(
summary, queued_scan.args if queued_scan else {}
)
return message
@staticmethod
async def _count_total_malicious(message) -> int:
return len(
[
k
for k, v in message["results"].items()
if v["hits"][0]["suspicious"] is False
]
)
async def _call_scan_finished_hook(self, summary, scan_args) -> None:
scan_finished = HookEvent.MalwareScanningFinished(
scan_id=summary["scanid"],
scan_type=summary["type"],
path=summary["path"],
started=summary["started"],
total_files=summary["total_files"],
total_malicious=summary["total_malicious"],
error=summary.get("error"),
status="failed" if summary.get("error") else "ok",
scan_params=scan_args,
stats={
**{
key: value
for key, value in summary.items()
if key
in ( # performance-related metrics
"scan_time",
"scan_time_hs",
"scan_time_preg",
"smart_time_hs",
"smart_time_preg",
"finder_time",
"cas_time",
"deobfuscate_time",
"mem_peak",
)
},
**{"total_files": summary["total_files"]},
},
)
await self.sink.process_message(scan_finished)
await self._recheck_scan_queue()
@staticmethod
def _get_detached_scan(
resource_type: Optional[Union[str, MalwareScanResourceType]], scan_id
):
return AiBolitDetachedScan(scan_id)
@expect(MessageType.MalwareScanComplete)
async def complete_detached_scan(self, message):
scan_id = message.get("scan_id")
resource_type = message.get("resource_type")
detached_scan = self._get_detached_scan(resource_type, scan_id)
try:
scan_message = await detached_scan.complete()
except ScanAlreadyCompleteError as err:
# This happens when AV is woken up by AiBolit. See DEF-11078.
logger.warning(
"Cannot complete scan %s, assuming it is already complete"
":\n%s",
scan_id,
err,
)
return
except ScanInfoError as err:
logger.error(
"Cannot complete %s scan %s, assuming it was not started:\n%s",
detached_scan.RESOURCE_TYPE.value,
scan_id,
err,
)
return
finally:
shutil.rmtree(str(detached_scan.detached_dir), ignore_errors=True)
await self.sink.process_message(scan_message)
@classmethod
async def aggregate_result(cls, message):
message["results"] = aggregate_result(message["results"])
await fill_results_owner(message["results"])
return message
async def _recheck_scan_queue(self):
await self.sink.process_message(MessageType.MalwareScanQueueRecheck())
class MalwareScanMessageInfo:
"""A helper class that allows to receive information about scan
from MalwareScan message.
"""
def __init__(self, message):
self.message = message
self._summary_from_db = None
self.scan_id = self.message["summary"]["scanid"]
@property
def is_detached(self):
summary = self.message["summary"]
return summary.get("type") in (
MalwareScanType.ON_DEMAND,
MalwareScanType.BACKGROUND,
MalwareScanType.USER,
None,
)
@property
def is_summary(self):
return self.message["results"] is None
@property
def summary_from_db(self):
if not self._summary_from_db:
summary_from_db = (
MalwareScanModel.select()
.where(MalwareScanModel.scanid == self.scan_id)
.limit(1)
)
if summary_from_db:
self._summary_from_db = summary_from_db[0]
return self._summary_from_db
class DetachedScanPluginIm360(DetachedScanPlugin):
SCOPE = Scope.IM360
@staticmethod
def _get_detached_scan(
resource_type: Optional[Union[str, MalwareScanResourceType]], scan_id
):
if resource_type is not None and (
MalwareScanResourceType(resource_type)
is MalwareScanResourceType.DB
):
return MDSDetachedScan(scan_id)
return AiBolitDetachedScan(scan_id)
@expect(MessageType.MalwareDatabaseScan)
async def complete_scan_db(self, message):
queued_scan = QueueSupervisorSync.queue.find(scanid=message["scan_id"])
if queued_scan:
QueueSupervisorSync.queue.remove(queued_scan)
scan_finished_event = HookEvent.MalwareScanningFinished(
scan_id=message["scan_id"],
scan_type=message["type"],
path=message["path"],
)
await self.sink.process_message(scan_finished_event)
await self._recheck_scan_queue()