"""Configuration loading, validation, and persistence for HAMeter. All configuration is managed through the web UI and stored as JSON at /data/config.json. A YAML config file can be imported as a one-time migration. """ import json import logging import os import tempfile from dataclasses import dataclass, field from typing import Optional import yaml logger = logging.getLogger(__name__) VALID_PROTOCOLS = {"scm", "scm+", "idm", "netidm", "r900", "r900bcd"} VALID_RATE_TYPES = {"per_unit", "fixed"} # Default icons/unit per common meter type. _METER_DEFAULTS = { "energy": {"icon": "mdi:flash", "unit": "kWh"}, "gas": {"icon": "mdi:fire", "unit": "ft\u00b3"}, "water": {"icon": "mdi:water", "unit": "gal"}, } CONFIG_PATH = "/data/config.json" # ------------------------------------------------------------------ # # Dataclasses # ------------------------------------------------------------------ # @dataclass class MqttConfig: host: str port: int = 1883 user: str = "" password: str = "" base_topic: str = "hameter" ha_autodiscovery: bool = True ha_autodiscovery_topic: str = "homeassistant" client_id: str = "hameter" @dataclass class RateComponent: name: str rate: float type: str = "per_unit" @dataclass class MeterConfig: id: int protocol: str name: str unit_of_measurement: str icon: str = "mdi:gauge" device_class: str = "" state_class: str = "total_increasing" multiplier: float = 1.0 cost_factors: list[RateComponent] = field(default_factory=list) @dataclass class GeneralConfig: sleep_for: int = 0 device_id: str = "0" rtl_tcp_host: str = "127.0.0.1" rtl_tcp_port: int = 1234 log_level: str = "INFO" rtlamr_extra_args: list = field(default_factory=list) @dataclass class HaMeterConfig: general: GeneralConfig mqtt: MqttConfig meters: list[MeterConfig] # ------------------------------------------------------------------ # # Config file operations # ------------------------------------------------------------------ # def config_exists(path: Optional[str] = None) -> bool: """Check if a config file exists at the standard path.""" return os.path.isfile(path or CONFIG_PATH) def load_config_from_json(path: Optional[str] = None) -> HaMeterConfig: """Load configuration from a JSON file. Raises: FileNotFoundError: If config file does not exist. ValueError: For validation errors. """ path = path or CONFIG_PATH with open(path) as f: raw = json.load(f) if not isinstance(raw, dict): raise ValueError(f"Config file is not a valid JSON object: {path}") return _build_config_from_dict(raw) def load_config_from_yaml(path: str) -> HaMeterConfig: """Load from YAML file (migration path from older versions). Raises: FileNotFoundError: If YAML file does not exist. ValueError: For validation errors. """ with open(path) as f: raw = yaml.safe_load(f) or {} if not isinstance(raw, dict): raise ValueError(f"YAML file is not a valid mapping: {path}") return _build_config_from_dict(raw) def save_config(config: HaMeterConfig, path: Optional[str] = None): """Atomically write config to JSON file. Writes to a temp file in the same directory, then os.replace() for atomic rename. This prevents corruption on power loss. """ path = path or CONFIG_PATH data = config_to_dict(config) dir_path = os.path.dirname(path) if dir_path: os.makedirs(dir_path, exist_ok=True) fd, tmp_path = tempfile.mkstemp(dir=dir_path or ".", suffix=".tmp") try: with os.fdopen(fd, "w") as f: json.dump(data, f, indent=2) f.write("\n") f.flush() os.fsync(f.fileno()) os.replace(tmp_path, path) except Exception: try: os.unlink(tmp_path) except OSError: pass raise def config_to_dict(config: HaMeterConfig) -> dict: """Serialize HaMeterConfig to a JSON-safe dict.""" return { "general": { "sleep_for": config.general.sleep_for, "device_id": config.general.device_id, "rtl_tcp_host": config.general.rtl_tcp_host, "rtl_tcp_port": config.general.rtl_tcp_port, "log_level": config.general.log_level, "rtlamr_extra_args": config.general.rtlamr_extra_args, }, "mqtt": { "host": config.mqtt.host, "port": config.mqtt.port, "user": config.mqtt.user, "password": config.mqtt.password, "base_topic": config.mqtt.base_topic, "ha_autodiscovery": config.mqtt.ha_autodiscovery, "ha_autodiscovery_topic": config.mqtt.ha_autodiscovery_topic, "client_id": config.mqtt.client_id, }, "meters": [ { "id": m.id, "protocol": m.protocol, "name": m.name, "unit_of_measurement": m.unit_of_measurement, "icon": m.icon, "device_class": m.device_class, "state_class": m.state_class, "multiplier": m.multiplier, "cost_factors": [ {"name": cf.name, "rate": cf.rate, "type": cf.type} for cf in m.cost_factors ], } for m in config.meters ], } # ------------------------------------------------------------------ # # Validation helpers # ------------------------------------------------------------------ # def validate_mqtt_config(data: dict) -> tuple[bool, str]: """Validate MQTT config fields, return (ok, error_message).""" if not data.get("host", "").strip(): return False, "MQTT host is required" port = data.get("port", 1883) try: port = int(port) if not (1 <= port <= 65535): raise ValueError except (ValueError, TypeError): return False, f"Invalid port: {port}" return True, "" def validate_meter_config(data: dict) -> tuple[bool, str]: """Validate a single meter config dict, return (ok, error_message).""" if not data.get("id"): return False, "Meter ID is required" try: int(data["id"]) except (ValueError, TypeError): return False, f"Meter ID must be a number: {data['id']}" protocol = data.get("protocol", "").lower() if protocol not in VALID_PROTOCOLS: return False, ( f"Invalid protocol: {protocol}. " f"Valid: {', '.join(sorted(VALID_PROTOCOLS))}" ) if not data.get("name", "").strip(): return False, "Meter name is required" multiplier = data.get("multiplier", 1.0) try: multiplier = float(multiplier) except (ValueError, TypeError): return False, f"Multiplier must be a number: {data.get('multiplier')}" if multiplier <= 0: return False, f"Multiplier must be positive, got {multiplier}" return True, "" def validate_rate_component(data: dict) -> tuple[bool, str]: """Validate a single rate component dict, return (ok, error_message).""" if not data.get("name", "").strip(): return False, "Rate component name is required" try: float(data.get("rate", "")) except (ValueError, TypeError): return False, f"Rate must be a number: {data.get('rate')}" comp_type = data.get("type", "per_unit") if comp_type not in VALID_RATE_TYPES: return False, f"Invalid rate type: {comp_type}. Must be 'per_unit' or 'fixed'" return True, "" def get_meter_defaults(device_class: str) -> dict: """Get smart defaults (icon, unit) for a device class.""" return dict(_METER_DEFAULTS.get(device_class, {})) # ------------------------------------------------------------------ # # Internal builders # ------------------------------------------------------------------ # def _build_config_from_dict(raw: dict) -> HaMeterConfig: """Build HaMeterConfig from a raw dict (from JSON or YAML).""" general = _build_general_from_dict(raw) mqtt = _build_mqtt_from_dict(raw) meters = _build_meters_from_dict(raw) return HaMeterConfig(general=general, mqtt=mqtt, meters=meters) def _build_general_from_dict(raw: dict) -> GeneralConfig: """Build GeneralConfig from raw dict.""" g = raw.get("general", {}) extra_args = g.get("rtlamr_extra_args", []) if isinstance(extra_args, str): extra_args = extra_args.split() rtl_tcp_port = int(g.get("rtl_tcp_port", 1234)) if not (1 <= rtl_tcp_port <= 65535): raise ValueError( f"rtl_tcp_port must be 1-65535, got {rtl_tcp_port}" ) log_level = str(g.get("log_level", "INFO")).upper() if log_level not in ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"): raise ValueError( f"Invalid log_level '{log_level}'. " f"Valid: DEBUG, INFO, WARNING, ERROR, CRITICAL" ) device_id = str(g.get("device_id", "0")) try: if int(device_id) < 0: raise ValueError except (ValueError, TypeError): raise ValueError( f"device_id must be a non-negative integer, got '{device_id}'" ) return GeneralConfig( sleep_for=int(g.get("sleep_for", 0)), device_id=device_id, rtl_tcp_host=g.get("rtl_tcp_host", "127.0.0.1"), rtl_tcp_port=rtl_tcp_port, log_level=log_level, rtlamr_extra_args=list(extra_args), ) def _build_mqtt_from_dict(raw: dict) -> MqttConfig: """Build MqttConfig from raw dict.""" m = raw.get("mqtt") or {} host = m.get("host", "") if not host: raise ValueError( "MQTT host not set. Configure it via the web UI." ) port = int(m.get("port", 1883)) if not (1 <= port <= 65535): raise ValueError(f"MQTT port must be 1-65535, got {port}") return MqttConfig( host=host, port=port, user=m.get("user", ""), password=m.get("password", ""), base_topic=m.get("base_topic", "hameter"), ha_autodiscovery=m.get("ha_autodiscovery", True), ha_autodiscovery_topic=m.get("ha_autodiscovery_topic", "homeassistant"), client_id=m.get("client_id", "hameter"), ) def _build_meters_from_dict(raw: dict) -> list[MeterConfig]: """Build meter list from raw dict.""" meters_raw = raw.get("meters") if not meters_raw: return [] meters = [] seen_ids: set[int] = set() for i, m in enumerate(meters_raw): meter_id = m.get("id") if meter_id is None: raise ValueError(f"Meter #{i + 1} missing required 'id'") mid_int = int(meter_id) if mid_int in seen_ids: raise ValueError( f"Meter #{i + 1} has duplicate id {mid_int}" ) seen_ids.add(mid_int) protocol = m.get("protocol", "").lower() if protocol not in VALID_PROTOCOLS: raise ValueError( f"Meter #{i + 1} (id={meter_id}) has invalid protocol " f"'{protocol}'. Valid: {', '.join(sorted(VALID_PROTOCOLS))}" ) name = m.get("name", "") if not name: raise ValueError( f"Meter #{i + 1} (id={meter_id}) missing required 'name'" ) # Apply smart defaults based on device_class. device_class = m.get("device_class", "") defaults = _METER_DEFAULTS.get(device_class, {}) cost_factors = [ RateComponent( name=cf.get("name", ""), rate=float(cf.get("rate", 0.0)), type=cf.get("type", "per_unit"), ) for cf in m.get("cost_factors", []) ] meters.append(MeterConfig( id=int(meter_id), protocol=protocol, name=name, unit_of_measurement=m.get("unit_of_measurement", "") or defaults.get("unit", ""), icon=m.get("icon", "") or defaults.get("icon", "mdi:gauge"), device_class=device_class, state_class=m.get("state_class", "total_increasing"), multiplier=float(m.get("multiplier", 1.0)), cost_factors=cost_factors, )) return meters