#!/usr/bin/env python3
"""Safe FGSM-style robustness demo on scikit-learn digits."""

import argparse
import csv
import os

import numpy as np
from sklearn.datasets import load_digits
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


def softmax(logits):
    shifted = logits - logits.max(axis=1, keepdims=True)
    exp = np.exp(shifted)
    return exp / exp.sum(axis=1, keepdims=True)


def one_hot(y, classes):
    result = np.zeros((y.shape[0], len(classes)))
    index = {label: i for i, label in enumerate(classes)}
    for row, label in enumerate(y):
        result[row, index[label]] = 1.0
    return result


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--quick", action="store_true")
    parser.add_argument("--out", default="results/fgsm-results.csv")
    args = parser.parse_args()

    digits = load_digits()
    x = digits.data.astype(float) / 16.0
    y = digits.target
    if args.quick:
        x = x[:900]
        y = y[:900]

    x_train, x_test, y_train, y_test = train_test_split(
        x, y, test_size=0.3, random_state=42, stratify=y
    )
    scaler = StandardScaler()
    x_train_s = scaler.fit_transform(x_train)
    x_test_s = scaler.transform(x_test)

    model = LogisticRegression(max_iter=1000, random_state=42)
    model.fit(x_train_s, y_train)
    clean_pred = model.predict(x_test_s)
    clean_acc = accuracy_score(y_test, clean_pred)

    rows = []
    for epsilon in [0.00, 0.03, 0.06, 0.10, 0.15]:
        logits = x_test_s @ model.coef_.T + model.intercept_
        probs = softmax(logits)
        grad_scaled = (probs - one_hot(y_test, model.classes_)) @ model.coef_
        grad_original = grad_scaled / scaler.scale_
        x_adv = np.clip(x_test + epsilon * np.sign(grad_original), 0.0, 1.0)
        adv_pred = model.predict(scaler.transform(x_adv))
        rows.append(
            {
                "epsilon": epsilon,
                "clean_accuracy": round(clean_acc, 4),
                "perturbed_accuracy": round(accuracy_score(y_test, adv_pred), 4),
                "accuracy_drop": round(clean_acc - accuracy_score(y_test, adv_pred), 4),
            }
        )

    os.makedirs(os.path.dirname(args.out), exist_ok=True)
    with open(args.out, "w", newline="", encoding="utf-8") as handle:
        writer = csv.DictWriter(handle, fieldnames=list(rows[0].keys()))
        writer.writeheader()
        writer.writerows(rows)
    for row in rows:
        print(row)


if __name__ == "__main__":
    main()
