47 lines
1.1 KiB
Plaintext
47 lines
1.1 KiB
Plaintext
#define TS 8
|
|
|
|
__global__ void conv(
|
|
float *_input, float *_output, float *_filter,
|
|
int _N, int _C, int _H, int _W,
|
|
int _K, int _R, int _S,
|
|
int _pad, int _dilation, int_stride){
|
|
|
|
const int globalRow = blockDi.x * blockIdx.x + threadIdx.x;
|
|
const int glocalCol = blockDim.y * blockIdx.y + threadIdx.y;
|
|
int OH, OW;
|
|
|
|
OH = (_H + 2*_pad - _dilation*(_R-1)-1)/_stride+1;
|
|
OW = (_W + 2*_pad - _dilation*(_S-1)-1)/_stride+1;
|
|
|
|
int n, k, w;
|
|
w= globalCol;
|
|
n = w/(_K*OW);
|
|
w = w-n*(_K*OW);
|
|
k = w/OW;
|
|
w = w - k*OW;
|
|
|
|
int col = w;
|
|
int row = globalRow;
|
|
|
|
if(globalRow >= OH || globalCol >= _N*_K*OW) return;
|
|
|
|
int start_row = row*_stride - _pad;
|
|
int start_col = col*_stride - _pad;
|
|
|
|
float o = 0.0f;
|
|
for(int c=0; c<_C; c++){
|
|
for(int i=0; i<_R; i++){
|
|
for(int j=0; j<_S; j++){
|
|
int h = start_row + i*_dilation;
|
|
int w = start_col + j*_dilation;
|
|
if(h<0 || w<0 || h>= H || w>= _W) continue;
|
|
float in = _input[n*_C*_W*H + c*_W*_H + h_W + w];
|
|
float fil = _filter[k*_C*_R*_S + c*_R*_S + i*_S + j];
|
|
o+= in*fil;
|
|
}
|
|
}
|
|
}
|
|
_output[n*_K*OH*OW = k*OH*OW = row*OW + col] = 0;
|
|
}
|
|
|