chundoong-lab-ta/SamsungDS22/submissions/HW2/ty.jeon/mat_mul.cpp

106 lines
2.4 KiB
C++

#include "mat_mul.h"
#include <cstdlib>
#include <cstdio>
#include <pthread.h>
#include <iostream>
#include <immintrin.h>
using namespace std;
#define BLOCKSIZE 128
#define TILEWIDTH 2
#define min(x, y) (((x) > (y)) ? (y) : (x))
#define max(x, y) (((x) > (y)) ? (x) : (y))
static float *A, *B, *C;
static int M, N, K;
static int num_threads;
static void* mat_mul_thread(void *data) {
// TODO: parallelize & optimize matrix multiplication
int pid = * (int*) data;
int slice = M / num_threads;
int start = pid * slice;
int end = pid == num_threads - 1 ? M : (pid + 1) * slice;
float Aik;
int bs = 32;
int N_safe = (N / bs) * bs;
__m256 sv;
__m256 cv[4];
__m256 bv[4];
for(int kk = 0; kk < K; kk += bs){
for(int i = start; i < end; ++i){
for(int k = kk; k < min(kk + bs, K); ++k){
sv = _mm256_set1_ps(A[i*K + k]);
for(int v = 0; v < N_safe/8; v+=4){
cv[0] = _mm256_loadu_ps(&C[i*N + v * 8]);
cv[1] = _mm256_loadu_ps(&C[i*N + (v+1) * 8]);
cv[2] = _mm256_loadu_ps(&C[i*N + (v+2) * 8]);
cv[3] = _mm256_loadu_ps(&C[i*N + (v+3) * 8]);
bv[0] = _mm256_loadu_ps(&B[k*N + v * 8]);
bv[1] = _mm256_loadu_ps(&B[k*N + (v+1) * 8]);
bv[2] = _mm256_loadu_ps(&B[k*N + (v+2) * 8]);
bv[3] = _mm256_loadu_ps(&B[k*N + (v+3) * 8]);
cv[0] = _mm256_fmadd_ps(bv[0], sv, cv[0]);
cv[1] = _mm256_fmadd_ps(bv[1], sv, cv[1]);
cv[2] = _mm256_fmadd_ps(bv[2], sv, cv[2]);
cv[3] = _mm256_fmadd_ps(bv[3], sv, cv[3]);
_mm256_storeu_ps(&C[i*N + v * 8], cv[0]);
_mm256_storeu_ps(&C[i*N + (v+1) * 8], cv[1]);
_mm256_storeu_ps(&C[i*N + (v+2) * 8], cv[2]);
_mm256_storeu_ps(&C[i*N + (v+3) * 8], cv[3]);
}
}
}
}
for(int kk = 0; kk < K; kk += bs){
for(int i = start; i < end; ++i){
for(int k = kk; k < min(kk + bs, K); ++k){
Aik = A[i*K + k];
for(int j = N_safe; j < N; ++j){
C[i*N + j] += Aik * B[k*N + j];
}
}
}
}
return NULL;
}
void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads) {
A = _A, B = _B, C = _C; M = _M, N = _N, K = _K;
num_threads = _num_threads;
// TODO: create '_num_threads' pthreads
pthread_t thread[num_threads];
int pid[num_threads];
int i;
for(i = 0; i < num_threads; i++){
pid[i] = i;
pthread_create(&(thread[i]), NULL, mat_mul_thread, &pid[i]); // original code
}
for(i = 0; i < num_threads; i++){
pthread_join(thread[i], NULL);
}
}