// super super slow sgemm kernel by heehoon #define TS 32 #define WPT 32 #define RTS 1 #define PADDINGX TS #define PADDINGY TS #define MIN(a, b) ((a) < (b) ? (a) : (b)) __kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) { const int row = get_local_id(0); const int col = get_local_id(1); const int global_row = TS * get_group_id(0) + row; const int global_col = TS * get_group_id(1) + col; __local float Asub[TS][TS]; __local float Bsub[TS][TS]; float Csub[WPT]; for (int w = 0; w < WPT; w++) { Csub[w] = 0.0f; } const int num_tiles = (K + TS - 1) / TS; for (int t = 0; t < num_tiles; t++) { #pragma unroll for (int w = 0; w < WPT; w++) { const int t_row = TS * t + row; const int t_col = TS * t + col; Asub[row + w*RTS][col] = A[(global_row + w*RTS) * K + t_col]; Bsub[row + w*RTS][col] = B[(t_row + w*RTS) * N + global_col]; } barrier(CLK_LOCAL_MEM_FENCE); for (int k = 0; k < TS; k++) { for (int w = 0; w < WPT; w++) { Csub[w] += Asub[row + w*RTS][k] * Bsub[k][col]; } } barrier(CLK_LOCAL_MEM_FENCE); } #pragma unroll for (int w = 0; w < WPT; w++) { C[(global_row + w*RTS) * N + global_col] = Csub[w]; } } // Pad the P * Q matrix with zeroes to form a P_XL * Q_XL matrix __kernel void paddingAddZeroes(const int P, const int Q, const __global float* input, const int P_XL, const int Q_XL, __global float* output) { const int tx = get_group_id(0) * PADDINGX + get_local_id(0); const int ty = get_group_id(1) * PADDINGY + get_local_id(1); // Check whether we are within bounds of the XL matrix if (tx < P_XL && ty < Q_XL) { // Copy the input or pad a zero float value; if (tx < P && ty < Q) { value = input[tx*Q + ty]; } else { value = 0.0f; } output[tx*Q_XL + ty] = value; } } // Remove padded values from a P_XL * Q_XL matrix to form a P * Q matrix __kernel void paddingRemoveZeroes(const int P_XL, const int Q_XL, const __global float* input, const int P, const int Q, __global float* output) { const int tx = get_group_id(0) * PADDINGX + get_local_id(0); const int ty = get_group_id(1) * PADDINGY + get_local_id(1); // Only store the result if within P * Q bounds if (tx < P && ty < Q) { output[tx*Q + ty] = input[tx*Q_XL + ty]; } }