add model training

This commit is contained in:
2026-02-02 17:08:43 +11:00
parent 737eef85ea
commit 8edd0dc8b0
7 changed files with 388 additions and 9 deletions

178
scripts/train_rain_model.py Normal file
View File

@@ -0,0 +1,178 @@
#!/usr/bin/env python3
import argparse
import os
from datetime import datetime
import numpy as np
import pandas as pd
import psycopg2
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, roc_auc_score
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
try:
import joblib
except ImportError: # pragma: no cover - optional dependency
joblib = None
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train a simple rain prediction model (next 1h >= 0.2mm).")
parser.add_argument("--db-url", default=os.getenv("DATABASE_URL"), help="Postgres connection string.")
parser.add_argument("--site", required=True, help="Site name (e.g. home).")
parser.add_argument("--start", help="Start time (RFC3339 or YYYY-MM-DD).")
parser.add_argument("--end", help="End time (RFC3339 or YYYY-MM-DD).")
parser.add_argument("--out", default="models/rain_model.pkl", help="Path to save model.")
return parser.parse_args()
def parse_time(value: str) -> str:
if not value:
return ""
try:
datetime.fromisoformat(value.replace("Z", "+00:00"))
return value
except ValueError:
raise ValueError(f"invalid time format: {value}")
def fetch_ws90(conn, site, start, end):
sql = """
SELECT ts, temperature_c, humidity, wind_avg_m_s, wind_max_m_s, wind_dir_deg, rain_mm
FROM observations_ws90
WHERE site = %s
AND (%s = '' OR ts >= %s::timestamptz)
AND (%s = '' OR ts <= %s::timestamptz)
ORDER BY ts ASC
"""
return pd.read_sql_query(sql, conn, params=(site, start, start, end, end), parse_dates=["ts"])
def fetch_baro(conn, site, start, end):
sql = """
SELECT ts, pressure_hpa
FROM observations_baro
WHERE site = %s
AND (%s = '' OR ts >= %s::timestamptz)
AND (%s = '' OR ts <= %s::timestamptz)
ORDER BY ts ASC
"""
return pd.read_sql_query(sql, conn, params=(site, start, start, end, end), parse_dates=["ts"])
def build_dataset(ws90: pd.DataFrame, baro: pd.DataFrame) -> pd.DataFrame:
if ws90.empty:
raise RuntimeError("no ws90 observations found")
if baro.empty:
raise RuntimeError("no barometer observations found")
ws90 = ws90.set_index("ts").sort_index()
baro = baro.set_index("ts").sort_index()
ws90_5m = ws90.resample("5min").agg(
{
"temperature_c": "mean",
"humidity": "mean",
"wind_avg_m_s": "mean",
"wind_max_m_s": "max",
"wind_dir_deg": "mean",
"rain_mm": "last",
}
)
baro_5m = baro.resample("5min").mean()
df = ws90_5m.join(baro_5m, how="outer")
df["pressure_hpa"] = df["pressure_hpa"].interpolate(limit=3)
# Compute incremental rain and future 1-hour sum.
df["rain_inc"] = df["rain_mm"].diff().clip(lower=0)
window = 12 # 12 * 5min = 1 hour
df["rain_next_1h_mm"] = df["rain_inc"].rolling(window=window, min_periods=1).sum().shift(-(window - 1))
df["rain_next_1h"] = df["rain_next_1h_mm"] >= 0.2
# Pressure trend over the previous hour.
df["pressure_trend_1h"] = df["pressure_hpa"] - df["pressure_hpa"].shift(12)
return df
def train_model(df: pd.DataFrame):
feature_cols = [
"pressure_trend_1h",
"humidity",
"temperature_c",
"wind_avg_m_s",
"wind_max_m_s",
]
df = df.dropna(subset=feature_cols + ["rain_next_1h"])
if len(df) < 200:
raise RuntimeError("not enough data after filtering (need >= 200 rows)")
X = df[feature_cols]
y = df["rain_next_1h"].astype(int)
split_idx = int(len(df) * 0.8)
X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:]
y_train, y_test = y.iloc[:split_idx], y.iloc[split_idx:]
model = Pipeline(
[
("scaler", StandardScaler()),
("clf", LogisticRegression(max_iter=1000, class_weight="balanced")),
]
)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
y_prob = model.predict_proba(X_test)[:, 1]
metrics = {
"rows": len(df),
"train_rows": len(X_train),
"test_rows": len(X_test),
"accuracy": accuracy_score(y_test, y_pred),
"precision": precision_score(y_test, y_pred, zero_division=0),
"recall": recall_score(y_test, y_pred, zero_division=0),
"roc_auc": roc_auc_score(y_test, y_prob),
"confusion_matrix": confusion_matrix(y_test, y_pred).tolist(),
}
return model, metrics, feature_cols
def main() -> int:
args = parse_args()
if not args.db_url:
raise SystemExit("missing --db-url or DATABASE_URL")
start = parse_time(args.start) if args.start else ""
end = parse_time(args.end) if args.end else ""
with psycopg2.connect(args.db_url) as conn:
ws90 = fetch_ws90(conn, args.site, start, end)
baro = fetch_baro(conn, args.site, start, end)
df = build_dataset(ws90, baro)
model, metrics, features = train_model(df)
print("Rain model metrics:")
for k, v in metrics.items():
print(f" {k}: {v}")
if args.out:
out_dir = os.path.dirname(args.out)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
if joblib is None:
print("joblib not installed; skipping model save.")
else:
joblib.dump({"model": model, "features": features}, args.out)
print(f"Saved model to {args.out}")
return 0
if __name__ == "__main__":
raise SystemExit(main())