"""
This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License,
or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Copyright © 2019 Cloud Linux Software Inc.
This software is also available under ImunifyAV commercial license,
see <https://www.imunify360.com/legal/eula>
"""
from __future__ import annotations
import asyncio
import itertools
import os
from dataclasses import dataclass
from operator import attrgetter
from pathlib import Path
from time import time
from typing import Dict, Iterable, List, Set, cast
from peewee import (
BooleanField,
Case,
CharField,
Check,
Expression,
FloatField,
ForeignKeyField,
IntegerField,
PrimaryKeyField,
SQL,
TextField,
fn,
)
from playhouse.shortcuts import model_to_dict
from defence360agent.contracts.config import UserType
from defence360agent.model import Model, instance
from defence360agent.model.simplification import (
FilenameField,
ScanPathField,
apply_order_by,
)
from defence360agent.utils import (
execute_iterable_expression,
get_abspath_from_user_dir,
get_results_iterable_expression,
split_for_chunk,
)
from imav.malwarelib.config import (
FAILED_TO_CLEANUP,
MalwareHitStatus,
MalwareScanResourceType,
MalwareScanType,
)
from imav.malwarelib.scan.crontab import get_crontab
class MalwareScan(Model):
"""Represents a batch of files scanned for malware
Usually a single AI-BOLIT execution.
See :class:`.MalwareScanType` for possible kinds of scans.
"""
class Meta:
database = instance.db
db_table = "malware_scans"
#: An id of a scan, unique per server.
scanid = CharField(primary_key=True)
#: Scan start timestamp.
started = IntegerField(null=False)
#: Scan completion timestamp.
completed = IntegerField(null=True)
#: Scan type - reflects how and why the files were scanned.
#: Must be one of :class:`.MalwareScanType`.
type = CharField(
null=False,
constraints=[
Check(
"type in {}".format(
(
MalwareScanType.ON_DEMAND,
MalwareScanType.REALTIME,
MalwareScanType.MALWARE_RESPONSE,
MalwareScanType.BACKGROUND,
MalwareScanType.RESCAN,
MalwareScanType.USER,
MalwareScanType.RESCAN_OUTDATED,
)
)
)
],
)
#: The number of resources scanned.
total_resources = IntegerField(null=False, default=0)
#: For some types of scan - the directory or a file that was scanned.
path = ScanPathField(null=True, default="")
#: If not `null`, the scan did not finish successfully.
#: Can be one of :class:`.ExitDetachedScanType` if scan was aborted or
#: stopped by user, or an arbitrary error message for other kinds
#: of issues.
error = TextField(null=True, default=None)
#: The number of malicious files found
total_malicious = IntegerField(null=False, default=0)
resource_type = CharField(
null=False,
constraints=[
Check(
"resource_type in {}".format(
(
MalwareScanResourceType.DB.value,
MalwareScanResourceType.FILE.value,
)
)
)
],
)
#: user who started the scan (None for root user)
initiator = CharField(null=True)
@classmethod
def ondemand_list(
cls,
since,
to,
limit,
offset,
order_by=None,
*,
types=(
MalwareScanType.ON_DEMAND,
MalwareScanType.BACKGROUND,
MalwareScanType.USER,
),
paths=None,
):
query = (
cls.select(
cls.total_resources,
cls.path,
cls.scanid,
cls.started,
cls.completed,
cls.error,
cls.total_malicious,
cls.type.alias("scan_type"),
cls.resource_type,
)
.where(cls.type.in_(types))
.where(cls.started >= since)
.where(cls.started <= to)
)
if paths:
query = query.where(cls.path.in_(paths))
query = (
query.group_by(
cls.total_resources, cls.path, cls.scanid, cls.started
)
.order_by(MalwareScan.started.desc())
.limit(limit)
.offset(offset)
)
if order_by is not None:
query = apply_order_by(order_by, cls, query)
return query.count(clear_limit=True), list(query.dicts())
class MalwareHit(Model):
"""Represents a malicious or suspicious file."""
class Meta:
database = instance.db
db_table = "malware_hits"
#: An id of a scan, unique per server.
id = PrimaryKeyField()
#: A reference to :class:`MalwareScan`.
scanid = ForeignKeyField(
MalwareScan, null=False, related_name="hits", on_delete="CASCADE"
)
#: The owner of the file.
owner = CharField(null=False)
#: The user a file belongs to (is in user's home but owned by another user)
user = CharField(null=False)
#: The original path to the file.
orig_file = FilenameField(null=False)
#: The type of infection (signature).
type = CharField(null=False)
#: Whether the file is malicious or just suspicious.
#: Suspicious files are not displayed in UI but sent for analysis to MRS.
malicious = BooleanField(null=False, default=False)
#: The hash of the files as provided by AI-BOLIT.
hash = CharField(null=True)
#: The size of the file.
size = CharField(null=True)
#: The exact timestamp when AI-BOLIT has detected the file.
#:
#: FIXME: unused? It looks like it was intended to resolve some possible
#: race conditions with parallel scans, but we don't actually use it
#: from the DB - we only compare the value in scan report
#: with :attr:`cleaned_at`.
timestamp = FloatField(null=True)
#: The current status of the file.
#: Must be one of :class:`.MalwareHitStatus`.
status = CharField(default=MalwareHitStatus.FOUND)
#: Timestamp when the file was last cleaned.
cleaned_at = FloatField(null=True)
resource_type = CharField(
null=False,
constraints=[
Check(
"resource_type in {}".format(
(
MalwareScanResourceType.DB.value,
MalwareScanResourceType.FILE.value,
)
)
)
],
)
app_name = CharField(null=True)
db_host = CharField(null=True)
db_port = CharField(null=True)
db_name = CharField(null=True)
snippet = CharField(null=True)
@property
def orig_file_path(self):
orig_file = cast(str, self.orig_file)
return Path(orig_file)
class OrderBy:
@staticmethod
def status():
return (
Case(
MalwareHit.status,
(
(MalwareHitStatus.CLEANUP_PENDING, 0),
(MalwareHitStatus.CLEANUP_STARTED, 1),
(MalwareHitStatus.FOUND, 2),
(MalwareHitStatus.CLEANUP_DONE, 4),
(MalwareHitStatus.CLEANUP_REMOVED, 5),
),
100,
),
)
@classmethod
def _hits_list(
cls,
clauses,
since=0,
to=None,
limit=None,
offset=None,
search=None,
by_scan_id=None,
user=None,
order_by=None,
by_status=None,
ids=None,
**kwargs,
):
hits = cls.select(cls, MalwareScan).join(MalwareScan)
to = to or time()
pattern = "%{}%".format(search)
started = (MalwareScan.started >= since) & (MalwareScan.started <= to)
full_clauses = clauses & started
if search is not None:
full_clauses &= SQL(
"CAST(orig_file AS TEXT) LIKE ?", (pattern,)
) | (cls.user**pattern)
if user is not None:
full_clauses &= MalwareHit.user == user
if by_scan_id is not None:
full_clauses &= MalwareScan.scanid == by_scan_id
if by_status is not None:
full_clauses &= MalwareHit.status << by_status
# `max_count` is used for pagination, must not include `ids`
max_count_clauses = full_clauses
if ids is not None:
full_clauses &= MalwareHit.id.in_(ids)
ordered = hits.where(full_clauses).limit(limit).offset(offset)
if order_by is not None:
ordered = apply_order_by(order_by, MalwareHit, ordered)
max_count = cls._hits_num(max_count_clauses)
result = [row.as_dict() for row in ordered]
return max_count, result
@classmethod
def suspicious_list(cls, *args, **kwargs):
return cls._hits_list(cls.is_suspicious(), *args, **kwargs)
@classmethod
def _hits_num(
cls, clauses=None, since=None, to=None, user=None, order_by=None
):
if since and to:
clauses &= (MalwareScan.started >= since) & (
MalwareScan.started <= to
)
if user is not None:
clauses &= cls.user == user
q = cls.select(fn.COUNT(cls.id)).join(MalwareScan).where(clauses)
if order_by is not None:
q = apply_order_by(order_by, MalwareHit, q)
return q.scalar()
@classmethod
def malicious_num(cls, since, to, user=None):
return cls._hits_num(
(cls.status.not_in(MalwareHitStatus.CLEANUP) & cls.malicious),
since,
to,
user,
)
@classmethod
def malicious_list(cls, *args, ignore_cleaned=False, **kwargs):
clauses = cls.malicious
if ignore_cleaned:
clauses &= cls.status.not_in(MalwareHitStatus.CLEANUP)
return cls._hits_list(clauses, *args, **kwargs)
@classmethod
def set_status(cls, hits, status, cleaned_at=None):
hits = [row.id for row in hits]
def expression(ids, cls, status, cleaned_at):
fields_to_update = {
"status": status,
}
if cleaned_at is not None:
fields_to_update["cleaned_at"] = cleaned_at
return cls.update(**fields_to_update).where(cls.id.in_(ids))
return execute_iterable_expression(
expression, hits, cls, status, cleaned_at
)
@classmethod
def delete_instances(cls, to_delete: list):
to_delete = [row.id for row in to_delete]
def expression(ids):
return cls.delete().where(cls.id.in_(ids))
return execute_iterable_expression(expression, to_delete)
@classmethod
def update_instances(cls, to_update: list):
for data in to_update:
for instance, new_fields_data in data.items():
for field, value in new_fields_data.items():
setattr(instance, field, value)
instance.save()
@classmethod
def is_infected(cls) -> Expression:
clauses = (
cls.status.in_(
[
MalwareHitStatus.FOUND,
]
)
& cls.malicious
)
return clauses
@classmethod
def is_suspicious(cls):
return ~cls.malicious
@classmethod
def malicious_select(
cls, ids=None, user=None, cleanup=False, restore=False, **kwargs
):
def expression(chunk_of_ids, cls, user):
clauses = cls.malicious
if chunk_of_ids is not None:
clauses &= cls.id.in_(chunk_of_ids)
elif cleanup:
clauses &= cls.status.not_in(MalwareHitStatus.CLEANUP)
elif restore:
clauses &= cls.status.in_(MalwareHitStatus.RESTORABLE)
if user is not None:
if isinstance(user, str):
user = [user]
clauses &= cls.user.in_(user)
return cls.select().where(clauses)
return list(
get_results_iterable_expression(
expression, ids, cls, user, exec_expr_with_empty_iter=True
)
)
@classmethod
def get_hits(cls, files, *, statuses=None):
def expression(files):
clauses = cls.orig_file.in_(files)
if statuses:
clauses &= cls.status.in_(statuses)
return cls.select().where(clauses)
return get_results_iterable_expression(expression, files)
@classmethod
def get_db_hits(cls, hits_info: Set):
paths = [entry.path for entry in hits_info]
apps = [entry.app_name for entry in hits_info]
paths_apps = [(entry.path, entry.app_name) for entry in hits_info]
hits = list(
MalwareHit.select()
.where(MalwareHit.orig_file.in_(paths))
.where(MalwareHit.app_name.in_(apps))
)
hits = [
hit for hit in hits if (hit.orig_file, hit.app_name) in paths_apps
]
return hits
@classmethod
def delete_hits(cls, files):
def expression(files):
return cls.delete().where(cls.orig_file.in_(files))
return execute_iterable_expression(expression, files)
def refresh(self):
return type(self).get(self._pk_expr())
@classmethod
def refresh_hits(cls, hits: Iterable[MalwareHit], include_scan_info=False):
def expression(hits):
query = cls.select()
if include_scan_info: # use a single query to get scan info
query = cls.select(cls, MalwareScan).join(MalwareScan)
return query.where(cls.id.in_([hit.id for hit in hits]))
return list(get_results_iterable_expression(expression, hits))
@classmethod
def db_hits(cls):
return cls.select().where(
cls.resource_type == MalwareScanResourceType.DB.value
)
@classmethod
def db_hits_pending_cleanup(cls) -> Expression:
"""Return db hits that are in queue for cleanup"""
return cls.db_hits().where(
cls.status == MalwareHitStatus.CLEANUP_PENDING,
)
@classmethod
def db_hits_under_cleanup(cls) -> Expression:
"""Return db hits for which the cleanup is in progress"""
return cls.db_hits().where(
cls.status == MalwareHitStatus.CLEANUP_STARTED
)
@classmethod
def db_hits_under_restoration(cls) -> Expression:
"""Return db hits for which the restore is in progress"""
return cls.db_hits().where(
cls.status == MalwareHitStatus.CLEANUP_RESTORE_STARTED
)
@classmethod
def db_hits_under_cleanup_in(cls, hit_info_set):
"""
Return db hits for which the cleanup is in progress
specified by the provided set of MalwareDatabaseHitInfo
"""
# FIXME: Use peewee.ValuesList when peewee is updated
# to obtain all hits using one query without additional processing
path_set = {hit_info.path for hit_info in hit_info_set}
app_name_set = {hit_info.app_name for hit_info in hit_info_set}
path_app_name_set = {
(hit_info.path, hit_info.app_name) for hit_info in hit_info_set
}
query = (
cls.db_hits_under_cleanup()
.where(cls.orig_file.in_(path_set))
.where(cls.app_name.in_(app_name_set))
)
return [
hit
for hit in query
if (hit.orig_file, hit.app_name) in path_app_name_set
]
@classmethod
def db_hits_pending_cleanup_restore(cls):
return cls.db_hits().where(
cls.status == MalwareHitStatus.CLEANUP_RESTORE_PENDING
)
@classmethod
def db_hits_under_cleanup_restore(cls):
return cls.db_hits().where(
cls.status == MalwareHitStatus.CLEANUP_RESTORE_STARTED
)
@staticmethod
def group_by_attribute(
*hit_list_list: List["MalwareHit"], attribute: str
) -> Dict[str, List["MalwareHit"]]:
hit_list = sorted(
(hit for hit in itertools.chain.from_iterable(hit_list_list)),
key=attrgetter(attribute),
)
return {
attr_value: list(hits)
for attr_value, hits in itertools.groupby(
hit_list,
key=attrgetter(attribute),
)
}
def as_dict(self):
return {
"id": self.id,
"username": self.user,
"file": self.orig_file,
"created": self.scanid.started,
"scan_id": self.scanid_id,
"scan_type": self.scanid.type,
"resource_type": self.resource_type,
"type": self.type,
"hash": self.hash,
"size": self.size,
"malicious": self.malicious,
"status": self.status,
"cleaned_at": self.cleaned_at,
"extra_data": {},
"db_name": self.db_name,
"app_name": self.app_name,
"db_host": self.db_host,
"db_port": self.db_port,
"snippet": self.snippet,
"table_fields": (
list(
MalwareHistory.select(
MalwareHistory.table_name,
MalwareHistory.table_field,
MalwareHistory.table_row_inf,
)
.where(
MalwareHistory.app_name == self.app_name,
MalwareHistory.db_host == self.db_host,
MalwareHistory.db_port == self.db_port,
MalwareHistory.db_name == self.db_name,
MalwareHistory.path == self.orig_file,
MalwareHistory.resource_type == self.resource_type,
MalwareHistory.scan_id == self.scanid,
MalwareHistory.table_name.is_null(False),
MalwareHistory.table_field.is_null(False),
MalwareHistory.table_row_inf.is_null(False),
)
.dicts()
)
if self.resource_type == MalwareScanResourceType.DB.value
else []
),
}
def __repr__(self):
if self.app_name:
return "%s(orig_file=%r, app_name=%r)" % (
self.__class__.__name__,
self.orig_file,
self.app_name,
)
return "%s(orig_file=%r)" % (self.__class__.__name__, self.orig_file)
@dataclass(frozen=True)
class MalwareHitAlternate:
"""
Used as a replacement for MalwareHit for file hits only
"""
scanid: str
orig_file: str
# app_name is always None for file hits
app_name: None
owner: str
user: str
size: int
hash: str
type: str
timestamp: int
malicious: bool
@classmethod
def create(cls, scanid, filename, data):
return cls(
scanid=scanid,
orig_file=filename,
app_name=None,
owner=data["owner"],
user=data["user"],
size=data["size"],
hash=data["hash"],
type=data["hits"][0]["matches"],
timestamp=data["hits"][0]["timestamp"],
malicious=not data["hits"][0]["suspicious"],
)
@property
def orig_file_path(self):
return Path(os.fsdecode(self.orig_file))
class MalwareIgnorePath(Model):
"""A path that must be excluded from all scans"""
class Meta:
database = instance.db
db_table = "malware_ignore_path"
indexes = ((("path", "resource_type"), True),) # True refers to unique
CACHE = None
id = PrimaryKeyField()
#: The path itself. Wildcards or patterns are NOT supported.
path = CharField()
resource_type = CharField(
null=False, constraints=[Check("resource_type in ('file','db')")]
)
#: Timestamp when it was added.
added_date = IntegerField(null=False, default=lambda: int(time()))
@classmethod
def _update_cache(cls):
items = list(cls.select().order_by(cls.path).dicts())
cls.CACHE = items
@classmethod
def create(cls, **kwargs):
cls.CACHE = None
return super(MalwareIgnorePath, cls).create(**kwargs)
@classmethod
def delete(cls):
cls.CACHE = None
return super(MalwareIgnorePath, cls).delete()
@classmethod
def paths_count_and_list(
cls,
limit=None,
offset=None,
search=None,
resource_type: str = None,
user=None,
since=None,
to=None,
order_by=None,
):
q = cls.select().order_by(cls.path)
if since is not None:
q = q.where(cls.added_date >= since)
if to is not None:
q = q.where(cls.added_date <= to)
if search is not None:
q = q.where(cls.path.contains(search))
if resource_type is not None:
q = q.where(cls.resource_type == resource_type)
if offset is not None:
q = q.offset(offset)
if limit is not None:
q = q.limit(limit)
if order_by is not None:
q = apply_order_by(order_by, cls, q)
if user is not None:
user_home = get_abspath_from_user_dir(user)
q = q.where(
(cls.path.startswith(str(user_home) + "/"))
| (cls.path == str(user_home))
| (cls.path == str(get_crontab(user)))
)
max_count = q.count(clear_limit=True)
return (
max_count,
[model_to_dict(row) for row in q],
)
@classmethod
def path_list(cls, *args, **kwargs) -> List[str]:
_, path_list = cls.paths_count_and_list(*args, **kwargs)
return [row["path"] for row in path_list]
@classmethod
async def is_path_ignored(cls, check_path):
"""Checks whether path stored in MalwareIgnorePath cache or
if it's belongs to path from cache or if it matches patters from cache
:param str check_path: path to check
:return: bool: is ignored according MalwareIgnorePath
"""
if cls.CACHE is None:
cls._update_cache()
path = Path(check_path)
for p in cls.CACHE:
await asyncio.sleep(0)
ignored_path = Path(p["path"])
if (path == ignored_path) or (ignored_path in path.parents):
return True
return False
class MalwareHistory(Model):
"""Records every event related to :class:`MalwareHit` records"""
class Meta:
database = instance.db
db_table = "malware_history"
#: The path of the file.
path = FilenameField(null=False)
app_name = CharField(null=True)
resource_type = CharField(
null=False,
constraints=[
Check(
"resource_type in {}".format(
(
MalwareScanResourceType.DB.value,
MalwareScanResourceType.FILE.value,
)
)
)
],
default=MalwareScanResourceType.FILE.value,
)
#: What happened with the file. Should be one of :class:`.MalwareEvent`.
event = CharField(null=False)
#: What kind of scan has detected the file, or `manual` for manual actions.
#: See :class:`.MalwareScanType`.
cause = CharField(null=False)
#: The name of the user who has triggered the event.
initiator = CharField(null=False)
#: A snapshot of :attr:`MalwareHit.owner`
file_owner = CharField(null=False)
#: A snapshot of :attr:`MalwareHit.user`
file_user = CharField(null=False)
#: Timestamp when the event took place.
ctime = IntegerField(null=False, default=lambda: int(time()))
#: Database host name (for db type scan).
db_host = CharField(null=True)
#: Database port (for db type scan).
db_port = CharField(null=True)
#: Database name (for db type scan).
db_name = CharField(null=True)
#: Infected table name (for db type scan)
table_name = CharField(null=True)
#: Infected field name (for db type scan)
table_field = CharField(null=True)
#: Infected table row id (for db type scan)
table_row_inf = IntegerField(null=True)
#: Scan ID reference (for generating `table_fields`)
scan_id = CharField(null=True)
@classmethod
def get_history(
cls, since, to, limit, offset, user=None, search=None, order_by=None
):
clauses = (cls.ctime >= since) & (cls.ctime <= to)
if search:
clauses &= (cls.event.contains(search)) | (
SQL("(INSTR(path, ?))", (search,))
)
if user:
clauses &= cls.file_user == user
query = cls.select().where(clauses).limit(limit).offset(offset).dicts()
if order_by is not None:
query = apply_order_by(order_by, MalwareHistory, query)
list_result = list(query)
return query.count(clear_limit=True), list_result
@classmethod
def save_event(cls, **kwargs):
cls.insert(
initiator=kwargs.pop("initiator", None) or UserType.ROOT,
cause=kwargs.pop("cause", None) or MalwareScanType.MANUAL,
resource_type=kwargs.pop("resource_type", None)
or MalwareScanResourceType.FILE.value,
**kwargs,
).execute()
@classmethod
def save_events(cls, hits: List[dict]):
with instance.db.atomic():
# The maximum number of inserts using insert_many is
# SQLITE_LIMIT_VARIABLE_NUMBER / # of columns.
# SQLITE_LIMIT_VARIABLE_NUMBER is set at SQLite compile time with
# the default value of 999.
for hits_chunk in split_for_chunk(
hits, chunk_size=999 // len(cls._meta.columns)
):
cls.insert_many(hits_chunk).execute()
@classmethod
def get_failed_cleanup_events_count(cls, paths: list, *, since: int):
return (
cls.select(cls.path, fn.COUNT())
.where(
cls.path.in_(paths)
& (cls.event == FAILED_TO_CLEANUP)
& (cls.ctime >= since)
)
.group_by(cls.path)
.tuples()
)