chundoong-lab-ta/SamsungDS22/submissions/HW5/jinin.so/kernel.cl

82 lines
2.4 KiB
Common Lisp

// super super slow sgemm kernel by heehoon
#define TS 32
#define VW 8
#define RTS 4
#define WPT 8
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
const int row = get_local_id(0); // Local row ID (max: TS)
const int col = get_local_id(1); // Local col ID (max: TS/WPT == RTS)
const int globalRow = TS*get_group_id(0) + row; // Row ID of C (0..M)
const int globalCol = TS*get_group_id(1) + col; // Col ID of C (0..N)
//printf("%d %d %d %d \n",globalRow,globalCol, M, N);
//if(globalRow >= M || globalCol >=N) {
//return;
//}
// Local memory to fit a tile of TS*TS elements of A and B
__local float Asub[TS][TS];
__local float Bsub[TS][TS];
// Initialise the accumulation registers
float acc[WPT];
for (int w=0; w<WPT; w++) {
acc[w] = 0.0f;
}
// Loop over all tiles
const int numTiles = (K+TS-1)/TS;
//printf("%d %d %d %d %d\n",globalRow,globalCol, M, N,numTiles);
for (int t=0; t<numTiles; t++) {
// Load one tile of A and B into local memory
for (int w=0; w<WPT; w++) {
const int tiledRow = TS*t + row;
const int tiledCol = TS*t + col;
if((globalRow+w*RTS) < M && (tiledCol < K)) {
Asub[row+w*RTS][col] = A[(globalRow+w*RTS)*K+tiledCol];
}else{
Asub[row+w*RTS][col] = 0.0f;
}
if((tiledRow+w*RTS < K) && (globalCol < N)){
Bsub[row+w*RTS][col] = B[(tiledRow+w*RTS)*N+globalCol];
}else{
Bsub[row+w*RTS][col] = 0.0f;
}
}
// Synchronise to make sure the tile is loaded
barrier(CLK_LOCAL_MEM_FENCE);
// Perform the computation for a single tile
for (int k=0; k<TS; k++) {
for (int w=0; w<WPT; w++) {
acc[w] += Asub[row+w*RTS][k] * Bsub[k][col];
}
}
// Synchronise before loading the next tile
barrier(CLK_LOCAL_MEM_FENCE);
}
// Store the final results in C
for (int w=0; w<WPT; w++) {
if((globalRow+ w*RTS) < M && globalCol < N){
C[(globalRow+ w*RTS)*N + globalCol] = acc[w];
}//else{
//C[(globalRow+ w*RTS)*N + globalCol] = 0.0f;
//}
}
}