370 lines
11 KiB
Plaintext
370 lines
11 KiB
Plaintext
#include "convolution.h"
|
|
#include <mpi.h>
|
|
#include <cuda_runtime.h>
|
|
#include <stdio.h>
|
|
#include "util.h"
|
|
|
|
static float *input, *output, *filter;
|
|
static float *_d_input[4], *d_input[4], *_d_output[4], *d_output[4], *d_filter[4], *_d_filter[4];
|
|
static int N, C, H, W;
|
|
static int K, R, S;
|
|
static int OH, OW;
|
|
static int pad;
|
|
static int dilation;
|
|
static int stride;
|
|
static int mpi_rank, mpi_world_size;
|
|
static int num_devices;
|
|
|
|
// #define TR
|
|
|
|
#define rowblocksz 16
|
|
#define CHECK_ERROR(cond) \
|
|
do { \
|
|
if ((cond) != cudaSuccess) { \
|
|
printf("[%s:%d] err\n", __FILE__, __LINE__);\
|
|
exit(EXIT_FAILURE); \
|
|
} \
|
|
} while (false)
|
|
|
|
__global__ void weighttranspose_kernal(float* input, float* output, int K, int C, int R, int S) {
|
|
int c = blockDim.x * blockIdx.x + threadIdx.x;
|
|
int k_start = (blockDim.y * blockIdx.y + threadIdx.y)*rowblocksz;
|
|
int k_end = k_start + rowblocksz;
|
|
|
|
if (c >= C || k_start >= K) return;
|
|
|
|
if (k_end >= K) k_end = K;
|
|
|
|
for (int r = 0; r < R; r++) {
|
|
for (int s = 0; s < S; s++) {
|
|
for (int k = k_start; k < k_end; k++) {
|
|
output[r * S*C*K + s * C*K + c * K + k] = input[k*C*R*S + c * R*S + r * S + s];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
__global__ void outputtranspose_kernal(float* input, float* output, int K, int OH, int OW) {
|
|
int k = blockDim.z * blockIdx.z + threadIdx.z;
|
|
int oh = blockDim.y * blockIdx.y + threadIdx.y;
|
|
int ow = blockDim.x * blockIdx.x + threadIdx.x;
|
|
|
|
if (k >= K || oh >= OH || ow >= OW) return;
|
|
|
|
output[k * OH * OW + oh * OW + ow] = input[oh * OW * K + ow * K + k];
|
|
}
|
|
|
|
__global__ void inputtranspose_kernal(float* input, float* output, int C, int H, int W) {
|
|
int c = blockDim.z * blockIdx.z + threadIdx.z;
|
|
int h = blockDim.y * blockIdx.y + threadIdx.y;
|
|
int w = blockDim.x * blockIdx.x + threadIdx.x;
|
|
|
|
if (c >= C || h >= H || w >= W) return;
|
|
|
|
output[h * W * C + w * C + c] = input[c * H * W + h * W + w];
|
|
}
|
|
|
|
__global__ void conv_kernal(float* input, float* weight, float* output,
|
|
int C, int H, int W, int K, int R, int S, int OH, int OW,
|
|
int stride, int pad, int dilation) {
|
|
|
|
int oh = blockDim.x * blockIdx.x + threadIdx.x;
|
|
int ow = blockDim.y * blockIdx.y + threadIdx.y;
|
|
int k_start = (blockDim.z * blockIdx.z + threadIdx.z)*rowblocksz;
|
|
int k_end = k_start + rowblocksz;
|
|
|
|
if (oh >= OH || ow >= OW || k_start >= K) return;
|
|
|
|
if (k_end >= K) k_end = K;
|
|
|
|
if(k_end-k_start==rowblocksz){
|
|
float o[rowblocksz];
|
|
o[0] = 0;
|
|
o[1] = 0;
|
|
o[2] = 0;
|
|
o[3] = 0;
|
|
o[4] = 0;
|
|
o[5] = 0;
|
|
o[6] = 0;
|
|
o[7] = 0;
|
|
o[8] = 0;
|
|
o[9] = 0;
|
|
o[10] = 0;
|
|
o[11] = 0;
|
|
o[12] = 0;
|
|
o[13] = 0;
|
|
o[14] = 0;
|
|
o[15] = 0;
|
|
for (int r = 0; r < R; ++r) {
|
|
for (int s = 0; s < S; ++s) {
|
|
int h = oh * stride - pad + r * dilation;
|
|
int w = ow * stride - pad + s * dilation;
|
|
if (h < 0 || h >= H || w < 0 || w >= W) continue;
|
|
float* w_idx = weight+r * S*C*K + s * C*K + k_start;
|
|
float* input_idx = input+h * W*C + w * C;
|
|
for (int c = 0; c < C; ++c) {
|
|
float i = *input_idx;
|
|
o[0] += i * w_idx[0];
|
|
o[1] += i * w_idx[1];
|
|
o[2] += i * w_idx[2];
|
|
o[3] += i * w_idx[3];
|
|
o[4] += i * w_idx[4];
|
|
o[5] += i * w_idx[5];
|
|
o[6] += i * w_idx[6];
|
|
o[7] += i * w_idx[7];
|
|
o[8] += i * w_idx[8];
|
|
o[9] += i * w_idx[9];
|
|
o[10] += i * w_idx[10];
|
|
o[11] += i * w_idx[11];
|
|
o[12] += i * w_idx[12];
|
|
o[13] += i * w_idx[13];
|
|
o[14] += i * w_idx[14];
|
|
o[15] += i * w_idx[15];
|
|
w_idx += K;
|
|
input_idx++;
|
|
}
|
|
}
|
|
}
|
|
int outidx=oh * OW*K + ow * K + k_start;
|
|
output[outidx+0] = o[0];
|
|
output[outidx+1] = o[1];
|
|
output[outidx+2] = o[2];
|
|
output[outidx+3] = o[3];
|
|
output[outidx+4] = o[4];
|
|
output[outidx+5] = o[5];
|
|
output[outidx+6] = o[6];
|
|
output[outidx+7] = o[7];
|
|
output[outidx+8] = o[8];
|
|
output[outidx+9] = o[9];
|
|
output[outidx+10] = o[10];
|
|
output[outidx+11] = o[11];
|
|
output[outidx+12] = o[12];
|
|
output[outidx+13] = o[13];
|
|
output[outidx+14] = o[14];
|
|
output[outidx+15] = o[15];
|
|
} else {
|
|
for (int k = k_start; k < k_end; k++) {
|
|
float o = 0;
|
|
for (int r = 0; r < R; ++r) {
|
|
for (int s = 0; s < S; ++s) {
|
|
int h = oh * stride - pad + r * dilation;
|
|
int w = ow * stride - pad + s * dilation;
|
|
if (h < 0 || h >= H || w < 0 || w >= W) continue;
|
|
for (int c = 0; c < C; ++c) {
|
|
float i = input[h * W*C + w * C + c];
|
|
float f = weight[r * S*C*K + s * C*K + c * K + k];
|
|
o += i * f;
|
|
}
|
|
}
|
|
}
|
|
output[oh * OW*K + ow * K + k] = o;
|
|
}
|
|
}
|
|
|
|
/* Naive kernel
|
|
int n = 0;
|
|
int k = blockDim.z * blockIdx.z + threadIdx.z;
|
|
int oh = blockDim.y * blockIdx.y + threadIdx.y;
|
|
int ow = blockDim.x * blockIdx.x + threadIdx.x;
|
|
|
|
if (oh >= OH || ow >= OW || k >= K) return;
|
|
|
|
float o = 0.f;
|
|
for (int c = 0; c < C; ++c) {
|
|
for (int r = 0; r < R; ++r) {
|
|
for (int s = 0; s < S; ++s) {
|
|
int h = oh * stride - pad + r * dilation;
|
|
int w = ow * stride - pad + s * dilation;
|
|
if (h < 0 || h >= H || w < 0 || w >= W) continue;
|
|
float i = input[n * C * H * W + c * H * W + h * W + w];
|
|
float f = weight[k * C * R * S + c * R * S + r * S + s];
|
|
o += i * f;
|
|
}
|
|
}
|
|
}
|
|
output[n * K * OH * OW + k * OH * OW + oh * OW + ow] = o;
|
|
*/
|
|
}
|
|
|
|
void convolution(
|
|
float *_input, float *_output, float *_filter,
|
|
int _N, int _C, int _H, int _W,
|
|
int _K, int _R, int _S,
|
|
int _pad, int _dilation, int _stride) {
|
|
input = _input;
|
|
output = _output;
|
|
filter = _filter;
|
|
|
|
OH = (H + 2 * pad - dilation * (R - 1) - 1) / stride + 1;
|
|
OW = (W + 2 * pad - dilation * (S - 1) - 1) / stride + 1;
|
|
|
|
|
|
int cimgidx_start;
|
|
int cimgidx_end;
|
|
int cimgidx_num;
|
|
|
|
// Send data to all processes
|
|
if (mpi_rank == 0) {
|
|
|
|
cimgidx_start = N * mpi_rank / mpi_world_size;
|
|
cimgidx_end = N * (mpi_rank + 1) / mpi_world_size;
|
|
cimgidx_num = cimgidx_end - cimgidx_start;
|
|
|
|
for (int i = 1; i < mpi_world_size; ++i) {
|
|
int pimgidx_start = N * i / mpi_world_size;
|
|
int pimgidx_end = N * (i + 1) / mpi_world_size;
|
|
int pimgidx_num = pimgidx_end - pimgidx_start;
|
|
|
|
MPI_Send(&N, 1, MPI_INT, i, 0, MPI_COMM_WORLD);
|
|
MPI_Send(filter, K * C * R * S, MPI_FLOAT, i, 1, MPI_COMM_WORLD);
|
|
MPI_Send(input + pimgidx_start*C*H*W, pimgidx_num*C*H*W, MPI_FLOAT, i, 2, MPI_COMM_WORLD);
|
|
}
|
|
}
|
|
else {
|
|
N = 0;
|
|
MPI_Recv(&N, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, NULL);
|
|
|
|
cimgidx_start = N * mpi_rank / mpi_world_size;
|
|
cimgidx_end = N * (mpi_rank + 1) / mpi_world_size;
|
|
cimgidx_num = cimgidx_end - cimgidx_start;
|
|
|
|
alloc_tensor(&input, cimgidx_num, C, H, W);
|
|
alloc_tensor(&output, cimgidx_num, K, OH, OW);
|
|
alloc_tensor(&filter, K, C, R, S);
|
|
|
|
MPI_Recv(filter, K * C * R * S, MPI_FLOAT, 0, 1, MPI_COMM_WORLD, NULL);
|
|
MPI_Recv(input, cimgidx_num*C*H*W, MPI_FLOAT, 0, 2, MPI_COMM_WORLD, NULL);
|
|
}
|
|
|
|
for (int i = 0; i < num_devices; i++) {
|
|
CHECK_ERROR(cudaSetDevice(i));
|
|
CHECK_ERROR(cudaMemcpy(_d_filter[i], filter, K*C*R*S * sizeof(float), cudaMemcpyHostToDevice));
|
|
|
|
#ifdef TR
|
|
dim3 blockDim(1, 64, 1);
|
|
dim3 gridDim(C, K / 64 + 1, 1);
|
|
weighttranspose_kernal <<<gridDim, blockDim >>> (_d_filter[i], d_filter[i], K, C, R, S);
|
|
CHECK_ERROR(cudaGetLastError());
|
|
#endif
|
|
}
|
|
|
|
|
|
#pragma omp parallel for
|
|
for (int i = 0; i < num_devices; i++) {
|
|
CHECK_ERROR(cudaSetDevice(i));
|
|
int j_start = i * cimgidx_num / num_devices;
|
|
int j_end = (i+1) * cimgidx_num / num_devices;
|
|
|
|
// printf("%d-%d\n", j_start, j_end);
|
|
for (int j = j_start; j < j_end; ++j) {
|
|
CHECK_ERROR(cudaMemcpy(_d_input[i], &input[j*C*H*W], C*H*W * sizeof(float), cudaMemcpyHostToDevice));
|
|
#ifdef TR
|
|
{
|
|
dim3 blockDim(16, 16, 2);
|
|
dim3 gridDim((W+15)/16, (H+15)/16, (C+1)/2);
|
|
inputtranspose_kernal <<<gridDim, blockDim >>> (_d_input[i], d_input[i], C, H, W);
|
|
CHECK_ERROR(cudaGetLastError());
|
|
}
|
|
#endif
|
|
|
|
dim3 blockDim(1, 16, 2);
|
|
dim3 gridDim(OH, (OW+15) / 16, (K + rowblocksz*2-1)/ (rowblocksz*2));
|
|
//dim3 blockDim(16, 16, 1);
|
|
//dim3 gridDim((OW+15)/16, (OH+15)/16, K/1);
|
|
conv_kernal <<<gridDim, blockDim >>> (d_input[i], d_filter[i], _d_output[i],
|
|
C, H, W, K, R, S, OH, OW, stride, pad, dilation);
|
|
CHECK_ERROR(cudaGetLastError());
|
|
|
|
#ifdef TR
|
|
{
|
|
dim3 blockDim(16, 16, 2);
|
|
dim3 gridDim((OW+15)/16, (OH+15)/16, (K+1)/2);
|
|
outputtranspose_kernal <<<gridDim, blockDim >>> (_d_output[i], d_output[i], K, OH, OW);
|
|
CHECK_ERROR(cudaGetLastError());
|
|
}
|
|
#endif
|
|
|
|
CHECK_ERROR(cudaMemcpy(&output[j*K*OH*OW], d_output[i], K*OH*OW * sizeof(float), cudaMemcpyDeviceToHost));
|
|
}
|
|
}
|
|
|
|
for (int i = 0; i < num_devices; i++) {
|
|
CHECK_ERROR(cudaSetDevice(i));
|
|
CHECK_ERROR(cudaDeviceSynchronize());
|
|
}
|
|
|
|
if (mpi_rank == 0) {
|
|
|
|
for (int i = 1; i < mpi_world_size; ++i) {
|
|
int pimgidx_start = N * i / mpi_world_size;
|
|
int pimgidx_end = N * (i + 1) / mpi_world_size;
|
|
int pimgidx_num = pimgidx_end - pimgidx_start;
|
|
MPI_Recv(output + pimgidx_start*K*OH*OW, pimgidx_num*K*OH*OW, MPI_FLOAT, i, 0, MPI_COMM_WORLD, NULL);
|
|
}
|
|
} else {
|
|
int cimgidx_start = N * mpi_rank / mpi_world_size;
|
|
int cimgidx_end = N * (mpi_rank + 1) / mpi_world_size;
|
|
int cimgidx_num = cimgidx_end - cimgidx_start;
|
|
MPI_Send(output, cimgidx_num*K*OH*OW, MPI_FLOAT, 0, 0, MPI_COMM_WORLD);
|
|
|
|
free(input);
|
|
free(output);
|
|
}
|
|
|
|
}
|
|
|
|
void convolution_init(
|
|
int _N, int _C, int _H, int _W,
|
|
int _K, int _R, int _S,
|
|
int _pad, int _dilation, int _stride) {
|
|
N = _N; C = _C; H = _H; W = _W;
|
|
K = _K; R = _R; S = _S;
|
|
pad = _pad;
|
|
dilation = _dilation;
|
|
stride = _stride;
|
|
|
|
|
|
OH = (H + 2 * pad - dilation * (R - 1) - 1) / stride + 1;
|
|
OW = (W + 2 * pad - dilation * (S - 1) - 1) / stride + 1;
|
|
|
|
cudaGetDeviceCount(&num_devices);
|
|
// num_devices = 1;
|
|
|
|
MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank);
|
|
MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size);
|
|
for (int i = 0; i < num_devices; i++) {
|
|
CHECK_ERROR(cudaSetDevice(i));
|
|
|
|
CHECK_ERROR(cudaMalloc(&d_input[i], C*H*W * sizeof(float)));
|
|
CHECK_ERROR(cudaMalloc(&d_output[i], K*OH*OW * sizeof(float)));
|
|
CHECK_ERROR(cudaMalloc(&d_filter[i], K*C*R*S * sizeof(float)));
|
|
#ifdef TR
|
|
CHECK_ERROR(cudaMalloc(&_d_input[i], C*H*W * sizeof(float)));
|
|
CHECK_ERROR(cudaMalloc(&_d_output[i], K*OH*OW * sizeof(float)));
|
|
CHECK_ERROR(cudaMalloc(&_d_filter[i], K*C*R*S * sizeof(float)));
|
|
#else
|
|
_d_input[i] = d_input[i];
|
|
_d_output[i] = d_output[i];
|
|
_d_filter[i] = d_filter[i];
|
|
#endif
|
|
}
|
|
}
|
|
|
|
void convolution_final(
|
|
int _N, int _C, int _H, int _W,
|
|
int _K, int _R, int _S,
|
|
int _pad, int _dilation, int _stride) {
|
|
for (int i = 0; i < num_devices; i++) {
|
|
CHECK_ERROR(cudaSetDevice(i));
|
|
CHECK_ERROR(cudaFree(d_input[i]));
|
|
CHECK_ERROR(cudaFree(d_output[i]));
|
|
CHECK_ERROR(cudaFree(d_filter[i]));
|
|
|
|
#ifdef TR
|
|
CHECK_ERROR(cudaFree(_d_input[i]));
|
|
CHECK_ERROR(cudaFree(_d_output[i]));
|
|
CHECK_ERROR(cudaFree(_d_filter[i]));
|
|
#endif
|
|
}
|
|
}
|