diff --git a/stacosys/conf/config.py b/stacosys/conf/config.py index 4e77fb1..f88fcbf 100644 --- a/stacosys/conf/config.py +++ b/stacosys/conf/config.py @@ -3,7 +3,7 @@ from enum import Enum -import profig +import configparser class ConfigParameter(Enum): @@ -32,32 +32,51 @@ class ConfigParameter(Enum): class Config: def __init__(self): - self._params = dict() + self._cfg = configparser.ConfigParser() @classmethod def load(cls, config_pathname): - cfg = profig.Config(config_pathname) - cfg.sync() config = cls() - config._params.update(cfg) + config._cfg.read(config_pathname) return config + def _split_key(self, key: ConfigParameter): + section, param = str(key.value).split(".") + if not param: + param = section + section = None + return (section, param) + def exists(self, key: ConfigParameter): - return key.value in self._params + section, param = self._split_key(key) + return self._cfg.has_option(section, param) def get(self, key: ConfigParameter): - return self._params[key.value] if key.value in self._params else None + section, param = self._split_key(key) + return ( + self._cfg.get(section, param) + if self._cfg.has_option(section, param) + else None + ) def put(self, key: ConfigParameter, value): - self._params[key.value] = value + section, param = self._split_key(key) + if section and not self._cfg.has_section(section): + self._cfg.add_section(section) + self._cfg.set(section, param, str(value)) def get_int(self, key: ConfigParameter): - return int(self._params[key.value]) + value = self.get(key) + return int(value) if value else 0 def get_bool(self, key: ConfigParameter): - value = self._params[key.value].lower() + value = self.get(key) assert value in ("yes", "true", "no", "false") return value in ("yes", "true") def __repr__(self): - return self._params.__repr__() + d = dict() + for section in self._cfg.sections(): + for option in self._cfg.options(section): + d[".".join([section, option])] = self._cfg.get(section, option) + return str(d) diff --git a/tests/test_config.py b/tests/test_config.py index 07797c4..e098965 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -23,9 +23,10 @@ class ConfigTestCase(unittest.TestCase): self.assertEqual( self.conf.get(ConfigParameter.DB_SQLITE_FILE), EXPECTED_DB_SQLITE_FILE ) - self.assertEqual(self.conf.get(ConfigParameter.HTTP_PORT), EXPECTED_HTTP_PORT) self.assertIsNone(self.conf.get(ConfigParameter.HTTP_HOST)) - self.assertEqual(self.conf.get(ConfigParameter.HTTP_PORT), EXPECTED_HTTP_PORT) + self.assertEqual( + self.conf.get(ConfigParameter.HTTP_PORT), str(EXPECTED_HTTP_PORT) + ) self.assertEqual(self.conf.get_int(ConfigParameter.HTTP_PORT), 8080) try: self.conf.get_bool(ConfigParameter.DB_SQLITE_FILE)