168 lines
6.4 KiB
Python
168 lines
6.4 KiB
Python
"""Unit tests for the HaMeterMQTT class."""
|
|
|
|
import json
|
|
import unittest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from hameter.mqtt_client import HaMeterMQTT
|
|
from hameter.config import MqttConfig, MeterConfig, RateComponent
|
|
from hameter.meter import MeterReading
|
|
|
|
|
|
def _mqtt_config(**kw):
|
|
defaults = dict(
|
|
host="broker.test", port=1883, user="", password="",
|
|
base_topic="hameter", ha_autodiscovery=True,
|
|
ha_autodiscovery_topic="homeassistant", client_id="hameter",
|
|
)
|
|
defaults.update(kw)
|
|
return MqttConfig(**defaults)
|
|
|
|
|
|
def _meter(**kw):
|
|
defaults = dict(
|
|
id=100, protocol="scm", name="Electric",
|
|
unit_of_measurement="kWh", cost_factors=[],
|
|
)
|
|
defaults.update(kw)
|
|
return MeterConfig(**defaults)
|
|
|
|
|
|
def _reading(**kw):
|
|
defaults = dict(
|
|
meter_id=100, protocol="SCM", raw_consumption=50000,
|
|
calibrated_consumption=500.0, timestamp="2026-03-05T12:00:00Z",
|
|
raw_message={"ID": 100, "Consumption": 50000},
|
|
)
|
|
defaults.update(kw)
|
|
return MeterReading(**defaults)
|
|
|
|
|
|
@patch("hameter.mqtt_client.mqtt.Client")
|
|
class TestHaMeterMQTT(unittest.TestCase):
|
|
|
|
def test_connect_with_credentials(self, MockClient):
|
|
mock_inst = MockClient.return_value
|
|
cfg = _mqtt_config(user="u", password="p")
|
|
HaMeterMQTT(cfg, [_meter()])
|
|
mock_inst.username_pw_set.assert_called_once_with("u", "p")
|
|
|
|
def test_connect_without_credentials(self, MockClient):
|
|
mock_inst = MockClient.return_value
|
|
cfg = _mqtt_config(user="", password="")
|
|
HaMeterMQTT(cfg, [_meter()])
|
|
mock_inst.username_pw_set.assert_not_called()
|
|
|
|
def test_connect_calls_broker(self, MockClient):
|
|
mock_inst = MockClient.return_value
|
|
m = HaMeterMQTT(_mqtt_config(), [_meter()])
|
|
m.connect()
|
|
mock_inst.connect.assert_called_once_with("broker.test", 1883, keepalive=60)
|
|
mock_inst.loop_start.assert_called_once()
|
|
|
|
def test_publish_reading(self, MockClient):
|
|
mock_inst = MockClient.return_value
|
|
m = HaMeterMQTT(_mqtt_config(), [_meter(id=100)])
|
|
r = _reading()
|
|
m.publish_reading(r)
|
|
# Find the state publish call
|
|
calls = mock_inst.publish.call_args_list
|
|
state_calls = [c for c in calls if "100/state" in str(c)]
|
|
assert len(state_calls) == 1
|
|
payload = json.loads(state_calls[0][0][1])
|
|
assert payload["reading"] == 500.0
|
|
assert payload["raw_reading"] == 50000
|
|
|
|
def test_publish_cost(self, MockClient):
|
|
mock_inst = MockClient.return_value
|
|
m = HaMeterMQTT(_mqtt_config(), [_meter(id=200)])
|
|
m.publish_cost(200, 42.75)
|
|
calls = mock_inst.publish.call_args_list
|
|
cost_calls = [c for c in calls if "200/cost" in str(c)]
|
|
assert len(cost_calls) == 1
|
|
payload = json.loads(cost_calls[0][0][1])
|
|
assert payload["cost"] == 42.75
|
|
|
|
def test_disconnect_publishes_offline(self, MockClient):
|
|
mock_inst = MockClient.return_value
|
|
m = HaMeterMQTT(_mqtt_config(), [_meter()])
|
|
m.disconnect()
|
|
calls = mock_inst.publish.call_args_list
|
|
offline_calls = [c for c in calls if "status" in str(c[0][0])]
|
|
assert len(offline_calls) >= 1
|
|
assert offline_calls[-1][0][1] == "offline"
|
|
mock_inst.loop_stop.assert_called_once()
|
|
mock_inst.disconnect.assert_called_once()
|
|
|
|
def test_on_connect_success(self, MockClient):
|
|
mock_inst = MockClient.return_value
|
|
HaMeterMQTT(_mqtt_config(), [_meter()])
|
|
on_connect = mock_inst.on_connect
|
|
mock_inst.publish.reset_mock()
|
|
mock_inst.subscribe.reset_mock()
|
|
on_connect(mock_inst, None, MagicMock(), 0, None)
|
|
publish_calls = mock_inst.publish.call_args_list
|
|
online = [c for c in publish_calls if c[0][1] == "online"]
|
|
assert len(online) >= 1
|
|
mock_inst.subscribe.assert_called_once()
|
|
|
|
def test_on_connect_failure(self, MockClient):
|
|
mock_inst = MockClient.return_value
|
|
HaMeterMQTT(_mqtt_config(), [_meter()])
|
|
on_connect = mock_inst.on_connect
|
|
mock_inst.publish.reset_mock()
|
|
mock_inst.subscribe.reset_mock()
|
|
on_connect(mock_inst, None, MagicMock(), 5, None)
|
|
mock_inst.publish.assert_not_called()
|
|
mock_inst.subscribe.assert_not_called()
|
|
|
|
def test_on_disconnect_clean(self, MockClient):
|
|
mock_inst = MockClient.return_value
|
|
HaMeterMQTT(_mqtt_config(), [_meter()])
|
|
on_disc = mock_inst.on_disconnect
|
|
on_disc(mock_inst, None, MagicMock(), 0, None)
|
|
|
|
def test_on_message_ha_online(self, MockClient):
|
|
mock_inst = MockClient.return_value
|
|
HaMeterMQTT(_mqtt_config(), [_meter()])
|
|
on_msg = mock_inst.on_message
|
|
mock_inst.publish.reset_mock()
|
|
msg = MagicMock()
|
|
msg.topic = "homeassistant/status"
|
|
msg.payload = b"online"
|
|
on_msg(mock_inst, None, msg)
|
|
calls = mock_inst.publish.call_args_list
|
|
disco = [c for c in calls if "homeassistant/sensor/" in str(c)]
|
|
assert len(disco) >= 3
|
|
|
|
def test_discovery_without_cost_factors(self, MockClient):
|
|
mock_inst = MockClient.return_value
|
|
HaMeterMQTT(_mqtt_config(), [_meter(cost_factors=[])])
|
|
mock_inst.publish.reset_mock()
|
|
# Access the instance through the mock (HaMeterMQTT stores it as self._client)
|
|
# We need the HaMeterMQTT instance to call _publish_discovery
|
|
m = HaMeterMQTT(_mqtt_config(), [_meter(cost_factors=[])])
|
|
mock_inst.publish.reset_mock()
|
|
m._publish_discovery()
|
|
calls = mock_inst.publish.call_args_list
|
|
disco = [c for c in calls if "homeassistant/sensor/" in str(c)]
|
|
assert len(disco) == 3
|
|
topics = " ".join(str(c) for c in disco)
|
|
assert "cost/config" not in topics
|
|
|
|
def test_discovery_with_cost_factors(self, MockClient):
|
|
mock_inst = MockClient.return_value
|
|
meter = _meter(cost_factors=[
|
|
RateComponent(name="Supply", rate=0.10, type="per_unit"),
|
|
])
|
|
m = HaMeterMQTT(_mqtt_config(), [meter])
|
|
mock_inst.publish.reset_mock()
|
|
m._publish_discovery()
|
|
calls = mock_inst.publish.call_args_list
|
|
disco = [c for c in calls if "homeassistant/sensor/" in str(c)]
|
|
assert len(disco) == 4
|
|
cost_calls = [c for c in disco if "cost/config" in str(c)]
|
|
assert len(cost_calls) == 1
|
|
payload = json.loads(cost_calls[0][0][1])
|
|
assert payload["device_class"] == "monetary"
|