179 lines
5.6 KiB
Python
179 lines
5.6 KiB
Python
#!/usr/bin/env python3
|
|
import argparse
|
|
import os
|
|
from datetime import datetime
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import psycopg2
|
|
from sklearn.linear_model import LogisticRegression
|
|
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, roc_auc_score
|
|
from sklearn.pipeline import Pipeline
|
|
from sklearn.preprocessing import StandardScaler
|
|
|
|
try:
|
|
import joblib
|
|
except ImportError: # pragma: no cover - optional dependency
|
|
joblib = None
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Train a simple rain prediction model (next 1h >= 0.2mm).")
|
|
parser.add_argument("--db-url", default=os.getenv("DATABASE_URL"), help="Postgres connection string.")
|
|
parser.add_argument("--site", required=True, help="Site name (e.g. home).")
|
|
parser.add_argument("--start", help="Start time (RFC3339 or YYYY-MM-DD).")
|
|
parser.add_argument("--end", help="End time (RFC3339 or YYYY-MM-DD).")
|
|
parser.add_argument("--out", default="models/rain_model.pkl", help="Path to save model.")
|
|
return parser.parse_args()
|
|
|
|
|
|
def parse_time(value: str) -> str:
|
|
if not value:
|
|
return ""
|
|
try:
|
|
datetime.fromisoformat(value.replace("Z", "+00:00"))
|
|
return value
|
|
except ValueError:
|
|
raise ValueError(f"invalid time format: {value}")
|
|
|
|
|
|
def fetch_ws90(conn, site, start, end):
|
|
sql = """
|
|
SELECT ts, temperature_c, humidity, wind_avg_m_s, wind_max_m_s, wind_dir_deg, rain_mm
|
|
FROM observations_ws90
|
|
WHERE site = %s
|
|
AND (%s = '' OR ts >= %s::timestamptz)
|
|
AND (%s = '' OR ts <= %s::timestamptz)
|
|
ORDER BY ts ASC
|
|
"""
|
|
return pd.read_sql_query(sql, conn, params=(site, start, start, end, end), parse_dates=["ts"])
|
|
|
|
|
|
def fetch_baro(conn, site, start, end):
|
|
sql = """
|
|
SELECT ts, pressure_hpa
|
|
FROM observations_baro
|
|
WHERE site = %s
|
|
AND (%s = '' OR ts >= %s::timestamptz)
|
|
AND (%s = '' OR ts <= %s::timestamptz)
|
|
ORDER BY ts ASC
|
|
"""
|
|
return pd.read_sql_query(sql, conn, params=(site, start, start, end, end), parse_dates=["ts"])
|
|
|
|
|
|
def build_dataset(ws90: pd.DataFrame, baro: pd.DataFrame) -> pd.DataFrame:
|
|
if ws90.empty:
|
|
raise RuntimeError("no ws90 observations found")
|
|
if baro.empty:
|
|
raise RuntimeError("no barometer observations found")
|
|
|
|
ws90 = ws90.set_index("ts").sort_index()
|
|
baro = baro.set_index("ts").sort_index()
|
|
|
|
ws90_5m = ws90.resample("5min").agg(
|
|
{
|
|
"temperature_c": "mean",
|
|
"humidity": "mean",
|
|
"wind_avg_m_s": "mean",
|
|
"wind_max_m_s": "max",
|
|
"wind_dir_deg": "mean",
|
|
"rain_mm": "last",
|
|
}
|
|
)
|
|
baro_5m = baro.resample("5min").mean()
|
|
|
|
df = ws90_5m.join(baro_5m, how="outer")
|
|
df["pressure_hpa"] = df["pressure_hpa"].interpolate(limit=3)
|
|
|
|
# Compute incremental rain and future 1-hour sum.
|
|
df["rain_inc"] = df["rain_mm"].diff().clip(lower=0)
|
|
window = 12 # 12 * 5min = 1 hour
|
|
df["rain_next_1h_mm"] = df["rain_inc"].rolling(window=window, min_periods=1).sum().shift(-(window - 1))
|
|
df["rain_next_1h"] = df["rain_next_1h_mm"] >= 0.2
|
|
|
|
# Pressure trend over the previous hour.
|
|
df["pressure_trend_1h"] = df["pressure_hpa"] - df["pressure_hpa"].shift(12)
|
|
|
|
return df
|
|
|
|
|
|
def train_model(df: pd.DataFrame):
|
|
feature_cols = [
|
|
"pressure_trend_1h",
|
|
"humidity",
|
|
"temperature_c",
|
|
"wind_avg_m_s",
|
|
"wind_max_m_s",
|
|
]
|
|
|
|
df = df.dropna(subset=feature_cols + ["rain_next_1h"])
|
|
if len(df) < 200:
|
|
raise RuntimeError("not enough data after filtering (need >= 200 rows)")
|
|
|
|
X = df[feature_cols]
|
|
y = df["rain_next_1h"].astype(int)
|
|
|
|
split_idx = int(len(df) * 0.8)
|
|
X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:]
|
|
y_train, y_test = y.iloc[:split_idx], y.iloc[split_idx:]
|
|
|
|
model = Pipeline(
|
|
[
|
|
("scaler", StandardScaler()),
|
|
("clf", LogisticRegression(max_iter=1000, class_weight="balanced")),
|
|
]
|
|
)
|
|
model.fit(X_train, y_train)
|
|
|
|
y_pred = model.predict(X_test)
|
|
y_prob = model.predict_proba(X_test)[:, 1]
|
|
|
|
metrics = {
|
|
"rows": len(df),
|
|
"train_rows": len(X_train),
|
|
"test_rows": len(X_test),
|
|
"accuracy": accuracy_score(y_test, y_pred),
|
|
"precision": precision_score(y_test, y_pred, zero_division=0),
|
|
"recall": recall_score(y_test, y_pred, zero_division=0),
|
|
"roc_auc": roc_auc_score(y_test, y_prob),
|
|
"confusion_matrix": confusion_matrix(y_test, y_pred).tolist(),
|
|
}
|
|
|
|
return model, metrics, feature_cols
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
if not args.db_url:
|
|
raise SystemExit("missing --db-url or DATABASE_URL")
|
|
|
|
start = parse_time(args.start) if args.start else ""
|
|
end = parse_time(args.end) if args.end else ""
|
|
|
|
with psycopg2.connect(args.db_url) as conn:
|
|
ws90 = fetch_ws90(conn, args.site, start, end)
|
|
baro = fetch_baro(conn, args.site, start, end)
|
|
|
|
df = build_dataset(ws90, baro)
|
|
model, metrics, features = train_model(df)
|
|
|
|
print("Rain model metrics:")
|
|
for k, v in metrics.items():
|
|
print(f" {k}: {v}")
|
|
|
|
if args.out:
|
|
out_dir = os.path.dirname(args.out)
|
|
if out_dir:
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
if joblib is None:
|
|
print("joblib not installed; skipping model save.")
|
|
else:
|
|
joblib.dump({"model": model, "features": features}, args.out)
|
|
print(f"Saved model to {args.out}")
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|