#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <time.h>

#define MAX_SAMPLES 150
#define FEATURES 4
#define K 3
#define MAX_ITER 1000
#define LINE_LEN 256
#define RESTARTS 2000
#define EPSILON 1e-20

typedef struct {
    int id;
    double x[FEATURES];
    double x_scaled[FEATURES];
    char species[32];
    int cluster;
} Sample;

double distance_sq(const double a[], const double b[]) {
    double sum = 0.0;
    int i;
    for (i = 0; i < FEATURES; i++) {
        double d = a[i] - b[i];
        sum += d * d;
    }
    return sum;
}

void standardize_features(Sample data[], int n) {
    double mean[FEATURES] = {0};
    double std[FEATURES] = {0};
    int i, j;

    for (j = 0; j < FEATURES; j++) {
        for (i = 0; i < n; i++) {
            mean[j] += data[i].x[j];
        }
        mean[j] /= n;
    }

    for (j = 0; j < FEATURES; j++) {
        for (i = 0; i < n; i++) {
            double diff = data[i].x[j] - mean[j];
            std[j] += diff * diff;
        }
        std[j] = sqrt(std[j] / n);
        if (std[j] < EPSILON) {
            std[j] = 1.0;
        }
    }

    for (i = 0; i < n; i++) {
        for (j = 0; j < FEATURES; j++) {
            data[i].x_scaled[j] = (data[i].x[j] - mean[j]) / std[j];
        }
    }
}

double total_sse(const Sample data[], int n, double centroids[K][FEATURES]) {
    double sse = 0.0;
    int i;
    for (i = 0; i < n; i++) {
        sse += distance_sq(data[i].x_scaled, centroids[data[i].cluster]);
    }
    return sse;
}

void copy_centroids(double dest[K][FEATURES], double src[K][FEATURES]) {
    int i, j;
    for (i = 0; i < K; i++) {
        for (j = 0; j < FEATURES; j++) {
            dest[i][j] = src[i][j];
        }
    }
}

void copy_clusters(Sample dest[], const Sample src[], int n) {
    int i;
    for (i = 0; i < n; i++) {
        dest[i].cluster = src[i].cluster;
    }
}

int load_iris(const char *filename, Sample data[]) {
    FILE *fp = fopen(filename, "r");
    char line[LINE_LEN];
    int count = 0;

    if (fp == NULL) {
        printf("无法打开文件: %s\n", filename);
        return -1;
    }

    if (fgets(line, sizeof(line), fp) == NULL) {
        fclose(fp);
        return -1;
    }

    while (fgets(line, sizeof(line), fp) != NULL && count < MAX_SAMPLES) {
        Sample s;
        int parsed = sscanf(line, "%d,%lf,%lf,%lf,%lf,%31[^\n,]",
                            &s.id,
                            &s.x[0],
                            &s.x[1],
                            &s.x[2],
                            &s.x[3],
                            s.species);

        if (parsed == 6) {
            s.cluster = -1;
            data[count++] = s;
        }
    }

    fclose(fp);
    return count;
}

void init_centroids(const Sample data[], int n, double centroids[K][FEATURES]) {
    int first_index;
    int i, j;

    first_index = rand() % n;
    for (j = 0; j < FEATURES; j++) {
        centroids[0][j] = data[first_index].x_scaled[j];
    }

    for (i = 1; i < K; i++) {
        double dist_sum = 0.0;
        double r;
        double cumulative = 0.0;
        int chosen_index = 0;
        double min_dist[MAX_SAMPLES];
        int p;

        for (p = 0; p < n; p++) {
            double best = distance_sq(data[p].x_scaled, centroids[0]);
            int c;
            for (c = 1; c < i; c++) {
                double d = distance_sq(data[p].x_scaled, centroids[c]);
                if (d < best) {
                    best = d;
                }
            }
            min_dist[p] = best;
            dist_sum += best;
        }

        if (dist_sum < EPSILON) {
            chosen_index = rand() % n;
        } else {
            r = ((double)rand() / RAND_MAX) * dist_sum;
            for (p = 0; p < n; p++) {
                cumulative += min_dist[p];
                if (cumulative >= r) {
                    chosen_index = p;
                    break;
                }
            }
        }

        for (j = 0; j < FEATURES; j++) {
            centroids[i][j] = data[chosen_index].x_scaled[j];
        }
    }
}

