182 lines
6.8 KiB
Python
182 lines
6.8 KiB
Python
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import os
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
import psycopg2
|
|
from psycopg2.extras import Json
|
|
|
|
from rain_model_common import build_dataset, fetch_baro, fetch_ws90, model_frame, parse_time, to_builtin
|
|
|
|
try:
|
|
import joblib
|
|
except ImportError: # pragma: no cover - optional dependency
|
|
joblib = None
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Run rain model inference and upsert prediction to Postgres.")
|
|
parser.add_argument("--db-url", default=os.getenv("DATABASE_URL"), help="Postgres connection string.")
|
|
parser.add_argument("--site", required=True, help="Site name (e.g. home).")
|
|
parser.add_argument("--model-path", default="models/rain_model.pkl", help="Path to trained model artifact.")
|
|
parser.add_argument("--model-name", default="rain_next_1h", help="Logical prediction model name.")
|
|
parser.add_argument("--model-version", help="Override artifact model_version.")
|
|
parser.add_argument(
|
|
"--at",
|
|
help="Prediction timestamp (RFC3339 or YYYY-MM-DD). Default: current UTC time.",
|
|
)
|
|
parser.add_argument(
|
|
"--history-hours",
|
|
type=int,
|
|
default=6,
|
|
help="History lookback window used to build features.",
|
|
)
|
|
parser.add_argument("--dry-run", action="store_true", help="Do not write prediction to DB.")
|
|
return parser.parse_args()
|
|
|
|
|
|
def load_artifact(path: str):
|
|
if joblib is None:
|
|
raise RuntimeError("joblib not installed; cannot load model artifact")
|
|
if not os.path.exists(path):
|
|
raise RuntimeError(f"model artifact not found: {path}")
|
|
artifact = joblib.load(path)
|
|
if "model" not in artifact:
|
|
raise RuntimeError("invalid artifact: missing 'model'")
|
|
if "features" not in artifact:
|
|
raise RuntimeError("invalid artifact: missing 'features'")
|
|
return artifact
|
|
|
|
|
|
def parse_at(value: str | None) -> datetime:
|
|
if not value:
|
|
return datetime.now(timezone.utc)
|
|
parsed = parse_time(value)
|
|
return datetime.fromisoformat(parsed.replace("Z", "+00:00")).astimezone(timezone.utc)
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
if not args.db_url:
|
|
raise SystemExit("missing --db-url or DATABASE_URL")
|
|
|
|
at = parse_at(args.at)
|
|
artifact = load_artifact(args.model_path)
|
|
model = artifact["model"]
|
|
features = artifact["features"]
|
|
threshold = float(artifact.get("threshold", 0.5))
|
|
model_version = args.model_version or artifact.get("model_version") or "unknown"
|
|
|
|
fetch_start = (at - timedelta(hours=args.history_hours)).isoformat()
|
|
fetch_end = (at + timedelta(hours=1, minutes=5)).isoformat()
|
|
|
|
with psycopg2.connect(args.db_url) as conn:
|
|
ws90 = fetch_ws90(conn, args.site, fetch_start, fetch_end)
|
|
baro = fetch_baro(conn, args.site, fetch_start, fetch_end)
|
|
|
|
full_df = build_dataset(ws90, baro)
|
|
feature_df = model_frame(full_df, feature_cols=features, require_target=False)
|
|
candidates = feature_df.loc[feature_df.index <= at]
|
|
if candidates.empty:
|
|
raise RuntimeError("no feature-complete row available at or before requested timestamp")
|
|
|
|
row = candidates.tail(1)
|
|
pred_ts = row.index[0].to_pydatetime()
|
|
x = row[features]
|
|
|
|
probability = float(model.predict_proba(x)[:, 1][0])
|
|
predict_rain = probability >= threshold
|
|
|
|
actual_mm = None
|
|
actual_flag = None
|
|
evaluated_at = None
|
|
latest_available = full_df.index.max().to_pydatetime()
|
|
if pred_ts + timedelta(hours=1) <= latest_available:
|
|
next_mm = full_df.loc[pred_ts, "rain_next_1h_mm"]
|
|
next_flag = full_df.loc[pred_ts, "rain_next_1h"]
|
|
if next_mm == next_mm: # NaN-safe check
|
|
actual_mm = float(next_mm)
|
|
if next_flag == next_flag:
|
|
actual_flag = bool(next_flag)
|
|
evaluated_at = datetime.now(timezone.utc)
|
|
|
|
metadata = {
|
|
"artifact_path": args.model_path,
|
|
"artifact_model_version": artifact.get("model_version"),
|
|
"feature_values": {col: float(row.iloc[0][col]) for col in features},
|
|
"source_window_start": fetch_start,
|
|
"source_window_end": fetch_end,
|
|
"requested_at": at.isoformat(),
|
|
"pred_ts": pred_ts.isoformat(),
|
|
}
|
|
metadata = to_builtin(metadata)
|
|
|
|
print("Rain inference summary:")
|
|
print(f" site: {args.site}")
|
|
print(f" model_name: {args.model_name}")
|
|
print(f" model_version: {model_version}")
|
|
print(f" pred_ts: {pred_ts.isoformat()}")
|
|
print(f" threshold: {threshold:.3f}")
|
|
print(f" probability: {probability:.4f}")
|
|
print(f" predict_rain: {predict_rain}")
|
|
print(f" outcome_available: {actual_flag is not None}")
|
|
|
|
if args.dry_run:
|
|
print("dry-run enabled; skipping DB upsert.")
|
|
return 0
|
|
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO predictions_rain_1h (
|
|
ts,
|
|
generated_at,
|
|
site,
|
|
model_name,
|
|
model_version,
|
|
threshold,
|
|
probability,
|
|
predict_rain,
|
|
rain_next_1h_mm_actual,
|
|
rain_next_1h_actual,
|
|
evaluated_at,
|
|
metadata
|
|
) VALUES (
|
|
%s, now(), %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
|
|
)
|
|
ON CONFLICT (site, model_name, model_version, ts)
|
|
DO UPDATE SET
|
|
generated_at = EXCLUDED.generated_at,
|
|
threshold = EXCLUDED.threshold,
|
|
probability = EXCLUDED.probability,
|
|
predict_rain = EXCLUDED.predict_rain,
|
|
rain_next_1h_mm_actual = COALESCE(EXCLUDED.rain_next_1h_mm_actual, predictions_rain_1h.rain_next_1h_mm_actual),
|
|
rain_next_1h_actual = COALESCE(EXCLUDED.rain_next_1h_actual, predictions_rain_1h.rain_next_1h_actual),
|
|
evaluated_at = COALESCE(EXCLUDED.evaluated_at, predictions_rain_1h.evaluated_at),
|
|
metadata = EXCLUDED.metadata
|
|
""",
|
|
(
|
|
pred_ts,
|
|
args.site,
|
|
args.model_name,
|
|
model_version,
|
|
threshold,
|
|
probability,
|
|
predict_rain,
|
|
actual_mm,
|
|
actual_flag,
|
|
evaluated_at,
|
|
Json(metadata),
|
|
),
|
|
)
|
|
conn.commit()
|
|
print("Prediction upserted into predictions_rain_1h.")
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|