chundoong-lab-ta/SamsungDS22/final-example/opt/convolution.cu

370 lines
11 KiB
Plaintext

#include "convolution.h"
#include <mpi.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include "util.h"
static float *input, *output, *filter;
static float *_d_input[4], *d_input[4], *_d_output[4], *d_output[4], *d_filter[4], *_d_filter[4];
static int N, C, H, W;
static int K, R, S;
static int OH, OW;
static int pad;
static int dilation;
static int stride;
static int mpi_rank, mpi_world_size;
static int num_devices;
// #define TR
#define rowblocksz 16
#define CHECK_ERROR(cond) \
do { \
if ((cond) != cudaSuccess) { \
printf("[%s:%d] err\n", __FILE__, __LINE__);\
exit(EXIT_FAILURE); \
} \
} while (false)
__global__ void weighttranspose_kernal(float* input, float* output, int K, int C, int R, int S) {
int c = blockDim.x * blockIdx.x + threadIdx.x;
int k_start = (blockDim.y * blockIdx.y + threadIdx.y)*rowblocksz;
int k_end = k_start + rowblocksz;
if (c >= C || k_start >= K) return;
if (k_end >= K) k_end = K;
for (int r = 0; r < R; r++) {
for (int s = 0; s < S; s++) {
for (int k = k_start; k < k_end; k++) {
output[r * S*C*K + s * C*K + c * K + k] = input[k*C*R*S + c * R*S + r * S + s];
}
}
}
}
__global__ void outputtranspose_kernal(float* input, float* output, int K, int OH, int OW) {
int k = blockDim.z * blockIdx.z + threadIdx.z;
int oh = blockDim.y * blockIdx.y + threadIdx.y;
int ow = blockDim.x * blockIdx.x + threadIdx.x;
if (k >= K || oh >= OH || ow >= OW) return;
output[k * OH * OW + oh * OW + ow] = input[oh * OW * K + ow * K + k];
}
__global__ void inputtranspose_kernal(float* input, float* output, int C, int H, int W) {
int c = blockDim.z * blockIdx.z + threadIdx.z;
int h = blockDim.y * blockIdx.y + threadIdx.y;
int w = blockDim.x * blockIdx.x + threadIdx.x;
if (c >= C || h >= H || w >= W) return;
output[h * W * C + w * C + c] = input[c * H * W + h * W + w];
}
__global__ void conv_kernal(float* input, float* weight, float* output,
int C, int H, int W, int K, int R, int S, int OH, int OW,
int stride, int pad, int dilation) {
int oh = blockDim.x * blockIdx.x + threadIdx.x;
int ow = blockDim.y * blockIdx.y + threadIdx.y;
int k_start = (blockDim.z * blockIdx.z + threadIdx.z)*rowblocksz;
int k_end = k_start + rowblocksz;
if (oh >= OH || ow >= OW || k_start >= K) return;
if (k_end >= K) k_end = K;
if(k_end-k_start==rowblocksz){
float o[rowblocksz];
o[0] = 0;
o[1] = 0;
o[2] = 0;
o[3] = 0;
o[4] = 0;
o[5] = 0;
o[6] = 0;
o[7] = 0;
o[8] = 0;
o[9] = 0;
o[10] = 0;
o[11] = 0;
o[12] = 0;
o[13] = 0;
o[14] = 0;
o[15] = 0;
for (int r = 0; r < R; ++r) {
for (int s = 0; s < S; ++s) {
int h = oh * stride - pad + r * dilation;
int w = ow * stride - pad + s * dilation;
if (h < 0 || h >= H || w < 0 || w >= W) continue;
float* w_idx = weight+r * S*C*K + s * C*K + k_start;
float* input_idx = input+h * W*C + w * C;
for (int c = 0; c < C; ++c) {
float i = *input_idx;
o[0] += i * w_idx[0];
o[1] += i * w_idx[1];
o[2] += i * w_idx[2];
o[3] += i * w_idx[3];
o[4] += i * w_idx[4];
o[5] += i * w_idx[5];
o[6] += i * w_idx[6];
o[7] += i * w_idx[7];
o[8] += i * w_idx[8];
o[9] += i * w_idx[9];
o[10] += i * w_idx[10];
o[11] += i * w_idx[11];
o[12] += i * w_idx[12];
o[13] += i * w_idx[13];
o[14] += i * w_idx[14];
o[15] += i * w_idx[15];
w_idx += K;
input_idx++;
}
}
}
int outidx=oh * OW*K + ow * K + k_start;
output[outidx+0] = o[0];
output[outidx+1] = o[1];
output[outidx+2] = o[2];
output[outidx+3] = o[3];
output[outidx+4] = o[4];
output[outidx+5] = o[5];
output[outidx+6] = o[6];
output[outidx+7] = o[7];
output[outidx+8] = o[8];
output[outidx+9] = o[9];
output[outidx+10] = o[10];
output[outidx+11] = o[11];
output[outidx+12] = o[12];
output[outidx+13] = o[13];
output[outidx+14] = o[14];
output[outidx+15] = o[15];
} else {
for (int k = k_start; k < k_end; k++) {
float o = 0;
for (int r = 0; r < R; ++r) {
for (int s = 0; s < S; ++s) {
int h = oh * stride - pad + r * dilation;
int w = ow * stride - pad + s * dilation;
if (h < 0 || h >= H || w < 0 || w >= W) continue;
for (int c = 0; c < C; ++c) {
float i = input[h * W*C + w * C + c];
float f = weight[r * S*C*K + s * C*K + c * K + k];
o += i * f;
}
}
}
output[oh * OW*K + ow * K + k] = o;
}
}
/* Naive kernel
int n = 0;
int k = blockDim.z * blockIdx.z + threadIdx.z;
int oh = blockDim.y * blockIdx.y + threadIdx.y;
int ow = blockDim.x * blockIdx.x + threadIdx.x;
if (oh >= OH || ow >= OW || k >= K) return;
float o = 0.f;
for (int c = 0; c < C; ++c) {
for (int r = 0; r < R; ++r) {
for (int s = 0; s < S; ++s) {
int h = oh * stride - pad + r * dilation;
int w = ow * stride - pad + s * dilation;
if (h < 0 || h >= H || w < 0 || w >= W) continue;
float i = input[n * C * H * W + c * H * W + h * W + w];
float f = weight[k * C * R * S + c * R * S + r * S + s];
o += i * f;
}
}
}
output[n * K * OH * OW + k * OH * OW + oh * OW + ow] = o;
*/
}
void convolution(
float *_input, float *_output, float *_filter,
int _N, int _C, int _H, int _W,
int _K, int _R, int _S,
int _pad, int _dilation, int _stride) {
input = _input;
output = _output;
filter = _filter;
OH = (H + 2 * pad - dilation * (R - 1) - 1) / stride + 1;
OW = (W + 2 * pad - dilation * (S - 1) - 1) / stride + 1;
int cimgidx_start;
int cimgidx_end;
int cimgidx_num;
// Send data to all processes
if (mpi_rank == 0) {
cimgidx_start = N * mpi_rank / mpi_world_size;
cimgidx_end = N * (mpi_rank + 1) / mpi_world_size;
cimgidx_num = cimgidx_end - cimgidx_start;
for (int i = 1; i < mpi_world_size; ++i) {
int pimgidx_start = N * i / mpi_world_size;
int pimgidx_end = N * (i + 1) / mpi_world_size;
int pimgidx_num = pimgidx_end - pimgidx_start;
MPI_Send(&N, 1, MPI_INT, i, 0, MPI_COMM_WORLD);
MPI_Send(filter, K * C * R * S, MPI_FLOAT, i, 1, MPI_COMM_WORLD);
MPI_Send(input + pimgidx_start*C*H*W, pimgidx_num*C*H*W, MPI_FLOAT, i, 2, MPI_COMM_WORLD);
}
}
else {
N = 0;
MPI_Recv(&N, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, NULL);
cimgidx_start = N * mpi_rank / mpi_world_size;
cimgidx_end = N * (mpi_rank + 1) / mpi_world_size;
cimgidx_num = cimgidx_end - cimgidx_start;
alloc_tensor(&input, cimgidx_num, C, H, W);
alloc_tensor(&output, cimgidx_num, K, OH, OW);
alloc_tensor(&filter, K, C, R, S);
MPI_Recv(filter, K * C * R * S, MPI_FLOAT, 0, 1, MPI_COMM_WORLD, NULL);
MPI_Recv(input, cimgidx_num*C*H*W, MPI_FLOAT, 0, 2, MPI_COMM_WORLD, NULL);
}
for (int i = 0; i < num_devices; i++) {
CHECK_ERROR(cudaSetDevice(i));
CHECK_ERROR(cudaMemcpy(_d_filter[i], filter, K*C*R*S * sizeof(float), cudaMemcpyHostToDevice));
#ifdef TR
dim3 blockDim(1, 64, 1);
dim3 gridDim(C, K / 64 + 1, 1);
weighttranspose_kernal <<<gridDim, blockDim >>> (_d_filter[i], d_filter[i], K, C, R, S);
CHECK_ERROR(cudaGetLastError());
#endif
}
#pragma omp parallel for
for (int i = 0; i < num_devices; i++) {
CHECK_ERROR(cudaSetDevice(i));
int j_start = i * cimgidx_num / num_devices;
int j_end = (i+1) * cimgidx_num / num_devices;
// printf("%d-%d\n", j_start, j_end);
for (int j = j_start; j < j_end; ++j) {
CHECK_ERROR(cudaMemcpy(_d_input[i], &input[j*C*H*W], C*H*W * sizeof(float), cudaMemcpyHostToDevice));
#ifdef TR
{
dim3 blockDim(16, 16, 2);
dim3 gridDim((W+15)/16, (H+15)/16, (C+1)/2);
inputtranspose_kernal <<<gridDim, blockDim >>> (_d_input[i], d_input[i], C, H, W);
CHECK_ERROR(cudaGetLastError());
}
#endif
dim3 blockDim(1, 16, 2);
dim3 gridDim(OH, (OW+15) / 16, (K + rowblocksz*2-1)/ (rowblocksz*2));
//dim3 blockDim(16, 16, 1);
//dim3 gridDim((OW+15)/16, (OH+15)/16, K/1);
conv_kernal <<<gridDim, blockDim >>> (d_input[i], d_filter[i], _d_output[i],
C, H, W, K, R, S, OH, OW, stride, pad, dilation);
CHECK_ERROR(cudaGetLastError());
#ifdef TR
{
dim3 blockDim(16, 16, 2);
dim3 gridDim((OW+15)/16, (OH+15)/16, (K+1)/2);
outputtranspose_kernal <<<gridDim, blockDim >>> (_d_output[i], d_output[i], K, OH, OW);
CHECK_ERROR(cudaGetLastError());
}
#endif
CHECK_ERROR(cudaMemcpy(&output[j*K*OH*OW], d_output[i], K*OH*OW * sizeof(float), cudaMemcpyDeviceToHost));
}
}
for (int i = 0; i < num_devices; i++) {
CHECK_ERROR(cudaSetDevice(i));
CHECK_ERROR(cudaDeviceSynchronize());
}
if (mpi_rank == 0) {
for (int i = 1; i < mpi_world_size; ++i) {
int pimgidx_start = N * i / mpi_world_size;
int pimgidx_end = N * (i + 1) / mpi_world_size;
int pimgidx_num = pimgidx_end - pimgidx_start;
MPI_Recv(output + pimgidx_start*K*OH*OW, pimgidx_num*K*OH*OW, MPI_FLOAT, i, 0, MPI_COMM_WORLD, NULL);
}
} else {
int cimgidx_start = N * mpi_rank / mpi_world_size;
int cimgidx_end = N * (mpi_rank + 1) / mpi_world_size;
int cimgidx_num = cimgidx_end - cimgidx_start;
MPI_Send(output, cimgidx_num*K*OH*OW, MPI_FLOAT, 0, 0, MPI_COMM_WORLD);
free(input);
free(output);
}
}
void convolution_init(
int _N, int _C, int _H, int _W,
int _K, int _R, int _S,
int _pad, int _dilation, int _stride) {
N = _N; C = _C; H = _H; W = _W;
K = _K; R = _R; S = _S;
pad = _pad;
dilation = _dilation;
stride = _stride;
OH = (H + 2 * pad - dilation * (R - 1) - 1) / stride + 1;
OW = (W + 2 * pad - dilation * (S - 1) - 1) / stride + 1;
cudaGetDeviceCount(&num_devices);
// num_devices = 1;
MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank);
MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size);
for (int i = 0; i < num_devices; i++) {
CHECK_ERROR(cudaSetDevice(i));
CHECK_ERROR(cudaMalloc(&d_input[i], C*H*W * sizeof(float)));
CHECK_ERROR(cudaMalloc(&d_output[i], K*OH*OW * sizeof(float)));
CHECK_ERROR(cudaMalloc(&d_filter[i], K*C*R*S * sizeof(float)));
#ifdef TR
CHECK_ERROR(cudaMalloc(&_d_input[i], C*H*W * sizeof(float)));
CHECK_ERROR(cudaMalloc(&_d_output[i], K*OH*OW * sizeof(float)));
CHECK_ERROR(cudaMalloc(&_d_filter[i], K*C*R*S * sizeof(float)));
#else
_d_input[i] = d_input[i];
_d_output[i] = d_output[i];
_d_filter[i] = d_filter[i];
#endif
}
}
void convolution_final(
int _N, int _C, int _H, int _W,
int _K, int _R, int _S,
int _pad, int _dilation, int _stride) {
for (int i = 0; i < num_devices; i++) {
CHECK_ERROR(cudaSetDevice(i));
CHECK_ERROR(cudaFree(d_input[i]));
CHECK_ERROR(cudaFree(d_output[i]));
CHECK_ERROR(cudaFree(d_filter[i]));
#ifdef TR
CHECK_ERROR(cudaFree(_d_input[i]));
CHECK_ERROR(cudaFree(_d_output[i]));
CHECK_ERROR(cudaFree(_d_filter[i]));
#endif
}
}