work on model training

This commit is contained in:
2026-03-05 11:03:20 +11:00
parent 96e72d7c43
commit c8e38cd597
10 changed files with 534 additions and 30 deletions

View File

@@ -5,6 +5,6 @@ WORKDIR /app
COPY scripts/requirements.txt /app/requirements.txt
RUN pip install --no-cache-dir -r /app/requirements.txt
COPY scripts/train_rain_model.py /app/train_rain_model.py
COPY scripts /app/scripts
ENTRYPOINT ["python", "/app/train_rain_model.py"]
ENTRYPOINT ["python", "/app/scripts/run_rain_ml_worker.py"]

View File

@@ -3,6 +3,7 @@
Starter weather-station data pipeline:
- MQTT ingest of WS90 payloads -> TimescaleDB
- ECMWF (Open-Meteo) forecast polling -> TimescaleDB
- Python rain-model worker (periodic training + inference writes) -> TimescaleDB
- Web UI with live metrics, comparisons, and charts
## Quick start
@@ -73,6 +74,7 @@ TimescaleDB schema is initialized from `db/init/001_schema.sql` and includes:
- `observations_ws90` (hypertable): raw WS90 observations with payload metadata, plus the full JSON payload (`payload_json`).
- `observations_baro` (hypertable): barometric pressure observations from other MQTT topics.
- `forecast_openmeteo_hourly` (hypertable): hourly forecast points keyed by `(site, model, retrieved_at, ts)`.
- `predictions_rain_1h` (hypertable): model probability + decision + realized outcome fields.
- Continuous aggregates:
- `cagg_ws90_1m`: 1minute rollups (avg/min/max for temp, humidity, wind, uvi, light, rain).
- `cagg_ws90_5m`: 5minute rollups (same metrics as `cagg_ws90_1m`).

View File

