chundoong-lab-ta/SamsungDS22/submissions/HW5/yw0.kim/kernel.cl

139 lines
3.6 KiB
Common Lisp

// super super slow sgemm kernel by heehoon
#define TILE_SIZE 32
#define WORK_PER_THREAD 8
#define REG_TILE_SIZE (TILE_SIZE / WORK_PER_THREAD)
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K)
{
int lr = get_local_id(0);
int lc = get_local_id(1);
int gr = TILE_SIZE * get_group_id(0) + lr;
int gc = TILE_SIZE * get_group_id(1) + lc;
// printf("group_id: (%d, %d), (gr, gc): (%d, %d), (lr, lc): (%d, %d)\n", (int)get_group_id(0), (int)get_group_id(1), gr, gc, lr, lc);
__local float A_mini[TILE_SIZE][TILE_SIZE];
__local float B_mini[TILE_SIZE][TILE_SIZE];
float temp[WORK_PER_THREAD];
for(int i = 0; i < WORK_PER_THREAD; i++)
{
temp[i] = 0.0;
}
int tiles = (K + TILE_SIZE - 1) / TILE_SIZE;
int tr = lr;
int tc = lc;
for(int tile = 0; tile < tiles; tile++)
{
int tileRow = lr;
int tileCol = lc;
int Arow = gr;
int Acol = tc;
int Brow = tr;
int Bcol = gc;
for(int work = 0; work < WORK_PER_THREAD; work++)
{
// if(gr == 32 && gc >= 32 && tile == 1 && work == 0)
// {
// printf("(gr, gc): (%d, %d), A_mini[%d][%d], B_mini[%d][%d], A[%d][%d], B[%d][%d]\n", gr, gc, (lr + work * REG_TILE_SIZE), lc, (lr + work * REG_TILE_SIZE), lc, (gr + work * REG_TILE_SIZE), tc, (tr + work * REG_TILE_SIZE), gc);
// }
if(Arow < M && Acol < K)
{
A_mini[tileRow][tileCol] = A[Arow * K + tc];
}
else
{
A_mini[tileRow][tileCol] = 0.0;
}
if(Brow < K && Bcol < N)
{
B_mini[tileRow][tileCol] = B[Brow * N + gc];
}
else
{
B_mini[tileRow][tileCol] = 0.0;
}
tileRow += REG_TILE_SIZE;
Arow += REG_TILE_SIZE;
Brow += REG_TILE_SIZE;
}
barrier(CLK_LOCAL_MEM_FENCE);
// if(gr == 32 && gc == 32 && tile == 1)
// {
// printf("(gr, gc): (%d, %d), (lr, lc): (%d, %d)\n", gr, gc, lr, lc);
// printf("------ A_mini ------\n");
// for(int i = 0; i < TILE_SIZE; i++)
// {
// for(int j = 0; j < TILE_SIZE; j++)
// {
// printf("%.3f ", A_mini[i][j]);
// }
// printf("\n");
// }
// printf("------ B_mini ------\n");
// for(int i = 0; i < TILE_SIZE; i++)
// {
// for(int j = 0; j < TILE_SIZE; j++)
// {
// printf("%.3f ", B_mini[i][j]);
// }
// printf("\n");
// }
// printf("------ C ------\n");
// for(int i = 0; i < M; i++)
// {
// for(int j = 0; j < N; j++)
// {
// printf("%.3f ", C[i * N + j]);
// }
// printf("\n");
// }
// }
// barrier(CLK_LOCAL_MEM_FENCE);
for(int k = 0; k < TILE_SIZE; k++)
{
int AtileRow = lr;
int BtileCol = lc;
for(int work = 0; work < WORK_PER_THREAD; work++)
{
temp[work] += A_mini[AtileRow][k] * B_mini[k][BtileCol];
// if(gr == 32 && gc == 0 && work == 0)
// {
// printf("temp[0]: %.3f, %.3f * %.3f = %.3f\n", temp[0], A_mini[lr + work * REG_TILE_SIZE][k], B_mini[k][lc], A_mini[lr + work * REG_TILE_SIZE][k] * B_mini[k][lc]);
// }
AtileRow += REG_TILE_SIZE;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
tr += TILE_SIZE;
tc += TILE_SIZE;
// if(gr == 32 && gc == 32)
// {
// printf("temp[0] = %.3f\n", temp[0]);
// }
}
int Crow = gr;
int Ccol = gc;
for(int work = 0; work < WORK_PER_THREAD; work++)
{
if(Crow < M && Ccol < N)
{
C[Crow * N + Ccol] = temp[work];
}
Crow += REG_TILE_SIZE;
}
}