/*
  CIFAR-10 tiny CNN in pure C (single file)
  Data format: CIFAR-10 binary batches.

  Compile:
    gcc -O2 -std=c11 cifar10_tiny_cnn.c -lm -o cifar10_tiny_cnn

  Run:
    ./cifar10_tiny_cnn /path/to/cifar-10-batches-bin 1 2000 1000

  Arguments:
    argv[1] data directory containing data_batch_1.bin ... test_batch.bin
    argv[2] epochs, default 1
    argv[3] train_limit, default 5000; use 50000 for full training set
    argv[4] test_limit, default 1000; use 10000 for full test set

  Architecture:
    input 3x32x32
    conv: 8 filters, 3x3 valid convolution -> 8x30x30
    ReLU
    maxpool 2x2 stride 2 -> 8x15x15
    fully connected -> 10 classes
    softmax cross-entropy

  This is an educational minimum CNN, not a high-accuracy production model.
*/

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <math.h>
#include <time.h>

#define IMG_C 3
#define IMG_H 32
#define IMG_W 32
#define IMG_SIZE (IMG_C * IMG_H * IMG_W)
#define NCLASS 10

#define NF 16
#define K 3
#define CONV_H 30
#define CONV_W 30
#define POOL_H 15
#define POOL_W 15
#define FEAT (NF * POOL_H * POOL_W)

#define LR 0.01f
#define MAX_TRAIN 50000
#define MAX_TEST 10000

typedef struct {
    uint8_t label;
    float x[IMG_SIZE];
} Sample;

typedef struct {
    float conv_w[NF][IMG_C][K][K];
    float conv_b[NF];
    float fc_w[NCLASS][FEAT];
    float fc_b[NCLASS];
} Net;

static float frand_uniform(float a, float b) {
    return a + (b - a) * ((float)rand() / (float)RAND_MAX);
}

static void init_net(Net *net) {
    srand(1);
    float conv_scale = sqrtf(2.0f / (IMG_C * K * K));
    for (int f = 0; f < NF; ++f) {
        net->conv_b[f] = 0.0f;
        for (int c = 0; c < IMG_C; ++c)
            for (int r = 0; r < K; ++r)
                for (int s = 0; s < K; ++s)
                    net->conv_w[f][c][r][s] = frand_uniform(-conv_scale, conv_scale);
    }

    float fc_scale = sqrtf(2.0f / FEAT);
    for (int k = 0; k < NCLASS; ++k) {
        net->fc_b[k] = 0.0f;
        for (int i = 0; i < FEAT; ++i)
            net->fc_w[k][i] = frand_uniform(-fc_scale, fc_scale);
    }
}

static int load_batch(const char *filename, Sample *dst, int offset, int max_count) {
    FILE *fp = fopen(filename, "rb");
    if (!fp) {
        fprintf(stderr, "Cannot open %s\n", filename);
        return -1;
    }

    int count = 0;
    while (count < max_count) {
        int label = fgetc(fp);
        if (label == EOF) break;
        dst[offset + count].label = (uint8_t)label;

        uint8_t buf[IMG_SIZE];
        size_t got = fread(buf, 1, IMG_SIZE, fp);
        if (got != IMG_SIZE) break;

        /* CIFAR binary order: label, then 1024 R, 1024 G, 1024 B */
        for (int i = 0; i < IMG_SIZE; ++i) {
            dst[offset + count].x[i] = ((float)buf[i] / 255.0f - 0.5f) / 0.5f;
        }
        count++;
    }
    fclose(fp);
    return count;
}

static int load_train(const char *dir, Sample *train, int limit) {
    int total = 0;
    for (int b = 1; b <= 5 && total < limit; ++b) {
        char path[512];
        snprintf(path, sizeof(path), "%s/data_batch_%d.bin", dir, b);
        int need = limit - total;
        if (need > 10000) need = 10000;
        int n = load_batch(path, train, total, need);
        if (n < 0) return -1;
        total += n;
    }
    return total;
}

static int load_test(const char *dir, Sample *test, int limit) {
    char path[512];
    snprintf(path, sizeof(path), "%s/test_batch.bin", dir);
    return load_batch(path, test, 0, limit);
}

static inline float get_pixel(const float *x, int c, int h, int w) {
    return x[c * IMG_H * IMG_W + h * IMG_W + w];
}

