"""
train.py  —  Rejection Risk ML Trainer
=======================================
Reads historical heat records from your existing rejection-analysis API
(or a local CSV), trains 3 models, and saves them as pickle files.

Usage — from API:
    pip install scikit-learn xgboost pandas numpy requests joblib
    python train.py --plant HPML --days 730
    python train.py --plant RHINO --days 730
    python train.py --plant KOYO  --days 730

Usage — from CSV (offline / no API):
    python train.py --plant HPML --csv /path/to/heats.csv

    Required CSV columns (any subset is fine; missing → filled with median):
        heatno, date, plant, prd_qty, rej_qty, rej_pct,
        c, si, mn, mg, cu, ce, cr, mo, s, p, sn, ti, ni, al, v,
        uts, ys, elong, noudularity, nodulecount, perlite, ferite, hardness,
        gcs, compactibility, wts, active_clay, dead_clay, permeability,
        moisture, moisture_return, loi, volatile_matter,
        pour_temp_first, pour_temp_last, pour_temp_delta,
        defect_bh, defect_porocity, defect_sw, ...  (one binary col per defect)

Models saved to ./models/{PLANT}/
    risk_model.pkl        — RandomForestRegressor  : predicted rejection rate 0–100
    defect_model.pkl      — XGBoost MultiOutput    : top defect probabilities
    anomaly_model.pkl     — IsolationForest + scaler: unusual heat flag
    feature_columns.json  — ordered feature list used at training time
    threshold.json        — per-defect optimal classification thresholds
    training_report.json  — metrics snapshot from the last training run
"""

from __future__ import annotations

import argparse
import json
import os
import warnings
from datetime import datetime, UTC
from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd
import requests
from joblib import dump

import storage  # S3-or-local artifact storage (see storage.py)
from sklearn.ensemble import IsolationForest, RandomForestRegressor
from sklearn.metrics import (
    classification_report,
    f1_score,
    mean_absolute_error,
    precision_recall_curve,
)
from sklearn.model_selection import cross_val_score, train_test_split
from sklearn.multioutput import MultiOutputClassifier
from sklearn.preprocessing import StandardScaler
from xgboost import XGBClassifier, XGBRegressor

warnings.filterwarnings("ignore")

# ── Config ─────────────────────────────────────────────────────────────────────
API_BASE   = os.environ.get("API_BASE",  "http://localhost:4000")
API_TOKEN  = os.environ.get("API_TOKEN", "")
MODELS_DIR = Path(os.environ.get("MODELS_DIR", "./models"))
ENV_FILE   = Path(__file__).parent / ".env"

# ── Feature definitions (MUST stay in sync with server.py) ────────────────────
CHEM_FEATURES = ["c","si","mn","mg","cu","ce","cr","mo","s","p","sn","ti","ni","al","v"]
MECH_FEATURES = ["uts","ys","elong","noudularity","nodulecount","perlite","ferite","hardness"]
SAND_FEATURES = [
    "gcs","compactibility","wts","active_clay","dead_clay","permeability",
    "moisture","moisture_return","loi","volatile_matter",
]
TEMP_FEATURES = ["pour_temp_first","pour_temp_last","pour_temp_delta"]
DERIVED_CHEM_FEATURES = [
    "ce_target","ce_deviation","mn_s_ratio","mn_s_product",
    "carbide_tendency","graphitizer_balance","mg_efficiency","al_ti_sum",
]
DERIVED_PROCESS_FEATURES = ["pour_temp_target","pour_temp_deviation"]
ALL_FEATURES  = CHEM_FEATURES + MECH_FEATURES + SAND_FEATURES + TEMP_FEATURES + DERIVED_CHEM_FEATURES + DERIVED_PROCESS_FEATURES

DEFECT_KEYS = [
    "bh","porocity","sw","swl","crack","cs","csh","inc","slag","hard",
    "high_bhn","ph","sf","scab","misrun","rough_surface","lkg","mould_crush","dm","dim",
]
DEFECT_LABELS = {
    "bh":"Blow Hole","porocity":"Porosity","sw":"Shrinkage","swl":"Shrinkage (SWL)",
    "crack":"Crack","cs":"Cold Shut","csh":"Chilling","inc":"Inclusion","slag":"Slag",
    "hard":"Hardness NG","high_bhn":"High BHN","ph":"Pin Hole","sf":"Sand Fusion",
    "scab":"Scab","misrun":"Misrun","rough_surface":"Rough Surface","lkg":"Leakage",
    "mould_crush":"Mould Crush","dm":"Dimension","dim":"Dimension OOS",
}

