feat: add rain data audit and prediction scripts
This commit is contained in:
@@ -1,16 +1,29 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import psycopg2
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, roc_auc_score
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
from rain_model_common import (
|
||||
FEATURE_COLUMNS,
|
||||
RAIN_EVENT_THRESHOLD_MM,
|
||||
build_dataset,
|
||||
evaluate_probs,
|
||||
fetch_baro,
|
||||
fetch_ws90,
|
||||
model_frame,
|
||||
parse_time,
|
||||
select_threshold,
|
||||
split_time_ordered,
|
||||
to_builtin,
|
||||
)
|
||||
|
||||
try:
|
||||
import joblib
|
||||
except ImportError: # pragma: no cover - optional dependency
|
||||
@@ -18,128 +31,42 @@ except ImportError: # pragma: no cover - optional dependency
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Train a simple rain prediction model (next 1h >= 0.2mm).")
|
||||
parser = argparse.ArgumentParser(description="Train a rain prediction model (next 1h >= 0.2mm).")
|
||||
parser.add_argument("--db-url", default=os.getenv("DATABASE_URL"), help="Postgres connection string.")
|
||||
parser.add_argument("--site", required=True, help="Site name (e.g. home).")
|
||||
parser.add_argument("--start", help="Start time (RFC3339 or YYYY-MM-DD).")
|
||||
parser.add_argument("--end", help="End time (RFC3339 or YYYY-MM-DD).")
|
||||
parser.add_argument("--train-ratio", type=float, default=0.7, help="Time-ordered train split ratio.")
|
||||
parser.add_argument("--val-ratio", type=float, default=0.15, help="Time-ordered validation split ratio.")
|
||||
parser.add_argument(
|
||||
"--min-precision",
|
||||
type=float,
|
||||
default=0.7,
|
||||
help="Minimum validation precision for threshold selection.",
|
||||
)
|
||||
parser.add_argument("--threshold", type=float, help="Optional fixed classification threshold.")
|
||||
parser.add_argument("--min-rows", type=int, default=200, help="Minimum model-ready rows required.")
|
||||
parser.add_argument("--out", default="models/rain_model.pkl", help="Path to save model.")
|
||||
parser.add_argument(
|
||||
"--report-out",
|
||||
default="models/rain_model_report.json",
|
||||
help="Path to save JSON training report.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-version",
|
||||
default="rain-logreg-v1",
|
||||
help="Version label stored in artifact metadata.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def parse_time(value: str) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
try:
|
||||
datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
return value
|
||||
except ValueError:
|
||||
raise ValueError(f"invalid time format: {value}")
|
||||
|
||||
|
||||
def fetch_ws90(conn, site, start, end):
|
||||
sql = """
|
||||
SELECT ts, temperature_c, humidity, wind_avg_m_s, wind_max_m_s, wind_dir_deg, rain_mm
|
||||
FROM observations_ws90
|
||||
WHERE site = %s
|
||||
AND (%s = '' OR ts >= %s::timestamptz)
|
||||
AND (%s = '' OR ts <= %s::timestamptz)
|
||||
ORDER BY ts ASC
|
||||
"""
|
||||
return pd.read_sql_query(sql, conn, params=(site, start, start, end, end), parse_dates=["ts"])
|
||||
|
||||
|
||||
def fetch_baro(conn, site, start, end):
|
||||
sql = """
|
||||
SELECT ts, pressure_hpa
|
||||
FROM observations_baro
|
||||
WHERE site = %s
|
||||
AND (%s = '' OR ts >= %s::timestamptz)
|
||||
AND (%s = '' OR ts <= %s::timestamptz)
|
||||
ORDER BY ts ASC
|
||||
"""
|
||||
return pd.read_sql_query(sql, conn, params=(site, start, start, end, end), parse_dates=["ts"])
|
||||
|
||||
|
||||
def build_dataset(ws90: pd.DataFrame, baro: pd.DataFrame) -> pd.DataFrame:
|
||||
if ws90.empty:
|
||||
raise RuntimeError("no ws90 observations found")
|
||||
if baro.empty:
|
||||
raise RuntimeError("no barometer observations found")
|
||||
|
||||
ws90 = ws90.set_index("ts").sort_index()
|
||||
baro = baro.set_index("ts").sort_index()
|
||||
|
||||
ws90_5m = ws90.resample("5min").agg(
|
||||
{
|
||||
"temperature_c": "mean",
|
||||
"humidity": "mean",
|
||||
"wind_avg_m_s": "mean",
|
||||
"wind_max_m_s": "max",
|
||||
"wind_dir_deg": "mean",
|
||||
"rain_mm": "last",
|
||||
}
|
||||
)
|
||||
baro_5m = baro.resample("5min").mean()
|
||||
|
||||
df = ws90_5m.join(baro_5m, how="outer")
|
||||
df["pressure_hpa"] = df["pressure_hpa"].interpolate(limit=3)
|
||||
|
||||
# Compute incremental rain and future 1-hour sum.
|
||||
df["rain_inc"] = df["rain_mm"].diff().clip(lower=0)
|
||||
window = 12 # 12 * 5min = 1 hour
|
||||
df["rain_next_1h_mm"] = df["rain_inc"].rolling(window=window, min_periods=1).sum().shift(-(window - 1))
|
||||
df["rain_next_1h"] = df["rain_next_1h_mm"] >= 0.2
|
||||
|
||||
# Pressure trend over the previous hour.
|
||||
df["pressure_trend_1h"] = df["pressure_hpa"] - df["pressure_hpa"].shift(12)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def train_model(df: pd.DataFrame):
|
||||
feature_cols = [
|
||||
"pressure_trend_1h",
|
||||
"humidity",
|
||||
"temperature_c",
|
||||
"wind_avg_m_s",
|
||||
"wind_max_m_s",
|
||||
]
|
||||
|
||||
df = df.dropna(subset=feature_cols + ["rain_next_1h"])
|
||||
if len(df) < 200:
|
||||
raise RuntimeError("not enough data after filtering (need >= 200 rows)")
|
||||
|
||||
X = df[feature_cols]
|
||||
y = df["rain_next_1h"].astype(int)
|
||||
|
||||
split_idx = int(len(df) * 0.8)
|
||||
X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:]
|
||||
y_train, y_test = y.iloc[:split_idx], y.iloc[split_idx:]
|
||||
|
||||
model = Pipeline(
|
||||
def make_model() -> Pipeline:
|
||||
return Pipeline(
|
||||
[
|
||||
("scaler", StandardScaler()),
|
||||
("clf", LogisticRegression(max_iter=1000, class_weight="balanced")),
|
||||
]
|
||||
)
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
y_pred = model.predict(X_test)
|
||||
y_prob = model.predict_proba(X_test)[:, 1]
|
||||
|
||||
metrics = {
|
||||
"rows": len(df),
|
||||
"train_rows": len(X_train),
|
||||
"test_rows": len(X_test),
|
||||
"accuracy": accuracy_score(y_test, y_pred),
|
||||
"precision": precision_score(y_test, y_pred, zero_division=0),
|
||||
"recall": recall_score(y_test, y_pred, zero_division=0),
|
||||
"roc_auc": roc_auc_score(y_test, y_prob),
|
||||
"confusion_matrix": confusion_matrix(y_test, y_pred).tolist(),
|
||||
}
|
||||
|
||||
return model, metrics, feature_cols
|
||||
|
||||
|
||||
def main() -> int:
|
||||
@@ -154,12 +81,124 @@ def main() -> int:
|
||||
ws90 = fetch_ws90(conn, args.site, start, end)
|
||||
baro = fetch_baro(conn, args.site, start, end)
|
||||
|
||||
df = build_dataset(ws90, baro)
|
||||
model, metrics, features = train_model(df)
|
||||
full_df = build_dataset(ws90, baro, rain_event_threshold_mm=RAIN_EVENT_THRESHOLD_MM)
|
||||
model_df = model_frame(full_df, FEATURE_COLUMNS, require_target=True)
|
||||
if len(model_df) < args.min_rows:
|
||||
raise RuntimeError(f"not enough model-ready rows after filtering (need >= {args.min_rows})")
|
||||
|
||||
print("Rain model metrics:")
|
||||
for k, v in metrics.items():
|
||||
print(f" {k}: {v}")
|
||||
train_df, val_df, test_df = split_time_ordered(
|
||||
model_df,
|
||||
train_ratio=args.train_ratio,
|
||||
val_ratio=args.val_ratio,
|
||||
)
|
||||
|
||||
x_train = train_df[FEATURE_COLUMNS]
|
||||
y_train = train_df["rain_next_1h"].astype(int).to_numpy()
|
||||
x_val = val_df[FEATURE_COLUMNS]
|
||||
y_val = val_df["rain_next_1h"].astype(int).to_numpy()
|
||||
x_test = test_df[FEATURE_COLUMNS]
|
||||
y_test = test_df["rain_next_1h"].astype(int).to_numpy()
|
||||
|
||||
base_model = make_model()
|
||||
base_model.fit(x_train, y_train)
|
||||
y_val_prob = base_model.predict_proba(x_val)[:, 1]
|
||||
|
||||
if args.threshold is not None:
|
||||
chosen_threshold = args.threshold
|
||||
threshold_info = {
|
||||
"selection_rule": "fixed_cli_threshold",
|
||||
"threshold": float(args.threshold),
|
||||
}
|
||||
else:
|
||||
chosen_threshold, threshold_info = select_threshold(
|
||||
y_true=y_val,
|
||||
y_prob=y_val_prob,
|
||||
min_precision=args.min_precision,
|
||||
)
|
||||
|
||||
val_metrics = evaluate_probs(y_true=y_val, y_prob=y_val_prob, threshold=chosen_threshold)
|
||||
|
||||
train_val_df = model_df.iloc[: len(train_df) + len(val_df)]
|
||||
x_train_val = train_val_df[FEATURE_COLUMNS]
|
||||
y_train_val = train_val_df["rain_next_1h"].astype(int).to_numpy()
|
||||
|
||||
final_model = make_model()
|
||||
final_model.fit(x_train_val, y_train_val)
|
||||
y_test_prob = final_model.predict_proba(x_test)[:, 1]
|
||||
test_metrics = evaluate_probs(y_true=y_test, y_prob=y_test_prob, threshold=chosen_threshold)
|
||||
|
||||
report = {
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"site": args.site,
|
||||
"model_version": args.model_version,
|
||||
"target_definition": f"rain_next_1h_mm >= {RAIN_EVENT_THRESHOLD_MM:.2f}",
|
||||
"feature_columns": FEATURE_COLUMNS,
|
||||
"data_window": {
|
||||
"requested_start": start or None,
|
||||
"requested_end": end or None,
|
||||
"actual_start": model_df.index.min(),
|
||||
"actual_end": model_df.index.max(),
|
||||
"model_rows": len(model_df),
|
||||
"ws90_rows": len(ws90),
|
||||
"baro_rows": len(baro),
|
||||
},
|
||||
"label_quality": {
|
||||
"rain_reset_count": int(np.nansum(full_df["rain_reset"].fillna(False).to_numpy(dtype=int))),
|
||||
"rain_spike_5m_count": int(np.nansum(full_df["rain_spike_5m"].fillna(False).to_numpy(dtype=int))),
|
||||
},
|
||||
"split": {
|
||||
"train_ratio": args.train_ratio,
|
||||
"val_ratio": args.val_ratio,
|
||||
"train_rows": len(train_df),
|
||||
"val_rows": len(val_df),
|
||||
"test_rows": len(test_df),
|
||||
"train_start": train_df.index.min(),
|
||||
"train_end": train_df.index.max(),
|
||||
"val_start": val_df.index.min(),
|
||||
"val_end": val_df.index.max(),
|
||||
"test_start": test_df.index.min(),
|
||||
"test_end": test_df.index.max(),
|
||||
},
|
||||
"threshold_selection": {
|
||||
**threshold_info,
|
||||
"min_precision_constraint": args.min_precision,
|
||||
},
|
||||
"validation_metrics": val_metrics,
|
||||
"test_metrics": test_metrics,
|
||||
}
|
||||
report = to_builtin(report)
|
||||
|
||||
print("Rain model training summary:")
|
||||
print(f" site: {args.site}")
|
||||
print(f" model_version: {args.model_version}")
|
||||
print(f" rows: total={report['data_window']['model_rows']} train={report['split']['train_rows']} val={report['split']['val_rows']} test={report['split']['test_rows']}")
|
||||
print(
|
||||
" threshold: "
|
||||
f"{report['threshold_selection']['threshold']:.3f} "
|
||||
f"({report['threshold_selection']['selection_rule']})"
|
||||
)
|
||||
print(
|
||||
" val metrics: "
|
||||
f"precision={report['validation_metrics']['precision']:.3f} "
|
||||
f"recall={report['validation_metrics']['recall']:.3f} "
|
||||
f"roc_auc={report['validation_metrics']['roc_auc'] if report['validation_metrics']['roc_auc'] is not None else 'n/a'} "
|
||||
f"pr_auc={report['validation_metrics']['pr_auc'] if report['validation_metrics']['pr_auc'] is not None else 'n/a'}"
|
||||
)
|
||||
print(
|
||||
" test metrics: "
|
||||
f"precision={report['test_metrics']['precision']:.3f} "
|
||||
f"recall={report['test_metrics']['recall']:.3f} "
|
||||
f"roc_auc={report['test_metrics']['roc_auc'] if report['test_metrics']['roc_auc'] is not None else 'n/a'} "
|
||||
f"pr_auc={report['test_metrics']['pr_auc'] if report['test_metrics']['pr_auc'] is not None else 'n/a'}"
|
||||
)
|
||||
|
||||
if args.report_out:
|
||||
report_dir = os.path.dirname(args.report_out)
|
||||
if report_dir:
|
||||
os.makedirs(report_dir, exist_ok=True)
|
||||
with open(args.report_out, "w", encoding="utf-8") as f:
|
||||
json.dump(report, f, indent=2)
|
||||
print(f"Saved report to {args.report_out}")
|
||||
|
||||
if args.out:
|
||||
out_dir = os.path.dirname(args.out)
|
||||
@@ -168,7 +207,17 @@ def main() -> int:
|
||||
if joblib is None:
|
||||
print("joblib not installed; skipping model save.")
|
||||
else:
|
||||
joblib.dump({"model": model, "features": features}, args.out)
|
||||
artifact = {
|
||||
"model": final_model,
|
||||
"features": FEATURE_COLUMNS,
|
||||
"threshold": float(chosen_threshold),
|
||||
"target_mm": float(RAIN_EVENT_THRESHOLD_MM),
|
||||
"model_version": args.model_version,
|
||||
"trained_at": datetime.now(timezone.utc).isoformat(),
|
||||
"split": report["split"],
|
||||
"threshold_selection": report["threshold_selection"],
|
||||
}
|
||||
joblib.dump(artifact, args.out)
|
||||
print(f"Saved model to {args.out}")
|
||||
|
||||
return 0
|
||||
|
||||
Reference in New Issue
Block a user