from collections import defaultdict
import json
import os
import subprocess
import sys
import time
import yaml
PRETTY_JSON_ARGS = {"sort_keys": True, "indent": 2, "separators": (",", ": ")}
EXITCODE_NOT_FOUND = 2
EXITCODE_WARNING = 3
EXITCODE_GENERAL_ERROR = 11
PAGERS = ["/bin/less", "/bin/more"]
SUCCESS, WARNING, ERROR = "success", "warnings", "error" # see simple_rpc
_CLI_MSG_PREFIX = {WARNING: "WARNING", ERROR: "ERROR"}
EXIT_CODES = {
SUCCESS: 0,
WARNING: EXITCODE_WARNING,
ERROR: EXITCODE_GENERAL_ERROR,
}
def pager(data):
pager = os.environ.get(
"PAGER", next((p for p in PAGERS if os.path.isfile(p)), None)
)
if pager is None:
print(data)
else:
subprocess.run([pager], input=data.encode(), stdout=sys.stdout)
class TablePrinter:
def __init__(self):
self._headers = {}
self._mappers = defaultdict(list)
self._right_aligned = {}
self._widths = {}
def set_field_properties(
self,
field,
mappers=None,
max_width=None,
right_align=False,
header=None,
):
if mappers:
self._mappers[field] = mappers
if max_width:
self._widths[field] = max_width
self._right_aligned[field] = right_align
self._headers[field] = header if header else field.upper()
def print(self, fields, items, file=sys.stdout):
headers = [self._headers.get(field, field.upper()) for field in fields]
widths = [len(field) for field in headers]
rows = []
for item in items:
row = []
for i, field in enumerate(fields):
v = item.get(field)
for mapper in self._mappers[field]:
v = mapper(v)
v = str(v)
if len(v) > widths[i]:
max_width = self._widths.get(field)
if max_width and len(v) > max_width:
v = v[: max_width - 3] + "..."
widths[i] = len(v)
row.append(v)
rows.append(row)
print(self._format_row(headers, widths, False))
for row in rows:
print(
self._format_row(
row, widths, self._right_aligned.get(field, False)
)
)
@staticmethod
def _add_padding(value, width, right_align):
if right_align:
return value.rjust(width)
return value.ljust(width)
@staticmethod
def _format_row(columns, widths, right_aligned):
cols = [
TablePrinter._add_padding(value, widths[i], right_aligned)
for i, value in enumerate(columns)
]
return " ".join(cols)
def n_a(value):
return value if value is not None else "n/a"
def to_int(value):
return int(value) if value is not None else value
def extract_field(field):
def extractor(value):
if isinstance(value, dict):
return value.get(field)
return value
return extractor
def print_table(data, field_props):
table = TablePrinter()
for props in field_props:
table.set_field_properties(*props)
table.print([item[0] for item in field_props], data)
def print_incidents(data):
field_props = (
("timestamp", [to_int]),
("abuser", [n_a]),
("country", [extract_field("code")]),
("times", [n_a]),
("name", [n_a]),
("severity", [n_a]),
)
print_table(data, field_props)
def add_ttl(data):
now = int(time.time())
for item in data:
expiration = item.get("expiration", 0)
if expiration > 0:
item["ttl"] = expiration - now
else:
item["ttl"] = 0
def print_graylist(data):
add_ttl(data)
field_props = (
("ip",),
("ttl",),
("country", [extract_field("code")]),
)
print_table(data, field_props)
def print_bwlist(data):
add_ttl(data)
field_props = (
("ip",),
("ttl",),
("country", [extract_field("code")]),
("imported_from",),
("comment",),
)
print_table(data, field_props)
def guess_printer(data):
if isinstance(data, (list, tuple)):
if len(data):
printer = TablePrinter()
if isinstance(data[0], dict):
keys = sorted(data[0].keys())
printer.set_field_properties(
"country", mappers=[extract_field("code")]
)
printer.print(keys, data)
else:
for item in data:
print(item)
else:
print(data)
def yaml_printer(data):
if isinstance(data, str):
print(data)
else:
print(yaml.dump(data, default_flow_style=False))
def json_printer(data):
if isinstance(data, str):
print(data)
else:
print(json.dumps(data))
def hook_printer(data):
if isinstance(data, dict):
print("Status: {}".format(data["status"]))
else:
result = []
for hook in data:
result.append(
"Event: {}, Path: {}{}".format(
hook["event"],
hook["path"],
" native" if hook["native"] else "",
)
)
print("\n".join(result))
PRINTERS = {
("config", "show"): json_printer,
("eula", "show"): pager,
("get",): print_incidents,
("whitelist",): print_bwlist,
("whitelist", "ip", "list"): print_bwlist,
("blacklist",): print_bwlist,
("blacklist", "ip", "list"): print_bwlist,
("graylist",): print_graylist,
("graylist", "ip", "list"): print_graylist,
("malware", "on-demand", "status"): yaml_printer,
("feature-management", "defaults"): yaml_printer,
("feature-management", "show"): yaml_printer,
("feature-management", "enable"): yaml_printer,
("feature-management", "disable"): yaml_printer,
("feature-management", "get"): yaml_printer,
("hook", "add"): hook_printer,
("hook", "delete"): hook_printer,
("hook", "list"): hook_printer,
("hook", "add-native"): hook_printer,
}
def print_response(method, result, is_json=False, is_verbose=False):
if is_json:
pretty_args = PRETTY_JSON_ARGS if is_verbose else {}
print(json.dumps(result, **pretty_args))
else:
print_fun = PRINTERS.get(method, guess_printer)
print_fun(result["items"] if result.get("items") is not None else "OK")
def print_warnings(data: dict):
if not isinstance(data, dict):
# This can happen, for example, if validation of cli args fails
return
for warning in data.get("warnings", []):
print(warning, file=sys.stderr)
def print_error(
result, messages, is_json=False, is_verbose=False, *, file=sys.stderr
):
if is_json:
pretty_args = PRETTY_JSON_ARGS if is_verbose else {}
print(json.dumps({result: messages}, **pretty_args))
else:
if isinstance(messages, (list, tuple)):
for msg in messages:
print("%s: %s" % (_CLI_MSG_PREFIX[result], msg), file=file)
else:
print(messages, file=file)