#include "util.h" #include "convolution.h" #include #include #include #include static float *input, *output, *filter; static int N, C, H, W; static int K, R, S; static int OH, OW; static int pad; static int dilation; static int stride; static int mpi_rank, mpi_world_size; void convolution( float *_input, float *_output, float *_filter, int _N, int _C, int _H, int _W, int _K, int _R, int _S, int _pad, int _dilation, int _stride) { int mpi_chunk[mpi_world_size]; MPI_Request request; MPI_Status status; input = _input; output = _output; filter = _filter; mpi_chunk[mpi_rank] = N/mpi_world_size + (N - N/mpi_world_size*mpi_world_size); if(mpi_rank == 1) mpi_chunk[mpi_rank] = N/mpi_world_size; OH = (H + 2 * pad - dilation * (R - 1) - 1) / stride + 1; OW = (W + 2 * pad - dilation * (S - 1) - 1) / stride + 1; if(mpi_world_size == 2) { if (mpi_rank == 0) { int remain_chunk = N - mpi_chunk[mpi_rank]; MPI_Isend(input+mpi_chunk[mpi_rank]*C*H*W, remain_chunk*C*H*W, MPI_FLOAT, 1, 0, MPI_COMM_WORLD, &request); } else { alloc_tensor(&input, mpi_chunk[mpi_rank], C, H, W); alloc_tensor(&output, mpi_chunk[mpi_rank], K, OH, OW); alloc_tensor(&filter, K, C, R, S); MPI_Recv(input, mpi_chunk[mpi_rank]*C*H*W, MPI_FLOAT, 0, 0, MPI_COMM_WORLD, &status); } MPI_Bcast(filter, K*C*R*S, MPI_FLOAT, 0, MPI_COMM_WORLD); } else { mpi_chunk[mpi_rank] = N; } int CHW = C*H*W; int HW = H*W; int CRS = C*R*S; int RS = R*S; #pragma omp parallel for num_threads(40) collapse(4) schedule(dynamic) for (int n = 0; n < mpi_chunk[mpi_rank]; ++n) { for (int k = 0; k < K; ++k) { for (int oh = 0; oh < OH; ++oh) { for (int ow = 0; ow < OW; ++ow) { int hpos = oh*stride - pad; int wpos = ow*stride - pad; __m512 avx_o = _mm512_setzero_ps(); for (int c = 0; c < C; ++c) { int Rbegin = (dilation - hpos - 1)/dilation > 0 ? (dilation - hpos - 1)/dilation : 0; int Sbegin = (dilation - wpos - 1)/dilation > 0 ? (dilation - wpos - 1)/dilation : 0; int Rend = (H - hpos + dilation - 1)/dilation < R ? (H - hpos + dilation - 1)/dilation : R; int Send = (W - wpos + dilation - 1)/dilation < S ? (W - wpos + dilation - 1)/dilation : S; int h = hpos + Rbegin * dilation; for (int r = Rbegin; r < Rend; ++r) { int w = wpos; for(int s = 0 ; s < Send; s += 16) { int i_idx = n * CHW + c * HW + h * W + w; int f_idx = k * CRS + c * RS + r * S + s; float v_i[16]; float v_f[16]; int internal_idx = 0; for(int v_idx = 0; v_idx < 16; ++v_idx) { if(s+v_idx < Sbegin || s+v_idx >= Send) { v_i[v_idx] = 0; v_f[v_idx] = 0; } else { v_i[v_idx] = input[i_idx+internal_idx]; v_f[v_idx] = filter[f_idx+v_idx]; } internal_idx += dilation; } __m512 avx_i = _mm512_load_ps(v_i); __m512 avx_f = _mm512_load_ps(v_f); avx_o = _mm512_fmadd_ps(avx_i, avx_f, avx_o); w += dilation << 4; } h += dilation; } // r } // c float *res = (float *)&avx_o; res[ 0] += res[ 1]; res[ 2] += res[ 3]; res[ 4] += res[ 5]; res[ 6] += res[ 7]; res[ 8] += res[ 9]; res[10] += res[11]; res[12] += res[13]; res[14] += res[15]; res[ 0] += res[ 2]; res[ 4] += res[ 6]; res[ 8] += res[10]; res[12] += res[14]; res[ 0] += res[ 4]; res[ 8] += res[ 12]; res[ 0] += res[8]; output[n * K * OH * OW + k * OH * OW + oh * OW + ow] = res[0]; } //ow } // oh } // k } // n if(mpi_world_size == 2) { if (mpi_rank == 0) { int remain_chunk = N - mpi_chunk[mpi_rank]; MPI_Recv(output+mpi_chunk[mpi_rank]*K*OH*OW, remain_chunk*K*OH*OW, MPI_FLOAT, 1, 0, MPI_COMM_WORLD, &status); } else { MPI_Isend(output, mpi_chunk[mpi_rank]*K*OH*OW, MPI_FLOAT, 0, 0, MPI_COMM_WORLD, &request); } } } void convolution_init( int _N, int _C, int _H, int _W, int _K, int _R, int _S, int _pad, int _dilation, int _stride) { N = _N; C = _C; H = _H; W = _W; K = _K; R = _R; S = _S; pad = _pad; dilation = _dilation; stride = _stride; MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank); MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size); } void convolution_final( int _N, int _C, int _H, int _W, int _K, int _R, int _S, int _pad, int _dilation, int _stride) { }