chundoong-lab-ta/SHPC2022/hw5_answer/matmul/kernel.cl

48 lines
1.3 KiB
Common Lisp

#define BLOCK_SIZE 16
#define MIN(a, b) (((a) < (b)) ? (a) : (b))
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
int j = get_global_id(0); // column index of C
int i = get_global_id(1); // row index of C
int gj = get_group_id(0);
int gi = get_group_id(1);
if (gi * BLOCK_SIZE >= M || gj * BLOCK_SIZE >= N) return; // boundary check
int nlcol = get_local_size(0);
int nlrow = get_local_size(1);
int lj = get_local_id(0);
int li = get_local_id(1);
__local float Alocal[BLOCK_SIZE][BLOCK_SIZE];
__local float Blocal[BLOCK_SIZE][BLOCK_SIZE];
__local float Clocal[BLOCK_SIZE][BLOCK_SIZE];
Clocal[li][lj] = 0.f;
int A_row_index = (gi * BLOCK_SIZE + li);
int B_col_index = (gj * BLOCK_SIZE + lj);
for (int bk = 0; bk < K; bk += BLOCK_SIZE) {
int A_col_index = bk + lj;
Alocal[li][lj] = (A_row_index < M && A_col_index < K) ?
A[A_row_index * K + A_col_index] :
0.f;
int B_row_index = bk + li;
Blocal[li][lj] = (B_row_index < K && B_col_index < N) ?
B[B_row_index * N + B_col_index] :
0.f;
}
barrier(CLK_LOCAL_MEM_FENCE);
for (int lk = 0; lk < BLOCK_SIZE; ++lk) {
Clocal[li][lj] += Alocal[li][lk] * Blocal[lk][lj];
}
barrier(CLK_LOCAL_MEM_FENCE);
}
if (i < M && j < N)
C[i * N + j] = Clocal[li][lj];
}