VALID_PLANTS = {"HPML", "RHINO", "KOYO"}

GRADE_CE_TARGETS = [
    ("700", 4.20),
    ("600", 4.25),
    ("500", 4.30),
    ("450", 4.35),
    ("400", 4.35),
]

GRADE_POUR_TEMP_TARGETS = [
    ("700", 1415.0),
    ("600", 1410.0),
    ("500", 1405.0),
    ("450", 1400.0),
    ("400", 1400.0),
]

def _num_series(df: pd.DataFrame, col: str) -> pd.Series:
    if col not in df.columns:
        return pd.Series(np.nan, index=df.index, dtype="float64")
    return pd.to_numeric(df[col], errors="coerce")


def _target_from_grade(value, targets, default):
    g = "" if value is None or pd.isna(value) else str(value).upper()
    for prefix, target in targets:
        if prefix in g:
            return target
    return default


# ── .env helpers ───────────────────────────────────────────────────────────────
def load_env_file() -> None:
    if not ENV_FILE.exists():
        return
    with open(ENV_FILE) as fh:
        for line in fh:
            line = line.strip()
            if not line or line.startswith("#") or "=" not in line:
                continue
            k, _, v = line.partition("=")
            os.environ.setdefault(k.strip(), v.strip())


def save_env_file(key: str, value: str) -> None:
    lines: list[str] = []
    if ENV_FILE.exists():
        with open(ENV_FILE) as fh:
            lines = fh.readlines()
    lines = [ln for ln in lines if not ln.startswith(f"{key}=")]
    lines.append(f"{key}={value}\n")
    with open(ENV_FILE, "w") as fh:
        fh.writelines(lines)


def login_interactive() -> str:
    """Prompt for credentials, call /api/auth/login, cache token in .env."""
    print(f"\nNo token found. Enter your app login credentials:")
    print(f"(Will be saved to {ENV_FILE} so you won't be asked again)")
    username = input("  Username: ").strip()
    password = input("  Password: ").strip()
    if not username or not password:
        raise RuntimeError("Username and password cannot be empty.")

    r = requests.post(
        f"{API_BASE}/api/auth/login",
        json={"username": username, "password": password},
        timeout=15,
    )
    if r.status_code != 200:
        raise RuntimeError(f"Login failed: {r.status_code} — {r.text[:200]}")

    data  = r.json()
    token = (
        data.get("token") or data.get("access_token") or
        data.get("accessToken") or data.get("jwt") or ""
    )
    if not token:
        raise RuntimeError(f"Login ok but no token in response: {list(data.keys())}")

    save_env_file("API_TOKEN", token)
    print(f"  Logged in as '{username}' — token saved to {ENV_FILE} ✓")
    return token


def resolve_token() -> str:
    """
    Token resolution order:
      1. API_TOKEN env-var (already set or passed via --token flag)
      2. ml/.env  (cached from a previous login run)
      3. Interactive login prompt (first time only)
    """
    global API_TOKEN
    if API_TOKEN:
        return API_TOKEN
    load_env_file()
    API_TOKEN = os.environ.get("API_TOKEN", "")
    if API_TOKEN:
        print("  Using cached token from .env ✓")
        return API_TOKEN
    API_TOKEN = login_interactive()
    return API_TOKEN


