Files
HAMeter/hameter/state.py
2026-03-06 12:25:27 -05:00

276 lines
8.0 KiB
Python

"""Thread-safe shared state between web server and pipeline."""
import logging
import threading
import time
from collections import deque
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from hameter.meter import MeterReading
@dataclass
class CostState:
"""Tracks cumulative cost for a single meter across a billing period."""
cumulative_cost: float = 0.0
last_calibrated_reading: Optional[float] = None
billing_period_start: str = ""
last_updated: str = ""
fixed_charges_applied: float = 0.0
class PipelineStatus(Enum):
UNCONFIGURED = "unconfigured"
STOPPED = "stopped"
STARTING = "starting"
RUNNING = "running"
RESTARTING = "restarting"
DISCOVERY = "discovery"
ERROR = "error"
class AppState:
"""Thread-safe shared state for the HAMeter application.
Accessed by:
- Main thread: Pipeline reads/writes pipeline_status, last_readings
- Flask thread: Web routes read status, trigger restart/discovery
"""
def __init__(self):
self._lock = threading.Lock()
# Pipeline state
self._status: PipelineStatus = PipelineStatus.UNCONFIGURED
self._status_message: str = ""
# Config object (set once loaded)
self._config = None
# Meter readings (most recent per meter ID)
self._last_readings: dict[int, MeterReading] = {}
self._reading_counts: dict[int, int] = {}
# Cost tracking per meter
self._cost_states: dict[int, CostState] = {}
# Discovery results
self._discovery_results: dict[int, dict] = {}
# Log ring buffer for web UI streaming
self._log_buffer: deque[dict] = deque(maxlen=1000)
# SSE subscribers (list of threading.Event per subscriber)
self._sse_events: list[threading.Event] = []
# Signals from web -> pipeline
self._restart_requested = threading.Event()
self._discovery_requested = threading.Event()
self._discovery_duration: int = 120
self._stop_discovery = threading.Event()
# Pipeline startup gate: set once config is valid
self._config_ready = threading.Event()
# --- Status ---
@property
def status(self) -> PipelineStatus:
with self._lock:
return self._status
def set_status(self, status: PipelineStatus, message: str = ""):
with self._lock:
self._status = status
self._status_message = message
self._notify_sse()
@property
def status_message(self) -> str:
with self._lock:
return self._status_message
# --- Config ---
@property
def config(self):
with self._lock:
return self._config
def set_config(self, config):
with self._lock:
self._config = config
# --- Readings ---
def record_reading(self, reading: MeterReading):
with self._lock:
self._last_readings[reading.meter_id] = reading
self._reading_counts[reading.meter_id] = (
self._reading_counts.get(reading.meter_id, 0) + 1
)
self._notify_sse()
def get_last_readings(self) -> dict[int, MeterReading]:
with self._lock:
return dict(self._last_readings)
def get_reading_counts(self) -> dict[int, int]:
with self._lock:
return dict(self._reading_counts)
def clear_readings(self, meter_id: Optional[int] = None):
"""Clear cached readings. If meter_id given, clear only that meter."""
with self._lock:
if meter_id is not None:
self._last_readings.pop(meter_id, None)
self._reading_counts.pop(meter_id, None)
else:
self._last_readings.clear()
self._reading_counts.clear()
self._notify_sse()
# --- Cost state ---
def get_cost_states(self) -> dict[int, CostState]:
with self._lock:
return dict(self._cost_states)
def get_cost_state(self, meter_id: int) -> Optional[CostState]:
with self._lock:
return self._cost_states.get(meter_id)
def update_cost_state(self, meter_id: int, cost_state: CostState):
with self._lock:
self._cost_states[meter_id] = cost_state
self._notify_sse()
def reset_cost_state(self, meter_id: int, timestamp: str):
"""Reset cost tracking for a new billing period."""
with self._lock:
self._cost_states[meter_id] = CostState(
cumulative_cost=0.0,
last_calibrated_reading=None,
billing_period_start=timestamp,
last_updated=timestamp,
fixed_charges_applied=0.0,
)
self._notify_sse()
def add_fixed_charges(self, meter_id: int, amount: float, timestamp: str):
"""Add fixed charges to the cumulative cost for a meter."""
with self._lock:
cs = self._cost_states.get(meter_id)
if cs:
cs.cumulative_cost = round(cs.cumulative_cost + amount, 4)
cs.fixed_charges_applied = round(
cs.fixed_charges_applied + amount, 4
)
cs.last_updated = timestamp
self._notify_sse()
def remove_cost_state(self, meter_id: int):
"""Remove cost state for a meter (e.g. when cost_factors are cleared)."""
with self._lock:
self._cost_states.pop(meter_id, None)
# --- Discovery ---
def record_discovery(self, meter_id: int, info: dict):
with self._lock:
self._discovery_results[meter_id] = info
self._notify_sse()
def get_discovery_results(self) -> dict[int, dict]:
with self._lock:
return dict(self._discovery_results)
def clear_discovery_results(self):
with self._lock:
self._discovery_results.clear()
# --- Log buffer ---
def add_log(self, record: dict):
with self._lock:
self._log_buffer.append(record)
self._notify_sse()
def get_recent_logs(self, count: int = 200) -> list[dict]:
with self._lock:
items = list(self._log_buffer)
return items[-count:]
# --- SSE notification ---
def subscribe_sse(self) -> threading.Event:
event = threading.Event()
with self._lock:
self._sse_events.append(event)
return event
def unsubscribe_sse(self, event: threading.Event):
with self._lock:
try:
self._sse_events.remove(event)
except ValueError:
pass
def _notify_sse(self):
with self._lock:
events = list(self._sse_events)
for event in events:
event.set()
# --- Signals ---
@property
def restart_requested(self) -> threading.Event:
return self._restart_requested
@property
def discovery_requested(self) -> threading.Event:
return self._discovery_requested
@property
def discovery_duration(self) -> int:
with self._lock:
return self._discovery_duration
@discovery_duration.setter
def discovery_duration(self, value: int):
with self._lock:
self._discovery_duration = value
@property
def stop_discovery(self) -> threading.Event:
return self._stop_discovery
@property
def config_ready(self) -> threading.Event:
return self._config_ready
class WebLogHandler(logging.Handler):
"""Captures log records into AppState for web UI streaming."""
def __init__(self, app_state: AppState):
super().__init__()
self._state = app_state
def emit(self, record):
try:
self._state.add_log({
"timestamp": time.strftime(
"%Y-%m-%d %H:%M:%S", time.localtime(record.created)
),
"level": record.levelname,
"name": record.name,
"message": record.getMessage(),
})
except Exception:
pass