211 lines
7.4 KiB
C++
211 lines
7.4 KiB
C++
#include "convolution.h"
|
|
#include <mpi.h>
|
|
#include <stdio.h>
|
|
#include <CL/cl.h>
|
|
#include <cuda_runtime.h>
|
|
#include <immintrin.h>
|
|
|
|
|
|
static float *input, *output, *filter;
|
|
static int N, C, H, W;
|
|
static int K, R, S;
|
|
static int OH, OW;
|
|
static int pad;
|
|
static int dilation;
|
|
static int stride;
|
|
static int mpi_rank, mpi_world_size;
|
|
|
|
static int my_min(int x, int y) {
|
|
return x < y ? x : y;
|
|
}
|
|
|
|
void my_alloc_tensor(float **t, int D0, int D1, int D2, int D3) {
|
|
*t = (float *) aligned_alloc(32, sizeof(float) * D0 * D1 * D2 * D3);
|
|
if (*t == NULL) {
|
|
printf("Failed to allocate memory for matrix.\n");
|
|
exit(0);
|
|
}
|
|
}
|
|
|
|
int ns[4];
|
|
int ne[4];
|
|
int ncounts[4];
|
|
|
|
#define NUM_THREADS 40
|
|
#define TSR 4
|
|
#define TSS 16
|
|
|
|
void convolution(
|
|
float *_input, float *_output, float *_filter,
|
|
int _N, int _C, int _H, int _W,
|
|
int _K, int _R, int _S,
|
|
int _pad, int _dilation, int _stride) {
|
|
input = _input;
|
|
output = _output;
|
|
filter = _filter;
|
|
|
|
MPI_Status status;
|
|
MPI_Request request;
|
|
|
|
OH = (H + 2 * pad - dilation * (R - 1) - 1) / stride + 1;
|
|
OW = (W + 2 * pad - dilation * (S - 1) - 1) / stride + 1;
|
|
|
|
for(int idx=0; idx<mpi_world_size; idx++) {
|
|
ns[idx] = N / mpi_world_size * idx + my_min(idx, N % mpi_world_size);
|
|
if(idx == (mpi_world_size-1))
|
|
ne[idx] = N;
|
|
else
|
|
ne[idx] = N / mpi_world_size * (idx + 1) + my_min(idx + 1, N % mpi_world_size);
|
|
ncounts[idx] = ne[idx]-ns[idx];
|
|
}
|
|
//memory allocation
|
|
if(mpi_rank != 0) {
|
|
my_alloc_tensor(&input, N, C, H, W);
|
|
my_alloc_tensor(&output, N, K, OH, OW);
|
|
my_alloc_tensor(&filter, K, C, R, S);
|
|
}
|
|
//send input
|
|
if (mpi_rank == 0) {
|
|
for(int proc=1; proc<mpi_world_size; proc++)
|
|
if(ncounts[proc] != 0)
|
|
MPI_Isend(input+ns[proc]*C*H*W, ncounts[proc]*C*H*W ,MPI_FLOAT, proc/*dst*/, proc /*tag*/, MPI_COMM_WORLD, &request);
|
|
}
|
|
else {
|
|
if(ncounts[mpi_rank] != 0)
|
|
MPI_Recv(input, ncounts[mpi_rank]*C*H*W, MPI_FLOAT, 0 /*src*/, mpi_rank /*tag*/, MPI_COMM_WORLD, &status);
|
|
}
|
|
//send filter
|
|
MPI_Bcast(filter, K*C*R*S, MPI_FLOAT, 0 /*src*/, MPI_COMM_WORLD);
|
|
#pragma omp parallel for num_threads(NUM_THREADS) collapse(3) schedule(dynamic)
|
|
for (int n = 0; n < ncounts[mpi_rank]; ++n) {
|
|
for (int k = 0; k < K; ++k) {
|
|
for (int oh = 0; oh < OH; ++oh) {
|
|
int h_start = oh * stride - pad;
|
|
for (int ow = 0; ow < OW; ++ow) {
|
|
int w_start = ow * stride - pad;
|
|
float o = 0.0f;
|
|
for (int c = 0; c < C; ++c) {
|
|
if((S == 16) && (R % TSR) == 0) {
|
|
for (int r = 0; r < R; r+=TSR) {
|
|
int h0 = h_start + (r+0) * dilation;
|
|
int h1 = h_start + (r+1) * dilation;
|
|
int h2 = h_start + (r+2) * dilation;
|
|
int h3 = h_start + (r+3) * dilation;
|
|
float local_input[TSR][TSS];
|
|
for (int s = 0; s < S; s++) {
|
|
int w = w_start + s * dilation;
|
|
if(w < 0 || w >= W) {
|
|
local_input[0][s] = 0;
|
|
local_input[1][s] = 0;
|
|
local_input[2][s] = 0;
|
|
local_input[3][s] = 0;
|
|
}
|
|
else {
|
|
if (h0 < 0 || h0 >= H)
|
|
local_input[0][s] = 0;
|
|
else
|
|
local_input[0][s] = input[n * C * H * W + c * H * W + h0 * W + w];
|
|
if (h1 < 0 || h1 >= H)
|
|
local_input[1][s] = 0;
|
|
else
|
|
local_input[1][s] = input[n * C * H * W + c * H * W + h1 * W + w];
|
|
if (h2 < 0 || h2 >= H)
|
|
local_input[2][s] = 0;
|
|
else
|
|
local_input[2][s] = input[n * C * H * W + c * H * W + h2 * W + w];
|
|
if (h3 < 0 || h3 >= H)
|
|
local_input[3][s] = 0;
|
|
else
|
|
local_input[3][s] = input[n * C * H * W + c * H * W + h3 * W + w];
|
|
}
|
|
}
|
|
//case1
|
|
__m512 i0 = _mm512_loadu_ps(&local_input[0][0]);
|
|
__m512 i1 = _mm512_loadu_ps(&local_input[1][0]);
|
|
__m512 i2 = _mm512_loadu_ps(&local_input[2][0]);
|
|
__m512 i3 = _mm512_loadu_ps(&local_input[3][0]);
|
|
__m512 f0 = _mm512_loadu_ps(&filter[k * C * R * S + c * R * S + (r+0) * S + 0]);
|
|
__m512 f1 = _mm512_loadu_ps(&filter[k * C * R * S + c * R * S + (r+1) * S + 0]);
|
|
__m512 f2 = _mm512_loadu_ps(&filter[k * C * R * S + c * R * S + (r+2) * S + 0]);
|
|
__m512 f3 = _mm512_loadu_ps(&filter[k * C * R * S + c * R * S + (r+3) * S + 0]);
|
|
__m512 to0 = _mm512_mul_ps(i0, f0);
|
|
__m512 to1 = _mm512_mul_ps(i1, f1);
|
|
__m512 to2 = _mm512_mul_ps(i2, f2);
|
|
__m512 to3 = _mm512_mul_ps(i3, f3);
|
|
o += _mm512_reduce_add_ps(to0);
|
|
o += _mm512_reduce_add_ps(to1);
|
|
o += _mm512_reduce_add_ps(to2);
|
|
o += _mm512_reduce_add_ps(to3);
|
|
}
|
|
}
|
|
else {
|
|
for (int r = 0; r < R; ++r) {
|
|
int h = h_start + r * dilation;
|
|
for (int s = 0; s < S; ++s) {
|
|
int w = w_start + s * dilation;
|
|
if (h < 0 || h >= H || w < 0 || w >= W) continue;
|
|
float i = input[n * C * H * W + c * H * W + h * W + w];
|
|
float f = filter[k * C * R * S + c * R * S + r * S + s];
|
|
o += i * f;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
output[n * K * OH * OW + k * OH * OW + oh * OW + ow] = o;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
//send output
|
|
if (mpi_rank == 0) {
|
|
for(int proc=1; proc<mpi_world_size; proc++)
|
|
if(ncounts[proc] != 0)
|
|
MPI_Recv(output+(ns[proc]*K*OH*OW), ncounts[proc]*K*OH*OW, MPI_FLOAT, proc /*src*/, proc /*tag*/, MPI_COMM_WORLD, &status);
|
|
}
|
|
else {
|
|
if(ncounts[mpi_rank] != 0)
|
|
MPI_Isend(output, ncounts[mpi_rank]*K*OH*OW, MPI_FLOAT, 0 /*dst*/, mpi_rank /*tag*/, MPI_COMM_WORLD, &request);
|
|
}
|
|
}
|
|
|
|
|
|
void convolution_init(
|
|
int _N, int _C, int _H, int _W,
|
|
int _K, int _R, int _S,
|
|
int _pad, int _dilation, int _stride) {
|
|
N = _N; C = _C; H = _H; W = _W;
|
|
K = _K; R = _R; S = _S;
|
|
pad = _pad;
|
|
dilation = _dilation;
|
|
stride = _stride;
|
|
|
|
MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank);
|
|
MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size);
|
|
|
|
//moved:OH = (H + 2 * pad - dilation * (R - 1) - 1) / stride + 1;
|
|
//moved:OW = (W + 2 * pad - dilation * (S - 1) - 1) / stride + 1;
|
|
|
|
}
|
|
|
|
void convolution_final(
|
|
int _N, int _C, int _H, int _W,
|
|
int _K, int _R, int _S,
|
|
int _pad, int _dilation, int _stride) {
|
|
|
|
//moved:MPI_Status status;
|
|
//moved:MPI_Request request;
|
|
|
|
//moved://send output
|
|
//moved:if (mpi_rank == 0) {
|
|
//moved: for(int proc=1; proc<mpi_world_size; proc++)
|
|
//moved: if(ncounts[proc] != 0)
|
|
//moved: MPI_Recv(output+(ns[proc]*K*OH*OW), ncounts[proc]*K*OH*OW, MPI_FLOAT, proc /*src*/, proc /*tag*/, MPI_COMM_WORLD, &status);
|
|
//moved:}
|
|
//moved:else {
|
|
//moved: if(ncounts[mpi_rank] != 0)
|
|
//moved: MPI_Isend(output, ncounts[mpi_rank]*K*OH*OW, MPI_FLOAT, 0 /*dst*/, mpi_rank /*tag*/, MPI_COMM_WORLD, &request);
|
|
//moved:}
|
|
|
|
}
|
|
|