#!/usr/bin/env python3 import argparse import os from datetime import datetime import numpy as np import pandas as pd import psycopg2 from sklearn.linear_model import LogisticRegression from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, roc_auc_score from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler try: import joblib except ImportError: # pragma: no cover - optional dependency joblib = None def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Train a simple rain prediction model (next 1h >= 0.2mm).") parser.add_argument("--db-url", default=os.getenv("DATABASE_URL"), help="Postgres connection string.") parser.add_argument("--site", required=True, help="Site name (e.g. home).") parser.add_argument("--start", help="Start time (RFC3339 or YYYY-MM-DD).") parser.add_argument("--end", help="End time (RFC3339 or YYYY-MM-DD).") parser.add_argument("--out", default="models/rain_model.pkl", help="Path to save model.") return parser.parse_args() def parse_time(value: str) -> str: if not value: return "" try: datetime.fromisoformat(value.replace("Z", "+00:00")) return value except ValueError: raise ValueError(f"invalid time format: {value}") def fetch_ws90(conn, site, start, end): sql = """ SELECT ts, temperature_c, humidity, wind_avg_m_s, wind_max_m_s, wind_dir_deg, rain_mm FROM observations_ws90 WHERE site = %s AND (%s = '' OR ts >= %s::timestamptz) AND (%s = '' OR ts <= %s::timestamptz) ORDER BY ts ASC """ return pd.read_sql_query(sql, conn, params=(site, start, start, end, end), parse_dates=["ts"]) def fetch_baro(conn, site, start, end): sql = """ SELECT ts, pressure_hpa FROM observations_baro WHERE site = %s AND (%s = '' OR ts >= %s::timestamptz) AND (%s = '' OR ts <= %s::timestamptz) ORDER BY ts ASC """ return pd.read_sql_query(sql, conn, params=(site, start, start, end, end), parse_dates=["ts"]) def build_dataset(ws90: pd.DataFrame, baro: pd.DataFrame) -> pd.DataFrame: if ws90.empty: raise RuntimeError("no ws90 observations found") if baro.empty: raise RuntimeError("no barometer observations found") ws90 = ws90.set_index("ts").sort_index() baro = baro.set_index("ts").sort_index() ws90_5m = ws90.resample("5min").agg( { "temperature_c": "mean", "humidity": "mean", "wind_avg_m_s": "mean", "wind_max_m_s": "max", "wind_dir_deg": "mean", "rain_mm": "last", } ) baro_5m = baro.resample("5min").mean() df = ws90_5m.join(baro_5m, how="outer") df["pressure_hpa"] = df["pressure_hpa"].interpolate(limit=3) # Compute incremental rain and future 1-hour sum. df["rain_inc"] = df["rain_mm"].diff().clip(lower=0) window = 12 # 12 * 5min = 1 hour df["rain_next_1h_mm"] = df["rain_inc"].rolling(window=window, min_periods=1).sum().shift(-(window - 1)) df["rain_next_1h"] = df["rain_next_1h_mm"] >= 0.2 # Pressure trend over the previous hour. df["pressure_trend_1h"] = df["pressure_hpa"] - df["pressure_hpa"].shift(12) return df def train_model(df: pd.DataFrame): feature_cols = [ "pressure_trend_1h", "humidity", "temperature_c", "wind_avg_m_s", "wind_max_m_s", ] df = df.dropna(subset=feature_cols + ["rain_next_1h"]) if len(df) < 200: raise RuntimeError("not enough data after filtering (need >= 200 rows)") X = df[feature_cols] y = df["rain_next_1h"].astype(int) split_idx = int(len(df) * 0.8) X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:] y_train, y_test = y.iloc[:split_idx], y.iloc[split_idx:] model = Pipeline( [ ("scaler", StandardScaler()), ("clf", LogisticRegression(max_iter=1000, class_weight="balanced")), ] ) model.fit(X_train, y_train) y_pred = model.predict(X_test) y_prob = model.predict_proba(X_test)[:, 1] metrics = { "rows": len(df), "train_rows": len(X_train), "test_rows": len(X_test), "accuracy": accuracy_score(y_test, y_pred), "precision": precision_score(y_test, y_pred, zero_division=0), "recall": recall_score(y_test, y_pred, zero_division=0), "roc_auc": roc_auc_score(y_test, y_prob), "confusion_matrix": confusion_matrix(y_test, y_pred).tolist(), } return model, metrics, feature_cols def main() -> int: args = parse_args() if not args.db_url: raise SystemExit("missing --db-url or DATABASE_URL") start = parse_time(args.start) if args.start else "" end = parse_time(args.end) if args.end else "" with psycopg2.connect(args.db_url) as conn: ws90 = fetch_ws90(conn, args.site, start, end) baro = fetch_baro(conn, args.site, start, end) df = build_dataset(ws90, baro) model, metrics, features = train_model(df) print("Rain model metrics:") for k, v in metrics.items(): print(f" {k}: {v}") if args.out: out_dir = os.path.dirname(args.out) if out_dir: os.makedirs(out_dir, exist_ok=True) if joblib is None: print("joblib not installed; skipping model save.") else: joblib.dump({"model": model, "features": features}, args.out) print(f"Saved model to {args.out}") return 0 if __name__ == "__main__": raise SystemExit(main())