Files
HAMeter/tests/test_config.py
2026-03-06 12:25:27 -05:00

513 lines
16 KiB
Python

"""Tests for hameter.config."""
import json
import os
from pathlib import Path
import pytest
import yaml
from hameter.config import (
HaMeterConfig,
GeneralConfig,
MeterConfig,
MqttConfig,
RateComponent,
config_exists,
config_to_dict,
load_config_from_json,
load_config_from_yaml,
save_config,
validate_meter_config,
validate_mqtt_config,
validate_rate_component,
get_meter_defaults,
)
def _write_json_config(tmp_path: Path, data: dict) -> str:
"""Write a JSON config dict to a temp file and return the path."""
p = tmp_path / "config.json"
p.write_text(json.dumps(data))
return str(p)
def _write_yaml_config(tmp_path: Path, data: dict) -> str:
"""Write a YAML config dict to a temp file and return the path."""
p = tmp_path / "hameter.yaml"
p.write_text(yaml.dump(data))
return str(p)
VALID_CONFIG = {
"general": {
"sleep_for": 0,
"device_id": "0",
"log_level": "DEBUG",
},
"mqtt": {
"host": "192.168.1.74",
"port": 1883,
"base_topic": "hameter",
},
"meters": [
{
"id": 23040293,
"protocol": "scm",
"name": "Electric Meter",
"unit_of_measurement": "kWh",
"multiplier": 0.1156,
"device_class": "energy",
},
],
}
class TestLoadConfigFromJson:
"""Tests for JSON-based configuration."""
def test_valid_config(self, tmp_path):
path = _write_json_config(tmp_path, VALID_CONFIG)
cfg = load_config_from_json(path)
assert cfg.general.sleep_for == 0
assert cfg.general.device_id == "0"
assert cfg.general.log_level == "DEBUG"
assert cfg.mqtt.host == "192.168.1.74"
assert cfg.mqtt.port == 1883
assert cfg.mqtt.base_topic == "hameter"
assert len(cfg.meters) == 1
m = cfg.meters[0]
assert m.id == 23040293
assert m.protocol == "scm"
assert m.name == "Electric Meter"
assert m.multiplier == 0.1156
assert m.device_class == "energy"
def test_defaults_applied(self, tmp_path):
data = {
"mqtt": {"host": "10.0.0.1"},
"meters": [
{"id": 123, "protocol": "scm", "name": "Test"},
],
}
path = _write_json_config(tmp_path, data)
cfg = load_config_from_json(path)
assert cfg.general.sleep_for == 0
assert cfg.general.device_id == "0"
assert cfg.general.log_level == "INFO"
assert cfg.general.rtl_tcp_host == "127.0.0.1"
assert cfg.general.rtl_tcp_port == 1234
assert cfg.mqtt.port == 1883
assert cfg.mqtt.base_topic == "hameter"
assert cfg.mqtt.ha_autodiscovery is True
m = cfg.meters[0]
assert m.multiplier == 1.0
assert m.state_class == "total_increasing"
assert m.icon == "mdi:gauge"
def test_missing_mqtt_host_raises(self, tmp_path):
data = {
"mqtt": {"port": 1883},
"meters": [{"id": 1, "protocol": "scm", "name": "X"}],
}
path = _write_json_config(tmp_path, data)
with pytest.raises(ValueError, match="MQTT host"):
load_config_from_json(path)
def test_empty_meters_is_valid(self, tmp_path):
data = {"mqtt": {"host": "10.0.0.1"}, "meters": []}
path = _write_json_config(tmp_path, data)
cfg = load_config_from_json(path)
assert cfg.meters == []
def test_no_meters_key_is_valid(self, tmp_path):
data = {"mqtt": {"host": "10.0.0.1"}}
path = _write_json_config(tmp_path, data)
cfg = load_config_from_json(path)
assert cfg.meters == []
def test_invalid_protocol_raises(self, tmp_path):
data = {
"mqtt": {"host": "10.0.0.1"},
"meters": [{"id": 1, "protocol": "invalid", "name": "X"}],
}
path = _write_json_config(tmp_path, data)
with pytest.raises(ValueError, match="invalid protocol"):
load_config_from_json(path)
def test_multiple_meters(self, tmp_path):
data = {
"mqtt": {"host": "10.0.0.1"},
"meters": [
{"id": 111, "protocol": "scm", "name": "Electric", "multiplier": 0.5},
{"id": 222, "protocol": "r900", "name": "Water", "multiplier": 1.0},
],
}
path = _write_json_config(tmp_path, data)
cfg = load_config_from_json(path)
assert len(cfg.meters) == 2
assert cfg.meters[0].id == 111
assert cfg.meters[0].multiplier == 0.5
assert cfg.meters[1].id == 222
assert cfg.meters[1].protocol == "r900"
def test_smart_defaults_energy(self, tmp_path):
data = {
"mqtt": {"host": "10.0.0.1"},
"meters": [
{"id": 1, "protocol": "scm", "name": "E", "device_class": "energy"},
],
}
path = _write_json_config(tmp_path, data)
cfg = load_config_from_json(path)
assert cfg.meters[0].icon == "mdi:flash"
assert cfg.meters[0].unit_of_measurement == "kWh"
def test_smart_defaults_gas(self, tmp_path):
data = {
"mqtt": {"host": "10.0.0.1"},
"meters": [
{"id": 1, "protocol": "scm", "name": "G", "device_class": "gas"},
],
}
path = _write_json_config(tmp_path, data)
cfg = load_config_from_json(path)
assert cfg.meters[0].icon == "mdi:fire"
assert cfg.meters[0].unit_of_measurement == "ft\u00b3"
def test_smart_defaults_water(self, tmp_path):
data = {
"mqtt": {"host": "10.0.0.1"},
"meters": [
{"id": 1, "protocol": "r900", "name": "W", "device_class": "water"},
],
}
path = _write_json_config(tmp_path, data)
cfg = load_config_from_json(path)
assert cfg.meters[0].icon == "mdi:water"
assert cfg.meters[0].unit_of_measurement == "gal"
def test_file_not_found_raises(self, tmp_path):
with pytest.raises(FileNotFoundError):
load_config_from_json(str(tmp_path / "nonexistent.json"))
class TestLoadConfigFromYaml:
"""Tests for YAML migration path."""
def test_valid_yaml(self, tmp_path):
path = _write_yaml_config(tmp_path, VALID_CONFIG)
cfg = load_config_from_yaml(path)
assert cfg.mqtt.host == "192.168.1.74"
assert len(cfg.meters) == 1
assert cfg.meters[0].id == 23040293
class TestSaveConfig:
"""Tests for atomic config saving."""
def test_save_and_reload(self, tmp_path):
config = HaMeterConfig(
general=GeneralConfig(),
mqtt=MqttConfig(host="10.0.0.1"),
meters=[
MeterConfig(
id=123, protocol="scm", name="Test",
unit_of_measurement="kWh",
),
],
)
path = str(tmp_path / "config.json")
save_config(config, path)
loaded = load_config_from_json(path)
assert loaded.mqtt.host == "10.0.0.1"
assert loaded.meters[0].id == 123
def test_creates_parent_directory(self, tmp_path):
config = HaMeterConfig(
general=GeneralConfig(),
mqtt=MqttConfig(host="10.0.0.1"),
meters=[],
)
path = str(tmp_path / "subdir" / "config.json")
save_config(config, path)
assert os.path.isfile(path)
def test_no_temp_file_left_on_success(self, tmp_path):
config = HaMeterConfig(
general=GeneralConfig(),
mqtt=MqttConfig(host="10.0.0.1"),
meters=[],
)
path = str(tmp_path / "config.json")
save_config(config, path)
# Only the config file should exist, no .tmp files
files = list(tmp_path.iterdir())
assert len(files) == 1
assert files[0].name == "config.json"
class TestConfigToDict:
"""Tests for serialization."""
def test_round_trip(self, tmp_path):
config = HaMeterConfig(
general=GeneralConfig(log_level="DEBUG", device_id="1"),
mqtt=MqttConfig(host="10.0.0.1", port=1884),
meters=[
MeterConfig(
id=999, protocol="r900", name="Water",
unit_of_measurement="gal", multiplier=2.5,
),
],
)
d = config_to_dict(config)
assert d["general"]["log_level"] == "DEBUG"
assert d["mqtt"]["host"] == "10.0.0.1"
assert d["meters"][0]["id"] == 999
assert d["meters"][0]["multiplier"] == 2.5
# Write and reload
path = str(tmp_path / "config.json")
with open(path, "w") as f:
json.dump(d, f)
loaded = load_config_from_json(path)
assert loaded.general.log_level == "DEBUG"
assert loaded.meters[0].multiplier == 2.5
class TestConfigExists:
"""Tests for config_exists."""
def test_exists(self, tmp_path):
path = str(tmp_path / "config.json")
Path(path).write_text("{}")
assert config_exists(path)
def test_not_exists(self, tmp_path):
assert not config_exists(str(tmp_path / "missing.json"))
class TestValidateMqttConfig:
"""Tests for validate_mqtt_config."""
def test_valid(self):
ok, err = validate_mqtt_config({"host": "10.0.0.1", "port": 1883})
assert ok
assert err == ""
def test_missing_host(self):
ok, err = validate_mqtt_config({"port": 1883})
assert not ok
assert "host" in err.lower()
def test_empty_host(self):
ok, err = validate_mqtt_config({"host": "", "port": 1883})
assert not ok
def test_invalid_port(self):
ok, err = validate_mqtt_config({"host": "x", "port": 99999})
assert not ok
assert "port" in err.lower()
def test_string_port(self):
ok, err = validate_mqtt_config({"host": "x", "port": "abc"})
assert not ok
class TestValidateMeterConfig:
"""Tests for validate_meter_config."""
def test_valid(self):
ok, err = validate_meter_config(
{"id": 123, "protocol": "scm", "name": "Test"}
)
assert ok
def test_missing_id(self):
ok, err = validate_meter_config(
{"protocol": "scm", "name": "Test"}
)
assert not ok
def test_invalid_protocol(self):
ok, err = validate_meter_config(
{"id": 1, "protocol": "bad", "name": "Test"}
)
assert not ok
assert "protocol" in err.lower()
def test_missing_name(self):
ok, err = validate_meter_config(
{"id": 1, "protocol": "scm"}
)
assert not ok
assert "name" in err.lower()
def test_invalid_multiplier(self):
ok, err = validate_meter_config(
{"id": 1, "protocol": "scm", "name": "X", "multiplier": "bad"}
)
assert not ok
assert "multiplier" in err.lower()
class TestCostFactors:
"""Tests for cost_factors serialization and deserialization."""
def test_round_trip_with_cost_factors(self, tmp_path):
config = HaMeterConfig(
general=GeneralConfig(),
mqtt=MqttConfig(host="10.0.0.1"),
meters=[
MeterConfig(
id=123, protocol="scm", name="Electric",
unit_of_measurement="kWh",
cost_factors=[
RateComponent(name="Supply", rate=0.14742, type="per_unit"),
RateComponent(name="Customer Charge", rate=9.65, type="fixed"),
],
),
],
)
path = str(tmp_path / "config.json")
save_config(config, path)
loaded = load_config_from_json(path)
assert len(loaded.meters[0].cost_factors) == 2
cf0 = loaded.meters[0].cost_factors[0]
assert cf0.name == "Supply"
assert cf0.rate == 0.14742
assert cf0.type == "per_unit"
cf1 = loaded.meters[0].cost_factors[1]
assert cf1.name == "Customer Charge"
assert cf1.rate == 9.65
assert cf1.type == "fixed"
def test_no_cost_factors_defaults_empty(self, tmp_path):
data = {
"mqtt": {"host": "10.0.0.1"},
"meters": [{"id": 1, "protocol": "scm", "name": "Test"}],
}
path = _write_json_config(tmp_path, data)
cfg = load_config_from_json(path)
assert cfg.meters[0].cost_factors == []
def test_config_to_dict_includes_cost_factors(self):
config = HaMeterConfig(
general=GeneralConfig(),
mqtt=MqttConfig(host="10.0.0.1"),
meters=[
MeterConfig(
id=1, protocol="scm", name="Test",
unit_of_measurement="kWh",
cost_factors=[
RateComponent(name="Rate", rate=0.10, type="per_unit"),
],
),
],
)
d = config_to_dict(config)
assert len(d["meters"][0]["cost_factors"]) == 1
assert d["meters"][0]["cost_factors"][0] == {
"name": "Rate", "rate": 0.10, "type": "per_unit",
}
def test_config_to_dict_empty_cost_factors(self):
config = HaMeterConfig(
general=GeneralConfig(),
mqtt=MqttConfig(host="10.0.0.1"),
meters=[
MeterConfig(
id=1, protocol="scm", name="Test",
unit_of_measurement="kWh",
),
],
)
d = config_to_dict(config)
assert d["meters"][0]["cost_factors"] == []
def test_negative_rate(self, tmp_path):
config = HaMeterConfig(
general=GeneralConfig(),
mqtt=MqttConfig(host="10.0.0.1"),
meters=[
MeterConfig(
id=1, protocol="scm", name="Test",
unit_of_measurement="kWh",
cost_factors=[
RateComponent(name="Credit", rate=-0.00077, type="per_unit"),
],
),
],
)
path = str(tmp_path / "config.json")
save_config(config, path)
loaded = load_config_from_json(path)
assert loaded.meters[0].cost_factors[0].rate == -0.00077
class TestValidateRateComponent:
"""Tests for validate_rate_component."""
def test_valid_per_unit(self):
ok, err = validate_rate_component(
{"name": "Supply", "rate": 0.14742, "type": "per_unit"}
)
assert ok
def test_valid_fixed(self):
ok, err = validate_rate_component(
{"name": "Customer Charge", "rate": 9.65, "type": "fixed"}
)
assert ok
def test_missing_name(self):
ok, err = validate_rate_component({"rate": 0.10})
assert not ok
assert "name" in err.lower()
def test_invalid_rate(self):
ok, err = validate_rate_component({"name": "X", "rate": "bad"})
assert not ok
assert "rate" in err.lower()
def test_invalid_type(self):
ok, err = validate_rate_component(
{"name": "X", "rate": 0.10, "type": "tiered"}
)
assert not ok
assert "type" in err.lower()
def test_defaults_to_per_unit(self):
ok, err = validate_rate_component({"name": "X", "rate": 0.10})
assert ok
class TestGetMeterDefaults:
"""Tests for get_meter_defaults."""
def test_energy(self):
d = get_meter_defaults("energy")
assert d["icon"] == "mdi:flash"
assert d["unit"] == "kWh"
def test_gas(self):
d = get_meter_defaults("gas")
assert d["icon"] == "mdi:fire"
def test_water(self):
d = get_meter_defaults("water")
assert d["icon"] == "mdi:water"
assert d["unit"] == "gal"
def test_unknown(self):
d = get_meter_defaults("unknown")
assert d == {}