import functools
import inspect
from typing import Any, Callable, Dict, Set, Tuple
from defence360agent.contracts.config import UserType
from defence360agent.utils import Scope
from .exceptions import RpcError
_RPC_MARK = "__rpc_command"
class DuplicateHandlerError(Exception):
pass
class NotCoroutineError(Exception):
pass
class Endpoints:
"""Endpoints class implements registration and lookup for functions
implementing RPC calls."""
SCOPE = Scope.AV_IM360
APPLICABLE_USER_TYPES = set() # type: Set[str]
__COMMAND_MAP = {
UserType.ROOT: {},
UserType.NON_ROOT: {},
} # type: Dict[str, Dict]
_subclasses = []
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._subclasses.append(cls)
@classmethod
def get_active_endpoints(cls):
# consider endpoint as active if it has at least one RPC call handler
active_endpoints = []
for subcls in cls._subclasses:
rpc_handlers = inspect.getmembers(
subcls, lambda item: getattr(item, _RPC_MARK, None)
)
if rpc_handlers:
active_endpoints.append(subcls)
return active_endpoints
def __init__(self, sink):
self._sink = sink
@classmethod
async def route_to_endpoint(cls, request, sink, user=UserType.ROOT) -> Any:
"""Find appropriate class and function within that class that
implements processing for request based on supplied 'command' within.
Call that (async) function and return its result.
If target class/function for given request['command'] is not found then
RpcError exception is raised."""
command = request["command"]
key = tuple(command)
if key not in cls.__COMMAND_MAP[user]:
raise RpcError(
'Endpoint not found for RPC method "%s"'
% " ".join(request["command"])
)
cls_handler, handler_name = cls.__COMMAND_MAP[user][key]
handler = getattr(cls_handler(sink), handler_name)
return await handler(**request["params"])
@classmethod
def register_rpc_handlers(cls) -> None:
"""Registers RPC handlers for all functions within a class.
Functions should be decorated with @bind('command', ...)."""
for name in dir(cls):
if name.startswith("_"):
continue
attr = getattr(cls, name)
command = getattr(attr, _RPC_MARK, None)
if command is None:
continue
if not inspect.iscoroutinefunction(attr):
raise NotCoroutineError("Must be a coroutine")
for user_type in cls.APPLICABLE_USER_TYPES:
if command in cls.__COMMAND_MAP[user_type]:
msg = (
"Duplicate handlers for command {} ({}): {} and {}"
.format(
command,
user_type,
cls.__COMMAND_MAP[user_type][command],
attr,
)
)
raise DuplicateHandlerError(msg)
cls.__COMMAND_MAP[user_type][command] = (cls, name)
@classmethod
def reset_rpc_handlers(cls):
"""Clears all previously made registrations."""
for user_type in {UserType.NON_ROOT, UserType.ROOT}:
cls.__COMMAND_MAP[user_type] = {}
class CommonEndpoints(Endpoints):
"""Endpoints available both for root and non root users."""
APPLICABLE_USER_TYPES = {UserType.NON_ROOT, UserType.ROOT}
class RootEndpoints(Endpoints):
"""Endpoints available only for root user."""
APPLICABLE_USER_TYPES = {UserType.ROOT}
class UserOnlyEndpoints(Endpoints):
"""Endpoints available only for non root users."""
APPLICABLE_USER_TYPES = {UserType.NON_ROOT}
LOOKUP_ASSIGNMENTS = functools.WRAPPER_ASSIGNMENTS + (_RPC_MARK,)
def wraps(
wrapped, assigned=LOOKUP_ASSIGNMENTS, updated=functools.WRAPPER_UPDATES
):
"""Decorator replacing functools.wraps for rpc handlers"""
return functools.partial(
functools.update_wrapper,
wrapped=wrapped,
assigned=assigned,
updated=updated,
)
def bind(*command):
"""Mark a function as processing RPC calls for command."""
def decorator(func):
setattr(func, _RPC_MARK, command)
return func
return decorator