add model training
This commit is contained in:
178
scripts/train_rain_model.py
Normal file
178
scripts/train_rain_model.py
Normal file
@@ -0,0 +1,178 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import psycopg2
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, roc_auc_score
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
try:
|
||||
import joblib
|
||||
except ImportError: # pragma: no cover - optional dependency
|
||||
joblib = None
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Train a simple rain prediction model (next 1h >= 0.2mm).")
|
||||
parser.add_argument("--db-url", default=os.getenv("DATABASE_URL"), help="Postgres connection string.")
|
||||
parser.add_argument("--site", required=True, help="Site name (e.g. home).")
|
||||
parser.add_argument("--start", help="Start time (RFC3339 or YYYY-MM-DD).")
|
||||
parser.add_argument("--end", help="End time (RFC3339 or YYYY-MM-DD).")
|
||||
parser.add_argument("--out", default="models/rain_model.pkl", help="Path to save model.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def parse_time(value: str) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
try:
|
||||
datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
return value
|
||||
except ValueError:
|
||||
raise ValueError(f"invalid time format: {value}")
|
||||
|
||||
|
||||
def fetch_ws90(conn, site, start, end):
|
||||
sql = """
|
||||
SELECT ts, temperature_c, humidity, wind_avg_m_s, wind_max_m_s, wind_dir_deg, rain_mm
|
||||
FROM observations_ws90
|
||||
WHERE site = %s
|
||||
AND (%s = '' OR ts >= %s::timestamptz)
|
||||
AND (%s = '' OR ts <= %s::timestamptz)
|
||||
ORDER BY ts ASC
|
||||
"""
|
||||
return pd.read_sql_query(sql, conn, params=(site, start, start, end, end), parse_dates=["ts"])
|
||||
|
||||
|
||||
def fetch_baro(conn, site, start, end):
|
||||
sql = """
|
||||
SELECT ts, pressure_hpa
|
||||
FROM observations_baro
|
||||
WHERE site = %s
|
||||
AND (%s = '' OR ts >= %s::timestamptz)
|
||||
AND (%s = '' OR ts <= %s::timestamptz)
|
||||
ORDER BY ts ASC
|
||||
"""
|
||||
return pd.read_sql_query(sql, conn, params=(site, start, start, end, end), parse_dates=["ts"])
|
||||
|
||||
|
||||
def build_dataset(ws90: pd.DataFrame, baro: pd.DataFrame) -> pd.DataFrame:
|
||||
if ws90.empty:
|
||||
raise RuntimeError("no ws90 observations found")
|
||||
if baro.empty:
|
||||
raise RuntimeError("no barometer observations found")
|
||||
|
||||
ws90 = ws90.set_index("ts").sort_index()
|
||||
baro = baro.set_index("ts").sort_index()
|
||||
|
||||
ws90_5m = ws90.resample("5min").agg(
|
||||
{
|
||||
"temperature_c": "mean",
|
||||
"humidity": "mean",
|
||||
"wind_avg_m_s": "mean",
|
||||
"wind_max_m_s": "max",
|
||||
"wind_dir_deg": "mean",
|
||||
"rain_mm": "last",
|
||||
}
|
||||
)
|
||||
baro_5m = baro.resample("5min").mean()
|
||||
|
||||
df = ws90_5m.join(baro_5m, how="outer")
|
||||
df["pressure_hpa"] = df["pressure_hpa"].interpolate(limit=3)
|
||||
|
||||
# Compute incremental rain and future 1-hour sum.
|
||||
df["rain_inc"] = df["rain_mm"].diff().clip(lower=0)
|
||||
window = 12 # 12 * 5min = 1 hour
|
||||
df["rain_next_1h_mm"] = df["rain_inc"].rolling(window=window, min_periods=1).sum().shift(-(window - 1))
|
||||
df["rain_next_1h"] = df["rain_next_1h_mm"] >= 0.2
|
||||
|
||||
# Pressure trend over the previous hour.
|
||||
df["pressure_trend_1h"] = df["pressure_hpa"] - df["pressure_hpa"].shift(12)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def train_model(df: pd.DataFrame):
|
||||
feature_cols = [
|
||||
"pressure_trend_1h",
|
||||
"humidity",
|
||||
"temperature_c",
|
||||
"wind_avg_m_s",
|
||||
"wind_max_m_s",
|
||||
]
|
||||
|
||||
df = df.dropna(subset=feature_cols + ["rain_next_1h"])
|
||||
if len(df) < 200:
|
||||
raise RuntimeError("not enough data after filtering (need >= 200 rows)")
|
||||
|
||||
X = df[feature_cols]
|
||||
y = df["rain_next_1h"].astype(int)
|
||||
|
||||
split_idx = int(len(df) * 0.8)
|
||||
X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:]
|
||||
y_train, y_test = y.iloc[:split_idx], y.iloc[split_idx:]
|
||||
|
||||
model = Pipeline(
|
||||
[
|
||||
("scaler", StandardScaler()),
|
||||
("clf", LogisticRegression(max_iter=1000, class_weight="balanced")),
|
||||
]
|
||||
)
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
y_pred = model.predict(X_test)
|
||||
y_prob = model.predict_proba(X_test)[:, 1]
|
||||
|
||||
metrics = {
|
||||
"rows": len(df),
|
||||
"train_rows": len(X_train),
|
||||
"test_rows": len(X_test),
|
||||
"accuracy": accuracy_score(y_test, y_pred),
|
||||
"precision": precision_score(y_test, y_pred, zero_division=0),
|
||||
"recall": recall_score(y_test, y_pred, zero_division=0),
|
||||
"roc_auc": roc_auc_score(y_test, y_prob),
|
||||
"confusion_matrix": confusion_matrix(y_test, y_pred).tolist(),
|
||||
}
|
||||
|
||||
return model, metrics, feature_cols
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
if not args.db_url:
|
||||
raise SystemExit("missing --db-url or DATABASE_URL")
|
||||
|
||||
start = parse_time(args.start) if args.start else ""
|
||||
end = parse_time(args.end) if args.end else ""
|
||||
|
||||
with psycopg2.connect(args.db_url) as conn:
|
||||
ws90 = fetch_ws90(conn, args.site, start, end)
|
||||
baro = fetch_baro(conn, args.site, start, end)
|
||||
|
||||
df = build_dataset(ws90, baro)
|
||||
model, metrics, features = train_model(df)
|
||||
|
||||
print("Rain model metrics:")
|
||||
for k, v in metrics.items():
|
||||
print(f" {k}: {v}")
|
||||
|
||||
if args.out:
|
||||
out_dir = os.path.dirname(args.out)
|
||||
if out_dir:
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
if joblib is None:
|
||||
print("joblib not installed; skipping model save.")
|
||||
else:
|
||||
joblib.dump({"model": model, "features": features}, args.out)
|
||||
print(f"Saved model to {args.out}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user