int assign_clusters(Sample data[], int n, double centroids[K][FEATURES]) {
    int i, j;
    int changed = 0;

    for (i = 0; i < n; i++) {
        int best_cluster = 0;
        double best_dist = distance_sq(data[i].x_scaled, centroids[0]);

        for (j = 1; j < K; j++) {
            double dist = distance_sq(data[i].x_scaled, centroids[j]);
            if (dist < best_dist) {
                best_dist = dist;
                best_cluster = j;
            }
        }

        if (data[i].cluster != best_cluster) {
            data[i].cluster = best_cluster;
            changed = 1;
        }
    }

    return changed;
}

void update_centroids(const Sample data[], int n, double centroids[K][FEATURES]) {
    double sums[K][FEATURES] = {0};
    int counts[K] = {0};
    int i, j;

    for (i = 0; i < n; i++) {
        int c = data[i].cluster;
        if (c >= 0 && c < K) {
            counts[c]++;
            for (j = 0; j < FEATURES; j++) {
                sums[c][j] += data[i].x_scaled[j];
            }
        }
    }

    for (i = 0; i < K; i++) {
        if (counts[i] > 0) {
            for (j = 0; j < FEATURES; j++) {
                centroids[i][j] = sums[i][j] / counts[i];
            }
        }
    }
}

void print_centroids(double centroids[K][FEATURES]) {
    int i;
    printf("标准化空间中的聚类中心：\n");
    for (i = 0; i < K; i++) {
        printf("Cluster %d center = (%.3f, %.3f, %.3f, %.3f)\n",
               i,
               centroids[i][0], centroids[i][1], centroids[i][2], centroids[i][3]);
    }
    printf("\n");
}

void print_cluster_result(const Sample data[], int n) {
    int counts[K] = {0};
    int i, j;
    const char *names[3] = {"Iris-setosa", "Iris-versicolor", "Iris-virginica"};
    int label_count[K][3] = {0};

    for (i = 0; i < n; i++) {
        counts[data[i].cluster]++;
        for (j = 0; j < 3; j++) {
            if (strcmp(data[i].species, names[j]) == 0) {
                label_count[data[i].cluster][j]++;
            }
        }
    }

    printf("每个簇的样本数量与真实标签分布：\n");
    for (i = 0; i < K; i++) {
        printf("Cluster %d: total=%d, setosa=%d, versicolor=%d, virginica=%d\n",
               i,
               counts[i],
               label_count[i][0],
               label_count[i][1],
               label_count[i][2]);
    }
    printf("\n");
}

void print_assignments(const Sample data[], int n) {
    int i;
    printf("样本的聚类结果：\n");
    for (i = 0; i < n ; i++) {
        printf("id=%d, species=%-15s -> cluster %d\n",
               data[i].id, data[i].species, data[i].cluster);
    }
}

int main(int argc, char *argv[]) {
    Sample data[MAX_SAMPLES];
    double centroids[K][FEATURES];
    int n;
    int iter;
    Sample best_data[MAX_SAMPLES];
    double best_centroids[K][FEATURES];
    double best_sse = -1.0;
    int best_iter = 0;
    int run;
    const char *filename = (argc > 1) ? argv[1] : "Iris.csv";

    n = load_iris(filename, data);
    if (n <= 0) {
        printf("读取数据失败。\n");
        return 1;
    }

    srand((unsigned)time(NULL));
    standardize_features(data, n);
    printf("成功读取 %d 条 Iris 数据。\n\n", n);

    for (run = 0; run < RESTARTS; run++) {
        int i;

        for (i = 0; i < n; i++) {
            data[i].cluster = -1;
        }

        init_centroids(data, n, centroids);

        for (iter = 0; iter < MAX_ITER; iter++) {
            int changed = assign_clusters(data, n, centroids);
            update_centroids(data, n, centroids);
            if (!changed) {
                break;
            }
        }

        {
            double sse = total_sse(data, n, centroids);
            if (best_sse < 0 || sse < best_sse) {
                best_sse = sse;
                best_iter = iter + 1;
                copy_centroids(best_centroids, centroids);
                copy_clusters(best_data, data, n);
            }
        }
    }

    copy_centroids(centroids, best_centroids);
    copy_clusters(data, best_data, n);

    printf("K-means 重启次数: %d\n", RESTARTS);
    printf("最佳结果迭代次数: %d\n", best_iter);
    printf("最佳 SSE: %.6f\n\n", best_sse);
    print_centroids(centroids);
    print_cluster_result(data, n);
    print_assignments(data, n);
    return 0;
}
