362 lines
13 KiB
Python
362 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
import argparse
|
|
import json
|
|
import os
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
import psycopg2
|
|
from sklearn.ensemble import HistGradientBoostingClassifier
|
|
from sklearn.linear_model import LogisticRegression
|
|
from sklearn.pipeline import Pipeline
|
|
from sklearn.preprocessing import StandardScaler
|
|
|
|
from rain_model_common import (
|
|
AVAILABLE_FEATURE_SETS,
|
|
RAIN_EVENT_THRESHOLD_MM,
|
|
build_dataset,
|
|
evaluate_probs,
|
|
fetch_baro,
|
|
fetch_forecast,
|
|
fetch_ws90,
|
|
feature_columns_for_set,
|
|
feature_columns_need_forecast,
|
|
model_frame,
|
|
parse_time,
|
|
safe_pr_auc,
|
|
safe_roc_auc,
|
|
select_threshold,
|
|
split_time_ordered,
|
|
to_builtin,
|
|
)
|
|
|
|
try:
|
|
import joblib
|
|
except ImportError: # pragma: no cover - optional dependency
|
|
joblib = None
|
|
|
|
|
|
MODEL_FAMILIES = ("logreg", "hist_gb", "auto")
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Train a rain prediction model (next 1h >= 0.2mm).")
|
|
parser.add_argument("--db-url", default=os.getenv("DATABASE_URL"), help="Postgres connection string.")
|
|
parser.add_argument("--site", required=True, help="Site name (e.g. home).")
|
|
parser.add_argument("--start", help="Start time (RFC3339 or YYYY-MM-DD).")
|
|
parser.add_argument("--end", help="End time (RFC3339 or YYYY-MM-DD).")
|
|
parser.add_argument("--train-ratio", type=float, default=0.7, help="Time-ordered train split ratio.")
|
|
parser.add_argument("--val-ratio", type=float, default=0.15, help="Time-ordered validation split ratio.")
|
|
parser.add_argument(
|
|
"--min-precision",
|
|
type=float,
|
|
default=0.7,
|
|
help="Minimum validation precision for threshold selection.",
|
|
)
|
|
parser.add_argument("--threshold", type=float, help="Optional fixed classification threshold.")
|
|
parser.add_argument("--min-rows", type=int, default=200, help="Minimum model-ready rows required.")
|
|
parser.add_argument(
|
|
"--feature-set",
|
|
default="baseline",
|
|
choices=AVAILABLE_FEATURE_SETS,
|
|
help="Named feature set to train with.",
|
|
)
|
|
parser.add_argument(
|
|
"--forecast-model",
|
|
default="ecmwf",
|
|
help="Forecast model name when feature set requires forecast columns.",
|
|
)
|
|
parser.add_argument(
|
|
"--model-family",
|
|
default="logreg",
|
|
choices=MODEL_FAMILIES,
|
|
help=(
|
|
"Estimator family. "
|
|
"'auto' compares logreg and hist_gb on validation and selects best by PR-AUC/ROC-AUC/F1."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--random-state",
|
|
type=int,
|
|
default=42,
|
|
help="Random seed for stochastic estimators.",
|
|
)
|
|
parser.add_argument("--out", default="models/rain_model.pkl", help="Path to save model.")
|
|
parser.add_argument(
|
|
"--report-out",
|
|
default="models/rain_model_report.json",
|
|
help="Path to save JSON training report.",
|
|
)
|
|
parser.add_argument(
|
|
"--dataset-out",
|
|
default="models/datasets/rain_dataset_{model_version}_{feature_set}.csv",
|
|
help=(
|
|
"Path (or template) for model-ready dataset snapshot. "
|
|
"Supports {model_version} and {feature_set}. Use empty string to disable."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--model-version",
|
|
default="rain-logreg-v1",
|
|
help="Version label stored in artifact metadata.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def make_model(model_family: str, random_state: int):
|
|
if model_family == "logreg":
|
|
return Pipeline(
|
|
[
|
|
("scaler", StandardScaler()),
|
|
("clf", LogisticRegression(max_iter=1000, class_weight="balanced", random_state=random_state)),
|
|
]
|
|
)
|
|
if model_family == "hist_gb":
|
|
return HistGradientBoostingClassifier(
|
|
max_iter=300,
|
|
learning_rate=0.05,
|
|
max_depth=5,
|
|
min_samples_leaf=20,
|
|
random_state=random_state,
|
|
)
|
|
raise ValueError(f"unknown model_family: {model_family}")
|
|
|
|
|
|
def train_candidate(
|
|
model_family: str,
|
|
x_train,
|
|
y_train: np.ndarray,
|
|
x_val,
|
|
y_val: np.ndarray,
|
|
random_state: int,
|
|
min_precision: float,
|
|
fixed_threshold: float | None,
|
|
) -> dict[str, Any]:
|
|
model = make_model(model_family=model_family, random_state=random_state)
|
|
model.fit(x_train, y_train)
|
|
y_val_prob = model.predict_proba(x_val)[:, 1]
|
|
|
|
if fixed_threshold is not None:
|
|
threshold = fixed_threshold
|
|
threshold_info = {
|
|
"selection_rule": "fixed_cli_threshold",
|
|
"threshold": float(fixed_threshold),
|
|
}
|
|
else:
|
|
threshold, threshold_info = select_threshold(
|
|
y_true=y_val,
|
|
y_prob=y_val_prob,
|
|
min_precision=min_precision,
|
|
)
|
|
|
|
val_metrics = evaluate_probs(y_true=y_val, y_prob=y_val_prob, threshold=threshold)
|
|
return {
|
|
"model_family": model_family,
|
|
"threshold": float(threshold),
|
|
"threshold_info": threshold_info,
|
|
"validation_metrics": val_metrics,
|
|
}
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
if not args.db_url:
|
|
raise SystemExit("missing --db-url or DATABASE_URL")
|
|
|
|
start = parse_time(args.start) if args.start else ""
|
|
end = parse_time(args.end) if args.end else ""
|
|
feature_cols = feature_columns_for_set(args.feature_set)
|
|
needs_forecast = feature_columns_need_forecast(feature_cols)
|
|
|
|
with psycopg2.connect(args.db_url) as conn:
|
|
ws90 = fetch_ws90(conn, args.site, start, end)
|
|
baro = fetch_baro(conn, args.site, start, end)
|
|
forecast = None
|
|
if needs_forecast:
|
|
forecast = fetch_forecast(conn, args.site, start, end, model=args.forecast_model)
|
|
|
|
full_df = build_dataset(ws90, baro, forecast=forecast, rain_event_threshold_mm=RAIN_EVENT_THRESHOLD_MM)
|
|
model_df = model_frame(full_df, feature_cols, require_target=True)
|
|
if len(model_df) < args.min_rows:
|
|
raise RuntimeError(f"not enough model-ready rows after filtering (need >= {args.min_rows})")
|
|
|
|
train_df, val_df, test_df = split_time_ordered(
|
|
model_df,
|
|
train_ratio=args.train_ratio,
|
|
val_ratio=args.val_ratio,
|
|
)
|
|
|
|
x_train = train_df[feature_cols]
|
|
y_train = train_df["rain_next_1h"].astype(int).to_numpy()
|
|
x_val = val_df[feature_cols]
|
|
y_val = val_df["rain_next_1h"].astype(int).to_numpy()
|
|
x_test = test_df[feature_cols]
|
|
y_test = test_df["rain_next_1h"].astype(int).to_numpy()
|
|
|
|
candidate_families = ["logreg", "hist_gb"] if args.model_family == "auto" else [args.model_family]
|
|
candidates = [
|
|
train_candidate(
|
|
model_family=family,
|
|
x_train=x_train,
|
|
y_train=y_train,
|
|
x_val=x_val,
|
|
y_val=y_val,
|
|
random_state=args.random_state,
|
|
min_precision=args.min_precision,
|
|
fixed_threshold=args.threshold,
|
|
)
|
|
for family in candidate_families
|
|
]
|
|
best_candidate = max(
|
|
candidates,
|
|
key=lambda c: (
|
|
safe_pr_auc(c["validation_metrics"]),
|
|
safe_roc_auc(c["validation_metrics"]),
|
|
float(c["validation_metrics"]["f1"]),
|
|
),
|
|
)
|
|
selected_model_family = str(best_candidate["model_family"])
|
|
chosen_threshold = float(best_candidate["threshold"])
|
|
threshold_info = best_candidate["threshold_info"]
|
|
val_metrics = best_candidate["validation_metrics"]
|
|
|
|
train_val_df = model_df.iloc[: len(train_df) + len(val_df)]
|
|
x_train_val = train_val_df[feature_cols]
|
|
y_train_val = train_val_df["rain_next_1h"].astype(int).to_numpy()
|
|
|
|
final_model = make_model(model_family=selected_model_family, random_state=args.random_state)
|
|
final_model.fit(x_train_val, y_train_val)
|
|
y_test_prob = final_model.predict_proba(x_test)[:, 1]
|
|
test_metrics = evaluate_probs(y_true=y_test, y_prob=y_test_prob, threshold=chosen_threshold)
|
|
|
|
report = {
|
|
"generated_at": datetime.now(timezone.utc).isoformat(),
|
|
"site": args.site,
|
|
"model_version": args.model_version,
|
|
"model_family_requested": args.model_family,
|
|
"model_family": selected_model_family,
|
|
"feature_set": args.feature_set,
|
|
"target_definition": f"rain_next_1h_mm >= {RAIN_EVENT_THRESHOLD_MM:.2f}",
|
|
"feature_columns": feature_cols,
|
|
"forecast_model": args.forecast_model if needs_forecast else None,
|
|
"data_window": {
|
|
"requested_start": start or None,
|
|
"requested_end": end or None,
|
|
"actual_start": model_df.index.min(),
|
|
"actual_end": model_df.index.max(),
|
|
"model_rows": len(model_df),
|
|
"ws90_rows": len(ws90),
|
|
"baro_rows": len(baro),
|
|
"forecast_rows": int(len(forecast)) if forecast is not None else 0,
|
|
},
|
|
"label_quality": {
|
|
"rain_reset_count": int(np.nansum(full_df["rain_reset"].fillna(False).to_numpy(dtype=int))),
|
|
"rain_spike_5m_count": int(np.nansum(full_df["rain_spike_5m"].fillna(False).to_numpy(dtype=int))),
|
|
},
|
|
"feature_missingness_ratio": {col: float(full_df[col].isna().mean()) for col in feature_cols if col in full_df.columns},
|
|
"split": {
|
|
"train_ratio": args.train_ratio,
|
|
"val_ratio": args.val_ratio,
|
|
"train_rows": len(train_df),
|
|
"val_rows": len(val_df),
|
|
"test_rows": len(test_df),
|
|
"train_start": train_df.index.min(),
|
|
"train_end": train_df.index.max(),
|
|
"val_start": val_df.index.min(),
|
|
"val_end": val_df.index.max(),
|
|
"test_start": test_df.index.min(),
|
|
"test_end": test_df.index.max(),
|
|
},
|
|
"threshold_selection": {
|
|
**threshold_info,
|
|
"min_precision_constraint": args.min_precision,
|
|
},
|
|
"candidate_models": [
|
|
{
|
|
"model_family": c["model_family"],
|
|
"threshold_selection": {
|
|
**c["threshold_info"],
|
|
"min_precision_constraint": args.min_precision,
|
|
},
|
|
"validation_metrics": c["validation_metrics"],
|
|
}
|
|
for c in candidates
|
|
],
|
|
"validation_metrics": val_metrics,
|
|
"test_metrics": test_metrics,
|
|
}
|
|
report = to_builtin(report)
|
|
|
|
print("Rain model training summary:")
|
|
print(f" site: {args.site}")
|
|
print(f" model_version: {args.model_version}")
|
|
print(f" model_family: {selected_model_family} (requested={args.model_family})")
|
|
print(f" feature_set: {args.feature_set} ({len(feature_cols)} features)")
|
|
print(f" rows: total={report['data_window']['model_rows']} train={report['split']['train_rows']} val={report['split']['val_rows']} test={report['split']['test_rows']}")
|
|
print(
|
|
" threshold: "
|
|
f"{report['threshold_selection']['threshold']:.3f} "
|
|
f"({report['threshold_selection']['selection_rule']})"
|
|
)
|
|
print(
|
|
" val metrics: "
|
|
f"precision={report['validation_metrics']['precision']:.3f} "
|
|
f"recall={report['validation_metrics']['recall']:.3f} "
|
|
f"roc_auc={report['validation_metrics']['roc_auc'] if report['validation_metrics']['roc_auc'] is not None else 'n/a'} "
|
|
f"pr_auc={report['validation_metrics']['pr_auc'] if report['validation_metrics']['pr_auc'] is not None else 'n/a'}"
|
|
)
|
|
print(
|
|
" test metrics: "
|
|
f"precision={report['test_metrics']['precision']:.3f} "
|
|
f"recall={report['test_metrics']['recall']:.3f} "
|
|
f"roc_auc={report['test_metrics']['roc_auc'] if report['test_metrics']['roc_auc'] is not None else 'n/a'} "
|
|
f"pr_auc={report['test_metrics']['pr_auc'] if report['test_metrics']['pr_auc'] is not None else 'n/a'}"
|
|
)
|
|
|
|
if args.report_out:
|
|
report_dir = os.path.dirname(args.report_out)
|
|
if report_dir:
|
|
os.makedirs(report_dir, exist_ok=True)
|
|
with open(args.report_out, "w", encoding="utf-8") as f:
|
|
json.dump(report, f, indent=2)
|
|
print(f"Saved report to {args.report_out}")
|
|
|
|
if args.dataset_out:
|
|
dataset_out = args.dataset_out.format(model_version=args.model_version, feature_set=args.feature_set)
|
|
dataset_dir = os.path.dirname(dataset_out)
|
|
if dataset_dir:
|
|
os.makedirs(dataset_dir, exist_ok=True)
|
|
snapshot_cols = list(dict.fromkeys(feature_cols + ["rain_next_1h", "rain_next_1h_mm"]))
|
|
model_df[snapshot_cols].to_csv(dataset_out, index=True, index_label="ts")
|
|
print(f"Saved dataset snapshot to {dataset_out}")
|
|
|
|
if args.out:
|
|
out_dir = os.path.dirname(args.out)
|
|
if out_dir:
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
if joblib is None:
|
|
print("joblib not installed; skipping model save.")
|
|
else:
|
|
artifact = {
|
|
"model": final_model,
|
|
"model_family": selected_model_family,
|
|
"features": feature_cols,
|
|
"feature_set": args.feature_set,
|
|
"forecast_model": args.forecast_model if needs_forecast else None,
|
|
"threshold": float(chosen_threshold),
|
|
"target_mm": float(RAIN_EVENT_THRESHOLD_MM),
|
|
"model_version": args.model_version,
|
|
"trained_at": datetime.now(timezone.utc).isoformat(),
|
|
"split": report["split"],
|
|
"threshold_selection": report["threshold_selection"],
|
|
}
|
|
joblib.dump(artifact, args.out)
|
|
print(f"Saved model to {args.out}")
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|