#include <iostream>
#include <arrayfire.h>
#include <stdio.h>
#include <assert.h>
using namespace af;
static void multi_Sgemv(int iterations, int ngpu, array *AMatrix, int m, const float* X, int n, float *Y)
{
array *YVector = new array[ngpu];
float **YVector_host = new float* [ngpu];
for (int i = 0 ; i < iterations; i++) {
for (int idx = 0; idx < ngpu; idx++) {
deviceset(idx);
YVector[idx] = matmul(AMatrix[idx], array(n/ngpu, X + idx*(n/ngpu), afHost));
}
for (int idx = 0; idx < ngpu; idx++) {
deviceset(idx);
YVector_host[idx] = YVector[idx].host<float>();
}
for (int i = 0; i < m; i++) {
Y[i] = 0;
for (int j = 0; j < ngpu; j++)
Y[i] += YVector_host[j][i];
}
for (int idx = 0; idx < ngpu; idx++)
array::free(YVector_host[idx]);
}
delete [] YVector;
delete [] YVector_host;
}
static void ones(float *X, int n)
{
while (n--)
*(X++) = 1;
}
#define MB (1024 * 1024)
#define mb(x) (unsigned)((x) / MB + !!((x) % MB))
int main(int argc, char **argv)
{
try {
printf("Multi-GPU Matrix-Vector Multiply: y = A*x\n\n"
"The system matrix 'A' is distributed across the available devices.\n"
"Each iteration pushes 'x' to the devices, multiplies against the matrix 'A',\n"
"and pulls the result 'y' back to the host.\n\n");
info();
int iterations = 1000;
int ngpu = devicecount();
if (ngpu == 1) {
printf("found one device, exiting example\n");
return 0;
}
int n = ngpu*7000;
printf("size(A)=[%d,%d] (%u mb)\n", n, n, mb(n * n * sizeof(float)));
printf("\nBenchmarking........\n\n");
float *A = new float[n*n], *X = new float[n], *Y = new float[n];
ones(A, n*n);
ones(X, n*1);
array *AMatrix = new array[ngpu];
for (int idx = 0; idx < ngpu; idx++) {
deviceset(idx);
AMatrix[idx] = array(n, n/ngpu, A + (n*n/ngpu * idx), afHost);
}
delete[] A;
af::sync();
timer::start();
multi_Sgemv(iterations, ngpu, AMatrix, n, X, n, Y);
af::sync();
printf("Average time for %d iterations : %g seconds\n", iterations, timer::stop() / iterations);
for (int i = 0; i < n; ++i)
assert(Y[i] == n);
delete[] X; delete[] Y;
} catch (af::exception& e) {
fprintf(stderr, "%s\n", e.what());
throw;
}
#ifdef WIN32 // pause in Windows
if (!(argc == 2 && argv[1][0] == '-')) {
printf("hit [enter]...");
fflush(stdout);
getchar();
}
#endif
return 0;
}