"""Tests for hameter.config.""" import json import os from pathlib import Path import pytest import yaml from hameter.config import ( HaMeterConfig, GeneralConfig, MeterConfig, MqttConfig, RateComponent, config_exists, config_to_dict, load_config_from_json, load_config_from_yaml, save_config, validate_meter_config, validate_mqtt_config, validate_rate_component, get_meter_defaults, ) def _write_json_config(tmp_path: Path, data: dict) -> str: """Write a JSON config dict to a temp file and return the path.""" p = tmp_path / "config.json" p.write_text(json.dumps(data)) return str(p) def _write_yaml_config(tmp_path: Path, data: dict) -> str: """Write a YAML config dict to a temp file and return the path.""" p = tmp_path / "hameter.yaml" p.write_text(yaml.dump(data)) return str(p) VALID_CONFIG = { "general": { "sleep_for": 0, "device_id": "0", "log_level": "DEBUG", }, "mqtt": { "host": "192.168.1.74", "port": 1883, "base_topic": "hameter", }, "meters": [ { "id": 23040293, "protocol": "scm", "name": "Electric Meter", "unit_of_measurement": "kWh", "multiplier": 0.1156, "device_class": "energy", }, ], } class TestLoadConfigFromJson: """Tests for JSON-based configuration.""" def test_valid_config(self, tmp_path): path = _write_json_config(tmp_path, VALID_CONFIG) cfg = load_config_from_json(path) assert cfg.general.sleep_for == 0 assert cfg.general.device_id == "0" assert cfg.general.log_level == "DEBUG" assert cfg.mqtt.host == "192.168.1.74" assert cfg.mqtt.port == 1883 assert cfg.mqtt.base_topic == "hameter" assert len(cfg.meters) == 1 m = cfg.meters[0] assert m.id == 23040293 assert m.protocol == "scm" assert m.name == "Electric Meter" assert m.multiplier == 0.1156 assert m.device_class == "energy" def test_defaults_applied(self, tmp_path): data = { "mqtt": {"host": "10.0.0.1"}, "meters": [ {"id": 123, "protocol": "scm", "name": "Test"}, ], } path = _write_json_config(tmp_path, data) cfg = load_config_from_json(path) assert cfg.general.sleep_for == 0 assert cfg.general.device_id == "0" assert cfg.general.log_level == "INFO" assert cfg.general.rtl_tcp_host == "127.0.0.1" assert cfg.general.rtl_tcp_port == 1234 assert cfg.mqtt.port == 1883 assert cfg.mqtt.base_topic == "hameter" assert cfg.mqtt.ha_autodiscovery is True m = cfg.meters[0] assert m.multiplier == 1.0 assert m.state_class == "total_increasing" assert m.icon == "mdi:gauge" def test_missing_mqtt_host_raises(self, tmp_path): data = { "mqtt": {"port": 1883}, "meters": [{"id": 1, "protocol": "scm", "name": "X"}], } path = _write_json_config(tmp_path, data) with pytest.raises(ValueError, match="MQTT host"): load_config_from_json(path) def test_empty_meters_is_valid(self, tmp_path): data = {"mqtt": {"host": "10.0.0.1"}, "meters": []} path = _write_json_config(tmp_path, data) cfg = load_config_from_json(path) assert cfg.meters == [] def test_no_meters_key_is_valid(self, tmp_path): data = {"mqtt": {"host": "10.0.0.1"}} path = _write_json_config(tmp_path, data) cfg = load_config_from_json(path) assert cfg.meters == [] def test_invalid_protocol_raises(self, tmp_path): data = { "mqtt": {"host": "10.0.0.1"}, "meters": [{"id": 1, "protocol": "invalid", "name": "X"}], } path = _write_json_config(tmp_path, data) with pytest.raises(ValueError, match="invalid protocol"): load_config_from_json(path) def test_multiple_meters(self, tmp_path): data = { "mqtt": {"host": "10.0.0.1"}, "meters": [ {"id": 111, "protocol": "scm", "name": "Electric", "multiplier": 0.5}, {"id": 222, "protocol": "r900", "name": "Water", "multiplier": 1.0}, ], } path = _write_json_config(tmp_path, data) cfg = load_config_from_json(path) assert len(cfg.meters) == 2 assert cfg.meters[0].id == 111 assert cfg.meters[0].multiplier == 0.5 assert cfg.meters[1].id == 222 assert cfg.meters[1].protocol == "r900" def test_smart_defaults_energy(self, tmp_path): data = { "mqtt": {"host": "10.0.0.1"}, "meters": [ {"id": 1, "protocol": "scm", "name": "E", "device_class": "energy"}, ], } path = _write_json_config(tmp_path, data) cfg = load_config_from_json(path) assert cfg.meters[0].icon == "mdi:flash" assert cfg.meters[0].unit_of_measurement == "kWh" def test_smart_defaults_gas(self, tmp_path): data = { "mqtt": {"host": "10.0.0.1"}, "meters": [ {"id": 1, "protocol": "scm", "name": "G", "device_class": "gas"}, ], } path = _write_json_config(tmp_path, data) cfg = load_config_from_json(path) assert cfg.meters[0].icon == "mdi:fire" assert cfg.meters[0].unit_of_measurement == "ft\u00b3" def test_smart_defaults_water(self, tmp_path): data = { "mqtt": {"host": "10.0.0.1"}, "meters": [ {"id": 1, "protocol": "r900", "name": "W", "device_class": "water"}, ], } path = _write_json_config(tmp_path, data) cfg = load_config_from_json(path) assert cfg.meters[0].icon == "mdi:water" assert cfg.meters[0].unit_of_measurement == "gal" def test_file_not_found_raises(self, tmp_path): with pytest.raises(FileNotFoundError): load_config_from_json(str(tmp_path / "nonexistent.json")) class TestLoadConfigFromYaml: """Tests for YAML migration path.""" def test_valid_yaml(self, tmp_path): path = _write_yaml_config(tmp_path, VALID_CONFIG) cfg = load_config_from_yaml(path) assert cfg.mqtt.host == "192.168.1.74" assert len(cfg.meters) == 1 assert cfg.meters[0].id == 23040293 class TestSaveConfig: """Tests for atomic config saving.""" def test_save_and_reload(self, tmp_path): config = HaMeterConfig( general=GeneralConfig(), mqtt=MqttConfig(host="10.0.0.1"), meters=[ MeterConfig( id=123, protocol="scm", name="Test", unit_of_measurement="kWh", ), ], ) path = str(tmp_path / "config.json") save_config(config, path) loaded = load_config_from_json(path) assert loaded.mqtt.host == "10.0.0.1" assert loaded.meters[0].id == 123 def test_creates_parent_directory(self, tmp_path): config = HaMeterConfig( general=GeneralConfig(), mqtt=MqttConfig(host="10.0.0.1"), meters=[], ) path = str(tmp_path / "subdir" / "config.json") save_config(config, path) assert os.path.isfile(path) def test_no_temp_file_left_on_success(self, tmp_path): config = HaMeterConfig( general=GeneralConfig(), mqtt=MqttConfig(host="10.0.0.1"), meters=[], ) path = str(tmp_path / "config.json") save_config(config, path) # Only the config file should exist, no .tmp files files = list(tmp_path.iterdir()) assert len(files) == 1 assert files[0].name == "config.json" class TestConfigToDict: """Tests for serialization.""" def test_round_trip(self, tmp_path): config = HaMeterConfig( general=GeneralConfig(log_level="DEBUG", device_id="1"), mqtt=MqttConfig(host="10.0.0.1", port=1884), meters=[ MeterConfig( id=999, protocol="r900", name="Water", unit_of_measurement="gal", multiplier=2.5, ), ], ) d = config_to_dict(config) assert d["general"]["log_level"] == "DEBUG" assert d["mqtt"]["host"] == "10.0.0.1" assert d["meters"][0]["id"] == 999 assert d["meters"][0]["multiplier"] == 2.5 # Write and reload path = str(tmp_path / "config.json") with open(path, "w") as f: json.dump(d, f) loaded = load_config_from_json(path) assert loaded.general.log_level == "DEBUG" assert loaded.meters[0].multiplier == 2.5 class TestConfigExists: """Tests for config_exists.""" def test_exists(self, tmp_path): path = str(tmp_path / "config.json") Path(path).write_text("{}") assert config_exists(path) def test_not_exists(self, tmp_path): assert not config_exists(str(tmp_path / "missing.json")) class TestValidateMqttConfig: """Tests for validate_mqtt_config.""" def test_valid(self): ok, err = validate_mqtt_config({"host": "10.0.0.1", "port": 1883}) assert ok assert err == "" def test_missing_host(self): ok, err = validate_mqtt_config({"port": 1883}) assert not ok assert "host" in err.lower() def test_empty_host(self): ok, err = validate_mqtt_config({"host": "", "port": 1883}) assert not ok def test_invalid_port(self): ok, err = validate_mqtt_config({"host": "x", "port": 99999}) assert not ok assert "port" in err.lower() def test_string_port(self): ok, err = validate_mqtt_config({"host": "x", "port": "abc"}) assert not ok class TestValidateMeterConfig: """Tests for validate_meter_config.""" def test_valid(self): ok, err = validate_meter_config( {"id": 123, "protocol": "scm", "name": "Test"} ) assert ok def test_missing_id(self): ok, err = validate_meter_config( {"protocol": "scm", "name": "Test"} ) assert not ok def test_invalid_protocol(self): ok, err = validate_meter_config( {"id": 1, "protocol": "bad", "name": "Test"} ) assert not ok assert "protocol" in err.lower() def test_missing_name(self): ok, err = validate_meter_config( {"id": 1, "protocol": "scm"} ) assert not ok assert "name" in err.lower() def test_invalid_multiplier(self): ok, err = validate_meter_config( {"id": 1, "protocol": "scm", "name": "X", "multiplier": "bad"} ) assert not ok assert "multiplier" in err.lower() class TestCostFactors: """Tests for cost_factors serialization and deserialization.""" def test_round_trip_with_cost_factors(self, tmp_path): config = HaMeterConfig( general=GeneralConfig(), mqtt=MqttConfig(host="10.0.0.1"), meters=[ MeterConfig( id=123, protocol="scm", name="Electric", unit_of_measurement="kWh", cost_factors=[ RateComponent(name="Supply", rate=0.14742, type="per_unit"), RateComponent(name="Customer Charge", rate=9.65, type="fixed"), ], ), ], ) path = str(tmp_path / "config.json") save_config(config, path) loaded = load_config_from_json(path) assert len(loaded.meters[0].cost_factors) == 2 cf0 = loaded.meters[0].cost_factors[0] assert cf0.name == "Supply" assert cf0.rate == 0.14742 assert cf0.type == "per_unit" cf1 = loaded.meters[0].cost_factors[1] assert cf1.name == "Customer Charge" assert cf1.rate == 9.65 assert cf1.type == "fixed" def test_no_cost_factors_defaults_empty(self, tmp_path): data = { "mqtt": {"host": "10.0.0.1"}, "meters": [{"id": 1, "protocol": "scm", "name": "Test"}], } path = _write_json_config(tmp_path, data) cfg = load_config_from_json(path) assert cfg.meters[0].cost_factors == [] def test_config_to_dict_includes_cost_factors(self): config = HaMeterConfig( general=GeneralConfig(), mqtt=MqttConfig(host="10.0.0.1"), meters=[ MeterConfig( id=1, protocol="scm", name="Test", unit_of_measurement="kWh", cost_factors=[ RateComponent(name="Rate", rate=0.10, type="per_unit"), ], ), ], ) d = config_to_dict(config) assert len(d["meters"][0]["cost_factors"]) == 1 assert d["meters"][0]["cost_factors"][0] == { "name": "Rate", "rate": 0.10, "type": "per_unit", } def test_config_to_dict_empty_cost_factors(self): config = HaMeterConfig( general=GeneralConfig(), mqtt=MqttConfig(host="10.0.0.1"), meters=[ MeterConfig( id=1, protocol="scm", name="Test", unit_of_measurement="kWh", ), ], ) d = config_to_dict(config) assert d["meters"][0]["cost_factors"] == [] def test_negative_rate(self, tmp_path): config = HaMeterConfig( general=GeneralConfig(), mqtt=MqttConfig(host="10.0.0.1"), meters=[ MeterConfig( id=1, protocol="scm", name="Test", unit_of_measurement="kWh", cost_factors=[ RateComponent(name="Credit", rate=-0.00077, type="per_unit"), ], ), ], ) path = str(tmp_path / "config.json") save_config(config, path) loaded = load_config_from_json(path) assert loaded.meters[0].cost_factors[0].rate == -0.00077 class TestValidateRateComponent: """Tests for validate_rate_component.""" def test_valid_per_unit(self): ok, err = validate_rate_component( {"name": "Supply", "rate": 0.14742, "type": "per_unit"} ) assert ok def test_valid_fixed(self): ok, err = validate_rate_component( {"name": "Customer Charge", "rate": 9.65, "type": "fixed"} ) assert ok def test_missing_name(self): ok, err = validate_rate_component({"rate": 0.10}) assert not ok assert "name" in err.lower() def test_invalid_rate(self): ok, err = validate_rate_component({"name": "X", "rate": "bad"}) assert not ok assert "rate" in err.lower() def test_invalid_type(self): ok, err = validate_rate_component( {"name": "X", "rate": 0.10, "type": "tiered"} ) assert not ok assert "type" in err.lower() def test_defaults_to_per_unit(self): ok, err = validate_rate_component({"name": "X", "rate": 0.10}) assert ok class TestGetMeterDefaults: """Tests for get_meter_defaults.""" def test_energy(self): d = get_meter_defaults("energy") assert d["icon"] == "mdi:flash" assert d["unit"] == "kWh" def test_gas(self): d = get_meter_defaults("gas") assert d["icon"] == "mdi:fire" def test_water(self): d = get_meter_defaults("water") assert d["icon"] == "mdi:water" assert d["unit"] == "gal" def test_unknown(self): d = get_meter_defaults("unknown") assert d == {}