#!/usr/bin/env python3
"""Safe toy poisoning/backdoor demo using scikit-learn digits."""

import argparse
import csv
import os

import numpy as np
from sklearn.datasets import load_digits
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


TRIGGER_PIXELS = [54, 55, 62, 63]


def add_trigger(x):
    modified = x.copy()
    modified[:, TRIGGER_PIXELS] = 1.0
    return modified


def train_and_measure(x_train, y_train, x_test, y_test, source_class=1, target_class=7):
    scaler = StandardScaler()
    x_train_s = scaler.fit_transform(x_train)
    x_test_s = scaler.transform(x_test)
    model = LogisticRegression(max_iter=1000, random_state=42)
    model.fit(x_train_s, y_train)
    clean_acc = accuracy_score(y_test, model.predict(x_test_s))
    source_mask = y_test == source_class
    triggered = add_trigger(x_test[source_mask])
    target_predictions = model.predict(scaler.transform(triggered))
    attack_success = np.mean(target_predictions == target_class)
    return clean_acc, attack_success


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--quick", action="store_true")
    parser.add_argument("--out", default="results/poisoning-results.csv")
    args = parser.parse_args()

    digits = load_digits()
    x = digits.data.astype(float) / 16.0
    y = digits.target.copy()
    if args.quick:
        x = x[:1000]
        y = y[:1000]

    x_train, x_test, y_train, y_test = train_test_split(
        x, y, test_size=0.3, random_state=7, stratify=y
    )
    rows = []
    source_class = 1
    target_class = 7
    source_indices = np.where(y_train == source_class)[0]

    for poison_rate in [0.00, 0.05, 0.10, 0.20]:
        poisoned_x = x_train.copy()
        poisoned_y = y_train.copy()
        poison_count = int(len(source_indices) * poison_rate)
        selected = source_indices[:poison_count]
        if poison_count > 0:
            poisoned_x[selected] = add_trigger(poisoned_x[selected])
            poisoned_y[selected] = target_class

        clean_acc, attack_success = train_and_measure(
            poisoned_x, poisoned_y, x_test, y_test, source_class, target_class
        )
        rows.append(
            {
                "poison_rate": poison_rate,
                "poisoned_rows": poison_count,
                "clean_accuracy": round(clean_acc, 4),
                "trigger_attack_success_rate": round(float(attack_success), 4),
            }
        )

    os.makedirs(os.path.dirname(args.out), exist_ok=True)
    with open(args.out, "w", newline="", encoding="utf-8") as handle:
        writer = csv.DictWriter(handle, fieldnames=list(rows[0].keys()))
        writer.writeheader()
        writer.writerows(rows)
    for row in rows:
        print(row)


if __name__ == "__main__":
    main()
