more work on model training
This commit is contained in:
@@ -8,7 +8,16 @@ from datetime import datetime, timedelta, timezone
|
||||
import psycopg2
|
||||
from psycopg2.extras import Json
|
||||
|
||||
from rain_model_common import build_dataset, fetch_baro, fetch_ws90, model_frame, parse_time, to_builtin
|
||||
from rain_model_common import (
|
||||
build_dataset,
|
||||
feature_columns_need_forecast,
|
||||
fetch_baro,
|
||||
fetch_forecast,
|
||||
fetch_ws90,
|
||||
model_frame,
|
||||
parse_time,
|
||||
to_builtin,
|
||||
)
|
||||
|
||||
try:
|
||||
import joblib
|
||||
@@ -33,6 +42,11 @@ def parse_args() -> argparse.Namespace:
|
||||
default=6,
|
||||
help="History lookback window used to build features.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--forecast-model",
|
||||
default="ecmwf",
|
||||
help="Forecast model name when inference features require forecast columns.",
|
||||
)
|
||||
parser.add_argument("--dry-run", action="store_true", help="Do not write prediction to DB.")
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -65,9 +79,12 @@ def main() -> int:
|
||||
at = parse_at(args.at)
|
||||
artifact = load_artifact(args.model_path)
|
||||
model = artifact["model"]
|
||||
features = artifact["features"]
|
||||
features = list(artifact["features"])
|
||||
feature_set = artifact.get("feature_set")
|
||||
needs_forecast = feature_columns_need_forecast(features)
|
||||
threshold = float(artifact.get("threshold", 0.5))
|
||||
model_version = args.model_version or artifact.get("model_version") or "unknown"
|
||||
forecast_model = str(artifact.get("forecast_model") or args.forecast_model)
|
||||
|
||||
fetch_start = (at - timedelta(hours=args.history_hours)).isoformat()
|
||||
fetch_end = (at + timedelta(hours=1, minutes=5)).isoformat()
|
||||
@@ -75,8 +92,11 @@ def main() -> int:
|
||||
with psycopg2.connect(args.db_url) as conn:
|
||||
ws90 = fetch_ws90(conn, args.site, fetch_start, fetch_end)
|
||||
baro = fetch_baro(conn, args.site, fetch_start, fetch_end)
|
||||
forecast = None
|
||||
if needs_forecast:
|
||||
forecast = fetch_forecast(conn, args.site, fetch_start, fetch_end, model=forecast_model)
|
||||
|
||||
full_df = build_dataset(ws90, baro)
|
||||
full_df = build_dataset(ws90, baro, forecast=forecast)
|
||||
feature_df = model_frame(full_df, feature_cols=features, require_target=False)
|
||||
candidates = feature_df.loc[feature_df.index <= at]
|
||||
if candidates.empty:
|
||||
@@ -105,6 +125,9 @@ def main() -> int:
|
||||
metadata = {
|
||||
"artifact_path": args.model_path,
|
||||
"artifact_model_version": artifact.get("model_version"),
|
||||
"artifact_feature_set": feature_set,
|
||||
"forecast_model": forecast_model if needs_forecast else None,
|
||||
"needs_forecast_features": needs_forecast,
|
||||
"feature_values": {col: float(row.iloc[0][col]) for col in features},
|
||||
"source_window_start": fetch_start,
|
||||
"source_window_end": fetch_end,
|
||||
@@ -117,6 +140,8 @@ def main() -> int:
|
||||
print(f" site: {args.site}")
|
||||
print(f" model_name: {args.model_name}")
|
||||
print(f" model_version: {model_version}")
|
||||
if feature_set:
|
||||
print(f" feature_set: {feature_set}")
|
||||
print(f" pred_ts: {pred_ts.isoformat()}")
|
||||
print(f" threshold: {threshold:.3f}")
|
||||
print(f" probability: {probability:.4f}")
|
||||
|
||||
Reference in New Issue
Block a user