chundoong-lab-ta/SamsungDS22/submissions/final/hkyoo.kim/A/convolution.cpp

250 lines
8.2 KiB
C++

#include "convolution.h"
#include <mpi.h>
#include <omp.h>
#include <stdio.h>
static float *input, *output, *filter;
static int N, C, H, W;
static int K, R, S;
static int OH, OW;
static int pad;
static int dilation;
static int stride;
static int mpi_rank, mpi_world_size;
int portion, low_bound, upper_bound;
MPI_Status status;
MPI_Request request;
void convolution(
float *_input, float *_output, float *_filter,
int _N, int _C, int _H, int _W,
int _K, int _R, int _S,
int _pad, int _dilation, int _stride) {
input = _input;
output = _output;
filter = _filter;
MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank);
MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size);
OH = (H + 2 * pad - dilation * (R - 1) - 1) / stride + 1;
OW = (W + 2 * pad - dilation * (S - 1) - 1) / stride + 1;
int num_threads= K < 40 ? K : 40;
int OHB,OWB,SB,RB,CB,NB;
int block = 32;
int rsblock = 4;
int ccblock = 4;
int nnblock = 16;
int ohblock = (OH / block ) * block;
int owblock = (OW / block ) * block;
//int nblock = ((upper_bound-low_bound) / nblock ) * nblock + low_bound;
int sblock = (S / rsblock ) * rsblock;
int rblock = (R / rsblock ) * rsblock;
int cblock = (C / ccblock ) * ccblock;
size_t bytes_nchw = N*C*H*W*sizeof(float);
size_t bytes_kcrs = K*C*R*S*sizeof(float);
size_t bytes_nkohow = N*K*OH*OW*sizeof(float);
if (mpi_rank == 0) {
portion = (N / mpi_world_size);
for(int i=1; i < mpi_world_size; i++){
low_bound = i*portion;
if( ((i+1)==mpi_world_size) && (( N % mpi_world_size) != 0 ) ){
upper_bound = N;
}
else{
upper_bound = low_bound + portion;
}
MPI_Isend(&low_bound, 1, MPI_INT, i, 1, MPI_COMM_WORLD, &request);
MPI_Isend(&upper_bound, 1, MPI_INT, i, 2, MPI_COMM_WORLD, &request);
MPI_Isend(&input[low_bound*C*H*W], ( (upper_bound - low_bound)*C*H*W ), MPI_FLOAT, i, 3, MPI_COMM_WORLD, &request);
MPI_Isend(&filter[0], (K*C*R*S), MPI_FLOAT, i, 4, MPI_COMM_WORLD, &request);
}
low_bound = 0;
upper_bound = portion;
//int nblock = ((upper_bound-low_bound) / nnblock ) * nnblock + low_bound;
#pragma omp parallel num_threads(num_threads)
{
int num = omp_get_thread_num();
int slice = K / num_threads;
int start = num*slice;
int end = (num == num_threads - 1)? K : (num+1)*slice;
#pragma omp parallel for schedule(static) private(OHB,OWB)
/////////////////////////
//for (int nn = low_bound; nn < nblock; nn+=nnblock) {
for (int ohh = 0; ohh < ohblock; ohh+=block) {
for (int oww = 0; oww < owblock; oww+=block) {
//for (int nn = low_bound; nn < nblock; nn+=nnblock) {
//for (int cc = 0; cc < cblock; cc+=ccblock) {
//for (int rr = 0; rr < rblock; rr+=rsblock) {
//for (int ss = 0; ss < sblock; ss+=rsblock) {
///////////////
for (int k = start; k < end; ++k) {
// if(nn + nnblock < nblock) NB = nn + nnblock;
// else NB = upper_bound;
// for (int n = nn; n < NB; ++n) {
for (int n = low_bound; n < upper_bound; ++n) {
if(ohh + block < ohblock) OHB = ohh + block;
else OHB = OH;
for (int oh = ohh; oh < OHB; ++oh) {
if(oww + block < owblock) OWB = oww + block;
else OWB = OW;
for (int ow = oww; ow < OWB; ++ow) {
//for (int rr = 0; rr < rblock; rr+=rsblock) {
//for (int ss = 0; ss < sblock; ss+=rsblock) {
float o = 0.f;
// if(cc + ccblock < cblock) CB = cc + ccblock;
// else CB = C;
for (int c = 0; c < C; ++c) {
// if(rr + rsblock < rblock) RB = rr + rsblock;
// else RB = R;
for (int r = 0; r < R; ++r) {
// if(ss + rsblock < sblock) SB = ss + rsblock;
// else SB = S;
for (int s = 0; s < S; ++s) {
int h = oh * stride - pad + r * dilation;
int w = ow * stride - pad + s * dilation;
if (h < 0 || h >= H || w < 0 || w >= W) continue;
float i = input[n * C * H * W + c * H * W + h * W + w];
float f = filter[k * C * R * S + c * R * S + r * S + s];
o += i * f;
}
}
}
output[n * K * OH * OW + k * OH * OW + oh * OW + ow] = o;
}
}
}
}
///////////
}
}
//}
//}
//}
/////////////////////////
} // omp
for(int i = 1; i < mpi_world_size; i++){
MPI_Recv(&low_bound, 1, MPI_INT, i, 4, MPI_COMM_WORLD, &status);
MPI_Recv(&upper_bound, 1, MPI_INT, i, 5, MPI_COMM_WORLD, &status);
MPI_Recv(&output[low_bound*K*OH*OW], ( (upper_bound - low_bound)*K*OH*OW ), MPI_FLOAT, i, 6, MPI_COMM_WORLD, &status);
}
} // mpi_rank == 0
else{
MPI_Recv(&low_bound, 1, MPI_INT, 0, 1, MPI_COMM_WORLD, &status);
MPI_Recv(&upper_bound, 1, MPI_INT, 0, 2, MPI_COMM_WORLD, &status);
input = (float*)malloc(bytes_nchw);
filter = (float*)malloc(bytes_kcrs);
output = (float*)malloc(bytes_nkohow);
MPI_Recv(&input[low_bound*C*H*W], ( (upper_bound - low_bound)*C*H*W ), MPI_FLOAT, 0, 3, MPI_COMM_WORLD, &status);
MPI_Recv(&filter[0], (K*C*R*S ), MPI_FLOAT, 0, 4, MPI_COMM_WORLD, &status);
//int nblock = ((upper_bound-low_bound) / nnblock ) * nnblock + low_bound;
#pragma omp parallel num_threads(num_threads)
{
int num = omp_get_thread_num();
int slice = K / num_threads;
int start = num*slice;
int end = (num == num_threads - 1)? K : (num+1)*slice;
#pragma omp parallel for schedule(static) private(OHB,OWB)
/////////////////////////
//for (int nn = low_bound; nn < nblock; nn+=nnblock) {
for (int ohh = 0; ohh < ohblock; ohh+=block) {
for (int oww = 0; oww < owblock; oww+=block) {
//for (int cc = 0; cc < cblock; cc+=ccblock) {
//for (int rr = 0; rr < rblock; rr+=rsblock) {
//for (int ss = 0; ss < sblock; ss+=rsblock) {
/////////////////////////
// for (int n = low_bound; n < upper_bound; ++n) {
for (int k = start; k < end; ++k) {
// for (int k = 0; k < K; ++k) {
// if(nn + nnblock < nblock) NB = nn + nnblock;
// else NB = upper_bound;
// for (int n = nn; n < NB; ++n) {
for (int n = low_bound; n < upper_bound; ++n) {
if(ohh + block < ohblock) OHB = ohh + block;
else OHB = OH;
for (int oh = ohh; oh < OHB; ++oh) {
if(oww + block < owblock) OWB = oww + block;
else OWB = OW;
for (int ow = oww; ow < OWB; ++ow) {
//for (int rr = 0; rr < rblock; rr+=rsblock) {
//for (int ss = 0; ss < sblock; ss+=rsblock) {
float o = 0.f;
// for (int c = start; c < end; ++c) {
// if(cc + ccblock < cblock) CB = cc + ccblock;
// else CB = C;
for (int c = 0; c < C; ++c) {
// if(rr + rsblock < rblock) RB = rr + rsblock;
// else RB = R;
for (int r = 0; r < R; ++r) {
// if(ss + rsblock < sblock) SB = ss + rsblock;
// else SB = S;
for (int s = 0; s < S; ++s) {
int h = oh * stride - pad + r * dilation;
int w = ow * stride - pad + s * dilation;
if (h < 0 || h >= H || w < 0 || w >= W) continue;
float i = input[n * C * H * W + c * H * W + h * W + w];
float f = filter[k * C * R * S + c * R * S + r * S + s];
o += i * f;
}
}
}
output[n * K * OH * OW + k * OH * OW + oh * OW + ow] = o;
}
}
}
}
/////////////////////////
}
}
//}
//}
//}
////////////////////
} // omp
MPI_Isend(&low_bound, 1, MPI_INT, 0, 4, MPI_COMM_WORLD, &request);
MPI_Isend(&upper_bound, 1, MPI_INT, 0, 5, MPI_COMM_WORLD, &request);
MPI_Isend(&output[low_bound*K*OH*OW], ( (upper_bound - low_bound)*K*OH*OW ), MPI_FLOAT, 0, 6, MPI_COMM_WORLD, &request);
free(input);
free(filter);
free(output);
}
}
void convolution_init(
int _N, int _C, int _H, int _W,
int _K, int _R, int _S,
int _pad, int _dilation, int _stride) {
N = _N; C = _C; H = _H; W = _W;
K = _K; R = _R; S = _S;
pad = _pad;
dilation = _dilation;
stride = _stride;
MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank);
MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size);
}
void convolution_final(
int _N, int _C, int _H, int _W,
int _K, int _R, int _S,
int _pad, int _dilation, int _stride) {
}