Files
HAMeter/tests/test_state.py
2026-03-06 12:25:27 -05:00

262 lines
8.8 KiB
Python

"""Tests for hameter.state."""
import threading
import pytest
from hameter.meter import MeterReading
from hameter.state import AppState, CostState, PipelineStatus, WebLogHandler
class TestAppState:
"""Tests for the AppState shared state object."""
def test_initial_status(self):
state = AppState()
assert state.status == PipelineStatus.UNCONFIGURED
def test_set_status(self):
state = AppState()
state.set_status(PipelineStatus.RUNNING)
assert state.status == PipelineStatus.RUNNING
assert state.status_message == ""
def test_set_status_with_message(self):
state = AppState()
state.set_status(PipelineStatus.ERROR, "Something broke")
assert state.status == PipelineStatus.ERROR
assert state.status_message == "Something broke"
def test_record_reading(self):
state = AppState()
reading = MeterReading(
meter_id=123,
protocol="scm",
raw_consumption=1000,
calibrated_consumption=100.0,
timestamp="2026-01-01 00:00:00",
raw_message={},
)
state.record_reading(reading)
readings = state.get_last_readings()
assert 123 in readings
assert readings[123].raw_consumption == 1000
counts = state.get_reading_counts()
assert counts[123] == 1
def test_multiple_readings_same_meter(self):
state = AppState()
for i in range(5):
reading = MeterReading(
meter_id=123,
protocol="scm",
raw_consumption=1000 + i,
calibrated_consumption=100.0 + i,
timestamp=f"2026-01-01 00:00:0{i}",
raw_message={},
)
state.record_reading(reading)
readings = state.get_last_readings()
assert readings[123].raw_consumption == 1004 # Last one
counts = state.get_reading_counts()
assert counts[123] == 5
def test_discovery_results(self):
state = AppState()
state.record_discovery(111, {"protocol": "scm", "count": 1})
state.record_discovery(222, {"protocol": "r900", "count": 3})
results = state.get_discovery_results()
assert len(results) == 2
assert results[111]["protocol"] == "scm"
assert results[222]["count"] == 3
def test_clear_discovery_results(self):
state = AppState()
state.record_discovery(111, {"protocol": "scm", "count": 1})
state.clear_discovery_results()
assert state.get_discovery_results() == {}
def test_log_buffer(self):
state = AppState()
state.add_log({"message": "hello"})
state.add_log({"message": "world"})
logs = state.get_recent_logs(10)
assert len(logs) == 2
assert logs[0]["message"] == "hello"
assert logs[1]["message"] == "world"
def test_log_buffer_max_size(self):
state = AppState()
for i in range(1500):
state.add_log({"message": f"log {i}"})
logs = state.get_recent_logs(2000)
assert len(logs) == 1000 # maxlen
def test_log_buffer_recent_count(self):
state = AppState()
for i in range(50):
state.add_log({"message": f"log {i}"})
logs = state.get_recent_logs(10)
assert len(logs) == 10
assert logs[-1]["message"] == "log 49"
def test_sse_subscribe_notify(self):
state = AppState()
event = state.subscribe_sse()
assert not event.is_set()
state.set_status(PipelineStatus.RUNNING)
assert event.is_set()
def test_sse_unsubscribe(self):
state = AppState()
event = state.subscribe_sse()
state.unsubscribe_sse(event)
# Should not raise even if unsubscribing twice
state.unsubscribe_sse(event)
def test_config_ready_event(self):
state = AppState()
assert not state.config_ready.is_set()
state.config_ready.set()
assert state.config_ready.is_set()
def test_restart_requested_event(self):
state = AppState()
assert not state.restart_requested.is_set()
state.restart_requested.set()
assert state.restart_requested.is_set()
state.restart_requested.clear()
assert not state.restart_requested.is_set()
def test_discovery_duration(self):
state = AppState()
assert state.discovery_duration == 120
state.discovery_duration = 60
assert state.discovery_duration == 60
def test_thread_safety(self):
"""Multiple threads recording readings simultaneously."""
state = AppState()
errors = []
def record(meter_id):
try:
for i in range(100):
reading = MeterReading(
meter_id=meter_id,
protocol="scm",
raw_consumption=i,
calibrated_consumption=float(i),
timestamp="2026-01-01",
raw_message={},
)
state.record_reading(reading)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=record, args=(i,)) for i in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert errors == []
counts = state.get_reading_counts()
for i in range(10):
assert counts[i] == 100
class TestCostState:
"""Tests for cost state tracking in AppState."""
def test_initial_cost_states_empty(self):
state = AppState()
assert state.get_cost_states() == {}
assert state.get_cost_state(123) is None
def test_update_cost_state(self):
state = AppState()
cs = CostState(cumulative_cost=50.0, last_calibrated_reading=1000.0)
state.update_cost_state(123, cs)
result = state.get_cost_state(123)
assert result is not None
assert result.cumulative_cost == 50.0
assert result.last_calibrated_reading == 1000.0
def test_get_cost_states_multiple(self):
state = AppState()
state.update_cost_state(111, CostState(cumulative_cost=10.0))
state.update_cost_state(222, CostState(cumulative_cost=20.0))
states = state.get_cost_states()
assert len(states) == 2
assert states[111].cumulative_cost == 10.0
assert states[222].cumulative_cost == 20.0
def test_reset_cost_state(self):
state = AppState()
cs = CostState(
cumulative_cost=100.0,
last_calibrated_reading=5000.0,
fixed_charges_applied=9.65,
)
state.update_cost_state(123, cs)
state.reset_cost_state(123, "2026-04-01T00:00:00Z")
result = state.get_cost_state(123)
assert result.cumulative_cost == 0.0
assert result.last_calibrated_reading is None
assert result.billing_period_start == "2026-04-01T00:00:00Z"
assert result.fixed_charges_applied == 0.0
def test_add_fixed_charges(self):
state = AppState()
cs = CostState(cumulative_cost=50.0, last_calibrated_reading=1000.0)
state.update_cost_state(123, cs)
state.add_fixed_charges(123, 9.65, "2026-03-05T00:00:00Z")
result = state.get_cost_state(123)
assert result.cumulative_cost == 59.65
assert result.fixed_charges_applied == 9.65
def test_add_fixed_charges_accumulates(self):
state = AppState()
cs = CostState(cumulative_cost=50.0)
state.update_cost_state(123, cs)
state.add_fixed_charges(123, 5.0, "2026-03-01")
state.add_fixed_charges(123, 3.0, "2026-03-02")
result = state.get_cost_state(123)
assert result.cumulative_cost == 58.0
assert result.fixed_charges_applied == 8.0
def test_add_fixed_charges_no_cost_state(self):
state = AppState()
# Should not raise even if no cost state exists
state.add_fixed_charges(999, 10.0, "2026-03-01")
assert state.get_cost_state(999) is None
def test_cost_state_notifies_sse(self):
state = AppState()
event = state.subscribe_sse()
state.update_cost_state(123, CostState(cumulative_cost=10.0))
assert event.is_set()
class TestWebLogHandler:
"""Tests for the WebLogHandler."""
def test_emits_to_app_state(self):
state = AppState()
handler = WebLogHandler(state)
import logging
logger = logging.getLogger("test.webloghandler")
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)
logger.info("Test message")
logger.removeHandler(handler)
logs = state.get_recent_logs()
assert len(logs) >= 1
last = logs[-1]
assert last["level"] == "INFO"
assert last["message"] == "Test message"
assert "timestamp" in last