416 lines
14 KiB
Common Lisp
416 lines
14 KiB
Common Lisp
|
// super super slow sgemm kernel by heehoon
|
||
|
/*__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
|
||
|
int i = get_global_id(0); // row index of C
|
||
|
int j = get_global_id(1); // column index of C
|
||
|
if (i >= M || j >= N) return; // boundary check
|
||
|
|
||
|
C[i * N + j] = 0;
|
||
|
for (int k = 0; k < K; k++) {
|
||
|
C[i * N + j] += A[i * K + k] * B[k * N + j];
|
||
|
}
|
||
|
}*/
|
||
|
|
||
|
|
||
|
// 2nd method - Tiled
|
||
|
/*
|
||
|
#define TS 32
|
||
|
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
|
||
|
// int i = get_global_id(0); // row index of C
|
||
|
// int j = get_global_id(1); // column index of C
|
||
|
//if (i >= M || j >= N) return; // boundary check
|
||
|
|
||
|
const int row = get_local_id(0); // local row ID (0~31)
|
||
|
const int col = get_local_id(1); // local col ID (0~31)
|
||
|
const int global_row = TS * get_group_id(0) + row; // row ID (0~M)
|
||
|
const int global_col = TS * get_group_id(1) + col; // row ID (0~N)
|
||
|
if (global_row >= M || global_col >= N) return; // boundary check
|
||
|
|
||
|
__local float Asub[TS][TS];
|
||
|
__local float Bsub[TS][TS];
|
||
|
|
||
|
float acc_val = 0.0f;
|
||
|
const int num_tiles = K / TS;
|
||
|
for(int t=0; t<num_tiles; t++){
|
||
|
const int t_row = TS*t + row;
|
||
|
const int t_col = TS*t + col;
|
||
|
Asub[row][col] = A[global_row*K + t_col];
|
||
|
Bsub[row][col] = B[t_row*N + global_col];
|
||
|
|
||
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||
|
|
||
|
for(int k=0; k<TS; k++){
|
||
|
acc_val += Asub[row][k] * Bsub[k][col];
|
||
|
}
|
||
|
|
||
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||
|
}
|
||
|
|
||
|
C[global_row*N + global_col] = acc_val;
|
||
|
}*/
|
||
|
|
||
|
// 3rd method - Tiled + More work
|
||
|
/*
|
||
|
#define TS 32
|
||
|
#define WPT 8
|
||
|
#define RTS (TS/WPT)
|
||
|
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
|
||
|
const int row = get_local_id(0); // local row ID (0~31)
|
||
|
const int col = get_local_id(1); // local col ID (0~4)
|
||
|
const int global_row = TS * get_group_id(0) + row; // row ID (M IDs)
|
||
|
const int global_col = TS * get_group_id(1) + col; // row ID (N/WPT IDs)
|
||
|
if (global_row >= M || global_col >= N) return; // boundary check
|
||
|
|
||
|
__local float Asub[TS][TS];
|
||
|
__local float Bsub[TS][TS];
|
||
|
|
||
|
float acc_val[WPT] = {0.0f,};
|
||
|
|
||
|
const int num_tiles = K / TS;
|
||
|
for(int t=0; t<num_tiles; t++){
|
||
|
const int t_row = TS*t + row;
|
||
|
const int t_col = TS*t + col;
|
||
|
for(int w=0; w<WPT; w++){
|
||
|
Asub[row][col + w*RTS] = A[global_row*K + t_col + w*RTS];
|
||
|
Bsub[row][col + w*RTS] = B[t_row*N + global_col + w*RTS];
|
||
|
}
|
||
|
|
||
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||
|
|
||
|
for(int k=0; k<TS; k++){
|
||
|
for(int w=0; w<WPT; w++){
|
||
|
acc_val[w] += Asub[row][k] * Bsub[k][col+w*RTS];
|
||
|
}
|
||
|
}
|
||
|
|
||
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||
|
}
|
||
|
|
||
|
for(int w=0; w<WPT; w++){
|
||
|
C[global_row*N + global_col+w*RTS] = acc_val[w];
|
||
|
}
|
||
|
}*/
|
||
|
|
||
|
|
||
|
// 4th method - Tile + Wide4
|
||
|
/*
|
||
|
#define TS 32
|
||
|
#define WIDTH 4
|
||
|
__kernel void sgemm(__global float4 *A, __global float4 *B, __global float4 *C, int M, int N, int K) {
|
||
|
const int row = get_local_id(0); // local row ID (0~31)
|
||
|
const int col = get_local_id(1); // local col ID (TS/WIDTH)
|
||
|
const int global_row = TS * get_group_id(0) + row; // row ID (M IDs)
|
||
|
const int global_col = (TS/WIDTH) * get_group_id(1) + col; // row ID (N/WIDTH IDs)
|
||
|
|
||
|
__local float4 Asub[TS][TS/WIDTH];
|
||
|
__local float4 Bsub[TS][TS/WIDTH];
|
||
|
|
||
|
float4 acc_val = {0.0f, 0.0f, 0.0f, 0.0f};
|
||
|
|
||
|
const int num_tiles = K / TS;
|
||
|
for(int t=0; t<num_tiles; t++){
|
||
|
const int t_row = TS*t + row;
|
||
|
const int t_col = (TS/WIDTH)*t + col;
|
||
|
Asub[row][col] = A[global_row*(K/WIDTH) + t_col];
|
||
|
Bsub[row][col] = B[t_row*(N/WIDTH) + global_col];
|
||
|
|
||
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||
|
|
||
|
float4 vecA, vecB;
|
||
|
float valA;
|
||
|
for(int k=0; k<TS/WIDTH; k++){
|
||
|
vecA = Asub[row][k];
|
||
|
for(int w=0; w<WIDTH; w++){
|
||
|
vecB = Bsub[WIDTH*k + w][col];
|
||
|
switch(w){
|
||
|
case 0: valA = vecA.x; break;
|
||
|
case 1: valA = vecA.y; break;
|
||
|
case 2: valA = vecA.z; break;
|
||
|
case 3: valA = vecA.w; break;
|
||
|
}
|
||
|
acc_val.x += vecB.x * valA;
|
||
|
acc_val.y += vecB.y * valA;
|
||
|
acc_val.z += vecB.z * valA;
|
||
|
acc_val.w += vecB.w * valA;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||
|
}
|
||
|
|
||
|
C[global_row*(N/WIDTH) + global_col] = acc_val;
|
||
|
}*/
|
||
|
|
||
|
// 5th method - Tiled + Wide8
|
||
|
/*
|
||
|
#define TS 32
|
||
|
#define WIDTH 8
|
||
|
__kernel void sgemm(__global float8 *A, __global float8 *B, __global float8 *C, int M, int N, int K) {
|
||
|
const int row = get_local_id(0); // local row ID (0~31)
|
||
|
const int col = get_local_id(1); // local col ID (TS/WIDTH)
|
||
|
const int global_row = TS * get_group_id(0) + row; // row ID (M IDs)
|
||
|
const int global_col = (TS/WIDTH) * get_group_id(1) + col; // row ID (N/WIDTH IDs)
|
||
|
|
||
|
__local float8 Asub[TS][TS/WIDTH];
|
||
|
__local float8 Bsub[TS][TS/WIDTH];
|
||
|
|
||
|
float8 acc_val = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||
|
|
||
|
const int num_tiles = K / TS;
|
||
|
for(int t=0; t<num_tiles; t++){
|
||
|
const int t_row = TS*t + row;
|
||
|
const int t_col = (TS/WIDTH)*t + col;
|
||
|
Asub[row][col] = A[global_row*(K/WIDTH) + t_col];
|
||
|
Bsub[row][col] = B[t_row*(N/WIDTH) + global_col];
|
||
|
|
||
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||
|
|
||
|
float8 vecA, vecB;
|
||
|
float valA;
|
||
|
for(int k=0; k<TS/WIDTH; k++){
|
||
|
vecA = Asub[row][k];
|
||
|
for(int w=0; w<WIDTH; w++){
|
||
|
vecB = Bsub[WIDTH*k + w][col];
|
||
|
switch(w){
|
||
|
case 0: valA = vecA.s0; break;
|
||
|
case 1: valA = vecA.s1; break;
|
||
|
case 2: valA = vecA.s2; break;
|
||
|
case 3: valA = vecA.s3; break;
|
||
|
case 4: valA = vecA.s4; break;
|
||
|
case 5: valA = vecA.s5; break;
|
||
|
case 6: valA = vecA.s6; break;
|
||
|
case 7: valA = vecA.s7; break;
|
||
|
}
|
||
|
acc_val += vecB * valA;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||
|
}
|
||
|
|
||
|
C[global_row*(N/WIDTH) + global_col] = acc_val;
|
||
|
}*/
|
||
|
|
||
|
// 6th method - Tiled + Wide8 + Unaligned -Vector8 Load problem!! -.-
|
||
|
/*
|
||
|
#define TS 32
|
||
|
#define WIDTH 8
|
||
|
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
|
||
|
const int row = get_local_id(0); // local row ID (0~31)
|
||
|
const int col = get_local_id(1); // local col ID (TS/WIDTH)
|
||
|
const int global_row = TS * get_group_id(0) + row; // row ID (M IDs)
|
||
|
const int global_col = (TS/WIDTH) * get_group_id(1) + col; // row ID (N/WIDTH IDs)
|
||
|
// Eliminate remaining thread
|
||
|
if((global_row >= M) || (global_col*WIDTH >= N)) return;
|
||
|
|
||
|
//const int incomplete_ws_row = (M - global_row) < (M % TS);
|
||
|
//const int incomplete_ws_col = (N - global_col*WIDTH) < (N % TS);
|
||
|
//const int remain_ws_col = min(N - global_col*WIDTH, TS);
|
||
|
const int remain_th_N = min(N - global_col*WIDTH, WIDTH);
|
||
|
const int incomplete_ws_row = (M - global_row) <= (M % TS);
|
||
|
const int incomplete_ws_col = (N - global_col*WIDTH) <= (N % TS);
|
||
|
|
||
|
__local float8 Asub[TS][TS/WIDTH];
|
||
|
__local float8 Bsub[TS][TS/WIDTH];
|
||
|
|
||
|
float8 acc_val = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
|
||
|
|
||
|
|
||
|
const int num_tiles = K/TS;
|
||
|
//C[global_row*N + global_col] = 100.0f*incomplete_ws_row + 10.0f*incomplete_ws_col + 1.0f;
|
||
|
//C[global_row*N + global_col] = remain_th_N*1.0f;
|
||
|
//C[global_row*N + global_col] = num_tiles*-1.0f;
|
||
|
//return;
|
||
|
|
||
|
if(incomplete_ws_row == 0 && incomplete_ws_col == 0) {
|
||
|
for(int t=0; t<num_tiles; t++) {
|
||
|
const int t_row = TS*t + row;
|
||
|
const int t_col = (TS/WIDTH)*t + col;
|
||
|
|
||
|
float *cur_A = A + global_row*K + t_col*WIDTH;
|
||
|
//Asub[row][col] = *(float8 *)cur_A;
|
||
|
Asub[row][col].s0 = *(cur_A + 0);
|
||
|
Asub[row][col].s1 = *(cur_A + 1);
|
||
|
Asub[row][col].s2 = *(cur_A + 2);
|
||
|
Asub[row][col].s3 = *(cur_A + 3);
|
||
|
Asub[row][col].s4 = *(cur_A + 4);
|
||
|
Asub[row][col].s5 = *(cur_A + 5);
|
||
|
Asub[row][col].s6 = *(cur_A + 6);
|
||
|
Asub[row][col].s7 = *(cur_A + 7);
|
||
|
|
||
|
float *cur_B = B + t_row*N + global_col*WIDTH;
|
||
|
//Bsub[row][col] = *(float8 *)cur_B;
|
||
|
Bsub[row][col].s0 = *(cur_B + 0);
|
||
|
Bsub[row][col].s1 = *(cur_B + 1);
|
||
|
Bsub[row][col].s2 = *(cur_B + 2);
|
||
|
Bsub[row][col].s3 = *(cur_B + 3);
|
||
|
Bsub[row][col].s4 = *(cur_B + 4);
|
||
|
Bsub[row][col].s5 = *(cur_B + 5);
|
||
|
Bsub[row][col].s6 = *(cur_B + 6);
|
||
|
Bsub[row][col].s7 = *(cur_B + 7);
|
||
|
|
||
|
|
||
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||
|
|
||
|
float8 vecA, vecB;
|
||
|
float valA;
|
||
|
for(int k=0; k<TS/WIDTH; k++) {
|
||
|
vecA = Asub[row][k];
|
||
|
for(int w=0; w<WIDTH; w++) {
|
||
|
vecB = Bsub[WIDTH*k + w][col];
|
||
|
switch(w){
|
||
|
case 0: valA = vecA.s0; break;
|
||
|
case 1: valA = vecA.s1; break;
|
||
|
case 2: valA = vecA.s2; break;
|
||
|
case 3: valA = vecA.s3; break;
|
||
|
case 4: valA = vecA.s4; break;
|
||
|
case 5: valA = vecA.s5; break;
|
||
|
case 6: valA = vecA.s6; break;
|
||
|
case 7: valA = vecA.s7; break;
|
||
|
}
|
||
|
acc_val += vecB * valA;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||
|
}
|
||
|
|
||
|
//Deal with the last tile if exist
|
||
|
}
|
||
|
else {
|
||
|
for(int th=0; th<remain_th_N; th++) {
|
||
|
for(int k=0; k<K; k++) {
|
||
|
float val = A[global_row*K + k] * B[k*N + global_col*WIDTH + th];
|
||
|
switch(th) {
|
||
|
case 0: acc_val.s0 += val; break;
|
||
|
case 1: acc_val.s1 += val; break;
|
||
|
case 2: acc_val.s2 += val; break;
|
||
|
case 3: acc_val.s3 += val; break;
|
||
|
case 4: acc_val.s4 += val; break;
|
||
|
case 5: acc_val.s5 += val; break;
|
||
|
case 6: acc_val.s6 += val; break;
|
||
|
case 7: acc_val.s7 += val; break;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//C[global_row] = remain_th_N; return;
|
||
|
float *cur_C = (C + global_row*N + global_col*WIDTH);
|
||
|
//if(remain_th_N == WIDTH){
|
||
|
// *(float8 *)(cur_C) = acc_val; // Alignment fault??
|
||
|
//}else
|
||
|
{
|
||
|
for(int th=0; th<remain_th_N; th++) {
|
||
|
switch(th) {
|
||
|
case 0: *(cur_C + 0) = acc_val.s0; break;
|
||
|
case 1: *(cur_C + 1) = acc_val.s1; break;
|
||
|
case 2: *(cur_C + 2) = acc_val.s2; break;
|
||
|
case 3: *(cur_C + 3) = acc_val.s3; break;
|
||
|
case 4: *(cur_C + 4) = acc_val.s4; break;
|
||
|
case 5: *(cur_C + 5) = acc_val.s5; break;
|
||
|
case 6: *(cur_C + 6) = acc_val.s6; break;
|
||
|
case 7: *(cur_C + 7) = acc_val.s7; break;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
//*cur_C = A[global_row*K + 0];
|
||
|
//*cur_C = acc_val.s0;
|
||
|
|
||
|
// Check if there is an incomplete tile.
|
||
|
|
||
|
//if(K == num_tiles*TS){
|
||
|
// *(float8 *)(cur_C) = acc_val;
|
||
|
//}else{
|
||
|
// //Calc the remaining number of works for the current thread. (0~8)
|
||
|
// float *cur_A = A + global_row*K + t_col*WIDTH;
|
||
|
// float *cur_B = B + t_row*N + global_col*WIDTH;
|
||
|
// int remaining_num = min(K - t_col*WIDTH, WIDTH);
|
||
|
// for(int r=0; r<remaining_num; r++){
|
||
|
// acc_
|
||
|
// }
|
||
|
//}
|
||
|
|
||
|
//int adj_WIDTH = min(K - t_col*WIDTH, WIDTH);
|
||
|
|
||
|
//if(adj_WIDTH == WIDTH)
|
||
|
//else
|
||
|
//{
|
||
|
// //Last
|
||
|
// for(int i=0; i<adj_WIDTH; i++){
|
||
|
// switch(i){
|
||
|
// case 0: *(cur_C+i) = acc_val.s0; break;
|
||
|
// case 1: *(cur_C+i) = acc_val.s1; break;
|
||
|
// case 2: *(cur_C+i) = acc_val.s2; break;
|
||
|
// case 3: *(cur_C+i) = acc_val.s3; break;
|
||
|
// case 4: *(cur_C+i) = acc_val.s4; break;
|
||
|
// case 5: *(cur_C+i) = acc_val.s5; break;
|
||
|
// case 6: *(cur_C+i) = acc_val.s6; break;
|
||
|
// case 7: *(cur_C+i) = acc_val.s7; break;
|
||
|
// }
|
||
|
// }
|
||
|
//}
|
||
|
}*/
|
||
|
|
||
|
// 7rd method - Tiled + More work + Unaligned
|
||
|
|
||
|
#define TS 32
|
||
|
#define WPT 8
|
||
|
#define RTS (TS/WPT)
|
||
|
|
||
|
__kernel void sgemm(__global float *A, __global float *B, __global float *C, int M, int N, int K) {
|
||
|
// Thread identifiers
|
||
|
const int row = get_local_id(0); // Local row ID (max: TS)
|
||
|
const int col = get_local_id(1); // Local col ID (max: TS/WPT == RTS)
|
||
|
const int globalRow = TS*get_group_id(0) + row; // Row ID of C (0..M)
|
||
|
const int globalCol = TS*get_group_id(1) + col; // Col ID of C (0..N)
|
||
|
//if(globalCol >= N) return;
|
||
|
|
||
|
// Local memory to fit a tile of TS*TS elements of A and B
|
||
|
__local float Asub[TS][TS];
|
||
|
__local float Bsub[TS][TS];
|
||
|
|
||
|
// Initialise the accumulation registers
|
||
|
float acc[WPT];
|
||
|
for (int w=0; w<WPT; w++) {
|
||
|
acc[w] = 0.0f;
|
||
|
}
|
||
|
|
||
|
// Loop over all tiles
|
||
|
const int numTiles = (K+TS-1)/TS;
|
||
|
for (int t=0; t<numTiles; t++) {
|
||
|
|
||
|
// Load one tile of A and B into local memory
|
||
|
for (int w=0; w<WPT; w++) {
|
||
|
const int tiledRow = TS*t + row;
|
||
|
const int tiledCol = TS*t + col;
|
||
|
if((globalRow + w*RTS)<M && tiledCol < K)
|
||
|
Asub[row + w*RTS][col] = A[(globalRow + w*RTS)*K + tiledCol];
|
||
|
else
|
||
|
Asub[row + w*RTS][col] = 0.0f;
|
||
|
if((tiledRow + w*RTS)<K && globalCol < N)
|
||
|
Bsub[row + w*RTS][col] = B[(tiledRow + w*RTS)*N + globalCol];
|
||
|
else
|
||
|
Bsub[row + w*RTS][col] = 0.0f;
|
||
|
}
|
||
|
|
||
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||
|
|
||
|
// Perform the computation for a single tile
|
||
|
int fix_TS = min(K-t*TS, TS);
|
||
|
for (int k=0; k<TS; k++) {
|
||
|
for (int w=0; w<WPT; w++) {
|
||
|
acc[w] += Asub[row + w*RTS][k] * Bsub[k][col];
|
||
|
}
|
||
|
}
|
||
|
|
||
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||
|
}
|
||
|
|
||
|
// Store the final results in C
|
||
|
for (int w=0; w<WPT; w++) {
|
||
|
if((globalRow+w*RTS) < M && globalCol <N)
|
||
|
C[(globalRow + w*RTS)*N + globalCol] = acc[w];
|
||
|
}
|
||
|
}
|
||
|
|