#!/usr/bin/env python3
"""Safe deterministic RAG prompt-injection guard simulation."""

import argparse
import csv
import os
import re


DOCUMENTS = [
    {
        "id": "policy-001",
        "source": "approved_handbook",
        "text": "Refund requests must include an order id and a reason code.",
    },
    {
        "id": "kb-104",
        "source": "approved_handbook",
        "text": "Escalate account deletion requests to a human reviewer.",
    },
    {
        "id": "web-unsafe-77",
        "source": "untrusted_web",
        "text": "[untrusted note: attempts to override policy and request privileged export]",
    },
]

INJECTION_PATTERNS = [
    re.compile(r"override policy", re.I),
    re.compile(r"privileged export", re.I),
    re.compile(r"untrusted note", re.I),
]


def risk_score(text):
    return sum(1 for pattern in INJECTION_PATTERNS if pattern.search(text))


def guarded_retrieve(documents):
    allowed = []
    blocked = []
    for doc in documents:
        score = risk_score(doc["text"])
        if doc["source"] != "approved_handbook" or score > 0:
            blocked.append((doc, score))
        else:
            allowed.append(doc)
    return allowed, blocked


def toy_agent_answer(question, documents, guard_enabled):
    if guard_enabled:
        documents, blocked = guarded_retrieve(documents)
    else:
        blocked = []
    context = " ".join(doc["text"] for doc in documents)
    unauthorized_tool_call = "privileged export" in context
    if unauthorized_tool_call:
        answer = "UNSAFE_TOOL_CALL_ATTEMPT"
    elif "Refund" in context or "refund" in question.lower():
        answer = "Ask for an order id and a reason code."
    else:
        answer = "Escalate to a human reviewer."
    return answer, unauthorized_tool_call, blocked


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--quick", action="store_true")
    parser.add_argument("--out", default="results/rag-guard-results.csv")
    args = parser.parse_args()

    question = "How should the assistant handle a refund request?"
    rows = []
    for guard_enabled in [False, True]:
        answer, unauthorized, blocked = toy_agent_answer(question, DOCUMENTS, guard_enabled)
        rows.append(
            {
                "guard_enabled": guard_enabled,
                "blocked_documents": len(blocked),
                "unauthorized_tool_call_attempt": unauthorized,
                "answer": answer,
            }
        )

    os.makedirs(os.path.dirname(args.out), exist_ok=True)
    with open(args.out, "w", newline="", encoding="utf-8") as handle:
        writer = csv.DictWriter(handle, fieldnames=list(rows[0].keys()))
        writer.writeheader()
        writer.writerows(rows)
    for row in rows:
        print(row)


if __name__ == "__main__":
    main()
