#include <iostream>
#include <stdio.h>
#include <arrayfire.h>
#include "ppm_utils.h"
using namespace af;
int k = 3;
void kmeans(array& data, int k, array& means) {
array datavec = flat(data);
float minimum = min<float>(datavec);
datavec = datavec - minimum;
int nbins = max<float>(datavec) + 1;
array means_vec = array(seq(k)) * nbins / (k + 1);
printf("min %g\nmax %d\n", minimum, nbins);
array hist_counts = constant(0,1, nbins);
array hist = histogram(datavec, nbins);
array hist_idx = where(hist);
int num_uniq = hist_idx.elements();
while (1) {
array prev_means = means_vec;
gfor(array i, num_uniq) {
array diffs = abs(hist_idx(i) - means_vec);
array val, idx;
min(val, idx, diffs);
hist_counts(hist_idx(i)) = idx;
}
for (int i = 0; i < k; ++i) {
array m = where(hist_counts == i);
means_vec(i) = sum(m * hist(m)) / sum(hist(m));
}
if (norm<float>(means_vec - prev_means) < 1) { break; }
}
means = means_vec + minimum;
}
array get_mask(array& data, array& means_vec, int k) {
array vol = constant(0,data.dims()[0], data.dims()[1], k);
gfor(array i, k) {
vol(span, span, i) = abs(data - means_vec(i));
}
array mask = constant(0,data.dims());
array idx, val;
min(val, idx, vol, 2);
for(int i=0; i<k; ++i){
mask = mask + means_vec(i) * (idx==i);
}
return mask;
}
void kmeans(array& data, int k, array& means, array& mask) {
kmeans(data, k, means);
mask = get_mask(data, means, k);
}
void kmeans_demo(bool console) {
printf("k = %d\n", k);
array img = load_gray_ppm("../image_processing/len_std.ppm") * 255.f;
array means, mask;
kmeans(img, k, means, mask);
print(means);
if (!console) {
fig("color","gray");
fig("sub",2,1,1); image(img); fig("title","input");
fig("sub",2,1,2); image(mask); fig("title","kmeans shift");
fig("draw");
}
}
int main(int argc, char** argv) {
bool console = false;
if (argc > 2 || (argc == 2 && argv[1][0] != '-')) {
printf("usage: kmeans_demo [-]\n");
return -1;
} else if (argc == 2 && argv[1][0] == '-') {
console = true;
}
try {
printf("** ArrayFire K-Means Demo **\n\n");
kmeans_demo(console);
} catch (af::exception& e) {
fprintf(stderr, "%s\n", e.what());
}
if (!console) {
printf("hit [enter]...");
getchar();
}
return 0;
}