another bugfix
This commit is contained in:
@@ -9,6 +9,7 @@ from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import psycopg2
|
||||
from sklearn.calibration import CalibratedClassifierCV
|
||||
from sklearn.ensemble import HistGradientBoostingClassifier
|
||||
@@ -499,6 +500,81 @@ def evaluate_naive_baselines(test_df, y_test: np.ndarray) -> dict[str, Any]:
|
||||
return out
|
||||
|
||||
|
||||
def evaluate_sliced_performance(
|
||||
test_df,
|
||||
y_true: np.ndarray,
|
||||
y_prob: np.ndarray,
|
||||
threshold: float,
|
||||
min_rows_per_slice: int = 30,
|
||||
) -> dict[str, Any]:
|
||||
frame = pd.DataFrame(
|
||||
{
|
||||
"y_true": y_true.astype(int),
|
||||
"y_prob": y_prob.astype(float),
|
||||
},
|
||||
index=test_df.index,
|
||||
)
|
||||
overall_rate = float(np.mean(y_true))
|
||||
hour = frame.index.hour
|
||||
is_day = (hour >= 6) & (hour < 18)
|
||||
|
||||
weekly_key = frame.index.to_series().dt.isocalendar()
|
||||
week_label = weekly_key["year"].astype(str) + "-W" + weekly_key["week"].astype(str).str.zfill(2)
|
||||
weekly_positive_rate = frame.groupby(week_label)["y_true"].transform("mean")
|
||||
rainy_week = weekly_positive_rate >= overall_rate
|
||||
|
||||
rain_context = test_df["rain_last_1h_mm"].to_numpy(dtype=float) if "rain_last_1h_mm" in test_df.columns else np.zeros(len(test_df))
|
||||
wet_context = rain_context >= RAIN_EVENT_THRESHOLD_MM
|
||||
|
||||
wind_values = test_df["wind_max_m_s"].to_numpy(dtype=float) if "wind_max_m_s" in test_df.columns else np.full(len(test_df), np.nan)
|
||||
if np.isfinite(wind_values).any():
|
||||
wind_q75 = float(np.nanquantile(wind_values, 0.75))
|
||||
windy = np.nan_to_num(wind_values, nan=wind_q75) >= wind_q75
|
||||
else:
|
||||
windy = np.zeros(len(test_df), dtype=bool)
|
||||
|
||||
definitions: list[tuple[str, np.ndarray, str]] = [
|
||||
("daytime_utc", np.asarray(is_day, dtype=bool), "06:00-17:59 UTC"),
|
||||
("nighttime_utc", np.asarray(~is_day, dtype=bool), "18:00-05:59 UTC"),
|
||||
("rainy_weeks", np.asarray(rainy_week, dtype=bool), "weeks with positive-rate >= test positive-rate"),
|
||||
("non_rainy_weeks", np.asarray(~rainy_week, dtype=bool), "weeks with positive-rate < test positive-rate"),
|
||||
("wet_context_last_1h", np.asarray(wet_context, dtype=bool), f"rain_last_1h_mm >= {RAIN_EVENT_THRESHOLD_MM:.2f}"),
|
||||
("dry_context_last_1h", np.asarray(~wet_context, dtype=bool), f"rain_last_1h_mm < {RAIN_EVENT_THRESHOLD_MM:.2f}"),
|
||||
("windy_q75", np.asarray(windy, dtype=bool), "wind_max_m_s >= test 75th percentile"),
|
||||
("calm_below_q75", np.asarray(~windy, dtype=bool), "wind_max_m_s < test 75th percentile"),
|
||||
]
|
||||
|
||||
out: dict[str, Any] = {}
|
||||
for name, mask, description in definitions:
|
||||
rows = int(np.sum(mask))
|
||||
if rows == 0:
|
||||
out[name] = {
|
||||
"rows": rows,
|
||||
"description": description,
|
||||
"status": "empty",
|
||||
}
|
||||
continue
|
||||
y_slice = y_true[mask]
|
||||
p_slice = y_prob[mask]
|
||||
if rows < min_rows_per_slice:
|
||||
out[name] = {
|
||||
"rows": rows,
|
||||
"description": description,
|
||||
"status": "insufficient_rows",
|
||||
"min_rows_required": min_rows_per_slice,
|
||||
}
|
||||
continue
|
||||
metrics = evaluate_probs(y_true=y_slice, y_prob=p_slice, threshold=threshold)
|
||||
out[name] = {
|
||||
"rows": rows,
|
||||
"description": description,
|
||||
"status": "ok",
|
||||
"metrics": metrics,
|
||||
"ece_10": expected_calibration_error(y_true=y_slice, y_prob=p_slice, bins=10),
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
def walk_forward_backtest(
|
||||
model_df,
|
||||
feature_cols: list[str],
|
||||
@@ -683,6 +759,25 @@ def write_model_card(path: str, report: dict[str, Any]) -> None:
|
||||
f"PR-AUC `{report['test_metrics']['pr_auc']}`, "
|
||||
f"ROC-AUC `{report['test_metrics']['roc_auc']}`, "
|
||||
f"Brier `{report['test_metrics']['brier']:.4f}`",
|
||||
"",
|
||||
"## Sliced Performance (Test)",
|
||||
"",
|
||||
]
|
||||
)
|
||||
for slice_name, info in report.get("sliced_performance_test", {}).items():
|
||||
if info.get("status") != "ok":
|
||||
continue
|
||||
metrics = info["metrics"]
|
||||
lines.append(
|
||||
f"- `{slice_name}` ({info['rows']} rows): "
|
||||
f"precision `{metrics['precision']:.3f}`, "
|
||||
f"recall `{metrics['recall']:.3f}`, "
|
||||
f"PR-AUC `{metrics['pr_auc']}`, "
|
||||
f"Brier `{metrics['brier']:.4f}`"
|
||||
)
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
"## Known Limitations",
|
||||
"",
|
||||
@@ -862,6 +957,12 @@ def main() -> int:
|
||||
"ece_10": expected_calibration_error(y_true=y_test, y_prob=y_test_prob, bins=10),
|
||||
}
|
||||
naive_baselines_test = evaluate_naive_baselines(test_df=test_df, y_test=y_test)
|
||||
sliced_performance = evaluate_sliced_performance(
|
||||
test_df=test_df,
|
||||
y_true=y_test,
|
||||
y_prob=y_test_prob,
|
||||
threshold=chosen_threshold,
|
||||
)
|
||||
walk_forward = walk_forward_backtest(
|
||||
model_df=model_df,
|
||||
feature_cols=feature_cols,
|
||||
@@ -941,6 +1042,7 @@ def main() -> int:
|
||||
"test_metrics": test_metrics,
|
||||
"test_calibration_quality": test_calibration,
|
||||
"naive_baselines_test": naive_baselines_test,
|
||||
"sliced_performance_test": sliced_performance,
|
||||
"walk_forward_backtest": walk_forward,
|
||||
}
|
||||
report = to_builtin(report)
|
||||
@@ -1002,6 +1104,21 @@ def main() -> int:
|
||||
f"pr_auc={m['pr_auc'] if m['pr_auc'] is not None else 'n/a'} "
|
||||
f"brier={m['brier']:.4f}"
|
||||
)
|
||||
sliced_ok = [
|
||||
(name, item)
|
||||
for name, item in report["sliced_performance_test"].items()
|
||||
if item.get("status") == "ok"
|
||||
]
|
||||
if sliced_ok:
|
||||
print(" sliced performance (test):")
|
||||
for name, item in sliced_ok:
|
||||
m = item["metrics"]
|
||||
print(
|
||||
f" {name}: rows={item['rows']} "
|
||||
f"precision={m['precision']:.3f} recall={m['recall']:.3f} "
|
||||
f"pr_auc={m['pr_auc'] if m['pr_auc'] is not None else 'n/a'} "
|
||||
f"brier={m['brier']:.4f}"
|
||||
)
|
||||
|
||||
if args.report_out:
|
||||
report_dir = os.path.dirname(args.report_out)
|
||||
|
||||
Reference in New Issue
Block a user