bugfix wunderground reporting
This commit is contained in:
@@ -45,6 +45,7 @@ site:
|
|||||||
latitude: -33.8688 # WGS84 latitude
|
latitude: -33.8688 # WGS84 latitude
|
||||||
longitude: 151.2093 # WGS84 longitude
|
longitude: 151.2093 # WGS84 longitude
|
||||||
elevation_m: 50 # Currently informational (not used by Open-Meteo ECMWF endpoint)
|
elevation_m: 50 # Currently informational (not used by Open-Meteo ECMWF endpoint)
|
||||||
|
timezone: "Australia/Sydney" # IANA timezone used for daily rain boundary (e.g. Wunderground dailyrainin)
|
||||||
|
|
||||||
pollers:
|
pollers:
|
||||||
open_meteo:
|
open_meteo:
|
||||||
@@ -67,6 +68,7 @@ wunderground:
|
|||||||
- The Open-Meteo ECMWF endpoint is queried by the poller only. The UI reads forecasts from TimescaleDB.
|
- The Open-Meteo ECMWF endpoint is queried by the poller only. The UI reads forecasts from TimescaleDB.
|
||||||
- Web UI supports Local/UTC toggle and date-aligned ranges (6h, 24h, 72h, 7d).
|
- Web UI supports Local/UTC toggle and date-aligned ranges (6h, 24h, 72h, 7d).
|
||||||
- `mqtt.topic` is still supported for single-topic configs, but `mqtt.topics` is preferred.
|
- `mqtt.topic` is still supported for single-topic configs, but `mqtt.topics` is preferred.
|
||||||
|
- Set `site.timezone` to your station timezone so Wunderground daily rain resets at local midnight.
|
||||||
|
|
||||||
## Schema & tables
|
## Schema & tables
|
||||||
TimescaleDB schema is initialized from `db/init/001_schema.sql` and includes:
|
TimescaleDB schema is initialized from `db/init/001_schema.sql` and includes:
|
||||||
|
|||||||
@@ -39,7 +39,17 @@ func main() {
|
|||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
latest := &mqttingest.Latest{}
|
rainDayLoc := time.Local
|
||||||
|
if cfg.Site.Timezone != "" {
|
||||||
|
loc, err := time.LoadLocation(cfg.Site.Timezone)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("site timezone load: %v", err)
|
||||||
|
}
|
||||||
|
rainDayLoc = loc
|
||||||
|
}
|
||||||
|
log.Printf("rain day timezone: %s", rainDayLoc)
|
||||||
|
|
||||||
|
latest := mqttingest.NewLatest(rainDayLoc)
|
||||||
forecastCache := &ForecastCache{}
|
forecastCache := &ForecastCache{}
|
||||||
|
|
||||||
d, err := db.Open(ctx, cfg.DB.ConnString)
|
d, err := db.Open(ctx, cfg.DB.ConnString)
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ site:
|
|||||||
latitude: -33.8688
|
latitude: -33.8688
|
||||||
longitude: 151.2093
|
longitude: 151.2093
|
||||||
elevation_m: 50
|
elevation_m: 50
|
||||||
|
timezone: "Australia/Sydney"
|
||||||
|
|
||||||
pollers:
|
pollers:
|
||||||
open_meteo:
|
open_meteo:
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ services:
|
|||||||
RAIN_SITE: "home"
|
RAIN_SITE: "home"
|
||||||
RAIN_MODEL_NAME: "rain_next_1h"
|
RAIN_MODEL_NAME: "rain_next_1h"
|
||||||
RAIN_MODEL_VERSION_BASE: "rain-logreg-v1"
|
RAIN_MODEL_VERSION_BASE: "rain-logreg-v1"
|
||||||
|
RAIN_MODEL_FAMILY: "logreg"
|
||||||
RAIN_FEATURE_SET: "baseline"
|
RAIN_FEATURE_SET: "baseline"
|
||||||
RAIN_FORECAST_MODEL: "ecmwf"
|
RAIN_FORECAST_MODEL: "ecmwf"
|
||||||
RAIN_LOOKBACK_DAYS: "30"
|
RAIN_LOOKBACK_DAYS: "30"
|
||||||
|
|||||||
@@ -47,6 +47,12 @@ Feature-set options:
|
|||||||
- `extended`: adds wind-direction encoding, lag/rolling stats, recent rain accumulation,
|
- `extended`: adds wind-direction encoding, lag/rolling stats, recent rain accumulation,
|
||||||
and aligned forecast features from `forecast_openmeteo_hourly`.
|
and aligned forecast features from `forecast_openmeteo_hourly`.
|
||||||
|
|
||||||
|
Model-family options (`train_rain_model.py`):
|
||||||
|
- `logreg`: logistic regression baseline.
|
||||||
|
- `hist_gb`: histogram gradient boosting (tree-based baseline).
|
||||||
|
- `auto`: trains both `logreg` and `hist_gb`, picks the best validation model by
|
||||||
|
PR-AUC, then ROC-AUC, then F1.
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
### 1) Apply schema update (existing DBs)
|
### 1) Apply schema update (existing DBs)
|
||||||
`001_schema.sql` now includes `predictions_rain_1h`.
|
`001_schema.sql` now includes `predictions_rain_1h`.
|
||||||
@@ -79,6 +85,7 @@ python scripts/train_rain_model.py \
|
|||||||
--val-ratio 0.15 \
|
--val-ratio 0.15 \
|
||||||
--min-precision 0.70 \
|
--min-precision 0.70 \
|
||||||
--feature-set "baseline" \
|
--feature-set "baseline" \
|
||||||
|
--model-family "logreg" \
|
||||||
--model-version "rain-logreg-v1" \
|
--model-version "rain-logreg-v1" \
|
||||||
--out "models/rain_model.pkl" \
|
--out "models/rain_model.pkl" \
|
||||||
--report-out "models/rain_model_report.json" \
|
--report-out "models/rain_model_report.json" \
|
||||||
@@ -92,6 +99,7 @@ python scripts/train_rain_model.py \
|
|||||||
--start "2026-02-01T00:00:00Z" \
|
--start "2026-02-01T00:00:00Z" \
|
||||||
--end "2026-03-03T23:55:00Z" \
|
--end "2026-03-03T23:55:00Z" \
|
||||||
--feature-set "extended" \
|
--feature-set "extended" \
|
||||||
|
--model-family "logreg" \
|
||||||
--forecast-model "ecmwf" \
|
--forecast-model "ecmwf" \
|
||||||
--model-version "rain-logreg-v1-extended" \
|
--model-version "rain-logreg-v1-extended" \
|
||||||
--out "models/rain_model_extended.pkl" \
|
--out "models/rain_model_extended.pkl" \
|
||||||
@@ -99,6 +107,35 @@ python scripts/train_rain_model.py \
|
|||||||
--dataset-out "models/datasets/rain_dataset_{model_version}_{feature_set}.csv"
|
--dataset-out "models/datasets/rain_dataset_{model_version}_{feature_set}.csv"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 3c) Train tree-based baseline (P1)
|
||||||
|
```sh
|
||||||
|
python scripts/train_rain_model.py \
|
||||||
|
--site "home" \
|
||||||
|
--start "2026-02-01T00:00:00Z" \
|
||||||
|
--end "2026-03-03T23:55:00Z" \
|
||||||
|
--feature-set "extended" \
|
||||||
|
--model-family "hist_gb" \
|
||||||
|
--forecast-model "ecmwf" \
|
||||||
|
--model-version "rain-hgb-v1-extended" \
|
||||||
|
--out "models/rain_model_hgb.pkl" \
|
||||||
|
--report-out "models/rain_model_report_hgb.json" \
|
||||||
|
--dataset-out "models/datasets/rain_dataset_{model_version}_{feature_set}.csv"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3d) Auto-compare logistic vs tree baseline
|
||||||
|
```sh
|
||||||
|
python scripts/train_rain_model.py \
|
||||||
|
--site "home" \
|
||||||
|
--start "2026-02-01T00:00:00Z" \
|
||||||
|
--end "2026-03-03T23:55:00Z" \
|
||||||
|
--feature-set "extended" \
|
||||||
|
--model-family "auto" \
|
||||||
|
--forecast-model "ecmwf" \
|
||||||
|
--model-version "rain-auto-v1-extended" \
|
||||||
|
--out "models/rain_model_auto.pkl" \
|
||||||
|
--report-out "models/rain_model_report_auto.json"
|
||||||
|
```
|
||||||
|
|
||||||
### 4) Run inference and store prediction
|
### 4) Run inference and store prediction
|
||||||
```sh
|
```sh
|
||||||
python scripts/predict_rain_model.py \
|
python scripts/predict_rain_model.py \
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -31,6 +32,7 @@ type Config struct {
|
|||||||
Latitude float64 `yaml:"latitude"`
|
Latitude float64 `yaml:"latitude"`
|
||||||
Longitude float64 `yaml:"longitude"`
|
Longitude float64 `yaml:"longitude"`
|
||||||
ElevationM float64 `yaml:"elevation_m"`
|
ElevationM float64 `yaml:"elevation_m"`
|
||||||
|
Timezone string `yaml:"timezone"`
|
||||||
} `yaml:"site"`
|
} `yaml:"site"`
|
||||||
|
|
||||||
Pollers struct {
|
Pollers struct {
|
||||||
@@ -105,6 +107,11 @@ func Load(path string) (*Config, error) {
|
|||||||
if c.Site.Name == "" {
|
if c.Site.Name == "" {
|
||||||
c.Site.Name = "default"
|
c.Site.Name = "default"
|
||||||
}
|
}
|
||||||
|
if c.Site.Timezone != "" {
|
||||||
|
if _, err := time.LoadLocation(c.Site.Timezone); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid site timezone %q: %w", c.Site.Timezone, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
if c.Pollers.OpenMeteo.Model == "" {
|
if c.Pollers.OpenMeteo.Model == "" {
|
||||||
c.Pollers.OpenMeteo.Model = "ecmwf"
|
c.Pollers.OpenMeteo.Model = "ecmwf"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,7 +27,9 @@ type Latest struct {
|
|||||||
|
|
||||||
// rolling sums built from "rain increment" values (mm)
|
// rolling sums built from "rain increment" values (mm)
|
||||||
rainIncs []rainIncPoint // last 1h
|
rainIncs []rainIncPoint // last 1h
|
||||||
dailyIncs []rainIncPoint // since midnight (or since start; we’ll trim daily by midnight)
|
dailyIncs []rainIncPoint // since midnight in rainDayLoc (or since start; trimmed each update)
|
||||||
|
|
||||||
|
rainDayLoc *time.Location
|
||||||
}
|
}
|
||||||
|
|
||||||
type rainIncPoint struct {
|
type rainIncPoint struct {
|
||||||
@@ -35,6 +37,15 @@ type rainIncPoint struct {
|
|||||||
mm float64 // incremental rainfall at this timestamp (mm)
|
mm float64 // incremental rainfall at this timestamp (mm)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewLatest(rainDayLoc *time.Location) *Latest {
|
||||||
|
if rainDayLoc == nil {
|
||||||
|
rainDayLoc = time.Local
|
||||||
|
}
|
||||||
|
return &Latest{
|
||||||
|
rainDayLoc: rainDayLoc,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (l *Latest) Update(ts time.Time, p *WS90Payload) {
|
func (l *Latest) Update(ts time.Time, p *WS90Payload) {
|
||||||
l.mu.Lock()
|
l.mu.Lock()
|
||||||
defer l.mu.Unlock()
|
defer l.mu.Unlock()
|
||||||
@@ -49,9 +60,9 @@ func (l *Latest) Update(ts time.Time, p *WS90Payload) {
|
|||||||
cutoff := ts.Add(-1 * time.Hour)
|
cutoff := ts.Add(-1 * time.Hour)
|
||||||
l.rainIncs = trimBefore(l.rainIncs, cutoff)
|
l.rainIncs = trimBefore(l.rainIncs, cutoff)
|
||||||
|
|
||||||
// Track daily increments: trim before local midnight
|
// Track daily increments: trim before midnight in configured rain day timezone.
|
||||||
l.dailyIncs = append(l.dailyIncs, rainIncPoint{ts: ts, mm: inc})
|
l.dailyIncs = append(l.dailyIncs, rainIncPoint{ts: ts, mm: inc})
|
||||||
midnight := localMidnight(ts)
|
midnight := l.rainDayMidnight(ts)
|
||||||
l.dailyIncs = trimBefore(l.dailyIncs, midnight)
|
l.dailyIncs = trimBefore(l.dailyIncs, midnight)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,10 +79,13 @@ func trimBefore(a []rainIncPoint, cutoff time.Time) []rainIncPoint {
|
|||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
// localMidnight returns midnight in the local timezone of the *process*.
|
// rainDayMidnight returns midnight in the configured rain day timezone.
|
||||||
// If you want a specific timezone (e.g. Australia/Sydney) we can wire that in later.
|
func (l *Latest) rainDayMidnight(t time.Time) time.Time {
|
||||||
func localMidnight(t time.Time) time.Time {
|
loc := l.rainDayLoc
|
||||||
lt := t.Local()
|
if loc == nil {
|
||||||
|
loc = time.Local
|
||||||
|
}
|
||||||
|
lt := t.In(loc)
|
||||||
return time.Date(lt.Year(), lt.Month(), lt.Day(), 0, 0, 0, 0, lt.Location())
|
return time.Date(lt.Year(), lt.Month(), lt.Day(), 0, 0, 0, 0, lt.Location())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
50
internal/mqttingest/latest_test.go
Normal file
50
internal/mqttingest/latest_test.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
package mqttingest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLatestDailyRainUsesConfiguredTimezone(t *testing.T) {
|
||||||
|
loc, err := time.LoadLocation("Australia/Sydney")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("load location: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
l := NewLatest(loc)
|
||||||
|
|
||||||
|
// Crosses UTC midnight but remains the same local day in Sydney (UTC+11 during DST).
|
||||||
|
l.Update(time.Date(2026, time.January, 14, 22, 0, 0, 0, time.UTC), &WS90Payload{RainMM: 0})
|
||||||
|
l.Update(time.Date(2026, time.January, 14, 23, 30, 0, 0, time.UTC), &WS90Payload{RainMM: 2})
|
||||||
|
l.Update(time.Date(2026, time.January, 15, 0, 5, 0, 0, time.UTC), &WS90Payload{RainMM: 2})
|
||||||
|
|
||||||
|
snap, ok := l.Snapshot()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected snapshot")
|
||||||
|
}
|
||||||
|
if snap.DailyRainMM != 2 {
|
||||||
|
t.Fatalf("expected daily rain 2.0mm, got %.2fmm", snap.DailyRainMM)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLatestDailyRainResetsAtConfiguredLocalMidnight(t *testing.T) {
|
||||||
|
loc, err := time.LoadLocation("Australia/Sydney")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("load location: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
l := NewLatest(loc)
|
||||||
|
|
||||||
|
// Crosses local midnight in Sydney (00:00 local == 13:00 UTC during DST).
|
||||||
|
l.Update(time.Date(2026, time.January, 15, 12, 30, 0, 0, time.UTC), &WS90Payload{RainMM: 0})
|
||||||
|
l.Update(time.Date(2026, time.January, 15, 12, 50, 0, 0, time.UTC), &WS90Payload{RainMM: 1})
|
||||||
|
l.Update(time.Date(2026, time.January, 15, 13, 10, 0, 0, time.UTC), &WS90Payload{RainMM: 1})
|
||||||
|
|
||||||
|
snap, ok := l.Snapshot()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected snapshot")
|
||||||
|
}
|
||||||
|
if snap.DailyRainMM != 0 {
|
||||||
|
t.Fatalf("expected daily rain reset after local midnight; got %.2fmm", snap.DailyRainMM)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ MODEL_PATH="${MODEL_PATH:-models/rain_model.pkl}"
|
|||||||
REPORT_PATH="${REPORT_PATH:-models/rain_model_report.json}"
|
REPORT_PATH="${REPORT_PATH:-models/rain_model_report.json}"
|
||||||
AUDIT_PATH="${AUDIT_PATH:-models/rain_data_audit.json}"
|
AUDIT_PATH="${AUDIT_PATH:-models/rain_data_audit.json}"
|
||||||
FEATURE_SET="${FEATURE_SET:-baseline}"
|
FEATURE_SET="${FEATURE_SET:-baseline}"
|
||||||
|
MODEL_FAMILY="${MODEL_FAMILY:-logreg}"
|
||||||
FORECAST_MODEL="${FORECAST_MODEL:-ecmwf}"
|
FORECAST_MODEL="${FORECAST_MODEL:-ecmwf}"
|
||||||
DATASET_PATH="${DATASET_PATH:-models/datasets/rain_dataset_${MODEL_VERSION}_${FEATURE_SET}.csv}"
|
DATASET_PATH="${DATASET_PATH:-models/datasets/rain_dataset_${MODEL_VERSION}_${FEATURE_SET}.csv}"
|
||||||
|
|
||||||
@@ -36,6 +37,7 @@ python scripts/train_rain_model.py \
|
|||||||
--val-ratio 0.15 \
|
--val-ratio 0.15 \
|
||||||
--min-precision 0.70 \
|
--min-precision 0.70 \
|
||||||
--feature-set "$FEATURE_SET" \
|
--feature-set "$FEATURE_SET" \
|
||||||
|
--model-family "$MODEL_FAMILY" \
|
||||||
--forecast-model "$FORECAST_MODEL" \
|
--forecast-model "$FORECAST_MODEL" \
|
||||||
--model-version "$MODEL_VERSION" \
|
--model-version "$MODEL_VERSION" \
|
||||||
--out "$MODEL_PATH" \
|
--out "$MODEL_PATH" \
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ class WorkerConfig:
|
|||||||
site: str
|
site: str
|
||||||
model_name: str
|
model_name: str
|
||||||
model_version_base: str
|
model_version_base: str
|
||||||
|
model_family: str
|
||||||
feature_set: str
|
feature_set: str
|
||||||
forecast_model: str
|
forecast_model: str
|
||||||
train_interval_hours: float
|
train_interval_hours: float
|
||||||
@@ -130,6 +131,8 @@ def run_training_cycle(cfg: WorkerConfig, env: dict[str, str]) -> None:
|
|||||||
str(cfg.min_precision),
|
str(cfg.min_precision),
|
||||||
"--feature-set",
|
"--feature-set",
|
||||||
cfg.feature_set,
|
cfg.feature_set,
|
||||||
|
"--model-family",
|
||||||
|
cfg.model_family,
|
||||||
"--forecast-model",
|
"--forecast-model",
|
||||||
cfg.forecast_model,
|
cfg.forecast_model,
|
||||||
"--model-version",
|
"--model-version",
|
||||||
@@ -176,6 +179,7 @@ def load_config() -> WorkerConfig:
|
|||||||
site=read_env("RAIN_SITE", "home"),
|
site=read_env("RAIN_SITE", "home"),
|
||||||
model_name=read_env("RAIN_MODEL_NAME", "rain_next_1h"),
|
model_name=read_env("RAIN_MODEL_NAME", "rain_next_1h"),
|
||||||
model_version_base=read_env("RAIN_MODEL_VERSION_BASE", "rain-logreg-v1"),
|
model_version_base=read_env("RAIN_MODEL_VERSION_BASE", "rain-logreg-v1"),
|
||||||
|
model_family=read_env("RAIN_MODEL_FAMILY", "logreg"),
|
||||||
feature_set=read_env("RAIN_FEATURE_SET", "baseline"),
|
feature_set=read_env("RAIN_FEATURE_SET", "baseline"),
|
||||||
forecast_model=read_env("RAIN_FORECAST_MODEL", "ecmwf"),
|
forecast_model=read_env("RAIN_FORECAST_MODEL", "ecmwf"),
|
||||||
train_interval_hours=read_env_float("RAIN_TRAIN_INTERVAL_HOURS", 24.0),
|
train_interval_hours=read_env_float("RAIN_TRAIN_INTERVAL_HOURS", 24.0),
|
||||||
@@ -212,6 +216,7 @@ def main() -> int:
|
|||||||
"[rain-ml] worker start "
|
"[rain-ml] worker start "
|
||||||
f"site={cfg.site} "
|
f"site={cfg.site} "
|
||||||
f"model_name={cfg.model_name} "
|
f"model_name={cfg.model_name} "
|
||||||
|
f"model_family={cfg.model_family} "
|
||||||
f"feature_set={cfg.feature_set} "
|
f"feature_set={cfg.feature_set} "
|
||||||
f"forecast_model={cfg.forecast_model} "
|
f"forecast_model={cfg.forecast_model} "
|
||||||
f"train_interval_hours={cfg.train_interval_hours} "
|
f"train_interval_hours={cfg.train_interval_hours} "
|
||||||
|
|||||||
@@ -3,9 +3,11 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import psycopg2
|
import psycopg2
|
||||||
|
from sklearn.ensemble import HistGradientBoostingClassifier
|
||||||
from sklearn.linear_model import LogisticRegression
|
from sklearn.linear_model import LogisticRegression
|
||||||
from sklearn.pipeline import Pipeline
|
from sklearn.pipeline import Pipeline
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
@@ -22,6 +24,8 @@ from rain_model_common import (
|
|||||||
feature_columns_need_forecast,
|
feature_columns_need_forecast,
|
||||||
model_frame,
|
model_frame,
|
||||||
parse_time,
|
parse_time,
|
||||||
|
safe_pr_auc,
|
||||||
|
safe_roc_auc,
|
||||||
select_threshold,
|
select_threshold,
|
||||||
split_time_ordered,
|
split_time_ordered,
|
||||||
to_builtin,
|
to_builtin,
|
||||||
@@ -33,6 +37,9 @@ except ImportError: # pragma: no cover - optional dependency
|
|||||||
joblib = None
|
joblib = None
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_FAMILIES = ("logreg", "hist_gb", "auto")
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
parser = argparse.ArgumentParser(description="Train a rain prediction model (next 1h >= 0.2mm).")
|
parser = argparse.ArgumentParser(description="Train a rain prediction model (next 1h >= 0.2mm).")
|
||||||
parser.add_argument("--db-url", default=os.getenv("DATABASE_URL"), help="Postgres connection string.")
|
parser.add_argument("--db-url", default=os.getenv("DATABASE_URL"), help="Postgres connection string.")
|
||||||
@@ -60,6 +67,21 @@ def parse_args() -> argparse.Namespace:
|
|||||||
default="ecmwf",
|
default="ecmwf",
|
||||||
help="Forecast model name when feature set requires forecast columns.",
|
help="Forecast model name when feature set requires forecast columns.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-family",
|
||||||
|
default="logreg",
|
||||||
|
choices=MODEL_FAMILIES,
|
||||||
|
help=(
|
||||||
|
"Estimator family. "
|
||||||
|
"'auto' compares logreg and hist_gb on validation and selects best by PR-AUC/ROC-AUC/F1."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--random-state",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="Random seed for stochastic estimators.",
|
||||||
|
)
|
||||||
parser.add_argument("--out", default="models/rain_model.pkl", help="Path to save model.")
|
parser.add_argument("--out", default="models/rain_model.pkl", help="Path to save model.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--report-out",
|
"--report-out",
|
||||||
@@ -82,13 +104,59 @@ def parse_args() -> argparse.Namespace:
|
|||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def make_model() -> Pipeline:
|
def make_model(model_family: str, random_state: int):
|
||||||
return Pipeline(
|
if model_family == "logreg":
|
||||||
[
|
return Pipeline(
|
||||||
("scaler", StandardScaler()),
|
[
|
||||||
("clf", LogisticRegression(max_iter=1000, class_weight="balanced")),
|
("scaler", StandardScaler()),
|
||||||
]
|
("clf", LogisticRegression(max_iter=1000, class_weight="balanced", random_state=random_state)),
|
||||||
)
|
]
|
||||||
|
)
|
||||||
|
if model_family == "hist_gb":
|
||||||
|
return HistGradientBoostingClassifier(
|
||||||
|
max_iter=300,
|
||||||
|
learning_rate=0.05,
|
||||||
|
max_depth=5,
|
||||||
|
min_samples_leaf=20,
|
||||||
|
random_state=random_state,
|
||||||
|
)
|
||||||
|
raise ValueError(f"unknown model_family: {model_family}")
|
||||||
|
|
||||||
|
|
||||||
|
def train_candidate(
|
||||||
|
model_family: str,
|
||||||
|
x_train,
|
||||||
|
y_train: np.ndarray,
|
||||||
|
x_val,
|
||||||
|
y_val: np.ndarray,
|
||||||
|
random_state: int,
|
||||||
|
min_precision: float,
|
||||||
|
fixed_threshold: float | None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
model = make_model(model_family=model_family, random_state=random_state)
|
||||||
|
model.fit(x_train, y_train)
|
||||||
|
y_val_prob = model.predict_proba(x_val)[:, 1]
|
||||||
|
|
||||||
|
if fixed_threshold is not None:
|
||||||
|
threshold = fixed_threshold
|
||||||
|
threshold_info = {
|
||||||
|
"selection_rule": "fixed_cli_threshold",
|
||||||
|
"threshold": float(fixed_threshold),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
threshold, threshold_info = select_threshold(
|
||||||
|
y_true=y_val,
|
||||||
|
y_prob=y_val_prob,
|
||||||
|
min_precision=min_precision,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_metrics = evaluate_probs(y_true=y_val, y_prob=y_val_prob, threshold=threshold)
|
||||||
|
return {
|
||||||
|
"model_family": model_family,
|
||||||
|
"threshold": float(threshold),
|
||||||
|
"threshold_info": threshold_info,
|
||||||
|
"validation_metrics": val_metrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
def main() -> int:
|
||||||
@@ -126,30 +194,38 @@ def main() -> int:
|
|||||||
x_test = test_df[feature_cols]
|
x_test = test_df[feature_cols]
|
||||||
y_test = test_df["rain_next_1h"].astype(int).to_numpy()
|
y_test = test_df["rain_next_1h"].astype(int).to_numpy()
|
||||||
|
|
||||||
base_model = make_model()
|
candidate_families = ["logreg", "hist_gb"] if args.model_family == "auto" else [args.model_family]
|
||||||
base_model.fit(x_train, y_train)
|
candidates = [
|
||||||
y_val_prob = base_model.predict_proba(x_val)[:, 1]
|
train_candidate(
|
||||||
|
model_family=family,
|
||||||
if args.threshold is not None:
|
x_train=x_train,
|
||||||
chosen_threshold = args.threshold
|
y_train=y_train,
|
||||||
threshold_info = {
|
x_val=x_val,
|
||||||
"selection_rule": "fixed_cli_threshold",
|
y_val=y_val,
|
||||||
"threshold": float(args.threshold),
|
random_state=args.random_state,
|
||||||
}
|
|
||||||
else:
|
|
||||||
chosen_threshold, threshold_info = select_threshold(
|
|
||||||
y_true=y_val,
|
|
||||||
y_prob=y_val_prob,
|
|
||||||
min_precision=args.min_precision,
|
min_precision=args.min_precision,
|
||||||
|
fixed_threshold=args.threshold,
|
||||||
)
|
)
|
||||||
|
for family in candidate_families
|
||||||
val_metrics = evaluate_probs(y_true=y_val, y_prob=y_val_prob, threshold=chosen_threshold)
|
]
|
||||||
|
best_candidate = max(
|
||||||
|
candidates,
|
||||||
|
key=lambda c: (
|
||||||
|
safe_pr_auc(c["validation_metrics"]),
|
||||||
|
safe_roc_auc(c["validation_metrics"]),
|
||||||
|
float(c["validation_metrics"]["f1"]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
selected_model_family = str(best_candidate["model_family"])
|
||||||
|
chosen_threshold = float(best_candidate["threshold"])
|
||||||
|
threshold_info = best_candidate["threshold_info"]
|
||||||
|
val_metrics = best_candidate["validation_metrics"]
|
||||||
|
|
||||||
train_val_df = model_df.iloc[: len(train_df) + len(val_df)]
|
train_val_df = model_df.iloc[: len(train_df) + len(val_df)]
|
||||||
x_train_val = train_val_df[feature_cols]
|
x_train_val = train_val_df[feature_cols]
|
||||||
y_train_val = train_val_df["rain_next_1h"].astype(int).to_numpy()
|
y_train_val = train_val_df["rain_next_1h"].astype(int).to_numpy()
|
||||||
|
|
||||||
final_model = make_model()
|
final_model = make_model(model_family=selected_model_family, random_state=args.random_state)
|
||||||
final_model.fit(x_train_val, y_train_val)
|
final_model.fit(x_train_val, y_train_val)
|
||||||
y_test_prob = final_model.predict_proba(x_test)[:, 1]
|
y_test_prob = final_model.predict_proba(x_test)[:, 1]
|
||||||
test_metrics = evaluate_probs(y_true=y_test, y_prob=y_test_prob, threshold=chosen_threshold)
|
test_metrics = evaluate_probs(y_true=y_test, y_prob=y_test_prob, threshold=chosen_threshold)
|
||||||
@@ -158,6 +234,8 @@ def main() -> int:
|
|||||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||||
"site": args.site,
|
"site": args.site,
|
||||||
"model_version": args.model_version,
|
"model_version": args.model_version,
|
||||||
|
"model_family_requested": args.model_family,
|
||||||
|
"model_family": selected_model_family,
|
||||||
"feature_set": args.feature_set,
|
"feature_set": args.feature_set,
|
||||||
"target_definition": f"rain_next_1h_mm >= {RAIN_EVENT_THRESHOLD_MM:.2f}",
|
"target_definition": f"rain_next_1h_mm >= {RAIN_EVENT_THRESHOLD_MM:.2f}",
|
||||||
"feature_columns": feature_cols,
|
"feature_columns": feature_cols,
|
||||||
@@ -194,6 +272,17 @@ def main() -> int:
|
|||||||
**threshold_info,
|
**threshold_info,
|
||||||
"min_precision_constraint": args.min_precision,
|
"min_precision_constraint": args.min_precision,
|
||||||
},
|
},
|
||||||
|
"candidate_models": [
|
||||||
|
{
|
||||||
|
"model_family": c["model_family"],
|
||||||
|
"threshold_selection": {
|
||||||
|
**c["threshold_info"],
|
||||||
|
"min_precision_constraint": args.min_precision,
|
||||||
|
},
|
||||||
|
"validation_metrics": c["validation_metrics"],
|
||||||
|
}
|
||||||
|
for c in candidates
|
||||||
|
],
|
||||||
"validation_metrics": val_metrics,
|
"validation_metrics": val_metrics,
|
||||||
"test_metrics": test_metrics,
|
"test_metrics": test_metrics,
|
||||||
}
|
}
|
||||||
@@ -202,6 +291,7 @@ def main() -> int:
|
|||||||
print("Rain model training summary:")
|
print("Rain model training summary:")
|
||||||
print(f" site: {args.site}")
|
print(f" site: {args.site}")
|
||||||
print(f" model_version: {args.model_version}")
|
print(f" model_version: {args.model_version}")
|
||||||
|
print(f" model_family: {selected_model_family} (requested={args.model_family})")
|
||||||
print(f" feature_set: {args.feature_set} ({len(feature_cols)} features)")
|
print(f" feature_set: {args.feature_set} ({len(feature_cols)} features)")
|
||||||
print(f" rows: total={report['data_window']['model_rows']} train={report['split']['train_rows']} val={report['split']['val_rows']} test={report['split']['test_rows']}")
|
print(f" rows: total={report['data_window']['model_rows']} train={report['split']['train_rows']} val={report['split']['val_rows']} test={report['split']['test_rows']}")
|
||||||
print(
|
print(
|
||||||
@@ -250,6 +340,7 @@ def main() -> int:
|
|||||||
else:
|
else:
|
||||||
artifact = {
|
artifact = {
|
||||||
"model": final_model,
|
"model": final_model,
|
||||||
|
"model_family": selected_model_family,
|
||||||
"features": feature_cols,
|
"features": feature_cols,
|
||||||
"feature_set": args.feature_set,
|
"feature_set": args.feature_set,
|
||||||
"forecast_model": args.forecast_model if needs_forecast else None,
|
"forecast_model": args.forecast_model if needs_forecast else None,
|
||||||
|
|||||||
2
todo.md
2
todo.md
@@ -24,7 +24,7 @@ Priority key: `P0` = critical/blocking, `P1` = important, `P2` = later optimizat
|
|||||||
|
|
||||||
## 4) Modeling and Validation
|
## 4) Modeling and Validation
|
||||||
- [x] [P0] Keep logistic regression as baseline.
|
- [x] [P0] Keep logistic regression as baseline.
|
||||||
- [ ] [P1] Add at least one tree-based baseline (e.g. gradient boosting).
|
- [x] [P1] Add at least one tree-based baseline (e.g. gradient boosting). (implemented via `hist_gb`; runtime evaluation pending local Python deps)
|
||||||
- [x] [P0] Use strict time-based train/validation/test splits (no random shuffling).
|
- [x] [P0] Use strict time-based train/validation/test splits (no random shuffling).
|
||||||
- [ ] [P1] Add walk-forward backtesting across multiple temporal folds.
|
- [ ] [P1] Add walk-forward backtesting across multiple temporal folds.
|
||||||
- [ ] [P1] Tune hyperparameters on validation data only.
|
- [ ] [P1] Tune hyperparameters on validation data only.
|
||||||
|
|||||||
Reference in New Issue
Block a user