add model training

This commit is contained in:
2026-02-02 17:08:43 +11:00
parent 737eef85ea
commit 8edd0dc8b0
7 changed files with 388 additions and 9 deletions

10
Dockerfile.train Normal file
View File

@@ -0,0 +1,10 @@
FROM python:3.11-slim
WORKDIR /app
COPY scripts/requirements.txt /app/requirements.txt
RUN pip install --no-cache-dir -r /app/requirements.txt
COPY scripts/train_rain_model.py /app/train_rain_model.py
ENTRYPOINT ["python", "/app/train_rain_model.py"]

View File

@@ -289,6 +289,60 @@ function updateText(id, text) {
if (el) el.textContent = text; if (el) el.textContent = text;
} }
function lastNonNull(points, key) {
for (let i = points.length - 1; i >= 0; i -= 1) {
const v = points[i][key];
if (v !== null && v !== undefined) {
return v;
}
}
return null;
}
function computeRainProbability(latest, pressureTrend1h) {
if (!latest) {
return null;
}
let prob = 0.1;
if (pressureTrend1h !== null && pressureTrend1h !== undefined) {
if (pressureTrend1h <= -3.0) {
prob += 0.5;
} else if (pressureTrend1h <= -2.0) {
prob += 0.35;
} else if (pressureTrend1h <= -1.0) {
prob += 0.2;
} else if (pressureTrend1h <= -0.5) {
prob += 0.1;
}
}
if (latest.rh !== null && latest.rh !== undefined) {
if (latest.rh >= 95) {
prob += 0.2;
} else if (latest.rh >= 90) {
prob += 0.15;
} else if (latest.rh >= 85) {
prob += 0.1;
}
}
if (latest.wind_m_s !== null && latest.wind_m_s !== undefined && latest.wind_m_s >= 6) {
prob += 0.05;
}
prob = Math.max(0.05, Math.min(0.95, prob));
let label = "Low";
if (prob >= 0.6) {
label = "High";
} else if (prob >= 0.35) {
label = "Medium";
}
return { prob, label };
}
function updateSiteMeta(site, model, tzLabel) { function updateSiteMeta(site, model, tzLabel) {
const home = document.getElementById("site-home"); const home = document.getElementById("site-home");
const suffix = document.getElementById("site-meta-suffix"); const suffix = document.getElementById("site-meta-suffix");
@@ -400,6 +454,13 @@ function renderDashboard(data) {
const obsFiltered = filterRange(obs, rangeStart, rangeEnd); const obsFiltered = filterRange(obs, rangeStart, rangeEnd);
const forecast = filterRange(forecastAll, rangeStart, rangeEnd); const forecast = filterRange(forecastAll, rangeStart, rangeEnd);
const lastPressureTrend = lastNonNull(obsFiltered, "pressure_trend_1h");
const rainProb = computeRainProbability(latest, lastPressureTrend);
if (rainProb) {
updateText("live-rain-prob", `${Math.round(rainProb.prob * 100)}% (${rainProb.label})`);
} else {
updateText("live-rain-prob", "--");
}
const obsTemps = obsFiltered.map((p) => p.temp_c); const obsTemps = obsFiltered.map((p) => p.temp_c);
const obsWinds = obsFiltered.map((p) => p.wind_m_s); const obsWinds = obsFiltered.map((p) => p.wind_m_s);

View File

@@ -58,6 +58,10 @@
<div class="label">Pressure hPa</div> <div class="label">Pressure hPa</div>
<div class="value" id="live-pressure">--</div> <div class="value" id="live-pressure">--</div>
</div> </div>
<div class="metric">
<div class="label">Rain 1h %</div>
<div class="value" id="live-rain-prob">--</div>
</div>
<div class="metric"> <div class="metric">
<div class="label">Wind m/s</div> <div class="label">Wind m/s</div>
<div class="value" id="live-wind">--</div> <div class="value" id="live-wind">--</div>

102
docs/rain_prediction.md Normal file
View File

@@ -0,0 +1,102 @@
# Rain Prediction (Next 1 Hour)
This project now includes a starter training script for a **binary rain prediction**:
> **Will we see >= 0.2 mm of rain in the next hour?**
It uses local observations (WS90 + barometric pressure) and trains a lightweight
logistic regression model. This is a baseline you can iterate on as you collect
more data.
## What the script does
- Pulls data from TimescaleDB.
- Resamples observations to 5-minute buckets.
- Derives **pressure trend (1h)** from barometer data.
- Computes **future 1-hour rainfall** from the cumulative `rain_mm` counter.
- Trains a model and prints evaluation metrics.
The output is a saved model file (optional) you can use later for inference.
## Requirements
Python 3.10+ and:
```
pandas
numpy
scikit-learn
psycopg2-binary
joblib
```
Install with:
```sh
python3 -m venv .venv
source .venv/bin/activate
pip install -r scripts/requirements.txt
```
## Usage
```sh
python scripts/train_rain_model.py \
--db-url "postgres://postgres:postgres@localhost:5432/micrometeo?sslmode=disable" \
--site "home" \
--start "2026-01-01" \
--end "2026-02-01" \
--out "models/rain_model.pkl"
```
You can also provide the connection string via `DATABASE_URL`:
```sh
export DATABASE_URL="postgres://postgres:postgres@localhost:5432/micrometeo?sslmode=disable"
python scripts/train_rain_model.py --site home
```
## Output
The script prints metrics including:
- accuracy
- precision / recall
- ROC AUC
- confusion matrix
If `joblib` is installed, it saves a model bundle:
```
models/rain_model.pkl
```
This bundle contains:
- The trained model pipeline
- The feature list used during training
## Data needs / when to run
For a reliable model, you will want:
- **At least 2-4 weeks** of observations
- A mix of rainy and non-rainy periods
Training with only a few days will produce an unstable model.
## Features used
The baseline model uses:
- `pressure_trend_1h` (hPa)
- `humidity` (%)
- `temperature_c` (C)
- `wind_avg_m_s` (m/s)
- `wind_max_m_s` (m/s)
These are easy to expand once you have more data (e.g. add forecast features).
## Notes / assumptions
- Rain detection is based on **incremental rain** derived from the WS90
`rain_mm` cumulative counter.
- Pressure comes from `observations_baro`.
- All timestamps are treated as UTC.
## Next improvements
Ideas once more data is available:
- Add forecast precipitation and cloud cover as features
- Try gradient boosted trees (e.g. XGBoost / LightGBM)
- Train per-season models
- Calibrate probabilities (Platt scaling / isotonic regression)

View File

@@ -15,6 +15,8 @@ type ObservationPoint struct {
TempC *float64 `json:"temp_c,omitempty"` TempC *float64 `json:"temp_c,omitempty"`
RH *float64 `json:"rh,omitempty"` RH *float64 `json:"rh,omitempty"`
PressureHPA *float64 `json:"pressure_hpa,omitempty"` PressureHPA *float64 `json:"pressure_hpa,omitempty"`
// PressureTrend1h is the change in pressure over the last hour (hPa).
PressureTrend1h *float64 `json:"pressure_trend_1h,omitempty"`
WindMS *float64 `json:"wind_m_s,omitempty"` WindMS *float64 `json:"wind_m_s,omitempty"`
WindGustMS *float64 `json:"wind_gust_m_s,omitempty"` WindGustMS *float64 `json:"wind_gust_m_s,omitempty"`
WindDirDeg *float64 `json:"wind_dir_deg,omitempty"` WindDirDeg *float64 `json:"wind_dir_deg,omitempty"`
@@ -147,6 +149,23 @@ func (d *DB) ObservationSeries(ctx context.Context, site, bucket string, start,
return nil, rows.Err() return nil, rows.Err()
} }
indexByTime := make(map[time.Time]int, len(points))
for i := range points {
indexByTime[points[i].TS] = i
}
for i := range points {
if points[i].PressureHPA == nil {
continue
}
target := points[i].TS.Add(-1 * time.Hour)
j, ok := indexByTime[target]
if !ok || points[j].PressureHPA == nil {
continue
}
trend := *points[i].PressureHPA - *points[j].PressureHPA
points[i].PressureTrend1h = &trend
}
return points, nil return points, nil
} }

5
scripts/requirements.txt Normal file
View File

@@ -0,0 +1,5 @@
pandas>=2.2.0
numpy>=1.26.0
scikit-learn>=1.4.0
psycopg2-binary>=2.9.0
joblib>=1.3.0

178
scripts/train_rain_model.py Normal file
View File

@@ -0,0 +1,178 @@
#!/usr/bin/env python3
import argparse
import os
from datetime import datetime
import numpy as np
import pandas as pd
import psycopg2
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, roc_auc_score
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
try:
import joblib
except ImportError: # pragma: no cover - optional dependency
joblib = None
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train a simple rain prediction model (next 1h >= 0.2mm).")
parser.add_argument("--db-url", default=os.getenv("DATABASE_URL"), help="Postgres connection string.")
parser.add_argument("--site", required=True, help="Site name (e.g. home).")
parser.add_argument("--start", help="Start time (RFC3339 or YYYY-MM-DD).")
parser.add_argument("--end", help="End time (RFC3339 or YYYY-MM-DD).")
parser.add_argument("--out", default="models/rain_model.pkl", help="Path to save model.")
return parser.parse_args()
def parse_time(value: str) -> str:
if not value:
return ""
try:
datetime.fromisoformat(value.replace("Z", "+00:00"))
return value
except ValueError:
raise ValueError(f"invalid time format: {value}")
def fetch_ws90(conn, site, start, end):
sql = """
SELECT ts, temperature_c, humidity, wind_avg_m_s, wind_max_m_s, wind_dir_deg, rain_mm
FROM observations_ws90
WHERE site = %s
AND (%s = '' OR ts >= %s::timestamptz)
AND (%s = '' OR ts <= %s::timestamptz)
ORDER BY ts ASC
"""
return pd.read_sql_query(sql, conn, params=(site, start, start, end, end), parse_dates=["ts"])
def fetch_baro(conn, site, start, end):
sql = """
SELECT ts, pressure_hpa
FROM observations_baro
WHERE site = %s
AND (%s = '' OR ts >= %s::timestamptz)
AND (%s = '' OR ts <= %s::timestamptz)
ORDER BY ts ASC
"""
return pd.read_sql_query(sql, conn, params=(site, start, start, end, end), parse_dates=["ts"])
def build_dataset(ws90: pd.DataFrame, baro: pd.DataFrame) -> pd.DataFrame:
if ws90.empty:
raise RuntimeError("no ws90 observations found")
if baro.empty:
raise RuntimeError("no barometer observations found")
ws90 = ws90.set_index("ts").sort_index()
baro = baro.set_index("ts").sort_index()
ws90_5m = ws90.resample("5min").agg(
{
"temperature_c": "mean",
"humidity": "mean",
"wind_avg_m_s": "mean",
"wind_max_m_s": "max",
"wind_dir_deg": "mean",
"rain_mm": "last",
}
)
baro_5m = baro.resample("5min").mean()
df = ws90_5m.join(baro_5m, how="outer")
df["pressure_hpa"] = df["pressure_hpa"].interpolate(limit=3)
# Compute incremental rain and future 1-hour sum.
df["rain_inc"] = df["rain_mm"].diff().clip(lower=0)
window = 12 # 12 * 5min = 1 hour
df["rain_next_1h_mm"] = df["rain_inc"].rolling(window=window, min_periods=1).sum().shift(-(window - 1))
df["rain_next_1h"] = df["rain_next_1h_mm"] >= 0.2
# Pressure trend over the previous hour.
df["pressure_trend_1h"] = df["pressure_hpa"] - df["pressure_hpa"].shift(12)
return df
def train_model(df: pd.DataFrame):
feature_cols = [
"pressure_trend_1h",
"humidity",
"temperature_c",
"wind_avg_m_s",
"wind_max_m_s",
]
df = df.dropna(subset=feature_cols + ["rain_next_1h"])
if len(df) < 200:
raise RuntimeError("not enough data after filtering (need >= 200 rows)")
X = df[feature_cols]
y = df["rain_next_1h"].astype(int)
split_idx = int(len(df) * 0.8)
X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:]
y_train, y_test = y.iloc[:split_idx], y.iloc[split_idx:]
model = Pipeline(
[
("scaler", StandardScaler()),
("clf", LogisticRegression(max_iter=1000, class_weight="balanced")),
]
)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
y_prob = model.predict_proba(X_test)[:, 1]
metrics = {
"rows": len(df),
"train_rows": len(X_train),
"test_rows": len(X_test),
"accuracy": accuracy_score(y_test, y_pred),
"precision": precision_score(y_test, y_pred, zero_division=0),
"recall": recall_score(y_test, y_pred, zero_division=0),
"roc_auc": roc_auc_score(y_test, y_prob),
"confusion_matrix": confusion_matrix(y_test, y_pred).tolist(),
}
return model, metrics, feature_cols
def main() -> int:
args = parse_args()
if not args.db_url:
raise SystemExit("missing --db-url or DATABASE_URL")
start = parse_time(args.start) if args.start else ""
end = parse_time(args.end) if args.end else ""
with psycopg2.connect(args.db_url) as conn:
ws90 = fetch_ws90(conn, args.site, start, end)
baro = fetch_baro(conn, args.site, start, end)
df = build_dataset(ws90, baro)
model, metrics, features = train_model(df)
print("Rain model metrics:")
for k, v in metrics.items():
print(f" {k}: {v}")
if args.out:
out_dir = os.path.dirname(args.out)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
if joblib is None:
print("joblib not installed; skipping model save.")
else:
joblib.dump({"model": model, "features": features}, args.out)
print(f"Saved model to {args.out}")
return 0
if __name__ == "__main__":
raise SystemExit(main())