import argparse
import datetime as dt
import json
import os

import joblib
import numpy as np
import psycopg2
from dotenv import load_dotenv

from sklearn.linear_model import Ridge


PLANTS = ["HPML", "RHINO", "KOYO"]
MODEL_VERSION = "autoreg_ridge_v1"

# FIX: explicit whitelist — validated before any DB call
VALID_PLANTS = frozenset(PLANTS)


def get_db_conn():
    load_dotenv(os.path.join(os.path.dirname(__file__), "..", ".env"))
    load_dotenv(os.path.join(os.path.dirname(__file__), "..", ".env.local"), override=False)

    database_url = os.getenv("DATABASE_URL")
    if database_url:
        return psycopg2.connect(database_url)

    host     = os.getenv("DB_HOST")
    port     = int(os.getenv("DB_PORT", "5432"))
    dbname   = os.getenv("DB_NAME")
    user     = os.getenv("DB_USER")
    password = os.getenv("DB_PASSWORD")
    ssl      = os.getenv("DB_SSL", "false").lower() != "false"

    return psycopg2.connect(
        host=host, port=port, dbname=dbname,
        user=user, password=password,
        sslmode="require" if ssl else "disable",
    )


def fetch_daily_rejection_pct(conn, plant, start_date, end_date):
    """
    FIX: plant was previously interpolated directly into an f-string SQL query,
    enabling SQL injection via the --plant argument.

    Now:
      1. plant is validated against VALID_PLANTS before this function is called.
      2. All values (plant LIKE pattern, dates) are passed as psycopg2 %s
         parameters — never concatenated into the SQL string.
    """
    met_start = (start_date - dt.timedelta(days=2)).strftime("%Y-%m-%d")
    met_end   = (end_date   + dt.timedelta(days=2)).strftime("%Y-%m-%d")
    start_str = start_date.strftime("%Y-%m-%d")
    end_str   = end_date.strftime("%Y-%m-%d")

    # Build the LIKE pattern safely — never embedded in the SQL string
    plant_like = f"%{plant.upper()}%"

    # FIX: use %s placeholders throughout; psycopg2 handles quoting/escaping
    sql = """
        WITH met_ranked AS (
            SELECT
                LEFT(TRIM(fcodebno), 9) AS base_code,
                plant_name,
                lotcreatdate,
                ROW_NUMBER() OVER (
                    PARTITION BY LEFT(TRIM(fcodebno), 9), DATE(lotcreatdate)
                    ORDER BY COALESCE(materialdoc, '') DESC
                ) AS rn
            FROM sap_metallurgy
            WHERE lotcreatdate IS NOT NULL
              AND lotcreatdate >= DATE %s
              AND lotcreatdate <= DATE %s
        ),
        met_dedup AS (
            SELECT * FROM met_ranked WHERE rn = 1
        ),
        base AS (
            SELECT
                r.p_date,
                COALESCE(r.prd_qty, 0) AS prd_qty,
                COALESCE(r.rej_qty, 0) AS rej_qty,
                COALESCE(r.prd_wt,  0) AS prd_wt,
                COALESCE(r.rej_wt,  0) AS rej_wt
            FROM rejection_data r
            LEFT JOIN met_dedup s
              ON  LEFT(TRIM(r.fcodebno), 9) = s.base_code
              AND ABS(DATEDIFF('day', DATE(s.lotcreatdate), DATE(r.p_date))) <= 1
              AND UPPER(s.plant_name) LIKE %s
            WHERE r.p_date >= DATE %s
              AND r.p_date <= DATE %s
              AND s.plant_name IS NOT NULL
        )
        SELECT
            DATE(p_date) AS d,
            SUM(prd_qty) AS prd_qty,
            SUM(rej_qty) AS rej_qty,
            SUM(prd_wt)  AS prd_wt,
            SUM(rej_wt)  AS rej_wt
        FROM base
        GROUP BY 1
        ORDER BY 1
    """

    params = (met_start, met_end, plant_like, start_str, end_str)

    with conn.cursor() as cur:
        cur.execute(sql, params)
        rows = cur.fetchall()

    series = []
    for r in rows:
        d, prd_qty, rej_qty, prd_wt, rej_wt = r
        prd_qty = float(prd_qty or 0)
        rej_qty = float(rej_qty or 0)
        prd_wt  = float(prd_wt  or 0)
        rej_wt  = float(rej_wt  or 0)

        qty_pct = (rej_qty / prd_qty) * 100.0 if prd_qty > 0 else None
        wt_pct  = (rej_wt  / prd_wt)  * 100.0 if prd_wt  > 0 else None

        series.append((d.strftime("%Y-%m-%d"), qty_pct, wt_pct))

    return series


