513 lines
16 KiB
Python
513 lines
16 KiB
Python
"""Tests for hameter.config."""
|
|
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
import yaml
|
|
|
|
from hameter.config import (
|
|
HaMeterConfig,
|
|
GeneralConfig,
|
|
MeterConfig,
|
|
MqttConfig,
|
|
RateComponent,
|
|
config_exists,
|
|
config_to_dict,
|
|
load_config_from_json,
|
|
load_config_from_yaml,
|
|
save_config,
|
|
validate_meter_config,
|
|
validate_mqtt_config,
|
|
validate_rate_component,
|
|
get_meter_defaults,
|
|
)
|
|
|
|
|
|
def _write_json_config(tmp_path: Path, data: dict) -> str:
|
|
"""Write a JSON config dict to a temp file and return the path."""
|
|
p = tmp_path / "config.json"
|
|
p.write_text(json.dumps(data))
|
|
return str(p)
|
|
|
|
|
|
def _write_yaml_config(tmp_path: Path, data: dict) -> str:
|
|
"""Write a YAML config dict to a temp file and return the path."""
|
|
p = tmp_path / "hameter.yaml"
|
|
p.write_text(yaml.dump(data))
|
|
return str(p)
|
|
|
|
|
|
VALID_CONFIG = {
|
|
"general": {
|
|
"sleep_for": 0,
|
|
"device_id": "0",
|
|
"log_level": "DEBUG",
|
|
},
|
|
"mqtt": {
|
|
"host": "192.168.1.74",
|
|
"port": 1883,
|
|
"base_topic": "hameter",
|
|
},
|
|
"meters": [
|
|
{
|
|
"id": 23040293,
|
|
"protocol": "scm",
|
|
"name": "Electric Meter",
|
|
"unit_of_measurement": "kWh",
|
|
"multiplier": 0.1156,
|
|
"device_class": "energy",
|
|
},
|
|
],
|
|
}
|
|
|
|
|
|
class TestLoadConfigFromJson:
|
|
"""Tests for JSON-based configuration."""
|
|
|
|
def test_valid_config(self, tmp_path):
|
|
path = _write_json_config(tmp_path, VALID_CONFIG)
|
|
cfg = load_config_from_json(path)
|
|
|
|
assert cfg.general.sleep_for == 0
|
|
assert cfg.general.device_id == "0"
|
|
assert cfg.general.log_level == "DEBUG"
|
|
|
|
assert cfg.mqtt.host == "192.168.1.74"
|
|
assert cfg.mqtt.port == 1883
|
|
assert cfg.mqtt.base_topic == "hameter"
|
|
|
|
assert len(cfg.meters) == 1
|
|
m = cfg.meters[0]
|
|
assert m.id == 23040293
|
|
assert m.protocol == "scm"
|
|
assert m.name == "Electric Meter"
|
|
assert m.multiplier == 0.1156
|
|
assert m.device_class == "energy"
|
|
|
|
def test_defaults_applied(self, tmp_path):
|
|
data = {
|
|
"mqtt": {"host": "10.0.0.1"},
|
|
"meters": [
|
|
{"id": 123, "protocol": "scm", "name": "Test"},
|
|
],
|
|
}
|
|
path = _write_json_config(tmp_path, data)
|
|
cfg = load_config_from_json(path)
|
|
|
|
assert cfg.general.sleep_for == 0
|
|
assert cfg.general.device_id == "0"
|
|
assert cfg.general.log_level == "INFO"
|
|
assert cfg.general.rtl_tcp_host == "127.0.0.1"
|
|
assert cfg.general.rtl_tcp_port == 1234
|
|
|
|
assert cfg.mqtt.port == 1883
|
|
assert cfg.mqtt.base_topic == "hameter"
|
|
assert cfg.mqtt.ha_autodiscovery is True
|
|
|
|
m = cfg.meters[0]
|
|
assert m.multiplier == 1.0
|
|
assert m.state_class == "total_increasing"
|
|
assert m.icon == "mdi:gauge"
|
|
|
|
def test_missing_mqtt_host_raises(self, tmp_path):
|
|
data = {
|
|
"mqtt": {"port": 1883},
|
|
"meters": [{"id": 1, "protocol": "scm", "name": "X"}],
|
|
}
|
|
path = _write_json_config(tmp_path, data)
|
|
with pytest.raises(ValueError, match="MQTT host"):
|
|
load_config_from_json(path)
|
|
|
|
def test_empty_meters_is_valid(self, tmp_path):
|
|
data = {"mqtt": {"host": "10.0.0.1"}, "meters": []}
|
|
path = _write_json_config(tmp_path, data)
|
|
cfg = load_config_from_json(path)
|
|
assert cfg.meters == []
|
|
|
|
def test_no_meters_key_is_valid(self, tmp_path):
|
|
data = {"mqtt": {"host": "10.0.0.1"}}
|
|
path = _write_json_config(tmp_path, data)
|
|
cfg = load_config_from_json(path)
|
|
assert cfg.meters == []
|
|
|
|
def test_invalid_protocol_raises(self, tmp_path):
|
|
data = {
|
|
"mqtt": {"host": "10.0.0.1"},
|
|
"meters": [{"id": 1, "protocol": "invalid", "name": "X"}],
|
|
}
|
|
path = _write_json_config(tmp_path, data)
|
|
with pytest.raises(ValueError, match="invalid protocol"):
|
|
load_config_from_json(path)
|
|
|
|
def test_multiple_meters(self, tmp_path):
|
|
data = {
|
|
"mqtt": {"host": "10.0.0.1"},
|
|
"meters": [
|
|
{"id": 111, "protocol": "scm", "name": "Electric", "multiplier": 0.5},
|
|
{"id": 222, "protocol": "r900", "name": "Water", "multiplier": 1.0},
|
|
],
|
|
}
|
|
path = _write_json_config(tmp_path, data)
|
|
cfg = load_config_from_json(path)
|
|
assert len(cfg.meters) == 2
|
|
assert cfg.meters[0].id == 111
|
|
assert cfg.meters[0].multiplier == 0.5
|
|
assert cfg.meters[1].id == 222
|
|
assert cfg.meters[1].protocol == "r900"
|
|
|
|
def test_smart_defaults_energy(self, tmp_path):
|
|
data = {
|
|
"mqtt": {"host": "10.0.0.1"},
|
|
"meters": [
|
|
{"id": 1, "protocol": "scm", "name": "E", "device_class": "energy"},
|
|
],
|
|
}
|
|
path = _write_json_config(tmp_path, data)
|
|
cfg = load_config_from_json(path)
|
|
assert cfg.meters[0].icon == "mdi:flash"
|
|
assert cfg.meters[0].unit_of_measurement == "kWh"
|
|
|
|
def test_smart_defaults_gas(self, tmp_path):
|
|
data = {
|
|
"mqtt": {"host": "10.0.0.1"},
|
|
"meters": [
|
|
{"id": 1, "protocol": "scm", "name": "G", "device_class": "gas"},
|
|
],
|
|
}
|
|
path = _write_json_config(tmp_path, data)
|
|
cfg = load_config_from_json(path)
|
|
assert cfg.meters[0].icon == "mdi:fire"
|
|
assert cfg.meters[0].unit_of_measurement == "ft\u00b3"
|
|
|
|
def test_smart_defaults_water(self, tmp_path):
|
|
data = {
|
|
"mqtt": {"host": "10.0.0.1"},
|
|
"meters": [
|
|
{"id": 1, "protocol": "r900", "name": "W", "device_class": "water"},
|
|
],
|
|
}
|
|
path = _write_json_config(tmp_path, data)
|
|
cfg = load_config_from_json(path)
|
|
assert cfg.meters[0].icon == "mdi:water"
|
|
assert cfg.meters[0].unit_of_measurement == "gal"
|
|
|
|
def test_file_not_found_raises(self, tmp_path):
|
|
with pytest.raises(FileNotFoundError):
|
|
load_config_from_json(str(tmp_path / "nonexistent.json"))
|
|
|
|
|
|
class TestLoadConfigFromYaml:
|
|
"""Tests for YAML migration path."""
|
|
|
|
def test_valid_yaml(self, tmp_path):
|
|
path = _write_yaml_config(tmp_path, VALID_CONFIG)
|
|
cfg = load_config_from_yaml(path)
|
|
assert cfg.mqtt.host == "192.168.1.74"
|
|
assert len(cfg.meters) == 1
|
|
assert cfg.meters[0].id == 23040293
|
|
|
|
|
|
class TestSaveConfig:
|
|
"""Tests for atomic config saving."""
|
|
|
|
def test_save_and_reload(self, tmp_path):
|
|
config = HaMeterConfig(
|
|
general=GeneralConfig(),
|
|
mqtt=MqttConfig(host="10.0.0.1"),
|
|
meters=[
|
|
MeterConfig(
|
|
id=123, protocol="scm", name="Test",
|
|
unit_of_measurement="kWh",
|
|
),
|
|
],
|
|
)
|
|
path = str(tmp_path / "config.json")
|
|
save_config(config, path)
|
|
loaded = load_config_from_json(path)
|
|
assert loaded.mqtt.host == "10.0.0.1"
|
|
assert loaded.meters[0].id == 123
|
|
|
|
def test_creates_parent_directory(self, tmp_path):
|
|
config = HaMeterConfig(
|
|
general=GeneralConfig(),
|
|
mqtt=MqttConfig(host="10.0.0.1"),
|
|
meters=[],
|
|
)
|
|
path = str(tmp_path / "subdir" / "config.json")
|
|
save_config(config, path)
|
|
assert os.path.isfile(path)
|
|
|
|
def test_no_temp_file_left_on_success(self, tmp_path):
|
|
config = HaMeterConfig(
|
|
general=GeneralConfig(),
|
|
mqtt=MqttConfig(host="10.0.0.1"),
|
|
meters=[],
|
|
)
|
|
path = str(tmp_path / "config.json")
|
|
save_config(config, path)
|
|
# Only the config file should exist, no .tmp files
|
|
files = list(tmp_path.iterdir())
|
|
assert len(files) == 1
|
|
assert files[0].name == "config.json"
|
|
|
|
|
|
class TestConfigToDict:
|
|
"""Tests for serialization."""
|
|
|
|
def test_round_trip(self, tmp_path):
|
|
config = HaMeterConfig(
|
|
general=GeneralConfig(log_level="DEBUG", device_id="1"),
|
|
mqtt=MqttConfig(host="10.0.0.1", port=1884),
|
|
meters=[
|
|
MeterConfig(
|
|
id=999, protocol="r900", name="Water",
|
|
unit_of_measurement="gal", multiplier=2.5,
|
|
),
|
|
],
|
|
)
|
|
d = config_to_dict(config)
|
|
assert d["general"]["log_level"] == "DEBUG"
|
|
assert d["mqtt"]["host"] == "10.0.0.1"
|
|
assert d["meters"][0]["id"] == 999
|
|
assert d["meters"][0]["multiplier"] == 2.5
|
|
|
|
# Write and reload
|
|
path = str(tmp_path / "config.json")
|
|
with open(path, "w") as f:
|
|
json.dump(d, f)
|
|
loaded = load_config_from_json(path)
|
|
assert loaded.general.log_level == "DEBUG"
|
|
assert loaded.meters[0].multiplier == 2.5
|
|
|
|
|
|
class TestConfigExists:
|
|
"""Tests for config_exists."""
|
|
|
|
def test_exists(self, tmp_path):
|
|
path = str(tmp_path / "config.json")
|
|
Path(path).write_text("{}")
|
|
assert config_exists(path)
|
|
|
|
def test_not_exists(self, tmp_path):
|
|
assert not config_exists(str(tmp_path / "missing.json"))
|
|
|
|
|
|
class TestValidateMqttConfig:
|
|
"""Tests for validate_mqtt_config."""
|
|
|
|
def test_valid(self):
|
|
ok, err = validate_mqtt_config({"host": "10.0.0.1", "port": 1883})
|
|
assert ok
|
|
assert err == ""
|
|
|
|
def test_missing_host(self):
|
|
ok, err = validate_mqtt_config({"port": 1883})
|
|
assert not ok
|
|
assert "host" in err.lower()
|
|
|
|
def test_empty_host(self):
|
|
ok, err = validate_mqtt_config({"host": "", "port": 1883})
|
|
assert not ok
|
|
|
|
def test_invalid_port(self):
|
|
ok, err = validate_mqtt_config({"host": "x", "port": 99999})
|
|
assert not ok
|
|
assert "port" in err.lower()
|
|
|
|
def test_string_port(self):
|
|
ok, err = validate_mqtt_config({"host": "x", "port": "abc"})
|
|
assert not ok
|
|
|
|
|
|
class TestValidateMeterConfig:
|
|
"""Tests for validate_meter_config."""
|
|
|
|
def test_valid(self):
|
|
ok, err = validate_meter_config(
|
|
{"id": 123, "protocol": "scm", "name": "Test"}
|
|
)
|
|
assert ok
|
|
|
|
def test_missing_id(self):
|
|
ok, err = validate_meter_config(
|
|
{"protocol": "scm", "name": "Test"}
|
|
)
|
|
assert not ok
|
|
|
|
def test_invalid_protocol(self):
|
|
ok, err = validate_meter_config(
|
|
{"id": 1, "protocol": "bad", "name": "Test"}
|
|
)
|
|
assert not ok
|
|
assert "protocol" in err.lower()
|
|
|
|
def test_missing_name(self):
|
|
ok, err = validate_meter_config(
|
|
{"id": 1, "protocol": "scm"}
|
|
)
|
|
assert not ok
|
|
assert "name" in err.lower()
|
|
|
|
def test_invalid_multiplier(self):
|
|
ok, err = validate_meter_config(
|
|
{"id": 1, "protocol": "scm", "name": "X", "multiplier": "bad"}
|
|
)
|
|
assert not ok
|
|
assert "multiplier" in err.lower()
|
|
|
|
|
|
class TestCostFactors:
|
|
"""Tests for cost_factors serialization and deserialization."""
|
|
|
|
def test_round_trip_with_cost_factors(self, tmp_path):
|
|
config = HaMeterConfig(
|
|
general=GeneralConfig(),
|
|
mqtt=MqttConfig(host="10.0.0.1"),
|
|
meters=[
|
|
MeterConfig(
|
|
id=123, protocol="scm", name="Electric",
|
|
unit_of_measurement="kWh",
|
|
cost_factors=[
|
|
RateComponent(name="Supply", rate=0.14742, type="per_unit"),
|
|
RateComponent(name="Customer Charge", rate=9.65, type="fixed"),
|
|
],
|
|
),
|
|
],
|
|
)
|
|
path = str(tmp_path / "config.json")
|
|
save_config(config, path)
|
|
loaded = load_config_from_json(path)
|
|
|
|
assert len(loaded.meters[0].cost_factors) == 2
|
|
cf0 = loaded.meters[0].cost_factors[0]
|
|
assert cf0.name == "Supply"
|
|
assert cf0.rate == 0.14742
|
|
assert cf0.type == "per_unit"
|
|
cf1 = loaded.meters[0].cost_factors[1]
|
|
assert cf1.name == "Customer Charge"
|
|
assert cf1.rate == 9.65
|
|
assert cf1.type == "fixed"
|
|
|
|
def test_no_cost_factors_defaults_empty(self, tmp_path):
|
|
data = {
|
|
"mqtt": {"host": "10.0.0.1"},
|
|
"meters": [{"id": 1, "protocol": "scm", "name": "Test"}],
|
|
}
|
|
path = _write_json_config(tmp_path, data)
|
|
cfg = load_config_from_json(path)
|
|
assert cfg.meters[0].cost_factors == []
|
|
|
|
def test_config_to_dict_includes_cost_factors(self):
|
|
config = HaMeterConfig(
|
|
general=GeneralConfig(),
|
|
mqtt=MqttConfig(host="10.0.0.1"),
|
|
meters=[
|
|
MeterConfig(
|
|
id=1, protocol="scm", name="Test",
|
|
unit_of_measurement="kWh",
|
|
cost_factors=[
|
|
RateComponent(name="Rate", rate=0.10, type="per_unit"),
|
|
],
|
|
),
|
|
],
|
|
)
|
|
d = config_to_dict(config)
|
|
assert len(d["meters"][0]["cost_factors"]) == 1
|
|
assert d["meters"][0]["cost_factors"][0] == {
|
|
"name": "Rate", "rate": 0.10, "type": "per_unit",
|
|
}
|
|
|
|
def test_config_to_dict_empty_cost_factors(self):
|
|
config = HaMeterConfig(
|
|
general=GeneralConfig(),
|
|
mqtt=MqttConfig(host="10.0.0.1"),
|
|
meters=[
|
|
MeterConfig(
|
|
id=1, protocol="scm", name="Test",
|
|
unit_of_measurement="kWh",
|
|
),
|
|
],
|
|
)
|
|
d = config_to_dict(config)
|
|
assert d["meters"][0]["cost_factors"] == []
|
|
|
|
def test_negative_rate(self, tmp_path):
|
|
config = HaMeterConfig(
|
|
general=GeneralConfig(),
|
|
mqtt=MqttConfig(host="10.0.0.1"),
|
|
meters=[
|
|
MeterConfig(
|
|
id=1, protocol="scm", name="Test",
|
|
unit_of_measurement="kWh",
|
|
cost_factors=[
|
|
RateComponent(name="Credit", rate=-0.00077, type="per_unit"),
|
|
],
|
|
),
|
|
],
|
|
)
|
|
path = str(tmp_path / "config.json")
|
|
save_config(config, path)
|
|
loaded = load_config_from_json(path)
|
|
assert loaded.meters[0].cost_factors[0].rate == -0.00077
|
|
|
|
|
|
class TestValidateRateComponent:
|
|
"""Tests for validate_rate_component."""
|
|
|
|
def test_valid_per_unit(self):
|
|
ok, err = validate_rate_component(
|
|
{"name": "Supply", "rate": 0.14742, "type": "per_unit"}
|
|
)
|
|
assert ok
|
|
|
|
def test_valid_fixed(self):
|
|
ok, err = validate_rate_component(
|
|
{"name": "Customer Charge", "rate": 9.65, "type": "fixed"}
|
|
)
|
|
assert ok
|
|
|
|
def test_missing_name(self):
|
|
ok, err = validate_rate_component({"rate": 0.10})
|
|
assert not ok
|
|
assert "name" in err.lower()
|
|
|
|
def test_invalid_rate(self):
|
|
ok, err = validate_rate_component({"name": "X", "rate": "bad"})
|
|
assert not ok
|
|
assert "rate" in err.lower()
|
|
|
|
def test_invalid_type(self):
|
|
ok, err = validate_rate_component(
|
|
{"name": "X", "rate": 0.10, "type": "tiered"}
|
|
)
|
|
assert not ok
|
|
assert "type" in err.lower()
|
|
|
|
def test_defaults_to_per_unit(self):
|
|
ok, err = validate_rate_component({"name": "X", "rate": 0.10})
|
|
assert ok
|
|
|
|
|
|
class TestGetMeterDefaults:
|
|
"""Tests for get_meter_defaults."""
|
|
|
|
def test_energy(self):
|
|
d = get_meter_defaults("energy")
|
|
assert d["icon"] == "mdi:flash"
|
|
assert d["unit"] == "kWh"
|
|
|
|
def test_gas(self):
|
|
d = get_meter_defaults("gas")
|
|
assert d["icon"] == "mdi:fire"
|
|
|
|
def test_water(self):
|
|
d = get_meter_defaults("water")
|
|
assert d["icon"] == "mdi:water"
|
|
assert d["unit"] == "gal"
|
|
|
|
def test_unknown(self):
|
|
d = get_meter_defaults("unknown")
|
|
assert d == {}
|