93 lines
2.3 KiB
Common Lisp
93 lines
2.3 KiB
Common Lisp
|
|
#define TS 32
|
|
#define WPTF 8
|
|
#define WPTL 16
|
|
#define RTSF (TS/WPTF)
|
|
#define RTSL (TS/WPTL)
|
|
|
|
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
|
|
const int row = get_local_id(0);
|
|
const int col = get_local_id(1);
|
|
const int global_row = TS*get_group_id(0)+row;
|
|
const int global_col = TS*get_group_id(1)+col;
|
|
|
|
__local float Asub[TS][TS];
|
|
__local float Bsub[TS][TS];
|
|
|
|
float intermediate_val[WPTF];
|
|
for(int w=0; w<WPTF; w++) {
|
|
intermediate_val[w] = 0.0f;
|
|
}
|
|
|
|
const int num_tiles = (K%TS)>0 ? K/TS+1 : K/TS;
|
|
|
|
for(int t=0; t<num_tiles; t++){
|
|
for(int w=0; w<WPTF; w++){
|
|
const int t_row = TS*t+row;
|
|
const int t_col = TS*t+col;
|
|
|
|
if(global_row+w*RTSF>=M || t_col >= K) {Asub[row+w*RTSF][col]=0.0f;}
|
|
else
|
|
Asub[row+w*RTSF][col]=A[(global_row+w*RTSF)*K+t_col];
|
|
|
|
if(t_row+w*RTSF>=K||global_col>=N) {Bsub[row+w*RTSF][col]=0.0f;}
|
|
else
|
|
Bsub[row+w*RTSF][col]=B[(t_row+w*RTSF)*N+global_col];
|
|
}
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
for (int k = 0; k < TS; k++) {
|
|
for(int w=0; w<WPTF; w++) {
|
|
intermediate_val[w] += Asub[row+w*RTSF][k]*Bsub[k][col];
|
|
}
|
|
}
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
}
|
|
|
|
for(int w=0;w<WPTF;w++) {
|
|
if(global_row+w*RTSF>=M || global_col >=N) continue;
|
|
else C[(global_row+w*RTSF)*N+global_col]=intermediate_val[w];
|
|
}
|
|
}
|
|
|
|
__kernel void sgemm2(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
|
|
const int row = get_local_id(0);
|
|
const int col = get_local_id(1);
|
|
const int global_row = TS*get_group_id(0)+row;
|
|
const int global_col = TS*get_group_id(1)+col;
|
|
__local float Asub[TS][TS];
|
|
__local float Bsub[TS][TS];
|
|
|
|
float intermediate_val[WPTL];
|
|
for(int w=0; w<WPTL; w++) {
|
|
intermediate_val[w] = 0.0f;
|
|
}
|
|
|
|
const int num_tiles = K/TS;
|
|
|
|
for(int t=0; t<num_tiles; t++){
|
|
for(int w=0; w<WPTL; w++){
|
|
const int t_row = TS*t+row;
|
|
const int t_col = TS*t+col;
|
|
|
|
Asub[row+w*RTSL][col]=A[(global_row+w*RTSL)*K+t_col];
|
|
Bsub[row+w*RTSL][col]=B[(t_row+w*RTSL)*N+global_col];
|
|
}
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
for (int k = 0; k < TS; k++) {
|
|
for(int w=0; w<WPTL; w++) {
|
|
intermediate_val[w] += Asub[row+w*RTSL][k]*Bsub[k][col];
|
|
}
|
|
}
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
}
|
|
|
|
for(int w=0; w<WPTL; w++) {
|
|
C[(global_row+w*RTSL)*N+global_col]=intermediate_val[w];
|
|
}
|
|
}
|