chundoong-lab-ta/SamsungDS22/submissions/final/jh8846.kim/tmp-A/convolution.cpp

251 lines
7.0 KiB
C++

#include "convolution.h"
#include <mpi.h>
#include <stdio.h>
///추가
#include <cstdlib>
#include <cstdio>
#include <omp.h>
#include <cstdio>
#include <cstdlib>
#include <omp.h>
#include <getopt.h>
#include <stdbool.h>
#include <stdlib.h>
#include "util.h"
#include <mpi.h>
#include <immintrin.h>
static float *input, *output, *filter;
static int N, C, H, W;
static int K, R, S;
static int OH, OW;
static int pad;
static int dilation;
static int stride;
static int mpi_rank, mpi_world_size;
static int tag=0;
// 추가
static int num_threads = 100;
static int N_nums[2] = {0,0};
static int N_offset[2] = {0,0};
// 추가
int omp_get_thread_num(void);
int omp_get_num_threads(void);
//static double elapsed_time;
//#define Timer_ID 1
void convolution(
float *_input, float *_output, float *_filter,
int _N, int _C, int _H, int _W,
int _K, int _R, int _S,
int _pad, int _dilation, int _stride) {
input = _input;
output = _output;
filter = _filter;
N =_N; C =_C; H =_H; W=_W; K=_K; R=_R; S=_S;
pad=_pad; dilation=_dilation; stride=_stride;
OH = (H + 2 * pad - dilation * (R - 1) - 1) / stride + 1;
OW = (W + 2 * pad - dilation * (S - 1) - 1) / stride + 1;
MPI_Request request;
MPI_Status status;
int N_num = N/mpi_world_size;
// 노드별 N 개수 정의
for(int i=0; i<mpi_world_size; i++)
{
if(i==mpi_world_size - 1)
{
N_nums[i] = (N - (N_num * (mpi_world_size-1)));
//printf("mpi_world_size : %d, N_nums[0]:%d\n", mpi_world_size, N_nums[0]);
}
else
{
N_nums[i] = N_num ; //master에 job 추가 할당을 위해 +1
//printf("mpi_world_size : %d, N_nums[1]:%d\n", mpi_world_size, N_nums[1]);
}
}
// 노드별 시작 위치 정의 for N
for(int i=0; i<mpi_world_size; i++)
{
N_offset[i+1] = N_offset[i] + N_nums[i];
}
//timer_start(Timer_ID);
// 행렬 alloc @ task 노드
if(mpi_rank != 0)
{
N = N_nums[mpi_rank];
alloc_tensor(&input, N_nums[mpi_rank], C, H, W);
alloc_tensor(&output, N_nums[mpi_rank], K, OH, OW);
alloc_tensor(&filter, K, C, R, S);
//zero_tensor(input, N_nums[mpi_rank], C, H, W);
//zero_tensor(output, N_nums[mpi_rank], K, OH, OW);
//zero_tensor(filter, K, C, R, S);
}
// elapsed_time = timer_stop(Timer_ID);
//printf("1. alloc time %f sec\n", elapsed_time);
//timer_start(Timer_ID);
// INPUT 전송/수신
if(mpi_rank == 0)
{
for(int i=1; i<mpi_world_size; i++)
{
MPI_Isend(&input[N_offset[i]*C*H*W], N_nums[i]*C*H*W, MPI_FLOAT, i, tag, MPI_COMM_WORLD, &request);
}
}
else
{
MPI_Recv(input, N_nums[mpi_rank]*C*H*W, MPI_FLOAT, 0, tag, MPI_COMM_WORLD, &status);
}
// elapsed_time = timer_stop(Timer_ID);
// printf("2. input 전송/수신 시간 %f sec\n", elapsed_time);
// timer_start(Timer_ID);
// Filter 전송/수신
MPI_Bcast(filter, R*S*C*K, MPI_FLOAT, 0, MPI_COMM_WORLD);
// elapsed_time = timer_stop(Timer_ID);
// printf("3. filter 전송/수신 시간 %f sec\n", elapsed_time);
// convolution 수행 시작
// timer_start(Timer_ID);
//int start = 0;
int end_N = N_nums[mpi_rank];
#pragma omp parallel for num_threads(num_threads) collapse(3) schedule(dynamic)
for (int n = 0; n < end_N; n++)
{
for (int k = 0; k < K; k++)
{
for (int oh = 0; oh < OH; oh++)
{
for (int ow = 0; ow < OW; ow++)
{
float o = 0.f;
__m512 o_result = _mm512_set1_ps(0.0f);
for (int c = 0; c < C; ++c)
{
for (int r = 0; r < R; ++r)
{
float input_temp[16];
int h = oh * stride - pad + r * dilation;
int h_re = n * _C * _H * _W + c * _H * _W + h * _W;
int ow_re = ow * stride - pad;
int fil_re =k * _C * _R * _S + c * _R * _S + r * _S;
int w;
if(pad==0 && dilation==1 && stride==1 && R%16 ==0 && S%16 ==0)
{
for (int s = 0; s < S; s=s+16)
{
w= ow * stride - pad + (s+0) * dilation;
if (h < 0 || h >= H || w < 0 || w >= W) continue;
input_temp[0] = input[h_re + (ow_re +(s+0) * dilation)];
input_temp[1] = input[h_re + (ow_re + (s+1) * dilation)];
input_temp[2] = input[h_re + (ow_re + (s+2) * dilation)];
input_temp[3] = input[h_re + (ow_re + (s+3) * dilation)];
input_temp[4] = input[h_re + (ow_re + (s+4) * dilation)];
input_temp[5] = input[h_re + (ow_re + (s+5) * dilation)];
input_temp[6] = input[h_re + (ow_re + (s+6) * dilation)];
input_temp[7] = input[h_re + (ow_re + (s+7) * dilation)];
input_temp[8] = input[h_re + (ow_re + (s+8)* dilation)];
input_temp[9] = input[h_re + (ow_re + (s+9) * dilation)];
input_temp[10] = input[h_re + (ow_re + (s+10) * dilation)];
input_temp[11] = input[h_re + (ow_re + (s+11) * dilation)];
input_temp[12] = input[h_re + (ow_re + (s+12) * dilation)];
input_temp[13] = input[h_re + (ow_re + (s+13) * dilation)];
input_temp[14] = input[h_re + (ow_re + (s+14) * dilation)];
input_temp[15] = input[h_re + (ow_re + (s+15) * dilation)];
__m512 i0 = _mm512_loadu_ps(input_temp);
__m512 f0 = _mm512_loadu_ps(&filter[fil_re + (s+0)]);
o_result = _mm512_fmadd_ps(i0, f0, o_result);
}
/*
float o_result_sum = 0.0f;
int tt_end;
if(S%16 == 0) tt_end=16;
else tt_end=S%16;
for(int tt=0; tt<tt_end; tt++)
{
o_result_sum +=o_result[tt];
}
*/
output[n * K * OH * OW + k * OH * OW + oh * OW + ow] = _mm512_reduce_add_ps(o_result); //o_result_sum;
}
else
{
for (int s = 0; s < S; ++s)
{
w = ow * stride - pad + s * dilation;
if (h < 0 || h >= H || w < 0 || w >= W) continue;
float i = input[n * C * H * W + c * H * W + h * W + w];
float f = filter[k * C * R * S + c * R * S + r * S + s];
o += i * f;
}
output[n * K * OH * OW + k * OH * OW + oh * OW + ow] = o;
}
}
}
}
}
}
}
// convolution 수행 끝
// elapsed_time = timer_stop(Timer_ID);
// printf("4. 계산 시간 %f sec\n", elapsed_time);
//timer_start(Timer_ID);
// 계산결과 output 전송/수신
if(mpi_rank != 0)
{
MPI_Isend(output, (N_nums[mpi_rank]*OH*OW*K), MPI_FLOAT, 0, 0, MPI_COMM_WORLD, &request);
}
else
{
for(int i=1; i<mpi_world_size; i++)
{
MPI_Recv(&output[N_offset[i]*OH*OW*K], (N_nums[i]*OH*OW*K), MPI_FLOAT, i, 0, MPI_COMM_WORLD, &status);
}
}
//elapsed_time = timer_stop(Timer_ID);
// printf("5. 계산결과 전송/수신 시간 %f sec\n", elapsed_time);
}
void convolution_init(
int _N, int _C, int _H, int _W,
int _K, int _R, int _S,
int _pad, int _dilation, int _stride) {
N = _N; C = _C; H = _H; W = _W;
K = _K; R = _R; S = _S;
pad = _pad;
dilation = _dilation;
stride = _stride;
MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank);
MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size);
}
void convolution_final(
int _N, int _C, int _H, int _W,
int _K, int _R, int _S,
int _pad, int _dilation, int _stride) {
}