"""
server.py  —  ML Prediction Service
=====================================
Lightweight Flask API that loads trained models and serves predictions.

Usage:
    pip install flask scikit-learn xgboost pandas numpy joblib
    python server.py

Endpoints:
    POST /predict          — full prediction (all features)
    POST /predict/early    — early warning (sand + PLC features only, pre-cast)
    GET  /health           — model status
    GET  /features         — list expected features per plant
"""

import json
import os
from pathlib import Path
from datetime import datetime, UTC

import numpy as np
from flask import Flask, jsonify, request

import storage  # S3-or-local artifact storage (see storage.py)

app    = Flask(__name__)
PORT   = int(os.environ.get("ML_PORT", 5001))
MODELS_DIR = Path(os.environ.get("MODELS_DIR", "./models"))

# ── Feature definitions (must match train.py) ─────────────────────────────────
CHEM_FEATURES = ["c","si","mn","mg","cu","ce","cr","mo","s","p","sn","ti","ni","al","v"]
MECH_FEATURES = ["uts","ys","elong","noudularity","nodulecount","perlite","ferite","hardness"]
SAND_FEATURES = ["gcs","compactibility","wts","active_clay","dead_clay","permeability",
                  "moisture","moisture_return","loi","volatile_matter"]
TEMP_FEATURES = ["pour_temp_first","pour_temp_last","pour_temp_delta"]
DERIVED_CHEM_FEATURES = [
    "ce_target","ce_deviation","mn_s_ratio","mn_s_product",
    "carbide_tendency","graphitizer_balance","mg_efficiency","al_ti_sum",
]
DERIVED_PROCESS_FEATURES = ["pour_temp_target","pour_temp_deviation"]
ALL_FEATURES = CHEM_FEATURES + MECH_FEATURES + SAND_FEATURES + TEMP_FEATURES + DERIVED_CHEM_FEATURES + DERIVED_PROCESS_FEATURES

GRADE_CE_TARGETS = [("700", 4.20), ("600", 4.25), ("500", 4.30), ("450", 4.35), ("400", 4.35)]
GRADE_POUR_TEMP_TARGETS = [("700", 1415.0), ("600", 1410.0), ("500", 1405.0), ("450", 1400.0), ("400", 1400.0)]

# Features available BEFORE casting (for early warning)
EARLY_FEATURES = SAND_FEATURES + TEMP_FEATURES + [
    "c","si","mn","mg",         # from previous heat / SAP lookup
]

DEFECT_LABELS = {
    "bh":"Blow Hole","porocity":"Porosity","sw":"Shrinkage","swl":"Shrinkage (SWL)",
    "crack":"Crack","cs":"Cold Shut","csh":"Chilling","inc":"Inclusion","slag":"Slag",
    "hard":"Hardness NG","high_bhn":"High BHN","ph":"Pin Hole","sf":"Sand Fusion",
    "scab":"Scab","misrun":"Misrun","rough_surface":"Rough Surface","lkg":"Leakage",
    "mould_crush":"Mould Crush","dm":"Dimension","dim":"Dimension OOS",
}

# ── Model registry ────────────────────────────────────────────────────────────
# Loaded on first request per plant, cached in memory
_model_cache = {}

def load_models(plant: str) -> dict:
    plant = plant.upper()
    if plant in _model_cache:
        return _model_cache[plant]

    if not storage.plant_has_models(plant):
        return None

    bundle = {"plant": plant}

    risk = storage.load_pickle(plant, "risk_model.pkl")
    if risk is not None:
        bundle["risk"] = risk

    defect = storage.load_pickle(plant, "defect_model.pkl")
    if defect is not None:
        bundle["defect"] = defect

    anomaly = storage.load_pickle(plant, "anomaly_model.pkl")
    if anomaly is not None:
        bundle["anomaly"] = anomaly

    feats = storage.load_json(plant, "feature_columns.json")
    bundle["feature_columns"] = feats if feats else (ALL_FEATURES + ["pour_temp_drop_pct"])

    _model_cache[plant] = bundle
    app.logger.info(f"Loaded models for {plant}: {list(bundle.keys())}")
    return bundle


