276 lines
8.0 KiB
Python
276 lines
8.0 KiB
Python
"""Thread-safe shared state between web server and pipeline."""
|
|
|
|
import logging
|
|
import threading
|
|
import time
|
|
from collections import deque
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import Optional
|
|
|
|
from hameter.meter import MeterReading
|
|
|
|
|
|
@dataclass
|
|
class CostState:
|
|
"""Tracks cumulative cost for a single meter across a billing period."""
|
|
|
|
cumulative_cost: float = 0.0
|
|
last_calibrated_reading: Optional[float] = None
|
|
billing_period_start: str = ""
|
|
last_updated: str = ""
|
|
fixed_charges_applied: float = 0.0
|
|
|
|
|
|
class PipelineStatus(Enum):
|
|
UNCONFIGURED = "unconfigured"
|
|
STOPPED = "stopped"
|
|
STARTING = "starting"
|
|
RUNNING = "running"
|
|
RESTARTING = "restarting"
|
|
DISCOVERY = "discovery"
|
|
ERROR = "error"
|
|
|
|
|
|
class AppState:
|
|
"""Thread-safe shared state for the HAMeter application.
|
|
|
|
Accessed by:
|
|
- Main thread: Pipeline reads/writes pipeline_status, last_readings
|
|
- Flask thread: Web routes read status, trigger restart/discovery
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._lock = threading.Lock()
|
|
|
|
# Pipeline state
|
|
self._status: PipelineStatus = PipelineStatus.UNCONFIGURED
|
|
self._status_message: str = ""
|
|
|
|
# Config object (set once loaded)
|
|
self._config = None
|
|
|
|
# Meter readings (most recent per meter ID)
|
|
self._last_readings: dict[int, MeterReading] = {}
|
|
self._reading_counts: dict[int, int] = {}
|
|
|
|
# Cost tracking per meter
|
|
self._cost_states: dict[int, CostState] = {}
|
|
|
|
# Discovery results
|
|
self._discovery_results: dict[int, dict] = {}
|
|
|
|
# Log ring buffer for web UI streaming
|
|
self._log_buffer: deque[dict] = deque(maxlen=1000)
|
|
|
|
# SSE subscribers (list of threading.Event per subscriber)
|
|
self._sse_events: list[threading.Event] = []
|
|
|
|
# Signals from web -> pipeline
|
|
self._restart_requested = threading.Event()
|
|
self._discovery_requested = threading.Event()
|
|
self._discovery_duration: int = 120
|
|
self._stop_discovery = threading.Event()
|
|
|
|
# Pipeline startup gate: set once config is valid
|
|
self._config_ready = threading.Event()
|
|
|
|
# --- Status ---
|
|
|
|
@property
|
|
def status(self) -> PipelineStatus:
|
|
with self._lock:
|
|
return self._status
|
|
|
|
def set_status(self, status: PipelineStatus, message: str = ""):
|
|
with self._lock:
|
|
self._status = status
|
|
self._status_message = message
|
|
self._notify_sse()
|
|
|
|
@property
|
|
def status_message(self) -> str:
|
|
with self._lock:
|
|
return self._status_message
|
|
|
|
# --- Config ---
|
|
|
|
@property
|
|
def config(self):
|
|
with self._lock:
|
|
return self._config
|
|
|
|
def set_config(self, config):
|
|
with self._lock:
|
|
self._config = config
|
|
|
|
# --- Readings ---
|
|
|
|
def record_reading(self, reading: MeterReading):
|
|
with self._lock:
|
|
self._last_readings[reading.meter_id] = reading
|
|
self._reading_counts[reading.meter_id] = (
|
|
self._reading_counts.get(reading.meter_id, 0) + 1
|
|
)
|
|
self._notify_sse()
|
|
|
|
def get_last_readings(self) -> dict[int, MeterReading]:
|
|
with self._lock:
|
|
return dict(self._last_readings)
|
|
|
|
def get_reading_counts(self) -> dict[int, int]:
|
|
with self._lock:
|
|
return dict(self._reading_counts)
|
|
|
|
def clear_readings(self, meter_id: Optional[int] = None):
|
|
"""Clear cached readings. If meter_id given, clear only that meter."""
|
|
with self._lock:
|
|
if meter_id is not None:
|
|
self._last_readings.pop(meter_id, None)
|
|
self._reading_counts.pop(meter_id, None)
|
|
else:
|
|
self._last_readings.clear()
|
|
self._reading_counts.clear()
|
|
self._notify_sse()
|
|
|
|
# --- Cost state ---
|
|
|
|
def get_cost_states(self) -> dict[int, CostState]:
|
|
with self._lock:
|
|
return dict(self._cost_states)
|
|
|
|
def get_cost_state(self, meter_id: int) -> Optional[CostState]:
|
|
with self._lock:
|
|
return self._cost_states.get(meter_id)
|
|
|
|
def update_cost_state(self, meter_id: int, cost_state: CostState):
|
|
with self._lock:
|
|
self._cost_states[meter_id] = cost_state
|
|
self._notify_sse()
|
|
|
|
def reset_cost_state(self, meter_id: int, timestamp: str):
|
|
"""Reset cost tracking for a new billing period."""
|
|
with self._lock:
|
|
self._cost_states[meter_id] = CostState(
|
|
cumulative_cost=0.0,
|
|
last_calibrated_reading=None,
|
|
billing_period_start=timestamp,
|
|
last_updated=timestamp,
|
|
fixed_charges_applied=0.0,
|
|
)
|
|
self._notify_sse()
|
|
|
|
def add_fixed_charges(self, meter_id: int, amount: float, timestamp: str):
|
|
"""Add fixed charges to the cumulative cost for a meter."""
|
|
with self._lock:
|
|
cs = self._cost_states.get(meter_id)
|
|
if cs:
|
|
cs.cumulative_cost = round(cs.cumulative_cost + amount, 4)
|
|
cs.fixed_charges_applied = round(
|
|
cs.fixed_charges_applied + amount, 4
|
|
)
|
|
cs.last_updated = timestamp
|
|
self._notify_sse()
|
|
|
|
def remove_cost_state(self, meter_id: int):
|
|
"""Remove cost state for a meter (e.g. when cost_factors are cleared)."""
|
|
with self._lock:
|
|
self._cost_states.pop(meter_id, None)
|
|
|
|
# --- Discovery ---
|
|
|
|
def record_discovery(self, meter_id: int, info: dict):
|
|
with self._lock:
|
|
self._discovery_results[meter_id] = info
|
|
self._notify_sse()
|
|
|
|
def get_discovery_results(self) -> dict[int, dict]:
|
|
with self._lock:
|
|
return dict(self._discovery_results)
|
|
|
|
def clear_discovery_results(self):
|
|
with self._lock:
|
|
self._discovery_results.clear()
|
|
|
|
# --- Log buffer ---
|
|
|
|
def add_log(self, record: dict):
|
|
with self._lock:
|
|
self._log_buffer.append(record)
|
|
self._notify_sse()
|
|
|
|
def get_recent_logs(self, count: int = 200) -> list[dict]:
|
|
with self._lock:
|
|
items = list(self._log_buffer)
|
|
return items[-count:]
|
|
|
|
# --- SSE notification ---
|
|
|
|
def subscribe_sse(self) -> threading.Event:
|
|
event = threading.Event()
|
|
with self._lock:
|
|
self._sse_events.append(event)
|
|
return event
|
|
|
|
def unsubscribe_sse(self, event: threading.Event):
|
|
with self._lock:
|
|
try:
|
|
self._sse_events.remove(event)
|
|
except ValueError:
|
|
pass
|
|
|
|
def _notify_sse(self):
|
|
with self._lock:
|
|
events = list(self._sse_events)
|
|
for event in events:
|
|
event.set()
|
|
|
|
# --- Signals ---
|
|
|
|
@property
|
|
def restart_requested(self) -> threading.Event:
|
|
return self._restart_requested
|
|
|
|
@property
|
|
def discovery_requested(self) -> threading.Event:
|
|
return self._discovery_requested
|
|
|
|
@property
|
|
def discovery_duration(self) -> int:
|
|
with self._lock:
|
|
return self._discovery_duration
|
|
|
|
@discovery_duration.setter
|
|
def discovery_duration(self, value: int):
|
|
with self._lock:
|
|
self._discovery_duration = value
|
|
|
|
@property
|
|
def stop_discovery(self) -> threading.Event:
|
|
return self._stop_discovery
|
|
|
|
@property
|
|
def config_ready(self) -> threading.Event:
|
|
return self._config_ready
|
|
|
|
|
|
class WebLogHandler(logging.Handler):
|
|
"""Captures log records into AppState for web UI streaming."""
|
|
|
|
def __init__(self, app_state: AppState):
|
|
super().__init__()
|
|
self._state = app_state
|
|
|
|
def emit(self, record):
|
|
try:
|
|
self._state.add_log({
|
|
"timestamp": time.strftime(
|
|
"%Y-%m-%d %H:%M:%S", time.localtime(record.created)
|
|
),
|
|
"level": record.levelname,
|
|
"name": record.name,
|
|
"message": record.getMessage(),
|
|
})
|
|
except Exception:
|
|
pass
|