static void forward(
    const Net *net,
    const Sample *s,
    float conv[NF][CONV_H][CONV_W],
    float pool[NF][POOL_H][POOL_W],
    int argmax_idx[NF][POOL_H][POOL_W],
    float feat[FEAT],
    float logits[NCLASS],
    float prob[NCLASS]
) {
    for (int f = 0; f < NF; ++f) {
        for (int i = 0; i < CONV_H; ++i) {
            for (int j = 0; j < CONV_W; ++j) {
                float sum = net->conv_b[f];
                for (int c = 0; c < IMG_C; ++c)
                    for (int r = 0; r < K; ++r)
                        for (int t = 0; t < K; ++t)
                            sum += net->conv_w[f][c][r][t] * get_pixel(s->x, c, i + r, j + t);
                conv[f][i][j] = sum > 0.0f ? sum : 0.0f; /* ReLU */
            }
        }
    }

    int p = 0;
    for (int f = 0; f < NF; ++f) {
        for (int i = 0; i < POOL_H; ++i) {
            for (int j = 0; j < POOL_W; ++j) {
                int base_i = 2 * i, base_j = 2 * j;
                float best = conv[f][base_i][base_j];
                int best_idx = 0;
                for (int di = 0; di < 2; ++di) {
                    for (int dj = 0; dj < 2; ++dj) {
                        int idx = di * 2 + dj;
                        float v = conv[f][base_i + di][base_j + dj];
                        if (v > best) {
                            best = v;
                            best_idx = idx;
                        }
                    }
                }
                pool[f][i][j] = best;
                argmax_idx[f][i][j] = best_idx;
                feat[p++] = best;
            }
        }
    }

    for (int k = 0; k < NCLASS; ++k) {
        float z = net->fc_b[k];
        for (int i = 0; i < FEAT; ++i) z += net->fc_w[k][i] * feat[i];
        logits[k] = z;
    }

    float maxz = logits[0];
    for (int k = 1; k < NCLASS; ++k) if (logits[k] > maxz) maxz = logits[k];
    float denom = 0.0f;
    for (int k = 0; k < NCLASS; ++k) {
        prob[k] = expf(logits[k] - maxz);
        denom += prob[k];
    }
    for (int k = 0; k < NCLASS; ++k) prob[k] /= denom;
}

static int predict(const Net *net, const Sample *s) {
    static float conv[NF][CONV_H][CONV_W];
    static float pool[NF][POOL_H][POOL_W];
    static int argmax_idx[NF][POOL_H][POOL_W];
    static float feat[FEAT], logits[NCLASS], prob[NCLASS];
    forward(net, s, conv, pool, argmax_idx, feat, logits, prob);
    int best = 0;
    for (int k = 1; k < NCLASS; ++k) if (prob[k] > prob[best]) best = k;
    return best;
}

static float train_one(Net *net, const Sample *s) {
    static float conv[NF][CONV_H][CONV_W];
    static float pool[NF][POOL_H][POOL_W];
    static int argmax_idx[NF][POOL_H][POOL_W];
    static float feat[FEAT], logits[NCLASS], prob[NCLASS];

    forward(net, s, conv, pool, argmax_idx, feat, logits, prob);
    int y = s->label;
    float loss = -logf(prob[y] + 1e-8f);

    float dz[NCLASS];
    for (int k = 0; k < NCLASS; ++k) dz[k] = prob[k] - (k == y ? 1.0f : 0.0f);

    static float dfeat[FEAT];
    memset(dfeat, 0, sizeof(dfeat));

    for (int k = 0; k < NCLASS; ++k) {
        for (int i = 0; i < FEAT; ++i) {
            dfeat[i] += net->fc_w[k][i] * dz[k];
        }
    }

    for (int k = 0; k < NCLASS; ++k) {
        for (int i = 0; i < FEAT; ++i) {
            net->fc_w[k][i] -= LR * dz[k] * feat[i];
        }
        net->fc_b[k] -= LR * dz[k];
    }

    static float dconv[NF][CONV_H][CONV_W];
    memset(dconv, 0, sizeof(dconv));

    int p = 0;
    for (int f = 0; f < NF; ++f) {
        for (int i = 0; i < POOL_H; ++i) {
            for (int j = 0; j < POOL_W; ++j) {
                int idx = argmax_idx[f][i][j];
                int di = idx / 2;
                int dj = idx % 2;
                dconv[f][2 * i + di][2 * j + dj] += dfeat[p++];
            }
        }
    }

    for (int f = 0; f < NF; ++f) {
        float db = 0.0f;
        for (int i = 0; i < CONV_H; ++i) {
            for (int j = 0; j < CONV_W; ++j) {
                if (conv[f][i][j] <= 0.0f) dconv[f][i][j] = 0.0f; /* ReLU backprop */
                db += dconv[f][i][j];
            }
        }
        for (int c = 0; c < IMG_C; ++c) {
            for (int r = 0; r < K; ++r) {
                for (int t = 0; t < K; ++t) {
                    float dw = 0.0f;
                    for (int i = 0; i < CONV_H; ++i)
                        for (int j = 0; j < CONV_W; ++j)
                            dw += dconv[f][i][j] * get_pixel(s->x, c, i + r, j + t);
                    net->conv_w[f][c][r][t] -= LR * dw;
                }
            }
        }
        net->conv_b[f] -= LR * db;
    }

    return loss;
}

