chundoong-lab-ta/SamsungDS22/submissions/HW5/ym.tai/kernel.cl

80 lines
2.2 KiB
Common Lisp
Raw Normal View History

2022-09-29 18:01:45 +09:00
#if 0
// super super slow sgemm kernel by heehoon
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
int i = get_global_id(0); // row index of C
int j = get_global_id(1); // column index of C
if (i >= M || j >= N) return; // boundary check
C[i * N + j] = 0;
for (int k = 0; k < K; k++) {
C[i * N + j] += A[i * K + k] * B[k * N + j];
}
}
#else
#define TILE_SIZE 32
#define WPT 8 // work per thread
#define RTS (TILE_SIZE / WPT)
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
const int row = get_local_id(0); // Local row ID (max: TILE_SIZE/WPT == RTS)
const int col = get_local_id(1); // Local col ID (max: TILE_SIZE)
const int globalRow = TILE_SIZE * get_group_id(0) + row; // row index of C (N)
const int globalCol = TILE_SIZE * get_group_id(1) + col; // column index of C (M)
const int numTiles = (K + TILE_SIZE - 1)/ TILE_SIZE;
// local memory for tile
__local float Asub[TILE_SIZE][TILE_SIZE];
__local float Bsub[TILE_SIZE][TILE_SIZE];
// Init result memory
float res[WPT];
for (int i = 0; i < WPT; i++) {
res[i] = 0.0f;
}
for (int t = 0; t < numTiles; t++) {
const int tiledRow = TILE_SIZE * t + row;
const int tiledCol = TILE_SIZE * t + col;
// Load A and B to local memory
for (int w = 0; w < WPT; w++) {
if (((w * RTS + globalRow) >= M) || (tiledCol >= K)) {
Asub[w * RTS + row][col] = 0;
}
else {
Asub[w * RTS + row][col] = A[(w * RTS + globalRow) * K + tiledCol];
}
if (((w * RTS + tiledRow) >= K) || (globalCol >= N)) {
Bsub[w * RTS + row][col] = 0;
}
else {
Bsub[w * RTS + row][col] = B[(w * RTS + tiledRow) * N + globalCol];
}
}
barrier(CLK_LOCAL_MEM_FENCE);
// result for tile
for (int i = 0; i < TILE_SIZE; i++) {
for (int j = 0; j < WPT; j++) {
res[j] += Asub[j * RTS + row][i] * Bsub[i][col];
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
// final results in C
for (int w = 0; w < WPT; w++) {
if ((w * RTS + globalRow < M) && (globalCol < N)) {
C[(w * RTS + globalRow) * N + globalCol] = res[w];
}
}
}
#endif