#define BLOCK_SIZE 32 #define MIN(a, b) (((a) < (b)) ? (a) : (b)) __kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) { int j = get_global_id(0); // column index of C int i = get_global_id(1); // row index of C int gj = get_group_id(0); int gi = get_group_id(1); if (gi * BLOCK_SIZE >= M || gj * BLOCK_SIZE >= N) return; // boundary check int nlcol = get_local_size(0); int nlrow = get_local_size(1); int lj = get_local_id(0); int li = get_local_id(1); __local float Alocal[BLOCK_SIZE][BLOCK_SIZE]; __local float Blocal[BLOCK_SIZE][BLOCK_SIZE]; float c = 0.f; int A_row_index = (gi * BLOCK_SIZE + li); int B_col_index = (gj * BLOCK_SIZE + lj); for (int bk = 0; bk < K; bk += BLOCK_SIZE) { int A_col_index = bk + lj; Alocal[li][lj] = (A_row_index < M && A_col_index < K) ? A[A_row_index * K + A_col_index] : 0.f; int B_row_index = bk + li; Blocal[li][lj] = (B_row_index < K && B_col_index < N) ? B[B_row_index * N + B_col_index] : 0.f; barrier(CLK_LOCAL_MEM_FENCE); for (int lk = 0; lk < BLOCK_SIZE; ++lk) { c += Alocal[li][lk] * Blocal[lk][lj]; } barrier(CLK_LOCAL_MEM_FENCE); } if (i < M && j < N) C[i * N + j] = c; }