325 lines
11 KiB
Python
325 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import glob
|
|
import json
|
|
import math
|
|
import os
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
|
|
@dataclass
|
|
class Candidate:
|
|
path: Path
|
|
model_version: str
|
|
feature_set: str
|
|
model_family: str
|
|
generated_at: str | None
|
|
test_precision: float | None
|
|
test_recall: float | None
|
|
test_pr_auc: float | None
|
|
test_roc_auc: float | None
|
|
test_brier: float | None
|
|
wf_precision: float | None
|
|
wf_recall: float | None
|
|
wf_pr_auc: float | None
|
|
wf_brier: float | None
|
|
score: float
|
|
eligible: bool
|
|
ineligible_reasons: list[str]
|
|
report: dict[str, Any]
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Rank rain-model training reports and recommend a deploy candidate.")
|
|
parser.add_argument(
|
|
"--reports-glob",
|
|
default="models/rain_model_report*.json",
|
|
help="Glob for report JSON files.",
|
|
)
|
|
parser.add_argument("--min-test-precision", type=float, default=0.65)
|
|
parser.add_argument("--min-test-recall", type=float, default=0.50)
|
|
parser.add_argument("--min-test-pr-auc", type=float, default=0.40)
|
|
parser.add_argument("--min-walk-forward-precision", type=float, default=0.30)
|
|
parser.add_argument("--min-walk-forward-recall", type=float, default=0.25)
|
|
parser.add_argument(
|
|
"--require-walk-forward",
|
|
action="store_true",
|
|
help="Require walk-forward summary metrics to be present and pass minimums.",
|
|
)
|
|
parser.add_argument("--top-k", type=int, default=5)
|
|
parser.add_argument("--json-out", help="Optional output JSON path.")
|
|
return parser.parse_args()
|
|
|
|
|
|
def as_float(v: Any) -> float | None:
|
|
if v is None:
|
|
return None
|
|
try:
|
|
out = float(v)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
if math.isnan(out):
|
|
return None
|
|
return out
|
|
|
|
|
|
def load_report(path: Path) -> dict[str, Any]:
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def naive_precision_baseline(report: dict[str, Any]) -> float | None:
|
|
baselines = report.get("naive_baselines_test") or {}
|
|
out: float | None = None
|
|
for baseline in baselines.values():
|
|
metrics = baseline.get("metrics", {})
|
|
precision = as_float(metrics.get("precision"))
|
|
if precision is None:
|
|
continue
|
|
if out is None or precision > out:
|
|
out = precision
|
|
return out
|
|
|
|
|
|
def score_candidate(
|
|
report: dict[str, Any],
|
|
min_test_precision: float,
|
|
min_test_recall: float,
|
|
min_test_pr_auc: float,
|
|
min_wf_precision: float,
|
|
min_wf_recall: float,
|
|
require_walk_forward: bool,
|
|
) -> tuple[float, bool, list[str], dict[str, float | None]]:
|
|
test = report.get("test_metrics") or {}
|
|
wf_summary = (report.get("walk_forward_backtest") or {}).get("summary") or {}
|
|
|
|
test_precision = as_float(test.get("precision"))
|
|
test_recall = as_float(test.get("recall"))
|
|
test_pr_auc = as_float(test.get("pr_auc"))
|
|
test_roc_auc = as_float(test.get("roc_auc"))
|
|
test_brier = as_float(test.get("brier"))
|
|
|
|
wf_precision = as_float(wf_summary.get("mean_precision"))
|
|
wf_recall = as_float(wf_summary.get("mean_recall"))
|
|
wf_pr_auc = as_float(wf_summary.get("mean_pr_auc"))
|
|
wf_brier = as_float(wf_summary.get("mean_brier"))
|
|
|
|
metrics = {
|
|
"test_precision": test_precision,
|
|
"test_recall": test_recall,
|
|
"test_pr_auc": test_pr_auc,
|
|
"test_roc_auc": test_roc_auc,
|
|
"test_brier": test_brier,
|
|
"wf_precision": wf_precision,
|
|
"wf_recall": wf_recall,
|
|
"wf_pr_auc": wf_pr_auc,
|
|
"wf_brier": wf_brier,
|
|
}
|
|
|
|
reasons: list[str] = []
|
|
if test_precision is None or test_precision < min_test_precision:
|
|
reasons.append(f"test_precision<{min_test_precision:.2f}")
|
|
if test_recall is None or test_recall < min_test_recall:
|
|
reasons.append(f"test_recall<{min_test_recall:.2f}")
|
|
if test_pr_auc is None or test_pr_auc < min_test_pr_auc:
|
|
reasons.append(f"test_pr_auc<{min_test_pr_auc:.2f}")
|
|
|
|
if require_walk_forward and (wf_precision is None or wf_recall is None):
|
|
reasons.append("walk_forward_missing")
|
|
if wf_precision is not None and wf_precision < min_wf_precision:
|
|
reasons.append(f"wf_precision<{min_wf_precision:.2f}")
|
|
if wf_recall is not None and wf_recall < min_wf_recall:
|
|
reasons.append(f"wf_recall<{min_wf_recall:.2f}")
|
|
|
|
eligible = len(reasons) == 0
|
|
|
|
# Weighted utility score with stability penalty.
|
|
score = 0.0
|
|
if test_precision is not None:
|
|
score += 3.0 * test_precision
|
|
if test_recall is not None:
|
|
score += 2.5 * test_recall
|
|
if test_pr_auc is not None:
|
|
score += 2.5 * test_pr_auc
|
|
if test_roc_auc is not None:
|
|
score += 1.0 * test_roc_auc
|
|
if test_brier is not None:
|
|
score += 1.5 * (1.0 - min(max(test_brier, 0.0), 1.0))
|
|
|
|
if wf_precision is not None:
|
|
score += 2.0 * wf_precision
|
|
else:
|
|
score -= 0.25
|
|
if wf_recall is not None:
|
|
score += 1.5 * wf_recall
|
|
if wf_pr_auc is not None:
|
|
score += 1.0 * wf_pr_auc
|
|
if wf_brier is not None:
|
|
score += 1.0 * (1.0 - min(max(wf_brier, 0.0), 1.0))
|
|
|
|
if test_precision is not None and wf_precision is not None:
|
|
score -= 1.5 * abs(test_precision - wf_precision)
|
|
if test_recall is not None and wf_recall is not None:
|
|
score -= 1.0 * abs(test_recall - wf_recall)
|
|
|
|
best_naive_precision = naive_precision_baseline(report)
|
|
if best_naive_precision is not None and test_precision is not None:
|
|
gap = test_precision - best_naive_precision
|
|
score += 0.5 * gap
|
|
|
|
return score, eligible, reasons, metrics
|
|
|
|
|
|
def parse_generated_at(value: str | None) -> datetime:
|
|
if not value:
|
|
return datetime.min
|
|
try:
|
|
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
|
except ValueError:
|
|
return datetime.min
|
|
|
|
|
|
def build_candidate(path: Path, report: dict[str, Any], args: argparse.Namespace) -> Candidate:
|
|
score, eligible, reasons, metrics = score_candidate(
|
|
report=report,
|
|
min_test_precision=args.min_test_precision,
|
|
min_test_recall=args.min_test_recall,
|
|
min_test_pr_auc=args.min_test_pr_auc,
|
|
min_wf_precision=args.min_walk_forward_precision,
|
|
min_wf_recall=args.min_walk_forward_recall,
|
|
require_walk_forward=args.require_walk_forward,
|
|
)
|
|
return Candidate(
|
|
path=path,
|
|
model_version=str(report.get("model_version") or "unknown"),
|
|
feature_set=str(report.get("feature_set") or "unknown"),
|
|
model_family=str(report.get("model_family") or "unknown"),
|
|
generated_at=report.get("generated_at"),
|
|
test_precision=metrics["test_precision"],
|
|
test_recall=metrics["test_recall"],
|
|
test_pr_auc=metrics["test_pr_auc"],
|
|
test_roc_auc=metrics["test_roc_auc"],
|
|
test_brier=metrics["test_brier"],
|
|
wf_precision=metrics["wf_precision"],
|
|
wf_recall=metrics["wf_recall"],
|
|
wf_pr_auc=metrics["wf_pr_auc"],
|
|
wf_brier=metrics["wf_brier"],
|
|
score=score,
|
|
eligible=eligible,
|
|
ineligible_reasons=reasons,
|
|
report=report,
|
|
)
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
paths = sorted(Path(p) for p in glob.glob(args.reports_glob))
|
|
if not paths:
|
|
print(f"No report files matched: {args.reports_glob}")
|
|
return 1
|
|
|
|
candidates: list[Candidate] = []
|
|
for path in paths:
|
|
try:
|
|
report = load_report(path)
|
|
except Exception as exc:
|
|
print(f"skip {path}: {exc}")
|
|
continue
|
|
candidates.append(build_candidate(path=path, report=report, args=args))
|
|
|
|
if not candidates:
|
|
print("No valid reports loaded.")
|
|
return 1
|
|
|
|
candidates.sort(
|
|
key=lambda c: (
|
|
1 if c.eligible else 0,
|
|
c.score,
|
|
parse_generated_at(c.generated_at),
|
|
),
|
|
reverse=True,
|
|
)
|
|
|
|
print(f"Scanned {len(candidates)} report(s). Top {min(args.top_k, len(candidates))}:")
|
|
for idx, c in enumerate(candidates[: args.top_k], start=1):
|
|
wf_part = (
|
|
f"wf_prec={c.wf_precision:.3f} wf_rec={c.wf_recall:.3f}"
|
|
if c.wf_precision is not None and c.wf_recall is not None
|
|
else "wf=n/a"
|
|
)
|
|
gate_part = "eligible" if c.eligible else f"ineligible({','.join(c.ineligible_reasons)})"
|
|
print(
|
|
f"{idx}. {gate_part} score={c.score:.3f} "
|
|
f"version={c.model_version} feature_set={c.feature_set} family={c.model_family} "
|
|
f"test_prec={c.test_precision if c.test_precision is not None else 'n/a'} "
|
|
f"test_rec={c.test_recall if c.test_recall is not None else 'n/a'} "
|
|
f"test_pr_auc={c.test_pr_auc if c.test_pr_auc is not None else 'n/a'} "
|
|
f"{wf_part} "
|
|
f"path={c.path}"
|
|
)
|
|
|
|
recommendation = next((c for c in candidates if c.eligible), candidates[0])
|
|
print("")
|
|
print("Recommended candidate:")
|
|
print(f" model_version={recommendation.model_version}")
|
|
print(f" feature_set={recommendation.feature_set}")
|
|
print(f" model_family={recommendation.model_family}")
|
|
print(f" report_path={recommendation.path}")
|
|
print(f" score={recommendation.score:.3f}")
|
|
if not recommendation.eligible:
|
|
print(f" note=no fully eligible report; selected highest score with reasons={recommendation.ineligible_reasons}")
|
|
|
|
if args.json_out:
|
|
payload = {
|
|
"generated_at": datetime.utcnow().isoformat() + "Z",
|
|
"reports_glob": args.reports_glob,
|
|
"recommendation": {
|
|
"model_version": recommendation.model_version,
|
|
"feature_set": recommendation.feature_set,
|
|
"model_family": recommendation.model_family,
|
|
"report_path": str(recommendation.path),
|
|
"score": recommendation.score,
|
|
"eligible": recommendation.eligible,
|
|
"ineligible_reasons": recommendation.ineligible_reasons,
|
|
},
|
|
"ranked": [
|
|
{
|
|
"model_version": c.model_version,
|
|
"feature_set": c.feature_set,
|
|
"model_family": c.model_family,
|
|
"report_path": str(c.path),
|
|
"generated_at": c.generated_at,
|
|
"score": c.score,
|
|
"eligible": c.eligible,
|
|
"ineligible_reasons": c.ineligible_reasons,
|
|
"test_precision": c.test_precision,
|
|
"test_recall": c.test_recall,
|
|
"test_pr_auc": c.test_pr_auc,
|
|
"test_roc_auc": c.test_roc_auc,
|
|
"test_brier": c.test_brier,
|
|
"wf_precision": c.wf_precision,
|
|
"wf_recall": c.wf_recall,
|
|
"wf_pr_auc": c.wf_pr_auc,
|
|
"wf_brier": c.wf_brier,
|
|
}
|
|
for c in candidates
|
|
],
|
|
}
|
|
out_dir = os.path.dirname(args.json_out)
|
|
if out_dir:
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
with open(args.json_out, "w", encoding="utf-8") as f:
|
|
json.dump(payload, f, indent=2)
|
|
print(f"Saved recommendation JSON to {args.json_out}")
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|