250 lines
8.2 KiB
C++
250 lines
8.2 KiB
C++
|
#include "convolution.h"
|
||
|
#include <mpi.h>
|
||
|
#include <omp.h>
|
||
|
#include <stdio.h>
|
||
|
|
||
|
static float *input, *output, *filter;
|
||
|
static int N, C, H, W;
|
||
|
static int K, R, S;
|
||
|
static int OH, OW;
|
||
|
static int pad;
|
||
|
static int dilation;
|
||
|
static int stride;
|
||
|
static int mpi_rank, mpi_world_size;
|
||
|
int portion, low_bound, upper_bound;
|
||
|
MPI_Status status;
|
||
|
MPI_Request request;
|
||
|
|
||
|
void convolution(
|
||
|
float *_input, float *_output, float *_filter,
|
||
|
int _N, int _C, int _H, int _W,
|
||
|
int _K, int _R, int _S,
|
||
|
int _pad, int _dilation, int _stride) {
|
||
|
input = _input;
|
||
|
output = _output;
|
||
|
filter = _filter;
|
||
|
|
||
|
MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank);
|
||
|
MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size);
|
||
|
OH = (H + 2 * pad - dilation * (R - 1) - 1) / stride + 1;
|
||
|
OW = (W + 2 * pad - dilation * (S - 1) - 1) / stride + 1;
|
||
|
|
||
|
int num_threads= K < 40 ? K : 40;
|
||
|
|
||
|
int OHB,OWB,SB,RB,CB,NB;
|
||
|
int block = 32;
|
||
|
int rsblock = 4;
|
||
|
int ccblock = 4;
|
||
|
int nnblock = 16;
|
||
|
int ohblock = (OH / block ) * block;
|
||
|
int owblock = (OW / block ) * block;
|
||
|
//int nblock = ((upper_bound-low_bound) / nblock ) * nblock + low_bound;
|
||
|
int sblock = (S / rsblock ) * rsblock;
|
||
|
int rblock = (R / rsblock ) * rsblock;
|
||
|
int cblock = (C / ccblock ) * ccblock;
|
||
|
|
||
|
size_t bytes_nchw = N*C*H*W*sizeof(float);
|
||
|
size_t bytes_kcrs = K*C*R*S*sizeof(float);
|
||
|
size_t bytes_nkohow = N*K*OH*OW*sizeof(float);
|
||
|
|
||
|
if (mpi_rank == 0) {
|
||
|
portion = (N / mpi_world_size);
|
||
|
for(int i=1; i < mpi_world_size; i++){
|
||
|
low_bound = i*portion;
|
||
|
if( ((i+1)==mpi_world_size) && (( N % mpi_world_size) != 0 ) ){
|
||
|
upper_bound = N;
|
||
|
}
|
||
|
else{
|
||
|
upper_bound = low_bound + portion;
|
||
|
}
|
||
|
|
||
|
MPI_Isend(&low_bound, 1, MPI_INT, i, 1, MPI_COMM_WORLD, &request);
|
||
|
MPI_Isend(&upper_bound, 1, MPI_INT, i, 2, MPI_COMM_WORLD, &request);
|
||
|
|
||
|
MPI_Isend(&input[low_bound*C*H*W], ( (upper_bound - low_bound)*C*H*W ), MPI_FLOAT, i, 3, MPI_COMM_WORLD, &request);
|
||
|
MPI_Isend(&filter[0], (K*C*R*S), MPI_FLOAT, i, 4, MPI_COMM_WORLD, &request);
|
||
|
}
|
||
|
|
||
|
low_bound = 0;
|
||
|
upper_bound = portion;
|
||
|
//int nblock = ((upper_bound-low_bound) / nnblock ) * nnblock + low_bound;
|
||
|
#pragma omp parallel num_threads(num_threads)
|
||
|
{
|
||
|
int num = omp_get_thread_num();
|
||
|
int slice = K / num_threads;
|
||
|
int start = num*slice;
|
||
|
int end = (num == num_threads - 1)? K : (num+1)*slice;
|
||
|
|
||
|
#pragma omp parallel for schedule(static) private(OHB,OWB)
|
||
|
/////////////////////////
|
||
|
//for (int nn = low_bound; nn < nblock; nn+=nnblock) {
|
||
|
for (int ohh = 0; ohh < ohblock; ohh+=block) {
|
||
|
for (int oww = 0; oww < owblock; oww+=block) {
|
||
|
//for (int nn = low_bound; nn < nblock; nn+=nnblock) {
|
||
|
//for (int cc = 0; cc < cblock; cc+=ccblock) {
|
||
|
//for (int rr = 0; rr < rblock; rr+=rsblock) {
|
||
|
//for (int ss = 0; ss < sblock; ss+=rsblock) {
|
||
|
///////////////
|
||
|
for (int k = start; k < end; ++k) {
|
||
|
// if(nn + nnblock < nblock) NB = nn + nnblock;
|
||
|
// else NB = upper_bound;
|
||
|
// for (int n = nn; n < NB; ++n) {
|
||
|
for (int n = low_bound; n < upper_bound; ++n) {
|
||
|
if(ohh + block < ohblock) OHB = ohh + block;
|
||
|
else OHB = OH;
|
||
|
for (int oh = ohh; oh < OHB; ++oh) {
|
||
|
if(oww + block < owblock) OWB = oww + block;
|
||
|
else OWB = OW;
|
||
|
for (int ow = oww; ow < OWB; ++ow) {
|
||
|
//for (int rr = 0; rr < rblock; rr+=rsblock) {
|
||
|
//for (int ss = 0; ss < sblock; ss+=rsblock) {
|
||
|
float o = 0.f;
|
||
|
|
||
|
// if(cc + ccblock < cblock) CB = cc + ccblock;
|
||
|
// else CB = C;
|
||
|
for (int c = 0; c < C; ++c) {
|
||
|
// if(rr + rsblock < rblock) RB = rr + rsblock;
|
||
|
// else RB = R;
|
||
|
for (int r = 0; r < R; ++r) {
|
||
|
// if(ss + rsblock < sblock) SB = ss + rsblock;
|
||
|
// else SB = S;
|
||
|
for (int s = 0; s < S; ++s) {
|
||
|
int h = oh * stride - pad + r * dilation;
|
||
|
int w = ow * stride - pad + s * dilation;
|
||
|
if (h < 0 || h >= H || w < 0 || w >= W) continue;
|
||
|
float i = input[n * C * H * W + c * H * W + h * W + w];
|
||
|
float f = filter[k * C * R * S + c * R * S + r * S + s];
|
||
|
o += i * f;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
output[n * K * OH * OW + k * OH * OW + oh * OW + ow] = o;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
///////////
|
||
|
}
|
||
|
}
|
||
|
//}
|
||
|
//}
|
||
|
//}
|
||
|
/////////////////////////
|
||
|
} // omp
|
||
|
|
||
|
for(int i = 1; i < mpi_world_size; i++){
|
||
|
|
||
|
MPI_Recv(&low_bound, 1, MPI_INT, i, 4, MPI_COMM_WORLD, &status);
|
||
|
MPI_Recv(&upper_bound, 1, MPI_INT, i, 5, MPI_COMM_WORLD, &status);
|
||
|
MPI_Recv(&output[low_bound*K*OH*OW], ( (upper_bound - low_bound)*K*OH*OW ), MPI_FLOAT, i, 6, MPI_COMM_WORLD, &status);
|
||
|
}
|
||
|
|
||
|
} // mpi_rank == 0
|
||
|
else{
|
||
|
|
||
|
MPI_Recv(&low_bound, 1, MPI_INT, 0, 1, MPI_COMM_WORLD, &status);
|
||
|
MPI_Recv(&upper_bound, 1, MPI_INT, 0, 2, MPI_COMM_WORLD, &status);
|
||
|
|
||
|
input = (float*)malloc(bytes_nchw);
|
||
|
filter = (float*)malloc(bytes_kcrs);
|
||
|
output = (float*)malloc(bytes_nkohow);
|
||
|
MPI_Recv(&input[low_bound*C*H*W], ( (upper_bound - low_bound)*C*H*W ), MPI_FLOAT, 0, 3, MPI_COMM_WORLD, &status);
|
||
|
MPI_Recv(&filter[0], (K*C*R*S ), MPI_FLOAT, 0, 4, MPI_COMM_WORLD, &status);
|
||
|
//int nblock = ((upper_bound-low_bound) / nnblock ) * nnblock + low_bound;
|
||
|
|
||
|
#pragma omp parallel num_threads(num_threads)
|
||
|
{
|
||
|
int num = omp_get_thread_num();
|
||
|
int slice = K / num_threads;
|
||
|
int start = num*slice;
|
||
|
int end = (num == num_threads - 1)? K : (num+1)*slice;
|
||
|
|
||
|
#pragma omp parallel for schedule(static) private(OHB,OWB)
|
||
|
/////////////////////////
|
||
|
//for (int nn = low_bound; nn < nblock; nn+=nnblock) {
|
||
|
for (int ohh = 0; ohh < ohblock; ohh+=block) {
|
||
|
for (int oww = 0; oww < owblock; oww+=block) {
|
||
|
//for (int cc = 0; cc < cblock; cc+=ccblock) {
|
||
|
//for (int rr = 0; rr < rblock; rr+=rsblock) {
|
||
|
//for (int ss = 0; ss < sblock; ss+=rsblock) {
|
||
|
/////////////////////////
|
||
|
// for (int n = low_bound; n < upper_bound; ++n) {
|
||
|
for (int k = start; k < end; ++k) {
|
||
|
// for (int k = 0; k < K; ++k) {
|
||
|
// if(nn + nnblock < nblock) NB = nn + nnblock;
|
||
|
// else NB = upper_bound;
|
||
|
// for (int n = nn; n < NB; ++n) {
|
||
|
for (int n = low_bound; n < upper_bound; ++n) {
|
||
|
if(ohh + block < ohblock) OHB = ohh + block;
|
||
|
else OHB = OH;
|
||
|
for (int oh = ohh; oh < OHB; ++oh) {
|
||
|
if(oww + block < owblock) OWB = oww + block;
|
||
|
else OWB = OW;
|
||
|
for (int ow = oww; ow < OWB; ++ow) {
|
||
|
//for (int rr = 0; rr < rblock; rr+=rsblock) {
|
||
|
//for (int ss = 0; ss < sblock; ss+=rsblock) {
|
||
|
float o = 0.f;
|
||
|
// for (int c = start; c < end; ++c) {
|
||
|
// if(cc + ccblock < cblock) CB = cc + ccblock;
|
||
|
// else CB = C;
|
||
|
for (int c = 0; c < C; ++c) {
|
||
|
// if(rr + rsblock < rblock) RB = rr + rsblock;
|
||
|
// else RB = R;
|
||
|
for (int r = 0; r < R; ++r) {
|
||
|
// if(ss + rsblock < sblock) SB = ss + rsblock;
|
||
|
// else SB = S;
|
||
|
for (int s = 0; s < S; ++s) {
|
||
|
int h = oh * stride - pad + r * dilation;
|
||
|
int w = ow * stride - pad + s * dilation;
|
||
|
if (h < 0 || h >= H || w < 0 || w >= W) continue;
|
||
|
float i = input[n * C * H * W + c * H * W + h * W + w];
|
||
|
float f = filter[k * C * R * S + c * R * S + r * S + s];
|
||
|
o += i * f;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
output[n * K * OH * OW + k * OH * OW + oh * OW + ow] = o;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
/////////////////////////
|
||
|
}
|
||
|
}
|
||
|
//}
|
||
|
//}
|
||
|
//}
|
||
|
////////////////////
|
||
|
} // omp
|
||
|
MPI_Isend(&low_bound, 1, MPI_INT, 0, 4, MPI_COMM_WORLD, &request);
|
||
|
MPI_Isend(&upper_bound, 1, MPI_INT, 0, 5, MPI_COMM_WORLD, &request);
|
||
|
MPI_Isend(&output[low_bound*K*OH*OW], ( (upper_bound - low_bound)*K*OH*OW ), MPI_FLOAT, 0, 6, MPI_COMM_WORLD, &request);
|
||
|
|
||
|
free(input);
|
||
|
free(filter);
|
||
|
free(output);
|
||
|
}
|
||
|
|
||
|
}
|
||
|
|
||
|
void convolution_init(
|
||
|
int _N, int _C, int _H, int _W,
|
||
|
int _K, int _R, int _S,
|
||
|
int _pad, int _dilation, int _stride) {
|
||
|
N = _N; C = _C; H = _H; W = _W;
|
||
|
K = _K; R = _R; S = _S;
|
||
|
pad = _pad;
|
||
|
dilation = _dilation;
|
||
|
stride = _stride;
|
||
|
|
||
|
MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank);
|
||
|
MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size);
|
||
|
|
||
|
}
|
||
|
|
||
|
void convolution_final(
|
||
|
int _N, int _C, int _H, int _W,
|
||
|
int _K, int _R, int _S,
|
||
|
int _pad, int _dilation, int _stride) {
|
||
|
}
|