"""
ERGON -- Evaluador del parser de bitacoras.

Corre 10 casos sinteticos y reporta:
- Tasa de acierto de tools esperadas (recall por caso + global)
- Campos bien extraidos por tool
- Falsos positivos (tools no esperadas disparadas)
- Latencia y uso de tokens
- Preguntas pendientes detectadas

No toca la DB productiva. Usa un parser "modo test" que inyecta el contexto
sinteticamente y llama a la Anthropic API directa con el system prompt real.

Uso:
    cd db/
    python parser_eval.py            # corre los 10 casos
    python parser_eval.py --case 3   # corre solo el caso 3
"""
from __future__ import annotations

import argparse
import json
import os
import sys
import time
from pathlib import Path
from typing import Any

# Windows console: forzar UTF-8 para que print no crashee con acentos / flechas unicode
if sys.stdout.encoding and sys.stdout.encoding.lower() != "utf-8":
    try:
        sys.stdout.reconfigure(encoding="utf-8", errors="replace")
        sys.stderr.reconfigure(encoding="utf-8", errors="replace")
    except Exception:
        pass

try:
    import anthropic
except ImportError:
    print("ERROR: anthropic SDK no instalado. pip install anthropic")
    sys.exit(1)

# Import del modulo parser sin depender de DB real
sys.path.insert(0, str(Path(__file__).resolve().parent))
from parser_bitacora import TOOLS, MODEL, PROMPT_VERSION, build_system_prompt

BASE = Path(__file__).resolve().parent
DATASET = BASE / "test_bitacoras.json"


# ---------------------------------------------------------------------------
# Parser sintetico -- mismo core que parser_bitacora.parse_bitacora
# pero sin tocar SQLite
# ---------------------------------------------------------------------------

def parse_sintetico(caso: dict, ctx: dict, hoy_iso: str, client: anthropic.Anthropic) -> dict:
    system_prompt = build_system_prompt(ctx, hoy_iso)
    entry = caso["input"]
    user_text = f"Bitacora del {entry['fecha']}:\n\n"
    if entry.get("actividades"):
        user_text += f"ACTIVIDADES: {entry['actividades']}\n\n"
    if entry.get("observaciones"):
        user_text += f"OBSERVACIONES: {entry['observaciones']}\n\n"
    if entry.get("clima"):
        user_text += f"CLIMA: {entry['clima']}\n"

    t0 = time.time()
    resp = client.messages.create(
        model=MODEL,
        max_tokens=2048,
        system=[{"type": "text", "text": system_prompt, "cache_control": {"type": "ephemeral"}}],
        tools=TOOLS,
        messages=[{"role": "user", "content": user_text}],
    )
    dt = time.time() - t0

    acciones = []
    resumen = []
    for block in resp.content:
        if block.type == "tool_use":
            acciones.append({"tool": block.name, "input": dict(block.input) if block.input else {}})
        elif block.type == "text" and block.text.strip():
            resumen.append(block.text.strip())

    return {
        "acciones": acciones,
        "resumen": "\n\n".join(resumen),
        "latency_s": round(dt, 2),
        "usage": {
            "input_tokens": resp.usage.input_tokens,
            "output_tokens": resp.usage.output_tokens,
            "cache_creation_input_tokens": getattr(resp.usage, "cache_creation_input_tokens", 0) or 0,
            "cache_read_input_tokens": getattr(resp.usage, "cache_read_input_tokens", 0) or 0,
        },
        "stop_reason": resp.stop_reason,
    }


# ---------------------------------------------------------------------------
# Evaluacion caso por caso
# ---------------------------------------------------------------------------

