"""Tests for hameter.state.""" import threading import pytest from hameter.meter import MeterReading from hameter.state import AppState, CostState, PipelineStatus, WebLogHandler class TestAppState: """Tests for the AppState shared state object.""" def test_initial_status(self): state = AppState() assert state.status == PipelineStatus.UNCONFIGURED def test_set_status(self): state = AppState() state.set_status(PipelineStatus.RUNNING) assert state.status == PipelineStatus.RUNNING assert state.status_message == "" def test_set_status_with_message(self): state = AppState() state.set_status(PipelineStatus.ERROR, "Something broke") assert state.status == PipelineStatus.ERROR assert state.status_message == "Something broke" def test_record_reading(self): state = AppState() reading = MeterReading( meter_id=123, protocol="scm", raw_consumption=1000, calibrated_consumption=100.0, timestamp="2026-01-01 00:00:00", raw_message={}, ) state.record_reading(reading) readings = state.get_last_readings() assert 123 in readings assert readings[123].raw_consumption == 1000 counts = state.get_reading_counts() assert counts[123] == 1 def test_multiple_readings_same_meter(self): state = AppState() for i in range(5): reading = MeterReading( meter_id=123, protocol="scm", raw_consumption=1000 + i, calibrated_consumption=100.0 + i, timestamp=f"2026-01-01 00:00:0{i}", raw_message={}, ) state.record_reading(reading) readings = state.get_last_readings() assert readings[123].raw_consumption == 1004 # Last one counts = state.get_reading_counts() assert counts[123] == 5 def test_discovery_results(self): state = AppState() state.record_discovery(111, {"protocol": "scm", "count": 1}) state.record_discovery(222, {"protocol": "r900", "count": 3}) results = state.get_discovery_results() assert len(results) == 2 assert results[111]["protocol"] == "scm" assert results[222]["count"] == 3 def test_clear_discovery_results(self): state = AppState() state.record_discovery(111, {"protocol": "scm", "count": 1}) state.clear_discovery_results() assert state.get_discovery_results() == {} def test_log_buffer(self): state = AppState() state.add_log({"message": "hello"}) state.add_log({"message": "world"}) logs = state.get_recent_logs(10) assert len(logs) == 2 assert logs[0]["message"] == "hello" assert logs[1]["message"] == "world" def test_log_buffer_max_size(self): state = AppState() for i in range(1500): state.add_log({"message": f"log {i}"}) logs = state.get_recent_logs(2000) assert len(logs) == 1000 # maxlen def test_log_buffer_recent_count(self): state = AppState() for i in range(50): state.add_log({"message": f"log {i}"}) logs = state.get_recent_logs(10) assert len(logs) == 10 assert logs[-1]["message"] == "log 49" def test_sse_subscribe_notify(self): state = AppState() event = state.subscribe_sse() assert not event.is_set() state.set_status(PipelineStatus.RUNNING) assert event.is_set() def test_sse_unsubscribe(self): state = AppState() event = state.subscribe_sse() state.unsubscribe_sse(event) # Should not raise even if unsubscribing twice state.unsubscribe_sse(event) def test_config_ready_event(self): state = AppState() assert not state.config_ready.is_set() state.config_ready.set() assert state.config_ready.is_set() def test_restart_requested_event(self): state = AppState() assert not state.restart_requested.is_set() state.restart_requested.set() assert state.restart_requested.is_set() state.restart_requested.clear() assert not state.restart_requested.is_set() def test_discovery_duration(self): state = AppState() assert state.discovery_duration == 120 state.discovery_duration = 60 assert state.discovery_duration == 60 def test_thread_safety(self): """Multiple threads recording readings simultaneously.""" state = AppState() errors = [] def record(meter_id): try: for i in range(100): reading = MeterReading( meter_id=meter_id, protocol="scm", raw_consumption=i, calibrated_consumption=float(i), timestamp="2026-01-01", raw_message={}, ) state.record_reading(reading) except Exception as e: errors.append(e) threads = [threading.Thread(target=record, args=(i,)) for i in range(10)] for t in threads: t.start() for t in threads: t.join() assert errors == [] counts = state.get_reading_counts() for i in range(10): assert counts[i] == 100 class TestCostState: """Tests for cost state tracking in AppState.""" def test_initial_cost_states_empty(self): state = AppState() assert state.get_cost_states() == {} assert state.get_cost_state(123) is None def test_update_cost_state(self): state = AppState() cs = CostState(cumulative_cost=50.0, last_calibrated_reading=1000.0) state.update_cost_state(123, cs) result = state.get_cost_state(123) assert result is not None assert result.cumulative_cost == 50.0 assert result.last_calibrated_reading == 1000.0 def test_get_cost_states_multiple(self): state = AppState() state.update_cost_state(111, CostState(cumulative_cost=10.0)) state.update_cost_state(222, CostState(cumulative_cost=20.0)) states = state.get_cost_states() assert len(states) == 2 assert states[111].cumulative_cost == 10.0 assert states[222].cumulative_cost == 20.0 def test_reset_cost_state(self): state = AppState() cs = CostState( cumulative_cost=100.0, last_calibrated_reading=5000.0, fixed_charges_applied=9.65, ) state.update_cost_state(123, cs) state.reset_cost_state(123, "2026-04-01T00:00:00Z") result = state.get_cost_state(123) assert result.cumulative_cost == 0.0 assert result.last_calibrated_reading is None assert result.billing_period_start == "2026-04-01T00:00:00Z" assert result.fixed_charges_applied == 0.0 def test_add_fixed_charges(self): state = AppState() cs = CostState(cumulative_cost=50.0, last_calibrated_reading=1000.0) state.update_cost_state(123, cs) state.add_fixed_charges(123, 9.65, "2026-03-05T00:00:00Z") result = state.get_cost_state(123) assert result.cumulative_cost == 59.65 assert result.fixed_charges_applied == 9.65 def test_add_fixed_charges_accumulates(self): state = AppState() cs = CostState(cumulative_cost=50.0) state.update_cost_state(123, cs) state.add_fixed_charges(123, 5.0, "2026-03-01") state.add_fixed_charges(123, 3.0, "2026-03-02") result = state.get_cost_state(123) assert result.cumulative_cost == 58.0 assert result.fixed_charges_applied == 8.0 def test_add_fixed_charges_no_cost_state(self): state = AppState() # Should not raise even if no cost state exists state.add_fixed_charges(999, 10.0, "2026-03-01") assert state.get_cost_state(999) is None def test_cost_state_notifies_sse(self): state = AppState() event = state.subscribe_sse() state.update_cost_state(123, CostState(cumulative_cost=10.0)) assert event.is_set() class TestWebLogHandler: """Tests for the WebLogHandler.""" def test_emits_to_app_state(self): state = AppState() handler = WebLogHandler(state) import logging logger = logging.getLogger("test.webloghandler") logger.addHandler(handler) logger.setLevel(logging.DEBUG) logger.info("Test message") logger.removeHandler(handler) logs = state.get_recent_logs() assert len(logs) >= 1 last = logs[-1] assert last["level"] == "INFO" assert last["message"] == "Test message" assert "timestamp" in last