From 8edd0dc8b03d3c63785925acc506536eafbb8f6a Mon Sep 17 00:00:00 2001 From: Nathan Coad Date: Mon, 2 Feb 2026 17:08:43 +1100 Subject: [PATCH] add model training --- Dockerfile.train | 10 ++ cmd/ingestd/web/app.js | 61 ++++++++++++ cmd/ingestd/web/index.html | 4 + docs/rain_prediction.md | 102 +++++++++++++++++++++ internal/db/series.go | 37 ++++++-- scripts/requirements.txt | 5 + scripts/train_rain_model.py | 178 ++++++++++++++++++++++++++++++++++++ 7 files changed, 388 insertions(+), 9 deletions(-) create mode 100644 Dockerfile.train create mode 100644 docs/rain_prediction.md create mode 100644 scripts/requirements.txt create mode 100644 scripts/train_rain_model.py diff --git a/Dockerfile.train b/Dockerfile.train new file mode 100644 index 0000000..e0f77dc --- /dev/null +++ b/Dockerfile.train @@ -0,0 +1,10 @@ +FROM python:3.11-slim + +WORKDIR /app + +COPY scripts/requirements.txt /app/requirements.txt +RUN pip install --no-cache-dir -r /app/requirements.txt + +COPY scripts/train_rain_model.py /app/train_rain_model.py + +ENTRYPOINT ["python", "/app/train_rain_model.py"] diff --git a/cmd/ingestd/web/app.js b/cmd/ingestd/web/app.js index 49156b8..7d17efb 100644 --- a/cmd/ingestd/web/app.js +++ b/cmd/ingestd/web/app.js @@ -289,6 +289,60 @@ function updateText(id, text) { if (el) el.textContent = text; } +function lastNonNull(points, key) { + for (let i = points.length - 1; i >= 0; i -= 1) { + const v = points[i][key]; + if (v !== null && v !== undefined) { + return v; + } + } + return null; +} + +function computeRainProbability(latest, pressureTrend1h) { + if (!latest) { + return null; + } + + let prob = 0.1; + if (pressureTrend1h !== null && pressureTrend1h !== undefined) { + if (pressureTrend1h <= -3.0) { + prob += 0.5; + } else if (pressureTrend1h <= -2.0) { + prob += 0.35; + } else if (pressureTrend1h <= -1.0) { + prob += 0.2; + } else if (pressureTrend1h <= -0.5) { + prob += 0.1; + } + } + + if (latest.rh !== null && latest.rh !== undefined) { + if (latest.rh >= 95) { + prob += 0.2; + } else if (latest.rh >= 90) { + prob += 0.15; + } else if (latest.rh >= 85) { + prob += 0.1; + } + } + + if (latest.wind_m_s !== null && latest.wind_m_s !== undefined && latest.wind_m_s >= 6) { + prob += 0.05; + } + + prob = Math.max(0.05, Math.min(0.95, prob)); + + let label = "Low"; + if (prob >= 0.6) { + label = "High"; + } else if (prob >= 0.35) { + label = "Medium"; + } + + return { prob, label }; +} + function updateSiteMeta(site, model, tzLabel) { const home = document.getElementById("site-home"); const suffix = document.getElementById("site-meta-suffix"); @@ -400,6 +454,13 @@ function renderDashboard(data) { const obsFiltered = filterRange(obs, rangeStart, rangeEnd); const forecast = filterRange(forecastAll, rangeStart, rangeEnd); + const lastPressureTrend = lastNonNull(obsFiltered, "pressure_trend_1h"); + const rainProb = computeRainProbability(latest, lastPressureTrend); + if (rainProb) { + updateText("live-rain-prob", `${Math.round(rainProb.prob * 100)}% (${rainProb.label})`); + } else { + updateText("live-rain-prob", "--"); + } const obsTemps = obsFiltered.map((p) => p.temp_c); const obsWinds = obsFiltered.map((p) => p.wind_m_s); diff --git a/cmd/ingestd/web/index.html b/cmd/ingestd/web/index.html index 6f71dc3..fe4137c 100644 --- a/cmd/ingestd/web/index.html +++ b/cmd/ingestd/web/index.html @@ -58,6 +58,10 @@
Pressure hPa
--
+
+
Rain 1h %
+
--
+
Wind m/s
--
diff --git a/docs/rain_prediction.md b/docs/rain_prediction.md new file mode 100644 index 0000000..c128529 --- /dev/null +++ b/docs/rain_prediction.md @@ -0,0 +1,102 @@ +# Rain Prediction (Next 1 Hour) + +This project now includes a starter training script for a **binary rain prediction**: + +> **Will we see >= 0.2 mm of rain in the next hour?** + +It uses local observations (WS90 + barometric pressure) and trains a lightweight +logistic regression model. This is a baseline you can iterate on as you collect +more data. + +## What the script does +- Pulls data from TimescaleDB. +- Resamples observations to 5-minute buckets. +- Derives **pressure trend (1h)** from barometer data. +- Computes **future 1-hour rainfall** from the cumulative `rain_mm` counter. +- Trains a model and prints evaluation metrics. + +The output is a saved model file (optional) you can use later for inference. + +## Requirements +Python 3.10+ and: + +``` +pandas +numpy +scikit-learn +psycopg2-binary +joblib +``` + +Install with: + +```sh +python3 -m venv .venv +source .venv/bin/activate +pip install -r scripts/requirements.txt +``` + +## Usage + +```sh +python scripts/train_rain_model.py \ + --db-url "postgres://postgres:postgres@localhost:5432/micrometeo?sslmode=disable" \ + --site "home" \ + --start "2026-01-01" \ + --end "2026-02-01" \ + --out "models/rain_model.pkl" +``` + +You can also provide the connection string via `DATABASE_URL`: + +```sh +export DATABASE_URL="postgres://postgres:postgres@localhost:5432/micrometeo?sslmode=disable" +python scripts/train_rain_model.py --site home +``` + +## Output +The script prints metrics including: +- accuracy +- precision / recall +- ROC AUC +- confusion matrix + +If `joblib` is installed, it saves a model bundle: + +``` +models/rain_model.pkl +``` + +This bundle contains: +- The trained model pipeline +- The feature list used during training + +## Data needs / when to run +For a reliable model, you will want: +- **At least 2-4 weeks** of observations +- A mix of rainy and non-rainy periods + +Training with only a few days will produce an unstable model. + +## Features used +The baseline model uses: +- `pressure_trend_1h` (hPa) +- `humidity` (%) +- `temperature_c` (C) +- `wind_avg_m_s` (m/s) +- `wind_max_m_s` (m/s) + +These are easy to expand once you have more data (e.g. add forecast features). + +## Notes / assumptions +- Rain detection is based on **incremental rain** derived from the WS90 + `rain_mm` cumulative counter. +- Pressure comes from `observations_baro`. +- All timestamps are treated as UTC. + +## Next improvements +Ideas once more data is available: +- Add forecast precipitation and cloud cover as features +- Try gradient boosted trees (e.g. XGBoost / LightGBM) +- Train per-season models +- Calibrate probabilities (Platt scaling / isotonic regression) diff --git a/internal/db/series.go b/internal/db/series.go index 59bdede..f9ebb7f 100644 --- a/internal/db/series.go +++ b/internal/db/series.go @@ -15,15 +15,17 @@ type ObservationPoint struct { TempC *float64 `json:"temp_c,omitempty"` RH *float64 `json:"rh,omitempty"` PressureHPA *float64 `json:"pressure_hpa,omitempty"` - WindMS *float64 `json:"wind_m_s,omitempty"` - WindGustMS *float64 `json:"wind_gust_m_s,omitempty"` - WindDirDeg *float64 `json:"wind_dir_deg,omitempty"` - UVI *float64 `json:"uvi,omitempty"` - LightLux *float64 `json:"light_lux,omitempty"` - BatteryMV *float64 `json:"battery_mv,omitempty"` - SupercapV *float64 `json:"supercap_v,omitempty"` - RainMM *float64 `json:"rain_mm,omitempty"` - RainStart *int64 `json:"rain_start,omitempty"` + // PressureTrend1h is the change in pressure over the last hour (hPa). + PressureTrend1h *float64 `json:"pressure_trend_1h,omitempty"` + WindMS *float64 `json:"wind_m_s,omitempty"` + WindGustMS *float64 `json:"wind_gust_m_s,omitempty"` + WindDirDeg *float64 `json:"wind_dir_deg,omitempty"` + UVI *float64 `json:"uvi,omitempty"` + LightLux *float64 `json:"light_lux,omitempty"` + BatteryMV *float64 `json:"battery_mv,omitempty"` + SupercapV *float64 `json:"supercap_v,omitempty"` + RainMM *float64 `json:"rain_mm,omitempty"` + RainStart *int64 `json:"rain_start,omitempty"` } type ForecastPoint struct { @@ -147,6 +149,23 @@ func (d *DB) ObservationSeries(ctx context.Context, site, bucket string, start, return nil, rows.Err() } + indexByTime := make(map[time.Time]int, len(points)) + for i := range points { + indexByTime[points[i].TS] = i + } + for i := range points { + if points[i].PressureHPA == nil { + continue + } + target := points[i].TS.Add(-1 * time.Hour) + j, ok := indexByTime[target] + if !ok || points[j].PressureHPA == nil { + continue + } + trend := *points[i].PressureHPA - *points[j].PressureHPA + points[i].PressureTrend1h = &trend + } + return points, nil } diff --git a/scripts/requirements.txt b/scripts/requirements.txt new file mode 100644 index 0000000..928727b --- /dev/null +++ b/scripts/requirements.txt @@ -0,0 +1,5 @@ +pandas>=2.2.0 +numpy>=1.26.0 +scikit-learn>=1.4.0 +psycopg2-binary>=2.9.0 +joblib>=1.3.0 diff --git a/scripts/train_rain_model.py b/scripts/train_rain_model.py new file mode 100644 index 0000000..2dadd94 --- /dev/null +++ b/scripts/train_rain_model.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 +import argparse +import os +from datetime import datetime + +import numpy as np +import pandas as pd +import psycopg2 +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, roc_auc_score +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler + +try: + import joblib +except ImportError: # pragma: no cover - optional dependency + joblib = None + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Train a simple rain prediction model (next 1h >= 0.2mm).") + parser.add_argument("--db-url", default=os.getenv("DATABASE_URL"), help="Postgres connection string.") + parser.add_argument("--site", required=True, help="Site name (e.g. home).") + parser.add_argument("--start", help="Start time (RFC3339 or YYYY-MM-DD).") + parser.add_argument("--end", help="End time (RFC3339 or YYYY-MM-DD).") + parser.add_argument("--out", default="models/rain_model.pkl", help="Path to save model.") + return parser.parse_args() + + +def parse_time(value: str) -> str: + if not value: + return "" + try: + datetime.fromisoformat(value.replace("Z", "+00:00")) + return value + except ValueError: + raise ValueError(f"invalid time format: {value}") + + +def fetch_ws90(conn, site, start, end): + sql = """ + SELECT ts, temperature_c, humidity, wind_avg_m_s, wind_max_m_s, wind_dir_deg, rain_mm + FROM observations_ws90 + WHERE site = %s + AND (%s = '' OR ts >= %s::timestamptz) + AND (%s = '' OR ts <= %s::timestamptz) + ORDER BY ts ASC + """ + return pd.read_sql_query(sql, conn, params=(site, start, start, end, end), parse_dates=["ts"]) + + +def fetch_baro(conn, site, start, end): + sql = """ + SELECT ts, pressure_hpa + FROM observations_baro + WHERE site = %s + AND (%s = '' OR ts >= %s::timestamptz) + AND (%s = '' OR ts <= %s::timestamptz) + ORDER BY ts ASC + """ + return pd.read_sql_query(sql, conn, params=(site, start, start, end, end), parse_dates=["ts"]) + + +def build_dataset(ws90: pd.DataFrame, baro: pd.DataFrame) -> pd.DataFrame: + if ws90.empty: + raise RuntimeError("no ws90 observations found") + if baro.empty: + raise RuntimeError("no barometer observations found") + + ws90 = ws90.set_index("ts").sort_index() + baro = baro.set_index("ts").sort_index() + + ws90_5m = ws90.resample("5min").agg( + { + "temperature_c": "mean", + "humidity": "mean", + "wind_avg_m_s": "mean", + "wind_max_m_s": "max", + "wind_dir_deg": "mean", + "rain_mm": "last", + } + ) + baro_5m = baro.resample("5min").mean() + + df = ws90_5m.join(baro_5m, how="outer") + df["pressure_hpa"] = df["pressure_hpa"].interpolate(limit=3) + + # Compute incremental rain and future 1-hour sum. + df["rain_inc"] = df["rain_mm"].diff().clip(lower=0) + window = 12 # 12 * 5min = 1 hour + df["rain_next_1h_mm"] = df["rain_inc"].rolling(window=window, min_periods=1).sum().shift(-(window - 1)) + df["rain_next_1h"] = df["rain_next_1h_mm"] >= 0.2 + + # Pressure trend over the previous hour. + df["pressure_trend_1h"] = df["pressure_hpa"] - df["pressure_hpa"].shift(12) + + return df + + +def train_model(df: pd.DataFrame): + feature_cols = [ + "pressure_trend_1h", + "humidity", + "temperature_c", + "wind_avg_m_s", + "wind_max_m_s", + ] + + df = df.dropna(subset=feature_cols + ["rain_next_1h"]) + if len(df) < 200: + raise RuntimeError("not enough data after filtering (need >= 200 rows)") + + X = df[feature_cols] + y = df["rain_next_1h"].astype(int) + + split_idx = int(len(df) * 0.8) + X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:] + y_train, y_test = y.iloc[:split_idx], y.iloc[split_idx:] + + model = Pipeline( + [ + ("scaler", StandardScaler()), + ("clf", LogisticRegression(max_iter=1000, class_weight="balanced")), + ] + ) + model.fit(X_train, y_train) + + y_pred = model.predict(X_test) + y_prob = model.predict_proba(X_test)[:, 1] + + metrics = { + "rows": len(df), + "train_rows": len(X_train), + "test_rows": len(X_test), + "accuracy": accuracy_score(y_test, y_pred), + "precision": precision_score(y_test, y_pred, zero_division=0), + "recall": recall_score(y_test, y_pred, zero_division=0), + "roc_auc": roc_auc_score(y_test, y_prob), + "confusion_matrix": confusion_matrix(y_test, y_pred).tolist(), + } + + return model, metrics, feature_cols + + +def main() -> int: + args = parse_args() + if not args.db_url: + raise SystemExit("missing --db-url or DATABASE_URL") + + start = parse_time(args.start) if args.start else "" + end = parse_time(args.end) if args.end else "" + + with psycopg2.connect(args.db_url) as conn: + ws90 = fetch_ws90(conn, args.site, start, end) + baro = fetch_baro(conn, args.site, start, end) + + df = build_dataset(ws90, baro) + model, metrics, features = train_model(df) + + print("Rain model metrics:") + for k, v in metrics.items(): + print(f" {k}: {v}") + + if args.out: + out_dir = os.path.dirname(args.out) + if out_dir: + os.makedirs(out_dir, exist_ok=True) + if joblib is None: + print("joblib not installed; skipping model save.") + else: + joblib.dump({"model": model, "features": features}, args.out) + print(f"Saved model to {args.out}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())