400 lines
12 KiB
Python
400 lines
12 KiB
Python
"""Configuration loading, validation, and persistence for HAMeter.
|
|
|
|
All configuration is managed through the web UI and stored as JSON
|
|
at /data/config.json. A YAML config file can be imported as a
|
|
one-time migration.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional
|
|
|
|
import yaml
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
VALID_PROTOCOLS = {"scm", "scm+", "idm", "netidm", "r900", "r900bcd"}
|
|
VALID_RATE_TYPES = {"per_unit", "fixed"}
|
|
|
|
# Default icons/unit per common meter type.
|
|
_METER_DEFAULTS = {
|
|
"energy": {"icon": "mdi:flash", "unit": "kWh"},
|
|
"gas": {"icon": "mdi:fire", "unit": "ft\u00b3"},
|
|
"water": {"icon": "mdi:water", "unit": "gal"},
|
|
}
|
|
|
|
CONFIG_PATH = "/data/config.json"
|
|
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# Dataclasses
|
|
# ------------------------------------------------------------------ #
|
|
|
|
@dataclass
|
|
class MqttConfig:
|
|
host: str
|
|
port: int = 1883
|
|
user: str = ""
|
|
password: str = ""
|
|
base_topic: str = "hameter"
|
|
ha_autodiscovery: bool = True
|
|
ha_autodiscovery_topic: str = "homeassistant"
|
|
client_id: str = "hameter"
|
|
|
|
|
|
@dataclass
|
|
class RateComponent:
|
|
name: str
|
|
rate: float
|
|
type: str = "per_unit"
|
|
|
|
|
|
@dataclass
|
|
class MeterConfig:
|
|
id: int
|
|
protocol: str
|
|
name: str
|
|
unit_of_measurement: str
|
|
icon: str = "mdi:gauge"
|
|
device_class: str = ""
|
|
state_class: str = "total_increasing"
|
|
multiplier: float = 1.0
|
|
cost_factors: list[RateComponent] = field(default_factory=list)
|
|
|
|
|
|
@dataclass
|
|
class GeneralConfig:
|
|
sleep_for: int = 0
|
|
device_id: str = "0"
|
|
rtl_tcp_host: str = "127.0.0.1"
|
|
rtl_tcp_port: int = 1234
|
|
log_level: str = "INFO"
|
|
rtlamr_extra_args: list = field(default_factory=list)
|
|
|
|
|
|
@dataclass
|
|
class HaMeterConfig:
|
|
general: GeneralConfig
|
|
mqtt: MqttConfig
|
|
meters: list[MeterConfig]
|
|
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# Config file operations
|
|
# ------------------------------------------------------------------ #
|
|
|
|
def config_exists(path: Optional[str] = None) -> bool:
|
|
"""Check if a config file exists at the standard path."""
|
|
return os.path.isfile(path or CONFIG_PATH)
|
|
|
|
|
|
def load_config_from_json(path: Optional[str] = None) -> HaMeterConfig:
|
|
"""Load configuration from a JSON file.
|
|
|
|
Raises:
|
|
FileNotFoundError: If config file does not exist.
|
|
ValueError: For validation errors.
|
|
"""
|
|
path = path or CONFIG_PATH
|
|
with open(path) as f:
|
|
raw = json.load(f)
|
|
|
|
if not isinstance(raw, dict):
|
|
raise ValueError(f"Config file is not a valid JSON object: {path}")
|
|
|
|
return _build_config_from_dict(raw)
|
|
|
|
|
|
def load_config_from_yaml(path: str) -> HaMeterConfig:
|
|
"""Load from YAML file (migration path from older versions).
|
|
|
|
Raises:
|
|
FileNotFoundError: If YAML file does not exist.
|
|
ValueError: For validation errors.
|
|
"""
|
|
with open(path) as f:
|
|
raw = yaml.safe_load(f) or {}
|
|
|
|
if not isinstance(raw, dict):
|
|
raise ValueError(f"YAML file is not a valid mapping: {path}")
|
|
|
|
return _build_config_from_dict(raw)
|
|
|
|
|
|
def save_config(config: HaMeterConfig, path: Optional[str] = None):
|
|
"""Atomically write config to JSON file.
|
|
|
|
Writes to a temp file in the same directory, then os.replace()
|
|
for atomic rename. This prevents corruption on power loss.
|
|
"""
|
|
path = path or CONFIG_PATH
|
|
data = config_to_dict(config)
|
|
|
|
dir_path = os.path.dirname(path)
|
|
if dir_path:
|
|
os.makedirs(dir_path, exist_ok=True)
|
|
|
|
fd, tmp_path = tempfile.mkstemp(dir=dir_path or ".", suffix=".tmp")
|
|
try:
|
|
with os.fdopen(fd, "w") as f:
|
|
json.dump(data, f, indent=2)
|
|
f.write("\n")
|
|
f.flush()
|
|
os.fsync(f.fileno())
|
|
os.replace(tmp_path, path)
|
|
except Exception:
|
|
try:
|
|
os.unlink(tmp_path)
|
|
except OSError:
|
|
pass
|
|
raise
|
|
|
|
|
|
def config_to_dict(config: HaMeterConfig) -> dict:
|
|
"""Serialize HaMeterConfig to a JSON-safe dict."""
|
|
return {
|
|
"general": {
|
|
"sleep_for": config.general.sleep_for,
|
|
"device_id": config.general.device_id,
|
|
"rtl_tcp_host": config.general.rtl_tcp_host,
|
|
"rtl_tcp_port": config.general.rtl_tcp_port,
|
|
"log_level": config.general.log_level,
|
|
"rtlamr_extra_args": config.general.rtlamr_extra_args,
|
|
},
|
|
"mqtt": {
|
|
"host": config.mqtt.host,
|
|
"port": config.mqtt.port,
|
|
"user": config.mqtt.user,
|
|
"password": config.mqtt.password,
|
|
"base_topic": config.mqtt.base_topic,
|
|
"ha_autodiscovery": config.mqtt.ha_autodiscovery,
|
|
"ha_autodiscovery_topic": config.mqtt.ha_autodiscovery_topic,
|
|
"client_id": config.mqtt.client_id,
|
|
},
|
|
"meters": [
|
|
{
|
|
"id": m.id,
|
|
"protocol": m.protocol,
|
|
"name": m.name,
|
|
"unit_of_measurement": m.unit_of_measurement,
|
|
"icon": m.icon,
|
|
"device_class": m.device_class,
|
|
"state_class": m.state_class,
|
|
"multiplier": m.multiplier,
|
|
"cost_factors": [
|
|
{"name": cf.name, "rate": cf.rate, "type": cf.type}
|
|
for cf in m.cost_factors
|
|
],
|
|
}
|
|
for m in config.meters
|
|
],
|
|
}
|
|
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# Validation helpers
|
|
# ------------------------------------------------------------------ #
|
|
|
|
def validate_mqtt_config(data: dict) -> tuple[bool, str]:
|
|
"""Validate MQTT config fields, return (ok, error_message)."""
|
|
if not data.get("host", "").strip():
|
|
return False, "MQTT host is required"
|
|
port = data.get("port", 1883)
|
|
try:
|
|
port = int(port)
|
|
if not (1 <= port <= 65535):
|
|
raise ValueError
|
|
except (ValueError, TypeError):
|
|
return False, f"Invalid port: {port}"
|
|
return True, ""
|
|
|
|
|
|
def validate_meter_config(data: dict) -> tuple[bool, str]:
|
|
"""Validate a single meter config dict, return (ok, error_message)."""
|
|
if not data.get("id"):
|
|
return False, "Meter ID is required"
|
|
try:
|
|
int(data["id"])
|
|
except (ValueError, TypeError):
|
|
return False, f"Meter ID must be a number: {data['id']}"
|
|
|
|
protocol = data.get("protocol", "").lower()
|
|
if protocol not in VALID_PROTOCOLS:
|
|
return False, (
|
|
f"Invalid protocol: {protocol}. "
|
|
f"Valid: {', '.join(sorted(VALID_PROTOCOLS))}"
|
|
)
|
|
|
|
if not data.get("name", "").strip():
|
|
return False, "Meter name is required"
|
|
|
|
multiplier = data.get("multiplier", 1.0)
|
|
try:
|
|
multiplier = float(multiplier)
|
|
except (ValueError, TypeError):
|
|
return False, f"Multiplier must be a number: {data.get('multiplier')}"
|
|
if multiplier <= 0:
|
|
return False, f"Multiplier must be positive, got {multiplier}"
|
|
|
|
return True, ""
|
|
|
|
|
|
def validate_rate_component(data: dict) -> tuple[bool, str]:
|
|
"""Validate a single rate component dict, return (ok, error_message)."""
|
|
if not data.get("name", "").strip():
|
|
return False, "Rate component name is required"
|
|
try:
|
|
float(data.get("rate", ""))
|
|
except (ValueError, TypeError):
|
|
return False, f"Rate must be a number: {data.get('rate')}"
|
|
comp_type = data.get("type", "per_unit")
|
|
if comp_type not in VALID_RATE_TYPES:
|
|
return False, f"Invalid rate type: {comp_type}. Must be 'per_unit' or 'fixed'"
|
|
return True, ""
|
|
|
|
|
|
def get_meter_defaults(device_class: str) -> dict:
|
|
"""Get smart defaults (icon, unit) for a device class."""
|
|
return dict(_METER_DEFAULTS.get(device_class, {}))
|
|
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# Internal builders
|
|
# ------------------------------------------------------------------ #
|
|
|
|
def _build_config_from_dict(raw: dict) -> HaMeterConfig:
|
|
"""Build HaMeterConfig from a raw dict (from JSON or YAML)."""
|
|
general = _build_general_from_dict(raw)
|
|
mqtt = _build_mqtt_from_dict(raw)
|
|
meters = _build_meters_from_dict(raw)
|
|
return HaMeterConfig(general=general, mqtt=mqtt, meters=meters)
|
|
|
|
|
|
def _build_general_from_dict(raw: dict) -> GeneralConfig:
|
|
"""Build GeneralConfig from raw dict."""
|
|
g = raw.get("general", {})
|
|
extra_args = g.get("rtlamr_extra_args", [])
|
|
if isinstance(extra_args, str):
|
|
extra_args = extra_args.split()
|
|
|
|
rtl_tcp_port = int(g.get("rtl_tcp_port", 1234))
|
|
if not (1 <= rtl_tcp_port <= 65535):
|
|
raise ValueError(
|
|
f"rtl_tcp_port must be 1-65535, got {rtl_tcp_port}"
|
|
)
|
|
|
|
log_level = str(g.get("log_level", "INFO")).upper()
|
|
if log_level not in ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"):
|
|
raise ValueError(
|
|
f"Invalid log_level '{log_level}'. "
|
|
f"Valid: DEBUG, INFO, WARNING, ERROR, CRITICAL"
|
|
)
|
|
|
|
device_id = str(g.get("device_id", "0"))
|
|
try:
|
|
if int(device_id) < 0:
|
|
raise ValueError
|
|
except (ValueError, TypeError):
|
|
raise ValueError(
|
|
f"device_id must be a non-negative integer, got '{device_id}'"
|
|
)
|
|
|
|
return GeneralConfig(
|
|
sleep_for=int(g.get("sleep_for", 0)),
|
|
device_id=device_id,
|
|
rtl_tcp_host=g.get("rtl_tcp_host", "127.0.0.1"),
|
|
rtl_tcp_port=rtl_tcp_port,
|
|
log_level=log_level,
|
|
rtlamr_extra_args=list(extra_args),
|
|
)
|
|
|
|
|
|
def _build_mqtt_from_dict(raw: dict) -> MqttConfig:
|
|
"""Build MqttConfig from raw dict."""
|
|
m = raw.get("mqtt") or {}
|
|
|
|
host = m.get("host", "")
|
|
if not host:
|
|
raise ValueError(
|
|
"MQTT host not set. Configure it via the web UI."
|
|
)
|
|
|
|
port = int(m.get("port", 1883))
|
|
if not (1 <= port <= 65535):
|
|
raise ValueError(f"MQTT port must be 1-65535, got {port}")
|
|
|
|
return MqttConfig(
|
|
host=host,
|
|
port=port,
|
|
user=m.get("user", ""),
|
|
password=m.get("password", ""),
|
|
base_topic=m.get("base_topic", "hameter"),
|
|
ha_autodiscovery=m.get("ha_autodiscovery", True),
|
|
ha_autodiscovery_topic=m.get("ha_autodiscovery_topic", "homeassistant"),
|
|
client_id=m.get("client_id", "hameter"),
|
|
)
|
|
|
|
|
|
def _build_meters_from_dict(raw: dict) -> list[MeterConfig]:
|
|
"""Build meter list from raw dict."""
|
|
meters_raw = raw.get("meters")
|
|
if not meters_raw:
|
|
return []
|
|
|
|
meters = []
|
|
seen_ids: set[int] = set()
|
|
for i, m in enumerate(meters_raw):
|
|
meter_id = m.get("id")
|
|
if meter_id is None:
|
|
raise ValueError(f"Meter #{i + 1} missing required 'id'")
|
|
|
|
mid_int = int(meter_id)
|
|
if mid_int in seen_ids:
|
|
raise ValueError(
|
|
f"Meter #{i + 1} has duplicate id {mid_int}"
|
|
)
|
|
seen_ids.add(mid_int)
|
|
|
|
protocol = m.get("protocol", "").lower()
|
|
if protocol not in VALID_PROTOCOLS:
|
|
raise ValueError(
|
|
f"Meter #{i + 1} (id={meter_id}) has invalid protocol "
|
|
f"'{protocol}'. Valid: {', '.join(sorted(VALID_PROTOCOLS))}"
|
|
)
|
|
|
|
name = m.get("name", "")
|
|
if not name:
|
|
raise ValueError(
|
|
f"Meter #{i + 1} (id={meter_id}) missing required 'name'"
|
|
)
|
|
|
|
# Apply smart defaults based on device_class.
|
|
device_class = m.get("device_class", "")
|
|
defaults = _METER_DEFAULTS.get(device_class, {})
|
|
|
|
cost_factors = [
|
|
RateComponent(
|
|
name=cf.get("name", ""),
|
|
rate=float(cf.get("rate", 0.0)),
|
|
type=cf.get("type", "per_unit"),
|
|
)
|
|
for cf in m.get("cost_factors", [])
|
|
]
|
|
|
|
meters.append(MeterConfig(
|
|
id=int(meter_id),
|
|
protocol=protocol,
|
|
name=name,
|
|
unit_of_measurement=m.get("unit_of_measurement", "") or defaults.get("unit", ""),
|
|
icon=m.get("icon", "") or defaults.get("icon", "mdi:gauge"),
|
|
device_class=device_class,
|
|
state_class=m.get("state_class", "total_increasing"),
|
|
multiplier=float(m.get("multiplier", 1.0)),
|
|
cost_factors=cost_factors,
|
|
))
|
|
|
|
return meters
|