chundoong-lab-ta/SamsungDS22/submissions/HW5/bumhee86.lee/kernel.cl

174 lines
5.3 KiB
Common Lisp

// super super slow sgemm kernel by heehoon
#define NUM_WORK_ITEM (32)
#define VECTOR_WIDTH (16)
#define RTS (NUM_WORK_ITEM / VECTOR_WIDTH)
#define USING_NON_VECTOR (1)
#if (USING_NON_VECTOR)
__kernel void sgemm(__global float* __restrict A, __global float* __restrict B, __global float*__restrict C, int M, int N, int K, int NON_OPTIMAL)
#else
__kernel void sgemm(__global float16* __restrict A, __global float16* __restrict B, __global float16*__restrict C, int M, int N, int K, int NON_OPTIMAL)
#endif
{
const int i = get_local_id(0); // row index of C
const int j = get_local_id(1); // column index of C
const int global_row = NUM_WORK_ITEM * get_group_id(0) + i;
const int global_col = NUM_WORK_ITEM * get_group_id(1) + j;
float intermediate_val[16] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
__local float tileA[NUM_WORK_ITEM][NUM_WORK_ITEM];
__local float tileB[NUM_WORK_ITEM][NUM_WORK_ITEM];
if (NON_OPTIMAL == 0)
{
#if (USING_NON_VECTOR)
const int num_tiles = K / NUM_WORK_ITEM;
// printf("i : %d, j : %d, global_row : %d, global_col : %d\n", i, j, global_row, global_col);
for (int t = 0; t < num_tiles; t++)
{
for (int w = 0; w < VECTOR_WIDTH; w++)
{
const int t_row = NUM_WORK_ITEM * t + i;
const int t_col = NUM_WORK_ITEM * t + j;
tileA[i + w * RTS][j] = A[((global_row + w * RTS)) * K + t_col];
tileB[i + w * RTS][j] = B[((t_row + w * RTS)) * N + global_col];
}
barrier(CLK_LOCAL_MEM_FENCE);
for (int k = 0; k < NUM_WORK_ITEM; k++)
{
for (int w = 0; w < VECTOR_WIDTH; w++)
{
intermediate_val[w] += tileA[i + w * RTS][k] * tileB[k][j];
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
for (int w = 0; w < VECTOR_WIDTH; w++)
{
C[(global_row + w * RTS) * N + global_col] = intermediate_val[w];
}
#else
const int global_col = (NUM_WORK_ITEM/VECTOR_WIDTH) * get_group_id(1) + j;
float16 intermediate_val = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
__local float16 tileA[NUM_WORK_ITEM][NUM_WORK_ITEM/VECTOR_WIDTH];
__local float16 tileB[NUM_WORK_ITEM][NUM_WORK_ITEM/VECTOR_WIDTH];
const int num_tiles = K / NUM_WORK_ITEM;
for (int t = 0; t < num_tiles; t++)
{
const int t_row = NUM_WORK_ITEM * t + i;
const int t_col = (NUM_WORK_ITEM/VECTOR_WIDTH) * t + j;
tileA[i][j] = A[global_row * (K/VECTOR_WIDTH) + t_col];
tileB[i][j] = B[t_row * (N/VECTOR_WIDTH) + global_col];
float16 vecA, vecB;
float valA;
barrier(CLK_LOCAL_MEM_FENCE);
for (int k = 0; k < NUM_WORK_ITEM/VECTOR_WIDTH; k++)
{
vecA = tileA[i][k];
for (int w = 0; w < VECTOR_WIDTH; w++)
{
vecB = tileB[VECTOR_WIDTH * k + w][j];
#if 0
valA = ((float*)&vecA)[w];
#else
switch(w)
{
case 0: valA = vecA.s0; break;
case 1: valA = vecA.s1; break;
case 2: valA = vecA.s2; break;
case 3: valA = vecA.s3; break;
case 4: valA = vecA.s4; break;
case 5: valA = vecA.s5; break;
case 6: valA = vecA.s6; break;
case 7: valA = vecA.s7; break;
case 8: valA = vecA.s8; break;
case 9: valA = vecA.s9; break;
case 10: valA = vecA.sA; break;
case 11: valA = vecA.sB; break;
case 12: valA = vecA.sC; break;
case 13: valA = vecA.sD; break;
case 14: valA = vecA.sE; break;
case 15: valA = vecA.sF; break;
}
#endif
intermediate_val += vecB * valA;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
C[global_row * (N/VECTOR_WIDTH) + global_col] = intermediate_val;
#endif
}
else
{
const int num_tiles = (K + NUM_WORK_ITEM - 1) / NUM_WORK_ITEM;
// printf("i : %d, j : %d, global_row : %d, global_col : %d\n", i, j, global_row, global_col);
for (int t = 0; t < num_tiles; t++)
{
for (int w = 0; w < VECTOR_WIDTH; w++)
{
const int t_row = NUM_WORK_ITEM * t + i;
const int t_col = NUM_WORK_ITEM * t + j;
if (global_row + w * RTS >= M || t_col >= K)
{
tileA[i + w * RTS][j] = 0.0f;
}
else
{
tileA[i + w * RTS][j] = A[((global_row + w * RTS)) * K + t_col];
}
if (t_row + w * RTS >= K || global_col >= N)
{
tileB[i + w * RTS][j] = 0.0f;
}
else
{
tileB[i + w * RTS][j] = B[((t_row + w * RTS)) * N + global_col];
}
}
barrier(CLK_LOCAL_MEM_FENCE);
for (int k = 0; k < NUM_WORK_ITEM; k++)
{
for (int w = 0; w < VECTOR_WIDTH; w++)
{
intermediate_val[w] += tileA[i + w * RTS][k] * tileB[k][j];
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
for (int w = 0; w < VECTOR_WIDTH; w++)
{
if(global_row + w * RTS >= M || global_col >= N)
{
break;
}
else
{
C[(global_row + w * RTS) * N + global_col] = intermediate_val[w];
}
}
// printf("C[%d] = , %+.3f, %+.3f, %+.3f, %+.3f,...\n", (global_row) * N + global_col, intermediate_val[0],intermediate_val[1],intermediate_val[2],intermediate_val[3]);
}
}