feat: add rain data audit and prediction scripts
This commit is contained in:
226
scripts/rain_model_common.py
Normal file
226
scripts/rain_model_common.py
Normal file
@@ -0,0 +1,226 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.metrics import (
|
||||
accuracy_score,
|
||||
average_precision_score,
|
||||
brier_score_loss,
|
||||
confusion_matrix,
|
||||
f1_score,
|
||||
precision_score,
|
||||
recall_score,
|
||||
roc_auc_score,
|
||||
)
|
||||
|
||||
FEATURE_COLUMNS = [
|
||||
"pressure_trend_1h",
|
||||
"humidity",
|
||||
"temperature_c",
|
||||
"wind_avg_m_s",
|
||||
"wind_max_m_s",
|
||||
]
|
||||
|
||||
RAIN_EVENT_THRESHOLD_MM = 0.2
|
||||
RAIN_SPIKE_THRESHOLD_MM_5M = 5.0
|
||||
RAIN_HORIZON_BUCKETS = 12 # 12 * 5m = 1h
|
||||
|
||||
|
||||
def parse_time(value: str) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
try:
|
||||
datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
return value
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"invalid time format: {value}") from exc
|
||||
|
||||
|
||||
def fetch_ws90(conn, site: str, start: str, end: str) -> pd.DataFrame:
|
||||
sql = """
|
||||
SELECT ts, station_id, received_at, temperature_c, humidity, wind_avg_m_s, wind_max_m_s, wind_dir_deg, rain_mm
|
||||
FROM observations_ws90
|
||||
WHERE site = %s
|
||||
AND (%s = '' OR ts >= %s::timestamptz)
|
||||
AND (%s = '' OR ts <= %s::timestamptz)
|
||||
ORDER BY ts ASC
|
||||
"""
|
||||
return pd.read_sql_query(sql, conn, params=(site, start, start, end, end), parse_dates=["ts", "received_at"])
|
||||
|
||||
|
||||
def fetch_baro(conn, site: str, start: str, end: str) -> pd.DataFrame:
|
||||
sql = """
|
||||
SELECT ts, source, received_at, pressure_hpa
|
||||
FROM observations_baro
|
||||
WHERE site = %s
|
||||
AND (%s = '' OR ts >= %s::timestamptz)
|
||||
AND (%s = '' OR ts <= %s::timestamptz)
|
||||
ORDER BY ts ASC
|
||||
"""
|
||||
return pd.read_sql_query(sql, conn, params=(site, start, start, end, end), parse_dates=["ts", "received_at"])
|
||||
|
||||
|
||||
def build_dataset(
|
||||
ws90: pd.DataFrame,
|
||||
baro: pd.DataFrame,
|
||||
rain_event_threshold_mm: float = RAIN_EVENT_THRESHOLD_MM,
|
||||
) -> pd.DataFrame:
|
||||
if ws90.empty:
|
||||
raise RuntimeError("no ws90 observations found")
|
||||
if baro.empty:
|
||||
raise RuntimeError("no barometer observations found")
|
||||
|
||||
ws90 = ws90.set_index("ts").sort_index()
|
||||
baro = baro.set_index("ts").sort_index()
|
||||
|
||||
ws90_5m = ws90.resample("5min").agg(
|
||||
{
|
||||
"temperature_c": "mean",
|
||||
"humidity": "mean",
|
||||
"wind_avg_m_s": "mean",
|
||||
"wind_max_m_s": "max",
|
||||
"wind_dir_deg": "mean",
|
||||
"rain_mm": "last",
|
||||
}
|
||||
)
|
||||
baro_5m = baro.resample("5min").agg({"pressure_hpa": "mean"})
|
||||
|
||||
df = ws90_5m.join(baro_5m, how="outer")
|
||||
df["pressure_hpa"] = df["pressure_hpa"].interpolate(limit=3)
|
||||
|
||||
df["rain_inc_raw"] = df["rain_mm"].diff()
|
||||
df["rain_reset"] = df["rain_inc_raw"] < 0
|
||||
df["rain_inc"] = df["rain_inc_raw"].clip(lower=0)
|
||||
df["rain_spike_5m"] = df["rain_inc"] >= RAIN_SPIKE_THRESHOLD_MM_5M
|
||||
|
||||
window = RAIN_HORIZON_BUCKETS
|
||||
df["rain_next_1h_mm"] = df["rain_inc"].rolling(window=window, min_periods=1).sum().shift(-(window - 1))
|
||||
df["rain_next_1h"] = df["rain_next_1h_mm"] >= rain_event_threshold_mm
|
||||
|
||||
df["pressure_trend_1h"] = df["pressure_hpa"] - df["pressure_hpa"].shift(window)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def model_frame(df: pd.DataFrame, feature_cols: list[str] | None = None, require_target: bool = True) -> pd.DataFrame:
|
||||
features = feature_cols or FEATURE_COLUMNS
|
||||
required = list(features)
|
||||
if require_target:
|
||||
required.append("rain_next_1h")
|
||||
out = df.dropna(subset=required).copy()
|
||||
return out.sort_index()
|
||||
|
||||
|
||||
def split_time_ordered(df: pd.DataFrame, train_ratio: float = 0.7, val_ratio: float = 0.15) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
if not (0 < train_ratio < 1):
|
||||
raise ValueError("train_ratio must be between 0 and 1")
|
||||
if not (0 <= val_ratio < 1):
|
||||
raise ValueError("val_ratio must be between 0 and 1")
|
||||
if train_ratio+val_ratio >= 1:
|
||||
raise ValueError("train_ratio + val_ratio must be < 1")
|
||||
|
||||
n = len(df)
|
||||
if n < 100:
|
||||
raise RuntimeError("not enough rows after filtering (need >= 100)")
|
||||
|
||||
train_end = int(n * train_ratio)
|
||||
val_end = int(n * (train_ratio + val_ratio))
|
||||
|
||||
train_end = min(max(train_end, 1), n - 2)
|
||||
val_end = min(max(val_end, train_end + 1), n - 1)
|
||||
|
||||
train_df = df.iloc[:train_end]
|
||||
val_df = df.iloc[train_end:val_end]
|
||||
test_df = df.iloc[val_end:]
|
||||
|
||||
if train_df.empty or val_df.empty or test_df.empty:
|
||||
raise RuntimeError("time split produced empty train/val/test set")
|
||||
return train_df, val_df, test_df
|
||||
|
||||
|
||||
def evaluate_probs(y_true: np.ndarray, y_prob: np.ndarray, threshold: float) -> dict[str, Any]:
|
||||
y_pred = (y_prob >= threshold).astype(int)
|
||||
|
||||
roc_auc = float("nan")
|
||||
pr_auc = float("nan")
|
||||
if len(np.unique(y_true)) > 1:
|
||||
roc_auc = roc_auc_score(y_true, y_prob)
|
||||
pr_auc = average_precision_score(y_true, y_prob)
|
||||
|
||||
cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
|
||||
metrics = {
|
||||
"rows": int(len(y_true)),
|
||||
"positive_rate": float(np.mean(y_true)),
|
||||
"threshold": float(threshold),
|
||||
"accuracy": accuracy_score(y_true, y_pred),
|
||||
"precision": precision_score(y_true, y_pred, zero_division=0),
|
||||
"recall": recall_score(y_true, y_pred, zero_division=0),
|
||||
"f1": f1_score(y_true, y_pred, zero_division=0),
|
||||
"roc_auc": roc_auc,
|
||||
"pr_auc": pr_auc,
|
||||
"brier": brier_score_loss(y_true, y_prob),
|
||||
"confusion_matrix": cm.tolist(),
|
||||
}
|
||||
return to_builtin(metrics)
|
||||
|
||||
|
||||
def select_threshold(y_true: np.ndarray, y_prob: np.ndarray, min_precision: float = 0.7) -> tuple[float, dict[str, Any]]:
|
||||
thresholds = np.linspace(0.05, 0.95, 91)
|
||||
|
||||
best: dict[str, Any] | None = None
|
||||
constrained_best: dict[str, Any] | None = None
|
||||
for threshold in thresholds:
|
||||
y_pred = (y_prob >= threshold).astype(int)
|
||||
precision = precision_score(y_true, y_pred, zero_division=0)
|
||||
recall = recall_score(y_true, y_pred, zero_division=0)
|
||||
f1 = f1_score(y_true, y_pred, zero_division=0)
|
||||
candidate = {
|
||||
"threshold": float(threshold),
|
||||
"precision": float(precision),
|
||||
"recall": float(recall),
|
||||
"f1": float(f1),
|
||||
}
|
||||
|
||||
if best is None or candidate["f1"] > best["f1"]:
|
||||
best = candidate
|
||||
|
||||
if precision >= min_precision:
|
||||
if constrained_best is None:
|
||||
constrained_best = candidate
|
||||
elif candidate["recall"] > constrained_best["recall"]:
|
||||
constrained_best = candidate
|
||||
elif candidate["recall"] == constrained_best["recall"] and candidate["f1"] > constrained_best["f1"]:
|
||||
constrained_best = candidate
|
||||
|
||||
if constrained_best is not None:
|
||||
constrained_best["selection_rule"] = f"max_recall_where_precision>={min_precision:.2f}"
|
||||
return float(constrained_best["threshold"]), constrained_best
|
||||
|
||||
assert best is not None
|
||||
best["selection_rule"] = "fallback_max_f1"
|
||||
return float(best["threshold"]), best
|
||||
|
||||
|
||||
def to_builtin(v: Any) -> Any:
|
||||
if isinstance(v, dict):
|
||||
return {k: to_builtin(val) for k, val in v.items()}
|
||||
if isinstance(v, list):
|
||||
return [to_builtin(i) for i in v]
|
||||
if isinstance(v, tuple):
|
||||
return [to_builtin(i) for i in v]
|
||||
if isinstance(v, np.integer):
|
||||
return int(v)
|
||||
if isinstance(v, np.floating):
|
||||
out = float(v)
|
||||
if np.isnan(out):
|
||||
return None
|
||||
return out
|
||||
if isinstance(v, pd.Timestamp):
|
||||
return v.isoformat()
|
||||
if isinstance(v, datetime):
|
||||
return v.isoformat()
|
||||
return v
|
||||
Reference in New Issue
Block a user