chundoong-lab-ta/SamsungDS22/submissions/HW5/jj15.kim/kernel.cl

45 lines
1.1 KiB
Common Lisp

// super super slow sgemm kernel by heehoon
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
#define TS 32
#define WPT 8
#define RTS (TS/WPT)
int r = get_local_id(0);
int c = get_local_id(1);
int g_r = TS * get_group_id(0) + r;
int g_c = TS * get_group_id(1) + c;
__local float PA[TS][TS];
__local float PB[TS][TS];
float tmp[WPT];
for (int w = 0; w < WPT; w++) {
tmp[w] = 0.0f;
}
int n_tiles = (K + TS - 1) / TS;
for (int t = 0; t < n_tiles; t++) {
for (int w = 0; w < WPT; w++) {
int t_r = TS * t + r;
int t_c = TS * t + c;
PA[r + w * RTS][c] = ((g_r + w * RTS) >= M || t_c >= K)? 0 : A[(g_r + w * RTS) * K + t_c];
PB[r + w * RTS][c] = ((t_r + w * RTS) >= K || g_c >= N)? 0 : B[(t_r + w * RTS) * N + g_c];
}
barrier(CLK_LOCAL_MEM_FENCE);
for (int k = 0; k < TS; k++) {
for (int w = 0; w < WPT; w++) {
tmp[w] += PA[r + w * RTS][k] * PB[k][c];
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
for (int w = 0; w < WPT; w++) {
if ((g_r + w * RTS) >= M || g_c >= N) continue;
C[(g_r + w * RTS) * N + g_c] = tmp[w];
}
}