174 lines
5.3 KiB
Common Lisp
174 lines
5.3 KiB
Common Lisp
// super super slow sgemm kernel by heehoon
|
|
#define NUM_WORK_ITEM (32)
|
|
#define VECTOR_WIDTH (16)
|
|
#define RTS (NUM_WORK_ITEM / VECTOR_WIDTH)
|
|
#define USING_NON_VECTOR (1)
|
|
|
|
#if (USING_NON_VECTOR)
|
|
__kernel void sgemm(__global float* __restrict A, __global float* __restrict B, __global float*__restrict C, int M, int N, int K, int NON_OPTIMAL)
|
|
#else
|
|
__kernel void sgemm(__global float16* __restrict A, __global float16* __restrict B, __global float16*__restrict C, int M, int N, int K, int NON_OPTIMAL)
|
|
#endif
|
|
{
|
|
const int i = get_local_id(0); // row index of C
|
|
const int j = get_local_id(1); // column index of C
|
|
const int global_row = NUM_WORK_ITEM * get_group_id(0) + i;
|
|
const int global_col = NUM_WORK_ITEM * get_group_id(1) + j;
|
|
float intermediate_val[16] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
|
|
|
__local float tileA[NUM_WORK_ITEM][NUM_WORK_ITEM];
|
|
__local float tileB[NUM_WORK_ITEM][NUM_WORK_ITEM];
|
|
|
|
if (NON_OPTIMAL == 0)
|
|
{
|
|
#if (USING_NON_VECTOR)
|
|
const int num_tiles = K / NUM_WORK_ITEM;
|
|
|
|
// printf("i : %d, j : %d, global_row : %d, global_col : %d\n", i, j, global_row, global_col);
|
|
for (int t = 0; t < num_tiles; t++)
|
|
{
|
|
for (int w = 0; w < VECTOR_WIDTH; w++)
|
|
{
|
|
const int t_row = NUM_WORK_ITEM * t + i;
|
|
const int t_col = NUM_WORK_ITEM * t + j;
|
|
|
|
tileA[i + w * RTS][j] = A[((global_row + w * RTS)) * K + t_col];
|
|
tileB[i + w * RTS][j] = B[((t_row + w * RTS)) * N + global_col];
|
|
}
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
for (int k = 0; k < NUM_WORK_ITEM; k++)
|
|
{
|
|
for (int w = 0; w < VECTOR_WIDTH; w++)
|
|
{
|
|
intermediate_val[w] += tileA[i + w * RTS][k] * tileB[k][j];
|
|
}
|
|
}
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
}
|
|
|
|
for (int w = 0; w < VECTOR_WIDTH; w++)
|
|
{
|
|
C[(global_row + w * RTS) * N + global_col] = intermediate_val[w];
|
|
}
|
|
#else
|
|
const int global_col = (NUM_WORK_ITEM/VECTOR_WIDTH) * get_group_id(1) + j;
|
|
float16 intermediate_val = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
|
|
|
__local float16 tileA[NUM_WORK_ITEM][NUM_WORK_ITEM/VECTOR_WIDTH];
|
|
__local float16 tileB[NUM_WORK_ITEM][NUM_WORK_ITEM/VECTOR_WIDTH];
|
|
|
|
const int num_tiles = K / NUM_WORK_ITEM;
|
|
for (int t = 0; t < num_tiles; t++)
|
|
{
|
|
const int t_row = NUM_WORK_ITEM * t + i;
|
|
const int t_col = (NUM_WORK_ITEM/VECTOR_WIDTH) * t + j;
|
|
tileA[i][j] = A[global_row * (K/VECTOR_WIDTH) + t_col];
|
|
tileB[i][j] = B[t_row * (N/VECTOR_WIDTH) + global_col];
|
|
|
|
float16 vecA, vecB;
|
|
float valA;
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
for (int k = 0; k < NUM_WORK_ITEM/VECTOR_WIDTH; k++)
|
|
{
|
|
vecA = tileA[i][k];
|
|
for (int w = 0; w < VECTOR_WIDTH; w++)
|
|
{
|
|
vecB = tileB[VECTOR_WIDTH * k + w][j];
|
|
#if 0
|
|
valA = ((float*)&vecA)[w];
|
|
#else
|
|
switch(w)
|
|
{
|
|
case 0: valA = vecA.s0; break;
|
|
case 1: valA = vecA.s1; break;
|
|
case 2: valA = vecA.s2; break;
|
|
case 3: valA = vecA.s3; break;
|
|
case 4: valA = vecA.s4; break;
|
|
case 5: valA = vecA.s5; break;
|
|
case 6: valA = vecA.s6; break;
|
|
case 7: valA = vecA.s7; break;
|
|
case 8: valA = vecA.s8; break;
|
|
case 9: valA = vecA.s9; break;
|
|
case 10: valA = vecA.sA; break;
|
|
case 11: valA = vecA.sB; break;
|
|
case 12: valA = vecA.sC; break;
|
|
case 13: valA = vecA.sD; break;
|
|
case 14: valA = vecA.sE; break;
|
|
case 15: valA = vecA.sF; break;
|
|
}
|
|
#endif
|
|
|
|
intermediate_val += vecB * valA;
|
|
}
|
|
}
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
}
|
|
C[global_row * (N/VECTOR_WIDTH) + global_col] = intermediate_val;
|
|
|
|
#endif
|
|
}
|
|
else
|
|
{
|
|
const int num_tiles = (K + NUM_WORK_ITEM - 1) / NUM_WORK_ITEM;
|
|
|
|
// printf("i : %d, j : %d, global_row : %d, global_col : %d\n", i, j, global_row, global_col);
|
|
for (int t = 0; t < num_tiles; t++)
|
|
{
|
|
for (int w = 0; w < VECTOR_WIDTH; w++)
|
|
{
|
|
const int t_row = NUM_WORK_ITEM * t + i;
|
|
const int t_col = NUM_WORK_ITEM * t + j;
|
|
if (global_row + w * RTS >= M || t_col >= K)
|
|
{
|
|
tileA[i + w * RTS][j] = 0.0f;
|
|
}
|
|
else
|
|
{
|
|
tileA[i + w * RTS][j] = A[((global_row + w * RTS)) * K + t_col];
|
|
}
|
|
|
|
if (t_row + w * RTS >= K || global_col >= N)
|
|
{
|
|
tileB[i + w * RTS][j] = 0.0f;
|
|
}
|
|
else
|
|
{
|
|
tileB[i + w * RTS][j] = B[((t_row + w * RTS)) * N + global_col];
|
|
}
|
|
}
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
for (int k = 0; k < NUM_WORK_ITEM; k++)
|
|
{
|
|
for (int w = 0; w < VECTOR_WIDTH; w++)
|
|
{
|
|
intermediate_val[w] += tileA[i + w * RTS][k] * tileB[k][j];
|
|
}
|
|
}
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
}
|
|
|
|
for (int w = 0; w < VECTOR_WIDTH; w++)
|
|
{
|
|
if(global_row + w * RTS >= M || global_col >= N)
|
|
{
|
|
break;
|
|
}
|
|
else
|
|
{
|
|
C[(global_row + w * RTS) * N + global_col] = intermediate_val[w];
|
|
}
|
|
}
|
|
// printf("C[%d] = , %+.3f, %+.3f, %+.3f, %+.3f,...\n", (global_row) * N + global_col, intermediate_val[0],intermediate_val[1],intermediate_val[2],intermediate_val[3]);
|
|
|
|
}
|
|
}
|