another bugfix

This commit is contained in:
2026-03-12 20:29:29 +11:00
parent d1237eed44
commit 20316cee91
8 changed files with 293 additions and 23 deletions

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import os
import shutil
import subprocess
import sys
import time
@@ -58,6 +59,7 @@ class WorkerConfig:
dataset_path_template: str
model_card_path_template: str
model_path: Path
model_backup_path: Path
report_path: Path
audit_path: Path
run_once: bool
@@ -82,6 +84,41 @@ def ensure_parent(path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
def with_suffix(path: Path, suffix: str) -> Path:
return path.with_name(path.name + suffix)
def promote_file(candidate: Path, target: Path) -> bool:
if not candidate.exists():
return False
ensure_parent(target)
candidate.replace(target)
return True
def promote_model_candidate(candidate: Path, target: Path, backup: Path) -> bool:
if not candidate.exists():
return False
ensure_parent(target)
ensure_parent(backup)
if target.exists():
shutil.copy2(target, backup)
try:
candidate.replace(target)
return True
except Exception:
if backup.exists():
shutil.copy2(backup, target)
raise
def remove_if_exists(path: Path) -> None:
if path.exists():
path.unlink()
def training_window(lookback_days: int) -> tuple[str, str]:
end = now_utc()
start = end - timedelta(days=lookback_days)
@@ -93,12 +130,28 @@ def run_training_cycle(cfg: WorkerConfig, env: dict[str, str]) -> None:
model_version = f"{cfg.model_version_base}-{now_utc().strftime('%Y%m%d%H%M')}"
dataset_out = cfg.dataset_path_template.format(model_version=model_version, feature_set=cfg.feature_set)
model_card_out = cfg.model_card_path_template.format(model_version=model_version)
model_candidate_path = with_suffix(cfg.model_path, ".candidate")
report_candidate_path = with_suffix(cfg.report_path, ".candidate")
audit_candidate_path = with_suffix(cfg.audit_path, ".candidate")
model_card_candidate_out = f"{model_card_out}.candidate" if model_card_out else ""
model_card_candidate_path = Path(model_card_candidate_out) if model_card_candidate_out else None
ensure_parent(cfg.audit_path)
ensure_parent(cfg.report_path)
# Ensure promotions only use artifacts from the current training cycle.
remove_if_exists(model_candidate_path)
remove_if_exists(report_candidate_path)
remove_if_exists(audit_candidate_path)
if model_card_candidate_path is not None:
remove_if_exists(model_card_candidate_path)
ensure_parent(audit_candidate_path)
ensure_parent(report_candidate_path)
ensure_parent(cfg.model_path)
ensure_parent(model_candidate_path)
ensure_parent(cfg.model_backup_path)
if dataset_out:
ensure_parent(Path(dataset_out))
if model_card_candidate_path is not None:
ensure_parent(model_card_candidate_path)
if model_card_out:
ensure_parent(Path(model_card_out))
@@ -117,7 +170,7 @@ def run_training_cycle(cfg: WorkerConfig, env: dict[str, str]) -> None:
"--forecast-model",
cfg.forecast_model,
"--out",
str(cfg.audit_path),
str(audit_candidate_path),
],
env,
)
@@ -152,11 +205,11 @@ def run_training_cycle(cfg: WorkerConfig, env: dict[str, str]) -> None:
"--model-version",
model_version,
"--out",
str(cfg.model_path),
str(model_candidate_path),
"--report-out",
str(cfg.report_path),
str(report_candidate_path),
"--model-card-out",
model_card_out,
model_card_candidate_out,
"--dataset-out",
dataset_out,
]
@@ -168,10 +221,40 @@ def run_training_cycle(cfg: WorkerConfig, env: dict[str, str]) -> None:
train_cmd.append("--strict-source-data")
run_cmd(train_cmd, env)
promoted_model = promote_model_candidate(
candidate=model_candidate_path,
target=cfg.model_path,
backup=cfg.model_backup_path,
)
if not promoted_model:
print(
"[rain-ml] training completed without new model artifact; keeping last-known-good model",
flush=True,
)
return
promote_file(report_candidate_path, cfg.report_path)
promote_file(audit_candidate_path, cfg.audit_path)
if model_card_candidate_path is not None:
promote_file(model_card_candidate_path, Path(model_card_out))
print(
f"[rain-ml] promoted new model artifact to {cfg.model_path} (backup={cfg.model_backup_path})",
flush=True,
)
def run_predict_once(cfg: WorkerConfig, env: dict[str, str]) -> None:
if not cfg.model_path.exists():
raise RuntimeError(f"model artifact not found: {cfg.model_path}")
if cfg.model_backup_path.exists():
ensure_parent(cfg.model_path)
shutil.copy2(cfg.model_backup_path, cfg.model_path)
print(f"[rain-ml] restored model from backup {cfg.model_backup_path}", flush=True)
else:
print(
f"[rain-ml] prediction skipped: model artifact not found ({cfg.model_path})",
flush=True,
)
return
run_cmd(
[
@@ -196,6 +279,10 @@ def load_config() -> WorkerConfig:
if not database_url:
raise SystemExit("DATABASE_URL is required")
model_path = Path(read_env("RAIN_MODEL_PATH", "models/rain_model.pkl"))
backup_path_raw = read_env("RAIN_MODEL_BACKUP_PATH", "")
model_backup_path = Path(backup_path_raw) if backup_path_raw else with_suffix(model_path, ".last_good")
return WorkerConfig(
database_url=database_url,
site=read_env("RAIN_SITE", "home"),
@@ -223,7 +310,8 @@ def load_config() -> WorkerConfig:
"RAIN_MODEL_CARD_PATH",
"models/model_card_{model_version}.md",
),
model_path=Path(read_env("RAIN_MODEL_PATH", "models/rain_model.pkl")),
model_path=model_path,
model_backup_path=model_backup_path,
report_path=Path(read_env("RAIN_REPORT_PATH", "models/rain_model_report.json")),
audit_path=Path(read_env("RAIN_AUDIT_PATH", "models/rain_data_audit.json")),
run_once=read_env_bool("RAIN_RUN_ONCE", False),
@@ -254,7 +342,8 @@ def main() -> int:
f"predict_interval_minutes={cfg.predict_interval_minutes} "
f"tune_hyperparameters={cfg.tune_hyperparameters} "
f"walk_forward_folds={cfg.walk_forward_folds} "
f"allow_empty_data={cfg.allow_empty_data}",
f"allow_empty_data={cfg.allow_empty_data} "
f"model_backup_path={cfg.model_backup_path}",
flush=True,
)