more work on model training

This commit is contained in:
2026-03-05 20:49:19 +11:00
parent 76270e5650
commit 5b8cad905f
9 changed files with 380 additions and 48 deletions

View File

@@ -17,7 +17,7 @@ from sklearn.metrics import (
roc_auc_score,
)
FEATURE_COLUMNS = [
BASELINE_FEATURE_COLUMNS = [
"pressure_trend_1h",
"humidity",
"temperature_c",
@@ -25,6 +25,49 @@ FEATURE_COLUMNS = [
"wind_max_m_s",
]
FORECAST_FEATURE_COLUMNS = [
"fc_temp_c",
"fc_rh",
"fc_pressure_msl_hpa",
"fc_wind_m_s",
"fc_wind_gust_m_s",
"fc_precip_mm",
"fc_precip_prob",
"fc_cloud_cover",
]
EXTENDED_FEATURE_COLUMNS = [
"pressure_trend_1h",
"temperature_c",
"humidity",
"wind_avg_m_s",
"wind_max_m_s",
"wind_dir_sin",
"wind_dir_cos",
"temp_lag_5m",
"temp_roll_1h_mean",
"temp_roll_1h_std",
"humidity_lag_5m",
"humidity_roll_1h_mean",
"humidity_roll_1h_std",
"wind_avg_lag_5m",
"wind_avg_roll_1h_mean",
"wind_gust_roll_1h_max",
"pressure_lag_5m",
"pressure_roll_1h_mean",
"pressure_roll_1h_std",
"rain_last_1h_mm",
*FORECAST_FEATURE_COLUMNS,
]
FEATURE_SETS: dict[str, list[str]] = {
"baseline": BASELINE_FEATURE_COLUMNS,
"extended": EXTENDED_FEATURE_COLUMNS,
}
AVAILABLE_FEATURE_SETS = tuple(sorted(FEATURE_SETS.keys()))
FEATURE_COLUMNS = BASELINE_FEATURE_COLUMNS
RAIN_EVENT_THRESHOLD_MM = 0.2
RAIN_SPIKE_THRESHOLD_MM_5M = 5.0
RAIN_HORIZON_BUCKETS = 12 # 12 * 5m = 1h
@@ -40,6 +83,34 @@ def parse_time(value: str) -> str:
raise ValueError(f"invalid time format: {value}") from exc
def feature_columns_for_set(feature_set: str) -> list[str]:
out = FEATURE_SETS.get(feature_set.lower())
if out is None:
raise ValueError(f"unknown feature set: {feature_set}")
return list(out)
def feature_columns_need_forecast(feature_cols: list[str]) -> bool:
return any(col in FORECAST_FEATURE_COLUMNS for col in feature_cols)
def feature_set_needs_forecast(feature_set: str) -> bool:
return feature_columns_need_forecast(feature_columns_for_set(feature_set))
def _fetch_df(conn, sql: str, params: tuple[Any, ...], parse_dt_cols: list[str]) -> pd.DataFrame:
with conn.cursor() as cur:
cur.execute(sql, params)
rows = cur.fetchall()
cols = [d.name for d in cur.description]
df = pd.DataFrame.from_records(rows, columns=cols)
if not df.empty:
for col in parse_dt_cols:
df[col] = pd.to_datetime(df[col], utc=True)
return df
def fetch_ws90(conn, site: str, start: str, end: str) -> pd.DataFrame:
sql = """
SELECT ts, station_id, received_at, temperature_c, humidity, wind_avg_m_s, wind_max_m_s, wind_dir_deg, rain_mm
@@ -49,16 +120,7 @@ def fetch_ws90(conn, site: str, start: str, end: str) -> pd.DataFrame:
AND (%s = '' OR ts <= %s::timestamptz)
ORDER BY ts ASC
"""
with conn.cursor() as cur:
cur.execute(sql, (site, start, start, end, end))
rows = cur.fetchall()
cols = [d.name for d in cur.description]
df = pd.DataFrame.from_records(rows, columns=cols)
if not df.empty:
df["ts"] = pd.to_datetime(df["ts"], utc=True)
df["received_at"] = pd.to_datetime(df["received_at"], utc=True)
return df
return _fetch_df(conn, sql, (site, start, start, end, end), ["ts", "received_at"])
def fetch_baro(conn, site: str, start: str, end: str) -> pd.DataFrame:
@@ -70,21 +132,80 @@ def fetch_baro(conn, site: str, start: str, end: str) -> pd.DataFrame:
AND (%s = '' OR ts <= %s::timestamptz)
ORDER BY ts ASC
"""
with conn.cursor() as cur:
cur.execute(sql, (site, start, start, end, end))
rows = cur.fetchall()
cols = [d.name for d in cur.description]
return _fetch_df(conn, sql, (site, start, start, end, end), ["ts", "received_at"])
df = pd.DataFrame.from_records(rows, columns=cols)
if not df.empty:
df["ts"] = pd.to_datetime(df["ts"], utc=True)
df["received_at"] = pd.to_datetime(df["received_at"], utc=True)
return df
def fetch_forecast(conn, site: str, start: str, end: str, model: str = "ecmwf") -> pd.DataFrame:
sql = """
SELECT DISTINCT ON (ts)
ts,
retrieved_at,
temp_c,
rh,
pressure_msl_hpa,
wind_m_s,
wind_gust_m_s,
precip_mm,
precip_prob,
cloud_cover
FROM forecast_openmeteo_hourly
WHERE site = %s
AND model = %s
AND (%s = '' OR ts >= %s::timestamptz - INTERVAL '2 hours')
AND (%s = '' OR ts <= %s::timestamptz + INTERVAL '2 hours')
ORDER BY ts ASC, retrieved_at DESC
"""
return _fetch_df(conn, sql, (site, model, start, start, end, end), ["ts", "retrieved_at"])
def _apply_forecast_features(df: pd.DataFrame, forecast: pd.DataFrame | None) -> pd.DataFrame:
out = df.copy()
for col in FORECAST_FEATURE_COLUMNS:
out[col] = np.nan
if forecast is None or forecast.empty:
return out
fc = forecast.set_index("ts").sort_index().rename(
columns={
"temp_c": "fc_temp_c",
"rh": "fc_rh",
"pressure_msl_hpa": "fc_pressure_msl_hpa",
"wind_m_s": "fc_wind_m_s",
"wind_gust_m_s": "fc_wind_gust_m_s",
"precip_mm": "fc_precip_mm",
"precip_prob": "fc_precip_prob",
"cloud_cover": "fc_cloud_cover",
}
)
keep = [c for c in FORECAST_FEATURE_COLUMNS if c in fc.columns]
fc = fc[keep]
# Bring hourly forecast onto 5m observation grid.
fc_5m = fc.resample("5min").ffill(limit=12)
out = out.join(fc_5m, how="left", rsuffix="_dup")
# Prefer joined forecast values and softly fill small gaps.
for col in keep:
dup_col = f"{col}_dup"
if dup_col in out.columns:
out[col] = out[dup_col]
out.drop(columns=[dup_col], inplace=True)
out[col] = out[col].ffill(limit=12).bfill(limit=2)
# Normalize precip probability to 0..1 if source is 0..100.
if "fc_precip_prob" in out.columns:
mask = out["fc_precip_prob"] > 1.0
out.loc[mask, "fc_precip_prob"] = out.loc[mask, "fc_precip_prob"] / 100.0
out["fc_precip_prob"] = out["fc_precip_prob"].clip(lower=0.0, upper=1.0)
return out
def build_dataset(
ws90: pd.DataFrame,
baro: pd.DataFrame,
forecast: pd.DataFrame | None = None,
rain_event_threshold_mm: float = RAIN_EVENT_THRESHOLD_MM,
) -> pd.DataFrame:
if ws90.empty:
@@ -116,11 +237,33 @@ def build_dataset(
df["rain_spike_5m"] = df["rain_inc"] >= RAIN_SPIKE_THRESHOLD_MM_5M
window = RAIN_HORIZON_BUCKETS
df["rain_last_1h_mm"] = df["rain_inc"].rolling(window=window, min_periods=1).sum()
df["rain_next_1h_mm"] = df["rain_inc"].rolling(window=window, min_periods=1).sum().shift(-(window - 1))
df["rain_next_1h"] = df["rain_next_1h_mm"] >= rain_event_threshold_mm
df["pressure_trend_1h"] = df["pressure_hpa"] - df["pressure_hpa"].shift(window)
# Wind direction cyclical encoding.
radians = np.deg2rad(df["wind_dir_deg"] % 360.0)
df["wind_dir_sin"] = np.sin(radians)
df["wind_dir_cos"] = np.cos(radians)
# Lags and rolling features (core sensors).
df["temp_lag_5m"] = df["temperature_c"].shift(1)
df["humidity_lag_5m"] = df["humidity"].shift(1)
df["wind_avg_lag_5m"] = df["wind_avg_m_s"].shift(1)
df["pressure_lag_5m"] = df["pressure_hpa"].shift(1)
df["temp_roll_1h_mean"] = df["temperature_c"].rolling(window=window, min_periods=3).mean()
df["temp_roll_1h_std"] = df["temperature_c"].rolling(window=window, min_periods=3).std()
df["humidity_roll_1h_mean"] = df["humidity"].rolling(window=window, min_periods=3).mean()
df["humidity_roll_1h_std"] = df["humidity"].rolling(window=window, min_periods=3).std()
df["wind_avg_roll_1h_mean"] = df["wind_avg_m_s"].rolling(window=window, min_periods=3).mean()
df["wind_gust_roll_1h_max"] = df["wind_max_m_s"].rolling(window=window, min_periods=3).max()
df["pressure_roll_1h_mean"] = df["pressure_hpa"].rolling(window=window, min_periods=3).mean()
df["pressure_roll_1h_std"] = df["pressure_hpa"].rolling(window=window, min_periods=3).std()
df = _apply_forecast_features(df, forecast)
return df
@@ -133,12 +276,16 @@ def model_frame(df: pd.DataFrame, feature_cols: list[str] | None = None, require
return out.sort_index()
def split_time_ordered(df: pd.DataFrame, train_ratio: float = 0.7, val_ratio: float = 0.15) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
def split_time_ordered(
df: pd.DataFrame,
train_ratio: float = 0.7,
val_ratio: float = 0.15,
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
if not (0 < train_ratio < 1):
raise ValueError("train_ratio must be between 0 and 1")
if not (0 <= val_ratio < 1):
raise ValueError("val_ratio must be between 0 and 1")
if train_ratio+val_ratio >= 1:
if train_ratio + val_ratio >= 1:
raise ValueError("train_ratio + val_ratio must be < 1")
n = len(df)
@@ -223,6 +370,26 @@ def select_threshold(y_true: np.ndarray, y_prob: np.ndarray, min_precision: floa
return float(best["threshold"]), best
def safe_pr_auc(v: dict[str, Any]) -> float:
value = v.get("pr_auc")
if value is None:
return float("-inf")
out = float(value)
if np.isnan(out):
return float("-inf")
return out
def safe_roc_auc(v: dict[str, Any]) -> float:
value = v.get("roc_auc")
if value is None:
return float("-inf")
out = float(value)
if np.isnan(out):
return float("-inf")
return out
def to_builtin(v: Any) -> Any:
if isinstance(v, dict):
return {k: to_builtin(val) for k, val in v.items()}