#!/usr/bin/env python3 from __future__ import annotations import argparse import os from datetime import datetime, timedelta, timezone import psycopg2 from psycopg2.extras import Json from rain_model_common import ( build_dataset, feature_columns_need_forecast, fetch_baro, fetch_forecast, fetch_ws90, model_frame, parse_time, to_builtin, ) try: import joblib except ImportError: # pragma: no cover - optional dependency joblib = None def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run rain model inference and upsert prediction to Postgres.") parser.add_argument("--db-url", default=os.getenv("DATABASE_URL"), help="Postgres connection string.") parser.add_argument("--site", required=True, help="Site name (e.g. home).") parser.add_argument("--model-path", default="models/rain_model.pkl", help="Path to trained model artifact.") parser.add_argument("--model-name", default="rain_next_1h", help="Logical prediction model name.") parser.add_argument("--model-version", help="Override artifact model_version.") parser.add_argument( "--at", help="Prediction timestamp (RFC3339 or YYYY-MM-DD). Default: current UTC time.", ) parser.add_argument( "--history-hours", type=int, default=6, help="History lookback window used to build features.", ) parser.add_argument( "--forecast-model", default="ecmwf", help="Forecast model name when inference features require forecast columns.", ) parser.add_argument("--dry-run", action="store_true", help="Do not write prediction to DB.") return parser.parse_args() def load_artifact(path: str): if joblib is None: raise RuntimeError("joblib not installed; cannot load model artifact") if not os.path.exists(path): raise RuntimeError(f"model artifact not found: {path}") artifact = joblib.load(path) if "model" not in artifact: raise RuntimeError("invalid artifact: missing 'model'") if "features" not in artifact: raise RuntimeError("invalid artifact: missing 'features'") return artifact def parse_at(value: str | None) -> datetime: if not value: return datetime.now(timezone.utc) parsed = parse_time(value) return datetime.fromisoformat(parsed.replace("Z", "+00:00")).astimezone(timezone.utc) def main() -> int: args = parse_args() if not args.db_url: raise SystemExit("missing --db-url or DATABASE_URL") at = parse_at(args.at) artifact = load_artifact(args.model_path) model = artifact["model"] features = list(artifact["features"]) feature_set = artifact.get("feature_set") needs_forecast = feature_columns_need_forecast(features) threshold = float(artifact.get("threshold", 0.5)) model_version = args.model_version or artifact.get("model_version") or "unknown" forecast_model = str(artifact.get("forecast_model") or args.forecast_model) fetch_start = (at - timedelta(hours=args.history_hours)).isoformat() fetch_end = (at + timedelta(hours=1, minutes=5)).isoformat() with psycopg2.connect(args.db_url) as conn: ws90 = fetch_ws90(conn, args.site, fetch_start, fetch_end) baro = fetch_baro(conn, args.site, fetch_start, fetch_end) forecast = None if needs_forecast: forecast = fetch_forecast(conn, args.site, fetch_start, fetch_end, model=forecast_model) full_df = build_dataset(ws90, baro, forecast=forecast) feature_df = model_frame(full_df, feature_cols=features, require_target=False) candidates = feature_df.loc[feature_df.index <= at] if candidates.empty: raise RuntimeError("no feature-complete row available at or before requested timestamp") row = candidates.tail(1) pred_ts = row.index[0].to_pydatetime() x = row[features] probability = float(model.predict_proba(x)[:, 1][0]) predict_rain = probability >= threshold actual_mm = None actual_flag = None evaluated_at = None latest_available = full_df.index.max().to_pydatetime() if pred_ts + timedelta(hours=1) <= latest_available: next_mm = full_df.loc[pred_ts, "rain_next_1h_mm"] next_flag = full_df.loc[pred_ts, "rain_next_1h"] if next_mm == next_mm: # NaN-safe check actual_mm = float(next_mm) if next_flag == next_flag: actual_flag = bool(next_flag) evaluated_at = datetime.now(timezone.utc) metadata = { "artifact_path": args.model_path, "artifact_model_version": artifact.get("model_version"), "artifact_feature_set": feature_set, "forecast_model": forecast_model if needs_forecast else None, "needs_forecast_features": needs_forecast, "feature_values": {col: float(row.iloc[0][col]) for col in features}, "source_window_start": fetch_start, "source_window_end": fetch_end, "requested_at": at.isoformat(), "pred_ts": pred_ts.isoformat(), } metadata = to_builtin(metadata) print("Rain inference summary:") print(f" site: {args.site}") print(f" model_name: {args.model_name}") print(f" model_version: {model_version}") if feature_set: print(f" feature_set: {feature_set}") print(f" pred_ts: {pred_ts.isoformat()}") print(f" threshold: {threshold:.3f}") print(f" probability: {probability:.4f}") print(f" predict_rain: {predict_rain}") print(f" outcome_available: {actual_flag is not None}") if args.dry_run: print("dry-run enabled; skipping DB upsert.") return 0 with conn.cursor() as cur: cur.execute( """ INSERT INTO predictions_rain_1h ( ts, generated_at, site, model_name, model_version, threshold, probability, predict_rain, rain_next_1h_mm_actual, rain_next_1h_actual, evaluated_at, metadata ) VALUES ( %s, now(), %s, %s, %s, %s, %s, %s, %s, %s, %s, %s ) ON CONFLICT (site, model_name, model_version, ts) DO UPDATE SET generated_at = EXCLUDED.generated_at, threshold = EXCLUDED.threshold, probability = EXCLUDED.probability, predict_rain = EXCLUDED.predict_rain, rain_next_1h_mm_actual = COALESCE(EXCLUDED.rain_next_1h_mm_actual, predictions_rain_1h.rain_next_1h_mm_actual), rain_next_1h_actual = COALESCE(EXCLUDED.rain_next_1h_actual, predictions_rain_1h.rain_next_1h_actual), evaluated_at = COALESCE(EXCLUDED.evaluated_at, predictions_rain_1h.evaluated_at), metadata = EXCLUDED.metadata """, ( pred_ts, args.site, args.model_name, model_version, threshold, probability, predict_rain, actual_mm, actual_flag, evaluated_at, Json(metadata), ), ) conn.commit() print("Prediction upserted into predictions_rain_1h.") return 0 if __name__ == "__main__": raise SystemExit(main())