Files
go-weatherstation/scripts/train_rain_model.py

228 lines
8.4 KiB
Python

#!/usr/bin/env python3
import argparse
import json
import os
from datetime import datetime, timezone
import numpy as np
import psycopg2
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from rain_model_common import (
FEATURE_COLUMNS,
RAIN_EVENT_THRESHOLD_MM,
build_dataset,
evaluate_probs,
fetch_baro,
fetch_ws90,
model_frame,
parse_time,
select_threshold,
split_time_ordered,
to_builtin,
)
try:
import joblib
except ImportError: # pragma: no cover - optional dependency
joblib = None
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train a rain prediction model (next 1h >= 0.2mm).")
parser.add_argument("--db-url", default=os.getenv("DATABASE_URL"), help="Postgres connection string.")
parser.add_argument("--site", required=True, help="Site name (e.g. home).")
parser.add_argument("--start", help="Start time (RFC3339 or YYYY-MM-DD).")
parser.add_argument("--end", help="End time (RFC3339 or YYYY-MM-DD).")
parser.add_argument("--train-ratio", type=float, default=0.7, help="Time-ordered train split ratio.")
parser.add_argument("--val-ratio", type=float, default=0.15, help="Time-ordered validation split ratio.")
parser.add_argument(
"--min-precision",
type=float,
default=0.7,
help="Minimum validation precision for threshold selection.",
)
parser.add_argument("--threshold", type=float, help="Optional fixed classification threshold.")
parser.add_argument("--min-rows", type=int, default=200, help="Minimum model-ready rows required.")
parser.add_argument("--out", default="models/rain_model.pkl", help="Path to save model.")
parser.add_argument(
"--report-out",
default="models/rain_model_report.json",
help="Path to save JSON training report.",
)
parser.add_argument(
"--model-version",
default="rain-logreg-v1",
help="Version label stored in artifact metadata.",
)
return parser.parse_args()
def make_model() -> Pipeline:
return Pipeline(
[
("scaler", StandardScaler()),
("clf", LogisticRegression(max_iter=1000, class_weight="balanced")),
]
)
def main() -> int:
args = parse_args()
if not args.db_url:
raise SystemExit("missing --db-url or DATABASE_URL")
start = parse_time(args.start) if args.start else ""
end = parse_time(args.end) if args.end else ""
with psycopg2.connect(args.db_url) as conn:
ws90 = fetch_ws90(conn, args.site, start, end)
baro = fetch_baro(conn, args.site, start, end)
full_df = build_dataset(ws90, baro, rain_event_threshold_mm=RAIN_EVENT_THRESHOLD_MM)
model_df = model_frame(full_df, FEATURE_COLUMNS, require_target=True)
if len(model_df) < args.min_rows:
raise RuntimeError(f"not enough model-ready rows after filtering (need >= {args.min_rows})")
train_df, val_df, test_df = split_time_ordered(
model_df,
train_ratio=args.train_ratio,
val_ratio=args.val_ratio,
)
x_train = train_df[FEATURE_COLUMNS]
y_train = train_df["rain_next_1h"].astype(int).to_numpy()
x_val = val_df[FEATURE_COLUMNS]
y_val = val_df["rain_next_1h"].astype(int).to_numpy()
x_test = test_df[FEATURE_COLUMNS]
y_test = test_df["rain_next_1h"].astype(int).to_numpy()
base_model = make_model()
base_model.fit(x_train, y_train)
y_val_prob = base_model.predict_proba(x_val)[:, 1]
if args.threshold is not None:
chosen_threshold = args.threshold
threshold_info = {
"selection_rule": "fixed_cli_threshold",
"threshold": float(args.threshold),
}
else:
chosen_threshold, threshold_info = select_threshold(
y_true=y_val,
y_prob=y_val_prob,
min_precision=args.min_precision,
)
val_metrics = evaluate_probs(y_true=y_val, y_prob=y_val_prob, threshold=chosen_threshold)
train_val_df = model_df.iloc[: len(train_df) + len(val_df)]
x_train_val = train_val_df[FEATURE_COLUMNS]
y_train_val = train_val_df["rain_next_1h"].astype(int).to_numpy()
final_model = make_model()
final_model.fit(x_train_val, y_train_val)
y_test_prob = final_model.predict_proba(x_test)[:, 1]
test_metrics = evaluate_probs(y_true=y_test, y_prob=y_test_prob, threshold=chosen_threshold)
report = {
"generated_at": datetime.now(timezone.utc).isoformat(),
"site": args.site,
"model_version": args.model_version,
"target_definition": f"rain_next_1h_mm >= {RAIN_EVENT_THRESHOLD_MM:.2f}",
"feature_columns": FEATURE_COLUMNS,
"data_window": {
"requested_start": start or None,
"requested_end": end or None,
"actual_start": model_df.index.min(),
"actual_end": model_df.index.max(),
"model_rows": len(model_df),
"ws90_rows": len(ws90),
"baro_rows": len(baro),
},
"label_quality": {
"rain_reset_count": int(np.nansum(full_df["rain_reset"].fillna(False).to_numpy(dtype=int))),
"rain_spike_5m_count": int(np.nansum(full_df["rain_spike_5m"].fillna(False).to_numpy(dtype=int))),
},
"split": {
"train_ratio": args.train_ratio,
"val_ratio": args.val_ratio,
"train_rows": len(train_df),
"val_rows": len(val_df),
"test_rows": len(test_df),
"train_start": train_df.index.min(),
"train_end": train_df.index.max(),
"val_start": val_df.index.min(),
"val_end": val_df.index.max(),
"test_start": test_df.index.min(),
"test_end": test_df.index.max(),
},
"threshold_selection": {
**threshold_info,
"min_precision_constraint": args.min_precision,
},
"validation_metrics": val_metrics,
"test_metrics": test_metrics,
}
report = to_builtin(report)
print("Rain model training summary:")
print(f" site: {args.site}")
print(f" model_version: {args.model_version}")
print(f" rows: total={report['data_window']['model_rows']} train={report['split']['train_rows']} val={report['split']['val_rows']} test={report['split']['test_rows']}")
print(
" threshold: "
f"{report['threshold_selection']['threshold']:.3f} "
f"({report['threshold_selection']['selection_rule']})"
)
print(
" val metrics: "
f"precision={report['validation_metrics']['precision']:.3f} "
f"recall={report['validation_metrics']['recall']:.3f} "
f"roc_auc={report['validation_metrics']['roc_auc'] if report['validation_metrics']['roc_auc'] is not None else 'n/a'} "
f"pr_auc={report['validation_metrics']['pr_auc'] if report['validation_metrics']['pr_auc'] is not None else 'n/a'}"
)
print(
" test metrics: "
f"precision={report['test_metrics']['precision']:.3f} "
f"recall={report['test_metrics']['recall']:.3f} "
f"roc_auc={report['test_metrics']['roc_auc'] if report['test_metrics']['roc_auc'] is not None else 'n/a'} "
f"pr_auc={report['test_metrics']['pr_auc'] if report['test_metrics']['pr_auc'] is not None else 'n/a'}"
)
if args.report_out:
report_dir = os.path.dirname(args.report_out)
if report_dir:
os.makedirs(report_dir, exist_ok=True)
with open(args.report_out, "w", encoding="utf-8") as f:
json.dump(report, f, indent=2)
print(f"Saved report to {args.report_out}")
if args.out:
out_dir = os.path.dirname(args.out)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
if joblib is None:
print("joblib not installed; skipping model save.")
else:
artifact = {
"model": final_model,
"features": FEATURE_COLUMNS,
"threshold": float(chosen_threshold),
"target_mm": float(RAIN_EVENT_THRESHOLD_MM),
"model_version": args.model_version,
"trained_at": datetime.now(timezone.utc).isoformat(),
"split": report["split"],
"threshold_selection": report["threshold_selection"],
}
joblib.dump(artifact, args.out)
print(f"Saved model to {args.out}")
return 0
if __name__ == "__main__":
raise SystemExit(main())