work on model training
This commit is contained in:
@@ -5,6 +5,6 @@ WORKDIR /app
|
||||
COPY scripts/requirements.txt /app/requirements.txt
|
||||
RUN pip install --no-cache-dir -r /app/requirements.txt
|
||||
|
||||
COPY scripts/train_rain_model.py /app/train_rain_model.py
|
||||
COPY scripts /app/scripts
|
||||
|
||||
ENTRYPOINT ["python", "/app/train_rain_model.py"]
|
||||
ENTRYPOINT ["python", "/app/scripts/run_rain_ml_worker.py"]
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
Starter weather-station data pipeline:
|
||||
- MQTT ingest of WS90 payloads -> TimescaleDB
|
||||
- ECMWF (Open-Meteo) forecast polling -> TimescaleDB
|
||||
- Python rain-model worker (periodic training + inference writes) -> TimescaleDB
|
||||
- Web UI with live metrics, comparisons, and charts
|
||||
|
||||
## Quick start
|
||||
@@ -73,6 +74,7 @@ TimescaleDB schema is initialized from `db/init/001_schema.sql` and includes:
|
||||
- `observations_ws90` (hypertable): raw WS90 observations with payload metadata, plus the full JSON payload (`payload_json`).
|
||||
- `observations_baro` (hypertable): barometric pressure observations from other MQTT topics.
|
||||
- `forecast_openmeteo_hourly` (hypertable): hourly forecast points keyed by `(site, model, retrieved_at, ts)`.
|
||||
- `predictions_rain_1h` (hypertable): model probability + decision + realized outcome fields.
|
||||
- Continuous aggregates:
|
||||
- `cagg_ws90_1m`: 1‑minute rollups (avg/min/max for temp, humidity, wind, uvi, light, rain).
|
||||
- `cagg_ws90_5m`: 5‑minute rollups (same metrics as `cagg_ws90_1m`).
|
||||
|
||||
@@ -24,14 +24,16 @@ type webServer struct {
|
||||
}
|
||||
|
||||
type dashboardResponse struct {
|
||||
GeneratedAt time.Time `json:"generated_at"`
|
||||
Site string `json:"site"`
|
||||
Model string `json:"model"`
|
||||
RangeStart time.Time `json:"range_start"`
|
||||
RangeEnd time.Time `json:"range_end"`
|
||||
Observations []db.ObservationPoint `json:"observations"`
|
||||
Forecast db.ForecastSeries `json:"forecast"`
|
||||
Latest *db.ObservationPoint `json:"latest"`
|
||||
GeneratedAt time.Time `json:"generated_at"`
|
||||
Site string `json:"site"`
|
||||
Model string `json:"model"`
|
||||
RangeStart time.Time `json:"range_start"`
|
||||
RangeEnd time.Time `json:"range_end"`
|
||||
Observations []db.ObservationPoint `json:"observations"`
|
||||
Forecast db.ForecastSeries `json:"forecast"`
|
||||
Latest *db.ObservationPoint `json:"latest"`
|
||||
LatestRainPredict *db.RainPredictionPoint `json:"latest_rain_prediction,omitempty"`
|
||||
RainPredictionRange []db.RainPredictionPoint `json:"rain_predictions,omitempty"`
|
||||
}
|
||||
|
||||
func runWebServer(ctx context.Context, d *db.DB, site providers.Site, model, addr string) error {
|
||||
@@ -171,15 +173,33 @@ func (s *webServer) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
const rainModelName = "rain_next_1h"
|
||||
|
||||
latestRainPrediction, err := s.db.LatestRainPrediction(r.Context(), s.site.Name, rainModelName)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to query latest rain prediction", http.StatusInternalServerError)
|
||||
log.Printf("web dashboard latest rain prediction error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
rainPredictionRange, err := s.db.RainPredictionSeriesRange(r.Context(), s.site.Name, rainModelName, start, end)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to query rain predictions", http.StatusInternalServerError)
|
||||
log.Printf("web dashboard rain prediction range error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
resp := dashboardResponse{
|
||||
GeneratedAt: time.Now().UTC(),
|
||||
Site: s.site.Name,
|
||||
Model: s.model,
|
||||
RangeStart: start,
|
||||
RangeEnd: end,
|
||||
Observations: observations,
|
||||
Forecast: forecast,
|
||||
Latest: latest,
|
||||
GeneratedAt: time.Now().UTC(),
|
||||
Site: s.site.Name,
|
||||
Model: s.model,
|
||||
RangeStart: start,
|
||||
RangeEnd: end,
|
||||
Observations: observations,
|
||||
Forecast: forecast,
|
||||
Latest: latest,
|
||||
LatestRainPredict: latestRainPrediction,
|
||||
RainPredictionRange: rainPredictionRange,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
@@ -501,6 +501,40 @@ function buildRainProbabilitySeries(points) {
|
||||
return out;
|
||||
}
|
||||
|
||||
function buildRainProbabilitySeriesFromPredictions(points) {
|
||||
return points.map((p) => {
|
||||
const t = new Date(p.ts).getTime();
|
||||
if (Number.isNaN(t)) {
|
||||
return { x: null, y: null };
|
||||
}
|
||||
if (p.probability === null || p.probability === undefined) {
|
||||
return { x: t, y: null };
|
||||
}
|
||||
return {
|
||||
x: t,
|
||||
y: Math.round(Number(p.probability) * 1000) / 10,
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
function thresholdSeries(range, threshold) {
|
||||
if (!range || !range.axisStart || !range.axisEnd || threshold === null || threshold === undefined) {
|
||||
return [];
|
||||
}
|
||||
const y = Math.round(Number(threshold) * 1000) / 10;
|
||||
return [
|
||||
{ x: range.axisStart.getTime(), y },
|
||||
{ x: range.axisEnd.getTime(), y },
|
||||
];
|
||||
}
|
||||
|
||||
function predictionAgeMinutes(prediction) {
|
||||
if (!prediction || !prediction.ts) return null;
|
||||
const ts = new Date(prediction.ts).getTime();
|
||||
if (Number.isNaN(ts)) return null;
|
||||
return (Date.now() - ts) / (60 * 1000);
|
||||
}
|
||||
|
||||
function updateWeatherIcons(latest, rainProb) {
|
||||
const sunEl = document.getElementById("live-icon-sun");
|
||||
const cloudEl = document.getElementById("live-icon-cloud");
|
||||
@@ -693,10 +727,26 @@ function renderDashboard(data) {
|
||||
const obsFiltered = filterRange(obs, rangeStart, rangeEnd);
|
||||
const forecast = filterRange(forecastAll, rangeStart, rangeEnd);
|
||||
const forecastLine = extendForecastTo(forecast, rangeEnd);
|
||||
const rainPredictions = filterRange(data.rain_predictions || [], rangeStart, rangeEnd);
|
||||
const latestRainPrediction = data.latest_rain_prediction || null;
|
||||
const latestPredictionAgeMin = predictionAgeMinutes(latestRainPrediction);
|
||||
const modelPredictionFresh = latestPredictionAgeMin !== null && latestPredictionAgeMin <= 90;
|
||||
const lastPressureTrend = lastNonNull(obsFiltered, "pressure_trend_1h");
|
||||
const rainProb = computeRainProbability(latest);
|
||||
const modelRainProb = modelPredictionFresh && latestRainPrediction && latestRainPrediction.probability !== null && latestRainPrediction.probability !== undefined
|
||||
? {
|
||||
prob: Number(latestRainPrediction.probability),
|
||||
label: classifyRainProbability(Number(latestRainPrediction.probability)),
|
||||
source: "model",
|
||||
}
|
||||
: null;
|
||||
const heuristicRainProb = computeRainProbability(latest);
|
||||
const rainProb = modelRainProb || heuristicRainProb;
|
||||
if (rainProb) {
|
||||
updateText("live-rain-prob", `${Math.round(rainProb.prob * 100)}% (${rainProb.label})`);
|
||||
const sourceLabel = rainProb.source === "model" ? "model" : "heuristic";
|
||||
updateText("live-rain-prob", `${Math.round(rainProb.prob * 100)}% (${rainProb.label}, ${sourceLabel})`);
|
||||
} else if (latestRainPrediction && latestRainPrediction.probability !== null && latestRainPrediction.probability !== undefined) {
|
||||
const stalePct = Math.round(Number(latestRainPrediction.probability) * 100);
|
||||
updateText("live-rain-prob", `${stalePct}% (stale model)`);
|
||||
} else {
|
||||
updateText("live-rain-prob", "--");
|
||||
}
|
||||
@@ -901,13 +951,20 @@ function renderDashboard(data) {
|
||||
data: {
|
||||
datasets: [
|
||||
{
|
||||
label: "predicted rain probability (%)",
|
||||
data: buildRainProbabilitySeries(obsFiltered),
|
||||
label: rainPredictions.length ? "model rain probability (%)" : "heuristic rain probability (%)",
|
||||
data: rainPredictions.length ? buildRainProbabilitySeriesFromPredictions(rainPredictions) : buildRainProbabilitySeries(obsFiltered),
|
||||
borderColor: colors.rain,
|
||||
backgroundColor: "rgba(78, 168, 222, 0.18)",
|
||||
fill: true,
|
||||
yAxisID: "y",
|
||||
},
|
||||
{
|
||||
label: "decision threshold (%)",
|
||||
data: thresholdSeries(range, latestRainPrediction ? latestRainPrediction.threshold : null),
|
||||
borderColor: "#f4b942",
|
||||
borderDash: [6, 4],
|
||||
yAxisID: "y",
|
||||
},
|
||||
],
|
||||
},
|
||||
options: rainProbOptions,
|
||||
|
||||
@@ -198,7 +198,7 @@
|
||||
</div>
|
||||
<div class="chart-card" data-chart="chart-rain-prob">
|
||||
<div class="chart-header">
|
||||
<div class="chart-title">Predicted Rain Probability (Observed Inputs)</div>
|
||||
<div class="chart-title">Predicted Rain Probability (Model)</div>
|
||||
<button class="chart-link" data-chart="chart-rain-prob" title="Copy chart link">Share</button>
|
||||
</div>
|
||||
<div class="chart-canvas">
|
||||
|
||||
@@ -24,5 +24,27 @@ services:
|
||||
volumes:
|
||||
- ./config.yaml:/app/config.yaml:ro
|
||||
|
||||
rainml:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.train
|
||||
depends_on:
|
||||
- timescaledb
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
DATABASE_URL: "postgres://postgres:postgres@timescaledb:5432/micrometeo?sslmode=disable"
|
||||
RAIN_SITE: "home"
|
||||
RAIN_MODEL_NAME: "rain_next_1h"
|
||||
RAIN_MODEL_VERSION_BASE: "rain-logreg-v1"
|
||||
RAIN_LOOKBACK_DAYS: "30"
|
||||
RAIN_TRAIN_INTERVAL_HOURS: "24"
|
||||
RAIN_PREDICT_INTERVAL_MINUTES: "10"
|
||||
RAIN_MIN_PRECISION: "0.70"
|
||||
RAIN_MODEL_PATH: "/app/models/rain_model.pkl"
|
||||
RAIN_REPORT_PATH: "/app/models/rain_model_report.json"
|
||||
RAIN_AUDIT_PATH: "/app/models/rain_data_audit.json"
|
||||
volumes:
|
||||
- ./models:/app/models
|
||||
|
||||
volumes:
|
||||
tsdata:
|
||||
|
||||
@@ -40,6 +40,7 @@ pip install -r scripts/requirements.txt
|
||||
- `scripts/train_rain_model.py`: strict time-based split training and metrics report.
|
||||
- `scripts/predict_rain_model.py`: inference using saved model artifact; upserts into
|
||||
`predictions_rain_1h`.
|
||||
- `scripts/run_rain_ml_worker.py`: long-running worker for periodic training + prediction.
|
||||
|
||||
## Usage
|
||||
### 1) Apply schema update (existing DBs)
|
||||
@@ -90,6 +91,18 @@ export DATABASE_URL="postgres://postgres:postgres@localhost:5432/micrometeo?sslm
|
||||
bash scripts/run_p0_rain_workflow.sh
|
||||
```
|
||||
|
||||
### 6) Continuous training + prediction via Docker Compose
|
||||
The `rainml` service in `docker-compose.yml` now runs:
|
||||
- periodic retraining (default every 24 hours)
|
||||
- periodic prediction writes (default every 10 minutes)
|
||||
|
||||
Artifacts are persisted to `./models` on the host.
|
||||
|
||||
```sh
|
||||
docker compose up -d rainml
|
||||
docker compose logs -f rainml
|
||||
```
|
||||
|
||||
## Output
|
||||
- Audit report: `models/rain_data_audit.json`
|
||||
- Training report: `models/rain_model_report.json`
|
||||
|
||||
@@ -46,6 +46,19 @@ type ForecastSeries struct {
|
||||
Points []ForecastPoint `json:"points"`
|
||||
}
|
||||
|
||||
type RainPredictionPoint struct {
|
||||
TS time.Time `json:"ts"`
|
||||
GeneratedAt time.Time `json:"generated_at"`
|
||||
ModelName string `json:"model_name"`
|
||||
ModelVersion string `json:"model_version"`
|
||||
Threshold float64 `json:"threshold"`
|
||||
Probability float64 `json:"probability"`
|
||||
PredictRain bool `json:"predict_rain"`
|
||||
RainNext1hMM *float64 `json:"rain_next_1h_mm_actual,omitempty"`
|
||||
RainNext1hActual *bool `json:"rain_next_1h_actual,omitempty"`
|
||||
EvaluatedAt *time.Time `json:"evaluated_at,omitempty"`
|
||||
}
|
||||
|
||||
func (d *DB) ObservationSeries(ctx context.Context, site, bucket string, start, end time.Time) ([]ObservationPoint, error) {
|
||||
if end.Before(start) || end.Equal(start) {
|
||||
return nil, errors.New("invalid time range")
|
||||
@@ -352,6 +365,139 @@ func (d *DB) ForecastSeriesRange(ctx context.Context, site, model string, start,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *DB) LatestRainPrediction(ctx context.Context, site, modelName string) (*RainPredictionPoint, error) {
|
||||
query := `
|
||||
SELECT
|
||||
ts,
|
||||
generated_at,
|
||||
model_name,
|
||||
model_version,
|
||||
threshold,
|
||||
probability,
|
||||
predict_rain,
|
||||
rain_next_1h_mm_actual,
|
||||
rain_next_1h_actual,
|
||||
evaluated_at
|
||||
FROM predictions_rain_1h
|
||||
WHERE site = $1
|
||||
AND model_name = $2
|
||||
ORDER BY ts DESC, generated_at DESC
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
var (
|
||||
p RainPredictionPoint
|
||||
rainMM, threshold, probability sql.NullFloat64
|
||||
rainActual sql.NullBool
|
||||
evaluatedAt sql.NullTime
|
||||
predictRain sql.NullBool
|
||||
)
|
||||
|
||||
err := d.Pool.QueryRow(ctx, query, site, modelName).Scan(
|
||||
&p.TS,
|
||||
&p.GeneratedAt,
|
||||
&p.ModelName,
|
||||
&p.ModelVersion,
|
||||
&threshold,
|
||||
&probability,
|
||||
&predictRain,
|
||||
&rainMM,
|
||||
&rainActual,
|
||||
&evaluatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if threshold.Valid {
|
||||
p.Threshold = threshold.Float64
|
||||
}
|
||||
if probability.Valid {
|
||||
p.Probability = probability.Float64
|
||||
}
|
||||
if predictRain.Valid {
|
||||
p.PredictRain = predictRain.Bool
|
||||
}
|
||||
p.RainNext1hMM = nullFloatPtr(rainMM)
|
||||
p.RainNext1hActual = nullBoolPtr(rainActual)
|
||||
p.EvaluatedAt = nullTimePtr(evaluatedAt)
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
func (d *DB) RainPredictionSeriesRange(ctx context.Context, site, modelName string, start, end time.Time) ([]RainPredictionPoint, error) {
|
||||
query := `
|
||||
SELECT DISTINCT ON (ts)
|
||||
ts,
|
||||
generated_at,
|
||||
model_name,
|
||||
model_version,
|
||||
threshold,
|
||||
probability,
|
||||
predict_rain,
|
||||
rain_next_1h_mm_actual,
|
||||
rain_next_1h_actual,
|
||||
evaluated_at
|
||||
FROM predictions_rain_1h
|
||||
WHERE site = $1
|
||||
AND model_name = $2
|
||||
AND ts >= $3
|
||||
AND ts <= $4
|
||||
ORDER BY ts ASC, generated_at DESC
|
||||
`
|
||||
|
||||
rows, err := d.Pool.Query(ctx, query, site, modelName, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
points := make([]RainPredictionPoint, 0, 256)
|
||||
for rows.Next() {
|
||||
var (
|
||||
p RainPredictionPoint
|
||||
rainMM, threshold, probability sql.NullFloat64
|
||||
rainActual sql.NullBool
|
||||
evaluatedAt sql.NullTime
|
||||
predictRain sql.NullBool
|
||||
)
|
||||
if err := rows.Scan(
|
||||
&p.TS,
|
||||
&p.GeneratedAt,
|
||||
&p.ModelName,
|
||||
&p.ModelVersion,
|
||||
&threshold,
|
||||
&probability,
|
||||
&predictRain,
|
||||
&rainMM,
|
||||
&rainActual,
|
||||
&evaluatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if threshold.Valid {
|
||||
p.Threshold = threshold.Float64
|
||||
}
|
||||
if probability.Valid {
|
||||
p.Probability = probability.Float64
|
||||
}
|
||||
if predictRain.Valid {
|
||||
p.PredictRain = predictRain.Bool
|
||||
}
|
||||
p.RainNext1hMM = nullFloatPtr(rainMM)
|
||||
p.RainNext1hActual = nullBoolPtr(rainActual)
|
||||
p.EvaluatedAt = nullTimePtr(evaluatedAt)
|
||||
points = append(points, p)
|
||||
}
|
||||
if rows.Err() != nil {
|
||||
return nil, rows.Err()
|
||||
}
|
||||
|
||||
return points, nil
|
||||
}
|
||||
|
||||
func nullFloatPtr(v sql.NullFloat64) *float64 {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
@@ -367,3 +513,19 @@ func nullIntPtr(v sql.NullInt64) *int64 {
|
||||
val := v.Int64
|
||||
return &val
|
||||
}
|
||||
|
||||
func nullBoolPtr(v sql.NullBool) *bool {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
}
|
||||
val := v.Bool
|
||||
return &val
|
||||
}
|
||||
|
||||
func nullTimePtr(v sql.NullTime) *time.Time {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
}
|
||||
val := v.Time
|
||||
return &val
|
||||
}
|
||||
|
||||
228
scripts/run_rain_ml_worker.py
Normal file
228
scripts/run_rain_ml_worker.py
Normal file
@@ -0,0 +1,228 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def read_env(name: str, default: str) -> str:
|
||||
return os.getenv(name, default).strip()
|
||||
|
||||
|
||||
def read_env_float(name: str, default: float) -> float:
|
||||
raw = os.getenv(name)
|
||||
if raw is None or raw.strip() == "":
|
||||
return default
|
||||
return float(raw)
|
||||
|
||||
|
||||
def read_env_int(name: str, default: int) -> int:
|
||||
raw = os.getenv(name)
|
||||
if raw is None or raw.strip() == "":
|
||||
return default
|
||||
return int(raw)
|
||||
|
||||
|
||||
def read_env_bool(name: str, default: bool) -> bool:
|
||||
raw = os.getenv(name)
|
||||
if raw is None:
|
||||
return default
|
||||
return raw.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkerConfig:
|
||||
database_url: str
|
||||
site: str
|
||||
model_name: str
|
||||
model_version_base: str
|
||||
train_interval_hours: float
|
||||
predict_interval_minutes: float
|
||||
lookback_days: int
|
||||
train_ratio: float
|
||||
val_ratio: float
|
||||
min_precision: float
|
||||
model_path: Path
|
||||
report_path: Path
|
||||
audit_path: Path
|
||||
run_once: bool
|
||||
retry_delay_seconds: int
|
||||
|
||||
|
||||
def now_utc() -> datetime:
|
||||
return datetime.now(timezone.utc).replace(microsecond=0)
|
||||
|
||||
|
||||
def iso_utc(v: datetime) -> str:
|
||||
return v.astimezone(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def run_cmd(cmd: list[str], env: dict[str, str]) -> None:
|
||||
print(f"[rain-ml] running: {' '.join(cmd)}", flush=True)
|
||||
subprocess.run(cmd, env=env, check=True)
|
||||
|
||||
|
||||
def ensure_parent(path: Path) -> None:
|
||||
if path.parent and not path.parent.exists():
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def training_window(lookback_days: int) -> tuple[str, str]:
|
||||
end = now_utc()
|
||||
start = end - timedelta(days=lookback_days)
|
||||
return iso_utc(start), iso_utc(end)
|
||||
|
||||
|
||||
def run_training_cycle(cfg: WorkerConfig, env: dict[str, str]) -> None:
|
||||
start, end = training_window(cfg.lookback_days)
|
||||
model_version = f"{cfg.model_version_base}-{now_utc().strftime('%Y%m%d%H%M')}"
|
||||
|
||||
ensure_parent(cfg.audit_path)
|
||||
ensure_parent(cfg.report_path)
|
||||
ensure_parent(cfg.model_path)
|
||||
|
||||
run_cmd(
|
||||
[
|
||||
sys.executable,
|
||||
"scripts/audit_rain_data.py",
|
||||
"--site",
|
||||
cfg.site,
|
||||
"--start",
|
||||
start,
|
||||
"--end",
|
||||
end,
|
||||
"--out",
|
||||
str(cfg.audit_path),
|
||||
],
|
||||
env,
|
||||
)
|
||||
|
||||
run_cmd(
|
||||
[
|
||||
sys.executable,
|
||||
"scripts/train_rain_model.py",
|
||||
"--site",
|
||||
cfg.site,
|
||||
"--start",
|
||||
start,
|
||||
"--end",
|
||||
end,
|
||||
"--train-ratio",
|
||||
str(cfg.train_ratio),
|
||||
"--val-ratio",
|
||||
str(cfg.val_ratio),
|
||||
"--min-precision",
|
||||
str(cfg.min_precision),
|
||||
"--model-version",
|
||||
model_version,
|
||||
"--out",
|
||||
str(cfg.model_path),
|
||||
"--report-out",
|
||||
str(cfg.report_path),
|
||||
],
|
||||
env,
|
||||
)
|
||||
|
||||
|
||||
def run_predict_once(cfg: WorkerConfig, env: dict[str, str]) -> None:
|
||||
if not cfg.model_path.exists():
|
||||
raise RuntimeError(f"model artifact not found: {cfg.model_path}")
|
||||
|
||||
run_cmd(
|
||||
[
|
||||
sys.executable,
|
||||
"scripts/predict_rain_model.py",
|
||||
"--site",
|
||||
cfg.site,
|
||||
"--model-path",
|
||||
str(cfg.model_path),
|
||||
"--model-name",
|
||||
cfg.model_name,
|
||||
],
|
||||
env,
|
||||
)
|
||||
|
||||
|
||||
def load_config() -> WorkerConfig:
|
||||
database_url = read_env("DATABASE_URL", "")
|
||||
if not database_url:
|
||||
raise SystemExit("DATABASE_URL is required")
|
||||
|
||||
return WorkerConfig(
|
||||
database_url=database_url,
|
||||
site=read_env("RAIN_SITE", "home"),
|
||||
model_name=read_env("RAIN_MODEL_NAME", "rain_next_1h"),
|
||||
model_version_base=read_env("RAIN_MODEL_VERSION_BASE", "rain-logreg-v1"),
|
||||
train_interval_hours=read_env_float("RAIN_TRAIN_INTERVAL_HOURS", 24.0),
|
||||
predict_interval_minutes=read_env_float("RAIN_PREDICT_INTERVAL_MINUTES", 10.0),
|
||||
lookback_days=read_env_int("RAIN_LOOKBACK_DAYS", 30),
|
||||
train_ratio=read_env_float("RAIN_TRAIN_RATIO", 0.7),
|
||||
val_ratio=read_env_float("RAIN_VAL_RATIO", 0.15),
|
||||
min_precision=read_env_float("RAIN_MIN_PRECISION", 0.70),
|
||||
model_path=Path(read_env("RAIN_MODEL_PATH", "models/rain_model.pkl")),
|
||||
report_path=Path(read_env("RAIN_REPORT_PATH", "models/rain_model_report.json")),
|
||||
audit_path=Path(read_env("RAIN_AUDIT_PATH", "models/rain_data_audit.json")),
|
||||
run_once=read_env_bool("RAIN_RUN_ONCE", False),
|
||||
retry_delay_seconds=read_env_int("RAIN_RETRY_DELAY_SECONDS", 60),
|
||||
)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
cfg = load_config()
|
||||
env = os.environ.copy()
|
||||
env["DATABASE_URL"] = cfg.database_url
|
||||
|
||||
train_every = timedelta(hours=cfg.train_interval_hours)
|
||||
predict_every = timedelta(minutes=cfg.predict_interval_minutes)
|
||||
next_train = now_utc()
|
||||
next_predict = now_utc()
|
||||
trained_once = False
|
||||
predicted_once = False
|
||||
|
||||
print(
|
||||
"[rain-ml] worker start "
|
||||
f"site={cfg.site} "
|
||||
f"model_name={cfg.model_name} "
|
||||
f"train_interval_hours={cfg.train_interval_hours} "
|
||||
f"predict_interval_minutes={cfg.predict_interval_minutes}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
while True:
|
||||
now = now_utc()
|
||||
try:
|
||||
if now >= next_train:
|
||||
run_training_cycle(cfg, env)
|
||||
next_train = now + train_every
|
||||
trained_once = True
|
||||
|
||||
if now >= next_predict:
|
||||
run_predict_once(cfg, env)
|
||||
next_predict = now + predict_every
|
||||
predicted_once = True
|
||||
|
||||
if cfg.run_once and trained_once and predicted_once:
|
||||
print("[rain-ml] run-once complete", flush=True)
|
||||
return 0
|
||||
|
||||
except subprocess.CalledProcessError as exc:
|
||||
print(f"[rain-ml] command failed exit={exc.returncode}; retrying in {cfg.retry_delay_seconds}s", flush=True)
|
||||
time.sleep(cfg.retry_delay_seconds)
|
||||
continue
|
||||
except Exception as exc: # pragma: no cover - defensive for runtime worker
|
||||
print(f"[rain-ml] worker error: {exc}; retrying in {cfg.retry_delay_seconds}s", flush=True)
|
||||
time.sleep(cfg.retry_delay_seconds)
|
||||
continue
|
||||
|
||||
sleep_for = min((next_train - now).total_seconds(), (next_predict - now).total_seconds(), 30.0)
|
||||
if sleep_for > 0:
|
||||
time.sleep(sleep_for)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
14
todo.md
14
todo.md
@@ -9,9 +9,9 @@ Priority key: `P0` = critical/blocking, `P1` = important, `P2` = later optimizat
|
||||
- [x] [P0] Freeze training window with explicit UTC start/end timestamps.
|
||||
|
||||
## 2) Data Quality and Label Validation
|
||||
- [ ] [P0] Audit `observations_ws90` and `observations_baro` for missingness, gaps, duplicates, and out-of-order rows. (script ready: `scripts/audit_rain_data.py`; run on runtime machine)
|
||||
- [ ] [P0] Validate rain label construction from `rain_mm` (counter resets, negative deltas, spikes). (script ready: `scripts/audit_rain_data.py`; run on runtime machine)
|
||||
- [ ] [P0] Measure class balance by week (rain-positive vs rain-negative). (script ready: `scripts/audit_rain_data.py`; run on runtime machine)
|
||||
- [x] [P0] Audit `observations_ws90` and `observations_baro` for missingness, gaps, duplicates, and out-of-order rows. (completed on runtime machine)
|
||||
- [x] [P0] Validate rain label construction from `rain_mm` (counter resets, negative deltas, spikes). (completed on runtime machine)
|
||||
- [x] [P0] Measure class balance by week (rain-positive vs rain-negative). (completed on runtime machine)
|
||||
- [ ] [P1] Document known data issues and mitigation rules.
|
||||
|
||||
## 3) Dataset and Feature Engineering
|
||||
@@ -38,10 +38,10 @@ Priority key: `P0` = critical/blocking, `P1` = important, `P2` = later optimizat
|
||||
- [ ] [P1] Produce a short model card (data window, features, metrics, known limitations).
|
||||
|
||||
## 6) Packaging and Deployment
|
||||
- [ ] [P1] Version model artifacts and feature schema together.
|
||||
- [x] [P1] Version model artifacts and feature schema together.
|
||||
- [x] [P0] Implement inference path with feature parity between training and serving.
|
||||
- [x] [P0] Add prediction storage table for predicted probabilities and realized outcomes.
|
||||
- [ ] [P1] Expose predictions via API and optionally surface in web dashboard.
|
||||
- [x] [P1] Expose predictions via API and optionally surface in web dashboard.
|
||||
- [ ] [P2] Add scheduled retraining with rollback to last-known-good model.
|
||||
|
||||
## 7) Monitoring and Operations
|
||||
@@ -51,7 +51,7 @@ Priority key: `P0` = critical/blocking, `P1` = important, `P2` = later optimizat
|
||||
- [ ] [P1] Document runbook for train/evaluate/deploy/rollback.
|
||||
|
||||
## 8) Immediate Next Steps (This Week)
|
||||
- [ ] [P0] Run first full data audit and label-quality checks. (blocked here; run on runtime machine)
|
||||
- [ ] [P0] Train baseline model on full available history and capture metrics. (blocked here; run on runtime machine)
|
||||
- [x] [P0] Run first full data audit and label-quality checks. (completed on runtime machine)
|
||||
- [x] [P0] Train baseline model on full available history and capture metrics. (completed on runtime machine)
|
||||
- [ ] [P1] Add one expanded feature set and rerun evaluation.
|
||||
- [x] [P0] Decide v1 threshold and define deployment interface.
|
||||
|
||||
Reference in New Issue
Block a user