98 lines
3.3 KiB
Python
98 lines
3.3 KiB
Python
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Compare two rain-model training reports.")
|
|
parser.add_argument("--baseline", required=True, help="Baseline report path (for example 1h).")
|
|
parser.add_argument("--candidate", required=True, help="Candidate report path (for example 4h).")
|
|
return parser.parse_args()
|
|
|
|
|
|
def load_json(path: str) -> dict[str, Any]:
|
|
p = Path(path)
|
|
if not p.exists():
|
|
raise FileNotFoundError(path)
|
|
with p.open("r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def to_float(v: Any) -> float | None:
|
|
if v is None:
|
|
return None
|
|
try:
|
|
return float(v)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
|
|
def metric(report: dict[str, Any], split: str, key: str) -> float | None:
|
|
return to_float(report.get(split, {}).get(key))
|
|
|
|
|
|
def delta_str(base: float | None, cand: float | None) -> str:
|
|
if base is None or cand is None:
|
|
return "n/a"
|
|
d = cand - base
|
|
return f"{d:+.4f}"
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
try:
|
|
baseline = load_json(args.baseline)
|
|
except FileNotFoundError:
|
|
print(f"error: baseline report not found: {args.baseline}")
|
|
model_dir = Path("models")
|
|
if model_dir.exists():
|
|
candidates = sorted(model_dir.glob("rain_model_report*.json"))
|
|
if candidates:
|
|
print("available report files:")
|
|
for c in candidates:
|
|
print(f" - {c}")
|
|
print("hint: provide an existing 1h report path, or train a new 1h report first.")
|
|
return 2
|
|
|
|
try:
|
|
candidate = load_json(args.candidate)
|
|
except FileNotFoundError:
|
|
print(f"error: candidate report not found: {args.candidate}")
|
|
return 2
|
|
|
|
pairs = [
|
|
("precision", metric(baseline, "test_metrics", "precision"), metric(candidate, "test_metrics", "precision")),
|
|
("recall", metric(baseline, "test_metrics", "recall"), metric(candidate, "test_metrics", "recall")),
|
|
("f1", metric(baseline, "test_metrics", "f1"), metric(candidate, "test_metrics", "f1")),
|
|
("pr_auc", metric(baseline, "test_metrics", "pr_auc"), metric(candidate, "test_metrics", "pr_auc")),
|
|
("roc_auc", metric(baseline, "test_metrics", "roc_auc"), metric(candidate, "test_metrics", "roc_auc")),
|
|
("brier", metric(baseline, "test_metrics", "brier"), metric(candidate, "test_metrics", "brier")),
|
|
]
|
|
|
|
print("Rain report comparison:")
|
|
print(
|
|
f" baseline: version={baseline.get('model_version')} "
|
|
f"horizon={baseline.get('horizon_hours')}h "
|
|
f"target={baseline.get('target_definition')}"
|
|
)
|
|
print(
|
|
f" candidate: version={candidate.get('model_version')} "
|
|
f"horizon={candidate.get('horizon_hours')}h "
|
|
f"target={candidate.get('target_definition')}"
|
|
)
|
|
print(" metrics (candidate - baseline):")
|
|
for name, base, cand in pairs:
|
|
base_txt = "n/a" if base is None else f"{base:.4f}"
|
|
cand_txt = "n/a" if cand is None else f"{cand:.4f}"
|
|
print(f" {name}: baseline={base_txt} candidate={cand_txt} delta={delta_str(base, cand)}")
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|