chundoong-lab-ta/SamsungDS22/submissions/HW2/hkyoo.kim/mat_mul.cpp

105 lines
2.6 KiB
C++

#include "mat_mul.h"
#include <cstdlib>
#include <cstdio>
#include <pthread.h>
#include <immintrin.h>
static float *A, *B, *C;
static int M, N, K;
static int num_threads;
static void* mat_mul_thread(void *data) {
// TODO: parallelize & optimize matrix multiplication
// A: MxK, B: KxN, C: MxN
int num =*(int*) data;
int Block=40;
int KB,j,k,kk;
float Aik[4];
int KBlock = (K/Block)*Block;
int KU = (KBlock/4)*4;
int slice = M / num_threads;
int start = num * slice;
int end = num == num_threads - 1 ? M : (num + 1) * slice;
for(kk = 0; kk < KBlock; kk += Block){
for (int i=start; i < end; i++) {
if( kk+Block < KU ) KB = kk + Block;
else KB = KU;
for (k = kk; k < KB; k+=4) {
Aik[0] = A[i*K + k];
Aik[1] = A[i*K + k+1];
Aik[2] = A[i*K + k+2];
Aik[3] = A[i*K + k+3];
for (j = 0; j < N; j++) {
C[i*N+j] += Aik[0] * B[k*N + j] //A[i*K + k] * B[k*N + j];
+ Aik[1] * B[(k+1)*N + j] //A[i*K + k] * B[k*N + j];
+ Aik[2] * B[(k+2)*N + j] //A[i*K + k] * B[k*N + j];
+ Aik[3] * B[(k+3)*N + j]; //A[i*K + k] * B[k*N + j];
}
}
if(K -KB < 4){
for (; k < KBlock; k++) {
Aik[0] = A[i*K + k];
for (j = 0; j < N; j++) {
C[i*N+j] += Aik[0] * B[k*N + j]; //A[i*K + k] * B[k*N + j];
}
}
}
}
}
if(K-KBlock < Block){
for (; kk < K; kk++) {
for (int i=start; i < end; i++) {
Aik[0] = A[i*K + kk];
for (j = 0; j < N; j++) {
C[i*N+j] += Aik[0] * B[kk*N + j]; //A[i*K + k] * B[k*N + j];
}
}
}
}
return NULL;
}
void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads) {
A = _A, B = _B, C = _C;
M = _M, N = _N, K = _K;
num_threads = _num_threads;
// TODO: create '_num_threads' pthreads
pthread_t * threads;
threads = (pthread_t *) malloc(sizeof(pthread_t) * num_threads);
int temp[num_threads];
int i,rc;
for(int i=0; i<num_threads; i++){
temp[i] =i;
int result = pthread_create(&threads[i], NULL, mat_mul_thread, (void *) &temp[i]);
if(result != 0){
printf("error %d : %d",i,result);
}
}
for(i=0; i<num_threads; i++){
rc = pthread_join(threads[i], NULL);
if( rc != 0){
printf("Error in threads[%d]: %d\n",i,rc);
exit(1);
}
}
free(threads);
}