#!/usr/bin/env python3 from __future__ import annotations import argparse import glob import json import math import os from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Any @dataclass class Candidate: path: Path model_version: str feature_set: str model_family: str generated_at: str | None test_precision: float | None test_recall: float | None test_pr_auc: float | None test_roc_auc: float | None test_brier: float | None wf_precision: float | None wf_recall: float | None wf_pr_auc: float | None wf_brier: float | None score: float eligible: bool ineligible_reasons: list[str] report: dict[str, Any] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Rank rain-model training reports and recommend a deploy candidate.") parser.add_argument( "--reports-glob", default="models/rain_model_report*.json", help="Glob for report JSON files.", ) parser.add_argument("--min-test-precision", type=float, default=0.65) parser.add_argument("--min-test-recall", type=float, default=0.50) parser.add_argument("--min-test-pr-auc", type=float, default=0.40) parser.add_argument("--min-walk-forward-precision", type=float, default=0.30) parser.add_argument("--min-walk-forward-recall", type=float, default=0.25) parser.add_argument( "--require-walk-forward", action="store_true", help="Require walk-forward summary metrics to be present and pass minimums.", ) parser.add_argument("--top-k", type=int, default=5) parser.add_argument("--json-out", help="Optional output JSON path.") return parser.parse_args() def as_float(v: Any) -> float | None: if v is None: return None try: out = float(v) except (TypeError, ValueError): return None if math.isnan(out): return None return out def load_report(path: Path) -> dict[str, Any]: with open(path, "r", encoding="utf-8") as f: return json.load(f) def naive_precision_baseline(report: dict[str, Any]) -> float | None: baselines = report.get("naive_baselines_test") or {} out: float | None = None for baseline in baselines.values(): metrics = baseline.get("metrics", {}) precision = as_float(metrics.get("precision")) if precision is None: continue if out is None or precision > out: out = precision return out def score_candidate( report: dict[str, Any], min_test_precision: float, min_test_recall: float, min_test_pr_auc: float, min_wf_precision: float, min_wf_recall: float, require_walk_forward: bool, ) -> tuple[float, bool, list[str], dict[str, float | None]]: test = report.get("test_metrics") or {} wf_summary = (report.get("walk_forward_backtest") or {}).get("summary") or {} test_precision = as_float(test.get("precision")) test_recall = as_float(test.get("recall")) test_pr_auc = as_float(test.get("pr_auc")) test_roc_auc = as_float(test.get("roc_auc")) test_brier = as_float(test.get("brier")) wf_precision = as_float(wf_summary.get("mean_precision")) wf_recall = as_float(wf_summary.get("mean_recall")) wf_pr_auc = as_float(wf_summary.get("mean_pr_auc")) wf_brier = as_float(wf_summary.get("mean_brier")) metrics = { "test_precision": test_precision, "test_recall": test_recall, "test_pr_auc": test_pr_auc, "test_roc_auc": test_roc_auc, "test_brier": test_brier, "wf_precision": wf_precision, "wf_recall": wf_recall, "wf_pr_auc": wf_pr_auc, "wf_brier": wf_brier, } reasons: list[str] = [] if test_precision is None or test_precision < min_test_precision: reasons.append(f"test_precision<{min_test_precision:.2f}") if test_recall is None or test_recall < min_test_recall: reasons.append(f"test_recall<{min_test_recall:.2f}") if test_pr_auc is None or test_pr_auc < min_test_pr_auc: reasons.append(f"test_pr_auc<{min_test_pr_auc:.2f}") if require_walk_forward and (wf_precision is None or wf_recall is None): reasons.append("walk_forward_missing") if wf_precision is not None and wf_precision < min_wf_precision: reasons.append(f"wf_precision<{min_wf_precision:.2f}") if wf_recall is not None and wf_recall < min_wf_recall: reasons.append(f"wf_recall<{min_wf_recall:.2f}") eligible = len(reasons) == 0 # Weighted utility score with stability penalty. score = 0.0 if test_precision is not None: score += 3.0 * test_precision if test_recall is not None: score += 2.5 * test_recall if test_pr_auc is not None: score += 2.5 * test_pr_auc if test_roc_auc is not None: score += 1.0 * test_roc_auc if test_brier is not None: score += 1.5 * (1.0 - min(max(test_brier, 0.0), 1.0)) if wf_precision is not None: score += 2.0 * wf_precision else: score -= 0.25 if wf_recall is not None: score += 1.5 * wf_recall if wf_pr_auc is not None: score += 1.0 * wf_pr_auc if wf_brier is not None: score += 1.0 * (1.0 - min(max(wf_brier, 0.0), 1.0)) if test_precision is not None and wf_precision is not None: score -= 1.5 * abs(test_precision - wf_precision) if test_recall is not None and wf_recall is not None: score -= 1.0 * abs(test_recall - wf_recall) best_naive_precision = naive_precision_baseline(report) if best_naive_precision is not None and test_precision is not None: gap = test_precision - best_naive_precision score += 0.5 * gap return score, eligible, reasons, metrics def parse_generated_at(value: str | None) -> datetime: if not value: return datetime.min try: return datetime.fromisoformat(value.replace("Z", "+00:00")) except ValueError: return datetime.min def build_candidate(path: Path, report: dict[str, Any], args: argparse.Namespace) -> Candidate: score, eligible, reasons, metrics = score_candidate( report=report, min_test_precision=args.min_test_precision, min_test_recall=args.min_test_recall, min_test_pr_auc=args.min_test_pr_auc, min_wf_precision=args.min_walk_forward_precision, min_wf_recall=args.min_walk_forward_recall, require_walk_forward=args.require_walk_forward, ) return Candidate( path=path, model_version=str(report.get("model_version") or "unknown"), feature_set=str(report.get("feature_set") or "unknown"), model_family=str(report.get("model_family") or "unknown"), generated_at=report.get("generated_at"), test_precision=metrics["test_precision"], test_recall=metrics["test_recall"], test_pr_auc=metrics["test_pr_auc"], test_roc_auc=metrics["test_roc_auc"], test_brier=metrics["test_brier"], wf_precision=metrics["wf_precision"], wf_recall=metrics["wf_recall"], wf_pr_auc=metrics["wf_pr_auc"], wf_brier=metrics["wf_brier"], score=score, eligible=eligible, ineligible_reasons=reasons, report=report, ) def main() -> int: args = parse_args() paths = sorted(Path(p) for p in glob.glob(args.reports_glob)) if not paths: print(f"No report files matched: {args.reports_glob}") return 1 candidates: list[Candidate] = [] for path in paths: try: report = load_report(path) except Exception as exc: print(f"skip {path}: {exc}") continue candidates.append(build_candidate(path=path, report=report, args=args)) if not candidates: print("No valid reports loaded.") return 1 candidates.sort( key=lambda c: ( 1 if c.eligible else 0, c.score, parse_generated_at(c.generated_at), ), reverse=True, ) print(f"Scanned {len(candidates)} report(s). Top {min(args.top_k, len(candidates))}:") for idx, c in enumerate(candidates[: args.top_k], start=1): wf_part = ( f"wf_prec={c.wf_precision:.3f} wf_rec={c.wf_recall:.3f}" if c.wf_precision is not None and c.wf_recall is not None else "wf=n/a" ) gate_part = "eligible" if c.eligible else f"ineligible({','.join(c.ineligible_reasons)})" print( f"{idx}. {gate_part} score={c.score:.3f} " f"version={c.model_version} feature_set={c.feature_set} family={c.model_family} " f"test_prec={c.test_precision if c.test_precision is not None else 'n/a'} " f"test_rec={c.test_recall if c.test_recall is not None else 'n/a'} " f"test_pr_auc={c.test_pr_auc if c.test_pr_auc is not None else 'n/a'} " f"{wf_part} " f"path={c.path}" ) recommendation = next((c for c in candidates if c.eligible), candidates[0]) print("") print("Recommended candidate:") print(f" model_version={recommendation.model_version}") print(f" feature_set={recommendation.feature_set}") print(f" model_family={recommendation.model_family}") print(f" report_path={recommendation.path}") print(f" score={recommendation.score:.3f}") if not recommendation.eligible: print(f" note=no fully eligible report; selected highest score with reasons={recommendation.ineligible_reasons}") if args.json_out: payload = { "generated_at": datetime.utcnow().isoformat() + "Z", "reports_glob": args.reports_glob, "recommendation": { "model_version": recommendation.model_version, "feature_set": recommendation.feature_set, "model_family": recommendation.model_family, "report_path": str(recommendation.path), "score": recommendation.score, "eligible": recommendation.eligible, "ineligible_reasons": recommendation.ineligible_reasons, }, "ranked": [ { "model_version": c.model_version, "feature_set": c.feature_set, "model_family": c.model_family, "report_path": str(c.path), "generated_at": c.generated_at, "score": c.score, "eligible": c.eligible, "ineligible_reasons": c.ineligible_reasons, "test_precision": c.test_precision, "test_recall": c.test_recall, "test_pr_auc": c.test_pr_auc, "test_roc_auc": c.test_roc_auc, "test_brier": c.test_brier, "wf_precision": c.wf_precision, "wf_recall": c.wf_recall, "wf_pr_auc": c.wf_pr_auc, "wf_brier": c.wf_brier, } for c in candidates ], } out_dir = os.path.dirname(args.json_out) if out_dir: os.makedirs(out_dir, exist_ok=True) with open(args.json_out, "w", encoding="utf-8") as f: json.dump(payload, f, indent=2) print(f"Saved recommendation JSON to {args.json_out}") return 0 if __name__ == "__main__": raise SystemExit(main())