initial commit
This commit is contained in:
275
hameter/state.py
Normal file
275
hameter/state.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""Thread-safe shared state between web server and pipeline."""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from hameter.meter import MeterReading
|
||||
|
||||
|
||||
@dataclass
|
||||
class CostState:
|
||||
"""Tracks cumulative cost for a single meter across a billing period."""
|
||||
|
||||
cumulative_cost: float = 0.0
|
||||
last_calibrated_reading: Optional[float] = None
|
||||
billing_period_start: str = ""
|
||||
last_updated: str = ""
|
||||
fixed_charges_applied: float = 0.0
|
||||
|
||||
|
||||
class PipelineStatus(Enum):
|
||||
UNCONFIGURED = "unconfigured"
|
||||
STOPPED = "stopped"
|
||||
STARTING = "starting"
|
||||
RUNNING = "running"
|
||||
RESTARTING = "restarting"
|
||||
DISCOVERY = "discovery"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class AppState:
|
||||
"""Thread-safe shared state for the HAMeter application.
|
||||
|
||||
Accessed by:
|
||||
- Main thread: Pipeline reads/writes pipeline_status, last_readings
|
||||
- Flask thread: Web routes read status, trigger restart/discovery
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Pipeline state
|
||||
self._status: PipelineStatus = PipelineStatus.UNCONFIGURED
|
||||
self._status_message: str = ""
|
||||
|
||||
# Config object (set once loaded)
|
||||
self._config = None
|
||||
|
||||
# Meter readings (most recent per meter ID)
|
||||
self._last_readings: dict[int, MeterReading] = {}
|
||||
self._reading_counts: dict[int, int] = {}
|
||||
|
||||
# Cost tracking per meter
|
||||
self._cost_states: dict[int, CostState] = {}
|
||||
|
||||
# Discovery results
|
||||
self._discovery_results: dict[int, dict] = {}
|
||||
|
||||
# Log ring buffer for web UI streaming
|
||||
self._log_buffer: deque[dict] = deque(maxlen=1000)
|
||||
|
||||
# SSE subscribers (list of threading.Event per subscriber)
|
||||
self._sse_events: list[threading.Event] = []
|
||||
|
||||
# Signals from web -> pipeline
|
||||
self._restart_requested = threading.Event()
|
||||
self._discovery_requested = threading.Event()
|
||||
self._discovery_duration: int = 120
|
||||
self._stop_discovery = threading.Event()
|
||||
|
||||
# Pipeline startup gate: set once config is valid
|
||||
self._config_ready = threading.Event()
|
||||
|
||||
# --- Status ---
|
||||
|
||||
@property
|
||||
def status(self) -> PipelineStatus:
|
||||
with self._lock:
|
||||
return self._status
|
||||
|
||||
def set_status(self, status: PipelineStatus, message: str = ""):
|
||||
with self._lock:
|
||||
self._status = status
|
||||
self._status_message = message
|
||||
self._notify_sse()
|
||||
|
||||
@property
|
||||
def status_message(self) -> str:
|
||||
with self._lock:
|
||||
return self._status_message
|
||||
|
||||
# --- Config ---
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
with self._lock:
|
||||
return self._config
|
||||
|
||||
def set_config(self, config):
|
||||
with self._lock:
|
||||
self._config = config
|
||||
|
||||
# --- Readings ---
|
||||
|
||||
def record_reading(self, reading: MeterReading):
|
||||
with self._lock:
|
||||
self._last_readings[reading.meter_id] = reading
|
||||
self._reading_counts[reading.meter_id] = (
|
||||
self._reading_counts.get(reading.meter_id, 0) + 1
|
||||
)
|
||||
self._notify_sse()
|
||||
|
||||
def get_last_readings(self) -> dict[int, MeterReading]:
|
||||
with self._lock:
|
||||
return dict(self._last_readings)
|
||||
|
||||
def get_reading_counts(self) -> dict[int, int]:
|
||||
with self._lock:
|
||||
return dict(self._reading_counts)
|
||||
|
||||
def clear_readings(self, meter_id: Optional[int] = None):
|
||||
"""Clear cached readings. If meter_id given, clear only that meter."""
|
||||
with self._lock:
|
||||
if meter_id is not None:
|
||||
self._last_readings.pop(meter_id, None)
|
||||
self._reading_counts.pop(meter_id, None)
|
||||
else:
|
||||
self._last_readings.clear()
|
||||
self._reading_counts.clear()
|
||||
self._notify_sse()
|
||||
|
||||
# --- Cost state ---
|
||||
|
||||
def get_cost_states(self) -> dict[int, CostState]:
|
||||
with self._lock:
|
||||
return dict(self._cost_states)
|
||||
|
||||
def get_cost_state(self, meter_id: int) -> Optional[CostState]:
|
||||
with self._lock:
|
||||
return self._cost_states.get(meter_id)
|
||||
|
||||
def update_cost_state(self, meter_id: int, cost_state: CostState):
|
||||
with self._lock:
|
||||
self._cost_states[meter_id] = cost_state
|
||||
self._notify_sse()
|
||||
|
||||
def reset_cost_state(self, meter_id: int, timestamp: str):
|
||||
"""Reset cost tracking for a new billing period."""
|
||||
with self._lock:
|
||||
self._cost_states[meter_id] = CostState(
|
||||
cumulative_cost=0.0,
|
||||
last_calibrated_reading=None,
|
||||
billing_period_start=timestamp,
|
||||
last_updated=timestamp,
|
||||
fixed_charges_applied=0.0,
|
||||
)
|
||||
self._notify_sse()
|
||||
|
||||
def add_fixed_charges(self, meter_id: int, amount: float, timestamp: str):
|
||||
"""Add fixed charges to the cumulative cost for a meter."""
|
||||
with self._lock:
|
||||
cs = self._cost_states.get(meter_id)
|
||||
if cs:
|
||||
cs.cumulative_cost = round(cs.cumulative_cost + amount, 4)
|
||||
cs.fixed_charges_applied = round(
|
||||
cs.fixed_charges_applied + amount, 4
|
||||
)
|
||||
cs.last_updated = timestamp
|
||||
self._notify_sse()
|
||||
|
||||
def remove_cost_state(self, meter_id: int):
|
||||
"""Remove cost state for a meter (e.g. when cost_factors are cleared)."""
|
||||
with self._lock:
|
||||
self._cost_states.pop(meter_id, None)
|
||||
|
||||
# --- Discovery ---
|
||||
|
||||
def record_discovery(self, meter_id: int, info: dict):
|
||||
with self._lock:
|
||||
self._discovery_results[meter_id] = info
|
||||
self._notify_sse()
|
||||
|
||||
def get_discovery_results(self) -> dict[int, dict]:
|
||||
with self._lock:
|
||||
return dict(self._discovery_results)
|
||||
|
||||
def clear_discovery_results(self):
|
||||
with self._lock:
|
||||
self._discovery_results.clear()
|
||||
|
||||
# --- Log buffer ---
|
||||
|
||||
def add_log(self, record: dict):
|
||||
with self._lock:
|
||||
self._log_buffer.append(record)
|
||||
self._notify_sse()
|
||||
|
||||
def get_recent_logs(self, count: int = 200) -> list[dict]:
|
||||
with self._lock:
|
||||
items = list(self._log_buffer)
|
||||
return items[-count:]
|
||||
|
||||
# --- SSE notification ---
|
||||
|
||||
def subscribe_sse(self) -> threading.Event:
|
||||
event = threading.Event()
|
||||
with self._lock:
|
||||
self._sse_events.append(event)
|
||||
return event
|
||||
|
||||
def unsubscribe_sse(self, event: threading.Event):
|
||||
with self._lock:
|
||||
try:
|
||||
self._sse_events.remove(event)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def _notify_sse(self):
|
||||
with self._lock:
|
||||
events = list(self._sse_events)
|
||||
for event in events:
|
||||
event.set()
|
||||
|
||||
# --- Signals ---
|
||||
|
||||
@property
|
||||
def restart_requested(self) -> threading.Event:
|
||||
return self._restart_requested
|
||||
|
||||
@property
|
||||
def discovery_requested(self) -> threading.Event:
|
||||
return self._discovery_requested
|
||||
|
||||
@property
|
||||
def discovery_duration(self) -> int:
|
||||
with self._lock:
|
||||
return self._discovery_duration
|
||||
|
||||
@discovery_duration.setter
|
||||
def discovery_duration(self, value: int):
|
||||
with self._lock:
|
||||
self._discovery_duration = value
|
||||
|
||||
@property
|
||||
def stop_discovery(self) -> threading.Event:
|
||||
return self._stop_discovery
|
||||
|
||||
@property
|
||||
def config_ready(self) -> threading.Event:
|
||||
return self._config_ready
|
||||
|
||||
|
||||
class WebLogHandler(logging.Handler):
|
||||
"""Captures log records into AppState for web UI streaming."""
|
||||
|
||||
def __init__(self, app_state: AppState):
|
||||
super().__init__()
|
||||
self._state = app_state
|
||||
|
||||
def emit(self, record):
|
||||
try:
|
||||
self._state.add_log({
|
||||
"timestamp": time.strftime(
|
||||
"%Y-%m-%d %H:%M:%S", time.localtime(record.created)
|
||||
),
|
||||
"level": record.levelname,
|
||||
"name": record.name,
|
||||
"message": record.getMessage(),
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
Reference in New Issue
Block a user