// super super slow sgemm kernel by heehoon __kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) { #define TS 32 #define WPT 8 #define RTS (TS/WPT) int r = get_local_id(0); int c = get_local_id(1); int g_r = TS * get_group_id(0) + r; int g_c = TS * get_group_id(1) + c; __local float PA[TS][TS]; __local float PB[TS][TS]; float tmp[WPT]; for (int w = 0; w < WPT; w++) { tmp[w] = 0.0f; } int n_tiles = (K + TS - 1) / TS; for (int t = 0; t < n_tiles; t++) { for (int w = 0; w < WPT; w++) { int t_r = TS * t + r; int t_c = TS * t + c; PA[r + w * RTS][c] = ((g_r + w * RTS) >= M || t_c >= K)? 0 : A[(g_r + w * RTS) * K + t_c]; PB[r + w * RTS][c] = ((t_r + w * RTS) >= K || g_c >= N)? 0 : B[(t_r + w * RTS) * N + g_c]; } barrier(CLK_LOCAL_MEM_FENCE); for (int k = 0; k < TS; k++) { for (int w = 0; w < WPT; w++) { tmp[w] += PA[r + w * RTS][k] * PB[k][c]; } } barrier(CLK_LOCAL_MEM_FENCE); } for (int w = 0; w < WPT; w++) { if ((g_r + w * RTS) >= M || g_c >= N) continue; C[(g_r + w * RTS) * N + g_c] = tmp[w]; } }