2022-11-24 20:51:04 +09:00
|
|
|
#define BLOCK_SIZE 32
|
2022-11-17 16:42:39 +09:00
|
|
|
#define MIN(a, b) (((a) < (b)) ? (a) : (b))
|
|
|
|
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
|
|
|
|
int j = get_global_id(0); // column index of C
|
|
|
|
int i = get_global_id(1); // row index of C
|
|
|
|
|
|
|
|
int gj = get_group_id(0);
|
|
|
|
int gi = get_group_id(1);
|
|
|
|
|
|
|
|
if (gi * BLOCK_SIZE >= M || gj * BLOCK_SIZE >= N) return; // boundary check
|
|
|
|
|
|
|
|
int nlcol = get_local_size(0);
|
|
|
|
int nlrow = get_local_size(1);
|
|
|
|
int lj = get_local_id(0);
|
|
|
|
int li = get_local_id(1);
|
|
|
|
|
|
|
|
__local float Alocal[BLOCK_SIZE][BLOCK_SIZE];
|
|
|
|
__local float Blocal[BLOCK_SIZE][BLOCK_SIZE];
|
|
|
|
|
2022-11-24 20:51:04 +09:00
|
|
|
float c = 0.f;
|
2022-11-17 16:42:39 +09:00
|
|
|
|
|
|
|
int A_row_index = (gi * BLOCK_SIZE + li);
|
|
|
|
int B_col_index = (gj * BLOCK_SIZE + lj);
|
|
|
|
|
|
|
|
for (int bk = 0; bk < K; bk += BLOCK_SIZE) {
|
|
|
|
int A_col_index = bk + lj;
|
|
|
|
Alocal[li][lj] = (A_row_index < M && A_col_index < K) ?
|
|
|
|
A[A_row_index * K + A_col_index] :
|
|
|
|
0.f;
|
|
|
|
|
|
|
|
int B_row_index = bk + li;
|
|
|
|
Blocal[li][lj] = (B_row_index < K && B_col_index < N) ?
|
|
|
|
B[B_row_index * N + B_col_index] :
|
|
|
|
0.f;
|
2022-11-24 20:51:04 +09:00
|
|
|
|
2022-11-17 16:42:39 +09:00
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
|
|
|
|
for (int lk = 0; lk < BLOCK_SIZE; ++lk) {
|
2022-11-24 20:51:04 +09:00
|
|
|
c += Alocal[li][lk] * Blocal[lk][lj];
|
2022-11-17 16:42:39 +09:00
|
|
|
}
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (i < M && j < N)
|
2022-11-24 20:51:04 +09:00
|
|
|
C[i * N + j] = c;
|
2022-11-17 16:42:39 +09:00
|
|
|
}
|