45 lines
1.1 KiB
Common Lisp
45 lines
1.1 KiB
Common Lisp
// super super slow sgemm kernel by heehoon
|
|
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
|
|
#define TS 32
|
|
#define WPT 8
|
|
#define RTS (TS/WPT)
|
|
|
|
int r = get_local_id(0);
|
|
int c = get_local_id(1);
|
|
int g_r = TS * get_group_id(0) + r;
|
|
int g_c = TS * get_group_id(1) + c;
|
|
|
|
__local float PA[TS][TS];
|
|
__local float PB[TS][TS];
|
|
|
|
float tmp[WPT];
|
|
for (int w = 0; w < WPT; w++) {
|
|
tmp[w] = 0.0f;
|
|
}
|
|
|
|
int n_tiles = (K + TS - 1) / TS;
|
|
for (int t = 0; t < n_tiles; t++) {
|
|
for (int w = 0; w < WPT; w++) {
|
|
int t_r = TS * t + r;
|
|
int t_c = TS * t + c;
|
|
PA[r + w * RTS][c] = ((g_r + w * RTS) >= M || t_c >= K)? 0 : A[(g_r + w * RTS) * K + t_c];
|
|
PB[r + w * RTS][c] = ((t_r + w * RTS) >= K || g_c >= N)? 0 : B[(t_r + w * RTS) * N + g_c];
|
|
}
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
for (int k = 0; k < TS; k++) {
|
|
for (int w = 0; w < WPT; w++) {
|
|
tmp[w] += PA[r + w * RTS][k] * PB[k][c];
|
|
}
|
|
}
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
}
|
|
|
|
for (int w = 0; w < WPT; w++) {
|
|
if ((g_r + w * RTS) >= M || g_c >= N) continue;
|
|
C[(g_r + w * RTS) * N + g_c] = tmp[w];
|
|
}
|
|
}
|