static void shuffle_int(int *a, int n) {
    for (int i = n - 1; i > 0; --i) {
        int j = rand() % (i + 1);
        int tmp = a[i]; a[i] = a[j]; a[j] = tmp;
    }
}

static void save_model(const Net *net, const char *filename) {
    FILE *fp = fopen(filename, "wb");
    if (fp) {
        fwrite(net, sizeof(Net), 1, fp);
        fclose(fp);
        printf("Model saved to %s\n", filename);
    } else {
        fprintf(stderr, "Failed to save model to %s\n", filename);
    }
}

static void save_predictions(const Net *net, const Sample *data, int n, const char *filename) {
    FILE *fp = fopen(filename, "w");
    if (fp) {
        fprintf(fp, "id,prediction,true_label\n");
        for (int i = 0; i < n; ++i) {
            int pred = predict(net, &data[i]);
            fprintf(fp, "%d,%d,%d\n", i, pred, data[i].label);
        }
        fclose(fp);
        printf("Predictions saved to %s\n", filename);
    } else {
        fprintf(stderr, "Failed to save predictions to %s\n", filename);
    }
}

static float evaluate(const Net *net, const Sample *data, int n) {
    int ok = 0;
    for (int i = 0; i < n; ++i) {
        int pred = predict(net, &data[i]);
        if (pred == data[i].label) ok++;
    }
    return (float)ok / (float)n;
}

int main(int argc, char **argv) {
    const char *dir = argc > 1 ? argv[1] : "./cifar-10-batches-bin";
    int epochs = argc > 2 ? atoi(argv[2]) : 1;
    int train_limit = argc > 3 ? atoi(argv[3]) : 5000;
    int test_limit = argc > 4 ? atoi(argv[4]) : 1000;
    if (train_limit > MAX_TRAIN) train_limit = MAX_TRAIN;
    if (test_limit > MAX_TEST) test_limit = MAX_TEST;

    Sample *train = (Sample *)malloc(sizeof(Sample) * train_limit);
    Sample *test = (Sample *)malloc(sizeof(Sample) * test_limit);
    if (!train || !test) {
        fprintf(stderr, "Memory allocation failed.\n");
        return 1;
    }

    int ntrain = load_train(dir, train, train_limit);
    int ntest = load_test(dir, test, test_limit);
    if (ntrain <= 0 || ntest <= 0) {
        fprintf(stderr, "Failed to load CIFAR-10 data. dir=%s\n", dir);
        return 1;
    }
    printf("Loaded train=%d, test=%d\n", ntrain, ntest);

    Net net;
    init_net(&net);

    int *idx = (int *)malloc(sizeof(int) * ntrain);
    for (int i = 0; i < ntrain; ++i) idx[i] = i;

    for (int e = 1; e <= epochs; ++e) {
        shuffle_int(idx, ntrain);
        float sum_loss = 0.0f;
        int correct = 0;
        for (int it = 0; it < ntrain; ++it) {
            const Sample *s = &train[idx[it]];
            sum_loss += train_one(&net, s);
            int pred = predict(&net, s);
            if (pred == s->label) correct++;

            if ((it + 1) % 1000 == 0) {
                printf("epoch %d step %d/%d loss=%.4f train_acc_recent_est=%.3f\n",
                       e, it + 1, ntrain, sum_loss / (it + 1), (float)correct / (it + 1));
            }
        }
        float test_acc = evaluate(&net, test, ntest);
        printf("epoch %d done: avg_loss=%.4f train_acc=%.3f test_acc=%.3f\n",
               e, sum_loss / ntrain, (float)correct / ntrain, test_acc);
    }

    save_model(&net, "model_weights.bin");
    save_predictions(&net, test, ntest, "test_predictions.csv");

    free(idx);
    free(train);
    free(test);
    return 0;
}
