chundoong-lab-ta/SamsungDS22/submissions/HW5/jinho.yi/kernel.cl

69 lines
2.5 KiB
Common Lisp

#define TS 32
#define WPT 8
#define RTS (TS/WPT)
#define WIDTH 8
__kernel void sgemm(__global float8 *A, __global float8 *B, __global float8 *C, int M, int N, int K) {
// Thread identifiers
const int row = get_local_id(0); // Local row ID (max: TS/WIDTH)
const int col = get_local_id(1); // Local col ID (max: TS)
const int globalRow = TS*get_group_id(0) + row; // 0..M/WIDTH
const int globalCol = (TS/WIDTH)*get_group_id(1) + col; // 0..N
// Local memory to fit a tile of TS*TS elements of A and B
__local float8 Asub[TS][TS/WIDTH];
__local float8 Bsub[TS][TS/WIDTH];
// Initialise the accumulation registers
float8 acc = { 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f };
// Loop over all tiles
const int numTiles = K/TS;
for (int tile=0; tile<numTiles; tile++) {
// Load one tile of A and B into local memory
const int tiledRow = TS*tile + row;
const int tiledCol = (TS/WIDTH)*tile + col;
Asub[row][col] = A[globalRow*(K/WIDTH) + tiledCol];
Bsub[row][col] = B[tiledRow*(N/WIDTH) + globalCol];
// Synchronise to make sure the tile is loaded
barrier(CLK_LOCAL_MEM_FENCE);
// Perform the computation for a single tile
float8 vecA, vecB;
float valA;
for (int k=0; k<TS/WIDTH; k++) {
vecA = Asub[row][k];
for (int w=0; w<WIDTH; w++) {
vecB = Bsub[WIDTH*k + w][col];
switch (w) {
case 0: valA = vecA.s0; break;
case 1: valA = vecA.s1; break;
case 2: valA = vecA.s2; break;
case 3: valA = vecA.s3; break;
case 4: valA = vecA.s4; break;
case 5: valA = vecA.s5; break;
case 6: valA = vecA.s6; break;
case 7: valA = vecA.s7; break;
}
acc.s0 += vecB.s0 * valA;
acc.s1 += vecB.s1 * valA;
acc.s2 += vecB.s2 * valA;
acc.s3 += vecB.s3 * valA;
acc.s4 += vecB.s4 * valA;
acc.s5 += vecB.s5 * valA;
acc.s6 += vecB.s6 * valA;
acc.s7 += vecB.s7 * valA;
}
}
// Synchronise before loading the next tile
barrier(CLK_LOCAL_MEM_FENCE);
}
// Store the final results in C
C[globalRow*(N/WIDTH) + globalCol] = acc;
}