chundoong-lab-ta/SamsungDS22/submissions/final/jihye65.park/tmp-B/convolution_org.cu

47 lines
1.1 KiB
Plaintext

#define TS 8
__global__ void conv(
float *_input, float *_output, float *_filter,
int _N, int _C, int _H, int _W,
int _K, int _R, int _S,
int _pad, int _dilation, int_stride){
const int globalRow = blockDi.x * blockIdx.x + threadIdx.x;
const int glocalCol = blockDim.y * blockIdx.y + threadIdx.y;
int OH, OW;
OH = (_H + 2*_pad - _dilation*(_R-1)-1)/_stride+1;
OW = (_W + 2*_pad - _dilation*(_S-1)-1)/_stride+1;
int n, k, w;
w= globalCol;
n = w/(_K*OW);
w = w-n*(_K*OW);
k = w/OW;
w = w - k*OW;
int col = w;
int row = globalRow;
if(globalRow >= OH || globalCol >= _N*_K*OW) return;
int start_row = row*_stride - _pad;
int start_col = col*_stride - _pad;
float o = 0.0f;
for(int c=0; c<_C; c++){
for(int i=0; i<_R; i++){
for(int j=0; j<_S; j++){
int h = start_row + i*_dilation;
int w = start_col + j*_dilation;
if(h<0 || w<0 || h>= H || w>= _W) continue;
float in = _input[n*_C*_W*H + c*_W*_H + h_W + w];
float fil = _filter[k*_C*_R*_S + c*_R*_S + i*_S + j];
o+= in*fil;
}
}
}
_output[n*_K*OH*OW = k*OH*OW = row*OW + col] = 0;
}