import argparse
import json
import os
import pathlib

import joblib
import numpy as np


# FIX: modelFile must stay inside the models/ directory.
# joblib.load() uses pickle internally — a crafted .joblib file can execute
# arbitrary code. Validating the path prevents directory-traversal attacks
# where an attacker replaces the model file with a malicious one.
ALLOWED_MODELS_DIR = (
    pathlib.Path(__file__).parent.parent / "models"
).resolve()


def validate_model_path(raw_path: str) -> pathlib.Path:
    """
    Resolve the path and confirm it sits inside ALLOWED_MODELS_DIR.
    Raises SystemExit if the path escapes the directory.
    """
    resolved = pathlib.Path(raw_path).resolve()
    try:
        resolved.relative_to(ALLOWED_MODELS_DIR)
    except ValueError:
        raise SystemExit(
            f"Model file path is outside the allowed directory.\n"
            f"  Given:   {resolved}\n"
            f"  Allowed: {ALLOWED_MODELS_DIR}"
        )
    if not resolved.exists():
        raise SystemExit(f"Model file not found: {resolved}")
    return resolved


def parse_json_list(s):
    if s is None:
        return []
    if isinstance(s, list):
        return s
    return json.loads(s)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--modelFile",   required=True)
    parser.add_argument("--historyQty",  required=True, help="JSON array of floats")
    parser.add_argument("--historyWt",   required=True, help="JSON array of floats")
    parser.add_argument("--horizonDays", type=int, required=True)
    args = parser.parse_args()

    # FIX: validate path before loading
    model_path = validate_model_path(args.modelFile)
    model      = joblib.load(model_path)

    window  = int(model["window"])
    horizon = int(args.horizonDays)

    history_qty = parse_json_list(args.historyQty)
    history_wt  = parse_json_list(args.historyWt)

    if len(history_qty) < window or len(history_wt) < window:
        raise SystemExit("history too short for window")

    qty_hist = list(map(float, history_qty[-window:]))
    wt_hist  = list(map(float, history_wt[-window:]))

    model_qty = model.get("model_qty")
    model_wt  = model.get("model_wt")

    preds_qty = []
    preds_wt  = []

    for _ in range(horizon):
        X_qty = np.asarray(qty_hist[-window:], dtype=float).reshape(1, -1)
        X_wt  = np.asarray(wt_hist[-window:],  dtype=float).reshape(1, -1)

        pred_q = float(model_qty.predict(X_qty)[0]) if model_qty is not None else None
        pred_w = float(model_wt.predict(X_wt)[0])   if model_wt  is not None else None

        preds_qty.append(pred_q)
        preds_wt.append(pred_w)

        qty_hist.append(pred_q if pred_q is not None else qty_hist[-1])
        wt_hist.append(pred_w  if pred_w  is not None else wt_hist[-1])

    print(json.dumps({
        "window":        window,
        "model_version": model.get("model_version"),
        "predictedQtyPct": preds_qty,
        "predictedWtPct":  preds_wt,
    }))


if __name__ == "__main__":
    main()
