This commit is contained in:
2026-03-12 19:55:51 +11:00
parent 76851f0816
commit d1237eed44
12 changed files with 1444 additions and 82 deletions

View File

@@ -50,7 +50,13 @@ class WorkerConfig:
train_ratio: float
val_ratio: float
min_precision: float
tune_hyperparameters: bool
max_hyperparam_trials: int
calibration_methods: str
walk_forward_folds: int
allow_empty_data: bool
dataset_path_template: str
model_card_path_template: str
model_path: Path
report_path: Path
audit_path: Path
@@ -86,12 +92,15 @@ def run_training_cycle(cfg: WorkerConfig, env: dict[str, str]) -> None:
start, end = training_window(cfg.lookback_days)
model_version = f"{cfg.model_version_base}-{now_utc().strftime('%Y%m%d%H%M')}"
dataset_out = cfg.dataset_path_template.format(model_version=model_version, feature_set=cfg.feature_set)
model_card_out = cfg.model_card_path_template.format(model_version=model_version)
ensure_parent(cfg.audit_path)
ensure_parent(cfg.report_path)
ensure_parent(cfg.model_path)
if dataset_out:
ensure_parent(Path(dataset_out))
if model_card_out:
ensure_parent(Path(model_card_out))
run_cmd(
[
@@ -113,39 +122,51 @@ def run_training_cycle(cfg: WorkerConfig, env: dict[str, str]) -> None:
env,
)
run_cmd(
[
sys.executable,
"scripts/train_rain_model.py",
"--site",
cfg.site,
"--start",
start,
"--end",
end,
"--train-ratio",
str(cfg.train_ratio),
"--val-ratio",
str(cfg.val_ratio),
"--min-precision",
str(cfg.min_precision),
"--feature-set",
cfg.feature_set,
"--model-family",
cfg.model_family,
"--forecast-model",
cfg.forecast_model,
"--model-version",
model_version,
"--out",
str(cfg.model_path),
"--report-out",
str(cfg.report_path),
"--dataset-out",
dataset_out,
],
env,
)
train_cmd = [
sys.executable,
"scripts/train_rain_model.py",
"--site",
cfg.site,
"--start",
start,
"--end",
end,
"--train-ratio",
str(cfg.train_ratio),
"--val-ratio",
str(cfg.val_ratio),
"--min-precision",
str(cfg.min_precision),
"--max-hyperparam-trials",
str(cfg.max_hyperparam_trials),
"--calibration-methods",
cfg.calibration_methods,
"--walk-forward-folds",
str(cfg.walk_forward_folds),
"--feature-set",
cfg.feature_set,
"--model-family",
cfg.model_family,
"--forecast-model",
cfg.forecast_model,
"--model-version",
model_version,
"--out",
str(cfg.model_path),
"--report-out",
str(cfg.report_path),
"--model-card-out",
model_card_out,
"--dataset-out",
dataset_out,
]
if cfg.tune_hyperparameters:
train_cmd.append("--tune-hyperparameters")
if cfg.allow_empty_data:
train_cmd.append("--allow-empty")
else:
train_cmd.append("--strict-source-data")
run_cmd(train_cmd, env)
def run_predict_once(cfg: WorkerConfig, env: dict[str, str]) -> None:
@@ -164,6 +185,7 @@ def run_predict_once(cfg: WorkerConfig, env: dict[str, str]) -> None:
cfg.model_name,
"--forecast-model",
cfg.forecast_model,
*(["--allow-empty"] if cfg.allow_empty_data else ["--strict-source-data"]),
],
env,
)
@@ -188,10 +210,19 @@ def load_config() -> WorkerConfig:
train_ratio=read_env_float("RAIN_TRAIN_RATIO", 0.7),
val_ratio=read_env_float("RAIN_VAL_RATIO", 0.15),
min_precision=read_env_float("RAIN_MIN_PRECISION", 0.70),
tune_hyperparameters=read_env_bool("RAIN_TUNE_HYPERPARAMETERS", False),
max_hyperparam_trials=read_env_int("RAIN_MAX_HYPERPARAM_TRIALS", 12),
calibration_methods=read_env("RAIN_CALIBRATION_METHODS", "none,sigmoid,isotonic"),
walk_forward_folds=read_env_int("RAIN_WALK_FORWARD_FOLDS", 0),
allow_empty_data=read_env_bool("RAIN_ALLOW_EMPTY_DATA", True),
dataset_path_template=read_env(
"RAIN_DATASET_PATH",
"models/datasets/rain_dataset_{model_version}_{feature_set}.csv",
),
model_card_path_template=read_env(
"RAIN_MODEL_CARD_PATH",
"models/model_card_{model_version}.md",
),
model_path=Path(read_env("RAIN_MODEL_PATH", "models/rain_model.pkl")),
report_path=Path(read_env("RAIN_REPORT_PATH", "models/rain_model_report.json")),
audit_path=Path(read_env("RAIN_AUDIT_PATH", "models/rain_data_audit.json")),
@@ -220,7 +251,10 @@ def main() -> int:
f"feature_set={cfg.feature_set} "
f"forecast_model={cfg.forecast_model} "
f"train_interval_hours={cfg.train_interval_hours} "
f"predict_interval_minutes={cfg.predict_interval_minutes}",
f"predict_interval_minutes={cfg.predict_interval_minutes} "
f"tune_hyperparameters={cfg.tune_hyperparameters} "
f"walk_forward_folds={cfg.walk_forward_folds} "
f"allow_empty_data={cfg.allow_empty_data}",
flush=True,
)