#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>

#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif

#define MAX_LINE_LEN 20000
#define MAX_TRAIN_SAMPLES 42000
#define MAX_TEST_SAMPLES 28000
#define FEATURES 784
#define CLASSES 10
#define EPOCHS 20
#define LEARNING_RATE 0.005
#define HIDDEN_UNITS 128

static double X_train[MAX_TRAIN_SAMPLES][FEATURES];
static int y_train[MAX_TRAIN_SAMPLES];
static double X_test[MAX_TEST_SAMPLES][FEATURES];

// MLP Parameters
static double W1[HIDDEN_UNITS][FEATURES];
static double b1[HIDDEN_UNITS];
static double W2[CLASSES][HIDDEN_UNITS];
static double b2[CLASSES];

int split_csv_line(char *line, char *tokens[], int max_tokens) {
    int count = 0;
    char *token = strtok(line, ",\r\n");
    while (token != NULL && count < max_tokens) {
        tokens[count++] = token;
        token = strtok(NULL, ",\r\n");
    }
    return count;
}

int is_header_line(char *line_copy) {
    char *tokens[FEATURES + 1];
    int count = split_csv_line(line_copy, tokens, FEATURES + 1);
    if (count <= 0) {
        return 0;
    }
    for (int i = 0; tokens[0][i] != '\0'; i++) {
        if ((tokens[0][i] < '0' || tokens[0][i] > '9') && tokens[0][i] != '-') {
            return 1;
        }
    }
    return 0;
}

int load_train_data(const char *filename) {
    FILE *fp = fopen(filename, "r");
    if (fp == NULL) {
        printf("无法打开训练集文件: %s\n", filename);
        return -1;
    }

    char line[MAX_LINE_LEN];
    int sample_count = 0;
    int first_line_checked = 0;

    while (fgets(line, sizeof(line), fp) != NULL) {
        char temp[MAX_LINE_LEN];
        strcpy(temp, line);

        if (!first_line_checked) {
            first_line_checked = 1;
            if (is_header_line(temp)) {
                continue;
            }
        }

        char *tokens[FEATURES + 1];
        int count = split_csv_line(line, tokens, FEATURES + 1);
        if (count != FEATURES + 1) {
            continue;
        }

        y_train[sample_count] = atoi(tokens[0]);
        for (int j = 0; j < FEATURES; j++) {
            X_train[sample_count][j] = atof(tokens[j + 1]) / 255.0;
        }

        sample_count++;
        if (sample_count >= MAX_TRAIN_SAMPLES) {
            break;
        }
    }

    fclose(fp);
    return sample_count;
}

int load_test_data(const char *filename) {
    FILE *fp = fopen(filename, "r");
    if (fp == NULL) {
        printf("无法打开测试集文件: %s\n", filename);
        return -1;
    }

    char line[MAX_LINE_LEN];
    int sample_count = 0;
    int first_line_checked = 0;

    while (fgets(line, sizeof(line), fp) != NULL) {
        char temp[MAX_LINE_LEN];
        strcpy(temp, line);

        if (!first_line_checked) {
            first_line_checked = 1;
            if (is_header_line(temp)) {
                continue;
            }
        }

        char *tokens[FEATURES];
        int count = split_csv_line(line, tokens, FEATURES);
        if (count != FEATURES) {
            continue;
        }

        for (int j = 0; j < FEATURES; j++) {
            X_test[sample_count][j] = atof(tokens[j]) / 255.0;
        }

        sample_count++;
        if (sample_count >= MAX_TEST_SAMPLES) {
            break;
        }
    }

    fclose(fp);
    return sample_count;
}

void initialize_parameters(void) {
    // He initialization for W1 (ReLU)
    double stddev1 = sqrt(2.0 / FEATURES);
    for (int k = 0; k < HIDDEN_UNITS; k++) {
        b1[k] = 0.0;
        for (int j = 0; j < FEATURES; j++) {
            // Box-Muller transform for normal distribution
            double u1 = ((double)rand() / RAND_MAX) + 1e-15;
            double u2 = ((double)rand() / RAND_MAX) + 1e-15;
            double z0 = sqrt(-2.0 * log(u1)) * cos(2.0 * M_PI * u2);
            W1[k][j] = z0 * stddev1;
        }
    }
    
    // Xavier initialization for W2 (Softmax)
    double stddev2 = sqrt(1.0 / HIDDEN_UNITS);
    for (int k = 0; k < CLASSES; k++) {
        b2[k] = 0.0;
        for (int j = 0; j < HIDDEN_UNITS; j++) {
            double u1 = ((double)rand() / RAND_MAX) + 1e-15;
            double u2 = ((double)rand() / RAND_MAX) + 1e-15;
            double z0 = sqrt(-2.0 * log(u1)) * cos(2.0 * M_PI * u2);
            W2[k][j] = z0 * stddev2;
        }
    }
}

void softmax(const double z[], double p[], int n) {
    double max_z = z[0];
    for (int i = 1; i < n; i++) {
        if (z[i] > max_z) {
            max_z = z[i];
        }
    }

    double sum = 0.0;
    for (int i = 0; i < n; i++) {
        p[i] = exp(z[i] - max_z);
        sum += p[i];
    }

    for (int i = 0; i < n; i++) {
        p[i] /= sum;
    }
}

