chundoong-lab-ta/SamsungDS22/submissions/HW2/km.hero.lee/mat_mul.cpp

178 lines
5.3 KiB
C++
Raw Normal View History

2022-09-29 18:01:45 +09:00
#include "mat_mul.h"
#include <cstdlib>
#include <cstdio>
#include <pthread.h>
#define MIN(x, y) (((x) < (y)) ? (x) : (y))
#define B_SIZE 32
static float *A, *B, *C;
static int M, N, K;
static int num_threads;
static void* mat_mul_thread(void *data) {
/*
* Non-optimized
*/
#ifdef _O0
for (int i = 0; i < M; ++i) {
for (int j = 0; j < N; ++j) {
for (int k = 0; k < K; ++k) {
C[i * N + j] += A[i * K + k] * B[k * N + j];
}
}
}
/*
* Array dimension change
* i-j-k -> i-k-j
*/
#elif _O1
for (int i = 0; i < M; ++i) {
for (int k = 0; k < K; ++k) {
for (int j = 0; j < N; ++j) {
register float tmp = A[i * K + k];
C[i * N + j] += tmp * B[k * N + j];
}
}
}
/*
* Blocked matrix multiplication
* Best block size: 256 bytes
*/
#elif _O2
for (int ii = 0; ii < M; ii+=B_SIZE)
for (int kk = 0; kk < K; kk+=B_SIZE)
for (int jj = 0; jj < N; jj+=B_SIZE)
for (int i=ii; i<MIN(ii+B_SIZE, M); i++)
for (int k=kk; k<MIN(kk+B_SIZE, K); k++) {
register float tmp = A[i * K + k];
for (int j=jj; j<MIN(jj+B_SIZE, N); j++)
C[i * N + j] += tmp * B[k * N + j];
}
/*
* Eliminate redundant blocked dimension
*/
#elif _O3_BASE
for (int ii = 0; ii < M; ii+=B_SIZE)
for (int kk = 0; kk < K; kk+=B_SIZE)
for (int jj = 0; jj < N; jj+=B_SIZE)
for (int i=ii; i<MIN(ii+B_SIZE, M); i++)
for (int k=kk; k<MIN(kk+B_SIZE, K); k++) {
register float tmp = A[i * K + k];
for (int j=jj; j<MIN(jj+B_SIZE, N); j++)
C[i * N + j] += tmp * B[k * N + j];
}
#elif _O3_DEL_I
for (int kk = 0; kk < K; kk+=B_SIZE)
for (int jj = 0; jj < N; jj+=B_SIZE)
for (int i=0; i<M; i++)
for (int k=kk; k<MIN(kk+B_SIZE, K); k++) {
register float tmp = A[i * K + k];
for (int j=jj; j<MIN(jj+B_SIZE, N); j++)
C[i * N + j] += tmp * B[k * N + j];
}
#elif _O3_DEL_J
for (int ii = 0; ii < M; ii+=B_SIZE)
for (int kk = 0; kk < K; kk+=B_SIZE)
for (int i=ii; i<MIN(ii+B_SIZE, M); i++)
for (int k=kk; k<MIN(kk+B_SIZE, K); k++) {
register float tmp = A[i * K + k];
for (int j=0; j<N; j++)
C[i * N + j] += tmp * B[k * N + j];
}
#elif _O3_DEL_K
for (int ii = 0; ii < M; ii+=B_SIZE)
for (int jj = 0; jj < N; jj+=B_SIZE)
for (int i=ii; i<MIN(ii+B_SIZE, M); i++)
for (int k=0; k<K; k++) {
register float tmp = A[i * K + k];
for (int j=jj; j<MIN(jj+B_SIZE, N); j++)
C[i * N + j] += tmp * B[k * N + j];
}
#elif _O3_DEL_IJ
for (int kk = 0; kk < K; kk+=B_SIZE)
for (int i=0; i<M; i++)
for (int k=kk; k<MIN(kk+B_SIZE, K); k++) {
register float tmp = A[i * K + k];
for (int j=0; j<N; j++)
C[i * N + j] += tmp * B[k * N + j];
}
/*
* Using pthread library
* Based on _O3_DEL_IJ
*/
#elif _O4
int tid = * (int *) data;
int elemPerTh = M / num_threads;
int start = tid * elemPerTh;
int end = MIN(M, ((tid + 1) * elemPerTh));
register float tmp;
for (int kk = 0; kk < K; kk += B_SIZE)
for (int i = start; i < end; i++)
for (int k = kk; k < MIN(kk + B_SIZE, K); k++) {
tmp = A[i * K + k];
for (int j = 0; j < N; j++)
C[i * N + j] += tmp * B[k * N + j];
}
/*
* Using SIMD instruction
*/
#elif _OAVX
for (int i = 0; i < M; ++i) {
for (int k = 0; k < K; ++k) {
__m512 a0;
a0 = _mm512_load_ps(A + (i * K + k));
for (int j = 0; j < N; j+=B_SIZE) {
__m512 b0;
__m512 s0;
b0 = _mm512_load_ps(B + (k * N + j));
s0 = _mm512_load_ps(C + (i * N + j));
s0 = _mm512_fmadd_ps(a0, b0, s0);
// _mm512_store_ps((C + (i * N + j)), s0);
}
}
}
#endif
int tid = * (int *) data;
int elemPerTh = (M + num_threads - 1) / num_threads;
int start = tid * elemPerTh;
int end = MIN(M, ((tid + 1) * elemPerTh));
register float tmp;
for (int kk = 0; kk < K; kk += B_SIZE)
for (int i = start; i < end; i++)
for (int k = kk; k < MIN(kk + B_SIZE, K); k++) {
tmp = A[i * K + k];
for (int j = 0; j < N; j++)
C[i * N + j] += tmp * B[k * N + j];
}
return NULL;
}
void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads) {
A = _A, B = _B, C = _C;
M = _M, N = _N, K = _K;
num_threads = _num_threads;
pthread_t * threads = new pthread_t[num_threads];
int tid[500];
for (int i=0; i<num_threads; i++) {
tid[i] = i;
pthread_create(&threads[i], NULL, mat_mul_thread, tid + i);
}
for (int i = 0; i < num_threads; i++) {
pthread_join(threads[i], NULL);
}
}