69 lines
2.5 KiB
Common Lisp
69 lines
2.5 KiB
Common Lisp
#define TS 32
|
|
#define WPT 8
|
|
#define RTS (TS/WPT)
|
|
#define WIDTH 8
|
|
|
|
__kernel void sgemm(__global float8 *A, __global float8 *B, __global float8 *C, int M, int N, int K) {
|
|
|
|
// Thread identifiers
|
|
const int row = get_local_id(0); // Local row ID (max: TS/WIDTH)
|
|
const int col = get_local_id(1); // Local col ID (max: TS)
|
|
const int globalRow = TS*get_group_id(0) + row; // 0..M/WIDTH
|
|
const int globalCol = (TS/WIDTH)*get_group_id(1) + col; // 0..N
|
|
|
|
// Local memory to fit a tile of TS*TS elements of A and B
|
|
__local float8 Asub[TS][TS/WIDTH];
|
|
__local float8 Bsub[TS][TS/WIDTH];
|
|
|
|
// Initialise the accumulation registers
|
|
float8 acc = { 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f };
|
|
|
|
// Loop over all tiles
|
|
const int numTiles = K/TS;
|
|
|
|
for (int tile=0; tile<numTiles; tile++) {
|
|
// Load one tile of A and B into local memory
|
|
const int tiledRow = TS*tile + row;
|
|
const int tiledCol = (TS/WIDTH)*tile + col;
|
|
Asub[row][col] = A[globalRow*(K/WIDTH) + tiledCol];
|
|
Bsub[row][col] = B[tiledRow*(N/WIDTH) + globalCol];
|
|
|
|
// Synchronise to make sure the tile is loaded
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
// Perform the computation for a single tile
|
|
float8 vecA, vecB;
|
|
float valA;
|
|
for (int k=0; k<TS/WIDTH; k++) {
|
|
vecA = Asub[row][k];
|
|
for (int w=0; w<WIDTH; w++) {
|
|
vecB = Bsub[WIDTH*k + w][col];
|
|
switch (w) {
|
|
case 0: valA = vecA.s0; break;
|
|
case 1: valA = vecA.s1; break;
|
|
case 2: valA = vecA.s2; break;
|
|
case 3: valA = vecA.s3; break;
|
|
case 4: valA = vecA.s4; break;
|
|
case 5: valA = vecA.s5; break;
|
|
case 6: valA = vecA.s6; break;
|
|
case 7: valA = vecA.s7; break;
|
|
}
|
|
acc.s0 += vecB.s0 * valA;
|
|
acc.s1 += vecB.s1 * valA;
|
|
acc.s2 += vecB.s2 * valA;
|
|
acc.s3 += vecB.s3 * valA;
|
|
acc.s4 += vecB.s4 * valA;
|
|
acc.s5 += vecB.s5 * valA;
|
|
acc.s6 += vecB.s6 * valA;
|
|
acc.s7 += vecB.s7 * valA;
|
|
}
|
|
}
|
|
|
|
// Synchronise before loading the next tile
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
}
|
|
|
|
// Store the final results in C
|
|
C[globalRow*(N/WIDTH) + globalCol] = acc;
|
|
}
|