Files
HAMeter/hameter/config.py
2026-03-06 12:25:27 -05:00

400 lines
12 KiB
Python

"""Configuration loading, validation, and persistence for HAMeter.
All configuration is managed through the web UI and stored as JSON
at /data/config.json. A YAML config file can be imported as a
one-time migration.
"""
import json
import logging
import os
import tempfile
from dataclasses import dataclass, field
from typing import Optional
import yaml
logger = logging.getLogger(__name__)
VALID_PROTOCOLS = {"scm", "scm+", "idm", "netidm", "r900", "r900bcd"}
VALID_RATE_TYPES = {"per_unit", "fixed"}
# Default icons/unit per common meter type.
_METER_DEFAULTS = {
"energy": {"icon": "mdi:flash", "unit": "kWh"},
"gas": {"icon": "mdi:fire", "unit": "ft\u00b3"},
"water": {"icon": "mdi:water", "unit": "gal"},
}
CONFIG_PATH = "/data/config.json"
# ------------------------------------------------------------------ #
# Dataclasses
# ------------------------------------------------------------------ #
@dataclass
class MqttConfig:
host: str
port: int = 1883
user: str = ""
password: str = ""
base_topic: str = "hameter"
ha_autodiscovery: bool = True
ha_autodiscovery_topic: str = "homeassistant"
client_id: str = "hameter"
@dataclass
class RateComponent:
name: str
rate: float
type: str = "per_unit"
@dataclass
class MeterConfig:
id: int
protocol: str
name: str
unit_of_measurement: str
icon: str = "mdi:gauge"
device_class: str = ""
state_class: str = "total_increasing"
multiplier: float = 1.0
cost_factors: list[RateComponent] = field(default_factory=list)
@dataclass
class GeneralConfig:
sleep_for: int = 0
device_id: str = "0"
rtl_tcp_host: str = "127.0.0.1"
rtl_tcp_port: int = 1234
log_level: str = "INFO"
rtlamr_extra_args: list = field(default_factory=list)
@dataclass
class HaMeterConfig:
general: GeneralConfig
mqtt: MqttConfig
meters: list[MeterConfig]
# ------------------------------------------------------------------ #
# Config file operations
# ------------------------------------------------------------------ #
def config_exists(path: Optional[str] = None) -> bool:
"""Check if a config file exists at the standard path."""
return os.path.isfile(path or CONFIG_PATH)
def load_config_from_json(path: Optional[str] = None) -> HaMeterConfig:
"""Load configuration from a JSON file.
Raises:
FileNotFoundError: If config file does not exist.
ValueError: For validation errors.
"""
path = path or CONFIG_PATH
with open(path) as f:
raw = json.load(f)
if not isinstance(raw, dict):
raise ValueError(f"Config file is not a valid JSON object: {path}")
return _build_config_from_dict(raw)
def load_config_from_yaml(path: str) -> HaMeterConfig:
"""Load from YAML file (migration path from older versions).
Raises:
FileNotFoundError: If YAML file does not exist.
ValueError: For validation errors.
"""
with open(path) as f:
raw = yaml.safe_load(f) or {}
if not isinstance(raw, dict):
raise ValueError(f"YAML file is not a valid mapping: {path}")
return _build_config_from_dict(raw)
def save_config(config: HaMeterConfig, path: Optional[str] = None):
"""Atomically write config to JSON file.
Writes to a temp file in the same directory, then os.replace()
for atomic rename. This prevents corruption on power loss.
"""
path = path or CONFIG_PATH
data = config_to_dict(config)
dir_path = os.path.dirname(path)
if dir_path:
os.makedirs(dir_path, exist_ok=True)
fd, tmp_path = tempfile.mkstemp(dir=dir_path or ".", suffix=".tmp")
try:
with os.fdopen(fd, "w") as f:
json.dump(data, f, indent=2)
f.write("\n")
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, path)
except Exception:
try:
os.unlink(tmp_path)
except OSError:
pass
raise
def config_to_dict(config: HaMeterConfig) -> dict:
"""Serialize HaMeterConfig to a JSON-safe dict."""
return {
"general": {
"sleep_for": config.general.sleep_for,
"device_id": config.general.device_id,
"rtl_tcp_host": config.general.rtl_tcp_host,
"rtl_tcp_port": config.general.rtl_tcp_port,
"log_level": config.general.log_level,
"rtlamr_extra_args": config.general.rtlamr_extra_args,
},
"mqtt": {
"host": config.mqtt.host,
"port": config.mqtt.port,
"user": config.mqtt.user,
"password": config.mqtt.password,
"base_topic": config.mqtt.base_topic,
"ha_autodiscovery": config.mqtt.ha_autodiscovery,
"ha_autodiscovery_topic": config.mqtt.ha_autodiscovery_topic,
"client_id": config.mqtt.client_id,
},
"meters": [
{
"id": m.id,
"protocol": m.protocol,
"name": m.name,
"unit_of_measurement": m.unit_of_measurement,
"icon": m.icon,
"device_class": m.device_class,
"state_class": m.state_class,
"multiplier": m.multiplier,
"cost_factors": [
{"name": cf.name, "rate": cf.rate, "type": cf.type}
for cf in m.cost_factors
],
}
for m in config.meters
],
}
# ------------------------------------------------------------------ #
# Validation helpers
# ------------------------------------------------------------------ #
def validate_mqtt_config(data: dict) -> tuple[bool, str]:
"""Validate MQTT config fields, return (ok, error_message)."""
if not data.get("host", "").strip():
return False, "MQTT host is required"
port = data.get("port", 1883)
try:
port = int(port)
if not (1 <= port <= 65535):
raise ValueError
except (ValueError, TypeError):
return False, f"Invalid port: {port}"
return True, ""
def validate_meter_config(data: dict) -> tuple[bool, str]:
"""Validate a single meter config dict, return (ok, error_message)."""
if not data.get("id"):
return False, "Meter ID is required"
try:
int(data["id"])
except (ValueError, TypeError):
return False, f"Meter ID must be a number: {data['id']}"
protocol = data.get("protocol", "").lower()
if protocol not in VALID_PROTOCOLS:
return False, (
f"Invalid protocol: {protocol}. "
f"Valid: {', '.join(sorted(VALID_PROTOCOLS))}"
)
if not data.get("name", "").strip():
return False, "Meter name is required"
multiplier = data.get("multiplier", 1.0)
try:
multiplier = float(multiplier)
except (ValueError, TypeError):
return False, f"Multiplier must be a number: {data.get('multiplier')}"
if multiplier <= 0:
return False, f"Multiplier must be positive, got {multiplier}"
return True, ""
def validate_rate_component(data: dict) -> tuple[bool, str]:
"""Validate a single rate component dict, return (ok, error_message)."""
if not data.get("name", "").strip():
return False, "Rate component name is required"
try:
float(data.get("rate", ""))
except (ValueError, TypeError):
return False, f"Rate must be a number: {data.get('rate')}"
comp_type = data.get("type", "per_unit")
if comp_type not in VALID_RATE_TYPES:
return False, f"Invalid rate type: {comp_type}. Must be 'per_unit' or 'fixed'"
return True, ""
def get_meter_defaults(device_class: str) -> dict:
"""Get smart defaults (icon, unit) for a device class."""
return dict(_METER_DEFAULTS.get(device_class, {}))
# ------------------------------------------------------------------ #
# Internal builders
# ------------------------------------------------------------------ #
def _build_config_from_dict(raw: dict) -> HaMeterConfig:
"""Build HaMeterConfig from a raw dict (from JSON or YAML)."""
general = _build_general_from_dict(raw)
mqtt = _build_mqtt_from_dict(raw)
meters = _build_meters_from_dict(raw)
return HaMeterConfig(general=general, mqtt=mqtt, meters=meters)
def _build_general_from_dict(raw: dict) -> GeneralConfig:
"""Build GeneralConfig from raw dict."""
g = raw.get("general", {})
extra_args = g.get("rtlamr_extra_args", [])
if isinstance(extra_args, str):
extra_args = extra_args.split()
rtl_tcp_port = int(g.get("rtl_tcp_port", 1234))
if not (1 <= rtl_tcp_port <= 65535):
raise ValueError(
f"rtl_tcp_port must be 1-65535, got {rtl_tcp_port}"
)
log_level = str(g.get("log_level", "INFO")).upper()
if log_level not in ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"):
raise ValueError(
f"Invalid log_level '{log_level}'. "
f"Valid: DEBUG, INFO, WARNING, ERROR, CRITICAL"
)
device_id = str(g.get("device_id", "0"))
try:
if int(device_id) < 0:
raise ValueError
except (ValueError, TypeError):
raise ValueError(
f"device_id must be a non-negative integer, got '{device_id}'"
)
return GeneralConfig(
sleep_for=int(g.get("sleep_for", 0)),
device_id=device_id,
rtl_tcp_host=g.get("rtl_tcp_host", "127.0.0.1"),
rtl_tcp_port=rtl_tcp_port,
log_level=log_level,
rtlamr_extra_args=list(extra_args),
)
def _build_mqtt_from_dict(raw: dict) -> MqttConfig:
"""Build MqttConfig from raw dict."""
m = raw.get("mqtt") or {}
host = m.get("host", "")
if not host:
raise ValueError(
"MQTT host not set. Configure it via the web UI."
)
port = int(m.get("port", 1883))
if not (1 <= port <= 65535):
raise ValueError(f"MQTT port must be 1-65535, got {port}")
return MqttConfig(
host=host,
port=port,
user=m.get("user", ""),
password=m.get("password", ""),
base_topic=m.get("base_topic", "hameter"),
ha_autodiscovery=m.get("ha_autodiscovery", True),
ha_autodiscovery_topic=m.get("ha_autodiscovery_topic", "homeassistant"),
client_id=m.get("client_id", "hameter"),
)
def _build_meters_from_dict(raw: dict) -> list[MeterConfig]:
"""Build meter list from raw dict."""
meters_raw = raw.get("meters")
if not meters_raw:
return []
meters = []
seen_ids: set[int] = set()
for i, m in enumerate(meters_raw):
meter_id = m.get("id")
if meter_id is None:
raise ValueError(f"Meter #{i + 1} missing required 'id'")
mid_int = int(meter_id)
if mid_int in seen_ids:
raise ValueError(
f"Meter #{i + 1} has duplicate id {mid_int}"
)
seen_ids.add(mid_int)
protocol = m.get("protocol", "").lower()
if protocol not in VALID_PROTOCOLS:
raise ValueError(
f"Meter #{i + 1} (id={meter_id}) has invalid protocol "
f"'{protocol}'. Valid: {', '.join(sorted(VALID_PROTOCOLS))}"
)
name = m.get("name", "")
if not name:
raise ValueError(
f"Meter #{i + 1} (id={meter_id}) missing required 'name'"
)
# Apply smart defaults based on device_class.
device_class = m.get("device_class", "")
defaults = _METER_DEFAULTS.get(device_class, {})
cost_factors = [
RateComponent(
name=cf.get("name", ""),
rate=float(cf.get("rate", 0.0)),
type=cf.get("type", "per_unit"),
)
for cf in m.get("cost_factors", [])
]
meters.append(MeterConfig(
id=int(meter_id),
protocol=protocol,
name=name,
unit_of_measurement=m.get("unit_of_measurement", "") or defaults.get("unit", ""),
icon=m.get("icon", "") or defaults.get("icon", "mdi:gauge"),
device_class=device_class,
state_class=m.get("state_class", "total_increasing"),
multiplier=float(m.get("multiplier", 1.0)),
cost_factors=cost_factors,
))
return meters