"""
Simple unix socket RPC server implementation
"""
import asyncio
import functools
import io
import json
import os
import select
import socket
import sys
import time
from contextlib import suppress
from logging import getLogger
from typing import Sequence
from psutil import Process
import sentry_sdk
from defence360agent.api import inactivity
from defence360agent.application import app
from defence360agent.contracts.config import SimpleRpc as Config
from defence360agent.feature_management.exceptions import (
FeatureManagementError,
)
from defence360agent.internals.auth_protocol import UnixSocketAuthProtocol
from defence360agent.model import tls_check
from defence360agent.model.simplification import run_in_executor
from defence360agent.utils import is_root_user
from defence360agent.utils.buffer import LineBuffer
from defence360agent.subsys.panels import hosting_panel
from defence360agent.subsys.panels.base import InvalidTokenException
from defence360agent.rpc_tools.exceptions import (
ResponseError,
ServiceStateError,
SocketError,
)
from defence360agent.rpc_tools.lookup import Endpoints, UserType
from defence360agent.rpc_tools.utils import is_running, rpc_is_running
from defence360agent.rpc_tools.validate import ValidationError
from defence360agent.rpc_tools import ERROR, SUCCESS, WARNING
logger = getLogger(__name__)
class RpcServiceState:
# If need DB and agent should be running
# e.g. on-demand scan
RUNNING = "running"
# Agent should be stopped
STOPPED = "stopped"
# It doesn't matter for operation running or stopping the agent
# if agent is running - using socket, instead of direct communication
ANY = "any"
# No need DB and UI interaction
# preferable for use direct instead any for execution external process
# e.g. enable/disable plugins/features
DIRECT = "direct"
async def _execute_request(coro, method):
try:
result = await coro
except ValidationError as e:
result = {
"result": WARNING,
"messages": e.errors,
}
result.update(e.extra_data)
return result
except (PermissionError, FeatureManagementError) as e:
msg, *args = e.args
logger.error(msg, *args)
return {
"result": ERROR,
"messages": [msg % tuple(args)],
}
except Exception as e:
sentry_sdk.capture_exception(e)
logger.error(
"Something went wrong while processing %s (%s)", method, str(e)
)
return {"result": ERROR, "messages": str(e)}
else:
return {"result": SUCCESS, "messages": [], "data": result}
def _apply_middleware(method, user):
cb = Endpoints.route_to_endpoint
if isinstance(method, (list, tuple)):
hashable = tuple(method)
common = app.MIDDLEWARE.get(None, [])
specific = app.MIDDLEWARE.get(hashable, [])
excluded = app.MIDDLEWARE_EXCLUDE.get(hashable, [])
for mw, users in reversed(common + specific):
if (user in users) and (mw not in excluded):
logger.debug("Applying middleware %s", mw.__name__)
cb = mw(cb)
return cb
def _find_uds_inodes(socket_path: str) -> Sequence[str]:
"""Find inodes corresponding to the unix domain socket path."""
with open(
"/proc/net/unix",
encoding=sys.getfilesystemencoding(),
errors=sys.getfilesystemencodeerrors(),
) as file:
return [line.split()[-2] for line in file if socket_path in line]
class _RpcServerProtocol(UnixSocketAuthProtocol):
def __init__(self, loop, sink, user):
self._loop = loop
self._sink = sink
self.user = user
self._transport = None
self._buf = LineBuffer()
def preprocess_data(self, data: str):
decoded = json.loads(data)
user_type, user_name = hosting_panel.HostingPanel().authenticate(
self, decoded
)
self.user = user_type
if user_name is not None:
decoded["params"]["user"] = user_name
# add calling process
try:
calling_process = Process(self._pid).cmdline()
except Exception as e:
calling_process = [str(e)]
decoded["calling_process"] = calling_process
return decoded
def data_received(self, data):
self._buf.append(data.decode())
for msg in self._buf:
try:
result = self.preprocess_data(msg)
method = result["command"]
params = result["params"]
logger.debug("Data received: %r", data)
cb = _apply_middleware(method, self.user)
# TODO: fix that there is no json flag in params
self._loop.create_task(
self._dispatch(
method, params, cb(result, self._sink, self.user)
)
)
except InvalidTokenException as e:
# without events in Sentry
logger.warning("Incorrect token provided")
self._write_response({"result": ERROR, "messages": str(e)})
except Exception as e:
logger.exception(
"Something went wrong before processing %s", data.decode()
)
self._write_response({"result": ERROR, "messages": str(e)})
async def _dispatch(self, method, params, coro):
with inactivity.track.task("rpc_{}".format(method)):
# route and save result to 'result'
response = await _execute_request(coro, method)
logger.info(
"Response: method - {}, data - {}".format(method, response)
)
self._write_response(response)
def connection_lost(self, transport):
self._transport = None
def _write_response(self, data):
if self._transport is None:
logger.warning("Cannot send RPC response: connection lost.")
return
else:
try:
self._transport.write((json.dumps(data) + "\n").encode())
except Exception as e:
logger.exception(e) # TODO: need to own message error
def _check_socket_folder_permissions(socket_path):
dir_name = os.path.dirname(socket_path)
os.makedirs(dir_name, exist_ok=True)
os.chmod(dir_name, 0o755)
class RpcServer:
SOCKET_PATH = Config.SOCKET_PATH
USER = UserType.ROOT
SOCKET_MODE = 0o700
@classmethod
async def create(cls, loop, sink):
_check_socket_folder_permissions(cls.SOCKET_PATH)
if os.path.exists(cls.SOCKET_PATH):
os.unlink(cls.SOCKET_PATH)
server = await loop.create_unix_server(
lambda: _RpcServerProtocol(loop, sink, cls.USER), cls.SOCKET_PATH
)
os.chmod(cls.SOCKET_PATH, cls.SOCKET_MODE)
return server
class RpcServerAV:
USER = UserType.ROOT
SOCKET_PATH = Config.SOCKET_PATH
PROTOCOL_CLASS = _RpcServerProtocol
@classmethod
async def create(cls, loop, sink):
"""Looking for socket in /proc/net/unix and check which descriptor
corresponded to it by comparing inode
$ ls -l /proc/[pid]/fd
lrwx------ 1 root root 64 Apr 11 07:20 4 -> socket:[2866765]
$ cat /proc/net/unix
Num RefCount Protocol Flags Type St Inode Path
ffff880054c0a4c0: 00000002 00000000 00010000 0001 01 2866765 /var/run/defence360agent/simple_rpc.sock # noqa
"""
def safe_readlink(*args, **kwargs):
"""Return empty path on error."""
with suppress(OSError):
return os.readlink(*args, **kwargs)
return ""
# find inodes for the SOCKET_PATH
_socket_path = cls.SOCKET_PATH
_check_socket_folder_permissions(_socket_path)
if _socket_path.startswith("/var/run"):
# remove /var prefix, see DEF-16201
_socket_path = _socket_path[len("/var") :]
inodes = _find_uds_inodes(_socket_path)
# find socket fds corresponding to the inodes
last_error = None
for inode in inodes:
try:
with os.scandir("/proc/self/fd") as it:
for fd in it:
if safe_readlink(fd.path) == "socket:[{}]".format(
inode
):
socket_fd = int(fd.name)
break # found fd
else: # no break, not found fd for given inode
continue # try another inode
break # found fd
except OSError as e:
last_error = e
else: # no break, not found
raise SocketError(
"[{}] Socket {!r} for {} not found.".format(
"inode" * (not inodes), cls.SOCKET_PATH, cls.USER
)
) from last_error
_socket = socket.fromfd(
socket_fd,
socket.AF_UNIX,
socket.SOCK_STREAM | socket.SOCK_NONBLOCK,
)
server = await loop.create_unix_server(
lambda: cls.PROTOCOL_CLASS(loop, sink, cls.USER), sock=_socket
)
return server
class NonRootRpcServerAV(RpcServerAV):
USER = UserType.NON_ROOT
SOCKET_PATH = Config.NON_ROOT_SOCKET_PATH
class NonRootRpcServer(RpcServer):
SOCKET_PATH = Config.NON_ROOT_SOCKET_PATH
USER = UserType.NON_ROOT
SOCKET_MODE = 0o777
class _RpcClientImpl:
def __init__(self, socket_path):
try:
self._sock = socket.socket(
socket.AF_UNIX, socket.SOCK_STREAM | socket.SOCK_NONBLOCK
)
self._sock.connect(socket_path)
except (ConnectionRefusedError, FileNotFoundError, BlockingIOError):
raise ServiceStateError()
def dispatch(self, method, params):
self._sock.sendall(
(json.dumps({"command": method, "params": params}) + "\n").encode()
)
try:
data = self._sock_recv_until(terminator_byte=b"\n")
except ConnectionResetError as e:
raise ResponseError("Connection reset: %s".format(e)) from e
try:
response = json.loads(data.decode())
except Exception as e:
raise ResponseError(
"Error parsing RPC response {!r}".format(data)
) from e
return response
def _sock_recv_until(self, terminator_byte):
assert not self._sock.getblocking()
chunks = []
while (not chunks) or (terminator_byte not in chunks[-1]):
fdread_list = [self._sock.fileno()]
rwx_fdlist = select.select(
fdread_list,
[],
[],
# naive timeout for one-shot response
# scenario
Config.CLIENT_TIMEOUT,
)
fdready_list = rwx_fdlist[0]
if self._sock.fileno() not in fdready_list:
if any(rwx_fdlist):
raise SocketError(
"select() = {!r} resulted in error".format(rwx_fdlist)
)
else:
raise SocketError("request timeout")
chunk = self._sock.recv(io.DEFAULT_BUFFER_SIZE)
if len(chunk) == 0:
raise SocketError("Empty response from socket.recv()")
chunks.append(chunk)
return b"".join(chunks)
class _NoRpcImpl:
def __init__(self, sink=None):
self._sink = sink
# suppress is for doing those things idempotent way
# PSSST! simplification.run_in_executor() is main thread now! :-X
# with suppress(tls_check.OverridingReset):
# tls_check.reset("main CLI thread for stopped agent")
with suppress(tls_check.OverridingReset):
loop = asyncio.get_event_loop()
loop.run_until_complete(run_in_executor(loop, tls_check.reset))
def dispatch(self, method, params):
loop = asyncio.get_event_loop()
logger.info("Executing {}, params: {}".format(method, params))
request = {"command": method, "params": params}
cb = _apply_middleware(method, user=UserType.ROOT)
return loop.run_until_complete(
_execute_request(cb(request, self._sink), method)
)
class RpcClient:
"""
One RpcClient instance is suitable to use for multiple ipc calls
:param RpcServiceState require_svc_is_running: whether to provide direct
endpoints binding if the service is stopped.
:param int reconnect_with_timeout: timeout in sec for reconnect retries
:param int num_retries: number of reconnect retries
"""
def __init__(
self,
*,
require_svc_is_running=RpcServiceState.RUNNING,
reconnect_with_timeout=None,
num_retries=1
):
self._impl = None
self._socket_path = (
Config.SOCKET_PATH
if is_root_user()
else Config.NON_ROOT_SOCKET_PATH
)
if (
require_svc_is_running == RpcServiceState.STOPPED
and rpc_is_running()
):
raise ServiceStateError(RpcServiceState.RUNNING)
if require_svc_is_running in (
RpcServiceState.ANY,
RpcServiceState.RUNNING,
):
try:
if reconnect_with_timeout:
self._impl = self._reconnect_with_timeout(
reconnect_with_timeout, num_retries
)
else:
self._impl = _RpcClientImpl(self._socket_path)
return
except ServiceStateError:
if require_svc_is_running == RpcServiceState.RUNNING:
raise
if self._impl is None:
# In other cases (ANY, STOPPED, DIRECT) need to use _NoRpcImpl
assert (
is_root_user()
), "_NoRpcImpl is not available for non root user"
self._impl = _NoRpcImpl()
def __getattr__(self, method):
return functools.partial(self._dispatch, method)
def cmd(self, *command):
return functools.partial(self._dispatch, command)
def _dispatch(self, method, **params):
response = self._impl.dispatch(method, params)
if isinstance(method, (list, tuple)):
if response["result"] in (ERROR, WARNING):
return response["result"], response["messages"]
else:
assert response["result"] == SUCCESS
return response["result"], response["data"]
else:
if response["result"] in (ERROR, WARNING):
raise ResponseError(response["messages"])
return response["data"]
def _reconnect_with_timeout(self, timeout, num_retries):
while True:
try:
return _RpcClientImpl(self._socket_path)
except ServiceStateError:
if num_retries:
logger.info(
"Waiting %d second(s) before retry...", timeout
)
time.sleep(timeout)
num_retries -= 1
else:
raise