#define TileSize 32 #define SubWorkSize 8 #define Offset (TileSize/SubWorkSize) __kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) { const int l_row = get_local_id(0); // row index of C const int l_col = get_local_id(1); // column index of C const int g_row = TileSize * get_group_id(0) + l_row; // global row index of C const int g_col = TileSize * get_group_id(1) + l_col; // global col index of C __local float Asub[TileSize][TileSize]; __local float Bsub[TileSize][TileSize]; float acc[SubWorkSize]; for (int w = 0; w < SubWorkSize; w++) { acc[w] = 0.0f; } const int num_tiles = (K + TileSize - 1) / TileSize; for (int t = 0; t < num_tiles; t++) { for (int w = 0; w < SubWorkSize; w++) { const int t_row = TileSize * t + l_row; const int t_col = TileSize * t + l_col; int A_row_bound = g_row + w*Offset; int B_row_bound = t_row + w*Offset; if ((A_row_bound < M) && (t_col < K)) Asub[l_row + w*Offset][l_col] = A[(g_row + w*Offset)*K + t_col]; else Asub[l_row + w*Offset][l_col] = 0.0f; if ((B_row_bound < K) && (g_col < N)) Bsub[l_row + w*Offset][l_col] = B[(t_row + w*Offset)*N + g_col]; else Bsub[l_row + w*Offset][l_col] = 0.0f; } barrier(CLK_LOCAL_MEM_FENCE); for (int k=0; k < TileSize; k++) { for (int w=0; w < SubWorkSize; w++) { acc[w] += Asub[l_row + w*Offset][k] * Bsub[k][l_col]; } } barrier(CLK_LOCAL_MEM_FENCE); } for (int w=0; w < SubWorkSize; w++) { int C_row_bound = g_row + w*Offset; if ((C_row_bound < M) && (g_col < N)) C[(g_row + w*Offset)*N + g_col] = acc[w]; else return; } /* if (i >= M || j >= N) return; // boundary check C[i * N + j] = 0; for (int k = 0; k < K; k++) { C[i * N + j] += A[i * K + k] * B[k * N + j]; } */ }