# ── Data loading: API ──────────────────────────────────────────────────────────
def fetch_heat_records_api(plant: str, days: int) -> pd.DataFrame:
    """
    Calls /api/metallurgy/rejection-analysis-meta to discover product codes,
    then fetches heat records for each product.
    """
    headers = {"Authorization": f"Bearer {API_TOKEN}"} if API_TOKEN else {}

    # ── 1. Discover product codes for this plant
    print(f"\n  Fetching product codes for {plant} (last {days} days) …")
    meta_r = requests.get(
        f"{API_BASE}/api/metallurgy/rejection-analysis-meta",
        params={"lookbackDays": days},
        headers=headers,
        timeout=30,
    )
    if meta_r.status_code != 200:
        raise RuntimeError(f"Meta fetch failed: {meta_r.status_code} — {meta_r.text[:200]}")

    meta = meta_r.json()
    plant_products = [
        p["material"]
        for p in meta.get("products", [])
        if p.get("plant", "").upper() == plant.upper()
        and p.get("lot_count", 0) >= 20     # skip very sparse products
    ]

    if not plant_products:
        print(f"  No products with ≥20 heats found for plant {plant}.")
        return pd.DataFrame()

    print(f"  Found {len(plant_products)} eligible products.")

    # ── 2. Fetch heat records per product
    all_records: list[dict] = []
    for pc in plant_products:
        try:
            r = requests.get(
                f"{API_BASE}/api/metallurgy/rejection-analysis",
                params={"productcode": pc, "plant": plant, "lookbackDays": days},
                headers=headers,
                timeout=60,
            )
            if r.status_code != 200:
                print(f"    {pc}: HTTP {r.status_code} — skipped")
                continue

            data    = r.json()
            records = data.get("heat_records", [])
            if not records:
                continue

            for rec in records:
                flat = _flatten_record(rec, pc)
                all_records.append(flat)

            print(f"    {pc}: {len(records)} heats")

        except Exception as exc:
            print(f"    {pc}: SKIP ({exc})")

    if not all_records:
        return pd.DataFrame()

    df = pd.DataFrame(all_records)
    print(f"\n  Total: {len(df)} heat records loaded from API.")
    return df


def _flatten_record(rec: dict, productcode: str) -> dict:
    """Convert a single heat_record dict to a flat row dict."""
    flat: dict = {
        "productcode":  productcode,
        "heatno":       rec.get("heatno"),
        "date":         rec.get("date"),
        "grade":        rec.get("grade"),
        "prd_qty":      rec.get("prd_qty", 0),
        "rej_qty":      rec.get("rej_qty", 0),
        "rej_pct":      rec.get("rej_pct"),
        "has_rej_data": rec.get("has_rejection_data", False),
    }
    chem = rec.get("chemistry",  {})
    mech = rec.get("mechanical", {})
    sand = rec.get("sand",       {})

    for f in CHEM_FEATURES:
        flat[f] = chem.get(f)
    for f in MECH_FEATURES:
        flat[f] = mech.get(f)
    for f in SAND_FEATURES:
        flat[f] = sand.get(f)

    flat["pour_temp_first"] = rec.get("pour_temp_first")
    flat["pour_temp_last"]  = rec.get("pour_temp_last")
    flat["pour_temp_delta"] = rec.get("pour_temp_delta")

    defects = rec.get("defects", {})
    for dk in DEFECT_KEYS:
        flat[f"defect_{dk}"] = defects.get(dk, {}).get("pct", 0) or 0

    return flat


# ── Data loading: CSV ──────────────────────────────────────────────────────────
def load_heat_records_csv(csv_path: str) -> pd.DataFrame:
    """
    Load heat records from a local CSV. Any column not present is added as NaN
    and later filled with the column median during feature prep.
    """
    df = pd.read_csv(csv_path, low_memory=False)
    print(f"\n  Loaded {len(df)} rows from {csv_path}")

    # Normalise column names to lowercase + strip spaces
    df.columns = [c.strip().lower().replace(" ", "_") for c in df.columns]

    # Derive has_rej_data: True if rej_pct is present and not null
    if "has_rej_data" not in df.columns:
        df["has_rej_data"] = df.get("rej_pct", pd.Series(dtype=float)).notna()

    # Ensure defect_* columns exist (default 0 if absent)
    for dk in DEFECT_KEYS:
        col = f"defect_{dk}"
        if col not in df.columns:
            df[col] = 0

    return df


