#!/usr/bin/env python3 from __future__ import annotations import argparse import os from datetime import datetime, timedelta, timezone import psycopg2 from psycopg2.extras import Json from rain_model_common import ( DEFAULT_HORIZON_HOURS, build_dataset, feature_columns_need_forecast, fetch_baro, fetch_forecast, fetch_ws90, model_frame, normalize_horizon_hours, parse_time, prediction_table_for_horizon, rain_next_flag_col, rain_next_mm_col, to_builtin, ) try: import joblib except ImportError: # pragma: no cover - optional dependency joblib = None def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run rain model inference and upsert prediction to Postgres.") parser.add_argument("--db-url", default=os.getenv("DATABASE_URL"), help="Postgres connection string.") parser.add_argument("--site", required=True, help="Site name (e.g. home).") parser.add_argument("--model-path", default="models/rain_model.pkl", help="Path to trained model artifact.") parser.add_argument("--model-name", default="rain_next_4h", help="Logical prediction model name.") parser.add_argument("--model-version", help="Override artifact model_version.") parser.add_argument( "--horizon-hours", type=int, help="Prediction horizon in hours. Defaults to artifact horizon when present, else 4.", ) parser.add_argument( "--at", help="Prediction timestamp (RFC3339 or YYYY-MM-DD). Default: current UTC time.", ) parser.add_argument( "--history-hours", type=int, default=6, help="History lookback window used to build features.", ) parser.add_argument( "--forecast-model", default="ecmwf", help="Forecast model name when inference features require forecast columns.", ) parser.set_defaults(allow_empty=True) parser.add_argument( "--allow-empty", dest="allow_empty", action="store_true", help="Exit successfully when required source rows are temporarily unavailable (default: enabled).", ) parser.add_argument( "--strict-source-data", dest="allow_empty", action="store_false", help="Fail when required source rows are unavailable.", ) parser.add_argument("--dry-run", action="store_true", help="Do not write prediction to DB.") return parser.parse_args() def load_artifact(path: str): if joblib is None: raise RuntimeError("joblib not installed; cannot load model artifact") if not os.path.exists(path): raise RuntimeError(f"model artifact not found: {path}") artifact = joblib.load(path) if "model" not in artifact: raise RuntimeError("invalid artifact: missing 'model'") if "features" not in artifact: raise RuntimeError("invalid artifact: missing 'features'") return artifact def parse_at(value: str | None) -> datetime: if not value: return datetime.now(timezone.utc) parsed = parse_time(value) return datetime.fromisoformat(parsed.replace("Z", "+00:00")).astimezone(timezone.utc) def main() -> int: args = parse_args() if not args.db_url: raise SystemExit("missing --db-url or DATABASE_URL") at = parse_at(args.at) artifact = load_artifact(args.model_path) model = artifact["model"] features = list(artifact["features"]) feature_set = artifact.get("feature_set") needs_forecast = feature_columns_need_forecast(features) threshold = float(artifact.get("threshold", 0.5)) model_version = args.model_version or artifact.get("model_version") or "unknown" forecast_model = str(artifact.get("forecast_model") or args.forecast_model) artifact_horizon = artifact.get("horizon_hours") if args.horizon_hours is not None: horizon_hours = normalize_horizon_hours(args.horizon_hours) elif artifact_horizon is not None: horizon_hours = normalize_horizon_hours(int(artifact_horizon)) else: horizon_hours = DEFAULT_HORIZON_HOURS target_col = str(artifact.get("target_col") or rain_next_flag_col(horizon_hours)) target_mm_col = str(artifact.get("target_mm_col") or rain_next_mm_col(horizon_hours)) prediction_table = prediction_table_for_horizon(horizon_hours) actual_mm_col = f"{target_mm_col}_actual" actual_flag_col = f"{target_col}_actual" fetch_start = (at - timedelta(hours=args.history_hours)).isoformat() fetch_end = (at + timedelta(hours=horizon_hours, minutes=5)).isoformat() with psycopg2.connect(args.db_url) as conn: ws90 = fetch_ws90(conn, args.site, fetch_start, fetch_end) baro = fetch_baro(conn, args.site, fetch_start, fetch_end) forecast = None if needs_forecast: forecast = fetch_forecast(conn, args.site, fetch_start, fetch_end, model=forecast_model) if ws90.empty: message = "no ws90 observations found in source window" if args.allow_empty: print(f"Rain inference skipped: {message}.") return 0 raise RuntimeError(message) if baro.empty: message = "no barometer observations found in source window" if args.allow_empty: print(f"Rain inference skipped: {message}.") return 0 raise RuntimeError(message) full_df = build_dataset(ws90, baro, forecast=forecast, horizon_hours=horizon_hours) feature_df = model_frame(full_df, feature_cols=features, require_target=False) candidates = feature_df.loc[feature_df.index <= at] if candidates.empty: message = "no feature-complete row available at or before requested timestamp" if args.allow_empty: print(f"Rain inference skipped: {message}.") return 0 raise RuntimeError(message) row = candidates.tail(1) pred_ts = row.index[0].to_pydatetime() x = row[features] probability = float(model.predict_proba(x)[:, 1][0]) predict_rain = probability >= threshold actual_mm = None actual_flag = None evaluated_at = None latest_available = full_df.index.max().to_pydatetime() if pred_ts + timedelta(hours=horizon_hours) <= latest_available: next_mm = full_df.loc[pred_ts, target_mm_col] next_flag = full_df.loc[pred_ts, target_col] if next_mm == next_mm: # NaN-safe check actual_mm = float(next_mm) if next_flag == next_flag: actual_flag = bool(next_flag) evaluated_at = datetime.now(timezone.utc) metadata = { "artifact_path": args.model_path, "artifact_model_version": artifact.get("model_version"), "artifact_feature_set": feature_set, "horizon_hours": horizon_hours, "target_col": target_col, "target_mm_col": target_mm_col, "prediction_table": prediction_table, "forecast_model": forecast_model if needs_forecast else None, "needs_forecast_features": needs_forecast, "feature_values": {col: float(row.iloc[0][col]) for col in features}, "source_window_start": fetch_start, "source_window_end": fetch_end, "requested_at": at.isoformat(), "pred_ts": pred_ts.isoformat(), } metadata = to_builtin(metadata) print("Rain inference summary:") print(f" site: {args.site}") print(f" model_name: {args.model_name}") print(f" model_version: {model_version}") print(f" horizon_hours: {horizon_hours}") if feature_set: print(f" feature_set: {feature_set}") print(f" pred_ts: {pred_ts.isoformat()}") print(f" threshold: {threshold:.3f}") print(f" probability: {probability:.4f}") print(f" predict_rain: {predict_rain}") print(f" outcome_available: {actual_flag is not None}") if args.dry_run: print("dry-run enabled; skipping DB upsert.") return 0 query = f""" INSERT INTO {prediction_table} ( ts, generated_at, site, model_name, model_version, threshold, probability, predict_rain, {actual_mm_col}, {actual_flag_col}, evaluated_at, metadata ) VALUES ( %s, now(), %s, %s, %s, %s, %s, %s, %s, %s, %s, %s ) ON CONFLICT (site, model_name, model_version, ts) DO UPDATE SET generated_at = EXCLUDED.generated_at, threshold = EXCLUDED.threshold, probability = EXCLUDED.probability, predict_rain = EXCLUDED.predict_rain, {actual_mm_col} = COALESCE(EXCLUDED.{actual_mm_col}, {prediction_table}.{actual_mm_col}), {actual_flag_col} = COALESCE(EXCLUDED.{actual_flag_col}, {prediction_table}.{actual_flag_col}), evaluated_at = COALESCE(EXCLUDED.evaluated_at, {prediction_table}.evaluated_at), metadata = EXCLUDED.metadata """ with conn.cursor() as cur: cur.execute( query, ( pred_ts, args.site, args.model_name, model_version, threshold, probability, predict_rain, actual_mm, actual_flag, evaluated_at, Json(metadata), ), ) conn.commit() print(f"Prediction upserted into {prediction_table}.") return 0 if __name__ == "__main__": raise SystemExit(main())