47 lines
1.1 KiB
Common Lisp
47 lines
1.1 KiB
Common Lisp
#define BLOCK_SIZE 16
|
|
|
|
#define AS(i, j) As[j + i * BLOCK_SIZE]
|
|
#define BS(i, j) Bs[j + i * BLOCK_SIZE]
|
|
|
|
__kernel void sgemm(__global float *A, __global float *B, __global float *C, __local float *As, __local float *Bs, int M, int N, int K) {
|
|
// block index
|
|
int bx = get_group_id(0);
|
|
int by = get_group_id(1);
|
|
|
|
// thread index
|
|
int tx = get_local_id(0);
|
|
int ty = get_local_id(1);
|
|
|
|
// index of sub-matrix of A
|
|
int aBegin = K * BLOCK_SIZE * by;
|
|
int aEnd = aBegin + K - 1;
|
|
|
|
// step size through sub-matrice of A
|
|
int aStep = BLOCK_SIZE;
|
|
|
|
// index of the first sub-matrix of B
|
|
int bBegin = BLOCK_SIZE * bx;
|
|
int bStep = BLOCK_SIZE * N;
|
|
|
|
// Csub
|
|
float Csub = 0.0f;
|
|
|
|
for (int a = aBegin, b = bBegin; a <= aEnd; a += aStep, b += bStep) {
|
|
AS(ty, tx) = A[a + K * ty + tx];
|
|
BS(ty, tx) = B[b + N * ty + tx];
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
|
|
#pragma unroll
|
|
for (int k = 0; k < BLOCK_SIZE; k++) {
|
|
Csub += AS(ty, k) * BS(k, tx);
|
|
}
|
|
|
|
barrier(CLK_LOCAL_MEM_FENCE);
|
|
}
|
|
|
|
// if (get_global_id(1) < M) {
|
|
C[get_global_id(1) * get_global_size(0) + get_global_id(0)] = Csub;
|
|
// }
|
|
}
|