Files
go-weatherstation/scripts/compare_rain_reports.py

98 lines
3.3 KiB
Python

#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Compare two rain-model training reports.")
parser.add_argument("--baseline", required=True, help="Baseline report path (for example 1h).")
parser.add_argument("--candidate", required=True, help="Candidate report path (for example 4h).")
return parser.parse_args()
def load_json(path: str) -> dict[str, Any]:
p = Path(path)
if not p.exists():
raise FileNotFoundError(path)
with p.open("r", encoding="utf-8") as f:
return json.load(f)
def to_float(v: Any) -> float | None:
if v is None:
return None
try:
return float(v)
except (TypeError, ValueError):
return None
def metric(report: dict[str, Any], split: str, key: str) -> float | None:
return to_float(report.get(split, {}).get(key))
def delta_str(base: float | None, cand: float | None) -> str:
if base is None or cand is None:
return "n/a"
d = cand - base
return f"{d:+.4f}"
def main() -> int:
args = parse_args()
try:
baseline = load_json(args.baseline)
except FileNotFoundError:
print(f"error: baseline report not found: {args.baseline}")
model_dir = Path("models")
if model_dir.exists():
candidates = sorted(model_dir.glob("rain_model_report*.json"))
if candidates:
print("available report files:")
for c in candidates:
print(f" - {c}")
print("hint: provide an existing 1h report path, or train a new 1h report first.")
return 2
try:
candidate = load_json(args.candidate)
except FileNotFoundError:
print(f"error: candidate report not found: {args.candidate}")
return 2
pairs = [
("precision", metric(baseline, "test_metrics", "precision"), metric(candidate, "test_metrics", "precision")),
("recall", metric(baseline, "test_metrics", "recall"), metric(candidate, "test_metrics", "recall")),
("f1", metric(baseline, "test_metrics", "f1"), metric(candidate, "test_metrics", "f1")),
("pr_auc", metric(baseline, "test_metrics", "pr_auc"), metric(candidate, "test_metrics", "pr_auc")),
("roc_auc", metric(baseline, "test_metrics", "roc_auc"), metric(candidate, "test_metrics", "roc_auc")),
("brier", metric(baseline, "test_metrics", "brier"), metric(candidate, "test_metrics", "brier")),
]
print("Rain report comparison:")
print(
f" baseline: version={baseline.get('model_version')} "
f"horizon={baseline.get('horizon_hours')}h "
f"target={baseline.get('target_definition')}"
)
print(
f" candidate: version={candidate.get('model_version')} "
f"horizon={candidate.get('horizon_hours')}h "
f"target={candidate.get('target_definition')}"
)
print(" metrics (candidate - baseline):")
for name, base, cand in pairs:
base_txt = "n/a" if base is None else f"{base:.4f}"
cand_txt = "n/a" if cand is None else f"{cand:.4f}"
print(f" {name}: baseline={base_txt} candidate={cand_txt} delta={delta_str(base, cand)}")
return 0
if __name__ == "__main__":
raise SystemExit(main())