// super super slow sgemm kernel by heehoon #define TILE_SIZE 32 #define WORK_PER_THREAD 8 #define REG_TILE_SIZE (TILE_SIZE / WORK_PER_THREAD) __kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) { int lr = get_local_id(0); int lc = get_local_id(1); int gr = TILE_SIZE * get_group_id(0) + lr; int gc = TILE_SIZE * get_group_id(1) + lc; // printf("group_id: (%d, %d), (gr, gc): (%d, %d), (lr, lc): (%d, %d)\n", (int)get_group_id(0), (int)get_group_id(1), gr, gc, lr, lc); __local float A_mini[TILE_SIZE][TILE_SIZE]; __local float B_mini[TILE_SIZE][TILE_SIZE]; float temp[WORK_PER_THREAD]; for(int i = 0; i < WORK_PER_THREAD; i++) { temp[i] = 0.0; } int tiles = (K + TILE_SIZE - 1) / TILE_SIZE; int tr = lr; int tc = lc; for(int tile = 0; tile < tiles; tile++) { int tileRow = lr; int tileCol = lc; int Arow = gr; int Acol = tc; int Brow = tr; int Bcol = gc; for(int work = 0; work < WORK_PER_THREAD; work++) { // if(gr == 32 && gc >= 32 && tile == 1 && work == 0) // { // printf("(gr, gc): (%d, %d), A_mini[%d][%d], B_mini[%d][%d], A[%d][%d], B[%d][%d]\n", gr, gc, (lr + work * REG_TILE_SIZE), lc, (lr + work * REG_TILE_SIZE), lc, (gr + work * REG_TILE_SIZE), tc, (tr + work * REG_TILE_SIZE), gc); // } if(Arow < M && Acol < K) { A_mini[tileRow][tileCol] = A[Arow * K + tc]; } else { A_mini[tileRow][tileCol] = 0.0; } if(Brow < K && Bcol < N) { B_mini[tileRow][tileCol] = B[Brow * N + gc]; } else { B_mini[tileRow][tileCol] = 0.0; } tileRow += REG_TILE_SIZE; Arow += REG_TILE_SIZE; Brow += REG_TILE_SIZE; } barrier(CLK_LOCAL_MEM_FENCE); // if(gr == 32 && gc == 32 && tile == 1) // { // printf("(gr, gc): (%d, %d), (lr, lc): (%d, %d)\n", gr, gc, lr, lc); // printf("------ A_mini ------\n"); // for(int i = 0; i < TILE_SIZE; i++) // { // for(int j = 0; j < TILE_SIZE; j++) // { // printf("%.3f ", A_mini[i][j]); // } // printf("\n"); // } // printf("------ B_mini ------\n"); // for(int i = 0; i < TILE_SIZE; i++) // { // for(int j = 0; j < TILE_SIZE; j++) // { // printf("%.3f ", B_mini[i][j]); // } // printf("\n"); // } // printf("------ C ------\n"); // for(int i = 0; i < M; i++) // { // for(int j = 0; j < N; j++) // { // printf("%.3f ", C[i * N + j]); // } // printf("\n"); // } // } // barrier(CLK_LOCAL_MEM_FENCE); for(int k = 0; k < TILE_SIZE; k++) { int AtileRow = lr; int BtileCol = lc; for(int work = 0; work < WORK_PER_THREAD; work++) { temp[work] += A_mini[AtileRow][k] * B_mini[k][BtileCol]; // if(gr == 32 && gc == 0 && work == 0) // { // printf("temp[0]: %.3f, %.3f * %.3f = %.3f\n", temp[0], A_mini[lr + work * REG_TILE_SIZE][k], B_mini[k][lc], A_mini[lr + work * REG_TILE_SIZE][k] * B_mini[k][lc]); // } AtileRow += REG_TILE_SIZE; } } barrier(CLK_LOCAL_MEM_FENCE); tr += TILE_SIZE; tc += TILE_SIZE; // if(gr == 32 && gc == 32) // { // printf("temp[0] = %.3f\n", temp[0]); // } } int Crow = gr; int Ccol = gc; for(int work = 0; work < WORK_PER_THREAD; work++) { if(Crow < M && Ccol < N) { C[Crow * N + Ccol] = temp[work]; } Crow += REG_TILE_SIZE; } }