initial commit
This commit is contained in:
399
hameter/config.py
Normal file
399
hameter/config.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""Configuration loading, validation, and persistence for HAMeter.
|
||||
|
||||
All configuration is managed through the web UI and stored as JSON
|
||||
at /data/config.json. A YAML config file can be imported as a
|
||||
one-time migration.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VALID_PROTOCOLS = {"scm", "scm+", "idm", "netidm", "r900", "r900bcd"}
|
||||
VALID_RATE_TYPES = {"per_unit", "fixed"}
|
||||
|
||||
# Default icons/unit per common meter type.
|
||||
_METER_DEFAULTS = {
|
||||
"energy": {"icon": "mdi:flash", "unit": "kWh"},
|
||||
"gas": {"icon": "mdi:fire", "unit": "ft\u00b3"},
|
||||
"water": {"icon": "mdi:water", "unit": "gal"},
|
||||
}
|
||||
|
||||
CONFIG_PATH = "/data/config.json"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Dataclasses
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
@dataclass
|
||||
class MqttConfig:
|
||||
host: str
|
||||
port: int = 1883
|
||||
user: str = ""
|
||||
password: str = ""
|
||||
base_topic: str = "hameter"
|
||||
ha_autodiscovery: bool = True
|
||||
ha_autodiscovery_topic: str = "homeassistant"
|
||||
client_id: str = "hameter"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateComponent:
|
||||
name: str
|
||||
rate: float
|
||||
type: str = "per_unit"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MeterConfig:
|
||||
id: int
|
||||
protocol: str
|
||||
name: str
|
||||
unit_of_measurement: str
|
||||
icon: str = "mdi:gauge"
|
||||
device_class: str = ""
|
||||
state_class: str = "total_increasing"
|
||||
multiplier: float = 1.0
|
||||
cost_factors: list[RateComponent] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneralConfig:
|
||||
sleep_for: int = 0
|
||||
device_id: str = "0"
|
||||
rtl_tcp_host: str = "127.0.0.1"
|
||||
rtl_tcp_port: int = 1234
|
||||
log_level: str = "INFO"
|
||||
rtlamr_extra_args: list = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HaMeterConfig:
|
||||
general: GeneralConfig
|
||||
mqtt: MqttConfig
|
||||
meters: list[MeterConfig]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Config file operations
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def config_exists(path: Optional[str] = None) -> bool:
|
||||
"""Check if a config file exists at the standard path."""
|
||||
return os.path.isfile(path or CONFIG_PATH)
|
||||
|
||||
|
||||
def load_config_from_json(path: Optional[str] = None) -> HaMeterConfig:
|
||||
"""Load configuration from a JSON file.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If config file does not exist.
|
||||
ValueError: For validation errors.
|
||||
"""
|
||||
path = path or CONFIG_PATH
|
||||
with open(path) as f:
|
||||
raw = json.load(f)
|
||||
|
||||
if not isinstance(raw, dict):
|
||||
raise ValueError(f"Config file is not a valid JSON object: {path}")
|
||||
|
||||
return _build_config_from_dict(raw)
|
||||
|
||||
|
||||
def load_config_from_yaml(path: str) -> HaMeterConfig:
|
||||
"""Load from YAML file (migration path from older versions).
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If YAML file does not exist.
|
||||
ValueError: For validation errors.
|
||||
"""
|
||||
with open(path) as f:
|
||||
raw = yaml.safe_load(f) or {}
|
||||
|
||||
if not isinstance(raw, dict):
|
||||
raise ValueError(f"YAML file is not a valid mapping: {path}")
|
||||
|
||||
return _build_config_from_dict(raw)
|
||||
|
||||
|
||||
def save_config(config: HaMeterConfig, path: Optional[str] = None):
|
||||
"""Atomically write config to JSON file.
|
||||
|
||||
Writes to a temp file in the same directory, then os.replace()
|
||||
for atomic rename. This prevents corruption on power loss.
|
||||
"""
|
||||
path = path or CONFIG_PATH
|
||||
data = config_to_dict(config)
|
||||
|
||||
dir_path = os.path.dirname(path)
|
||||
if dir_path:
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
|
||||
fd, tmp_path = tempfile.mkstemp(dir=dir_path or ".", suffix=".tmp")
|
||||
try:
|
||||
with os.fdopen(fd, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
f.write("\n")
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(tmp_path, path)
|
||||
except Exception:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
def config_to_dict(config: HaMeterConfig) -> dict:
|
||||
"""Serialize HaMeterConfig to a JSON-safe dict."""
|
||||
return {
|
||||
"general": {
|
||||
"sleep_for": config.general.sleep_for,
|
||||
"device_id": config.general.device_id,
|
||||
"rtl_tcp_host": config.general.rtl_tcp_host,
|
||||
"rtl_tcp_port": config.general.rtl_tcp_port,
|
||||
"log_level": config.general.log_level,
|
||||
"rtlamr_extra_args": config.general.rtlamr_extra_args,
|
||||
},
|
||||
"mqtt": {
|
||||
"host": config.mqtt.host,
|
||||
"port": config.mqtt.port,
|
||||
"user": config.mqtt.user,
|
||||
"password": config.mqtt.password,
|
||||
"base_topic": config.mqtt.base_topic,
|
||||
"ha_autodiscovery": config.mqtt.ha_autodiscovery,
|
||||
"ha_autodiscovery_topic": config.mqtt.ha_autodiscovery_topic,
|
||||
"client_id": config.mqtt.client_id,
|
||||
},
|
||||
"meters": [
|
||||
{
|
||||
"id": m.id,
|
||||
"protocol": m.protocol,
|
||||
"name": m.name,
|
||||
"unit_of_measurement": m.unit_of_measurement,
|
||||
"icon": m.icon,
|
||||
"device_class": m.device_class,
|
||||
"state_class": m.state_class,
|
||||
"multiplier": m.multiplier,
|
||||
"cost_factors": [
|
||||
{"name": cf.name, "rate": cf.rate, "type": cf.type}
|
||||
for cf in m.cost_factors
|
||||
],
|
||||
}
|
||||
for m in config.meters
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Validation helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def validate_mqtt_config(data: dict) -> tuple[bool, str]:
|
||||
"""Validate MQTT config fields, return (ok, error_message)."""
|
||||
if not data.get("host", "").strip():
|
||||
return False, "MQTT host is required"
|
||||
port = data.get("port", 1883)
|
||||
try:
|
||||
port = int(port)
|
||||
if not (1 <= port <= 65535):
|
||||
raise ValueError
|
||||
except (ValueError, TypeError):
|
||||
return False, f"Invalid port: {port}"
|
||||
return True, ""
|
||||
|
||||
|
||||
def validate_meter_config(data: dict) -> tuple[bool, str]:
|
||||
"""Validate a single meter config dict, return (ok, error_message)."""
|
||||
if not data.get("id"):
|
||||
return False, "Meter ID is required"
|
||||
try:
|
||||
int(data["id"])
|
||||
except (ValueError, TypeError):
|
||||
return False, f"Meter ID must be a number: {data['id']}"
|
||||
|
||||
protocol = data.get("protocol", "").lower()
|
||||
if protocol not in VALID_PROTOCOLS:
|
||||
return False, (
|
||||
f"Invalid protocol: {protocol}. "
|
||||
f"Valid: {', '.join(sorted(VALID_PROTOCOLS))}"
|
||||
)
|
||||
|
||||
if not data.get("name", "").strip():
|
||||
return False, "Meter name is required"
|
||||
|
||||
multiplier = data.get("multiplier", 1.0)
|
||||
try:
|
||||
multiplier = float(multiplier)
|
||||
except (ValueError, TypeError):
|
||||
return False, f"Multiplier must be a number: {data.get('multiplier')}"
|
||||
if multiplier <= 0:
|
||||
return False, f"Multiplier must be positive, got {multiplier}"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def validate_rate_component(data: dict) -> tuple[bool, str]:
|
||||
"""Validate a single rate component dict, return (ok, error_message)."""
|
||||
if not data.get("name", "").strip():
|
||||
return False, "Rate component name is required"
|
||||
try:
|
||||
float(data.get("rate", ""))
|
||||
except (ValueError, TypeError):
|
||||
return False, f"Rate must be a number: {data.get('rate')}"
|
||||
comp_type = data.get("type", "per_unit")
|
||||
if comp_type not in VALID_RATE_TYPES:
|
||||
return False, f"Invalid rate type: {comp_type}. Must be 'per_unit' or 'fixed'"
|
||||
return True, ""
|
||||
|
||||
|
||||
def get_meter_defaults(device_class: str) -> dict:
|
||||
"""Get smart defaults (icon, unit) for a device class."""
|
||||
return dict(_METER_DEFAULTS.get(device_class, {}))
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Internal builders
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _build_config_from_dict(raw: dict) -> HaMeterConfig:
|
||||
"""Build HaMeterConfig from a raw dict (from JSON or YAML)."""
|
||||
general = _build_general_from_dict(raw)
|
||||
mqtt = _build_mqtt_from_dict(raw)
|
||||
meters = _build_meters_from_dict(raw)
|
||||
return HaMeterConfig(general=general, mqtt=mqtt, meters=meters)
|
||||
|
||||
|
||||
def _build_general_from_dict(raw: dict) -> GeneralConfig:
|
||||
"""Build GeneralConfig from raw dict."""
|
||||
g = raw.get("general", {})
|
||||
extra_args = g.get("rtlamr_extra_args", [])
|
||||
if isinstance(extra_args, str):
|
||||
extra_args = extra_args.split()
|
||||
|
||||
rtl_tcp_port = int(g.get("rtl_tcp_port", 1234))
|
||||
if not (1 <= rtl_tcp_port <= 65535):
|
||||
raise ValueError(
|
||||
f"rtl_tcp_port must be 1-65535, got {rtl_tcp_port}"
|
||||
)
|
||||
|
||||
log_level = str(g.get("log_level", "INFO")).upper()
|
||||
if log_level not in ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"):
|
||||
raise ValueError(
|
||||
f"Invalid log_level '{log_level}'. "
|
||||
f"Valid: DEBUG, INFO, WARNING, ERROR, CRITICAL"
|
||||
)
|
||||
|
||||
device_id = str(g.get("device_id", "0"))
|
||||
try:
|
||||
if int(device_id) < 0:
|
||||
raise ValueError
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError(
|
||||
f"device_id must be a non-negative integer, got '{device_id}'"
|
||||
)
|
||||
|
||||
return GeneralConfig(
|
||||
sleep_for=int(g.get("sleep_for", 0)),
|
||||
device_id=device_id,
|
||||
rtl_tcp_host=g.get("rtl_tcp_host", "127.0.0.1"),
|
||||
rtl_tcp_port=rtl_tcp_port,
|
||||
log_level=log_level,
|
||||
rtlamr_extra_args=list(extra_args),
|
||||
)
|
||||
|
||||
|
||||
def _build_mqtt_from_dict(raw: dict) -> MqttConfig:
|
||||
"""Build MqttConfig from raw dict."""
|
||||
m = raw.get("mqtt") or {}
|
||||
|
||||
host = m.get("host", "")
|
||||
if not host:
|
||||
raise ValueError(
|
||||
"MQTT host not set. Configure it via the web UI."
|
||||
)
|
||||
|
||||
port = int(m.get("port", 1883))
|
||||
if not (1 <= port <= 65535):
|
||||
raise ValueError(f"MQTT port must be 1-65535, got {port}")
|
||||
|
||||
return MqttConfig(
|
||||
host=host,
|
||||
port=port,
|
||||
user=m.get("user", ""),
|
||||
password=m.get("password", ""),
|
||||
base_topic=m.get("base_topic", "hameter"),
|
||||
ha_autodiscovery=m.get("ha_autodiscovery", True),
|
||||
ha_autodiscovery_topic=m.get("ha_autodiscovery_topic", "homeassistant"),
|
||||
client_id=m.get("client_id", "hameter"),
|
||||
)
|
||||
|
||||
|
||||
def _build_meters_from_dict(raw: dict) -> list[MeterConfig]:
|
||||
"""Build meter list from raw dict."""
|
||||
meters_raw = raw.get("meters")
|
||||
if not meters_raw:
|
||||
return []
|
||||
|
||||
meters = []
|
||||
seen_ids: set[int] = set()
|
||||
for i, m in enumerate(meters_raw):
|
||||
meter_id = m.get("id")
|
||||
if meter_id is None:
|
||||
raise ValueError(f"Meter #{i + 1} missing required 'id'")
|
||||
|
||||
mid_int = int(meter_id)
|
||||
if mid_int in seen_ids:
|
||||
raise ValueError(
|
||||
f"Meter #{i + 1} has duplicate id {mid_int}"
|
||||
)
|
||||
seen_ids.add(mid_int)
|
||||
|
||||
protocol = m.get("protocol", "").lower()
|
||||
if protocol not in VALID_PROTOCOLS:
|
||||
raise ValueError(
|
||||
f"Meter #{i + 1} (id={meter_id}) has invalid protocol "
|
||||
f"'{protocol}'. Valid: {', '.join(sorted(VALID_PROTOCOLS))}"
|
||||
)
|
||||
|
||||
name = m.get("name", "")
|
||||
if not name:
|
||||
raise ValueError(
|
||||
f"Meter #{i + 1} (id={meter_id}) missing required 'name'"
|
||||
)
|
||||
|
||||
# Apply smart defaults based on device_class.
|
||||
device_class = m.get("device_class", "")
|
||||
defaults = _METER_DEFAULTS.get(device_class, {})
|
||||
|
||||
cost_factors = [
|
||||
RateComponent(
|
||||
name=cf.get("name", ""),
|
||||
rate=float(cf.get("rate", 0.0)),
|
||||
type=cf.get("type", "per_unit"),
|
||||
)
|
||||
for cf in m.get("cost_factors", [])
|
||||
]
|
||||
|
||||
meters.append(MeterConfig(
|
||||
id=int(meter_id),
|
||||
protocol=protocol,
|
||||
name=name,
|
||||
unit_of_measurement=m.get("unit_of_measurement", "") or defaults.get("unit", ""),
|
||||
icon=m.get("icon", "") or defaults.get("icon", "mdi:gauge"),
|
||||
device_class=device_class,
|
||||
state_class=m.get("state_class", "total_increasing"),
|
||||
multiplier=float(m.get("multiplier", 1.0)),
|
||||
cost_factors=cost_factors,
|
||||
))
|
||||
|
||||
return meters
|
||||
Reference in New Issue
Block a user