# ── Feature engineering ────────────────────────────────────────────────────────
def prepare_features(df: pd.DataFrame) -> tuple[pd.DataFrame, list[str]]:
    """
    1. Add derived pour_temp_drop_pct feature.
    2. Fill NaN with column median (or 0 if no median).
    3. Return (cleaned_df, ordered_feature_cols).
    """
    df = df.copy()
    c  = _num_series(df, "c")
    si = _num_series(df, "si")
    p  = _num_series(df, "p")
    mn = _num_series(df, "mn")
    s  = _num_series(df, "s")
    mg = _num_series(df, "mg")
    cr = _num_series(df, "cr")
    mo = _num_series(df, "mo")
    v  = _num_series(df, "v")
    sn = _num_series(df, "sn")
    al = _num_series(df, "al")
    ti = _num_series(df, "ti")

    ce_calc = c + (si / 3.0) + (p / 3.0)
    if "ce" not in df.columns:
        df["ce"] = ce_calc
    else:
        df["ce"] = _num_series(df, "ce").fillna(ce_calc)

    default_ce_target = float(pd.to_numeric(df["ce"], errors="coerce").median()) if pd.notna(pd.to_numeric(df["ce"], errors="coerce").median()) else 4.3
    default_temp_target = float(_num_series(df, "pour_temp_first").median()) if pd.notna(_num_series(df, "pour_temp_first").median()) else 1405.0
    grades = df["grade"] if "grade" in df.columns else pd.Series("", index=df.index)
    df["ce_target"] = grades.map(lambda g: _target_from_grade(g, GRADE_CE_TARGETS, default_ce_target))
    df["ce_deviation"] = pd.to_numeric(df["ce"], errors="coerce") - df["ce_target"]
    df["mn_s_ratio"] = np.where(s.abs() > 1e-9, mn / s, np.nan)
    df["mn_s_product"] = mn * s
    df["carbide_tendency"] = cr.fillna(0) + mo.fillna(0) + v.fillna(0) + sn.fillna(0)
    df["graphitizer_balance"] = si - (3.0 * (cr.fillna(0) + mo.fillna(0) + v.fillna(0)))
    df["mg_efficiency"] = np.where((s * 1.7 + mg).abs() > 1e-9, mg / (s * 1.7 + mg), np.nan)
    df["al_ti_sum"] = al.fillna(0) + ti.fillna(0)
    df["pour_temp_target"] = grades.map(lambda g: _target_from_grade(g, GRADE_POUR_TEMP_TARGETS, default_temp_target))
    df["pour_temp_deviation"] = _num_series(df, "pour_temp_first") - df["pour_temp_target"]

    # Derived: % drop from first to last pour temp
    valid_mask = df["pour_temp_first"].notna() & (df["pour_temp_first"] > 0) & df["pour_temp_delta"].notna()
    df["pour_temp_drop_pct"] = np.nan
    df.loc[valid_mask, "pour_temp_drop_pct"] = (
        df.loc[valid_mask, "pour_temp_delta"] / df.loc[valid_mask, "pour_temp_first"] * 100
    )

    feature_cols = ALL_FEATURES + ["pour_temp_drop_pct"]

    for col in feature_cols:
        if col in df.columns:
            median = df[col].median()
            df[col] = df[col].fillna(median if pd.notna(median) else 0.0)
        else:
            df[col] = 0.0

    # Clip rej_pct to [0, 100]
    if "rej_pct" in df.columns:
        df["rej_pct"] = df["rej_pct"].clip(0, 100)

    return df, feature_cols


