add model training
This commit is contained in:
10
Dockerfile.train
Normal file
10
Dockerfile.train
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
FROM python:3.11-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY scripts/requirements.txt /app/requirements.txt
|
||||||
|
RUN pip install --no-cache-dir -r /app/requirements.txt
|
||||||
|
|
||||||
|
COPY scripts/train_rain_model.py /app/train_rain_model.py
|
||||||
|
|
||||||
|
ENTRYPOINT ["python", "/app/train_rain_model.py"]
|
||||||
@@ -289,6 +289,60 @@ function updateText(id, text) {
|
|||||||
if (el) el.textContent = text;
|
if (el) el.textContent = text;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function lastNonNull(points, key) {
|
||||||
|
for (let i = points.length - 1; i >= 0; i -= 1) {
|
||||||
|
const v = points[i][key];
|
||||||
|
if (v !== null && v !== undefined) {
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
function computeRainProbability(latest, pressureTrend1h) {
|
||||||
|
if (!latest) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
let prob = 0.1;
|
||||||
|
if (pressureTrend1h !== null && pressureTrend1h !== undefined) {
|
||||||
|
if (pressureTrend1h <= -3.0) {
|
||||||
|
prob += 0.5;
|
||||||
|
} else if (pressureTrend1h <= -2.0) {
|
||||||
|
prob += 0.35;
|
||||||
|
} else if (pressureTrend1h <= -1.0) {
|
||||||
|
prob += 0.2;
|
||||||
|
} else if (pressureTrend1h <= -0.5) {
|
||||||
|
prob += 0.1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (latest.rh !== null && latest.rh !== undefined) {
|
||||||
|
if (latest.rh >= 95) {
|
||||||
|
prob += 0.2;
|
||||||
|
} else if (latest.rh >= 90) {
|
||||||
|
prob += 0.15;
|
||||||
|
} else if (latest.rh >= 85) {
|
||||||
|
prob += 0.1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (latest.wind_m_s !== null && latest.wind_m_s !== undefined && latest.wind_m_s >= 6) {
|
||||||
|
prob += 0.05;
|
||||||
|
}
|
||||||
|
|
||||||
|
prob = Math.max(0.05, Math.min(0.95, prob));
|
||||||
|
|
||||||
|
let label = "Low";
|
||||||
|
if (prob >= 0.6) {
|
||||||
|
label = "High";
|
||||||
|
} else if (prob >= 0.35) {
|
||||||
|
label = "Medium";
|
||||||
|
}
|
||||||
|
|
||||||
|
return { prob, label };
|
||||||
|
}
|
||||||
|
|
||||||
function updateSiteMeta(site, model, tzLabel) {
|
function updateSiteMeta(site, model, tzLabel) {
|
||||||
const home = document.getElementById("site-home");
|
const home = document.getElementById("site-home");
|
||||||
const suffix = document.getElementById("site-meta-suffix");
|
const suffix = document.getElementById("site-meta-suffix");
|
||||||
@@ -400,6 +454,13 @@ function renderDashboard(data) {
|
|||||||
|
|
||||||
const obsFiltered = filterRange(obs, rangeStart, rangeEnd);
|
const obsFiltered = filterRange(obs, rangeStart, rangeEnd);
|
||||||
const forecast = filterRange(forecastAll, rangeStart, rangeEnd);
|
const forecast = filterRange(forecastAll, rangeStart, rangeEnd);
|
||||||
|
const lastPressureTrend = lastNonNull(obsFiltered, "pressure_trend_1h");
|
||||||
|
const rainProb = computeRainProbability(latest, lastPressureTrend);
|
||||||
|
if (rainProb) {
|
||||||
|
updateText("live-rain-prob", `${Math.round(rainProb.prob * 100)}% (${rainProb.label})`);
|
||||||
|
} else {
|
||||||
|
updateText("live-rain-prob", "--");
|
||||||
|
}
|
||||||
|
|
||||||
const obsTemps = obsFiltered.map((p) => p.temp_c);
|
const obsTemps = obsFiltered.map((p) => p.temp_c);
|
||||||
const obsWinds = obsFiltered.map((p) => p.wind_m_s);
|
const obsWinds = obsFiltered.map((p) => p.wind_m_s);
|
||||||
|
|||||||
@@ -58,6 +58,10 @@
|
|||||||
<div class="label">Pressure hPa</div>
|
<div class="label">Pressure hPa</div>
|
||||||
<div class="value" id="live-pressure">--</div>
|
<div class="value" id="live-pressure">--</div>
|
||||||
</div>
|
</div>
|
||||||
|
<div class="metric">
|
||||||
|
<div class="label">Rain 1h %</div>
|
||||||
|
<div class="value" id="live-rain-prob">--</div>
|
||||||
|
</div>
|
||||||
<div class="metric">
|
<div class="metric">
|
||||||
<div class="label">Wind m/s</div>
|
<div class="label">Wind m/s</div>
|
||||||
<div class="value" id="live-wind">--</div>
|
<div class="value" id="live-wind">--</div>
|
||||||
|
|||||||
102
docs/rain_prediction.md
Normal file
102
docs/rain_prediction.md
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
# Rain Prediction (Next 1 Hour)
|
||||||
|
|
||||||
|
This project now includes a starter training script for a **binary rain prediction**:
|
||||||
|
|
||||||
|
> **Will we see >= 0.2 mm of rain in the next hour?**
|
||||||
|
|
||||||
|
It uses local observations (WS90 + barometric pressure) and trains a lightweight
|
||||||
|
logistic regression model. This is a baseline you can iterate on as you collect
|
||||||
|
more data.
|
||||||
|
|
||||||
|
## What the script does
|
||||||
|
- Pulls data from TimescaleDB.
|
||||||
|
- Resamples observations to 5-minute buckets.
|
||||||
|
- Derives **pressure trend (1h)** from barometer data.
|
||||||
|
- Computes **future 1-hour rainfall** from the cumulative `rain_mm` counter.
|
||||||
|
- Trains a model and prints evaluation metrics.
|
||||||
|
|
||||||
|
The output is a saved model file (optional) you can use later for inference.
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
Python 3.10+ and:
|
||||||
|
|
||||||
|
```
|
||||||
|
pandas
|
||||||
|
numpy
|
||||||
|
scikit-learn
|
||||||
|
psycopg2-binary
|
||||||
|
joblib
|
||||||
|
```
|
||||||
|
|
||||||
|
Install with:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python3 -m venv .venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
pip install -r scripts/requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python scripts/train_rain_model.py \
|
||||||
|
--db-url "postgres://postgres:postgres@localhost:5432/micrometeo?sslmode=disable" \
|
||||||
|
--site "home" \
|
||||||
|
--start "2026-01-01" \
|
||||||
|
--end "2026-02-01" \
|
||||||
|
--out "models/rain_model.pkl"
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also provide the connection string via `DATABASE_URL`:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
export DATABASE_URL="postgres://postgres:postgres@localhost:5432/micrometeo?sslmode=disable"
|
||||||
|
python scripts/train_rain_model.py --site home
|
||||||
|
```
|
||||||
|
|
||||||
|
## Output
|
||||||
|
The script prints metrics including:
|
||||||
|
- accuracy
|
||||||
|
- precision / recall
|
||||||
|
- ROC AUC
|
||||||
|
- confusion matrix
|
||||||
|
|
||||||
|
If `joblib` is installed, it saves a model bundle:
|
||||||
|
|
||||||
|
```
|
||||||
|
models/rain_model.pkl
|
||||||
|
```
|
||||||
|
|
||||||
|
This bundle contains:
|
||||||
|
- The trained model pipeline
|
||||||
|
- The feature list used during training
|
||||||
|
|
||||||
|
## Data needs / when to run
|
||||||
|
For a reliable model, you will want:
|
||||||
|
- **At least 2-4 weeks** of observations
|
||||||
|
- A mix of rainy and non-rainy periods
|
||||||
|
|
||||||
|
Training with only a few days will produce an unstable model.
|
||||||
|
|
||||||
|
## Features used
|
||||||
|
The baseline model uses:
|
||||||
|
- `pressure_trend_1h` (hPa)
|
||||||
|
- `humidity` (%)
|
||||||
|
- `temperature_c` (C)
|
||||||
|
- `wind_avg_m_s` (m/s)
|
||||||
|
- `wind_max_m_s` (m/s)
|
||||||
|
|
||||||
|
These are easy to expand once you have more data (e.g. add forecast features).
|
||||||
|
|
||||||
|
## Notes / assumptions
|
||||||
|
- Rain detection is based on **incremental rain** derived from the WS90
|
||||||
|
`rain_mm` cumulative counter.
|
||||||
|
- Pressure comes from `observations_baro`.
|
||||||
|
- All timestamps are treated as UTC.
|
||||||
|
|
||||||
|
## Next improvements
|
||||||
|
Ideas once more data is available:
|
||||||
|
- Add forecast precipitation and cloud cover as features
|
||||||
|
- Try gradient boosted trees (e.g. XGBoost / LightGBM)
|
||||||
|
- Train per-season models
|
||||||
|
- Calibrate probabilities (Platt scaling / isotonic regression)
|
||||||
@@ -15,6 +15,8 @@ type ObservationPoint struct {
|
|||||||
TempC *float64 `json:"temp_c,omitempty"`
|
TempC *float64 `json:"temp_c,omitempty"`
|
||||||
RH *float64 `json:"rh,omitempty"`
|
RH *float64 `json:"rh,omitempty"`
|
||||||
PressureHPA *float64 `json:"pressure_hpa,omitempty"`
|
PressureHPA *float64 `json:"pressure_hpa,omitempty"`
|
||||||
|
// PressureTrend1h is the change in pressure over the last hour (hPa).
|
||||||
|
PressureTrend1h *float64 `json:"pressure_trend_1h,omitempty"`
|
||||||
WindMS *float64 `json:"wind_m_s,omitempty"`
|
WindMS *float64 `json:"wind_m_s,omitempty"`
|
||||||
WindGustMS *float64 `json:"wind_gust_m_s,omitempty"`
|
WindGustMS *float64 `json:"wind_gust_m_s,omitempty"`
|
||||||
WindDirDeg *float64 `json:"wind_dir_deg,omitempty"`
|
WindDirDeg *float64 `json:"wind_dir_deg,omitempty"`
|
||||||
@@ -147,6 +149,23 @@ func (d *DB) ObservationSeries(ctx context.Context, site, bucket string, start,
|
|||||||
return nil, rows.Err()
|
return nil, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
indexByTime := make(map[time.Time]int, len(points))
|
||||||
|
for i := range points {
|
||||||
|
indexByTime[points[i].TS] = i
|
||||||
|
}
|
||||||
|
for i := range points {
|
||||||
|
if points[i].PressureHPA == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
target := points[i].TS.Add(-1 * time.Hour)
|
||||||
|
j, ok := indexByTime[target]
|
||||||
|
if !ok || points[j].PressureHPA == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
trend := *points[i].PressureHPA - *points[j].PressureHPA
|
||||||
|
points[i].PressureTrend1h = &trend
|
||||||
|
}
|
||||||
|
|
||||||
return points, nil
|
return points, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
5
scripts/requirements.txt
Normal file
5
scripts/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
pandas>=2.2.0
|
||||||
|
numpy>=1.26.0
|
||||||
|
scikit-learn>=1.4.0
|
||||||
|
psycopg2-binary>=2.9.0
|
||||||
|
joblib>=1.3.0
|
||||||
178
scripts/train_rain_model.py
Normal file
178
scripts/train_rain_model.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import psycopg2
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, roc_auc_score
|
||||||
|
from sklearn.pipeline import Pipeline
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
|
||||||
|
try:
|
||||||
|
import joblib
|
||||||
|
except ImportError: # pragma: no cover - optional dependency
|
||||||
|
joblib = None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(description="Train a simple rain prediction model (next 1h >= 0.2mm).")
|
||||||
|
parser.add_argument("--db-url", default=os.getenv("DATABASE_URL"), help="Postgres connection string.")
|
||||||
|
parser.add_argument("--site", required=True, help="Site name (e.g. home).")
|
||||||
|
parser.add_argument("--start", help="Start time (RFC3339 or YYYY-MM-DD).")
|
||||||
|
parser.add_argument("--end", help="End time (RFC3339 or YYYY-MM-DD).")
|
||||||
|
parser.add_argument("--out", default="models/rain_model.pkl", help="Path to save model.")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_time(value: str) -> str:
|
||||||
|
if not value:
|
||||||
|
return ""
|
||||||
|
try:
|
||||||
|
datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||||
|
return value
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"invalid time format: {value}")
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_ws90(conn, site, start, end):
|
||||||
|
sql = """
|
||||||
|
SELECT ts, temperature_c, humidity, wind_avg_m_s, wind_max_m_s, wind_dir_deg, rain_mm
|
||||||
|
FROM observations_ws90
|
||||||
|
WHERE site = %s
|
||||||
|
AND (%s = '' OR ts >= %s::timestamptz)
|
||||||
|
AND (%s = '' OR ts <= %s::timestamptz)
|
||||||
|
ORDER BY ts ASC
|
||||||
|
"""
|
||||||
|
return pd.read_sql_query(sql, conn, params=(site, start, start, end, end), parse_dates=["ts"])
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_baro(conn, site, start, end):
|
||||||
|
sql = """
|
||||||
|
SELECT ts, pressure_hpa
|
||||||
|
FROM observations_baro
|
||||||
|
WHERE site = %s
|
||||||
|
AND (%s = '' OR ts >= %s::timestamptz)
|
||||||
|
AND (%s = '' OR ts <= %s::timestamptz)
|
||||||
|
ORDER BY ts ASC
|
||||||
|
"""
|
||||||
|
return pd.read_sql_query(sql, conn, params=(site, start, start, end, end), parse_dates=["ts"])
|
||||||
|
|
||||||
|
|
||||||
|
def build_dataset(ws90: pd.DataFrame, baro: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
if ws90.empty:
|
||||||
|
raise RuntimeError("no ws90 observations found")
|
||||||
|
if baro.empty:
|
||||||
|
raise RuntimeError("no barometer observations found")
|
||||||
|
|
||||||
|
ws90 = ws90.set_index("ts").sort_index()
|
||||||
|
baro = baro.set_index("ts").sort_index()
|
||||||
|
|
||||||
|
ws90_5m = ws90.resample("5min").agg(
|
||||||
|
{
|
||||||
|
"temperature_c": "mean",
|
||||||
|
"humidity": "mean",
|
||||||
|
"wind_avg_m_s": "mean",
|
||||||
|
"wind_max_m_s": "max",
|
||||||
|
"wind_dir_deg": "mean",
|
||||||
|
"rain_mm": "last",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
baro_5m = baro.resample("5min").mean()
|
||||||
|
|
||||||
|
df = ws90_5m.join(baro_5m, how="outer")
|
||||||
|
df["pressure_hpa"] = df["pressure_hpa"].interpolate(limit=3)
|
||||||
|
|
||||||
|
# Compute incremental rain and future 1-hour sum.
|
||||||
|
df["rain_inc"] = df["rain_mm"].diff().clip(lower=0)
|
||||||
|
window = 12 # 12 * 5min = 1 hour
|
||||||
|
df["rain_next_1h_mm"] = df["rain_inc"].rolling(window=window, min_periods=1).sum().shift(-(window - 1))
|
||||||
|
df["rain_next_1h"] = df["rain_next_1h_mm"] >= 0.2
|
||||||
|
|
||||||
|
# Pressure trend over the previous hour.
|
||||||
|
df["pressure_trend_1h"] = df["pressure_hpa"] - df["pressure_hpa"].shift(12)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def train_model(df: pd.DataFrame):
|
||||||
|
feature_cols = [
|
||||||
|
"pressure_trend_1h",
|
||||||
|
"humidity",
|
||||||
|
"temperature_c",
|
||||||
|
"wind_avg_m_s",
|
||||||
|
"wind_max_m_s",
|
||||||
|
]
|
||||||
|
|
||||||
|
df = df.dropna(subset=feature_cols + ["rain_next_1h"])
|
||||||
|
if len(df) < 200:
|
||||||
|
raise RuntimeError("not enough data after filtering (need >= 200 rows)")
|
||||||
|
|
||||||
|
X = df[feature_cols]
|
||||||
|
y = df["rain_next_1h"].astype(int)
|
||||||
|
|
||||||
|
split_idx = int(len(df) * 0.8)
|
||||||
|
X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:]
|
||||||
|
y_train, y_test = y.iloc[:split_idx], y.iloc[split_idx:]
|
||||||
|
|
||||||
|
model = Pipeline(
|
||||||
|
[
|
||||||
|
("scaler", StandardScaler()),
|
||||||
|
("clf", LogisticRegression(max_iter=1000, class_weight="balanced")),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
model.fit(X_train, y_train)
|
||||||
|
|
||||||
|
y_pred = model.predict(X_test)
|
||||||
|
y_prob = model.predict_proba(X_test)[:, 1]
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"rows": len(df),
|
||||||
|
"train_rows": len(X_train),
|
||||||
|
"test_rows": len(X_test),
|
||||||
|
"accuracy": accuracy_score(y_test, y_pred),
|
||||||
|
"precision": precision_score(y_test, y_pred, zero_division=0),
|
||||||
|
"recall": recall_score(y_test, y_pred, zero_division=0),
|
||||||
|
"roc_auc": roc_auc_score(y_test, y_prob),
|
||||||
|
"confusion_matrix": confusion_matrix(y_test, y_pred).tolist(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return model, metrics, feature_cols
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
args = parse_args()
|
||||||
|
if not args.db_url:
|
||||||
|
raise SystemExit("missing --db-url or DATABASE_URL")
|
||||||
|
|
||||||
|
start = parse_time(args.start) if args.start else ""
|
||||||
|
end = parse_time(args.end) if args.end else ""
|
||||||
|
|
||||||
|
with psycopg2.connect(args.db_url) as conn:
|
||||||
|
ws90 = fetch_ws90(conn, args.site, start, end)
|
||||||
|
baro = fetch_baro(conn, args.site, start, end)
|
||||||
|
|
||||||
|
df = build_dataset(ws90, baro)
|
||||||
|
model, metrics, features = train_model(df)
|
||||||
|
|
||||||
|
print("Rain model metrics:")
|
||||||
|
for k, v in metrics.items():
|
||||||
|
print(f" {k}: {v}")
|
||||||
|
|
||||||
|
if args.out:
|
||||||
|
out_dir = os.path.dirname(args.out)
|
||||||
|
if out_dir:
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
if joblib is None:
|
||||||
|
print("joblib not installed; skipping model save.")
|
||||||
|
else:
|
||||||
|
joblib.dump({"model": model, "features": features}, args.out)
|
||||||
|
print(f"Saved model to {args.out}")
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
Reference in New Issue
Block a user