improve model training
This commit is contained in:
@@ -45,6 +45,7 @@ except ImportError: # pragma: no cover - optional dependency
|
||||
|
||||
MODEL_FAMILIES = ("logreg", "hist_gb", "auto")
|
||||
CALIBRATION_METHODS = ("none", "sigmoid", "isotonic")
|
||||
THRESHOLD_POLICIES = ("validation", "walk_forward")
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
@@ -62,6 +63,12 @@ def parse_args() -> argparse.Namespace:
|
||||
help="Minimum validation precision for threshold selection.",
|
||||
)
|
||||
parser.add_argument("--threshold", type=float, help="Optional fixed classification threshold.")
|
||||
parser.add_argument(
|
||||
"--threshold-policy",
|
||||
default="validation",
|
||||
choices=THRESHOLD_POLICIES,
|
||||
help="How to choose operating threshold when --threshold is not set.",
|
||||
)
|
||||
parser.add_argument("--min-rows", type=int, default=200, help="Minimum model-ready rows required.")
|
||||
parser.set_defaults(allow_empty=True)
|
||||
parser.add_argument(
|
||||
@@ -575,6 +582,127 @@ def evaluate_sliced_performance(
|
||||
return out
|
||||
|
||||
|
||||
def tune_threshold_walk_forward(
|
||||
model_df,
|
||||
feature_cols: list[str],
|
||||
model_family: str,
|
||||
model_params: dict[str, Any],
|
||||
calibration_method: str,
|
||||
random_state: int,
|
||||
min_precision: float,
|
||||
folds: int,
|
||||
) -> dict[str, Any]:
|
||||
if folds <= 0:
|
||||
return {
|
||||
"enabled": False,
|
||||
"status": "disabled",
|
||||
"reason": "walk_forward_folds <= 0",
|
||||
}
|
||||
|
||||
n = len(model_df)
|
||||
min_train_rows = max(200, int(0.4 * n))
|
||||
remaining = n - min_train_rows
|
||||
if remaining < 50:
|
||||
return {
|
||||
"enabled": True,
|
||||
"status": "insufficient_data",
|
||||
"reason": "not enough rows for walk-forward threshold tuning",
|
||||
"requested_folds": folds,
|
||||
"min_train_rows": min_train_rows,
|
||||
}
|
||||
|
||||
fold_size = max(25, remaining // folds)
|
||||
fold_details: list[dict[str, Any]] = []
|
||||
y_true_chunks: list[np.ndarray] = []
|
||||
y_prob_chunks: list[np.ndarray] = []
|
||||
|
||||
for idx in range(folds):
|
||||
train_end = min_train_rows + idx * fold_size
|
||||
test_end = n if idx == folds - 1 else min(min_train_rows + (idx + 1) * fold_size, n)
|
||||
if train_end >= test_end:
|
||||
continue
|
||||
|
||||
fold_train = model_df.iloc[:train_end]
|
||||
fold_test = model_df.iloc[train_end:test_end]
|
||||
if len(fold_train) < 160 or len(fold_test) < 25:
|
||||
continue
|
||||
|
||||
y_fold_train = fold_train["rain_next_1h"].astype(int).to_numpy()
|
||||
y_fold_test = fold_test["rain_next_1h"].astype(int).to_numpy()
|
||||
if len(np.unique(y_fold_train)) < 2:
|
||||
continue
|
||||
|
||||
try:
|
||||
fold_model, fold_fit = fit_with_optional_calibration(
|
||||
model_family=model_family,
|
||||
model_params=model_params,
|
||||
random_state=random_state,
|
||||
x_train=fold_train[feature_cols],
|
||||
y_train=y_fold_train,
|
||||
calibration_method=calibration_method,
|
||||
fallback_to_none=True,
|
||||
)
|
||||
fold_test_prob = fold_model.predict_proba(fold_test[feature_cols])[:, 1]
|
||||
|
||||
y_true_chunks.append(y_fold_test)
|
||||
y_prob_chunks.append(fold_test_prob)
|
||||
fold_details.append(
|
||||
{
|
||||
"fold_index": idx + 1,
|
||||
"train_rows": len(fold_train),
|
||||
"test_rows": len(fold_test),
|
||||
"train_start": fold_train.index.min(),
|
||||
"train_end": fold_train.index.max(),
|
||||
"test_start": fold_test.index.min(),
|
||||
"test_end": fold_test.index.max(),
|
||||
"fit": fold_fit,
|
||||
"test_positive_rate": float(np.mean(y_fold_test)),
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
fold_details.append(
|
||||
{
|
||||
"fold_index": idx + 1,
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
|
||||
if not y_true_chunks:
|
||||
return {
|
||||
"enabled": True,
|
||||
"status": "failed",
|
||||
"reason": "no successful folds produced out-of-fold predictions",
|
||||
"requested_folds": folds,
|
||||
"folds": fold_details,
|
||||
}
|
||||
|
||||
y_oof_true = np.concatenate(y_true_chunks)
|
||||
y_oof_prob = np.concatenate(y_prob_chunks)
|
||||
tuned_threshold, tuned_info = select_threshold(
|
||||
y_true=y_oof_true,
|
||||
y_prob=y_oof_prob,
|
||||
min_precision=min_precision,
|
||||
)
|
||||
tuned_info = dict(tuned_info)
|
||||
tuned_info["selection_rule"] = f"walk_forward_{tuned_info['selection_rule']}"
|
||||
|
||||
return {
|
||||
"enabled": True,
|
||||
"status": "ok",
|
||||
"requested_folds": folds,
|
||||
"successful_folds": int(len(y_true_chunks)),
|
||||
"rows_used": int(len(y_oof_true)),
|
||||
"threshold": float(tuned_threshold),
|
||||
"threshold_selection": tuned_info,
|
||||
"oof_metrics_at_threshold": evaluate_probs(
|
||||
y_true=y_oof_true,
|
||||
y_prob=y_oof_prob,
|
||||
threshold=tuned_threshold,
|
||||
),
|
||||
"folds": fold_details,
|
||||
}
|
||||
|
||||
|
||||
def walk_forward_backtest(
|
||||
model_df,
|
||||
feature_cols: list[str],
|
||||
@@ -935,7 +1063,32 @@ def main() -> int:
|
||||
selected_model_params = best_candidate["model_params"]
|
||||
selected_calibration_method = str(best_candidate["calibration_method"])
|
||||
chosen_threshold = float(best_candidate["threshold"])
|
||||
threshold_info = best_candidate["threshold_info"]
|
||||
threshold_info = dict(best_candidate["threshold_info"])
|
||||
threshold_policy_applied = "fixed" if args.threshold is not None else "validation"
|
||||
threshold_tuning_walk_forward = {
|
||||
"enabled": args.threshold_policy == "walk_forward",
|
||||
"status": "not_run",
|
||||
}
|
||||
if args.threshold is None and args.threshold_policy == "walk_forward":
|
||||
threshold_tuning_walk_forward = tune_threshold_walk_forward(
|
||||
model_df=model_df.iloc[: len(train_df) + len(val_df)],
|
||||
feature_cols=feature_cols,
|
||||
model_family=selected_model_family,
|
||||
model_params=selected_model_params,
|
||||
calibration_method=selected_calibration_method,
|
||||
random_state=args.random_state,
|
||||
min_precision=args.min_precision,
|
||||
folds=args.walk_forward_folds,
|
||||
)
|
||||
if threshold_tuning_walk_forward.get("status") == "ok":
|
||||
chosen_threshold = float(threshold_tuning_walk_forward["threshold"])
|
||||
threshold_info = dict(threshold_tuning_walk_forward["threshold_selection"])
|
||||
threshold_policy_applied = "walk_forward"
|
||||
else:
|
||||
threshold_info["warning"] = (
|
||||
"walk-forward threshold tuning unavailable; fell back to validation-selected threshold"
|
||||
)
|
||||
threshold_policy_applied = "validation_fallback"
|
||||
val_metrics = best_candidate["validation_metrics"]
|
||||
|
||||
train_val_df = model_df.iloc[: len(train_df) + len(val_df)]
|
||||
@@ -971,7 +1124,7 @@ def main() -> int:
|
||||
calibration_method=selected_calibration_method,
|
||||
random_state=args.random_state,
|
||||
min_precision=args.min_precision,
|
||||
fixed_threshold=args.threshold,
|
||||
fixed_threshold=chosen_threshold if threshold_policy_applied == "walk_forward" else args.threshold,
|
||||
folds=args.walk_forward_folds,
|
||||
)
|
||||
|
||||
@@ -989,6 +1142,8 @@ def main() -> int:
|
||||
"calibration_method_requested": calibration_methods,
|
||||
"calibration_method": selected_calibration_method,
|
||||
"calibration_fit": final_fit_info,
|
||||
"threshold_policy_requested": args.threshold_policy,
|
||||
"threshold_policy_applied": threshold_policy_applied,
|
||||
"data_window": {
|
||||
"requested_start": start or None,
|
||||
"requested_end": end or None,
|
||||
@@ -1043,6 +1198,7 @@ def main() -> int:
|
||||
"test_calibration_quality": test_calibration,
|
||||
"naive_baselines_test": naive_baselines_test,
|
||||
"sliced_performance_test": sliced_performance,
|
||||
"threshold_tuning_walk_forward": threshold_tuning_walk_forward,
|
||||
"walk_forward_backtest": walk_forward,
|
||||
}
|
||||
report = to_builtin(report)
|
||||
@@ -1053,6 +1209,10 @@ def main() -> int:
|
||||
print(f" model_family: {selected_model_family} (requested={args.model_family})")
|
||||
print(f" model_params: {selected_model_params}")
|
||||
print(f" calibration_method: {report['calibration_method']}")
|
||||
print(
|
||||
f" threshold_policy: requested={report['threshold_policy_requested']} "
|
||||
f"applied={report['threshold_policy_applied']}"
|
||||
)
|
||||
print(f" feature_set: {args.feature_set} ({len(feature_cols)} features)")
|
||||
print(
|
||||
" rows: "
|
||||
|
||||
Reference in New Issue
Block a user