chundoong-lab-ta/SHPC2022/hw5_answer/matmul/kernel.cl

47 lines
1.3 KiB
Common Lisp
Raw Normal View History

2022-11-24 20:51:04 +09:00
#define BLOCK_SIZE 32
2022-11-17 16:42:39 +09:00
#define MIN(a, b) (((a) < (b)) ? (a) : (b))
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
int j = get_global_id(0); // column index of C
int i = get_global_id(1); // row index of C
int gj = get_group_id(0);
int gi = get_group_id(1);
if (gi * BLOCK_SIZE >= M || gj * BLOCK_SIZE >= N) return; // boundary check
int nlcol = get_local_size(0);
int nlrow = get_local_size(1);
int lj = get_local_id(0);
int li = get_local_id(1);
__local float Alocal[BLOCK_SIZE][BLOCK_SIZE];
__local float Blocal[BLOCK_SIZE][BLOCK_SIZE];
2022-11-24 20:51:04 +09:00
float c = 0.f;
2022-11-17 16:42:39 +09:00
int A_row_index = (gi * BLOCK_SIZE + li);
int B_col_index = (gj * BLOCK_SIZE + lj);
for (int bk = 0; bk < K; bk += BLOCK_SIZE) {
int A_col_index = bk + lj;
Alocal[li][lj] = (A_row_index < M && A_col_index < K) ?
A[A_row_index * K + A_col_index] :
0.f;
int B_row_index = bk + li;
Blocal[li][lj] = (B_row_index < K && B_col_index < N) ?
B[B_row_index * N + B_col_index] :
0.f;
2022-11-24 20:51:04 +09:00
2022-11-17 16:42:39 +09:00
barrier(CLK_LOCAL_MEM_FENCE);
for (int lk = 0; lk < BLOCK_SIZE; ++lk) {
2022-11-24 20:51:04 +09:00
c += Alocal[li][lk] * Blocal[lk][lj];
2022-11-17 16:42:39 +09:00
}
barrier(CLK_LOCAL_MEM_FENCE);
}
if (i < M && j < N)
2022-11-24 20:51:04 +09:00
C[i * N + j] = c;
2022-11-17 16:42:39 +09:00
}