# ── Model 1: Risk scorer ───────────────────────────────────────────────────────
def train_risk_model(
    df: pd.DataFrame,
    feature_cols: list[str],
    plant: str,
) -> tuple[Optional[object], Optional[list[str]]]:
    """
    RandomForestRegressor — predicts rejection % for a heat.
    Only heats with joined rejection data are used as training labels.
    """
    df_lab = df[df["has_rej_data"] & df["rej_pct"].notna()].copy()

    if len(df_lab) < 30:
        print(f"  [Risk] Only {len(df_lab)} labelled heats — need ≥30. Skipping.")
        return None, None

    X = df_lab[feature_cols].values
    y = df_lab["rej_pct"].values

    X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.2, random_state=42)

    rf_model = RandomForestRegressor(
        n_estimators=300,
        max_depth=8,
        min_samples_leaf=3,
        max_features="sqrt",
        n_jobs=-1,
        random_state=42,
    )
    xgb_model = XGBRegressor(
        n_estimators=350,
        max_depth=4,
        learning_rate=0.04,
        subsample=0.85,
        colsample_bytree=0.85,
        objective="reg:squarederror",
        n_jobs=-1,
        random_state=42,
        verbosity=0,
    )

    scored = []
    for model_name, candidate in [("RandomForestRegressor", rf_model), ("XGBRegressor", xgb_model)]:
        candidate.fit(X_tr, y_tr)
        preds = candidate.predict(X_te)
        mae = mean_absolute_error(y_te, preds)
        try:
            cv_mae = cross_val_score(candidate, X, y, cv=5, scoring="neg_mean_absolute_error")
            cv_mean, cv_std = -cv_mae.mean(), cv_mae.std()
        except Exception:
            cv_mean, cv_std = mae, 0.0
        scored.append((cv_mean, mae, cv_std, model_name, candidate))
        print(f"  [Risk:{model_name}] MAE={mae:.2f}%  CV MAE={cv_mean:.2f}±{cv_std:.2f}")

    scored.sort(key=lambda row: (row[0], row[1]))
    cv_mean, mae, cv_std, model_name, model = scored[0]
    model._casting_model_name = model_name
    cv_mae = np.array([-cv_mean])
    print(f"  [Risk] Selected {model_name}")

    print(
        f"  [Risk] MAE={mae:.2f}%  CV MAE={-cv_mae.mean():.2f}±{cv_mae.std():.2f}"
        f"  n_train={len(df_lab)}"
    )

    top10 = sorted(zip(feature_cols, model.feature_importances_), key=lambda t: t[1], reverse=True)[:10]
    print(f"  [Risk] Top features: {[(f, round(v, 3)) for f, v in top10]}")

    return model, feature_cols


# ── Model 2: Defect classifier (XGBoost MultiOutput) ──────────────────────────
def _oversample_binary(X: np.ndarray, y: np.ndarray, target_pos_ratio: float = 0.25) -> tuple[np.ndarray, np.ndarray]:
    pos_idx = np.where(y == 1)[0]
    neg_idx = np.where(y == 0)[0]
    if len(pos_idx) == 0 or len(neg_idx) == 0:
        return X, y
    target_pos = int(np.ceil((target_pos_ratio * len(neg_idx)) / max(1 - target_pos_ratio, 1e-6)))
    extra = max(0, target_pos - len(pos_idx))
    if extra == 0:
        return X, y
    rng = np.random.default_rng(42)
    sampled = rng.choice(pos_idx, size=extra, replace=True)
    idx = np.concatenate([np.arange(len(y)), sampled])
    rng.shuffle(idx)
    return X[idx], y[idx]


def train_defect_model(
    df: pd.DataFrame,
    feature_cols: list[str],
    plant: str,
) -> Optional[dict]:
    """
    XGBoost multi-output classifier — one binary head per defect type.
    Only defects appearing in ≥3% of labelled heats are included.
    Saves per-defect optimal thresholds (maximise F1) to threshold.json.
    """
    df_lab = df[df["has_rej_data"]].copy()

    if len(df_lab) < 30:
        print(f"  [Defect] Only {len(df_lab)} labelled heats — need ≥30. Skipping.")
        return None

    X = df_lab[feature_cols].values

    # Select defects with sufficient prevalence
    active_defects: list[str] = []
    for dk in DEFECT_KEYS:
        col = f"defect_{dk}"
        if col in df_lab.columns and int((df_lab[col] > 0).sum()) >= 3:
            active_defects.append(dk)

    if not active_defects:
        print("  [Defect] No defects with ≥3% prevalence — skipping.")
        return None

    y_df  = (df_lab[[f"defect_{dk}" for dk in active_defects]] > 0).astype(int)
    y     = y_df.values

    X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.2, random_state=42, stratify=None)
    pos_total = max(1, int(y_tr.sum()))
    neg_total = max(1, int(y_tr.size - y_tr.sum()))
    imbalance_weight = float(np.clip(neg_total / pos_total, 1.0, 50.0))

    # XGBoost wrapped in MultiOutputClassifier
    base = XGBClassifier(
        n_estimators=200,
        max_depth=5,
        learning_rate=0.05,
        subsample=0.8,
        colsample_bytree=0.8,
        use_label_encoder=False,
        eval_metric="logloss",
        scale_pos_weight=imbalance_weight,
        n_jobs=-1,
        random_state=42,
        verbosity=0,
    )
    model = MultiOutputClassifier(base, n_jobs=-1)
    model.fit(X_tr, y_tr)

    # Per-defect metrics and optimal threshold search
    proba_matrix = model.predict_proba(X_te)   # list of [n_test, 2] arrays
    thresholds: dict[str, float] = {}

    print(f"  [Defect] {len(active_defects)} active defect classes:")
    for i, dk in enumerate(active_defects):
        prob_pos = proba_matrix[i][:, 1] if proba_matrix[i].shape[1] > 1 else np.zeros(len(X_te))
        true_col = y_te[:, i]
        prev     = true_col.mean()

        # Find threshold that maximises F1 on test set
        if prev > 0:
            precisions, recalls, thresh_vals = precision_recall_curve(true_col, prob_pos)
            f1_vals  = np.where(
                (precisions + recalls) > 0,
                2 * precisions * recalls / (precisions + recalls + 1e-9),
                0,
            )
            best_idx   = int(np.argmax(f1_vals[:-1]))           # last element has no threshold
            best_thresh = float(thresh_vals[best_idx]) if len(thresh_vals) > 0 else 0.5
            best_f1    = float(f1_vals[best_idx])
        else:
            best_thresh = 0.5
            best_f1     = 0.0

        thresholds[dk] = round(best_thresh, 3)
        label = DEFECT_LABELS.get(dk, dk)
        print(f"    {label:<22} prev={prev:.0%}  best_threshold={best_thresh:.2f}  test_F1={best_f1:.2f}")

    return {
        "model":      model,
        "defects":    active_defects,
        "thresholds": thresholds,
    }


