119 lines
3.6 KiB
Common Lisp
119 lines
3.6 KiB
Common Lisp
// super super slow sgemm kernel by heehoon
|
|
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
|
|
const int i = get_global_id(0); // row index of C
|
|
const int j = get_global_id(1); // column index of C
|
|
if (i >= M || j >= N) return; // boundary check
|
|
|
|
float acc = 0;
|
|
#pragma unroll
|
|
for (int k = 0; k < K; k++) {
|
|
acc += A[i * K + k] * B[k * N + j];
|
|
}
|
|
C[i * N + j] = acc;
|
|
}
|
|
//==================================================================
|
|
// Matrix Multiplication
|
|
//------------------------------------------------------------------
|
|
// N
|
|
// A: K columns, M rows o-----o
|
|
// B: N columns, K rows | |
|
|
// C: N columns, M rows K | [B] |
|
|
// | |
|
|
// o-----o
|
|
// K N
|
|
// o-------o o-----o
|
|
// M | [A] | M | [C] |
|
|
// | | | |
|
|
// o-------o o-----o
|
|
//==================================================================
|
|
// #define TSM 128 // Tile-size of M
|
|
// #define TSN 128 // Tile-size of N
|
|
// #define TSK 16 // Tile-size of K
|
|
// #define WPT 8 // Work-per-thread
|
|
//
|
|
// __kernel void sgemm(const __global float *A, const __global float *B, __global float *C, const int M, const int N, const int K) {
|
|
//
|
|
// // Local index
|
|
// const int local_m = get_local_id(0);
|
|
// const int local_n = get_local_id(1);
|
|
//
|
|
// // Work-group offset
|
|
// const int offset_m = TSM * get_group_id(0);
|
|
// const int offset_n = TSN * get_group_id(1);
|
|
//
|
|
// // Local memory to fit a tile
|
|
// __local float A_sub [TSK][TSM];
|
|
// __local float B_sub [TSN][TSK+2]; // TODO : why +2 ?
|
|
//
|
|
// // Allocate register space
|
|
// float A_reg;
|
|
// float B_reg[WPT];
|
|
//
|
|
// // Initialize
|
|
// float acc[WPT][WPT];
|
|
// #pragma unroll
|
|
// for (int wm = 0; wm < WPT; wm++) {
|
|
// #pragma unroll
|
|
// for (int wn = 0; wn < WPT; wn++) {
|
|
// acc[wm][wn] = 0.0f;
|
|
// }
|
|
// }
|
|
//
|
|
// // Loop over all tiles
|
|
// for (int k = 0; k < K; k += TSK) {
|
|
//
|
|
// // Load per Thread
|
|
// #pragma unroll
|
|
// for (int la = 0; la < (TSK * WPT * WPT); la++) {
|
|
// int tid = local_n * (TSM / WPT) + local_m; // Thread id
|
|
// int id = la * (TSN / WPT) * (TSM / WPT) + tid;
|
|
// int row = id & TSM;
|
|
// int col = id / TSM;
|
|
//
|
|
// int tiledIndex = k + col;
|
|
// A_sub[col][row] = A[tiledIndex*M + offset_m + row];
|
|
// B_sub[row][col] = B[tiledIndex*N + offset_n + row];
|
|
// }
|
|
//
|
|
// // Synchronize
|
|
// barrier(CLK_LOCAL_MEM_FENCE);
|
|
//
|
|
// // Loop over the values of a single tile
|
|
// for (int k=0; k<TSK; k++) {
|
|
//
|
|
// // Cache the values of B_sub in registers
|
|
// #pragma unroll
|
|
// for (int wn=0; wn<WPT; wn++) {
|
|
// int col = local_n + wn * (TSN / WPT);
|
|
// B_reg[wn] = B_sub[col][k];
|
|
// }
|
|
//
|
|
// // Perform the computation
|
|
// #pragma unroll
|
|
// for (int wm = 0; wm < WPT; wm++) {
|
|
// int row = local_m + wm * (TSM / WPT);
|
|
// A_reg = A_sub[k][row];
|
|
// #pragma unroll
|
|
// for (int wn = 0; wn < WPT; wn++) {
|
|
// acc[wm][wn] += A_reg * B_reg[wn];
|
|
// }
|
|
// }
|
|
// }
|
|
//
|
|
// // Synchronize
|
|
// barrier(CLK_LOCAL_MEM_FENCE);
|
|
// }
|
|
//
|
|
// // Store the final results in C
|
|
// #pragma unroll
|
|
// for (int wm=0; wm<WPT; wm++) {
|
|
// int globalRow = offset_m + local_m + wm * (TSM / WPT);
|
|
// #pragma unroll
|
|
// for (int wn=0; wn<WPT; wn++) {
|
|
// int globalCol = offset_n + local_n + wn * (TSN / WPT);
|
|
// C[globalCol*M + globalRow] = acc[wm][wn];
|
|
// }
|
|
// }
|
|
// }
|
|
//
|