388 lines
12 KiB
Python
388 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timedelta, timezone
|
|
from pathlib import Path
|
|
|
|
|
|
def read_env(name: str, default: str) -> str:
|
|
return os.getenv(name, default).strip()
|
|
|
|
|
|
def read_env_float(name: str, default: float) -> float:
|
|
raw = os.getenv(name)
|
|
if raw is None or raw.strip() == "":
|
|
return default
|
|
return float(raw)
|
|
|
|
|
|
def read_env_int(name: str, default: int) -> int:
|
|
raw = os.getenv(name)
|
|
if raw is None or raw.strip() == "":
|
|
return default
|
|
return int(raw)
|
|
|
|
|
|
def read_env_bool(name: str, default: bool) -> bool:
|
|
raw = os.getenv(name)
|
|
if raw is None:
|
|
return default
|
|
return raw.strip().lower() in {"1", "true", "yes", "on"}
|
|
|
|
|
|
@dataclass
|
|
class WorkerConfig:
|
|
database_url: str
|
|
site: str
|
|
model_name: str
|
|
model_version_base: str
|
|
model_family: str
|
|
feature_set: str
|
|
forecast_model: str
|
|
train_interval_hours: float
|
|
predict_interval_minutes: float
|
|
lookback_days: int
|
|
train_ratio: float
|
|
val_ratio: float
|
|
min_precision: float
|
|
tune_hyperparameters: bool
|
|
max_hyperparam_trials: int
|
|
calibration_methods: str
|
|
threshold_policy: str
|
|
walk_forward_folds: int
|
|
allow_empty_data: bool
|
|
dataset_path_template: str
|
|
model_card_path_template: str
|
|
model_path: Path
|
|
model_backup_path: Path
|
|
report_path: Path
|
|
audit_path: Path
|
|
run_once: bool
|
|
retry_delay_seconds: int
|
|
|
|
|
|
def now_utc() -> datetime:
|
|
return datetime.now(timezone.utc).replace(microsecond=0)
|
|
|
|
|
|
def iso_utc(v: datetime) -> str:
|
|
return v.astimezone(timezone.utc).isoformat().replace("+00:00", "Z")
|
|
|
|
|
|
def run_cmd(cmd: list[str], env: dict[str, str]) -> None:
|
|
print(f"[rain-ml] running: {' '.join(cmd)}", flush=True)
|
|
subprocess.run(cmd, env=env, check=True)
|
|
|
|
|
|
def ensure_parent(path: Path) -> None:
|
|
if path.parent and not path.parent.exists():
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
def with_suffix(path: Path, suffix: str) -> Path:
|
|
return path.with_name(path.name + suffix)
|
|
|
|
|
|
def promote_file(candidate: Path, target: Path) -> bool:
|
|
if not candidate.exists():
|
|
return False
|
|
ensure_parent(target)
|
|
candidate.replace(target)
|
|
return True
|
|
|
|
|
|
def promote_model_candidate(candidate: Path, target: Path, backup: Path) -> bool:
|
|
if not candidate.exists():
|
|
return False
|
|
|
|
ensure_parent(target)
|
|
ensure_parent(backup)
|
|
if target.exists():
|
|
shutil.copy2(target, backup)
|
|
|
|
try:
|
|
candidate.replace(target)
|
|
return True
|
|
except Exception:
|
|
if backup.exists():
|
|
shutil.copy2(backup, target)
|
|
raise
|
|
|
|
|
|
def remove_if_exists(path: Path) -> None:
|
|
if path.exists():
|
|
path.unlink()
|
|
|
|
|
|
def training_window(lookback_days: int) -> tuple[str, str]:
|
|
end = now_utc()
|
|
start = end - timedelta(days=lookback_days)
|
|
return iso_utc(start), iso_utc(end)
|
|
|
|
|
|
def run_training_cycle(cfg: WorkerConfig, env: dict[str, str]) -> None:
|
|
start, end = training_window(cfg.lookback_days)
|
|
model_version = f"{cfg.model_version_base}-{now_utc().strftime('%Y%m%d%H%M')}"
|
|
dataset_out = cfg.dataset_path_template.format(model_version=model_version, feature_set=cfg.feature_set)
|
|
model_card_out = cfg.model_card_path_template.format(model_version=model_version)
|
|
model_candidate_path = with_suffix(cfg.model_path, ".candidate")
|
|
report_candidate_path = with_suffix(cfg.report_path, ".candidate")
|
|
audit_candidate_path = with_suffix(cfg.audit_path, ".candidate")
|
|
model_card_candidate_out = f"{model_card_out}.candidate" if model_card_out else ""
|
|
model_card_candidate_path = Path(model_card_candidate_out) if model_card_candidate_out else None
|
|
|
|
# Ensure promotions only use artifacts from the current training cycle.
|
|
remove_if_exists(model_candidate_path)
|
|
remove_if_exists(report_candidate_path)
|
|
remove_if_exists(audit_candidate_path)
|
|
if model_card_candidate_path is not None:
|
|
remove_if_exists(model_card_candidate_path)
|
|
|
|
ensure_parent(audit_candidate_path)
|
|
ensure_parent(report_candidate_path)
|
|
ensure_parent(cfg.model_path)
|
|
ensure_parent(model_candidate_path)
|
|
ensure_parent(cfg.model_backup_path)
|
|
if dataset_out:
|
|
ensure_parent(Path(dataset_out))
|
|
if model_card_candidate_path is not None:
|
|
ensure_parent(model_card_candidate_path)
|
|
if model_card_out:
|
|
ensure_parent(Path(model_card_out))
|
|
|
|
run_cmd(
|
|
[
|
|
sys.executable,
|
|
"scripts/audit_rain_data.py",
|
|
"--site",
|
|
cfg.site,
|
|
"--start",
|
|
start,
|
|
"--end",
|
|
end,
|
|
"--feature-set",
|
|
cfg.feature_set,
|
|
"--forecast-model",
|
|
cfg.forecast_model,
|
|
"--out",
|
|
str(audit_candidate_path),
|
|
],
|
|
env,
|
|
)
|
|
|
|
train_cmd = [
|
|
sys.executable,
|
|
"scripts/train_rain_model.py",
|
|
"--site",
|
|
cfg.site,
|
|
"--start",
|
|
start,
|
|
"--end",
|
|
end,
|
|
"--train-ratio",
|
|
str(cfg.train_ratio),
|
|
"--val-ratio",
|
|
str(cfg.val_ratio),
|
|
"--min-precision",
|
|
str(cfg.min_precision),
|
|
"--max-hyperparam-trials",
|
|
str(cfg.max_hyperparam_trials),
|
|
"--calibration-methods",
|
|
cfg.calibration_methods,
|
|
"--threshold-policy",
|
|
cfg.threshold_policy,
|
|
"--walk-forward-folds",
|
|
str(cfg.walk_forward_folds),
|
|
"--feature-set",
|
|
cfg.feature_set,
|
|
"--model-family",
|
|
cfg.model_family,
|
|
"--forecast-model",
|
|
cfg.forecast_model,
|
|
"--model-version",
|
|
model_version,
|
|
"--out",
|
|
str(model_candidate_path),
|
|
"--report-out",
|
|
str(report_candidate_path),
|
|
"--model-card-out",
|
|
model_card_candidate_out,
|
|
"--dataset-out",
|
|
dataset_out,
|
|
]
|
|
if cfg.tune_hyperparameters:
|
|
train_cmd.append("--tune-hyperparameters")
|
|
if cfg.allow_empty_data:
|
|
train_cmd.append("--allow-empty")
|
|
else:
|
|
train_cmd.append("--strict-source-data")
|
|
run_cmd(train_cmd, env)
|
|
|
|
promoted_model = promote_model_candidate(
|
|
candidate=model_candidate_path,
|
|
target=cfg.model_path,
|
|
backup=cfg.model_backup_path,
|
|
)
|
|
if not promoted_model:
|
|
print(
|
|
"[rain-ml] training completed without new model artifact; keeping last-known-good model",
|
|
flush=True,
|
|
)
|
|
return
|
|
|
|
promote_file(report_candidate_path, cfg.report_path)
|
|
promote_file(audit_candidate_path, cfg.audit_path)
|
|
if model_card_candidate_path is not None:
|
|
promote_file(model_card_candidate_path, Path(model_card_out))
|
|
print(
|
|
f"[rain-ml] promoted new model artifact to {cfg.model_path} (backup={cfg.model_backup_path})",
|
|
flush=True,
|
|
)
|
|
|
|
|
|
def run_predict_once(cfg: WorkerConfig, env: dict[str, str]) -> None:
|
|
if not cfg.model_path.exists():
|
|
if cfg.model_backup_path.exists():
|
|
ensure_parent(cfg.model_path)
|
|
shutil.copy2(cfg.model_backup_path, cfg.model_path)
|
|
print(f"[rain-ml] restored model from backup {cfg.model_backup_path}", flush=True)
|
|
else:
|
|
print(
|
|
f"[rain-ml] prediction skipped: model artifact not found ({cfg.model_path})",
|
|
flush=True,
|
|
)
|
|
return
|
|
|
|
run_cmd(
|
|
[
|
|
sys.executable,
|
|
"scripts/predict_rain_model.py",
|
|
"--site",
|
|
cfg.site,
|
|
"--model-path",
|
|
str(cfg.model_path),
|
|
"--model-name",
|
|
cfg.model_name,
|
|
"--forecast-model",
|
|
cfg.forecast_model,
|
|
*(["--allow-empty"] if cfg.allow_empty_data else ["--strict-source-data"]),
|
|
],
|
|
env,
|
|
)
|
|
|
|
|
|
def load_config() -> WorkerConfig:
|
|
database_url = read_env("DATABASE_URL", "")
|
|
if not database_url:
|
|
raise SystemExit("DATABASE_URL is required")
|
|
|
|
model_path = Path(read_env("RAIN_MODEL_PATH", "models/rain_model.pkl"))
|
|
backup_path_raw = read_env("RAIN_MODEL_BACKUP_PATH", "")
|
|
model_backup_path = Path(backup_path_raw) if backup_path_raw else with_suffix(model_path, ".last_good")
|
|
|
|
return WorkerConfig(
|
|
database_url=database_url,
|
|
site=read_env("RAIN_SITE", "home"),
|
|
model_name=read_env("RAIN_MODEL_NAME", "rain_next_1h"),
|
|
model_version_base=read_env("RAIN_MODEL_VERSION_BASE", "rain-logreg-v1"),
|
|
model_family=read_env("RAIN_MODEL_FAMILY", "logreg"),
|
|
feature_set=read_env("RAIN_FEATURE_SET", "baseline"),
|
|
forecast_model=read_env("RAIN_FORECAST_MODEL", "ecmwf"),
|
|
train_interval_hours=read_env_float("RAIN_TRAIN_INTERVAL_HOURS", 24.0),
|
|
predict_interval_minutes=read_env_float("RAIN_PREDICT_INTERVAL_MINUTES", 10.0),
|
|
lookback_days=read_env_int("RAIN_LOOKBACK_DAYS", 30),
|
|
train_ratio=read_env_float("RAIN_TRAIN_RATIO", 0.7),
|
|
val_ratio=read_env_float("RAIN_VAL_RATIO", 0.15),
|
|
min_precision=read_env_float("RAIN_MIN_PRECISION", 0.70),
|
|
tune_hyperparameters=read_env_bool("RAIN_TUNE_HYPERPARAMETERS", False),
|
|
max_hyperparam_trials=read_env_int("RAIN_MAX_HYPERPARAM_TRIALS", 12),
|
|
calibration_methods=read_env("RAIN_CALIBRATION_METHODS", "none,sigmoid,isotonic"),
|
|
threshold_policy=read_env("RAIN_THRESHOLD_POLICY", "walk_forward"),
|
|
walk_forward_folds=read_env_int("RAIN_WALK_FORWARD_FOLDS", 0),
|
|
allow_empty_data=read_env_bool("RAIN_ALLOW_EMPTY_DATA", True),
|
|
dataset_path_template=read_env(
|
|
"RAIN_DATASET_PATH",
|
|
"models/datasets/rain_dataset_{model_version}_{feature_set}.csv",
|
|
),
|
|
model_card_path_template=read_env(
|
|
"RAIN_MODEL_CARD_PATH",
|
|
"models/model_card_{model_version}.md",
|
|
),
|
|
model_path=model_path,
|
|
model_backup_path=model_backup_path,
|
|
report_path=Path(read_env("RAIN_REPORT_PATH", "models/rain_model_report.json")),
|
|
audit_path=Path(read_env("RAIN_AUDIT_PATH", "models/rain_data_audit.json")),
|
|
run_once=read_env_bool("RAIN_RUN_ONCE", False),
|
|
retry_delay_seconds=read_env_int("RAIN_RETRY_DELAY_SECONDS", 60),
|
|
)
|
|
|
|
|
|
def main() -> int:
|
|
cfg = load_config()
|
|
env = os.environ.copy()
|
|
env["DATABASE_URL"] = cfg.database_url
|
|
|
|
train_every = timedelta(hours=cfg.train_interval_hours)
|
|
predict_every = timedelta(minutes=cfg.predict_interval_minutes)
|
|
next_train = now_utc()
|
|
next_predict = now_utc()
|
|
trained_once = False
|
|
predicted_once = False
|
|
|
|
print(
|
|
"[rain-ml] worker start "
|
|
f"site={cfg.site} "
|
|
f"model_name={cfg.model_name} "
|
|
f"model_family={cfg.model_family} "
|
|
f"feature_set={cfg.feature_set} "
|
|
f"forecast_model={cfg.forecast_model} "
|
|
f"train_interval_hours={cfg.train_interval_hours} "
|
|
f"predict_interval_minutes={cfg.predict_interval_minutes} "
|
|
f"tune_hyperparameters={cfg.tune_hyperparameters} "
|
|
f"threshold_policy={cfg.threshold_policy} "
|
|
f"walk_forward_folds={cfg.walk_forward_folds} "
|
|
f"allow_empty_data={cfg.allow_empty_data} "
|
|
f"model_backup_path={cfg.model_backup_path}",
|
|
flush=True,
|
|
)
|
|
|
|
while True:
|
|
now = now_utc()
|
|
try:
|
|
if now >= next_train:
|
|
run_training_cycle(cfg, env)
|
|
next_train = now + train_every
|
|
trained_once = True
|
|
|
|
if now >= next_predict:
|
|
run_predict_once(cfg, env)
|
|
next_predict = now + predict_every
|
|
predicted_once = True
|
|
|
|
if cfg.run_once and trained_once and predicted_once:
|
|
print("[rain-ml] run-once complete", flush=True)
|
|
return 0
|
|
|
|
except subprocess.CalledProcessError as exc:
|
|
print(f"[rain-ml] command failed exit={exc.returncode}; retrying in {cfg.retry_delay_seconds}s", flush=True)
|
|
time.sleep(cfg.retry_delay_seconds)
|
|
continue
|
|
except Exception as exc: # pragma: no cover - defensive for runtime worker
|
|
print(f"[rain-ml] worker error: {exc}; retrying in {cfg.retry_delay_seconds}s", flush=True)
|
|
time.sleep(cfg.retry_delay_seconds)
|
|
continue
|
|
|
|
sleep_for = min((next_train - now).total_seconds(), (next_predict - now).total_seconds(), 30.0)
|
|
if sleep_for > 0:
|
|
time.sleep(sleep_for)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|