# ── Model 3: Anomaly detector ─────────────────────────────────────────────────
def train_anomaly_model(
    df: pd.DataFrame,
    feature_cols: list[str],
    plant: str,
) -> dict:
    """
    IsolationForest trained on all heats (unsupervised).
    Scaler is fitted on the TRAINING split only to avoid data leakage.
    contamination is estimated from the historical high-rejection rate.
    """
    X = df[feature_cols].values

    # Estimate contamination from rejection distribution
    if df["has_rej_data"].any():
        rej_series = df.loc[df["has_rej_data"], "rej_pct"].fillna(0)
        q75        = rej_series.quantile(0.75)
        contamination = float(np.clip((rej_series > q75).mean(), 0.05, 0.25))
    else:
        contamination = 0.10

    # Fit scaler on train split only  ← fixes the data-leakage bug
    X_tr, _X_te = train_test_split(X, test_size=0.15, random_state=42)
    scaler = StandardScaler().fit(X_tr)
    X_scaled = scaler.transform(X)

    model = IsolationForest(
        n_estimators=300,
        contamination=contamination,
        max_samples="auto",
        random_state=42,
        n_jobs=-1,
    )
    model.fit(X_scaled)

    flagged_pct = (model.predict(X_scaled) == -1).mean()
    print(
        f"  [Anomaly] contamination={contamination:.0%}  "
        f"flagged={flagged_pct:.0%}  n={len(X)}"
    )

    return {"model": model, "scaler": scaler}


# ── Save models ────────────────────────────────────────────────────────────────
def save_models(
    plant: str,
    risk,
    defect,
    anomaly,
    feature_cols: list[str],
    metrics: dict,
) -> None:
    plant = plant.upper()

    if risk is not None:
        storage.save_pickle(plant, "risk_model.pkl", risk)
        print(f"  Saved risk_model.pkl")

    if defect is not None:
        storage.save_pickle(plant, "defect_model.pkl", defect)
        n_heads = len(defect["defects"])
        print(f"  Saved defect_model.pkl  ({n_heads} defect heads)")

        # Per-defect thresholds that server.py uses for classification
        storage.save_json(plant, "threshold.json", defect["thresholds"])
        print(f"  Saved threshold.json")

    if anomaly is not None:
        storage.save_pickle(plant, "anomaly_model.pkl", anomaly)
        print(f"  Saved anomaly_model.pkl")

    storage.save_json(plant, "feature_columns.json", feature_cols)
    print(f"  Saved feature_columns.json  ({len(feature_cols)} features)")

    # Training report — lightweight audit trail
    report = {
        "plant":     plant,
        "trained_at": datetime.now(UTC).isoformat(),
        "feature_count": len(feature_cols),
        **metrics,
    }
    storage.save_json(plant, "training_report.json", report)
    print(f"  Saved training_report.json")

    print(f"\n  ✓ All models saved → {storage.backend_description()}/{plant}")


