import json
import logging
import os
import pwd
from abc import abstractmethod
from contextlib import suppress
from textwrap import dedent
from typing import Mapping, Optional, Protocol
import sentry_sdk
import yaml
from defence360agent.utils import atomic_rewrite
logger = logging.getLogger(__name__)
# Don't read config if its file is larger than this.
_MAX_CONFIG_SIZE = 1 << 20 # 1MiB
class IConfigProvider(Protocol):
@abstractmethod
def read_config_file(
self, force_read: bool = False, ignore_errors: bool = True
):
raise NotImplementedError
@abstractmethod
def write_config_file(self, config: Mapping) -> None:
raise NotImplementedError
@abstractmethod
def modified_since(self, timestamp: Optional[float]) -> bool:
raise NotImplementedError
class ConfigError(Exception):
pass
class JsonMessage:
"""Pretty-print given *obj* as JSON.
To be used for logging. Example:
logging.info("object: %s", JsonMessage(obj))
"""
def __init__(self, obj):
self._obj = obj
def __str__(self):
return json.dumps(self._obj, sort_keys=True)
def diff_section(prev_section: Optional[dict], section: Optional[dict]):
"""Return difference between config sections."""
prev_section = prev_section or {}
section = section or {}
removed_settings = prev_section.keys() - section.keys()
added_settings = section.keys() - prev_section.keys()
return {
"-": {v: prev_section[v] for v in removed_settings},
"+": {v: section[v] for v in added_settings},
# modified settings
"?": {
v: (prev_section[v], section[v])
for v in (prev_section.keys() & section.keys())
if prev_section[v] != section[v]
},
}
def diff_config(prev_conf: dict, conf: dict):
"""Compare *prev_conf* with the current *conf*."""
removed_sections = prev_conf.keys() - conf.keys()
yield {section: prev_conf[section] for section in removed_sections}
added_sections = conf.keys() - prev_conf.keys()
yield {section: conf[section] for section in added_sections}
# changed sections
yield {
section: diff_section(prev_conf[section], conf[section])
for section in (prev_conf.keys() & conf.keys())
if prev_conf[section] != conf[section]
}
def exclude_equals(*, main_conf: dict, base_conf: dict) -> dict:
"""
Return dict derived from *main_conf* excluding parts
that are equal in *base_conf*.
For example,
>>> base_conf = {
"SECTION1": {"OPTION1": "default", "OPTION2": "default"},
"SECTION2": {"OPTION1": "default"}
}
>>> main_conf = {
"SECTION1": {"OPTION1": "value", "OPTION2": "default"},
"SECTION2": {"OPTION1": "default"}
}
>>>
>>> exclude_equals(main_conf=main_conf, base_conf=base_conf)
{'SECTION1': {'OPTION1': 'value'}}
>>>
"""
_, added, changed = diff_config(base_conf, main_conf)
result = {}
for section, value in main_conf.items():
if section in added.keys():
result[section] = value
if section in changed.keys():
result.setdefault(section, {}).update(changed[section]["+"])
result.setdefault(section, {}).update(
{k: v[1] for k, v in changed[section]["?"].items()}
)
return result
class ConfigReader:
"""
ConfigFile file for settings page.
Location config file is PATH
"""
def __init__(self, path, disclaimer="", permissions=None):
self.path = path
self.disclaimer = disclaimer
self.permissions = permissions
def __repr__(self):
return "<{classname}({path})>".format(
classname=self.__class__.__qualname__, path=self.path
)
def __str__(self):
return f"ConfigReader at {self.path}"
def read_config_file(
self, force_read: bool = False, ignore_errors: bool = True
) -> dict:
"""Read config file into memory.
Raises ConfigError.
"""
try:
if os.path.getsize(self.path) > _MAX_CONFIG_SIZE:
raise ConfigError("Config file is too large")
filename = self.path
with open(filename, "r") as config_file:
logger.info("Reading config file %s", filename)
text = config_file.read()
except UnicodeDecodeError as e:
raise ConfigError("Unable to decode config file") from e
except FileNotFoundError:
return {}
try:
return self.load_config_body(text)
except ConfigError as e:
logger.error(e)
if ignore_errors:
return {}
raise e
def load_config_body(self, text: str) -> dict:
try:
config = yaml.safe_load(text)
except yaml.YAMLError as e:
raise ConfigError(
f"Imunify360 config is not valid YAML document ({e})"
) from e
if config is None:
return {}
if not isinstance(config, dict):
raise ConfigError(
"Imunify360 config is invalid or empty"
": path={!r}, text={!r}".format(self.path, text)
)
return config
def _pre_write(self):
pass
def _post_write(self):
pass
def write_config_file(self, config) -> str:
self._pre_write()
config_text = ""
if self.disclaimer:
config_text += dedent(self.disclaimer)
config_text += "\n"
config_text += yaml.dump(config, default_flow_style=False)
atomic_rewrite(
self.path, config_text, backup=False, permissions=self.permissions
)
self._post_write()
return config_text
def modified_since(self, timestamp: Optional[float]) -> bool:
return True
class CachedConfigReader(ConfigReader):
def __init__(self, path, disclaimer="", permissions=None):
super().__init__(path, disclaimer)
self.mtime: Optional[float] = None
self.size: Optional[float] = None
self._config = {}
self.permissions = permissions
def __str__(self):
return (
"{classname} <'{path}', modified at {mtime}, {size} bytes>".format(
classname=self.__class__.__qualname__,
path=self.path,
mtime=self.mtime,
size=self.size,
)
)
def read_config_file(
self, force_read: bool = False, ignore_errors: bool = True
):
"""Update config if config file is modified"""
if self.modified_since(self.mtime) or force_read:
prev_config = self._config
try:
self._config = super().read_config_file(
ignore_errors=ignore_errors
)
except ConfigError as error:
sentry_sdk.capture_exception(error)
logger.warning(
"%s is invalid, using previous settings: %s",
self,
JsonMessage(self._config),
)
if not ignore_errors:
raise error
else:
if self.mtime is not None: # don't log on startup
diffs = list(diff_config(prev_config, self._config))
if any(diffs):
# content has changed, log it
logger.info(
"%s modified: removed=%s, added=%s, changed=%s",
self,
*map(JsonMessage, diffs),
)
try:
stat = os.stat(self.path)
self.mtime = stat.st_mtime
self.size = stat.st_size
except FileNotFoundError:
self.mtime = 0.0
self.size = 0.0
return self._config
def modified_since(self, timestamp: Optional[float]) -> bool:
"""Whether the config has updated since *timestamp*.
(as defined by its last modification time and size)
:param timestamp: None means that the file has never been read before
"""
# On startup consider timestamp to be None
if timestamp is None:
timestamp = 0.0
try:
stat = os.stat(self.path)
except FileNotFoundError:
st_mtime, st_size = 0.0, 0.0
else:
st_mtime, st_size = stat.st_mtime, stat.st_size
return st_mtime > timestamp or st_size != self.size
class WriteOnlyConfigReader(CachedConfigReader):
def read_config_file(self, *_, **__):
return self._config
def write_config_file(self, config):
config_text = super().write_config_file(config)
self._config = self.load_config_body(config_text)
return config_text
class UserConfigReader(CachedConfigReader):
def __init__(self, path, username):
super().__init__(path)
self.username = username
def __str__(self):
return f"Config of user {self.username}"
def _pre_write(self):
confdir = os.path.dirname(self.path)
with suppress(FileExistsError):
os.mkdir(os.path.dirname(self.path))
os.chown(confdir, 0, pwd.getpwnam(self.username).pw_gid)
os.chmod(confdir, 0o750)
def _post_write(self):
os.chown(self.path, 0, pwd.getpwnam(self.username).pw_gid)
os.chmod(self.path, 0o640)