"""Thread-safe shared state between web server and pipeline.""" import logging import threading import time from collections import deque from dataclasses import dataclass from enum import Enum from typing import Optional from hameter.meter import MeterReading @dataclass class CostState: """Tracks cumulative cost for a single meter across a billing period.""" cumulative_cost: float = 0.0 last_calibrated_reading: Optional[float] = None billing_period_start: str = "" last_updated: str = "" fixed_charges_applied: float = 0.0 class PipelineStatus(Enum): UNCONFIGURED = "unconfigured" STOPPED = "stopped" STARTING = "starting" RUNNING = "running" RESTARTING = "restarting" DISCOVERY = "discovery" ERROR = "error" class AppState: """Thread-safe shared state for the HAMeter application. Accessed by: - Main thread: Pipeline reads/writes pipeline_status, last_readings - Flask thread: Web routes read status, trigger restart/discovery """ def __init__(self): self._lock = threading.Lock() # Pipeline state self._status: PipelineStatus = PipelineStatus.UNCONFIGURED self._status_message: str = "" # Config object (set once loaded) self._config = None # Meter readings (most recent per meter ID) self._last_readings: dict[int, MeterReading] = {} self._reading_counts: dict[int, int] = {} # Cost tracking per meter self._cost_states: dict[int, CostState] = {} # Discovery results self._discovery_results: dict[int, dict] = {} # Log ring buffer for web UI streaming self._log_buffer: deque[dict] = deque(maxlen=1000) # SSE subscribers (list of threading.Event per subscriber) self._sse_events: list[threading.Event] = [] # Signals from web -> pipeline self._restart_requested = threading.Event() self._discovery_requested = threading.Event() self._discovery_duration: int = 120 self._stop_discovery = threading.Event() # Pipeline startup gate: set once config is valid self._config_ready = threading.Event() # --- Status --- @property def status(self) -> PipelineStatus: with self._lock: return self._status def set_status(self, status: PipelineStatus, message: str = ""): with self._lock: self._status = status self._status_message = message self._notify_sse() @property def status_message(self) -> str: with self._lock: return self._status_message # --- Config --- @property def config(self): with self._lock: return self._config def set_config(self, config): with self._lock: self._config = config # --- Readings --- def record_reading(self, reading: MeterReading): with self._lock: self._last_readings[reading.meter_id] = reading self._reading_counts[reading.meter_id] = ( self._reading_counts.get(reading.meter_id, 0) + 1 ) self._notify_sse() def get_last_readings(self) -> dict[int, MeterReading]: with self._lock: return dict(self._last_readings) def get_reading_counts(self) -> dict[int, int]: with self._lock: return dict(self._reading_counts) def clear_readings(self, meter_id: Optional[int] = None): """Clear cached readings. If meter_id given, clear only that meter.""" with self._lock: if meter_id is not None: self._last_readings.pop(meter_id, None) self._reading_counts.pop(meter_id, None) else: self._last_readings.clear() self._reading_counts.clear() self._notify_sse() # --- Cost state --- def get_cost_states(self) -> dict[int, CostState]: with self._lock: return dict(self._cost_states) def get_cost_state(self, meter_id: int) -> Optional[CostState]: with self._lock: return self._cost_states.get(meter_id) def update_cost_state(self, meter_id: int, cost_state: CostState): with self._lock: self._cost_states[meter_id] = cost_state self._notify_sse() def reset_cost_state(self, meter_id: int, timestamp: str): """Reset cost tracking for a new billing period.""" with self._lock: self._cost_states[meter_id] = CostState( cumulative_cost=0.0, last_calibrated_reading=None, billing_period_start=timestamp, last_updated=timestamp, fixed_charges_applied=0.0, ) self._notify_sse() def add_fixed_charges(self, meter_id: int, amount: float, timestamp: str): """Add fixed charges to the cumulative cost for a meter.""" with self._lock: cs = self._cost_states.get(meter_id) if cs: cs.cumulative_cost = round(cs.cumulative_cost + amount, 4) cs.fixed_charges_applied = round( cs.fixed_charges_applied + amount, 4 ) cs.last_updated = timestamp self._notify_sse() def remove_cost_state(self, meter_id: int): """Remove cost state for a meter (e.g. when cost_factors are cleared).""" with self._lock: self._cost_states.pop(meter_id, None) # --- Discovery --- def record_discovery(self, meter_id: int, info: dict): with self._lock: self._discovery_results[meter_id] = info self._notify_sse() def get_discovery_results(self) -> dict[int, dict]: with self._lock: return dict(self._discovery_results) def clear_discovery_results(self): with self._lock: self._discovery_results.clear() # --- Log buffer --- def add_log(self, record: dict): with self._lock: self._log_buffer.append(record) self._notify_sse() def get_recent_logs(self, count: int = 200) -> list[dict]: with self._lock: items = list(self._log_buffer) return items[-count:] # --- SSE notification --- def subscribe_sse(self) -> threading.Event: event = threading.Event() with self._lock: self._sse_events.append(event) return event def unsubscribe_sse(self, event: threading.Event): with self._lock: try: self._sse_events.remove(event) except ValueError: pass def _notify_sse(self): with self._lock: events = list(self._sse_events) for event in events: event.set() # --- Signals --- @property def restart_requested(self) -> threading.Event: return self._restart_requested @property def discovery_requested(self) -> threading.Event: return self._discovery_requested @property def discovery_duration(self) -> int: with self._lock: return self._discovery_duration @discovery_duration.setter def discovery_duration(self, value: int): with self._lock: self._discovery_duration = value @property def stop_discovery(self) -> threading.Event: return self._stop_discovery @property def config_ready(self) -> threading.Event: return self._config_ready class WebLogHandler(logging.Handler): """Captures log records into AppState for web UI streaming.""" def __init__(self, app_state: AppState): super().__init__() self._state = app_state def emit(self, record): try: self._state.add_log({ "timestamp": time.strftime( "%Y-%m-%d %H:%M:%S", time.localtime(record.created) ), "level": record.levelname, "name": record.name, "message": record.getMessage(), }) except Exception: pass