import argparse
import ipaddress
import sys
from functools import lru_cache, partial
from itertools import chain
from typing import Any, Dict, Iterable, Iterator, Mapping, Tuple
from defence360agent.application import app
from defence360agent.contracts.config import Core as Config
from defence360agent.rpc_tools.utils import prepare_schema
from defence360agent.simple_rpc import RpcClient
from defence360agent.utils.cli import EXITCODE_NOT_FOUND
class SchemaToArgparse:
# NOTE: 'default' is a normalization rule, 'required' is a validation rule
OptionType = Iterator[Tuple[str, Any]]
def __init__(self, argument, options):
self._argument: str = argument
self._allowed: Iterable = options.get("allowed")
self._default: Any = options.get("default")
self._envvar: str = options.get("envvar", False)
self._help: str = options.get("help")
self._positional: bool = options.get("positional", False)
self._rename: str = options.get("rename")
self._required: bool = options.get("required", False)
self._type: str = options.get("type")
@property
def argname(self) -> str:
if self._positional:
return self._argument
return "--" + self._argument.replace("_", "-")
@property
def options(self):
argparse_options = dict(
chain(
self.choices(),
self.default(),
self.help(),
self.metavar(),
self.nargs(),
self.required(),
),
)
return argparse_options
def nargs(self) -> OptionType:
option = "nargs"
if self._type == "list":
# FIXME: all positional arguments are not required
# to support `rename`
if not self._positional and self._required and not self._envvar:
yield option, "+"
else:
yield option, "*"
elif self._positional and (self._envvar or self._default is None):
yield option, "?"
def choices(self) -> OptionType:
yield "choices", self._allowed
def help(self) -> OptionType:
yield "help", self._help
def metavar(self) -> OptionType:
option = "metavar"
if self._rename:
yield option, self._rename.upper()
elif self._type == "list":
yield option, self._argument.upper()
def default(self) -> OptionType:
if (
self._default is not None
and not self._envvar
and (self._type == "list" or not self._positional)
):
yield "default", self._default
def required(self):
if (
self._required
and self._type != "list"
and not self._envvar
# 'required' is an invalid argument for positionals
and not self._positional
):
yield "required", True
def schema_to_argparse(parser, argument, options):
if options.get("type") == "boolean":
required = options.get("required") and not options.get("envvar", False)
bool_parser = parser.add_mutually_exclusive_group(required=required)
bool_parser.add_argument(
"--" + argument.replace("_", "-"),
dest=argument,
action="store_true",
)
bool_parser.add_argument(
"--no-" + argument.replace("_", "-"),
dest=argument,
action="store_false",
)
bool_parser.set_defaults(**{argument: options.get("default")})
else:
converter = SchemaToArgparse(argument, options)
parser.add_argument(converter.argname, **converter.options)
class EnvParser:
@staticmethod
def format_help(envvar_parameter_options: Mapping):
if not envvar_parameter_options:
return ""
def format_arg(options):
if "help" in options:
return f"{options['envvar']}\t\t{options['help']}"
return options["envvar"]
return "\nenvironment variables: \n {}".format(
"\n ".join(
format_arg(options)
for options in envvar_parameter_options.values()
)
)
@staticmethod
def _validate(envvar, value, options):
if "isascii" in options:
try:
value.encode("ascii")
except UnicodeEncodeError:
return (
f"error: {envvar}={value} must only contain ascii symbols",
)
return None
@classmethod
def parse(
cls,
environ: Mapping,
command,
envvar_parameter_options,
exclude: Iterable[str],
) -> Dict[str, str]:
kwargs = {}
for parameter, options in envvar_parameter_options.items():
if parameter in exclude:
continue
envvar_name = options["envvar"]
try:
value = kwargs[parameter] = environ[envvar_name]
except KeyError:
if "default" in options:
kwargs[parameter] = options["default"]
continue
if not options.get("required"):
continue
msg = cls._format_error(
command,
envvar_parameter_options,
"error: environment variable {} is not defined".format(
envvar_name
),
)
print(msg, file=sys.stderr)
sys.exit(EXITCODE_NOT_FOUND)
else:
if err := cls._validate(envvar_name, value, options):
msg = cls._format_error(
command, envvar_parameter_options, err
)
print(msg, file=sys.stderr)
sys.exit(EXITCODE_NOT_FOUND)
return kwargs
@classmethod
def _format_error(cls, command, envvar_parameter_options, msg):
return "{command}:\n{help}\n\n{message}".format(
command=" ".join(command),
help=cls.format_help(envvar_parameter_options),
message=msg,
)
def is_valid_ipv4_addr(addr):
try:
ipaddress.IPv4Address(addr)
except ipaddress.AddressValueError:
return False
return True
def _filter_user(schema, user):
for key, values in schema.items():
if user in values.get("cli", {}).get("users", []):
yield key, values
def rpc_endpoint(command, require_rpc, **params):
return RpcClient(require_svc_is_running=require_rpc).cmd(*command)(
**params
)
def generate_endpoint_params(arg_parser_namespace, arguments):
kwargs = {}
for argument in arguments:
arg_parser_argument = argument.replace("-", "_")
value = getattr(arg_parser_namespace, arg_parser_argument, None)
if value is not None:
kwargs[argument] = value
return kwargs
def apply_parser(subparsers, schema):
_subparsers = {}
commands = sorted(schema.keys())
for methods in commands:
values = schema[methods]
assert isinstance(methods, (tuple, list))
parser = None
# generate subparsers
subparser = subparsers
for i, command in enumerate(methods):
# last element
if i == len(methods) - 1:
parser = subparser.add_parser(
name=command,
help=values.get("help"),
formatter_class=argparse.RawDescriptionHelpFormatter,
)
if any(
(c != methods and methods == c[: len(methods)])
for c in commands
):
_subparsers[methods] = parser.add_subparsers(
help="Available commands"
)
else:
# Need to reuse created subparsers for sub-commands, otherwise
# they will be overwritten.
#
# Example:
# For both of the commands:
# * malware on-demand queue put
# * malware on-demand queue remove
# only one subparser is created. We should add `queue`
# subparser only once in order to keep both `put` and `remove`.
hashable = tuple(methods[: i + 1])
exists_subparser = _subparsers.get(hashable)
if not exists_subparser:
subparser = _subparsers[hashable] = subparser.add_parser(
name=command,
help=values.get("help"),
).add_subparsers(help="Available commands")
else:
subparser = exists_subparser
assert parser, "parser is not defined"
# generate arguments
envvar_parameter_options = {}
for argument, options in values.get("schema", {}).items():
if "envvar" in options:
envvar_parameter_options[argument] = options
if options.get("envvar_only", False):
continue
if "rename" in options:
options.update(**values["schema"][options["rename"]])
options["required"] = False
options["positional"] = False
schema_to_argparse(parser, argument, options)
parser.epilog = EnvParser.format_help(envvar_parameter_options)
parser.add_argument(
"--json", action="store_true", help="return data in JSON format"
)
parser.add_argument("--verbose", "-v", action="count")
require_rpc = values.get("cli", {}).get("require_rpc", "running")
parser.set_defaults(
# Initializing `RpcClient` here for each command will
# inevitably lead to the `ServiceStateError`,
# because some endpoints require the agent to be stopped and
# some require it to be running. So we use `partial` to
# defer initialization until the command is selected.
endpoint=partial(rpc_endpoint, methods, require_rpc),
generate_endpoint_params=partial(
generate_endpoint_params,
arguments=values.get("schema", {}).keys(),
),
envvar_parameter_options=envvar_parameter_options,
command=methods,
)
def _apply_subparsers(subparsers, user):
schema = dict(_filter_user(prepare_schema(app.SCHEMA_PATHS), user))
apply_parser(subparsers, schema)
@lru_cache(maxsize=1)
def create_cli_parser():
parser = argparse.ArgumentParser(description="CLI for %s." % Config.NAME)
parser.add_argument("--log-config", help="logging config filename")
parser.add_argument(
"--console-log-level",
choices=["ERROR", "WARNING", "INFO", "DEBUG"],
help="Level of logging input to the console",
)
parser.add_argument(
"--remote-addr",
type=lambda ip: ip if is_valid_ipv4_addr(ip) else None,
help="Client's IP address for adding it to the whitelist",
)
subparsers = parser.add_subparsers(help="Available commands")
_apply_subparsers(subparsers, "root")
return parser