import inspect
import logging
import os
import time
from datetime import datetime, timedelta
from peewee import (
BlobField,
CharField,
DateField,
ForeignKeyField,
IntegerField,
PeeweeException,
)
from defence360agent.model import instance, Model
#: seconds in a POSIX day
POSIX_DAY = 24 * 60 * 60
logger = logging.getLogger(__name__)
class FilenameField(BlobField):
"""
Class to store file names in database
"""
def db_value(self, value):
return os.fsencode(value)
def python_value(self, value):
return os.fsdecode(value)
class ScanPathField(CharField):
REALTIME_SCAN_PATH_STUB = "list_of_files"
def db_value(self, value):
if isinstance(value, list):
return self.REALTIME_SCAN_PATH_STUB
return value
class ModelError(PeeweeException):
"""
Model exception. Please use this one from other modules instead
PeeweeException directly
"""
pass
async def run_in_executor(loop, cb, *args):
"""
Fake run_in_executor() test (DEF-4541)
"""
return cb(*args)
def remove_old_and_truncate(
table: Model, num_days: int, max_count: int
) -> int:
"""
Removes records that is older that *num_days* days and
all others that are out of range *max_count* from *table*.
Returns count of rows deleted.
"""
has_timestamp = getattr(table, "timestamp", False)
if not has_timestamp:
raise ValueError("No 'timestamp' column in table {!r}".format(table))
# keep no more than *max_count* rows that are newer than *num_days*
end_save_time = time.time() - num_days * POSIX_DAY
to_keep = (
table.select(table.timestamp)
.order_by(table.timestamp.desc())
.limit(max_count)
.where(table.timestamp > end_save_time)
)
deleted_count = (
table.delete().where(table.timestamp.not_in(to_keep)).execute()
)
return deleted_count
class Eula(Model):
"""Keeps track of updates and acceptions of end user license agreement.
Admins will be asked to accept EULA if the latest version is not accepted
yet.
"""
class Meta:
database = instance.db
db_table = "eula"
#: Date when EULA was updated.
updated = DateField(primary_key=True)
#: Timestamp when EULA was accepted.
accepted = IntegerField(null=True, default=None)
@classmethod
def is_accepted(cls) -> bool:
unaccepted = next(
iter(
cls.select()
.where(cls.accepted.is_null())
.order_by(cls.updated)
.limit(1)
),
None,
)
return unaccepted is None
@classmethod
def accept(cls) -> None:
cls.update(accepted=time.time()).where(
cls.accepted.is_null()
).execute()
def get_models(module):
return [
obj
for _, obj in inspect.getmembers(
module,
lambda obj: inspect.isclass(obj)
and issubclass(obj, Model)
and obj != Model,
)
]
def create_tables(module):
instance.db.connect()
instance.db.create_tables(get_models(module), safe=True)
class ApplyOrderBy:
@staticmethod
def resolve_nodes(_model, column_name: str) -> tuple:
"""
:param _model: peewee.Model or peewee.ForeignKeyField
:param column_name: str
:return: tuple<peewee.Node>
"""
model = (
_model.rel_model if isinstance(_model, ForeignKeyField) else _model
)
nodes = ()
custom_order_by = getattr(model, "OrderBy", None)
if custom_order_by is not None:
nodes = getattr(custom_order_by, column_name, lambda: nodes)()
if not nodes:
node = getattr(model, column_name, None)
if node is not None:
nodes = (node,) # type: ignore
return nodes
@staticmethod
def get_nodes(model, column_names: list) -> list:
"""
:param model: peewee.Model or peewee.ForeignKeyField
:param column_names: list<str>
:return: list<peewee.Node>
"""
column_name, rest = column_names[0], column_names[1:]
nodes = ApplyOrderBy.resolve_nodes(model, column_name)
result = []
for node_or_model in nodes:
if rest: # model
for node in ApplyOrderBy.get_nodes(node_or_model, rest):
result.append(node)
else: # node
result.append(node_or_model)
return result
def __call__(self, order_by, model, query_builder):
"""
:param order_by: list<OrderBy>
:param model: peewee.Model or peewee.ForeignKeyField
:param query_builder: peewee.Query
:return: peewee.Query with applied order_by
"""
orders = []
for order in order_by:
nodes = ApplyOrderBy.get_nodes(model, order.column_name.split("."))
for node in nodes:
orders.append(node.desc() if order.desc else node)
return query_builder.order_by(*orders)
apply_order_by = ApplyOrderBy()