def fill_null_forward(values):
    last    = None
    non_null = [v for v in values if v is not None]
    fill    = float(np.mean(non_null)) if non_null else 0.0
    out     = []
    for v in values:
        if v is None:
            out.append(fill if last is None else last)
        else:
            out.append(float(v))
            last = float(v)
    return out


def make_autoreg_dataset(series, window):
    X, y = [], []
    for t in range(window, len(series)):
        X.append(series[t - window: t])
        y.append(series[t])
    if not X:
        return None, None
    return np.asarray(X, dtype=float), np.asarray(y, dtype=float)


def train_for_plant(conn, plant, window, days_back):
    end_date   = dt.date.today() - dt.timedelta(days=1)
    start_date = end_date - dt.timedelta(days=days_back)

    series = fetch_daily_rejection_pct(conn, plant, start_date, end_date)
    if len(series) < window + 2:
        return None

    qty_vals = [q for _, q, _ in series]
    wt_vals  = [w for _, _, w in series]

    qty_filled = fill_null_forward(qty_vals)
    wt_filled  = fill_null_forward(wt_vals)

    X_qty, y_qty = make_autoreg_dataset(qty_filled, window)
    X_wt,  y_wt  = make_autoreg_dataset(wt_filled,  window)

    if X_qty is None or X_wt is None:
        return None

    model_qty = Ridge(alpha=1.0, random_state=42)
    model_wt  = Ridge(alpha=1.0, random_state=42)
    model_qty.fit(X_qty, y_qty)
    model_wt.fit(X_wt,  y_wt)

    return {
        "plant":         plant,
        "window":        window,
        "model_version": MODEL_VERSION,
        "model_qty":     model_qty,
        "model_wt":      model_wt,
    }


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--plant",    default="ALL", help="HPML / RHINO / KOYO or ALL")
    parser.add_argument("--window",   type=int, default=30)
    parser.add_argument("--daysBack", type=int, default=365)
    args = parser.parse_args()

    # FIX: validate plant against explicit whitelist before any DB interaction
    if args.plant.strip().upper() == "ALL":
        plants = list(PLANTS)
    else:
        plant_upper = args.plant.strip().upper()
        if plant_upper not in VALID_PLANTS:
            raise SystemExit(
                f"Invalid --plant '{args.plant}'. Must be one of: {', '.join(VALID_PLANTS)} or ALL"
            )
        plants = [plant_upper]

    window = int(args.window)
    if window < 7 or window > 120:
        raise SystemExit("--window must be between 7 and 120")

    conn       = get_db_conn()
    models_dir = os.path.join(os.path.dirname(__file__), "..", "models")
    os.makedirs(models_dir, exist_ok=True)

    out = {"trained": [], "failed": []}
    try:
        for plant in plants:
            model = train_for_plant(conn, plant, window, args.daysBack)
            if model is None:
                out["failed"].append({"plant": plant, "reason": "not enough training data"})
                continue

            model_file = os.path.join(
                models_dir, f"rejection_autoreg_{plant}_w{window}.joblib"
            )
            joblib.dump(
                {
                    "plant":         model["plant"],
                    "window":        model["window"],
                    "model_version": model["model_version"],
                    "model_qty":     model["model_qty"],
                    "model_wt":      model["model_wt"],
                },
                model_file,
            )
            out["trained"].append({"plant": plant, "modelFile": model_file})
    finally:
        conn.close()

    print(json.dumps(out))


if __name__ == "__main__":
    main()
