add model training
This commit is contained in:
10
Dockerfile.train
Normal file
10
Dockerfile.train
Normal file
@@ -0,0 +1,10 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY scripts/requirements.txt /app/requirements.txt
|
||||
RUN pip install --no-cache-dir -r /app/requirements.txt
|
||||
|
||||
COPY scripts/train_rain_model.py /app/train_rain_model.py
|
||||
|
||||
ENTRYPOINT ["python", "/app/train_rain_model.py"]
|
||||
@@ -289,6 +289,60 @@ function updateText(id, text) {
|
||||
if (el) el.textContent = text;
|
||||
}
|
||||
|
||||
function lastNonNull(points, key) {
|
||||
for (let i = points.length - 1; i >= 0; i -= 1) {
|
||||
const v = points[i][key];
|
||||
if (v !== null && v !== undefined) {
|
||||
return v;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function computeRainProbability(latest, pressureTrend1h) {
|
||||
if (!latest) {
|
||||
return null;
|
||||
}
|
||||
|
||||
let prob = 0.1;
|
||||
if (pressureTrend1h !== null && pressureTrend1h !== undefined) {
|
||||
if (pressureTrend1h <= -3.0) {
|
||||
prob += 0.5;
|
||||
} else if (pressureTrend1h <= -2.0) {
|
||||
prob += 0.35;
|
||||
} else if (pressureTrend1h <= -1.0) {
|
||||
prob += 0.2;
|
||||
} else if (pressureTrend1h <= -0.5) {
|
||||
prob += 0.1;
|
||||
}
|
||||
}
|
||||
|
||||
if (latest.rh !== null && latest.rh !== undefined) {
|
||||
if (latest.rh >= 95) {
|
||||
prob += 0.2;
|
||||
} else if (latest.rh >= 90) {
|
||||
prob += 0.15;
|
||||
} else if (latest.rh >= 85) {
|
||||
prob += 0.1;
|
||||
}
|
||||
}
|
||||
|
||||
if (latest.wind_m_s !== null && latest.wind_m_s !== undefined && latest.wind_m_s >= 6) {
|
||||
prob += 0.05;
|
||||
}
|
||||
|
||||
prob = Math.max(0.05, Math.min(0.95, prob));
|
||||
|
||||
let label = "Low";
|
||||
if (prob >= 0.6) {
|
||||
label = "High";
|
||||
} else if (prob >= 0.35) {
|
||||
label = "Medium";
|
||||
}
|
||||
|
||||
return { prob, label };
|
||||
}
|
||||
|
||||
function updateSiteMeta(site, model, tzLabel) {
|
||||
const home = document.getElementById("site-home");
|
||||
const suffix = document.getElementById("site-meta-suffix");
|
||||
@@ -400,6 +454,13 @@ function renderDashboard(data) {
|
||||
|
||||
const obsFiltered = filterRange(obs, rangeStart, rangeEnd);
|
||||
const forecast = filterRange(forecastAll, rangeStart, rangeEnd);
|
||||
const lastPressureTrend = lastNonNull(obsFiltered, "pressure_trend_1h");
|
||||
const rainProb = computeRainProbability(latest, lastPressureTrend);
|
||||
if (rainProb) {
|
||||
updateText("live-rain-prob", `${Math.round(rainProb.prob * 100)}% (${rainProb.label})`);
|
||||
} else {
|
||||
updateText("live-rain-prob", "--");
|
||||
}
|
||||
|
||||
const obsTemps = obsFiltered.map((p) => p.temp_c);
|
||||
const obsWinds = obsFiltered.map((p) => p.wind_m_s);
|
||||
|
||||
@@ -58,6 +58,10 @@
|
||||
<div class="label">Pressure hPa</div>
|
||||
<div class="value" id="live-pressure">--</div>
|
||||
</div>
|
||||
<div class="metric">
|
||||
<div class="label">Rain 1h %</div>
|
||||
<div class="value" id="live-rain-prob">--</div>
|
||||
</div>
|
||||
<div class="metric">
|
||||
<div class="label">Wind m/s</div>
|
||||
<div class="value" id="live-wind">--</div>
|
||||
|
||||
102
docs/rain_prediction.md
Normal file
102
docs/rain_prediction.md
Normal file
@@ -0,0 +1,102 @@
|
||||
# Rain Prediction (Next 1 Hour)
|
||||
|
||||
This project now includes a starter training script for a **binary rain prediction**:
|
||||
|
||||
> **Will we see >= 0.2 mm of rain in the next hour?**
|
||||
|
||||
It uses local observations (WS90 + barometric pressure) and trains a lightweight
|
||||
logistic regression model. This is a baseline you can iterate on as you collect
|
||||
more data.
|
||||
|
||||
## What the script does
|
||||
- Pulls data from TimescaleDB.
|
||||
- Resamples observations to 5-minute buckets.
|
||||
- Derives **pressure trend (1h)** from barometer data.
|
||||
- Computes **future 1-hour rainfall** from the cumulative `rain_mm` counter.
|
||||
- Trains a model and prints evaluation metrics.
|
||||
|
||||
The output is a saved model file (optional) you can use later for inference.
|
||||
|
||||
## Requirements
|
||||
Python 3.10+ and:
|
||||
|
||||
```
|
||||
pandas
|
||||
numpy
|
||||
scikit-learn
|
||||
psycopg2-binary
|
||||
joblib
|
||||
```
|
||||
|
||||
Install with:
|
||||
|
||||
```sh
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install -r scripts/requirements.txt
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```sh
|
||||
python scripts/train_rain_model.py \
|
||||
--db-url "postgres://postgres:postgres@localhost:5432/micrometeo?sslmode=disable" \
|
||||
--site "home" \
|
||||
--start "2026-01-01" \
|
||||
--end "2026-02-01" \
|
||||
--out "models/rain_model.pkl"
|
||||
```
|
||||
|
||||
You can also provide the connection string via `DATABASE_URL`:
|
||||
|
||||
```sh
|
||||
export DATABASE_URL="postgres://postgres:postgres@localhost:5432/micrometeo?sslmode=disable"
|
||||
python scripts/train_rain_model.py --site home
|
||||
```
|
||||
|
||||
## Output
|
||||
The script prints metrics including:
|
||||
- accuracy
|
||||
- precision / recall
|
||||
- ROC AUC
|
||||
- confusion matrix
|
||||
|
||||
If `joblib` is installed, it saves a model bundle:
|
||||
|
||||
```
|
||||
models/rain_model.pkl
|
||||
```
|
||||
|
||||
This bundle contains:
|
||||
- The trained model pipeline
|
||||
- The feature list used during training
|
||||
|
||||
## Data needs / when to run
|
||||
For a reliable model, you will want:
|
||||
- **At least 2-4 weeks** of observations
|
||||
- A mix of rainy and non-rainy periods
|
||||
|
||||
Training with only a few days will produce an unstable model.
|
||||
|
||||
## Features used
|
||||
The baseline model uses:
|
||||
- `pressure_trend_1h` (hPa)
|
||||
- `humidity` (%)
|
||||
- `temperature_c` (C)
|
||||
- `wind_avg_m_s` (m/s)
|
||||
- `wind_max_m_s` (m/s)
|
||||
|
||||
These are easy to expand once you have more data (e.g. add forecast features).
|
||||
|
||||
## Notes / assumptions
|
||||
- Rain detection is based on **incremental rain** derived from the WS90
|
||||
`rain_mm` cumulative counter.
|
||||
- Pressure comes from `observations_baro`.
|
||||
- All timestamps are treated as UTC.
|
||||
|
||||
## Next improvements
|
||||
Ideas once more data is available:
|
||||
- Add forecast precipitation and cloud cover as features
|
||||
- Try gradient boosted trees (e.g. XGBoost / LightGBM)
|
||||
- Train per-season models
|
||||
- Calibrate probabilities (Platt scaling / isotonic regression)
|
||||
@@ -15,15 +15,17 @@ type ObservationPoint struct {
|
||||
TempC *float64 `json:"temp_c,omitempty"`
|
||||
RH *float64 `json:"rh,omitempty"`
|
||||
PressureHPA *float64 `json:"pressure_hpa,omitempty"`
|
||||
WindMS *float64 `json:"wind_m_s,omitempty"`
|
||||
WindGustMS *float64 `json:"wind_gust_m_s,omitempty"`
|
||||
WindDirDeg *float64 `json:"wind_dir_deg,omitempty"`
|
||||
UVI *float64 `json:"uvi,omitempty"`
|
||||
LightLux *float64 `json:"light_lux,omitempty"`
|
||||
BatteryMV *float64 `json:"battery_mv,omitempty"`
|
||||
SupercapV *float64 `json:"supercap_v,omitempty"`
|
||||
RainMM *float64 `json:"rain_mm,omitempty"`
|
||||
RainStart *int64 `json:"rain_start,omitempty"`
|
||||
// PressureTrend1h is the change in pressure over the last hour (hPa).
|
||||
PressureTrend1h *float64 `json:"pressure_trend_1h,omitempty"`
|
||||
WindMS *float64 `json:"wind_m_s,omitempty"`
|
||||
WindGustMS *float64 `json:"wind_gust_m_s,omitempty"`
|
||||
WindDirDeg *float64 `json:"wind_dir_deg,omitempty"`
|
||||
UVI *float64 `json:"uvi,omitempty"`
|
||||
LightLux *float64 `json:"light_lux,omitempty"`
|
||||
BatteryMV *float64 `json:"battery_mv,omitempty"`
|
||||
SupercapV *float64 `json:"supercap_v,omitempty"`
|
||||
RainMM *float64 `json:"rain_mm,omitempty"`
|
||||
RainStart *int64 `json:"rain_start,omitempty"`
|
||||
}
|
||||
|
||||
type ForecastPoint struct {
|
||||
@@ -147,6 +149,23 @@ func (d *DB) ObservationSeries(ctx context.Context, site, bucket string, start,
|
||||
return nil, rows.Err()
|
||||
}
|
||||
|
||||
indexByTime := make(map[time.Time]int, len(points))
|
||||
for i := range points {
|
||||
indexByTime[points[i].TS] = i
|
||||
}
|
||||
for i := range points {
|
||||
if points[i].PressureHPA == nil {
|
||||
continue
|
||||
}
|
||||
target := points[i].TS.Add(-1 * time.Hour)
|
||||
j, ok := indexByTime[target]
|
||||
if !ok || points[j].PressureHPA == nil {
|
||||
continue
|
||||
}
|
||||
trend := *points[i].PressureHPA - *points[j].PressureHPA
|
||||
points[i].PressureTrend1h = &trend
|
||||
}
|
||||
|
||||
return points, nil
|
||||
}
|
||||
|
||||
|
||||
5
scripts/requirements.txt
Normal file
5
scripts/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
pandas>=2.2.0
|
||||
numpy>=1.26.0
|
||||
scikit-learn>=1.4.0
|
||||
psycopg2-binary>=2.9.0
|
||||
joblib>=1.3.0
|
||||
178
scripts/train_rain_model.py
Normal file
178
scripts/train_rain_model.py
Normal file
@@ -0,0 +1,178 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import psycopg2
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, roc_auc_score
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
try:
|
||||
import joblib
|
||||
except ImportError: # pragma: no cover - optional dependency
|
||||
joblib = None
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Train a simple rain prediction model (next 1h >= 0.2mm).")
|
||||
parser.add_argument("--db-url", default=os.getenv("DATABASE_URL"), help="Postgres connection string.")
|
||||
parser.add_argument("--site", required=True, help="Site name (e.g. home).")
|
||||
parser.add_argument("--start", help="Start time (RFC3339 or YYYY-MM-DD).")
|
||||
parser.add_argument("--end", help="End time (RFC3339 or YYYY-MM-DD).")
|
||||
parser.add_argument("--out", default="models/rain_model.pkl", help="Path to save model.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def parse_time(value: str) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
try:
|
||||
datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
return value
|
||||
except ValueError:
|
||||
raise ValueError(f"invalid time format: {value}")
|
||||
|
||||
|
||||
def fetch_ws90(conn, site, start, end):
|
||||
sql = """
|
||||
SELECT ts, temperature_c, humidity, wind_avg_m_s, wind_max_m_s, wind_dir_deg, rain_mm
|
||||
FROM observations_ws90
|
||||
WHERE site = %s
|
||||
AND (%s = '' OR ts >= %s::timestamptz)
|
||||
AND (%s = '' OR ts <= %s::timestamptz)
|
||||
ORDER BY ts ASC
|
||||
"""
|
||||
return pd.read_sql_query(sql, conn, params=(site, start, start, end, end), parse_dates=["ts"])
|
||||
|
||||
|
||||
def fetch_baro(conn, site, start, end):
|
||||
sql = """
|
||||
SELECT ts, pressure_hpa
|
||||
FROM observations_baro
|
||||
WHERE site = %s
|
||||
AND (%s = '' OR ts >= %s::timestamptz)
|
||||
AND (%s = '' OR ts <= %s::timestamptz)
|
||||
ORDER BY ts ASC
|
||||
"""
|
||||
return pd.read_sql_query(sql, conn, params=(site, start, start, end, end), parse_dates=["ts"])
|
||||
|
||||
|
||||
def build_dataset(ws90: pd.DataFrame, baro: pd.DataFrame) -> pd.DataFrame:
|
||||
if ws90.empty:
|
||||
raise RuntimeError("no ws90 observations found")
|
||||
if baro.empty:
|
||||
raise RuntimeError("no barometer observations found")
|
||||
|
||||
ws90 = ws90.set_index("ts").sort_index()
|
||||
baro = baro.set_index("ts").sort_index()
|
||||
|
||||
ws90_5m = ws90.resample("5min").agg(
|
||||
{
|
||||
"temperature_c": "mean",
|
||||
"humidity": "mean",
|
||||
"wind_avg_m_s": "mean",
|
||||
"wind_max_m_s": "max",
|
||||
"wind_dir_deg": "mean",
|
||||
"rain_mm": "last",
|
||||
}
|
||||
)
|
||||
baro_5m = baro.resample("5min").mean()
|
||||
|
||||
df = ws90_5m.join(baro_5m, how="outer")
|
||||
df["pressure_hpa"] = df["pressure_hpa"].interpolate(limit=3)
|
||||
|
||||
# Compute incremental rain and future 1-hour sum.
|
||||
df["rain_inc"] = df["rain_mm"].diff().clip(lower=0)
|
||||
window = 12 # 12 * 5min = 1 hour
|
||||
df["rain_next_1h_mm"] = df["rain_inc"].rolling(window=window, min_periods=1).sum().shift(-(window - 1))
|
||||
df["rain_next_1h"] = df["rain_next_1h_mm"] >= 0.2
|
||||
|
||||
# Pressure trend over the previous hour.
|
||||
df["pressure_trend_1h"] = df["pressure_hpa"] - df["pressure_hpa"].shift(12)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def train_model(df: pd.DataFrame):
|
||||
feature_cols = [
|
||||
"pressure_trend_1h",
|
||||
"humidity",
|
||||
"temperature_c",
|
||||
"wind_avg_m_s",
|
||||
"wind_max_m_s",
|
||||
]
|
||||
|
||||
df = df.dropna(subset=feature_cols + ["rain_next_1h"])
|
||||
if len(df) < 200:
|
||||
raise RuntimeError("not enough data after filtering (need >= 200 rows)")
|
||||
|
||||
X = df[feature_cols]
|
||||
y = df["rain_next_1h"].astype(int)
|
||||
|
||||
split_idx = int(len(df) * 0.8)
|
||||
X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:]
|
||||
y_train, y_test = y.iloc[:split_idx], y.iloc[split_idx:]
|
||||
|
||||
model = Pipeline(
|
||||
[
|
||||
("scaler", StandardScaler()),
|
||||
("clf", LogisticRegression(max_iter=1000, class_weight="balanced")),
|
||||
]
|
||||
)
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
y_pred = model.predict(X_test)
|
||||
y_prob = model.predict_proba(X_test)[:, 1]
|
||||
|
||||
metrics = {
|
||||
"rows": len(df),
|
||||
"train_rows": len(X_train),
|
||||
"test_rows": len(X_test),
|
||||
"accuracy": accuracy_score(y_test, y_pred),
|
||||
"precision": precision_score(y_test, y_pred, zero_division=0),
|
||||
"recall": recall_score(y_test, y_pred, zero_division=0),
|
||||
"roc_auc": roc_auc_score(y_test, y_prob),
|
||||
"confusion_matrix": confusion_matrix(y_test, y_pred).tolist(),
|
||||
}
|
||||
|
||||
return model, metrics, feature_cols
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
if not args.db_url:
|
||||
raise SystemExit("missing --db-url or DATABASE_URL")
|
||||
|
||||
start = parse_time(args.start) if args.start else ""
|
||||
end = parse_time(args.end) if args.end else ""
|
||||
|
||||
with psycopg2.connect(args.db_url) as conn:
|
||||
ws90 = fetch_ws90(conn, args.site, start, end)
|
||||
baro = fetch_baro(conn, args.site, start, end)
|
||||
|
||||
df = build_dataset(ws90, baro)
|
||||
model, metrics, features = train_model(df)
|
||||
|
||||
print("Rain model metrics:")
|
||||
for k, v in metrics.items():
|
||||
print(f" {k}: {v}")
|
||||
|
||||
if args.out:
|
||||
out_dir = os.path.dirname(args.out)
|
||||
if out_dir:
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
if joblib is None:
|
||||
print("joblib not installed; skipping model save.")
|
||||
else:
|
||||
joblib.dump({"model": model, "features": features}, args.out)
|
||||
print(f"Saved model to {args.out}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user