chundoong-lab-ta/SamsungDS22/submissions/HW5/youngsik.eom/kernel.cl

416 lines
14 KiB
Common Lisp
Raw Normal View History

2022-09-29 18:01:45 +09:00
// super super slow sgemm kernel by heehoon
/*__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
int i = get_global_id(0); // row index of C
int j = get_global_id(1); // column index of C
if (i >= M || j >= N) return; // boundary check
C[i * N + j] = 0;
for (int k = 0; k < K; k++) {
C[i * N + j] += A[i * K + k] * B[k * N + j];
}
}*/
// 2nd method - Tiled
/*
#define TS 32
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
// int i = get_global_id(0); // row index of C
// int j = get_global_id(1); // column index of C
//if (i >= M || j >= N) return; // boundary check
const int row = get_local_id(0); // local row ID (0~31)
const int col = get_local_id(1); // local col ID (0~31)
const int global_row = TS * get_group_id(0) + row; // row ID (0~M)
const int global_col = TS * get_group_id(1) + col; // row ID (0~N)
if (global_row >= M || global_col >= N) return; // boundary check
__local float Asub[TS][TS];
__local float Bsub[TS][TS];
float acc_val = 0.0f;
const int num_tiles = K / TS;
for(int t=0; t<num_tiles; t++){
const int t_row = TS*t + row;
const int t_col = TS*t + col;
Asub[row][col] = A[global_row*K + t_col];
Bsub[row][col] = B[t_row*N + global_col];
barrier(CLK_LOCAL_MEM_FENCE);
for(int k=0; k<TS; k++){
acc_val += Asub[row][k] * Bsub[k][col];
}
barrier(CLK_LOCAL_MEM_FENCE);
}
C[global_row*N + global_col] = acc_val;
}*/
// 3rd method - Tiled + More work
/*
#define TS 32
#define WPT 8
#define RTS (TS/WPT)
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
const int row = get_local_id(0); // local row ID (0~31)
const int col = get_local_id(1); // local col ID (0~4)
const int global_row = TS * get_group_id(0) + row; // row ID (M IDs)
const int global_col = TS * get_group_id(1) + col; // row ID (N/WPT IDs)
if (global_row >= M || global_col >= N) return; // boundary check
__local float Asub[TS][TS];
__local float Bsub[TS][TS];
float acc_val[WPT] = {0.0f,};
const int num_tiles = K / TS;
for(int t=0; t<num_tiles; t++){
const int t_row = TS*t + row;
const int t_col = TS*t + col;
for(int w=0; w<WPT; w++){
Asub[row][col + w*RTS] = A[global_row*K + t_col + w*RTS];
Bsub[row][col + w*RTS] = B[t_row*N + global_col + w*RTS];
}
barrier(CLK_LOCAL_MEM_FENCE);
for(int k=0; k<TS; k++){
for(int w=0; w<WPT; w++){
acc_val[w] += Asub[row][k] * Bsub[k][col+w*RTS];
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
for(int w=0; w<WPT; w++){
C[global_row*N + global_col+w*RTS] = acc_val[w];
}
}*/
// 4th method - Tile + Wide4
/*
#define TS 32
#define WIDTH 4
__kernel void sgemm(__global float4 *A, __global float4 *B, __global float4 *C, int M, int N, int K) {
const int row = get_local_id(0); // local row ID (0~31)
const int col = get_local_id(1); // local col ID (TS/WIDTH)
const int global_row = TS * get_group_id(0) + row; // row ID (M IDs)
const int global_col = (TS/WIDTH) * get_group_id(1) + col; // row ID (N/WIDTH IDs)
__local float4 Asub[TS][TS/WIDTH];
__local float4 Bsub[TS][TS/WIDTH];
float4 acc_val = {0.0f, 0.0f, 0.0f, 0.0f};
const int num_tiles = K / TS;
for(int t=0; t<num_tiles; t++){
const int t_row = TS*t + row;
const int t_col = (TS/WIDTH)*t + col;
Asub[row][col] = A[global_row*(K/WIDTH) + t_col];
Bsub[row][col] = B[t_row*(N/WIDTH) + global_col];
barrier(CLK_LOCAL_MEM_FENCE);
float4 vecA, vecB;
float valA;
for(int k=0; k<TS/WIDTH; k++){
vecA = Asub[row][k];
for(int w=0; w<WIDTH; w++){
vecB = Bsub[WIDTH*k + w][col];
switch(w){
case 0: valA = vecA.x; break;
case 1: valA = vecA.y; break;
case 2: valA = vecA.z; break;
case 3: valA = vecA.w; break;
}
acc_val.x += vecB.x * valA;
acc_val.y += vecB.y * valA;
acc_val.z += vecB.z * valA;
acc_val.w += vecB.w * valA;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
C[global_row*(N/WIDTH) + global_col] = acc_val;
}*/
// 5th method - Tiled + Wide8
/*
#define TS 32
#define WIDTH 8
__kernel void sgemm(__global float8 *A, __global float8 *B, __global float8 *C, int M, int N, int K) {
const int row = get_local_id(0); // local row ID (0~31)
const int col = get_local_id(1); // local col ID (TS/WIDTH)
const int global_row = TS * get_group_id(0) + row; // row ID (M IDs)
const int global_col = (TS/WIDTH) * get_group_id(1) + col; // row ID (N/WIDTH IDs)
__local float8 Asub[TS][TS/WIDTH];
__local float8 Bsub[TS][TS/WIDTH];
float8 acc_val = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
const int num_tiles = K / TS;
for(int t=0; t<num_tiles; t++){
const int t_row = TS*t + row;
const int t_col = (TS/WIDTH)*t + col;
Asub[row][col] = A[global_row*(K/WIDTH) + t_col];
Bsub[row][col] = B[t_row*(N/WIDTH) + global_col];
barrier(CLK_LOCAL_MEM_FENCE);
float8 vecA, vecB;
float valA;
for(int k=0; k<TS/WIDTH; k++){
vecA = Asub[row][k];
for(int w=0; w<WIDTH; w++){
vecB = Bsub[WIDTH*k + w][col];
switch(w){
case 0: valA = vecA.s0; break;
case 1: valA = vecA.s1; break;
case 2: valA = vecA.s2; break;
case 3: valA = vecA.s3; break;
case 4: valA = vecA.s4; break;
case 5: valA = vecA.s5; break;
case 6: valA = vecA.s6; break;
case 7: valA = vecA.s7; break;
}
acc_val += vecB * valA;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
C[global_row*(N/WIDTH) + global_col] = acc_val;
}*/
// 6th method - Tiled + Wide8 + Unaligned -Vector8 Load problem!! -.-
/*
#define TS 32
#define WIDTH 8
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
const int row = get_local_id(0); // local row ID (0~31)
const int col = get_local_id(1); // local col ID (TS/WIDTH)
const int global_row = TS * get_group_id(0) + row; // row ID (M IDs)
const int global_col = (TS/WIDTH) * get_group_id(1) + col; // row ID (N/WIDTH IDs)
// Eliminate remaining thread
if((global_row >= M) || (global_col*WIDTH >= N)) return;
//const int incomplete_ws_row = (M - global_row) < (M % TS);
//const int incomplete_ws_col = (N - global_col*WIDTH) < (N % TS);
//const int remain_ws_col = min(N - global_col*WIDTH, TS);
const int remain_th_N = min(N - global_col*WIDTH, WIDTH);
const int incomplete_ws_row = (M - global_row) <= (M % TS);
const int incomplete_ws_col = (N - global_col*WIDTH) <= (N % TS);
__local float8 Asub[TS][TS/WIDTH];
__local float8 Bsub[TS][TS/WIDTH];
float8 acc_val = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
const int num_tiles = K/TS;
//C[global_row*N + global_col] = 100.0f*incomplete_ws_row + 10.0f*incomplete_ws_col + 1.0f;
//C[global_row*N + global_col] = remain_th_N*1.0f;
//C[global_row*N + global_col] = num_tiles*-1.0f;
//return;
if(incomplete_ws_row == 0 && incomplete_ws_col == 0) {
for(int t=0; t<num_tiles; t++) {
const int t_row = TS*t + row;
const int t_col = (TS/WIDTH)*t + col;
float *cur_A = A + global_row*K + t_col*WIDTH;
//Asub[row][col] = *(float8 *)cur_A;
Asub[row][col].s0 = *(cur_A + 0);
Asub[row][col].s1 = *(cur_A + 1);
Asub[row][col].s2 = *(cur_A + 2);
Asub[row][col].s3 = *(cur_A + 3);
Asub[row][col].s4 = *(cur_A + 4);
Asub[row][col].s5 = *(cur_A + 5);
Asub[row][col].s6 = *(cur_A + 6);
Asub[row][col].s7 = *(cur_A + 7);
float *cur_B = B + t_row*N + global_col*WIDTH;
//Bsub[row][col] = *(float8 *)cur_B;
Bsub[row][col].s0 = *(cur_B + 0);
Bsub[row][col].s1 = *(cur_B + 1);
Bsub[row][col].s2 = *(cur_B + 2);
Bsub[row][col].s3 = *(cur_B + 3);
Bsub[row][col].s4 = *(cur_B + 4);
Bsub[row][col].s5 = *(cur_B + 5);
Bsub[row][col].s6 = *(cur_B + 6);
Bsub[row][col].s7 = *(cur_B + 7);
barrier(CLK_LOCAL_MEM_FENCE);
float8 vecA, vecB;
float valA;
for(int k=0; k<TS/WIDTH; k++) {
vecA = Asub[row][k];
for(int w=0; w<WIDTH; w++) {
vecB = Bsub[WIDTH*k + w][col];
switch(w){
case 0: valA = vecA.s0; break;
case 1: valA = vecA.s1; break;
case 2: valA = vecA.s2; break;
case 3: valA = vecA.s3; break;
case 4: valA = vecA.s4; break;
case 5: valA = vecA.s5; break;
case 6: valA = vecA.s6; break;
case 7: valA = vecA.s7; break;
}
acc_val += vecB * valA;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
//Deal with the last tile if exist
}
else {
for(int th=0; th<remain_th_N; th++) {
for(int k=0; k<K; k++) {
float val = A[global_row*K + k] * B[k*N + global_col*WIDTH + th];
switch(th) {
case 0: acc_val.s0 += val; break;
case 1: acc_val.s1 += val; break;
case 2: acc_val.s2 += val; break;
case 3: acc_val.s3 += val; break;
case 4: acc_val.s4 += val; break;
case 5: acc_val.s5 += val; break;
case 6: acc_val.s6 += val; break;
case 7: acc_val.s7 += val; break;
}
}
}
}
//C[global_row] = remain_th_N; return;
float *cur_C = (C + global_row*N + global_col*WIDTH);
//if(remain_th_N == WIDTH){
// *(float8 *)(cur_C) = acc_val; // Alignment fault??
//}else
{
for(int th=0; th<remain_th_N; th++) {
switch(th) {
case 0: *(cur_C + 0) = acc_val.s0; break;
case 1: *(cur_C + 1) = acc_val.s1; break;
case 2: *(cur_C + 2) = acc_val.s2; break;
case 3: *(cur_C + 3) = acc_val.s3; break;
case 4: *(cur_C + 4) = acc_val.s4; break;
case 5: *(cur_C + 5) = acc_val.s5; break;
case 6: *(cur_C + 6) = acc_val.s6; break;
case 7: *(cur_C + 7) = acc_val.s7; break;
}
}
}
//*cur_C = A[global_row*K + 0];
//*cur_C = acc_val.s0;
// Check if there is an incomplete tile.
//if(K == num_tiles*TS){
// *(float8 *)(cur_C) = acc_val;
//}else{
// //Calc the remaining number of works for the current thread. (0~8)
// float *cur_A = A + global_row*K + t_col*WIDTH;
// float *cur_B = B + t_row*N + global_col*WIDTH;
// int remaining_num = min(K - t_col*WIDTH, WIDTH);
// for(int r=0; r<remaining_num; r++){
// acc_
// }
//}
//int adj_WIDTH = min(K - t_col*WIDTH, WIDTH);
//if(adj_WIDTH == WIDTH)
//else
//{
// //Last
// for(int i=0; i<adj_WIDTH; i++){
// switch(i){
// case 0: *(cur_C+i) = acc_val.s0; break;
// case 1: *(cur_C+i) = acc_val.s1; break;
// case 2: *(cur_C+i) = acc_val.s2; break;
// case 3: *(cur_C+i) = acc_val.s3; break;
// case 4: *(cur_C+i) = acc_val.s4; break;
// case 5: *(cur_C+i) = acc_val.s5; break;
// case 6: *(cur_C+i) = acc_val.s6; break;
// case 7: *(cur_C+i) = acc_val.s7; break;
// }
// }
//}
}*/
// 7rd method - Tiled + More work + Unaligned
#define TS 32
#define WPT 8
#define RTS (TS/WPT)
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
// Thread identifiers
const int row = get_local_id(0); // Local row ID (max: TS)
const int col = get_local_id(1); // Local col ID (max: TS/WPT == RTS)
const int globalRow = TS*get_group_id(0) + row; // Row ID of C (0..M)
const int globalCol = TS*get_group_id(1) + col; // Col ID of C (0..N)
//if(globalCol >= N) return;
// Local memory to fit a tile of TS*TS elements of A and B
__local float Asub[TS][TS];
__local float Bsub[TS][TS];
// Initialise the accumulation registers
float acc[WPT];
for (int w=0; w<WPT; w++) {
acc[w] = 0.0f;
}
// Loop over all tiles
const int numTiles = (K+TS-1)/TS;
for (int t=0; t<numTiles; t++) {
// Load one tile of A and B into local memory
for (int w=0; w<WPT; w++) {
const int tiledRow = TS*t + row;
const int tiledCol = TS*t + col;
if((globalRow + w*RTS)<M && tiledCol < K)
Asub[row + w*RTS][col] = A[(globalRow + w*RTS)*K + tiledCol];
else
Asub[row + w*RTS][col] = 0.0f;
if((tiledRow + w*RTS)<K && globalCol < N)
Bsub[row + w*RTS][col] = B[(tiledRow + w*RTS)*N + globalCol];
else
Bsub[row + w*RTS][col] = 0.0f;
}
barrier(CLK_LOCAL_MEM_FENCE);
// Perform the computation for a single tile
int fix_TS = min(K-t*TS, TS);
for (int k=0; k<TS; k++) {
for (int w=0; w<WPT; w++) {
acc[w] += Asub[row + w*RTS][k] * Bsub[k][col];
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
// Store the final results in C
for (int w=0; w<WPT; w++) {
if((globalRow+w*RTS) < M && globalCol <N)
C[(globalRow + w*RTS)*N + globalCol] = acc[w];
}
}