#include "convolution.h" #include #include #include "util.h" #include static float *input, *output, *filter; static int N, C, H, W; static int K, R, S; static int OH, OW; static int pad; static int dilation; static int stride; static int mpi_rank, mpi_world_size; static void perform_cal(int begin, int end) { #pragma omp parallel for collapse(3) num_threads(100) for (int n = begin; n < end; ++n) { for (int k = 0; k < K; ++k) { for (int oh = 0; oh < OH; ++oh) { for (int ow = 0; ow < OW; ++ow) { __m512 o = _mm512_setzero_ps(); for (int c = 0; c < C; c++) { float* in_p = &input[n * C * H * W + oh * W + c * H * W + ow]; float* filt_p = &filter[k * C * R * S + c * R * S]; __m512 f0 = _mm512_load_ps(&filt_p[0]); __m512 f1 = _mm512_load_ps(&filt_p[S * 1]); __m512 f2 = _mm512_load_ps(&filt_p[S * 2]); __m512 f3 = _mm512_load_ps(&filt_p[S * 3]); __m512 f4 = _mm512_load_ps(&filt_p[S * 4]); __m512 f5 = _mm512_load_ps(&filt_p[S * 5]); __m512 f6 = _mm512_load_ps(&filt_p[S * 6]); __m512 f7 = _mm512_load_ps(&filt_p[S * 7]); __m512 f8 = _mm512_load_ps(&filt_p[S * 8]); __m512 f9 = _mm512_load_ps(&filt_p[S * 9]); __m512 f10 = _mm512_load_ps(&filt_p[S * 10]); __m512 f11 = _mm512_load_ps(&filt_p[S * 11]); __m512 f12 = _mm512_load_ps(&filt_p[S * 12]); __m512 f13 = _mm512_load_ps(&filt_p[S * 13]); __m512 f14 = _mm512_load_ps(&filt_p[S * 14]); __m512 f15 = _mm512_load_ps(&filt_p[S * 15]); o = _mm512_fmadd_ps(_mm512_loadu_ps(&in_p[0]), f0, o); o = _mm512_fmadd_ps(_mm512_loadu_ps(&in_p[W * 1]), f1, o); o = _mm512_fmadd_ps(_mm512_loadu_ps(&in_p[W * 2]), f2, o); o = _mm512_fmadd_ps(_mm512_loadu_ps(&in_p[W * 3]), f3, o); o = _mm512_fmadd_ps(_mm512_loadu_ps(&in_p[W * 4]), f4, o); o = _mm512_fmadd_ps(_mm512_loadu_ps(&in_p[W * 5]), f5, o); o = _mm512_fmadd_ps(_mm512_loadu_ps(&in_p[W * 6]), f6, o); o = _mm512_fmadd_ps(_mm512_loadu_ps(&in_p[W * 7]), f7, o); o = _mm512_fmadd_ps(_mm512_loadu_ps(&in_p[W * 8]), f8, o); o = _mm512_fmadd_ps(_mm512_loadu_ps(&in_p[W * 9]), f9, o); o = _mm512_fmadd_ps(_mm512_loadu_ps(&in_p[W * 10]), f10, o); o = _mm512_fmadd_ps(_mm512_loadu_ps(&in_p[W * 11]), f11, o); o = _mm512_fmadd_ps(_mm512_loadu_ps(&in_p[W * 12]), f12, o); o = _mm512_fmadd_ps(_mm512_loadu_ps(&in_p[W * 13]), f13, o); o = _mm512_fmadd_ps(_mm512_loadu_ps(&in_p[W * 14]), f14, o); o = _mm512_fmadd_ps(_mm512_loadu_ps(&in_p[W * 15]), f15, o); } output[n * K * OH * OW + k * OH * OW + oh * OW + ow] = _mm512_reduce_add_ps(o); } } } } } static void calc(int begin, int end) { #pragma omp parallel for collapse(3) num_threads(100) for (int n = begin; n < end; ++n) { for (int k = 0; k < K; ++k) { for (int oh = 0; oh < OH; ++oh) { for (int ow = 0; ow < OW; ++ow) { float o = 0.f; for (int c = 0; c < C; ++c) { for (int r = 0; r < R; ++r) { for (int s = 0; s < S; ++s) { int h = oh * stride - pad + r * dilation; int w = ow * stride - pad + s * dilation; if (h < 0 || h >= H || w < 0 || w >= W) continue; float i = input[n * C * H * W + c * H * W + h * W + w]; float f = filter[k * C * R * S + c * R * S + r * S + s]; o += i * f; } } } output[n * K * OH * OW + k * OH * OW + oh * OW + ow] = o; } } } } } void convolution( float *_input, float *_output, float *_filter, int _N, int _C, int _H, int _W, int _K, int _R, int _S, int _pad, int _dilation, int _stride) { MPI_Status MpiStatus; MPI_Request MpiRequest[3]; if (mpi_rank == 0) { input = _input; output = _output; filter = _filter; } if (_pad == 0 && _dilation == 1 && _stride == 1 && mpi_world_size == 2 && R == 16 && S == 16 && (N % 32 == 0) && (C % 32 == 0) && (H % 32 == 0) && (W % 32 == 0) && (K % 32 == 0)) { if (mpi_rank == 0) { MPI_Isend(input + (N / 2) * C * H * W, (N / 2) * C * H * W , MPI_FLOAT, 1, 0, MPI_COMM_WORLD, &MpiRequest[0]); MPI_Isend(filter, K * C * R * S, MPI_FLOAT, 1, 1, MPI_COMM_WORLD, &MpiRequest[1]); MPI_Irecv(output + (N / 2) * K * OH * OW, (N / 2) * K * OH * OW, MPI_FLOAT, 1, 2, MPI_COMM_WORLD, &MpiRequest[2]); perform_cal(0, N / 2); MPI_Wait(&MpiRequest[0], &MpiStatus); MPI_Wait(&MpiRequest[1], &MpiStatus); MPI_Wait(&MpiRequest[2], &MpiStatus); } else { MPI_Irecv(input, N / 2 * C * H * W , MPI_FLOAT, 0, 0, MPI_COMM_WORLD, &MpiRequest[0]); MPI_Irecv(filter + K * 0 * C * R * S, K * C * R * S, MPI_FLOAT, 0, 1, MPI_COMM_WORLD, &MpiRequest[1]); MPI_Wait(&MpiRequest[0], &MpiStatus); MPI_Wait(&MpiRequest[1], &MpiStatus); perform_cal(0, N / 2); MPI_Isend(output, N / 2 * K * OH * OW, MPI_FLOAT, 0, 2, MPI_COMM_WORLD, &MpiRequest[2]); MPI_Wait(&MpiRequest[2], &MpiStatus); } } else { int Node2Offset = mpi_world_size > 1 ? (N - (N / 2)) : N; if (mpi_rank == 0) { if (mpi_world_size > 1 && N / 2 > 0) { MPI_Isend(input + Node2Offset * C * H * W, N / 2 * C * H * W , MPI_FLOAT, 1, 0, MPI_COMM_WORLD, &MpiRequest[0]); MPI_Isend(filter, K * C * R * S, MPI_FLOAT, 1, 1, MPI_COMM_WORLD, &MpiRequest[1]); MPI_Irecv(output + Node2Offset * K * OH * OW, N / 2 * K * OH * OW, MPI_FLOAT, 1, 2, MPI_COMM_WORLD, &MpiRequest[2]); } calc(0, Node2Offset); if (mpi_world_size > 1 && N / 2 > 0) { MPI_Wait(&MpiRequest[0], &MpiStatus); MPI_Wait(&MpiRequest[1], &MpiStatus); MPI_Wait(&MpiRequest[2], &MpiStatus); } } else if (N / 2 > 0) { MPI_Irecv(input, N / 2 * C * H * W , MPI_FLOAT, 0, 0, MPI_COMM_WORLD, &MpiRequest[0]); MPI_Irecv(filter, K * C * R * S, MPI_FLOAT, 0, 1, MPI_COMM_WORLD, &MpiRequest[1]); MPI_Wait(&MpiRequest[0], &MpiStatus); MPI_Wait(&MpiRequest[1], &MpiStatus); calc(0, N / 2); MPI_Isend(output, N / 2 * K * OH * OW, MPI_FLOAT, 0, 2, MPI_COMM_WORLD, &MpiRequest[2]); MPI_Wait(&MpiRequest[2], &MpiStatus); } } } void convolution_init( int _N, int _C, int _H, int _W, int _K, int _R, int _S, int _pad, int _dilation, int _stride) { N = _N; C = _C; H = _H; W = _W; K = _K; R = _R; S = _S; pad = _pad; dilation = _dilation; stride = _stride; OH = (H + 2 * pad - dilation * (R - 1) - 1) / stride + 1; OW = (W + 2 * pad - dilation * (S - 1) - 1) / stride + 1; MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank); MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size); if (mpi_rank != 0 && N / 2 > 0) { alloc_tensor((float**)&input, N / 2, _C, _H, _W); alloc_tensor((float**)&output, N / 2, _K, OH, OW); alloc_tensor((float**)&filter, _K, _C, _R, _S); } } void convolution_final( int _N, int _C, int _H, int _W, int _K, int _R, int _S, int _pad, int _dilation, int _stride) { }