more work on model training
This commit is contained in:
@@ -11,12 +11,15 @@ from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
from rain_model_common import (
|
||||
FEATURE_COLUMNS,
|
||||
AVAILABLE_FEATURE_SETS,
|
||||
RAIN_EVENT_THRESHOLD_MM,
|
||||
build_dataset,
|
||||
evaluate_probs,
|
||||
fetch_baro,
|
||||
fetch_forecast,
|
||||
fetch_ws90,
|
||||
feature_columns_for_set,
|
||||
feature_columns_need_forecast,
|
||||
model_frame,
|
||||
parse_time,
|
||||
select_threshold,
|
||||
@@ -46,12 +49,31 @@ def parse_args() -> argparse.Namespace:
|
||||
)
|
||||
parser.add_argument("--threshold", type=float, help="Optional fixed classification threshold.")
|
||||
parser.add_argument("--min-rows", type=int, default=200, help="Minimum model-ready rows required.")
|
||||
parser.add_argument(
|
||||
"--feature-set",
|
||||
default="baseline",
|
||||
choices=AVAILABLE_FEATURE_SETS,
|
||||
help="Named feature set to train with.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--forecast-model",
|
||||
default="ecmwf",
|
||||
help="Forecast model name when feature set requires forecast columns.",
|
||||
)
|
||||
parser.add_argument("--out", default="models/rain_model.pkl", help="Path to save model.")
|
||||
parser.add_argument(
|
||||
"--report-out",
|
||||
default="models/rain_model_report.json",
|
||||
help="Path to save JSON training report.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-out",
|
||||
default="models/datasets/rain_dataset_{model_version}_{feature_set}.csv",
|
||||
help=(
|
||||
"Path (or template) for model-ready dataset snapshot. "
|
||||
"Supports {model_version} and {feature_set}. Use empty string to disable."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-version",
|
||||
default="rain-logreg-v1",
|
||||
@@ -76,13 +98,18 @@ def main() -> int:
|
||||
|
||||
start = parse_time(args.start) if args.start else ""
|
||||
end = parse_time(args.end) if args.end else ""
|
||||
feature_cols = feature_columns_for_set(args.feature_set)
|
||||
needs_forecast = feature_columns_need_forecast(feature_cols)
|
||||
|
||||
with psycopg2.connect(args.db_url) as conn:
|
||||
ws90 = fetch_ws90(conn, args.site, start, end)
|
||||
baro = fetch_baro(conn, args.site, start, end)
|
||||
forecast = None
|
||||
if needs_forecast:
|
||||
forecast = fetch_forecast(conn, args.site, start, end, model=args.forecast_model)
|
||||
|
||||
full_df = build_dataset(ws90, baro, rain_event_threshold_mm=RAIN_EVENT_THRESHOLD_MM)
|
||||
model_df = model_frame(full_df, FEATURE_COLUMNS, require_target=True)
|
||||
full_df = build_dataset(ws90, baro, forecast=forecast, rain_event_threshold_mm=RAIN_EVENT_THRESHOLD_MM)
|
||||
model_df = model_frame(full_df, feature_cols, require_target=True)
|
||||
if len(model_df) < args.min_rows:
|
||||
raise RuntimeError(f"not enough model-ready rows after filtering (need >= {args.min_rows})")
|
||||
|
||||
@@ -92,11 +119,11 @@ def main() -> int:
|
||||
val_ratio=args.val_ratio,
|
||||
)
|
||||
|
||||
x_train = train_df[FEATURE_COLUMNS]
|
||||
x_train = train_df[feature_cols]
|
||||
y_train = train_df["rain_next_1h"].astype(int).to_numpy()
|
||||
x_val = val_df[FEATURE_COLUMNS]
|
||||
x_val = val_df[feature_cols]
|
||||
y_val = val_df["rain_next_1h"].astype(int).to_numpy()
|
||||
x_test = test_df[FEATURE_COLUMNS]
|
||||
x_test = test_df[feature_cols]
|
||||
y_test = test_df["rain_next_1h"].astype(int).to_numpy()
|
||||
|
||||
base_model = make_model()
|
||||
@@ -119,7 +146,7 @@ def main() -> int:
|
||||
val_metrics = evaluate_probs(y_true=y_val, y_prob=y_val_prob, threshold=chosen_threshold)
|
||||
|
||||
train_val_df = model_df.iloc[: len(train_df) + len(val_df)]
|
||||
x_train_val = train_val_df[FEATURE_COLUMNS]
|
||||
x_train_val = train_val_df[feature_cols]
|
||||
y_train_val = train_val_df["rain_next_1h"].astype(int).to_numpy()
|
||||
|
||||
final_model = make_model()
|
||||
@@ -131,8 +158,10 @@ def main() -> int:
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"site": args.site,
|
||||
"model_version": args.model_version,
|
||||
"feature_set": args.feature_set,
|
||||
"target_definition": f"rain_next_1h_mm >= {RAIN_EVENT_THRESHOLD_MM:.2f}",
|
||||
"feature_columns": FEATURE_COLUMNS,
|
||||
"feature_columns": feature_cols,
|
||||
"forecast_model": args.forecast_model if needs_forecast else None,
|
||||
"data_window": {
|
||||
"requested_start": start or None,
|
||||
"requested_end": end or None,
|
||||
@@ -141,11 +170,13 @@ def main() -> int:
|
||||
"model_rows": len(model_df),
|
||||
"ws90_rows": len(ws90),
|
||||
"baro_rows": len(baro),
|
||||
"forecast_rows": int(len(forecast)) if forecast is not None else 0,
|
||||
},
|
||||
"label_quality": {
|
||||
"rain_reset_count": int(np.nansum(full_df["rain_reset"].fillna(False).to_numpy(dtype=int))),
|
||||
"rain_spike_5m_count": int(np.nansum(full_df["rain_spike_5m"].fillna(False).to_numpy(dtype=int))),
|
||||
},
|
||||
"feature_missingness_ratio": {col: float(full_df[col].isna().mean()) for col in feature_cols if col in full_df.columns},
|
||||
"split": {
|
||||
"train_ratio": args.train_ratio,
|
||||
"val_ratio": args.val_ratio,
|
||||
@@ -171,6 +202,7 @@ def main() -> int:
|
||||
print("Rain model training summary:")
|
||||
print(f" site: {args.site}")
|
||||
print(f" model_version: {args.model_version}")
|
||||
print(f" feature_set: {args.feature_set} ({len(feature_cols)} features)")
|
||||
print(f" rows: total={report['data_window']['model_rows']} train={report['split']['train_rows']} val={report['split']['val_rows']} test={report['split']['test_rows']}")
|
||||
print(
|
||||
" threshold: "
|
||||
@@ -200,6 +232,15 @@ def main() -> int:
|
||||
json.dump(report, f, indent=2)
|
||||
print(f"Saved report to {args.report_out}")
|
||||
|
||||
if args.dataset_out:
|
||||
dataset_out = args.dataset_out.format(model_version=args.model_version, feature_set=args.feature_set)
|
||||
dataset_dir = os.path.dirname(dataset_out)
|
||||
if dataset_dir:
|
||||
os.makedirs(dataset_dir, exist_ok=True)
|
||||
snapshot_cols = list(dict.fromkeys(feature_cols + ["rain_next_1h", "rain_next_1h_mm"]))
|
||||
model_df[snapshot_cols].to_csv(dataset_out, index=True, index_label="ts")
|
||||
print(f"Saved dataset snapshot to {dataset_out}")
|
||||
|
||||
if args.out:
|
||||
out_dir = os.path.dirname(args.out)
|
||||
if out_dir:
|
||||
@@ -209,7 +250,9 @@ def main() -> int:
|
||||
else:
|
||||
artifact = {
|
||||
"model": final_model,
|
||||
"features": FEATURE_COLUMNS,
|
||||
"features": feature_cols,
|
||||
"feature_set": args.feature_set,
|
||||
"forecast_model": args.forecast_model if needs_forecast else None,
|
||||
"threshold": float(chosen_threshold),
|
||||
"target_mm": float(RAIN_EVENT_THRESHOLD_MM),
|
||||
"model_version": args.model_version,
|
||||
|
||||
Reference in New Issue
Block a user