chundoong-lab-ta/SamsungDS22/submissions/HW5/taekyung.yeo/kernel.cl

93 lines
2.3 KiB
Common Lisp

#define TS 32
#define WPTF 8
#define WPTL 16
#define RTSF (TS/WPTF)
#define RTSL (TS/WPTL)
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
const int row = get_local_id(0);
const int col = get_local_id(1);
const int global_row = TS*get_group_id(0)+row;
const int global_col = TS*get_group_id(1)+col;
__local float Asub[TS][TS];
__local float Bsub[TS][TS];
float intermediate_val[WPTF];
for(int w=0; w<WPTF; w++) {
intermediate_val[w] = 0.0f;
}
const int num_tiles = (K%TS)>0 ? K/TS+1 : K/TS;
for(int t=0; t<num_tiles; t++){
for(int w=0; w<WPTF; w++){
const int t_row = TS*t+row;
const int t_col = TS*t+col;
if(global_row+w*RTSF>=M || t_col >= K) {Asub[row+w*RTSF][col]=0.0f;}
else
Asub[row+w*RTSF][col]=A[(global_row+w*RTSF)*K+t_col];
if(t_row+w*RTSF>=K||global_col>=N) {Bsub[row+w*RTSF][col]=0.0f;}
else
Bsub[row+w*RTSF][col]=B[(t_row+w*RTSF)*N+global_col];
}
barrier(CLK_LOCAL_MEM_FENCE);
for (int k = 0; k < TS; k++) {
for(int w=0; w<WPTF; w++) {
intermediate_val[w] += Asub[row+w*RTSF][k]*Bsub[k][col];
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
for(int w=0;w<WPTF;w++) {
if(global_row+w*RTSF>=M || global_col >=N) continue;
else C[(global_row+w*RTSF)*N+global_col]=intermediate_val[w];
}
}
__kernel void sgemm2(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
const int row = get_local_id(0);
const int col = get_local_id(1);
const int global_row = TS*get_group_id(0)+row;
const int global_col = TS*get_group_id(1)+col;
__local float Asub[TS][TS];
__local float Bsub[TS][TS];
float intermediate_val[WPTL];
for(int w=0; w<WPTL; w++) {
intermediate_val[w] = 0.0f;
}
const int num_tiles = K/TS;
for(int t=0; t<num_tiles; t++){
for(int w=0; w<WPTL; w++){
const int t_row = TS*t+row;
const int t_col = TS*t+col;
Asub[row+w*RTSL][col]=A[(global_row+w*RTSL)*K+t_col];
Bsub[row+w*RTSL][col]=B[(t_row+w*RTSL)*N+global_col];
}
barrier(CLK_LOCAL_MEM_FENCE);
for (int k = 0; k < TS; k++) {
for(int w=0; w<WPTL; w++) {
intermediate_val[w] += Asub[row+w*RTSL][k]*Bsub[k][col];
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
for(int w=0; w<WPTL; w++) {
C[(global_row+w*RTSL)*N+global_col]=intermediate_val[w];
}
}