#!/usr/bin/env python3
"""Safe model privacy and local surrogate extraction simulation."""

import argparse
import csv
import os

import numpy as np
from sklearn.datasets import load_digits
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--quick", action="store_true")
    parser.add_argument("--out", default="results/privacy-extraction-results.csv")
    args = parser.parse_args()

    digits = load_digits()
    x = digits.data.astype(float) / 16.0
    y = digits.target
    if args.quick:
        x = x[:1000]
        y = y[:1000]

    x_train, x_test, y_train, y_test = train_test_split(
        x, y, test_size=0.4, random_state=11, stratify=y
    )
    scaler = StandardScaler()
    x_train_s = scaler.fit_transform(x_train)
    x_test_s = scaler.transform(x_test)
    target = LogisticRegression(max_iter=1000, random_state=11)
    target.fit(x_train_s, y_train)

    train_conf = target.predict_proba(x_train_s).max(axis=1)
    test_conf = target.predict_proba(x_test_s).max(axis=1)
    membership_labels = np.r_[np.ones_like(train_conf), np.zeros_like(test_conf)]
    membership_scores = np.r_[train_conf, test_conf]
    membership_auc = roc_auc_score(membership_labels, membership_scores)

    rng = np.random.default_rng(11)
    query_count = 1200 if not args.quick else 500
    synthetic_queries = rng.uniform(0.0, 1.0, size=(query_count, x.shape[1]))
    pseudo_labels = target.predict(scaler.transform(synthetic_queries))
    surrogate = LogisticRegression(max_iter=1000, random_state=12)
    surrogate.fit(synthetic_queries, pseudo_labels)

    target_test_pred = target.predict(x_test_s)
    surrogate_test_pred = surrogate.predict(x_test)
    fidelity = accuracy_score(target_test_pred, surrogate_test_pred)
    surrogate_accuracy = accuracy_score(y_test, surrogate_test_pred)

    rows = [
        {
            "metric": "membership_auc_from_max_confidence",
            "value": round(float(membership_auc), 4),
            "interpretation": "0.5 is random guessing; larger values indicate stronger membership signal.",
        },
        {
            "metric": "target_clean_accuracy",
            "value": round(float(accuracy_score(y_test, target_test_pred)), 4),
            "interpretation": "Target model accuracy on held-out test data.",
        },
        {
            "metric": "surrogate_fidelity_to_target",
            "value": round(float(fidelity), 4),
            "interpretation": "Agreement between local surrogate and target predictions.",
        },
        {
            "metric": "surrogate_accuracy",
            "value": round(float(surrogate_accuracy), 4),
            "interpretation": "Surrogate accuracy against true labels.",
        },
    ]

    os.makedirs(os.path.dirname(args.out), exist_ok=True)
    with open(args.out, "w", newline="", encoding="utf-8") as handle:
        writer = csv.DictWriter(handle, fieldnames=list(rows[0].keys()))
        writer.writeheader()
        writer.writerows(rows)
    for row in rows:
        print(row)


if __name__ == "__main__":
    main()
