135 lines
4.0 KiB
Python
135 lines
4.0 KiB
Python
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Evaluate go/no-go cutover gate for rain model reports.")
|
|
parser.add_argument("--baseline", required=True, help="Baseline report JSON path (for example 1h production).")
|
|
parser.add_argument("--candidate", required=True, help="Candidate report JSON path (for example 4h shadow).")
|
|
parser.add_argument(
|
|
"--min-candidate-precision",
|
|
type=float,
|
|
default=0.60,
|
|
help="Minimum allowed candidate test precision.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-precision-drop",
|
|
type=float,
|
|
default=0.05,
|
|
help="Maximum allowed drop: candidate_precision >= baseline_precision - value.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-pr-auc-drop",
|
|
type=float,
|
|
default=0.05,
|
|
help="Maximum allowed drop: candidate_pr_auc >= baseline_pr_auc - value.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-roc-auc-drop",
|
|
type=float,
|
|
default=0.05,
|
|
help="Maximum allowed drop: candidate_roc_auc >= baseline_roc_auc - value.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-brier-increase",
|
|
type=float,
|
|
default=0.03,
|
|
help="Maximum allowed increase: candidate_brier <= baseline_brier + value.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def load_report(path: str) -> dict[str, Any]:
|
|
p = Path(path)
|
|
if not p.exists():
|
|
raise FileNotFoundError(path)
|
|
with p.open("r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def metric(report: dict[str, Any], key: str) -> float:
|
|
value = report.get("test_metrics", {}).get(key)
|
|
if value is None:
|
|
raise ValueError(f"missing test metric: {key}")
|
|
return float(value)
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
baseline = load_report(args.baseline)
|
|
candidate = load_report(args.candidate)
|
|
|
|
b_precision = metric(baseline, "precision")
|
|
c_precision = metric(candidate, "precision")
|
|
b_pr_auc = metric(baseline, "pr_auc")
|
|
c_pr_auc = metric(candidate, "pr_auc")
|
|
b_roc_auc = metric(baseline, "roc_auc")
|
|
c_roc_auc = metric(candidate, "roc_auc")
|
|
b_brier = metric(baseline, "brier")
|
|
c_brier = metric(candidate, "brier")
|
|
|
|
checks: list[tuple[str, bool, str]] = []
|
|
checks.append(
|
|
(
|
|
"candidate_precision_floor",
|
|
c_precision >= args.min_candidate_precision,
|
|
f"{c_precision:.4f} >= {args.min_candidate_precision:.4f}",
|
|
)
|
|
)
|
|
checks.append(
|
|
(
|
|
"precision_drop",
|
|
c_precision >= (b_precision - args.max_precision_drop),
|
|
f"{c_precision:.4f} >= {b_precision - args.max_precision_drop:.4f}",
|
|
)
|
|
)
|
|
checks.append(
|
|
(
|
|
"pr_auc_drop",
|
|
c_pr_auc >= (b_pr_auc - args.max_pr_auc_drop),
|
|
f"{c_pr_auc:.4f} >= {b_pr_auc - args.max_pr_auc_drop:.4f}",
|
|
)
|
|
)
|
|
checks.append(
|
|
(
|
|
"roc_auc_drop",
|
|
c_roc_auc >= (b_roc_auc - args.max_roc_auc_drop),
|
|
f"{c_roc_auc:.4f} >= {b_roc_auc - args.max_roc_auc_drop:.4f}",
|
|
)
|
|
)
|
|
checks.append(
|
|
(
|
|
"brier_increase",
|
|
c_brier <= (b_brier + args.max_brier_increase),
|
|
f"{c_brier:.4f} <= {b_brier + args.max_brier_increase:.4f}",
|
|
)
|
|
)
|
|
|
|
print("Rain cutover gate:")
|
|
print(f" baseline: {args.baseline}")
|
|
print(f" candidate: {args.candidate}")
|
|
print(f" baseline_version={baseline.get('model_version')} candidate_version={candidate.get('model_version')}")
|
|
|
|
failures: list[str] = []
|
|
for name, ok, detail in checks:
|
|
status = "ok" if ok else "fail"
|
|
print(f" {name}: {status} ({detail})")
|
|
if not ok:
|
|
failures.append(name)
|
|
|
|
if failures:
|
|
print(f"cutover_decision: FAIL ({', '.join(failures)})")
|
|
return 1
|
|
|
|
print("cutover_decision: PASS")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|