139 lines
3.6 KiB
Common Lisp
139 lines
3.6 KiB
Common Lisp
// super super slow sgemm kernel by heehoon
|
|
#define TILE_SIZE 32
|
|
#define WORK_PER_THREAD 8
|
|
#define REG_TILE_SIZE (TILE_SIZE / WORK_PER_THREAD)
|
|
|
|
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K)
|
|
{
|
|
int lr = get_local_id(0);
|
|
int lc = get_local_id(1);
|
|
int gr = TILE_SIZE * get_group_id(0) + lr;
|
|
int gc = TILE_SIZE * get_group_id(1) + lc;
|
|
|
|
// printf("group_id: (%d, %d), (gr, gc): (%d, %d), (lr, lc): (%d, %d)\n", (int)get_group_id(0), (int)get_group_id(1), gr, gc, lr, lc);
|
|
|
|
__local float A_mini[TILE_SIZE][TILE_SIZE];
|
|
__local float B_mini[TILE_SIZE][TILE_SIZE];
|
|
|
|
float temp[WORK_PER_THREAD];
|
|
for(int i = 0; i < WORK_PER_THREAD; i++)
|
|
{
|
|
temp[i] = 0.0;
|
|
}
|
|
|
|
int tiles = (K + TILE_SIZE - 1) / TILE_SIZE;
|
|
int tr = lr;
|
|
int tc = lc;
|
|
for(int tile = 0; tile < tiles; tile++)
|
|
{
|
|
int tileRow = lr;
|
|
int tileCol = lc;
|
|
|
|
int Arow = gr;
|
|
int Acol = tc;
|
|
int Brow = tr;
|
|
int Bcol = gc;
|
|
for(int work = 0; work < WORK_PER_THREAD; work++)
|
|
{
|
|
// if(gr == 32 && gc >= 32 && tile == 1 && work == 0)
|
|
// {
|
|
// printf("(gr, gc): (%d, %d), A_mini[%d][%d], B_mini[%d][%d], A[%d][%d], B[%d][%d]\n", gr, gc, (lr + work * REG_TILE_SIZE), lc, (lr + work * REG_TILE_SIZE), lc, (gr + work * REG_TILE_SIZE), tc, (tr + work * REG_TILE_SIZE), gc);
|
|
// }
|
|
|
|
if(Arow < M && Acol < K)
|
|
{
|
|
A_mini[tileRow][tileCol] = A[Arow * K + tc];
|
|
}
|
|
else
|
|
{
|
|
A_mini[tileRow][tileCol] = 0.0;
|
|
}
|
|
|
|
if(Brow < K && Bcol < N)
|
|
{
|
|
B_mini[tileRow][tileCol] = B[Brow * N + gc];
|
|
}
|
|
else
|
|
{
|
|
B_mini[tileRow][tileCol] = 0.0;
|
|
}
|
|
|
|
tileRow += REG_TILE_SIZE;
|
|
Arow += REG_TILE_SIZE;
|
|
Brow += REG_TILE_SIZE;
|
|
}
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
// if(gr == 32 && gc == 32 && tile == 1)
|
|
// {
|
|
// printf("(gr, gc): (%d, %d), (lr, lc): (%d, %d)\n", gr, gc, lr, lc);
|
|
// printf("------ A_mini ------\n");
|
|
// for(int i = 0; i < TILE_SIZE; i++)
|
|
// {
|
|
// for(int j = 0; j < TILE_SIZE; j++)
|
|
// {
|
|
// printf("%.3f ", A_mini[i][j]);
|
|
// }
|
|
// printf("\n");
|
|
// }
|
|
|
|
// printf("------ B_mini ------\n");
|
|
// for(int i = 0; i < TILE_SIZE; i++)
|
|
// {
|
|
// for(int j = 0; j < TILE_SIZE; j++)
|
|
// {
|
|
// printf("%.3f ", B_mini[i][j]);
|
|
// }
|
|
// printf("\n");
|
|
// }
|
|
|
|
// printf("------ C ------\n");
|
|
// for(int i = 0; i < M; i++)
|
|
// {
|
|
// for(int j = 0; j < N; j++)
|
|
// {
|
|
// printf("%.3f ", C[i * N + j]);
|
|
// }
|
|
// printf("\n");
|
|
// }
|
|
// }
|
|
// barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
for(int k = 0; k < TILE_SIZE; k++)
|
|
{
|
|
int AtileRow = lr;
|
|
int BtileCol = lc;
|
|
for(int work = 0; work < WORK_PER_THREAD; work++)
|
|
{
|
|
temp[work] += A_mini[AtileRow][k] * B_mini[k][BtileCol];
|
|
|
|
// if(gr == 32 && gc == 0 && work == 0)
|
|
// {
|
|
// printf("temp[0]: %.3f, %.3f * %.3f = %.3f\n", temp[0], A_mini[lr + work * REG_TILE_SIZE][k], B_mini[k][lc], A_mini[lr + work * REG_TILE_SIZE][k] * B_mini[k][lc]);
|
|
// }
|
|
AtileRow += REG_TILE_SIZE;
|
|
}
|
|
}
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
tr += TILE_SIZE;
|
|
tc += TILE_SIZE;
|
|
|
|
// if(gr == 32 && gc == 32)
|
|
// {
|
|
// printf("temp[0] = %.3f\n", temp[0]);
|
|
// }
|
|
}
|
|
|
|
int Crow = gr;
|
|
int Ccol = gc;
|
|
for(int work = 0; work < WORK_PER_THREAD; work++)
|
|
{
|
|
if(Crow < M && Ccol < N)
|
|
{
|
|
C[Crow * N + Ccol] = temp[work];
|
|
}
|
|
Crow += REG_TILE_SIZE;
|
|
}
|
|
}
|