update for 4 hour rain forecast
This commit is contained in:
@@ -9,13 +9,18 @@ import psycopg2
|
||||
from psycopg2.extras import Json
|
||||
|
||||
from rain_model_common import (
|
||||
DEFAULT_HORIZON_HOURS,
|
||||
build_dataset,
|
||||
feature_columns_need_forecast,
|
||||
fetch_baro,
|
||||
fetch_forecast,
|
||||
fetch_ws90,
|
||||
model_frame,
|
||||
normalize_horizon_hours,
|
||||
parse_time,
|
||||
prediction_table_for_horizon,
|
||||
rain_next_flag_col,
|
||||
rain_next_mm_col,
|
||||
to_builtin,
|
||||
)
|
||||
|
||||
@@ -30,8 +35,13 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument("--db-url", default=os.getenv("DATABASE_URL"), help="Postgres connection string.")
|
||||
parser.add_argument("--site", required=True, help="Site name (e.g. home).")
|
||||
parser.add_argument("--model-path", default="models/rain_model.pkl", help="Path to trained model artifact.")
|
||||
parser.add_argument("--model-name", default="rain_next_1h", help="Logical prediction model name.")
|
||||
parser.add_argument("--model-name", default="rain_next_4h", help="Logical prediction model name.")
|
||||
parser.add_argument("--model-version", help="Override artifact model_version.")
|
||||
parser.add_argument(
|
||||
"--horizon-hours",
|
||||
type=int,
|
||||
help="Prediction horizon in hours. Defaults to artifact horizon when present, else 4.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--at",
|
||||
help="Prediction timestamp (RFC3339 or YYYY-MM-DD). Default: current UTC time.",
|
||||
@@ -98,9 +108,21 @@ def main() -> int:
|
||||
threshold = float(artifact.get("threshold", 0.5))
|
||||
model_version = args.model_version or artifact.get("model_version") or "unknown"
|
||||
forecast_model = str(artifact.get("forecast_model") or args.forecast_model)
|
||||
artifact_horizon = artifact.get("horizon_hours")
|
||||
if args.horizon_hours is not None:
|
||||
horizon_hours = normalize_horizon_hours(args.horizon_hours)
|
||||
elif artifact_horizon is not None:
|
||||
horizon_hours = normalize_horizon_hours(int(artifact_horizon))
|
||||
else:
|
||||
horizon_hours = DEFAULT_HORIZON_HOURS
|
||||
target_col = str(artifact.get("target_col") or rain_next_flag_col(horizon_hours))
|
||||
target_mm_col = str(artifact.get("target_mm_col") or rain_next_mm_col(horizon_hours))
|
||||
prediction_table = prediction_table_for_horizon(horizon_hours)
|
||||
actual_mm_col = f"{target_mm_col}_actual"
|
||||
actual_flag_col = f"{target_col}_actual"
|
||||
|
||||
fetch_start = (at - timedelta(hours=args.history_hours)).isoformat()
|
||||
fetch_end = (at + timedelta(hours=1, minutes=5)).isoformat()
|
||||
fetch_end = (at + timedelta(hours=horizon_hours, minutes=5)).isoformat()
|
||||
|
||||
with psycopg2.connect(args.db_url) as conn:
|
||||
ws90 = fetch_ws90(conn, args.site, fetch_start, fetch_end)
|
||||
@@ -122,7 +144,7 @@ def main() -> int:
|
||||
return 0
|
||||
raise RuntimeError(message)
|
||||
|
||||
full_df = build_dataset(ws90, baro, forecast=forecast)
|
||||
full_df = build_dataset(ws90, baro, forecast=forecast, horizon_hours=horizon_hours)
|
||||
feature_df = model_frame(full_df, feature_cols=features, require_target=False)
|
||||
candidates = feature_df.loc[feature_df.index <= at]
|
||||
if candidates.empty:
|
||||
@@ -143,9 +165,9 @@ def main() -> int:
|
||||
actual_flag = None
|
||||
evaluated_at = None
|
||||
latest_available = full_df.index.max().to_pydatetime()
|
||||
if pred_ts + timedelta(hours=1) <= latest_available:
|
||||
next_mm = full_df.loc[pred_ts, "rain_next_1h_mm"]
|
||||
next_flag = full_df.loc[pred_ts, "rain_next_1h"]
|
||||
if pred_ts + timedelta(hours=horizon_hours) <= latest_available:
|
||||
next_mm = full_df.loc[pred_ts, target_mm_col]
|
||||
next_flag = full_df.loc[pred_ts, target_col]
|
||||
if next_mm == next_mm: # NaN-safe check
|
||||
actual_mm = float(next_mm)
|
||||
if next_flag == next_flag:
|
||||
@@ -156,6 +178,10 @@ def main() -> int:
|
||||
"artifact_path": args.model_path,
|
||||
"artifact_model_version": artifact.get("model_version"),
|
||||
"artifact_feature_set": feature_set,
|
||||
"horizon_hours": horizon_hours,
|
||||
"target_col": target_col,
|
||||
"target_mm_col": target_mm_col,
|
||||
"prediction_table": prediction_table,
|
||||
"forecast_model": forecast_model if needs_forecast else None,
|
||||
"needs_forecast_features": needs_forecast,
|
||||
"feature_values": {col: float(row.iloc[0][col]) for col in features},
|
||||
@@ -170,6 +196,7 @@ def main() -> int:
|
||||
print(f" site: {args.site}")
|
||||
print(f" model_name: {args.model_name}")
|
||||
print(f" model_version: {model_version}")
|
||||
print(f" horizon_hours: {horizon_hours}")
|
||||
if feature_set:
|
||||
print(f" feature_set: {feature_set}")
|
||||
print(f" pred_ts: {pred_ts.isoformat()}")
|
||||
@@ -182,10 +209,8 @@ def main() -> int:
|
||||
print("dry-run enabled; skipping DB upsert.")
|
||||
return 0
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO predictions_rain_1h (
|
||||
query = f"""
|
||||
INSERT INTO {prediction_table} (
|
||||
ts,
|
||||
generated_at,
|
||||
site,
|
||||
@@ -194,8 +219,8 @@ def main() -> int:
|
||||
threshold,
|
||||
probability,
|
||||
predict_rain,
|
||||
rain_next_1h_mm_actual,
|
||||
rain_next_1h_actual,
|
||||
{actual_mm_col},
|
||||
{actual_flag_col},
|
||||
evaluated_at,
|
||||
metadata
|
||||
) VALUES (
|
||||
@@ -207,11 +232,14 @@ def main() -> int:
|
||||
threshold = EXCLUDED.threshold,
|
||||
probability = EXCLUDED.probability,
|
||||
predict_rain = EXCLUDED.predict_rain,
|
||||
rain_next_1h_mm_actual = COALESCE(EXCLUDED.rain_next_1h_mm_actual, predictions_rain_1h.rain_next_1h_mm_actual),
|
||||
rain_next_1h_actual = COALESCE(EXCLUDED.rain_next_1h_actual, predictions_rain_1h.rain_next_1h_actual),
|
||||
evaluated_at = COALESCE(EXCLUDED.evaluated_at, predictions_rain_1h.evaluated_at),
|
||||
{actual_mm_col} = COALESCE(EXCLUDED.{actual_mm_col}, {prediction_table}.{actual_mm_col}),
|
||||
{actual_flag_col} = COALESCE(EXCLUDED.{actual_flag_col}, {prediction_table}.{actual_flag_col}),
|
||||
evaluated_at = COALESCE(EXCLUDED.evaluated_at, {prediction_table}.evaluated_at),
|
||||
metadata = EXCLUDED.metadata
|
||||
""",
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
query,
|
||||
(
|
||||
pred_ts,
|
||||
args.site,
|
||||
@@ -227,7 +255,7 @@ def main() -> int:
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
print("Prediction upserted into predictions_rain_1h.")
|
||||
print(f"Prediction upserted into {prediction_table}.")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user