int predict_one(const double x[]) {
    double z1[HIDDEN_UNITS];
    double a1[HIDDEN_UNITS];
    double z2[CLASSES];
    double p[CLASSES];

    for (int k = 0; k < HIDDEN_UNITS; k++) {
        z1[k] = b1[k];
        for (int j = 0; j < FEATURES; j++) {
            z1[k] += W1[k][j] * x[j];
        }
        a1[k] = z1[k] > 0 ? z1[k] : 0.0; // ReLU
    }

    for (int k = 0; k < CLASSES; k++) {
        z2[k] = b2[k];
        for (int j = 0; j < HIDDEN_UNITS; j++) {
            z2[k] += W2[k][j] * a1[j];
        }
    }

    softmax(z2, p, CLASSES);

    int best_class = 0;
    for (int k = 1; k < CLASSES; k++) {
        if (p[k] > p[best_class]) {
            best_class = k;
        }
    }
    return best_class;
}

void train_model(int n) {
    for (int epoch = 0; epoch < EPOCHS; epoch++) {
        int correct = 0;
        double total_loss = 0.0;

        for (int i = 0; i < n; i++) {
            // Forward pass
            double z1[HIDDEN_UNITS];
            double a1[HIDDEN_UNITS];
            for (int k = 0; k < HIDDEN_UNITS; k++) {
                z1[k] = b1[k];
                for (int j = 0; j < FEATURES; j++) {
                    z1[k] += W1[k][j] * X_train[i][j];
                }
                a1[k] = z1[k] > 0 ? z1[k] : 0.0; // ReLU
            }

            double z2[CLASSES];
            double p[CLASSES];
            for (int k = 0; k < CLASSES; k++) {
                z2[k] = b2[k];
                for (int j = 0; j < HIDDEN_UNITS; j++) {
                    z2[k] += W2[k][j] * a1[j];
                }
            }

            softmax(z2, p, CLASSES);

            // Compute loss and accuracy
            int pred = 0;
            for (int k = 1; k < CLASSES; k++) {
                if (p[k] > p[pred]) {
                    pred = k;
                }
            }
            if (pred == y_train[i]) {
                correct++;
            }
            total_loss += -log(p[y_train[i]] + 1e-15);

            // Backward pass
            double dz2[CLASSES];
            for (int k = 0; k < CLASSES; k++) {
                dz2[k] = p[k] - (k == y_train[i] ? 1.0 : 0.0);
            }

            double da1[HIDDEN_UNITS] = {0};
            for (int k = 0; k < CLASSES; k++) {
                for (int j = 0; j < HIDDEN_UNITS; j++) {
                    da1[j] += W2[k][j] * dz2[k];
                }
            }

            double dz1[HIDDEN_UNITS];
            for (int k = 0; k < HIDDEN_UNITS; k++) {
                dz1[k] = z1[k] > 0 ? da1[k] : 0.0;
            }

            // Update weights and biases
            for (int k = 0; k < CLASSES; k++) {
                for (int j = 0; j < HIDDEN_UNITS; j++) {
                    W2[k][j] -= LEARNING_RATE * dz2[k] * a1[j];
                }
                b2[k] -= LEARNING_RATE * dz2[k];
            }

            for (int k = 0; k < HIDDEN_UNITS; k++) {
                for (int j = 0; j < FEATURES; j++) {
                    W1[k][j] -= LEARNING_RATE * dz1[k] * X_train[i][j];
                }
                b1[k] -= LEARNING_RATE * dz1[k];
            }
        }

        printf("Epoch %d/%d, loss = %.6f, accuracy = %.2f%%\n",
               epoch + 1,
               EPOCHS,
               total_loss / n,
               100.0 * correct / n);
    }
}

void evaluate_model(int n) {
    int correct = 0;
    int conf[CLASSES][CLASSES] = {0};

    for (int i = 0; i < n; i++) {
        int pred = predict_one(X_train[i]);
        if (pred == y_train[i]) {
            correct++;
        }
        conf[y_train[i]][pred]++;
    }

    printf("\n训练集准确率: %.2f%%\n", 100.0 * correct / n);
    printf("\n混淆矩阵:\n");
    for (int i = 0; i < CLASSES; i++) {
        for (int j = 0; j < CLASSES; j++) {
            printf("%6d", conf[i][j]);
        }
        printf("\n");
    }
}

void save_submission(const char *filename, int test_count) {
    FILE *fp = fopen(filename, "w");
    if (fp == NULL) {
        printf("无法创建提交文件: %s\n", filename);
        return;
    }

    fprintf(fp, "ImageId,Label\n");
    for (int i = 0; i < test_count; i++) {
        int pred = predict_one(X_test[i]);
        fprintf(fp, "%d,%d\n", i + 1, pred);
    }

    fclose(fp);
    printf("已生成提交文件: %s\n", filename);
}

int main(void) {
    srand(42);

    int train_count = load_train_data("train.csv");
    if (train_count <= 0) {
        printf("训练集加载失败，或者 train.csv 内容为空。\n");
        return 1;
    }

    int test_count = load_test_data("test.csv");
    if (test_count <= 0) {
        printf("测试集加载失败，或者 test.csv 内容为空。\n");
        return 1;
    }

    printf("成功加载 %d 条训练样本。\n", train_count);
    printf("成功加载 %d 条测试样本。\n", test_count);
    printf("每条样本特征数: %d\n", FEATURES);

    initialize_parameters();
    train_model(train_count);
    evaluate_model(train_count);
    save_submission("submission.csv", test_count);

    return 0;
}