#define TS 32 #define WPT 8 #define RTS (TS/WPT) #define WIDTH 8 __kernel void sgemm(__global float8 *A, __global float8 *B, __global float8 *C, int M, int N, int K) { // Thread identifiers const int row = get_local_id(0); // Local row ID (max: TS/WIDTH) const int col = get_local_id(1); // Local col ID (max: TS) const int globalRow = TS*get_group_id(0) + row; // 0..M/WIDTH const int globalCol = (TS/WIDTH)*get_group_id(1) + col; // 0..N // Local memory to fit a tile of TS*TS elements of A and B __local float8 Asub[TS][TS/WIDTH]; __local float8 Bsub[TS][TS/WIDTH]; // Initialise the accumulation registers float8 acc = { 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f }; // Loop over all tiles const int numTiles = K/TS; for (int tile=0; tile