def eval_caso(caso: dict, resultado: dict) -> dict:
    """Compara tools esperadas vs obtenidas. Devuelve diagnostico."""
    expected_tools = caso.get("expected_tools", [])
    obtained_tools = [a["tool"] for a in resultado["acciones"]]

    exp_set = set(expected_tools)
    obt_set = set(obtained_tools)

    matched = exp_set & obt_set
    faltantes = exp_set - obt_set
    extras = obt_set - exp_set

    # Para el caso "expected_tools=[]" (bitacora vacia / futuro / ambiguo):
    # acierto = no disparar ninguna tool
    if not expected_tools:
        tools_ok = len(obtained_tools) == 0
    else:
        tools_ok = exp_set == obt_set

    # Check de campos clave si hay expected_fields
    fields_ok = True
    field_details = []
    if "expected_fields" in caso:
        for tool_name, expected in caso["expected_fields"].items():
            obtained = next((a["input"] for a in resultado["acciones"] if a["tool"] == tool_name), None)
            if obtained is None:
                fields_ok = False
                field_details.append(f"  {tool_name}: NO DISPARADA")
                continue
            for k, v in expected.items():
                if k.endswith("_contains"):
                    real_key = k.replace("_contains", "")
                    got = str(obtained.get(real_key, ""))
                    if str(v).lower() not in got.lower():
                        fields_ok = False
                        field_details.append(f"  {tool_name}.{real_key}: obtuvo '{got}', esperaba contener '{v}'")
                    else:
                        field_details.append(f"  {tool_name}.{real_key}: OK ('{got}' contiene '{v}')")
                else:
                    got = obtained.get(k)
                    if got != v:
                        # Tolerar numericos cercanos (por ejemplo 83.1 vs 83)
                        if isinstance(v, (int, float)) and isinstance(got, (int, float)):
                            if abs(got - v) / max(abs(v), 1) < 0.01:
                                field_details.append(f"  {tool_name}.{k}: OK (~{got})")
                                continue
                        fields_ok = False
                        field_details.append(f"  {tool_name}.{k}: obtuvo {got!r}, esperaba {v!r}")
                    else:
                        field_details.append(f"  {tool_name}.{k}: OK ({got!r})")

    return {
        "tools_ok": tools_ok,
        "fields_ok": fields_ok,
        "matched": sorted(matched),
        "faltantes": sorted(faltantes),
        "extras": sorted(extras),
        "field_details": field_details,
    }


