1176 lines
43 KiB
Python
1176 lines
43 KiB
Python
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import itertools
|
|
import json
|
|
import os
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import psycopg2
|
|
from sklearn.calibration import CalibratedClassifierCV
|
|
from sklearn.ensemble import HistGradientBoostingClassifier
|
|
from sklearn.linear_model import LogisticRegression
|
|
from sklearn.model_selection import TimeSeriesSplit
|
|
from sklearn.pipeline import Pipeline
|
|
from sklearn.preprocessing import StandardScaler
|
|
|
|
from rain_model_common import (
|
|
AVAILABLE_FEATURE_SETS,
|
|
RAIN_EVENT_THRESHOLD_MM,
|
|
build_dataset,
|
|
evaluate_probs,
|
|
fetch_baro,
|
|
fetch_forecast,
|
|
fetch_ws90,
|
|
feature_columns_for_set,
|
|
feature_columns_need_forecast,
|
|
model_frame,
|
|
parse_time,
|
|
safe_pr_auc,
|
|
safe_roc_auc,
|
|
select_threshold,
|
|
split_time_ordered,
|
|
to_builtin,
|
|
)
|
|
|
|
try:
|
|
import joblib
|
|
except ImportError: # pragma: no cover - optional dependency
|
|
joblib = None
|
|
|
|
|
|
MODEL_FAMILIES = ("logreg", "hist_gb", "auto")
|
|
CALIBRATION_METHODS = ("none", "sigmoid", "isotonic")
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Train a rain prediction model (next 1h >= 0.2mm).")
|
|
parser.add_argument("--db-url", default=os.getenv("DATABASE_URL"), help="Postgres connection string.")
|
|
parser.add_argument("--site", required=True, help="Site name (e.g. home).")
|
|
parser.add_argument("--start", help="Start time (RFC3339 or YYYY-MM-DD).")
|
|
parser.add_argument("--end", help="End time (RFC3339 or YYYY-MM-DD).")
|
|
parser.add_argument("--train-ratio", type=float, default=0.7, help="Time-ordered train split ratio.")
|
|
parser.add_argument("--val-ratio", type=float, default=0.15, help="Time-ordered validation split ratio.")
|
|
parser.add_argument(
|
|
"--min-precision",
|
|
type=float,
|
|
default=0.7,
|
|
help="Minimum validation precision for threshold selection.",
|
|
)
|
|
parser.add_argument("--threshold", type=float, help="Optional fixed classification threshold.")
|
|
parser.add_argument("--min-rows", type=int, default=200, help="Minimum model-ready rows required.")
|
|
parser.set_defaults(allow_empty=True)
|
|
parser.add_argument(
|
|
"--allow-empty",
|
|
dest="allow_empty",
|
|
action="store_true",
|
|
help="Exit successfully when source/model-ready rows are temporarily unavailable (default: enabled).",
|
|
)
|
|
parser.add_argument(
|
|
"--strict-source-data",
|
|
dest="allow_empty",
|
|
action="store_false",
|
|
help="Fail when source/model-ready rows are unavailable.",
|
|
)
|
|
parser.add_argument(
|
|
"--feature-set",
|
|
default="baseline",
|
|
choices=AVAILABLE_FEATURE_SETS,
|
|
help="Named feature set to train with.",
|
|
)
|
|
parser.add_argument(
|
|
"--forecast-model",
|
|
default="ecmwf",
|
|
help="Forecast model name when feature set requires forecast columns.",
|
|
)
|
|
parser.add_argument(
|
|
"--model-family",
|
|
default="logreg",
|
|
choices=MODEL_FAMILIES,
|
|
help=(
|
|
"Estimator family. "
|
|
"'auto' compares logreg and hist_gb on validation and selects best by PR-AUC/ROC-AUC/F1."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--tune-hyperparameters",
|
|
action="store_true",
|
|
help="Run a lightweight validation-only hyperparameter search before final training.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-hyperparam-trials",
|
|
type=int,
|
|
default=12,
|
|
help="Maximum hyperparameter trials per model family when tuning is enabled.",
|
|
)
|
|
parser.add_argument(
|
|
"--calibration-methods",
|
|
default="none,sigmoid,isotonic",
|
|
help="Comma-separated methods from: none,sigmoid,isotonic. Best method selected by validation Brier/ECE.",
|
|
)
|
|
parser.add_argument(
|
|
"--walk-forward-folds",
|
|
type=int,
|
|
default=4,
|
|
help="Number of temporal folds for walk-forward backtest (0 to disable).",
|
|
)
|
|
parser.add_argument(
|
|
"--random-state",
|
|
type=int,
|
|
default=42,
|
|
help="Random seed for stochastic estimators.",
|
|
)
|
|
parser.add_argument("--out", default="models/rain_model.pkl", help="Path to save model.")
|
|
parser.add_argument(
|
|
"--report-out",
|
|
default="models/rain_model_report.json",
|
|
help="Path to save JSON training report.",
|
|
)
|
|
parser.add_argument(
|
|
"--model-card-out",
|
|
default="models/model_card_{model_version}.md",
|
|
help="Path (or template) for markdown model card. Supports {model_version}. Use empty string to disable.",
|
|
)
|
|
parser.add_argument(
|
|
"--dataset-out",
|
|
default="models/datasets/rain_dataset_{model_version}_{feature_set}.csv",
|
|
help=(
|
|
"Path (or template) for model-ready dataset snapshot. "
|
|
"Supports {model_version} and {feature_set}. Use empty string to disable."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--model-version",
|
|
default="rain-logreg-v1",
|
|
help="Version label stored in artifact metadata.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def parse_calibration_methods(value: str) -> list[str]:
|
|
methods: list[str] = []
|
|
for token in (value or "").split(","):
|
|
method = token.strip().lower()
|
|
if not method:
|
|
continue
|
|
if method not in CALIBRATION_METHODS:
|
|
raise ValueError(f"unknown calibration method: {method}")
|
|
if method not in methods:
|
|
methods.append(method)
|
|
if not methods:
|
|
return ["none"]
|
|
return methods
|
|
|
|
|
|
def default_model_params(model_family: str) -> dict[str, Any]:
|
|
if model_family == "logreg":
|
|
return {"c": 1.0}
|
|
if model_family == "hist_gb":
|
|
return {
|
|
"max_iter": 300,
|
|
"learning_rate": 0.05,
|
|
"max_depth": 5,
|
|
"min_samples_leaf": 20,
|
|
}
|
|
raise ValueError(f"unknown model_family: {model_family}")
|
|
|
|
|
|
def model_param_grid(model_family: str) -> list[dict[str, Any]]:
|
|
if model_family == "logreg":
|
|
return [{"c": c} for c in [0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0]]
|
|
|
|
if model_family == "hist_gb":
|
|
grid: list[dict[str, Any]] = []
|
|
for learning_rate, max_depth, max_iter, min_samples_leaf in itertools.product(
|
|
[0.03, 0.05, 0.08],
|
|
[3, 5],
|
|
[200, 300],
|
|
[10, 20],
|
|
):
|
|
grid.append(
|
|
{
|
|
"learning_rate": learning_rate,
|
|
"max_depth": max_depth,
|
|
"max_iter": max_iter,
|
|
"min_samples_leaf": min_samples_leaf,
|
|
}
|
|
)
|
|
return grid
|
|
raise ValueError(f"unknown model_family: {model_family}")
|
|
|
|
|
|
def limit_trials(grid: list[dict[str, Any]], max_trials: int) -> list[dict[str, Any]]:
|
|
if max_trials <= 0 or len(grid) <= max_trials:
|
|
return grid
|
|
step = max(1, len(grid) // max_trials)
|
|
return grid[::step][:max_trials]
|
|
|
|
|
|
def make_model(model_family: str, random_state: int, params: dict[str, Any] | None = None):
|
|
model_params = default_model_params(model_family)
|
|
if params:
|
|
model_params.update(params)
|
|
|
|
if model_family == "logreg":
|
|
c = float(model_params["c"])
|
|
return Pipeline(
|
|
[
|
|
("scaler", StandardScaler()),
|
|
(
|
|
"clf",
|
|
LogisticRegression(
|
|
C=c,
|
|
max_iter=1000,
|
|
class_weight="balanced",
|
|
random_state=random_state,
|
|
),
|
|
),
|
|
]
|
|
)
|
|
if model_family == "hist_gb":
|
|
return HistGradientBoostingClassifier(
|
|
max_iter=int(model_params["max_iter"]),
|
|
learning_rate=float(model_params["learning_rate"]),
|
|
max_depth=int(model_params["max_depth"]),
|
|
min_samples_leaf=int(model_params["min_samples_leaf"]),
|
|
random_state=random_state,
|
|
)
|
|
raise ValueError(f"unknown model_family: {model_family}")
|
|
|
|
|
|
def threshold_from_probs(
|
|
y_true: np.ndarray,
|
|
y_prob: np.ndarray,
|
|
min_precision: float,
|
|
fixed_threshold: float | None,
|
|
) -> tuple[float, dict[str, Any]]:
|
|
if fixed_threshold is not None:
|
|
return float(fixed_threshold), {
|
|
"selection_rule": "fixed_cli_threshold",
|
|
"threshold": float(fixed_threshold),
|
|
}
|
|
return select_threshold(y_true=y_true, y_prob=y_prob, min_precision=min_precision)
|
|
|
|
|
|
def metric_key(metrics: dict[str, Any]) -> tuple[float, float, float]:
|
|
return (
|
|
safe_pr_auc(metrics),
|
|
safe_roc_auc(metrics),
|
|
float(metrics["f1"]),
|
|
)
|
|
|
|
|
|
def expected_calibration_error(y_true: np.ndarray, y_prob: np.ndarray, bins: int = 10) -> float:
|
|
if len(y_true) == 0:
|
|
return float("nan")
|
|
|
|
edges = np.linspace(0.0, 1.0, bins + 1)
|
|
total = float(len(y_true))
|
|
ece = 0.0
|
|
for i in range(bins):
|
|
lo = edges[i]
|
|
hi = edges[i + 1]
|
|
if i == bins - 1:
|
|
mask = (y_prob >= lo) & (y_prob <= hi)
|
|
else:
|
|
mask = (y_prob >= lo) & (y_prob < hi)
|
|
if not np.any(mask):
|
|
continue
|
|
bucket_weight = float(np.sum(mask)) / total
|
|
bucket_confidence = float(np.mean(y_prob[mask]))
|
|
bucket_frequency = float(np.mean(y_true[mask]))
|
|
ece += bucket_weight * abs(bucket_frequency - bucket_confidence)
|
|
return float(ece)
|
|
|
|
|
|
def calibration_cv_splits(rows: int) -> int:
|
|
if rows >= 800:
|
|
return 5
|
|
if rows >= 400:
|
|
return 4
|
|
if rows >= 240:
|
|
return 3
|
|
if rows >= 140:
|
|
return 2
|
|
return 0
|
|
|
|
|
|
def fit_with_optional_calibration(
|
|
model_family: str,
|
|
model_params: dict[str, Any],
|
|
random_state: int,
|
|
x_train,
|
|
y_train: np.ndarray,
|
|
calibration_method: str,
|
|
fallback_to_none: bool = True,
|
|
):
|
|
base = make_model(model_family=model_family, random_state=random_state, params=model_params)
|
|
if calibration_method == "none":
|
|
base.fit(x_train, y_train)
|
|
return base, {
|
|
"requested_method": "none",
|
|
"effective_method": "none",
|
|
"cv_splits": 0,
|
|
}
|
|
|
|
splits = calibration_cv_splits(len(y_train))
|
|
if splits < 2:
|
|
if not fallback_to_none:
|
|
raise RuntimeError("not enough rows for calibration folds")
|
|
base.fit(x_train, y_train)
|
|
return base, {
|
|
"requested_method": calibration_method,
|
|
"effective_method": "none",
|
|
"cv_splits": splits,
|
|
"warning": "insufficient rows for calibration folds; falling back to uncalibrated model",
|
|
}
|
|
|
|
try:
|
|
calibrated = CalibratedClassifierCV(
|
|
estimator=base,
|
|
method=calibration_method,
|
|
cv=TimeSeriesSplit(n_splits=splits),
|
|
)
|
|
calibrated.fit(x_train, y_train)
|
|
return calibrated, {
|
|
"requested_method": calibration_method,
|
|
"effective_method": calibration_method,
|
|
"cv_splits": splits,
|
|
}
|
|
except Exception as exc:
|
|
if not fallback_to_none:
|
|
raise
|
|
base.fit(x_train, y_train)
|
|
return base, {
|
|
"requested_method": calibration_method,
|
|
"effective_method": "none",
|
|
"cv_splits": splits,
|
|
"warning": f"calibration failed ({exc}); falling back to uncalibrated model",
|
|
}
|
|
|
|
|
|
def hyperparameter_search(
|
|
model_family: str,
|
|
x_train,
|
|
y_train: np.ndarray,
|
|
x_val,
|
|
y_val: np.ndarray,
|
|
random_state: int,
|
|
min_precision: float,
|
|
fixed_threshold: float | None,
|
|
enabled: bool,
|
|
max_trials: int,
|
|
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
|
if enabled:
|
|
grid = limit_trials(model_param_grid(model_family), max_trials=max_trials)
|
|
else:
|
|
grid = [default_model_params(model_family)]
|
|
|
|
trials: list[dict[str, Any]] = []
|
|
for params in grid:
|
|
model = make_model(model_family=model_family, random_state=random_state, params=params)
|
|
model.fit(x_train, y_train)
|
|
y_val_prob = model.predict_proba(x_val)[:, 1]
|
|
threshold, threshold_info = threshold_from_probs(
|
|
y_true=y_val,
|
|
y_prob=y_val_prob,
|
|
min_precision=min_precision,
|
|
fixed_threshold=fixed_threshold,
|
|
)
|
|
metrics = evaluate_probs(y_true=y_val, y_prob=y_val_prob, threshold=threshold)
|
|
trials.append(
|
|
{
|
|
"model_family": model_family,
|
|
"model_params": params,
|
|
"threshold": float(threshold),
|
|
"threshold_info": threshold_info,
|
|
"validation_metrics": metrics,
|
|
}
|
|
)
|
|
|
|
best = max(trials, key=lambda v: metric_key(v["validation_metrics"]))
|
|
return best["model_params"], trials
|
|
|
|
|
|
def calibration_selection_key(result: dict[str, Any]) -> tuple[float, float, float, float, float]:
|
|
metrics = result["validation_metrics"]
|
|
brier = float(metrics["brier"])
|
|
ece_10 = float(result["calibration_quality"]["ece_10"])
|
|
return (
|
|
brier,
|
|
ece_10,
|
|
-safe_pr_auc(metrics),
|
|
-safe_roc_auc(metrics),
|
|
-float(metrics["f1"]),
|
|
)
|
|
|
|
|
|
def evaluate_calibration_methods(
|
|
model_family: str,
|
|
model_params: dict[str, Any],
|
|
calibration_methods: list[str],
|
|
x_train,
|
|
y_train: np.ndarray,
|
|
x_val,
|
|
y_val: np.ndarray,
|
|
random_state: int,
|
|
min_precision: float,
|
|
fixed_threshold: float | None,
|
|
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
|
results: list[dict[str, Any]] = []
|
|
for method in calibration_methods:
|
|
fitted, fit_info = fit_with_optional_calibration(
|
|
model_family=model_family,
|
|
model_params=model_params,
|
|
random_state=random_state,
|
|
x_train=x_train,
|
|
y_train=y_train,
|
|
calibration_method=method,
|
|
fallback_to_none=True,
|
|
)
|
|
y_val_prob = fitted.predict_proba(x_val)[:, 1]
|
|
threshold, threshold_info = threshold_from_probs(
|
|
y_true=y_val,
|
|
y_prob=y_val_prob,
|
|
min_precision=min_precision,
|
|
fixed_threshold=fixed_threshold,
|
|
)
|
|
metrics = evaluate_probs(y_true=y_val, y_prob=y_val_prob, threshold=threshold)
|
|
results.append(
|
|
{
|
|
"calibration_method": method,
|
|
"fit": fit_info,
|
|
"threshold": float(threshold),
|
|
"threshold_info": threshold_info,
|
|
"validation_metrics": metrics,
|
|
"calibration_quality": {
|
|
"ece_10": expected_calibration_error(y_true=y_val, y_prob=y_val_prob, bins=10),
|
|
},
|
|
}
|
|
)
|
|
|
|
selected = min(results, key=calibration_selection_key)
|
|
return selected, results
|
|
|
|
|
|
def evaluate_naive_baselines(test_df, y_test: np.ndarray) -> dict[str, Any]:
|
|
out: dict[str, Any] = {}
|
|
|
|
if "rain_last_1h_mm" in test_df.columns:
|
|
rain_last = test_df["rain_last_1h_mm"].to_numpy(dtype=float)
|
|
persistence_prob = (rain_last >= RAIN_EVENT_THRESHOLD_MM).astype(float)
|
|
out["persistence_last_1h"] = {
|
|
"rule": f"predict rain when rain_last_1h_mm >= {RAIN_EVENT_THRESHOLD_MM:.2f}",
|
|
"metrics": evaluate_probs(y_true=y_test, y_prob=persistence_prob, threshold=0.5),
|
|
}
|
|
|
|
has_fc_prob = "fc_precip_prob" in test_df.columns and test_df["fc_precip_prob"].notna().any()
|
|
has_fc_mm = "fc_precip_mm" in test_df.columns and test_df["fc_precip_mm"].notna().any()
|
|
if has_fc_prob:
|
|
fc_prob = test_df["fc_precip_prob"].fillna(0.0).clip(lower=0.0, upper=1.0).to_numpy(dtype=float)
|
|
out["forecast_precip_prob"] = {
|
|
"rule": "use fc_precip_prob directly as baseline probability",
|
|
"metrics": evaluate_probs(y_true=y_test, y_prob=fc_prob, threshold=0.5),
|
|
}
|
|
|
|
if has_fc_prob or has_fc_mm:
|
|
fc_prob = (
|
|
test_df["fc_precip_prob"].fillna(0.0).clip(lower=0.0, upper=1.0).to_numpy(dtype=float)
|
|
if "fc_precip_prob" in test_df.columns
|
|
else np.zeros(len(test_df), dtype=float)
|
|
)
|
|
fc_mm = (
|
|
test_df["fc_precip_mm"].fillna(0.0).to_numpy(dtype=float)
|
|
if "fc_precip_mm" in test_df.columns
|
|
else np.zeros(len(test_df), dtype=float)
|
|
)
|
|
rule_prob = ((fc_prob >= 0.5) | (fc_mm >= RAIN_EVENT_THRESHOLD_MM)).astype(float)
|
|
out["forecast_threshold_rule"] = {
|
|
"rule": (
|
|
"predict rain when (fc_precip_prob >= 0.50) "
|
|
f"OR (fc_precip_mm >= {RAIN_EVENT_THRESHOLD_MM:.2f})"
|
|
),
|
|
"metrics": evaluate_probs(y_true=y_test, y_prob=rule_prob, threshold=0.5),
|
|
}
|
|
|
|
return out
|
|
|
|
|
|
def evaluate_sliced_performance(
|
|
test_df,
|
|
y_true: np.ndarray,
|
|
y_prob: np.ndarray,
|
|
threshold: float,
|
|
min_rows_per_slice: int = 30,
|
|
) -> dict[str, Any]:
|
|
frame = pd.DataFrame(
|
|
{
|
|
"y_true": y_true.astype(int),
|
|
"y_prob": y_prob.astype(float),
|
|
},
|
|
index=test_df.index,
|
|
)
|
|
overall_rate = float(np.mean(y_true))
|
|
hour = frame.index.hour
|
|
is_day = (hour >= 6) & (hour < 18)
|
|
|
|
weekly_key = frame.index.to_series().dt.isocalendar()
|
|
week_label = weekly_key["year"].astype(str) + "-W" + weekly_key["week"].astype(str).str.zfill(2)
|
|
weekly_positive_rate = frame.groupby(week_label)["y_true"].transform("mean")
|
|
rainy_week = weekly_positive_rate >= overall_rate
|
|
|
|
rain_context = test_df["rain_last_1h_mm"].to_numpy(dtype=float) if "rain_last_1h_mm" in test_df.columns else np.zeros(len(test_df))
|
|
wet_context = rain_context >= RAIN_EVENT_THRESHOLD_MM
|
|
|
|
wind_values = test_df["wind_max_m_s"].to_numpy(dtype=float) if "wind_max_m_s" in test_df.columns else np.full(len(test_df), np.nan)
|
|
if np.isfinite(wind_values).any():
|
|
wind_q75 = float(np.nanquantile(wind_values, 0.75))
|
|
windy = np.nan_to_num(wind_values, nan=wind_q75) >= wind_q75
|
|
else:
|
|
windy = np.zeros(len(test_df), dtype=bool)
|
|
|
|
definitions: list[tuple[str, np.ndarray, str]] = [
|
|
("daytime_utc", np.asarray(is_day, dtype=bool), "06:00-17:59 UTC"),
|
|
("nighttime_utc", np.asarray(~is_day, dtype=bool), "18:00-05:59 UTC"),
|
|
("rainy_weeks", np.asarray(rainy_week, dtype=bool), "weeks with positive-rate >= test positive-rate"),
|
|
("non_rainy_weeks", np.asarray(~rainy_week, dtype=bool), "weeks with positive-rate < test positive-rate"),
|
|
("wet_context_last_1h", np.asarray(wet_context, dtype=bool), f"rain_last_1h_mm >= {RAIN_EVENT_THRESHOLD_MM:.2f}"),
|
|
("dry_context_last_1h", np.asarray(~wet_context, dtype=bool), f"rain_last_1h_mm < {RAIN_EVENT_THRESHOLD_MM:.2f}"),
|
|
("windy_q75", np.asarray(windy, dtype=bool), "wind_max_m_s >= test 75th percentile"),
|
|
("calm_below_q75", np.asarray(~windy, dtype=bool), "wind_max_m_s < test 75th percentile"),
|
|
]
|
|
|
|
out: dict[str, Any] = {}
|
|
for name, mask, description in definitions:
|
|
rows = int(np.sum(mask))
|
|
if rows == 0:
|
|
out[name] = {
|
|
"rows": rows,
|
|
"description": description,
|
|
"status": "empty",
|
|
}
|
|
continue
|
|
y_slice = y_true[mask]
|
|
p_slice = y_prob[mask]
|
|
if rows < min_rows_per_slice:
|
|
out[name] = {
|
|
"rows": rows,
|
|
"description": description,
|
|
"status": "insufficient_rows",
|
|
"min_rows_required": min_rows_per_slice,
|
|
}
|
|
continue
|
|
metrics = evaluate_probs(y_true=y_slice, y_prob=p_slice, threshold=threshold)
|
|
out[name] = {
|
|
"rows": rows,
|
|
"description": description,
|
|
"status": "ok",
|
|
"metrics": metrics,
|
|
"ece_10": expected_calibration_error(y_true=y_slice, y_prob=p_slice, bins=10),
|
|
}
|
|
return out
|
|
|
|
|
|
def walk_forward_backtest(
|
|
model_df,
|
|
feature_cols: list[str],
|
|
model_family: str,
|
|
model_params: dict[str, Any],
|
|
calibration_method: str,
|
|
random_state: int,
|
|
min_precision: float,
|
|
fixed_threshold: float | None,
|
|
folds: int,
|
|
) -> dict[str, Any]:
|
|
if folds <= 0:
|
|
return {"enabled": False, "folds": [], "summary": None}
|
|
|
|
n = len(model_df)
|
|
min_train_rows = max(200, int(0.4 * n))
|
|
remaining = n - min_train_rows
|
|
if remaining < 50:
|
|
return {
|
|
"enabled": True,
|
|
"folds": [],
|
|
"summary": {
|
|
"error": "not enough rows for walk-forward folds",
|
|
"requested_folds": folds,
|
|
"min_train_rows": min_train_rows,
|
|
},
|
|
}
|
|
|
|
fold_size = max(25, remaining // folds)
|
|
results: list[dict[str, Any]] = []
|
|
|
|
for idx in range(folds):
|
|
train_end = min_train_rows + idx * fold_size
|
|
test_end = n if idx == folds - 1 else min(min_train_rows + (idx + 1) * fold_size, n)
|
|
if train_end >= test_end:
|
|
continue
|
|
|
|
fold_train = model_df.iloc[:train_end]
|
|
fold_test = model_df.iloc[train_end:test_end]
|
|
if len(fold_train) < 160 or len(fold_test) < 25:
|
|
continue
|
|
|
|
y_fold_train = fold_train["rain_next_1h"].astype(int).to_numpy()
|
|
y_fold_test = fold_test["rain_next_1h"].astype(int).to_numpy()
|
|
if len(np.unique(y_fold_train)) < 2:
|
|
continue
|
|
|
|
inner_val_rows = max(40, int(0.2 * len(fold_train)))
|
|
if len(fold_train) - inner_val_rows < 80:
|
|
continue
|
|
inner_train = fold_train.iloc[:-inner_val_rows]
|
|
inner_val = fold_train.iloc[-inner_val_rows:]
|
|
y_inner_train = inner_train["rain_next_1h"].astype(int).to_numpy()
|
|
y_inner_val = inner_val["rain_next_1h"].astype(int).to_numpy()
|
|
if len(np.unique(y_inner_train)) < 2:
|
|
continue
|
|
|
|
try:
|
|
threshold_model, threshold_fit = fit_with_optional_calibration(
|
|
model_family=model_family,
|
|
model_params=model_params,
|
|
random_state=random_state,
|
|
x_train=inner_train[feature_cols],
|
|
y_train=y_inner_train,
|
|
calibration_method=calibration_method,
|
|
fallback_to_none=True,
|
|
)
|
|
inner_val_prob = threshold_model.predict_proba(inner_val[feature_cols])[:, 1]
|
|
fold_threshold, fold_threshold_info = threshold_from_probs(
|
|
y_true=y_inner_val,
|
|
y_prob=inner_val_prob,
|
|
min_precision=min_precision,
|
|
fixed_threshold=fixed_threshold,
|
|
)
|
|
|
|
fold_model, fold_fit = fit_with_optional_calibration(
|
|
model_family=model_family,
|
|
model_params=model_params,
|
|
random_state=random_state,
|
|
x_train=fold_train[feature_cols],
|
|
y_train=y_fold_train,
|
|
calibration_method=calibration_method,
|
|
fallback_to_none=True,
|
|
)
|
|
fold_test_prob = fold_model.predict_proba(fold_test[feature_cols])[:, 1]
|
|
fold_metrics = evaluate_probs(y_true=y_fold_test, y_prob=fold_test_prob, threshold=fold_threshold)
|
|
fold_result = {
|
|
"fold_index": idx + 1,
|
|
"train_rows": len(fold_train),
|
|
"test_rows": len(fold_test),
|
|
"train_start": fold_train.index.min(),
|
|
"train_end": fold_train.index.max(),
|
|
"test_start": fold_test.index.min(),
|
|
"test_end": fold_test.index.max(),
|
|
"threshold": float(fold_threshold),
|
|
"threshold_selection": fold_threshold_info,
|
|
"threshold_fit": threshold_fit,
|
|
"model_fit": fold_fit,
|
|
"metrics": fold_metrics,
|
|
}
|
|
results.append(fold_result)
|
|
except Exception as exc:
|
|
results.append(
|
|
{
|
|
"fold_index": idx + 1,
|
|
"train_rows": len(fold_train),
|
|
"test_rows": len(fold_test),
|
|
"error": str(exc),
|
|
}
|
|
)
|
|
|
|
good = [r for r in results if "metrics" in r]
|
|
if not good:
|
|
return {
|
|
"enabled": True,
|
|
"folds": results,
|
|
"summary": {
|
|
"error": "no successful fold evaluations",
|
|
"requested_folds": folds,
|
|
},
|
|
}
|
|
|
|
summary = {
|
|
"successful_folds": len(good),
|
|
"requested_folds": folds,
|
|
"mean_precision": float(np.mean([f["metrics"]["precision"] for f in good])),
|
|
"mean_recall": float(np.mean([f["metrics"]["recall"] for f in good])),
|
|
"mean_f1": float(np.mean([f["metrics"]["f1"] for f in good])),
|
|
"mean_brier": float(np.mean([f["metrics"]["brier"] for f in good])),
|
|
"mean_pr_auc": float(np.mean([f["metrics"]["pr_auc"] for f in good if f["metrics"]["pr_auc"] is not None]))
|
|
if any(f["metrics"]["pr_auc"] is not None for f in good)
|
|
else None,
|
|
"mean_roc_auc": float(np.mean([f["metrics"]["roc_auc"] for f in good if f["metrics"]["roc_auc"] is not None]))
|
|
if any(f["metrics"]["roc_auc"] is not None for f in good)
|
|
else None,
|
|
}
|
|
return {"enabled": True, "folds": results, "summary": summary}
|
|
|
|
|
|
def write_model_card(path: str, report: dict[str, Any]) -> None:
|
|
lines = [
|
|
"# Rain Model Card",
|
|
"",
|
|
f"- Model version: `{report['model_version']}`",
|
|
f"- Generated at (UTC): `{report['generated_at']}`",
|
|
f"- Site: `{report['site']}`",
|
|
f"- Target: `{report['target_definition']}`",
|
|
f"- Feature set: `{report['feature_set']}`",
|
|
f"- Model family: `{report['model_family']}`",
|
|
f"- Model params: `{report['model_params']}`",
|
|
f"- Calibration method: `{report['calibration_method']}`",
|
|
f"- Operating threshold: `{report['threshold_selection']['threshold']:.3f}`",
|
|
"",
|
|
"## Data Window",
|
|
"",
|
|
f"- Requested start: `{report['data_window']['requested_start']}`",
|
|
f"- Requested end: `{report['data_window']['requested_end']}`",
|
|
f"- Actual start: `{report['data_window']['actual_start']}`",
|
|
f"- Actual end: `{report['data_window']['actual_end']}`",
|
|
f"- Rows: train `{report['split']['train_rows']}`, val `{report['split']['val_rows']}`, test `{report['split']['test_rows']}`",
|
|
"",
|
|
"## Features",
|
|
"",
|
|
]
|
|
for col in report["feature_columns"]:
|
|
lines.append(f"- `{col}`")
|
|
|
|
lines.extend(
|
|
[
|
|
"",
|
|
"## Performance",
|
|
"",
|
|
"- Validation:",
|
|
f" precision `{report['validation_metrics']['precision']:.3f}`, "
|
|
f"recall `{report['validation_metrics']['recall']:.3f}`, "
|
|
f"PR-AUC `{report['validation_metrics']['pr_auc']}`, "
|
|
f"ROC-AUC `{report['validation_metrics']['roc_auc']}`, "
|
|
f"Brier `{report['validation_metrics']['brier']:.4f}`",
|
|
"- Test:",
|
|
f" precision `{report['test_metrics']['precision']:.3f}`, "
|
|
f"recall `{report['test_metrics']['recall']:.3f}`, "
|
|
f"PR-AUC `{report['test_metrics']['pr_auc']}`, "
|
|
f"ROC-AUC `{report['test_metrics']['roc_auc']}`, "
|
|
f"Brier `{report['test_metrics']['brier']:.4f}`",
|
|
"",
|
|
"## Sliced Performance (Test)",
|
|
"",
|
|
]
|
|
)
|
|
for slice_name, info in report.get("sliced_performance_test", {}).items():
|
|
if info.get("status") != "ok":
|
|
continue
|
|
metrics = info["metrics"]
|
|
lines.append(
|
|
f"- `{slice_name}` ({info['rows']} rows): "
|
|
f"precision `{metrics['precision']:.3f}`, "
|
|
f"recall `{metrics['recall']:.3f}`, "
|
|
f"PR-AUC `{metrics['pr_auc']}`, "
|
|
f"Brier `{metrics['brier']:.4f}`"
|
|
)
|
|
|
|
lines.extend(
|
|
[
|
|
"",
|
|
"## Known Limitations",
|
|
"",
|
|
"- Sensor rain counter resets are clipped at zero increment; extreme spikes are flagged but not fully removed.",
|
|
"- Forecast feature availability can vary by upstream model freshness.",
|
|
"- Performance may drift seasonally and should be monitored with the drift queries in docs.",
|
|
"",
|
|
]
|
|
)
|
|
|
|
card_dir = os.path.dirname(path)
|
|
if card_dir:
|
|
os.makedirs(card_dir, exist_ok=True)
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
f.write("\n".join(lines))
|
|
|
|
|
|
def train_candidate(
|
|
model_family: str,
|
|
x_train,
|
|
y_train: np.ndarray,
|
|
x_val,
|
|
y_val: np.ndarray,
|
|
random_state: int,
|
|
min_precision: float,
|
|
fixed_threshold: float | None,
|
|
tune_hyperparameters: bool,
|
|
max_hyperparam_trials: int,
|
|
calibration_methods: list[str],
|
|
) -> dict[str, Any]:
|
|
best_params, tuning_trials = hyperparameter_search(
|
|
model_family=model_family,
|
|
x_train=x_train,
|
|
y_train=y_train,
|
|
x_val=x_val,
|
|
y_val=y_val,
|
|
random_state=random_state,
|
|
min_precision=min_precision,
|
|
fixed_threshold=fixed_threshold,
|
|
enabled=tune_hyperparameters,
|
|
max_trials=max_hyperparam_trials,
|
|
)
|
|
|
|
selected_calibration, calibration_trials = evaluate_calibration_methods(
|
|
model_family=model_family,
|
|
model_params=best_params,
|
|
calibration_methods=calibration_methods,
|
|
x_train=x_train,
|
|
y_train=y_train,
|
|
x_val=x_val,
|
|
y_val=y_val,
|
|
random_state=random_state,
|
|
min_precision=min_precision,
|
|
fixed_threshold=fixed_threshold,
|
|
)
|
|
|
|
return {
|
|
"model_family": model_family,
|
|
"model_params": best_params,
|
|
"hyperparameter_tuning": {
|
|
"enabled": tune_hyperparameters,
|
|
"trial_count": len(tuning_trials),
|
|
"trials": tuning_trials,
|
|
},
|
|
"calibration_comparison": calibration_trials,
|
|
"calibration_method": selected_calibration["calibration_method"],
|
|
"calibration_fit": selected_calibration["fit"],
|
|
"threshold": float(selected_calibration["threshold"]),
|
|
"threshold_info": selected_calibration["threshold_info"],
|
|
"validation_metrics": selected_calibration["validation_metrics"],
|
|
"calibration_quality": selected_calibration["calibration_quality"],
|
|
}
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
if not args.db_url:
|
|
raise SystemExit("missing --db-url or DATABASE_URL")
|
|
|
|
start = parse_time(args.start) if args.start else ""
|
|
end = parse_time(args.end) if args.end else ""
|
|
feature_cols = feature_columns_for_set(args.feature_set)
|
|
needs_forecast = feature_columns_need_forecast(feature_cols)
|
|
calibration_methods = parse_calibration_methods(args.calibration_methods)
|
|
|
|
with psycopg2.connect(args.db_url) as conn:
|
|
ws90 = fetch_ws90(conn, args.site, start, end)
|
|
baro = fetch_baro(conn, args.site, start, end)
|
|
forecast = None
|
|
if needs_forecast:
|
|
forecast = fetch_forecast(conn, args.site, start, end, model=args.forecast_model)
|
|
|
|
if ws90.empty:
|
|
message = "no ws90 observations found in requested window"
|
|
if args.allow_empty:
|
|
print(f"Rain model training skipped: {message}.")
|
|
return 0
|
|
raise RuntimeError(message)
|
|
if baro.empty:
|
|
message = "no barometer observations found in requested window"
|
|
if args.allow_empty:
|
|
print(f"Rain model training skipped: {message}.")
|
|
return 0
|
|
raise RuntimeError(message)
|
|
|
|
full_df = build_dataset(ws90, baro, forecast=forecast, rain_event_threshold_mm=RAIN_EVENT_THRESHOLD_MM)
|
|
model_df = model_frame(full_df, feature_cols, require_target=True)
|
|
if len(model_df) < args.min_rows:
|
|
message = f"not enough model-ready rows after filtering (need >= {args.min_rows})"
|
|
if args.allow_empty:
|
|
print(f"Rain model training skipped: {message}.")
|
|
return 0
|
|
raise RuntimeError(message)
|
|
|
|
train_df, val_df, test_df = split_time_ordered(
|
|
model_df,
|
|
train_ratio=args.train_ratio,
|
|
val_ratio=args.val_ratio,
|
|
)
|
|
|
|
x_train = train_df[feature_cols]
|
|
y_train = train_df["rain_next_1h"].astype(int).to_numpy()
|
|
x_val = val_df[feature_cols]
|
|
y_val = val_df["rain_next_1h"].astype(int).to_numpy()
|
|
x_test = test_df[feature_cols]
|
|
y_test = test_df["rain_next_1h"].astype(int).to_numpy()
|
|
|
|
if len(np.unique(y_train)) < 2:
|
|
raise RuntimeError("training split does not contain both classes; cannot train classifier")
|
|
if len(np.unique(y_val)) < 2:
|
|
print("warning: validation split has one class; AUC metrics may be unavailable", flush=True)
|
|
|
|
candidate_families = ["logreg", "hist_gb"] if args.model_family == "auto" else [args.model_family]
|
|
candidates = [
|
|
train_candidate(
|
|
model_family=family,
|
|
x_train=x_train,
|
|
y_train=y_train,
|
|
x_val=x_val,
|
|
y_val=y_val,
|
|
random_state=args.random_state,
|
|
min_precision=args.min_precision,
|
|
fixed_threshold=args.threshold,
|
|
tune_hyperparameters=args.tune_hyperparameters,
|
|
max_hyperparam_trials=args.max_hyperparam_trials,
|
|
calibration_methods=calibration_methods,
|
|
)
|
|
for family in candidate_families
|
|
]
|
|
best_candidate = max(
|
|
candidates,
|
|
key=lambda c: metric_key(c["validation_metrics"]),
|
|
)
|
|
selected_model_family = str(best_candidate["model_family"])
|
|
selected_model_params = best_candidate["model_params"]
|
|
selected_calibration_method = str(best_candidate["calibration_method"])
|
|
chosen_threshold = float(best_candidate["threshold"])
|
|
threshold_info = best_candidate["threshold_info"]
|
|
val_metrics = best_candidate["validation_metrics"]
|
|
|
|
train_val_df = model_df.iloc[: len(train_df) + len(val_df)]
|
|
x_train_val = train_val_df[feature_cols]
|
|
y_train_val = train_val_df["rain_next_1h"].astype(int).to_numpy()
|
|
|
|
final_model, final_fit_info = fit_with_optional_calibration(
|
|
model_family=selected_model_family,
|
|
model_params=selected_model_params,
|
|
random_state=args.random_state,
|
|
x_train=x_train_val,
|
|
y_train=y_train_val,
|
|
calibration_method=selected_calibration_method,
|
|
fallback_to_none=True,
|
|
)
|
|
y_test_prob = final_model.predict_proba(x_test)[:, 1]
|
|
test_metrics = evaluate_probs(y_true=y_test, y_prob=y_test_prob, threshold=chosen_threshold)
|
|
test_calibration = {
|
|
"ece_10": expected_calibration_error(y_true=y_test, y_prob=y_test_prob, bins=10),
|
|
}
|
|
naive_baselines_test = evaluate_naive_baselines(test_df=test_df, y_test=y_test)
|
|
sliced_performance = evaluate_sliced_performance(
|
|
test_df=test_df,
|
|
y_true=y_test,
|
|
y_prob=y_test_prob,
|
|
threshold=chosen_threshold,
|
|
)
|
|
walk_forward = walk_forward_backtest(
|
|
model_df=model_df,
|
|
feature_cols=feature_cols,
|
|
model_family=selected_model_family,
|
|
model_params=selected_model_params,
|
|
calibration_method=selected_calibration_method,
|
|
random_state=args.random_state,
|
|
min_precision=args.min_precision,
|
|
fixed_threshold=args.threshold,
|
|
folds=args.walk_forward_folds,
|
|
)
|
|
|
|
report = {
|
|
"generated_at": datetime.now(timezone.utc).isoformat(),
|
|
"site": args.site,
|
|
"model_version": args.model_version,
|
|
"model_family_requested": args.model_family,
|
|
"model_family": selected_model_family,
|
|
"model_params": selected_model_params,
|
|
"feature_set": args.feature_set,
|
|
"target_definition": f"rain_next_1h_mm >= {RAIN_EVENT_THRESHOLD_MM:.2f}",
|
|
"feature_columns": feature_cols,
|
|
"forecast_model": args.forecast_model if needs_forecast else None,
|
|
"calibration_method_requested": calibration_methods,
|
|
"calibration_method": selected_calibration_method,
|
|
"calibration_fit": final_fit_info,
|
|
"data_window": {
|
|
"requested_start": start or None,
|
|
"requested_end": end or None,
|
|
"actual_start": model_df.index.min(),
|
|
"actual_end": model_df.index.max(),
|
|
"model_rows": len(model_df),
|
|
"ws90_rows": len(ws90),
|
|
"baro_rows": len(baro),
|
|
"forecast_rows": int(len(forecast)) if forecast is not None else 0,
|
|
},
|
|
"label_quality": {
|
|
"rain_reset_count": int(np.nansum(full_df["rain_reset"].fillna(False).to_numpy(dtype=int))),
|
|
"rain_spike_5m_count": int(np.nansum(full_df["rain_spike_5m"].fillna(False).to_numpy(dtype=int))),
|
|
},
|
|
"feature_missingness_ratio": {col: float(full_df[col].isna().mean()) for col in feature_cols if col in full_df.columns},
|
|
"split": {
|
|
"train_ratio": args.train_ratio,
|
|
"val_ratio": args.val_ratio,
|
|
"train_rows": len(train_df),
|
|
"val_rows": len(val_df),
|
|
"test_rows": len(test_df),
|
|
"train_start": train_df.index.min(),
|
|
"train_end": train_df.index.max(),
|
|
"val_start": val_df.index.min(),
|
|
"val_end": val_df.index.max(),
|
|
"test_start": test_df.index.min(),
|
|
"test_end": test_df.index.max(),
|
|
},
|
|
"threshold_selection": {
|
|
**threshold_info,
|
|
"min_precision_constraint": args.min_precision,
|
|
},
|
|
"candidate_models": [
|
|
{
|
|
"model_family": c["model_family"],
|
|
"model_params": c["model_params"],
|
|
"hyperparameter_tuning": c["hyperparameter_tuning"],
|
|
"calibration_method": c["calibration_method"],
|
|
"calibration_fit": c["calibration_fit"],
|
|
"calibration_comparison": c["calibration_comparison"],
|
|
"threshold_selection": {
|
|
**c["threshold_info"],
|
|
"min_precision_constraint": args.min_precision,
|
|
},
|
|
"calibration_quality": c["calibration_quality"],
|
|
"validation_metrics": c["validation_metrics"],
|
|
}
|
|
for c in candidates
|
|
],
|
|
"validation_metrics": val_metrics,
|
|
"test_metrics": test_metrics,
|
|
"test_calibration_quality": test_calibration,
|
|
"naive_baselines_test": naive_baselines_test,
|
|
"sliced_performance_test": sliced_performance,
|
|
"walk_forward_backtest": walk_forward,
|
|
}
|
|
report = to_builtin(report)
|
|
|
|
print("Rain model training summary:")
|
|
print(f" site: {args.site}")
|
|
print(f" model_version: {args.model_version}")
|
|
print(f" model_family: {selected_model_family} (requested={args.model_family})")
|
|
print(f" model_params: {selected_model_params}")
|
|
print(f" calibration_method: {report['calibration_method']}")
|
|
print(f" feature_set: {args.feature_set} ({len(feature_cols)} features)")
|
|
print(
|
|
" rows: "
|
|
f"total={report['data_window']['model_rows']} "
|
|
f"train={report['split']['train_rows']} "
|
|
f"val={report['split']['val_rows']} "
|
|
f"test={report['split']['test_rows']}"
|
|
)
|
|
print(
|
|
" threshold: "
|
|
f"{report['threshold_selection']['threshold']:.3f} "
|
|
f"({report['threshold_selection']['selection_rule']})"
|
|
)
|
|
print(
|
|
" val metrics: "
|
|
f"precision={report['validation_metrics']['precision']:.3f} "
|
|
f"recall={report['validation_metrics']['recall']:.3f} "
|
|
f"roc_auc={report['validation_metrics']['roc_auc'] if report['validation_metrics']['roc_auc'] is not None else 'n/a'} "
|
|
f"pr_auc={report['validation_metrics']['pr_auc'] if report['validation_metrics']['pr_auc'] is not None else 'n/a'}"
|
|
)
|
|
print(
|
|
" test metrics: "
|
|
f"precision={report['test_metrics']['precision']:.3f} "
|
|
f"recall={report['test_metrics']['recall']:.3f} "
|
|
f"roc_auc={report['test_metrics']['roc_auc'] if report['test_metrics']['roc_auc'] is not None else 'n/a'} "
|
|
f"pr_auc={report['test_metrics']['pr_auc'] if report['test_metrics']['pr_auc'] is not None else 'n/a'} "
|
|
f"brier={report['test_metrics']['brier']:.4f} "
|
|
f"ece10={report['test_calibration_quality']['ece_10']:.4f}"
|
|
)
|
|
if report["walk_forward_backtest"]["summary"] is not None:
|
|
summary = report["walk_forward_backtest"]["summary"]
|
|
if "error" in summary:
|
|
print(f" walk-forward: {summary['error']}")
|
|
else:
|
|
print(
|
|
" walk-forward: "
|
|
f"folds={summary['successful_folds']}/{summary['requested_folds']} "
|
|
f"mean_precision={summary['mean_precision']:.3f} "
|
|
f"mean_recall={summary['mean_recall']:.3f} "
|
|
f"mean_pr_auc={summary['mean_pr_auc'] if summary['mean_pr_auc'] is not None else 'n/a'}"
|
|
)
|
|
if report["naive_baselines_test"]:
|
|
print(" naive baselines (test):")
|
|
for name, baseline in report["naive_baselines_test"].items():
|
|
m = baseline["metrics"]
|
|
print(
|
|
f" {name}: "
|
|
f"precision={m['precision']:.3f} recall={m['recall']:.3f} "
|
|
f"pr_auc={m['pr_auc'] if m['pr_auc'] is not None else 'n/a'} "
|
|
f"brier={m['brier']:.4f}"
|
|
)
|
|
sliced_ok = [
|
|
(name, item)
|
|
for name, item in report["sliced_performance_test"].items()
|
|
if item.get("status") == "ok"
|
|
]
|
|
if sliced_ok:
|
|
print(" sliced performance (test):")
|
|
for name, item in sliced_ok:
|
|
m = item["metrics"]
|
|
print(
|
|
f" {name}: rows={item['rows']} "
|
|
f"precision={m['precision']:.3f} recall={m['recall']:.3f} "
|
|
f"pr_auc={m['pr_auc'] if m['pr_auc'] is not None else 'n/a'} "
|
|
f"brier={m['brier']:.4f}"
|
|
)
|
|
|
|
if args.report_out:
|
|
report_dir = os.path.dirname(args.report_out)
|
|
if report_dir:
|
|
os.makedirs(report_dir, exist_ok=True)
|
|
with open(args.report_out, "w", encoding="utf-8") as f:
|
|
json.dump(report, f, indent=2)
|
|
print(f"Saved report to {args.report_out}")
|
|
|
|
if args.model_card_out:
|
|
model_card_out = args.model_card_out.format(model_version=args.model_version)
|
|
write_model_card(model_card_out, report)
|
|
print(f"Saved model card to {model_card_out}")
|
|
|
|
if args.dataset_out:
|
|
dataset_out = args.dataset_out.format(model_version=args.model_version, feature_set=args.feature_set)
|
|
dataset_dir = os.path.dirname(dataset_out)
|
|
if dataset_dir:
|
|
os.makedirs(dataset_dir, exist_ok=True)
|
|
snapshot_cols = list(dict.fromkeys(feature_cols + ["rain_next_1h", "rain_next_1h_mm"]))
|
|
model_df[snapshot_cols].to_csv(dataset_out, index=True, index_label="ts")
|
|
print(f"Saved dataset snapshot to {dataset_out}")
|
|
|
|
if args.out:
|
|
out_dir = os.path.dirname(args.out)
|
|
if out_dir:
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
if joblib is None:
|
|
print("joblib not installed; skipping model save.")
|
|
else:
|
|
artifact = {
|
|
"model": final_model,
|
|
"model_family": selected_model_family,
|
|
"model_params": selected_model_params,
|
|
"calibration_method": selected_calibration_method,
|
|
"calibration_fit": final_fit_info,
|
|
"features": feature_cols,
|
|
"feature_set": args.feature_set,
|
|
"forecast_model": args.forecast_model if needs_forecast else None,
|
|
"threshold": float(chosen_threshold),
|
|
"target_mm": float(RAIN_EVENT_THRESHOLD_MM),
|
|
"model_version": args.model_version,
|
|
"trained_at": datetime.now(timezone.utc).isoformat(),
|
|
"split": report["split"],
|
|
"threshold_selection": report["threshold_selection"],
|
|
}
|
|
joblib.dump(artifact, args.out)
|
|
print(f"Saved model to {args.out}")
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|