import logging
import warnings
from functools import wraps
from defence360agent.contracts import eula
from defence360agent.contracts.config import Core, UserType
from defence360agent.contracts.license import LicenseCLN
from defence360agent.contracts.messages import MessageType
logger = logging.getLogger(__name__)
def add_license(f):
@wraps(f)
async def wrapper(*args, **kwargs):
result = await f(*args, **kwargs)
assert isinstance(result, dict), (
"Result should be a dictionary %s" % result
)
license = LicenseCLN.license_info()
result["license"] = license
return result
return wrapper
def add_license_user(f):
@wraps(f)
async def wrapper(*args, **kwargs):
result = await f(*args, **kwargs)
assert isinstance(result, dict), (
"Result should be a dictionary %s" % result
)
license = LicenseCLN.license_info()
result["license"] = {
"status": license["status"],
"license_type": license.get("license_type"),
}
return result
return wrapper
def add_eula(f):
@wraps(f)
async def wrapper(*args, **kwargs):
result = await f(*args, **kwargs)
assert isinstance(result, dict), (
"Result should be a dictionary %s" % result
)
eula_dict = None
# do not show eula if not registered or using free AV version
if LicenseCLN.is_valid() and (not LicenseCLN.is_free()):
if not await eula.is_accepted():
try:
eula_dict = {
"message": eula.message(),
"text": eula.text(),
"updated": eula.updated(),
}
except OSError as e:
eula_dict = {
"message": "Failed to read EULA",
"text": "Failed to read EULA: {}".format(str(e)),
"updated": "",
}
result["eula"] = eula_dict
return result
return wrapper
def add_version(f):
@wraps(f)
async def wrapper(*args, **kwargs):
result = await f(*args, **kwargs)
assert isinstance(result, dict), (
"Result should be a dictionary %s" % result
)
result["version"] = Core.VERSION
return result
return wrapper
def max_count(f):
@wraps(f)
async def wrapper(*args, **kwargs):
count, items = await f(*args, **kwargs)
return {"max_count": count, "items": items}
return wrapper
def counts(f):
@wraps(f)
async def wrapper(*args, **kwargs):
max_count, counts, items = await f(*args, **kwargs)
return {"max_count": max_count, "counts": counts, "items": items}
return wrapper
def collect_warnings(f):
@wraps(f)
async def wrapper(*args, **kwargs):
warnings.simplefilter("always", DeprecationWarning)
with warnings.catch_warnings(record=True) as warns:
result = await f(*args, **kwargs)
result["warnings"] = [" ".join(w.message.args) for w in warns]
return result
return wrapper
# Need only for backward compatibility
def default_to_items(f):
@wraps(f)
async def wrapper(*args, **kwargs):
result = await f(*args, **kwargs)
if not isinstance(result, dict):
result = {"items": result}
return result
return wrapper
def preserve_remote_addr(f):
"""
This middleware copies 'remote_addr' to 'client_addr'.
This is needed because send_command_invoke middleware may remove
remote_addr parameter from request.
Used for endpoints that need remote_addr in their logic.
:param f:
:return:
"""
@wraps(f)
async def wrapper(request, *args, **kwargs):
remote_addr = request["params"].get("remote_addr")
request["client_addr"] = remote_addr
return await f(request, *args, **kwargs)
return wrapper
def send_command_invoke_message(coro):
@wraps(coro)
async def wrapper(request, *args, **kwargs):
# get the sink to send CommandInvoke message
sink = None
if args:
sink = args[0]
elif "sink" in kwargs:
sink = kwargs["sink"]
if sink is not None:
params = dict(request["params"])
if "user" not in params:
# find user type (root/non-root) to determine access rights
user_type = None
if len(args) > 1:
user_type = args[1]
elif "user" in kwargs:
user_type = kwargs["user"]
if user_type == UserType.NON_ROOT:
params["user"] = True
# don't send passwords
if "password" in params:
params["password"] = "***"
# send message
await sink.process_message(
MessageType.CommandInvoke(
command=request["command"],
params=params,
calling_process=request.pop("calling_process", None),
)
)
request["params"].pop("remote_addr", None)
return await coro(request, *args, **kwargs)
return wrapper