# ---------------------------------------------------------------------------
# Runner
# ---------------------------------------------------------------------------

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--case", type=int, help="Correr solo un caso por id")
    ap.add_argument("--verbose", action="store_true", help="Imprimir input/output completo de cada caso")
    args = ap.parse_args()

    if not os.environ.get("ANTHROPIC_API_KEY"):
        # Intentar leer .env como hace servidor_local
        env_file = BASE.parent / ".env"
        if env_file.exists():
            for line in env_file.read_text(encoding="utf-8").splitlines():
                if line.startswith("ANTHROPIC_API_KEY="):
                    os.environ["ANTHROPIC_API_KEY"] = line.split("=", 1)[1].strip()
                    break
    if not os.environ.get("ANTHROPIC_API_KEY"):
        print("ERROR: ANTHROPIC_API_KEY no encontrada (ni en env ni en .env)")
        sys.exit(1)

    dataset = json.loads(DATASET.read_text(encoding="utf-8"))
    ctx_raw = dataset["_obra_contexto"]
    hoy_iso = dataset["_hoy"]
    # Armar contexto en el shape que espera build_system_prompt
    ctx = {
        "obra_id": 1,
        "obra_codigo": ctx_raw["codigo"],
        "obra_nombre": "Obra Demo (DEMO)",
        "obra_cliente": "Grupo DG Desarrollo",
        "moneda": "PYG",
        "fecha_corte": hoy_iso,
        "rubros": [{"nombre": r, "peso_pct": 100.0 / len(ctx_raw["rubros_activos"])}
                   for r in ctx_raw["rubros_activos"]],
        "subcontratistas": [{"nombre": s} for s in ctx_raw["subcontratistas_conocidos"]],
    }

    casos = dataset["casos"]
    if args.case:
        casos = [c for c in casos if c["id"] == args.case]
        if not casos:
            print(f"ERROR: caso {args.case} no encontrado")
            sys.exit(1)

    client = anthropic.Anthropic()

    print(f"\n{'=' * 70}")
    print(f"ERGON Parser Eval -- {MODEL} -- prompt {PROMPT_VERSION}")
    print(f"Corriendo {len(casos)} casos...\n")

    resultados = []
    for caso in casos:
        print(f"--- Caso {caso['id']}: {caso['categoria']} ---")
        try:
            r = parse_sintetico(caso, ctx, hoy_iso, client)
        except anthropic.APIError as e:
            print(f"  API ERROR: {e}\n")
            resultados.append({"caso": caso, "error": str(e)})
            continue
        except Exception as e:
            print(f"  EXCEPTION: {e}\n")
            resultados.append({"caso": caso, "error": str(e)})
            continue

        diag = eval_caso(caso, r)

        status = "PASS" if diag["tools_ok"] and diag["fields_ok"] else "FAIL"
        print(f"  [{status}] latency={r['latency_s']}s in={r['usage']['input_tokens']} "
              f"out={r['usage']['output_tokens']} "
              f"cache_read={r['usage']['cache_read_input_tokens']}")
        print(f"  Tools esperadas: {caso.get('expected_tools', [])}")
        print(f"  Tools obtenidas: {[a['tool'] for a in r['acciones']]}")
        if diag["faltantes"]:
            print(f"  FALTANTES: {diag['faltantes']}")
        if diag["extras"]:
            print(f"  EXTRAS: {diag['extras']}")
        for line in diag["field_details"]:
            print(line)
        if r["resumen"]:
            print(f"  Resumen del modelo: {r['resumen'][:200]}")
        if args.verbose:
            print(f"  INPUT: {json.dumps(caso['input'], ensure_ascii=False)}")
            print(f"  ACCIONES: {json.dumps(r['acciones'], ensure_ascii=False, indent=2)}")
        print()
        resultados.append({"caso": caso, "resultado": r, "diagnostico": diag})

    # Resumen global
    total = len(resultados)
    pass_count = sum(1 for x in resultados if "diagnostico" in x
                     and x["diagnostico"]["tools_ok"] and x["diagnostico"]["fields_ok"])
    tools_ok_count = sum(1 for x in resultados if "diagnostico" in x and x["diagnostico"]["tools_ok"])
    fields_ok_count = sum(1 for x in resultados if "diagnostico" in x and x["diagnostico"]["fields_ok"])
    errors = sum(1 for x in resultados if "error" in x)

    total_in = sum(x.get("resultado", {}).get("usage", {}).get("input_tokens", 0) for x in resultados)
    total_out = sum(x.get("resultado", {}).get("usage", {}).get("output_tokens", 0) for x in resultados)
    total_cache_read = sum(x.get("resultado", {}).get("usage", {}).get("cache_read_input_tokens", 0) for x in resultados)
    avg_latency = sum(x.get("resultado", {}).get("latency_s", 0) for x in resultados) / max(total - errors, 1)

    print("=" * 70)
    print(f"RESUMEN")
    print(f"  Total casos:        {total}")
    print(f"  Pass completo:      {pass_count}/{total}  ({100*pass_count/total:.0f}%)")
    print(f"  Tools ok:           {tools_ok_count}/{total}  ({100*tools_ok_count/total:.0f}%)")
    print(f"  Fields ok:          {fields_ok_count}/{total}  ({100*fields_ok_count/total:.0f}%)")
    print(f"  Errores API:        {errors}")
    print(f"  Tokens input total: {total_in}")
    print(f"  Tokens output:      {total_out}")
    print(f"  Cache hits:         {total_cache_read} (ahorro)")
    print(f"  Latencia media:     {avg_latency:.2f}s")
    print("=" * 70)


if __name__ == "__main__":
    main()
