chundoong-lab-ta/SamsungDS22/submissions/HW5/h2.nam/kernel.cl

57 lines
1.5 KiB
Common Lisp
Raw Normal View History

2022-09-29 18:01:45 +09:00
#define TSIZE 32
#define WPT 16
#define RSIZE (TSIZE/WPT)
// super super slow sgemm kernel by heehoon
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
const int row = get_local_id(0);
const int col = get_local_id(1);
const int grow = TSIZE * get_group_id(0) + row;
const int gcolumn = TSIZE * get_group_id(1) + col;
const int t_num = (K+TSIZE-1)/TSIZE;
float result[WPT];
__local float Sub_A[TSIZE][TSIZE];
__local float Sub_B[TSIZE][TSIZE];
for (int i=0; i<WPT; i++) {
result[i] = 0.0f;
}
for (int j=0; j<t_num; j++) {
for (int i=0; i<WPT; i++) {
const int trow = j * TSIZE + row;
const int tcolumn = j * TSIZE + col;
if(grow + i * RSIZE >= M || tcolumn >= K){
Sub_A[row + i * RSIZE][col] = 0.0f;
}
else{
Sub_A[row + i * RSIZE][col] = A[(grow + i * RSIZE) * K + tcolumn];
}
if(trow + i*RSIZE >=K || gcolumn >=N){
Sub_B[row + i*RSIZE][col] = 0.0f;
}
else{
Sub_B[row + i * RSIZE][col] = B[(trow + i * RSIZE) * N + gcolumn];
}
}
barrier(CLK_LOCAL_MEM_FENCE);
for (int k=0; k<TSIZE; k++) {
for (int i=0; i<WPT; i++) {
result[i] += Sub_A[row + i * RSIZE][k] * Sub_B[k][col];
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
for (int i=0; i<WPT; i++) {
if(grow + i * RSIZE >= M || gcolumn >=N) continue;
C[(grow + i * RSIZE) * N + gcolumn] = result[i];
}
}