# ── Feature vector builder ────────────────────────────────────────────────────
def build_feature_vector(features_in: dict, feature_cols: list) -> np.ndarray:
    """
    Takes the raw feature dict from the request and returns a numpy array
    in the exact order expected by the models. Missing values → 0.
    """
    features_in = dict(features_in or {})
    grade = str(features_in.get("grade") or "").upper()
    def _num(key):
        try:
            val = features_in.get(key)
            return None if val is None else float(val)
        except (TypeError, ValueError):
            return None
    def _target(targets, default):
        for prefix, value in targets:
            if prefix in grade:
                return value
        return default

    c, si, p = _num("c"), _num("si"), _num("p")
    if _num("ce") is None and c is not None and si is not None:
        features_in["ce"] = c + (si / 3.0) + ((p or 0.0) / 3.0)
    ce = _num("ce")
    mn, s, mg = _num("mn"), _num("s"), _num("mg")
    cr, mo, v, sn = _num("cr"), _num("mo"), _num("v"), _num("sn")
    al, ti = _num("al"), _num("ti")
    ce_target = _target(GRADE_CE_TARGETS, ce if ce is not None else 4.3)
    pour_temp_target = _target(GRADE_POUR_TEMP_TARGETS, 1405.0)
    features_in["ce_target"] = ce_target
    features_in["ce_deviation"] = (ce - ce_target) if ce is not None else 0.0
    features_in["mn_s_ratio"] = (mn / s) if mn is not None and s not in (None, 0) else 0.0
    features_in["mn_s_product"] = (mn or 0.0) * (s or 0.0)
    features_in["carbide_tendency"] = (cr or 0.0) + (mo or 0.0) + (v or 0.0) + (sn or 0.0)
    features_in["graphitizer_balance"] = (si or 0.0) - 3.0 * ((cr or 0.0) + (mo or 0.0) + (v or 0.0))
    denom = ((s or 0.0) * 1.7) + (mg or 0.0)
    features_in["mg_efficiency"] = ((mg or 0.0) / denom) if abs(denom) > 1e-9 else 0.0
    features_in["al_ti_sum"] = (al or 0.0) + (ti or 0.0)
    features_in["pour_temp_target"] = pour_temp_target
    first = _num("pour_temp_first")
    features_in["pour_temp_deviation"] = (first - pour_temp_target) if first is not None else 0.0

    # Derived feature
    delta = _num("pour_temp_delta")
    if first and first > 0 and delta is not None:
        features_in["pour_temp_drop_pct"] = (delta / first) * 100
    else:
        features_in["pour_temp_drop_pct"] = 0.0

    vec = []
    for col in feature_cols:
        val = features_in.get(col)
        if val is None or (isinstance(val, float) and np.isnan(val)):
            vec.append(0.0)
        else:
            vec.append(float(val))

    return np.array(vec).reshape(1, -1)


def clamp(v, lo, hi):
    return max(lo, min(hi, v))


# ── Risk score helper ─────────────────────────────────────────────────────────
def compute_risk(bundle: dict, X: np.ndarray) -> dict:
    if "risk" not in bundle:
        return {"score": None, "level": "unknown", "confidence": "no model"}

    raw = float(bundle["risk"].predict(X)[0])
    score = round(clamp(raw, 0, 100), 1)

    # Risk level thresholds (tune after seeing real distribution)
    if score < 2.0:
        level, color = "low",    "#30d158"
    elif score < 5.0:
        level, color = "medium", "#ff9f0a"
    else:
        level, color = "high",   "#ff3b30"

    return {
        "score": score,
        "level": level,
        "color": color,
        "description": f"Predicted rejection rate: {score:.1f}%",
    }


# ── Defect prediction helper ──────────────────────────────────────────────────
def compute_defects(bundle: dict, X: np.ndarray) -> list:
    if "defect" not in bundle:
        return []

    d = bundle["defect"]
    model   = d["model"]
    defects = d["defects"]
    thresholds = d.get("thresholds", {})

    # Get probability for each class (index 1 = positive class)
    try:
        proba_matrix = model.predict_proba(X)   # list of arrays, one per output
    except Exception:
        return []

    results = []
    for i, dk in enumerate(defects):
        if isinstance(proba_matrix, list):
            # MultiOutputClassifier returns list of [n_samples, n_classes]
            prob = float(proba_matrix[i][0][1]) if proba_matrix[i].shape[1] > 1 else 0.0
        else:
            prob = float(proba_matrix[0][i])

        threshold = float(thresholds.get(dk, 0.10))
        if prob >= min(threshold, 0.10):
            results.append({
                "defect":      dk,
                "label":       DEFECT_LABELS.get(dk, dk),
                "probability": round(prob * 100, 1),
                "threshold":   round(threshold * 100, 1),
            })

    results.sort(key=lambda x: x["probability"], reverse=True)
    return results[:5]   # top 5


