80 lines
2.2 KiB
Common Lisp
80 lines
2.2 KiB
Common Lisp
|
|
||
|
#if 0
|
||
|
// super super slow sgemm kernel by heehoon
|
||
|
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
|
||
|
int i = get_global_id(0); // row index of C
|
||
|
int j = get_global_id(1); // column index of C
|
||
|
if (i >= M || j >= N) return; // boundary check
|
||
|
|
||
|
C[i * N + j] = 0;
|
||
|
for (int k = 0; k < K; k++) {
|
||
|
C[i * N + j] += A[i * K + k] * B[k * N + j];
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#else
|
||
|
|
||
|
#define TILE_SIZE 32
|
||
|
#define WPT 8 // work per thread
|
||
|
#define RTS (TILE_SIZE / WPT)
|
||
|
|
||
|
|
||
|
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
|
||
|
const int row = get_local_id(0); // Local row ID (max: TILE_SIZE/WPT == RTS)
|
||
|
const int col = get_local_id(1); // Local col ID (max: TILE_SIZE)
|
||
|
const int globalRow = TILE_SIZE * get_group_id(0) + row; // row index of C (N)
|
||
|
const int globalCol = TILE_SIZE * get_group_id(1) + col; // column index of C (M)
|
||
|
const int numTiles = (K + TILE_SIZE - 1)/ TILE_SIZE;
|
||
|
|
||
|
// local memory for tile
|
||
|
__local float Asub[TILE_SIZE][TILE_SIZE];
|
||
|
__local float Bsub[TILE_SIZE][TILE_SIZE];
|
||
|
|
||
|
// Init result memory
|
||
|
float res[WPT];
|
||
|
for (int i = 0; i < WPT; i++) {
|
||
|
res[i] = 0.0f;
|
||
|
}
|
||
|
|
||
|
for (int t = 0; t < numTiles; t++) {
|
||
|
const int tiledRow = TILE_SIZE * t + row;
|
||
|
const int tiledCol = TILE_SIZE * t + col;
|
||
|
|
||
|
// Load A and B to local memory
|
||
|
for (int w = 0; w < WPT; w++) {
|
||
|
if (((w * RTS + globalRow) >= M) || (tiledCol >= K)) {
|
||
|
Asub[w * RTS + row][col] = 0;
|
||
|
}
|
||
|
else {
|
||
|
Asub[w * RTS + row][col] = A[(w * RTS + globalRow) * K + tiledCol];
|
||
|
}
|
||
|
|
||
|
if (((w * RTS + tiledRow) >= K) || (globalCol >= N)) {
|
||
|
Bsub[w * RTS + row][col] = 0;
|
||
|
}
|
||
|
else {
|
||
|
Bsub[w * RTS + row][col] = B[(w * RTS + tiledRow) * N + globalCol];
|
||
|
}
|
||
|
}
|
||
|
|
||
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||
|
|
||
|
// result for tile
|
||
|
for (int i = 0; i < TILE_SIZE; i++) {
|
||
|
for (int j = 0; j < WPT; j++) {
|
||
|
res[j] += Asub[j * RTS + row][i] * Bsub[i][col];
|
||
|
}
|
||
|
}
|
||
|
|
||
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||
|
}
|
||
|
|
||
|
// final results in C
|
||
|
for (int w = 0; w < WPT; w++) {
|
||
|
if ((w * RTS + globalRow < M) && (globalCol < N)) {
|
||
|
C[(w * RTS + globalRow) * N + globalCol] = res[w];
|
||
|
}
|
||
|
}
|
||
|
|
||
|
}
|
||
|
#endif
|