@@ -24,14 +24,16 @@ type webServer struct {
}
type dashboardResponse struct {
GeneratedAt time.Time `json:"generated_at"`
Site string `json:"site"`
Model string `json:"model"`
RangeStart time.Time `json:"range_start"`
RangeEnd time.Time `json:"range_end"`
Observations []db.ObservationPoint `json:"observations"`
Forecast db.ForecastSeries `json:"forecast"`
Latest *db.ObservationPoint `json:"latest"`
GeneratedAt time.Time `json:"generated_at"`
Site string `json:"site"`
Model string `json:"model"`
RangeStart time.Time `json:"range_start"`
RangeEnd time.Time `json:"range_end"`
Observations []db.ObservationPoint `json:"observations"`
Forecast db.ForecastSeries `json:"forecast"`
Latest *db.ObservationPoint `json:"latest"`
LatestRainPredict *db.RainPredictionPoint `json:"latest_rain_prediction,omitempty"`
RainPredictionRange []db.RainPredictionPoint `json:"rain_predictions,omitempty"`
}
func runWebServer(ctx context.Context, d *db.DB, site providers.Site, model, addr string) error {
@@ -171,15 +173,33 @@ func (s *webServer) handleDashboard(w http.ResponseWriter, r *http.Request) {
return
}
const rainModelName = "rain_next_1h"
latestRainPrediction, err := s.db.LatestRainPrediction(r.Context(), s.site.Name, rainModelName)
if err != nil {
http.Error(w, "failed to query latest rain prediction", http.StatusInternalServerError)
log.Printf("web dashboard latest rain prediction error: %v", err)
return
}
rainPredictionRange, err := s.db.RainPredictionSeriesRange(r.Context(), s.site.Name, rainModelName, start, end)
if err != nil {
http.Error(w, "failed to query rain predictions", http.StatusInternalServerError)
log.Printf("web dashboard rain prediction range error: %v", err)
return
}
resp := dashboardResponse{
GeneratedAt: time.Now().UTC(),
Site: s.site.Name,
Model: s.model,
RangeStart: start,
RangeEnd: end,
Observations: observations,
Forecast: forecast,
Latest: latest,
GeneratedAt: time.Now().UTC(),
Site: s.site.Name,
Model: s.model,
RangeStart: start,
RangeEnd: end,
Observations: observations,
Forecast: forecast,
Latest: latest,
LatestRainPredict: latestRainPrediction,
RainPredictionRange: rainPredictionRange,
}
w.Header().Set("Content-Type", "application/json")

View File

@@ -501,6 +501,40 @@ function buildRainProbabilitySeries(points) {
return out;
}
function buildRainProbabilitySeriesFromPredictions(points) {
return points.map((p) => {
const t = new Date(p.ts).getTime();
if (Number.isNaN(t)) {
return { x: null, y: null };
}
if (p.probability === null || p.probability === undefined) {
return { x: t, y: null };
}
return {
x: t,
y: Math.round(Number(p.probability) * 1000) / 10,
};
});
}
function thresholdSeries(range, threshold) {
if (!range || !range.axisStart || !range.axisEnd || threshold === null || threshold === undefined) {
return [];
}
const y = Math.round(Number(threshold) * 1000) / 10;
return [
{ x: range.axisStart.getTime(), y },
{ x: range.axisEnd.getTime(), y },
];
}
function predictionAgeMinutes(prediction) {
if (!prediction || !prediction.ts) return null;
const ts = new Date(prediction.ts).getTime();
if (Number.isNaN(ts)) return null;
return (Date.now() - ts) / (60 * 1000);
}
function updateWeatherIcons(latest, rainProb) {
const sunEl = document.getElementById("live-icon-sun");
const cloudEl = document.getElementById("live-icon-cloud");
@@ -693,10 +727,26 @@ function renderDashboard(data) {
const obsFiltered = filterRange(obs, rangeStart, rangeEnd);
const forecast = filterRange(forecastAll, rangeStart, rangeEnd);
const forecastLine = extendForecastTo(forecast, rangeEnd);
const rainPredictions = filterRange(data.rain_predictions || [], rangeStart, rangeEnd);
const latestRainPrediction = data.latest_rain_prediction || null;
const latestPredictionAgeMin = predictionAgeMinutes(latestRainPrediction);
const modelPredictionFresh = latestPredictionAgeMin !== null && latestPredictionAgeMin <= 90;
const lastPressureTrend = lastNonNull(obsFiltered, "pressure_trend_1h");
const rainProb = computeRainProbability(latest);
const modelRainProb = modelPredictionFresh && latestRainPrediction && latestRainPrediction.probability !== null && latestRainPrediction.probability !== undefined
? {
prob: Number(latestRainPrediction.probability),
label: classifyRainProbability(Number(latestRainPrediction.probability)),
source: "model",
}
: null;
const heuristicRainProb = computeRainProbability(latest);
const rainProb = modelRainProb || heuristicRainProb;
if (rainProb) {
updateText("live-rain-prob", `${Math.round(rainProb.prob * 100)}% (${rainProb.label})`);
const sourceLabel = rainProb.source === "model" ? "model" : "heuristic";
updateText("live-rain-prob", `${Math.round(rainProb.prob * 100)}% (${rainProb.label}, ${sourceLabel})`);
} else if (latestRainPrediction && latestRainPrediction.probability !== null && latestRainPrediction.probability !== undefined) {
const stalePct = Math.round(Number(latestRainPrediction.probability) * 100);
updateText("live-rain-prob", `${stalePct}% (stale model)`);
} else {
updateText("live-rain-prob", "--");
}
@@ -901,13 +951,20 @@ function renderDashboard(data) {
data: {
datasets: [
{
label: "predicted rain probability (%)",
data: buildRainProbabilitySeries(obsFiltered),
label: rainPredictions.length ? "model rain probability (%)" : "heuristic rain probability (%)",
data: rainPredictions.length ? buildRainProbabilitySeriesFromPredictions(rainPredictions) : buildRainProbabilitySeries(obsFiltered),
borderColor: colors.rain,
backgroundColor: "rgba(78, 168, 222, 0.18)",
fill: true,
yAxisID: "y",
},
{
label: "decision threshold (%)",
data: thresholdSeries(range, latestRainPrediction ? latestRainPrediction.threshold : null),
borderColor: "#f4b942",
borderDash: [6, 4],
yAxisID: "y",
},
],
},
options: rainProbOptions,

View File

@@ -198,7 +198,7 @@
</div>
<div class="chart-card" data-chart="chart-rain-prob">
<div class="chart-header">
<div class="chart-title">Predicted Rain Probability (Observed Inputs)</div>
<div class="chart-title">Predicted Rain Probability (Model)</div>
<button class="chart-link" data-chart="chart-rain-prob" title="Copy chart link">Share</button>
</div>
<div class="chart-canvas">

View File

@@ -24,5 +24,27 @@ services:
volumes:
- ./config.yaml:/app/config.yaml:ro
rainml:
build:
context: .
dockerfile: Dockerfile.train
depends_on:
- timescaledb
restart: unless-stopped
environment:
DATABASE_URL: "postgres://postgres:postgres@timescaledb:5432/micrometeo?sslmode=disable"
RAIN_SITE: "home"
RAIN_MODEL_NAME: "rain_next_1h"
RAIN_MODEL_VERSION_BASE: "rain-logreg-v1"
RAIN_LOOKBACK_DAYS: "30"
RAIN_TRAIN_INTERVAL_HOURS: "24"
RAIN_PREDICT_INTERVAL_MINUTES: "10"
RAIN_MIN_PRECISION: "0.70"
RAIN_MODEL_PATH: "/app/models/rain_model.pkl"
RAIN_REPORT_PATH: "/app/models/rain_model_report.json"
RAIN_AUDIT_PATH: "/app/models/rain_data_audit.json"
volumes:
- ./models:/app/models
volumes:
tsdata:

View File

@@ -40,6 +40,7 @@ pip install -r scripts/requirements.txt
- `scripts/train_rain_model.py`: strict time-based split training and metrics report.
- `scripts/predict_rain_model.py`: inference using saved model artifact; upserts into
`predictions_rain_1h`.
- `scripts/run_rain_ml_worker.py`: long-running worker for periodic training + prediction.
## Usage
### 1) Apply schema update (existing DBs)
@@ -90,6 +91,18 @@ export DATABASE_URL="postgres://postgres:postgres@localhost:5432/micrometeo?sslm
bash scripts/run_p0_rain_workflow.sh
```
### 6) Continuous training + prediction via Docker Compose
The `rainml` service in `docker-compose.yml` now runs:
- periodic retraining (default every 24 hours)
- periodic prediction writes (default every 10 minutes)
Artifacts are persisted to `./models` on the host.
```sh
docker compose up -d rainml
docker compose logs -f rainml
```
## Output
- Audit report: `models/rain_data_audit.json`
- Training report: `models/rain_model_report.json`

View File

@@ -46,6 +46,19 @@ type ForecastSeries struct {
Points []ForecastPoint `json:"points"`
}
type RainPredictionPoint struct {
TS time.Time `json:"ts"`
GeneratedAt time.Time `json:"generated_at"`
ModelName string `json:"model_name"`
ModelVersion string `json:"model_version"`
Threshold float64 `json:"threshold"`
Probability float64 `json:"probability"`
PredictRain bool `json:"predict_rain"`
RainNext1hMM *float64 `json:"rain_next_1h_mm_actual,omitempty"`
RainNext1hActual *bool `json:"rain_next_1h_actual,omitempty"`
EvaluatedAt *time.Time `json:"evaluated_at,omitempty"`
}
func (d *DB) ObservationSeries(ctx context.Context, site, bucket string, start, end time.Time) ([]ObservationPoint, error) {
if end.Before(start) || end.Equal(start) {
return nil, errors.New("invalid time range")
@@ -352,6 +365,139 @@ func (d *DB) ForecastSeriesRange(ctx context.Context, site, model string, start,
}, nil
}
func (d *DB) LatestRainPrediction(ctx context.Context, site, modelName string) (*RainPredictionPoint, error) {
query := `
SELECT
ts,
generated_at,
model_name,
model_version,
threshold,
probability,
predict_rain,
rain_next_1h_mm_actual,
rain_next_1h_actual,
evaluated_at
FROM predictions_rain_1h
WHERE site = $1
AND model_name = $2
ORDER BY ts DESC, generated_at DESC
LIMIT 1
`
var (
p RainPredictionPoint
rainMM, threshold, probability sql.NullFloat64
rainActual sql.NullBool
evaluatedAt sql.NullTime
predictRain sql.NullBool
)
err := d.Pool.QueryRow(ctx, query, site, modelName).Scan(
&p.TS,
&p.GeneratedAt,
&p.ModelName,
&p.ModelVersion,
&threshold,
&probability,
&predictRain,
&rainMM,
&rainActual,
&evaluatedAt,
)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return nil, err
}
if threshold.Valid {
p.Threshold = threshold.Float64
}
if probability.Valid {
p.Probability = probability.Float64
}
if predictRain.Valid {
p.PredictRain = predictRain.Bool
}
p.RainNext1hMM = nullFloatPtr(rainMM)
p.RainNext1hActual = nullBoolPtr(rainActual)
p.EvaluatedAt = nullTimePtr(evaluatedAt)
return &p, nil
}
func (d *DB) RainPredictionSeriesRange(ctx context.Context, site, modelName string, start, end time.Time) ([]RainPredictionPoint, error) {
query := `
SELECT DISTINCT ON (ts)
ts,
generated_at,
model_name,
model_version,
threshold,
probability,
predict_rain,
rain_next_1h_mm_actual,
rain_next_1h_actual,
evaluated_at
FROM predictions_rain_1h
WHERE site = $1
AND model_name = $2
AND ts >= $3
AND ts <= $4
ORDER BY ts ASC, generated_at DESC
`
rows, err := d.Pool.Query(ctx, query, site, modelName, start, end)
if err != nil {
return nil, err
}
defer rows.Close()
points := make([]RainPredictionPoint, 0, 256)
for rows.Next() {
var (
p RainPredictionPoint
rainMM, threshold, probability sql.NullFloat64
rainActual sql.NullBool
evaluatedAt sql.NullTime
predictRain sql.NullBool
)
if err := rows.Scan(
&p.TS,
&p.GeneratedAt,
&p.ModelName,
&p.ModelVersion,
&threshold,
&probability,
&predictRain,
&rainMM,
&rainActual,
&evaluatedAt,
); err != nil {
return nil, err
}
if threshold.Valid {
p.Threshold = threshold.Float64
}
if probability.Valid {
p.Probability = probability.Float64
}
if predictRain.Valid {
p.PredictRain = predictRain.Bool
}
p.RainNext1hMM = nullFloatPtr(rainMM)
p.RainNext1hActual = nullBoolPtr(rainActual)
p.EvaluatedAt = nullTimePtr(evaluatedAt)
points = append(points, p)
}
if rows.Err() != nil {
return nil, rows.Err()
}
return points, nil
}
func nullFloatPtr(v sql.NullFloat64) *float64 {
if !v.Valid {
return nil
@@ -367,3 +513,19 @@ func nullIntPtr(v sql.NullInt64) *int64 {
val := v.Int64
return &val
}
func nullBoolPtr(v sql.NullBool) *bool {
if !v.Valid {
return nil
}
val := v.Bool
return &val
}
func nullTimePtr(v sql.NullTime) *time.Time {
if !v.Valid {
return nil
}
val := v.Time
return &val
}

View File

@@ -0,0 +1,228 @@
#!/usr/bin/env python3
from __future__ import annotations
import os
import subprocess
import sys
import time
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from pathlib import Path
def read_env(name: str, default: str) -> str:
return os.getenv(name, default).strip()
def read_env_float(name: str, default: float) -> float:
raw = os.getenv(name)
if raw is None or raw.strip() == "":
return default
return float(raw)
def read_env_int(name: str, default: int) -> int:
raw = os.getenv(name)
if raw is None or raw.strip() == "":
return default
return int(raw)
def read_env_bool(name: str, default: bool) -> bool:
raw = os.getenv(name)
if raw is None:
return default
return raw.strip().lower() in {"1", "true", "yes", "on"}
@dataclass
class WorkerConfig:
database_url: str
site: str
model_name: str
model_version_base: str
train_interval_hours: float
predict_interval_minutes: float
lookback_days: int
train_ratio: float
val_ratio: float
min_precision: float
model_path: Path
report_path: Path
audit_path: Path
run_once: bool
retry_delay_seconds: int
def now_utc() -> datetime:
return datetime.now(timezone.utc).replace(microsecond=0)
def iso_utc(v: datetime) -> str:
return v.astimezone(timezone.utc).isoformat().replace("+00:00", "Z")
def run_cmd(cmd: list[str], env: dict[str, str]) -> None:
print(f"[rain-ml] running: {' '.join(cmd)}", flush=True)
subprocess.run(cmd, env=env, check=True)
def ensure_parent(path: Path) -> None:
if path.parent and not path.parent.exists():
path.parent.mkdir(parents=True, exist_ok=True)
def training_window(lookback_days: int) -> tuple[str, str]:
end = now_utc()
start = end - timedelta(days=lookback_days)
return iso_utc(start), iso_utc(end)
def run_training_cycle(cfg: WorkerConfig, env: dict[str, str]) -> None:
start, end = training_window(cfg.lookback_days)
model_version = f"{cfg.model_version_base}-{now_utc().strftime('%Y%m%d%H%M')}"
ensure_parent(cfg.audit_path)
ensure_parent(cfg.report_path)
ensure_parent(cfg.model_path)
run_cmd(
[
sys.executable,
"scripts/audit_rain_data.py",
"--site",
cfg.site,
"--start",
start,
"--end",
end,
"--out",
str(cfg.audit_path),
],
env,
)
run_cmd(
[
sys.executable,
"scripts/train_rain_model.py",
"--site",
cfg.site,
"--start",
start,
"--end",
end,
"--train-ratio",
str(cfg.train_ratio),
"--val-ratio",
str(cfg.val_ratio),
"--min-precision",
str(cfg.min_precision),
"--model-version",
model_version,
"--out",
str(cfg.model_path),
"--report-out",
str(cfg.report_path),
],
env,
)
def run_predict_once(cfg: WorkerConfig, env: dict[str, str]) -> None:
if not cfg.model_path.exists():
raise RuntimeError(f"model artifact not found: {cfg.model_path}")
run_cmd(
[
sys.executable,
"scripts/predict_rain_model.py",
"--site",
cfg.site,
"--model-path",
str(cfg.model_path),
"--model-name",
cfg.model_name,
],
env,
)
def load_config() -> WorkerConfig:
database_url = read_env("DATABASE_URL", "")
if not database_url:
raise SystemExit("DATABASE_URL is required")
return WorkerConfig(
database_url=database_url,
site=read_env("RAIN_SITE", "home"),
model_name=read_env("RAIN_MODEL_NAME", "rain_next_1h"),
model_version_base=read_env("RAIN_MODEL_VERSION_BASE", "rain-logreg-v1"),
train_interval_hours=read_env_float("RAIN_TRAIN_INTERVAL_HOURS", 24.0),
predict_interval_minutes=read_env_float("RAIN_PREDICT_INTERVAL_MINUTES", 10.0),
lookback_days=read_env_int("RAIN_LOOKBACK_DAYS", 30),
train_ratio=read_env_float("RAIN_TRAIN_RATIO", 0.7),
val_ratio=read_env_float("RAIN_VAL_RATIO", 0.15),
min_precision=read_env_float("RAIN_MIN_PRECISION", 0.70),
model_path=Path(read_env("RAIN_MODEL_PATH", "models/rain_model.pkl")),
report_path=Path(read_env("RAIN_REPORT_PATH", "models/rain_model_report.json")),
audit_path=Path(read_env("RAIN_AUDIT_PATH", "models/rain_data_audit.json")),
run_once=read_env_bool("RAIN_RUN_ONCE", False),
retry_delay_seconds=read_env_int("RAIN_RETRY_DELAY_SECONDS", 60),
)
def main() -> int:
cfg = load_config()
env = os.environ.copy()
env["DATABASE_URL"] = cfg.database_url
train_every = timedelta(hours=cfg.train_interval_hours)
predict_every = timedelta(minutes=cfg.predict_interval_minutes)
next_train = now_utc()
next_predict = now_utc()
trained_once = False
predicted_once = False
print(
"[rain-ml] worker start "
f"site={cfg.site} "
f"model_name={cfg.model_name} "
f"train_interval_hours={cfg.train_interval_hours} "
f"predict_interval_minutes={cfg.predict_interval_minutes}",
flush=True,
)
while True:
now = now_utc()
try:
if now >= next_train:
run_training_cycle(cfg, env)
next_train = now + train_every
trained_once = True
if now >= next_predict:
run_predict_once(cfg, env)
next_predict = now + predict_every
predicted_once = True
if cfg.run_once and trained_once and predicted_once:
print("[rain-ml] run-once complete", flush=True)
return 0
except subprocess.CalledProcessError as exc:
print(f"[rain-ml] command failed exit={exc.returncode}; retrying in {cfg.retry_delay_seconds}s", flush=True)
time.sleep(cfg.retry_delay_seconds)
continue
except Exception as exc: # pragma: no cover - defensive for runtime worker
print(f"[rain-ml] worker error: {exc}; retrying in {cfg.retry_delay_seconds}s", flush=True)
time.sleep(cfg.retry_delay_seconds)
continue
sleep_for = min((next_train - now).total_seconds(), (next_predict - now).total_seconds(), 30.0)
if sleep_for > 0:
time.sleep(sleep_for)
if __name__ == "__main__":
raise SystemExit(main())

14
todo.md
View File

@@ -9,9 +9,9 @@ Priority key: `P0` = critical/blocking, `P1` = important, `P2` = later optimizat
- [x] [P0] Freeze training window with explicit UTC start/end timestamps.
## 2) Data Quality and Label Validation
- [ ] [P0] Audit `observations_ws90` and `observations_baro` for missingness, gaps, duplicates, and out-of-order rows. (script ready: `scripts/audit_rain_data.py`; run on runtime machine)
- [ ] [P0] Validate rain label construction from `rain_mm` (counter resets, negative deltas, spikes). (script ready: `scripts/audit_rain_data.py`; run on runtime machine)
- [ ] [P0] Measure class balance by week (rain-positive vs rain-negative). (script ready: `scripts/audit_rain_data.py`; run on runtime machine)
- [x] [P0] Audit `observations_ws90` and `observations_baro` for missingness, gaps, duplicates, and out-of-order rows. (completed on runtime machine)
- [x] [P0] Validate rain label construction from `rain_mm` (counter resets, negative deltas, spikes). (completed on runtime machine)
- [x] [P0] Measure class balance by week (rain-positive vs rain-negative). (completed on runtime machine)
- [ ] [P1] Document known data issues and mitigation rules.
## 3) Dataset and Feature Engineering
@@ -38,10 +38,10 @@ Priority key: `P0` = critical/blocking, `P1` = important, `P2` = later optimizat
- [ ] [P1] Produce a short model card (data window, features, metrics, known limitations).
## 6) Packaging and Deployment
- [ ] [P1] Version model artifacts and feature schema together.
- [x] [P1] Version model artifacts and feature schema together.
- [x] [P0] Implement inference path with feature parity between training and serving.
- [x] [P0] Add prediction storage table for predicted probabilities and realized outcomes.
- [ ] [P1] Expose predictions via API and optionally surface in web dashboard.
- [x] [P1] Expose predictions via API and optionally surface in web dashboard.
- [ ] [P2] Add scheduled retraining with rollback to last-known-good model.
## 7) Monitoring and Operations
@@ -51,7 +51,7 @@ Priority key: `P0` = critical/blocking, `P1` = important, `P2` = later optimizat
- [ ] [P1] Document runbook for train/evaluate/deploy/rollback.
## 8) Immediate Next Steps (This Week)
- [ ] [P0] Run first full data audit and label-quality checks. (blocked here; run on runtime machine)
- [ ] [P0] Train baseline model on full available history and capture metrics. (blocked here; run on runtime machine)
- [x] [P0] Run first full data audit and label-quality checks. (completed on runtime machine)
- [x] [P0] Train baseline model on full available history and capture metrics. (completed on runtime machine)
- [ ] [P1] Add one expanded feature set and rerun evaluation.
- [x] [P0] Decide v1 threshold and define deployment interface.