# ── Anomaly detection helper ──────────────────────────────────────────────────
def compute_anomaly(bundle: dict, X: np.ndarray, feature_cols: list, features_in: dict) -> dict:
    if "anomaly" not in bundle:
        return {"is_anomaly": False, "score": None, "outlier_params": []}

    a        = bundle["anomaly"]
    model    = a["model"]
    scaler   = a["scaler"]

    X_scaled = scaler.transform(X)
    pred     = model.predict(X_scaled)[0]          # 1=normal, -1=anomaly
    raw_score = float(model.decision_function(X_scaled)[0])

    # Normalise anomaly score to 0–100 (higher = more anomalous)
    # decision_function: negative = anomaly, positive = normal
    # Typical range: -0.5 to +0.5
    anomaly_pct = round(clamp((-raw_score + 0.5) * 100, 0, 100), 1)

    # Find which features contributed most to the anomaly
    outlier_params = []
    if pred == -1:
        X_mean = scaler.mean_
        X_std  = scaler.scale_
        for i, col in enumerate(feature_cols):
            val = features_in.get(col, 0) or 0
            z   = abs((val - X_mean[i]) / (X_std[i] + 1e-9))
            if z > 2.0:
                outlier_params.append({
                    "feature":    col,
                    "value":      round(float(val), 3),
                    "z_score":    round(float(z), 2),
                    "direction":  "high" if val > X_mean[i] else "low",
                })
        outlier_params.sort(key=lambda x: x["z_score"], reverse=True)
        outlier_params = outlier_params[:5]

    return {
        "is_anomaly":    bool(pred == -1),
        "anomaly_score": anomaly_pct,
        "outlier_params": outlier_params,
        "description": "Unusual heat — parameters deviate significantly from historical norms" if pred == -1 else "Normal heat",
    }


# ── POST /predict ─────────────────────────────────────────────────────────────
@app.route("/predict", methods=["POST"])
def predict():
    """
    Body: {
        plant:    "HPML",
        heatno:   "25H07070",          (optional, for logging)
        features: { c: 3.5, si: 2.1, ... }
    }
    Response: {
        heatno, plant,
        risk:    { score, level, color, description },
        defects: [ { defect, label, probability } ],
        anomaly: { is_anomaly, anomaly_score, outlier_params, description },
        model_version, timestamp
    }
    """
    body = request.get_json(silent=True) or {}
    plant      = str(body.get("plant", "")).upper()
    heatno     = str(body.get("heatno", ""))
    features_in = body.get("features", {})

    if not plant:
        return jsonify({"error": "plant is required"}), 400

    bundle = load_models(plant)
    if bundle is None:
        return jsonify({
            "error": f"No models found for plant {plant}. Run train.py first.",
            "hint":  f"python train.py --plant {plant}"
        }), 404

    feature_cols = bundle["feature_columns"]
    X = build_feature_vector(features_in, feature_cols)

    risk    = compute_risk(bundle, X)
    defects = compute_defects(bundle, X)
    anomaly = compute_anomaly(bundle, X, feature_cols, features_in)

    # Overall alert level
    alert = "ok"
    if anomaly["is_anomaly"] or risk.get("level") == "high":
        alert = "danger"
    elif risk.get("level") == "medium":
        alert = "warning"

    return jsonify({
        "heatno":     heatno,
        "plant":      plant,
        "alert":      alert,
        "risk":       risk,
        "defects":    defects,
        "anomaly":    anomaly,
        "features_received": len([v for v in features_in.values() if v is not None]),
        "timestamp":  datetime.now(UTC).isoformat(),
    })


# ── POST /predict/early ───────────────────────────────────────────────────────
@app.route("/predict/early", methods=["POST"])
def predict_early():
    """
    Pre-cast prediction using sand + PLC features only.
    Chemistry is taken from the PREVIOUS heat for this product (passed by Node).
    Body: same as /predict but features will be sparse (no full chemistry).
    """
    body = request.get_json(silent=True) or {}
    body["early_mode"] = True
    return predict()


# ── GET /health ───────────────────────────────────────────────────────────────
@app.route("/health", methods=["GET"])
def health():
    plants_loaded = []
    for plant in ["HPML", "RHINO", "KOYO"]:
        if storage.plant_has_models(plant):
            plants_loaded.append({
                "plant":   plant,
                "models": {
                    "risk":    storage.artifact_exists(plant, "risk_model.pkl"),
                    "defect":  storage.artifact_exists(plant, "defect_model.pkl"),
                    "anomaly": storage.artifact_exists(plant, "anomaly_model.pkl"),
                },
                "cached":  plant in _model_cache,
            })

    return jsonify({
        "status":  "ok",
        "storage": storage.backend_description(),
        "plants":  plants_loaded,
        "port":    PORT,
    })


# ── GET /features ─────────────────────────────────────────────────────────────
@app.route("/features/<plant>", methods=["GET"])
def get_features(plant):
    bundle = load_models(plant.upper())
    if bundle is None:
        return jsonify({"error": f"No models for {plant}"}), 404
    return jsonify({
        "plant":           plant.upper(),
        "feature_columns": bundle["feature_columns"],
        "count":           len(bundle["feature_columns"]),
    })


# ── Startup ───────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    print(f"\n ML Prediction Service starting on port {PORT}")
    print(f"   Model storage:    {storage.backend_description()}")
    print(f"   Health check: http://localhost:{PORT}/health\n")

    # Pre-load models for all plants that have trained models
    for plant in ["HPML", "RHINO", "KOYO"]:
        if storage.plant_has_models(plant):
            load_models(plant)

    app.run(host="0.0.0.0", port=PORT, debug=False)