# ── Sanity check: feature alignment ───────────────────────────────────────────
def verify_feature_alignment(feature_cols: list[str]) -> None:
    """
    Warn if feature_cols diverges from the server.py canonical list.
    A mismatch here causes silent wrong predictions.
    """
    server_canonical = ALL_FEATURES + ["pour_temp_drop_pct"]
    extras   = set(feature_cols) - set(server_canonical)
    missing  = set(server_canonical) - set(feature_cols)
    if extras:
        print(f"  ⚠  Extra features not in server.py: {sorted(extras)}")
    if missing:
        print(f"  ⚠  Features in server.py but not trained: {sorted(missing)}")
    if not extras and not missing:
        print(f"  [Align] Feature list matches server.py ✓")


# ── Main ───────────────────────────────────────────────────────────────────────
def main() -> None:
    parser = argparse.ArgumentParser(description="Train casting rejection ML models")
    parser.add_argument("--plant",  default="HPML",   choices=sorted(VALID_PLANTS),
                        help="Plant code to train for")
    parser.add_argument("--days",   type=int, default=730,
                        help="API lookback window in days (default 730)")
    parser.add_argument("--csv",    default=None,
                        help="Path to local CSV — skips API fetch entirely")
    parser.add_argument("--api",    default=None,
                        help="Override API_BASE URL")
    parser.add_argument("--token",  default=None,
                        help="JWT auth token — skips interactive login")
    args = parser.parse_args()

    global API_BASE, API_TOKEN
    if args.api:
        API_BASE = args.api
    if args.token:
        API_TOKEN = args.token

    print(f"\n{'='*62}")
    print(f"  Casting Rejection ML Trainer")
    print(f"  Plant: {args.plant}   |  Source: {'CSV' if args.csv else f'API ({args.days}d)'}")
    print(f"{'='*62}")

    # ── 1. Load data
    if args.csv:
        df_raw = load_heat_records_csv(args.csv)
    else:
        try:
            resolve_token()
        except Exception as exc:
            print(f"\n  ERROR resolving token: {exc}")
            return
        df_raw = fetch_heat_records_api(args.plant, args.days)

    if df_raw.empty:
        print(
            "\n  No data loaded — cannot train.\n"
            "  • For API mode: verify API_BASE and that the server is running.\n"
            "  • For CSV mode: check the --csv path and column names.\n"
        )
        return

    # ── 2. Feature engineering
    df, feature_cols = prepare_features(df_raw)
    n_total    = len(df)
    n_labelled = int(df["has_rej_data"].sum())
    print(f"\n  Dataset: {n_total} total heats  |  {n_labelled} with rejection labels")
    print(f"  Feature set: {len(feature_cols)} features")

    verify_feature_alignment(feature_cols)

    # ── 3. Train
    print(f"\n{'─'*40}")
    print(f"  Training models …")
    print(f"{'─'*40}")

    risk_model,  _     = train_risk_model(df, feature_cols, args.plant)
    defect_bundle      = train_defect_model(df, feature_cols, args.plant)
    anomaly_bundle     = train_anomaly_model(df, feature_cols, args.plant)

    # ── 4. Collect lightweight metrics for audit trail
    metrics: dict = {
        "n_heats_total":    n_total,
        "n_heats_labelled": n_labelled,
        "risk_model":       getattr(risk_model, "_casting_model_name", type(risk_model).__name__) if risk_model else None,
        "defect_model":     "XGBoost+MultiOutputClassifier" if defect_bundle else None,
        "defect_heads":     defect_bundle["defects"] if defect_bundle else [],
        "anomaly_model":    "IsolationForest" if anomaly_bundle else None,
    }

    # ── 5. Save
    print(f"\n{'─'*40}")
    print(f"  Saving …")
    print(f"{'─'*40}")
    save_models(args.plant, risk_model, defect_bundle, anomaly_bundle, feature_cols, metrics)

    print(f"\n  Done. Start the prediction server with:\n")
    print(f"      python server.py\n")


if __name__ == "__main__":
    main()
