#!/usr/bin/env python3 import argparse import json import os from datetime import datetime, timezone import numpy as np import psycopg2 from sklearn.linear_model import LogisticRegression from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from rain_model_common import ( FEATURE_COLUMNS, RAIN_EVENT_THRESHOLD_MM, build_dataset, evaluate_probs, fetch_baro, fetch_ws90, model_frame, parse_time, select_threshold, split_time_ordered, to_builtin, ) try: import joblib except ImportError: # pragma: no cover - optional dependency joblib = None def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Train a rain prediction model (next 1h >= 0.2mm).") parser.add_argument("--db-url", default=os.getenv("DATABASE_URL"), help="Postgres connection string.") parser.add_argument("--site", required=True, help="Site name (e.g. home).") parser.add_argument("--start", help="Start time (RFC3339 or YYYY-MM-DD).") parser.add_argument("--end", help="End time (RFC3339 or YYYY-MM-DD).") parser.add_argument("--train-ratio", type=float, default=0.7, help="Time-ordered train split ratio.") parser.add_argument("--val-ratio", type=float, default=0.15, help="Time-ordered validation split ratio.") parser.add_argument( "--min-precision", type=float, default=0.7, help="Minimum validation precision for threshold selection.", ) parser.add_argument("--threshold", type=float, help="Optional fixed classification threshold.") parser.add_argument("--min-rows", type=int, default=200, help="Minimum model-ready rows required.") parser.add_argument("--out", default="models/rain_model.pkl", help="Path to save model.") parser.add_argument( "--report-out", default="models/rain_model_report.json", help="Path to save JSON training report.", ) parser.add_argument( "--model-version", default="rain-logreg-v1", help="Version label stored in artifact metadata.", ) return parser.parse_args() def make_model() -> Pipeline: return Pipeline( [ ("scaler", StandardScaler()), ("clf", LogisticRegression(max_iter=1000, class_weight="balanced")), ] ) def main() -> int: args = parse_args() if not args.db_url: raise SystemExit("missing --db-url or DATABASE_URL") start = parse_time(args.start) if args.start else "" end = parse_time(args.end) if args.end else "" with psycopg2.connect(args.db_url) as conn: ws90 = fetch_ws90(conn, args.site, start, end) baro = fetch_baro(conn, args.site, start, end) full_df = build_dataset(ws90, baro, rain_event_threshold_mm=RAIN_EVENT_THRESHOLD_MM) model_df = model_frame(full_df, FEATURE_COLUMNS, require_target=True) if len(model_df) < args.min_rows: raise RuntimeError(f"not enough model-ready rows after filtering (need >= {args.min_rows})") train_df, val_df, test_df = split_time_ordered( model_df, train_ratio=args.train_ratio, val_ratio=args.val_ratio, ) x_train = train_df[FEATURE_COLUMNS] y_train = train_df["rain_next_1h"].astype(int).to_numpy() x_val = val_df[FEATURE_COLUMNS] y_val = val_df["rain_next_1h"].astype(int).to_numpy() x_test = test_df[FEATURE_COLUMNS] y_test = test_df["rain_next_1h"].astype(int).to_numpy() base_model = make_model() base_model.fit(x_train, y_train) y_val_prob = base_model.predict_proba(x_val)[:, 1] if args.threshold is not None: chosen_threshold = args.threshold threshold_info = { "selection_rule": "fixed_cli_threshold", "threshold": float(args.threshold), } else: chosen_threshold, threshold_info = select_threshold( y_true=y_val, y_prob=y_val_prob, min_precision=args.min_precision, ) val_metrics = evaluate_probs(y_true=y_val, y_prob=y_val_prob, threshold=chosen_threshold) train_val_df = model_df.iloc[: len(train_df) + len(val_df)] x_train_val = train_val_df[FEATURE_COLUMNS] y_train_val = train_val_df["rain_next_1h"].astype(int).to_numpy() final_model = make_model() final_model.fit(x_train_val, y_train_val) y_test_prob = final_model.predict_proba(x_test)[:, 1] test_metrics = evaluate_probs(y_true=y_test, y_prob=y_test_prob, threshold=chosen_threshold) report = { "generated_at": datetime.now(timezone.utc).isoformat(), "site": args.site, "model_version": args.model_version, "target_definition": f"rain_next_1h_mm >= {RAIN_EVENT_THRESHOLD_MM:.2f}", "feature_columns": FEATURE_COLUMNS, "data_window": { "requested_start": start or None, "requested_end": end or None, "actual_start": model_df.index.min(), "actual_end": model_df.index.max(), "model_rows": len(model_df), "ws90_rows": len(ws90), "baro_rows": len(baro), }, "label_quality": { "rain_reset_count": int(np.nansum(full_df["rain_reset"].fillna(False).to_numpy(dtype=int))), "rain_spike_5m_count": int(np.nansum(full_df["rain_spike_5m"].fillna(False).to_numpy(dtype=int))), }, "split": { "train_ratio": args.train_ratio, "val_ratio": args.val_ratio, "train_rows": len(train_df), "val_rows": len(val_df), "test_rows": len(test_df), "train_start": train_df.index.min(), "train_end": train_df.index.max(), "val_start": val_df.index.min(), "val_end": val_df.index.max(), "test_start": test_df.index.min(), "test_end": test_df.index.max(), }, "threshold_selection": { **threshold_info, "min_precision_constraint": args.min_precision, }, "validation_metrics": val_metrics, "test_metrics": test_metrics, } report = to_builtin(report) print("Rain model training summary:") print(f" site: {args.site}") print(f" model_version: {args.model_version}") print(f" rows: total={report['data_window']['model_rows']} train={report['split']['train_rows']} val={report['split']['val_rows']} test={report['split']['test_rows']}") print( " threshold: " f"{report['threshold_selection']['threshold']:.3f} " f"({report['threshold_selection']['selection_rule']})" ) print( " val metrics: " f"precision={report['validation_metrics']['precision']:.3f} " f"recall={report['validation_metrics']['recall']:.3f} " f"roc_auc={report['validation_metrics']['roc_auc'] if report['validation_metrics']['roc_auc'] is not None else 'n/a'} " f"pr_auc={report['validation_metrics']['pr_auc'] if report['validation_metrics']['pr_auc'] is not None else 'n/a'}" ) print( " test metrics: " f"precision={report['test_metrics']['precision']:.3f} " f"recall={report['test_metrics']['recall']:.3f} " f"roc_auc={report['test_metrics']['roc_auc'] if report['test_metrics']['roc_auc'] is not None else 'n/a'} " f"pr_auc={report['test_metrics']['pr_auc'] if report['test_metrics']['pr_auc'] is not None else 'n/a'}" ) if args.report_out: report_dir = os.path.dirname(args.report_out) if report_dir: os.makedirs(report_dir, exist_ok=True) with open(args.report_out, "w", encoding="utf-8") as f: json.dump(report, f, indent=2) print(f"Saved report to {args.report_out}") if args.out: out_dir = os.path.dirname(args.out) if out_dir: os.makedirs(out_dir, exist_ok=True) if joblib is None: print("joblib not installed; skipping model save.") else: artifact = { "model": final_model, "features": FEATURE_COLUMNS, "threshold": float(chosen_threshold), "target_mm": float(RAIN_EVENT_THRESHOLD_MM), "model_version": args.model_version, "trained_at": datetime.now(timezone.utc).isoformat(), "split": report["split"], "threshold_selection": report["threshold_selection"], } joblib.dump(artifact, args.out) print(f"Saved model to {args.out}") return 0 if __name__ == "__main__": raise SystemExit(main())