#define TS 8 __global__ void conv( float *_input, float *_output, float *_filter, int _N, int _C, int _H, int _W, int _K, int _R, int _S, int _pad, int _dilation, int_stride){ const int globalRow = blockDi.x * blockIdx.x + threadIdx.x; const int glocalCol = blockDim.y * blockIdx.y + threadIdx.y; int OH, OW; OH = (_H + 2*_pad - _dilation*(_R-1)-1)/_stride+1; OW = (_W + 2*_pad - _dilation*(_S-1)-1)/_stride+1; int n, k, w; w= globalCol; n = w/(_K*OW); w = w-n*(_K*OW); k = w/OW; w = w - k*OW; int col = w; int row = globalRow; if(globalRow >= OH || globalCol >= _N*_K*OW) return; int start_row = row*_stride - _pad; int start_col = col*_stride - _pad; float o = 0.0f; for(int c=0; c<_C; c++){ for(int i=0; i<_R; i++){ for(int j=0; j<_S; j++){ int h = start_row + i*_dilation; int w = start_col + j*_dilation; if(h<0 || w<0 || h>= H || w>= _W) continue; float in = _input[n*_C*_W*H + c*_W*_H + h_W + w]; float fil = _filter[k*_C*_R*_S + c*_R*_S + i*_S + j]; o+= in*fil; } } } _output[n*_K*OH*OW = k*OH*OW = row*OW + col] = 0; }