chundoong-lab-ta/SamsungDS22/submissions/HW2/seong81.kim/mat_mul.cpp

119 lines
3.5 KiB
C++
Raw Normal View History

2022-09-29 18:01:45 +09:00
#include "mat_mul.h"
#include <cstdlib>
#include <cstdio>
#include <pthread.h>
static float *A, *B, *C;
static int M, N, K;
static int num_threads;
typedef struct _mat_mul_arg {
int idx, size;
} mat_mul_arg;
inline int func_min (int a, int b){
return (a>b) ? b : a;
}
static void* mat_mul_thread(void *data) {
// TODO: parallelize & optimize matrix multiplication
int thread_idx, thread_size ;
int low_M, high_M ;
int i, j , k ;
int c_base, c_idx ;
int line_start ;
int a_base ;
float a[45] ;
float b[ 5] ;
mat_mul_arg *local;
local = (mat_mul_arg*)data ;
thread_idx = local->idx ;
thread_size = local->size ;
low_M = thread_idx * thread_size ;
high_M = func_min((low_M + thread_size), M);
line_start = low_M * N ;
for ( k = 0 ; k < K - 45 ; k += 45){
c_base = line_start ;
for ( i = low_M; i < high_M; ++i) {
//---------- START OF LOOP -------------//
a_base = i * K + k ;
for (int ll = 0 ; ll < 45 ; ll++)
a[ll] = A[a_base + ll] ;
for (int ll = 0 ; ll < 45 ; ll += 5){
for ( j = 0; j < N; ++j) {
c_idx = c_base + j ;
b[0] = B[(k + ll + 0) * N + j] ;
b[1] = B[(k + ll + 1) * N + j] ;
b[2] = B[(k + ll + 2) * N + j] ;
b[3] = B[(k + ll + 3) * N + j] ;
b[4] = B[(k + ll + 4) * N + j] ;
C[c_idx] = (a[ll + 0] * b[0]) +
(a[ll + 1] * b[1]) +
(a[ll + 2] * b[2]) +
(a[ll + 3] * b[3]) +
(a[ll + 4] * b[4]) +
C[c_idx] ;
}
}
c_base += N ;
//---------- END OF LOOP -------------//
}
}
// Processing Remained Area
for ( ; k < K ; ++k) {
for ( i = low_M; i < high_M; ++i) {
for ( j = 0; j < N; ++j) {
C[i * N + j] += A[i * K + k] * B[k * N + j];
}
}
}
return NULL;
}
void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads) {
A = _A, B = _B, C = _C;
M = _M, N = _N, K = _K;
num_threads = _num_threads;
//TODO: create '_num_threads' pthreads
pthread_t thread[num_threads];
mat_mul_arg arg [num_threads];
int i ;
int M_div_thread;
int real_num_thread ;
int thread_unit ;
M_div_thread = M / num_threads ;
if (M_div_thread == 0) {
real_num_thread = M ;
thread_unit = 1 ;
} else if (M_div_thread * num_threads < M){
thread_unit = M_div_thread + 1 ;
real_num_thread = num_threads ;
} else {
thread_unit = M_div_thread ;
real_num_thread = num_threads ;
}
for (i = 0 ; i < real_num_thread ; i++){
arg[i].idx = i;
arg[i].size = thread_unit;
//pthread_create(&thread[i], NULL, mat_mul_thread, NULL);
pthread_create(&thread[i], NULL, mat_mul_thread, &arg[i]);
}
for (i = 0 ; i < real_num_thread ; i++){
pthread_join(thread[i], NULL);
}
}