119 lines
3.5 KiB
C++
119 lines
3.5 KiB
C++
#include "mat_mul.h"
|
|
|
|
#include <cstdlib>
|
|
#include <cstdio>
|
|
#include <pthread.h>
|
|
|
|
static float *A, *B, *C;
|
|
static int M, N, K;
|
|
static int num_threads;
|
|
|
|
typedef struct _mat_mul_arg {
|
|
int idx, size;
|
|
} mat_mul_arg;
|
|
|
|
inline int func_min (int a, int b){
|
|
return (a>b) ? b : a;
|
|
}
|
|
static void* mat_mul_thread(void *data) {
|
|
// TODO: parallelize & optimize matrix multiplication
|
|
int thread_idx, thread_size ;
|
|
int low_M, high_M ;
|
|
int i, j , k ;
|
|
int c_base, c_idx ;
|
|
int line_start ;
|
|
int a_base ;
|
|
float a[45] ;
|
|
float b[ 5] ;
|
|
|
|
mat_mul_arg *local;
|
|
|
|
local = (mat_mul_arg*)data ;
|
|
|
|
thread_idx = local->idx ;
|
|
thread_size = local->size ;
|
|
|
|
low_M = thread_idx * thread_size ;
|
|
high_M = func_min((low_M + thread_size), M);
|
|
|
|
line_start = low_M * N ;
|
|
for ( k = 0 ; k < K - 45 ; k += 45){
|
|
c_base = line_start ;
|
|
for ( i = low_M; i < high_M; ++i) {
|
|
//---------- START OF LOOP -------------//
|
|
a_base = i * K + k ;
|
|
for (int ll = 0 ; ll < 45 ; ll++)
|
|
a[ll] = A[a_base + ll] ;
|
|
|
|
for (int ll = 0 ; ll < 45 ; ll += 5){
|
|
for ( j = 0; j < N; ++j) {
|
|
c_idx = c_base + j ;
|
|
|
|
b[0] = B[(k + ll + 0) * N + j] ;
|
|
b[1] = B[(k + ll + 1) * N + j] ;
|
|
b[2] = B[(k + ll + 2) * N + j] ;
|
|
b[3] = B[(k + ll + 3) * N + j] ;
|
|
b[4] = B[(k + ll + 4) * N + j] ;
|
|
|
|
C[c_idx] = (a[ll + 0] * b[0]) +
|
|
(a[ll + 1] * b[1]) +
|
|
(a[ll + 2] * b[2]) +
|
|
(a[ll + 3] * b[3]) +
|
|
(a[ll + 4] * b[4]) +
|
|
C[c_idx] ;
|
|
}
|
|
}
|
|
c_base += N ;
|
|
//---------- END OF LOOP -------------//
|
|
}
|
|
}
|
|
|
|
// Processing Remained Area
|
|
for ( ; k < K ; ++k) {
|
|
for ( i = low_M; i < high_M; ++i) {
|
|
for ( j = 0; j < N; ++j) {
|
|
C[i * N + j] += A[i * K + k] * B[k * N + j];
|
|
}
|
|
}
|
|
}
|
|
return NULL;
|
|
}
|
|
|
|
void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads) {
|
|
A = _A, B = _B, C = _C;
|
|
M = _M, N = _N, K = _K;
|
|
num_threads = _num_threads;
|
|
|
|
//TODO: create '_num_threads' pthreads
|
|
pthread_t thread[num_threads];
|
|
mat_mul_arg arg [num_threads];
|
|
|
|
int i ;
|
|
int M_div_thread;
|
|
int real_num_thread ;
|
|
int thread_unit ;
|
|
|
|
M_div_thread = M / num_threads ;
|
|
if (M_div_thread == 0) {
|
|
real_num_thread = M ;
|
|
thread_unit = 1 ;
|
|
} else if (M_div_thread * num_threads < M){
|
|
thread_unit = M_div_thread + 1 ;
|
|
real_num_thread = num_threads ;
|
|
} else {
|
|
thread_unit = M_div_thread ;
|
|
real_num_thread = num_threads ;
|
|
}
|
|
|
|
for (i = 0 ; i < real_num_thread ; i++){
|
|
arg[i].idx = i;
|
|
arg[i].size = thread_unit;
|
|
//pthread_create(&thread[i], NULL, mat_mul_thread, NULL);
|
|
pthread_create(&thread[i], NULL, mat_mul_thread, &arg[i]);
|
|
}
|
|
|
|
for (i = 0 ; i < real_num_thread ; i++){
|
|
pthread_join(thread[i], NULL);
|
|
}
|
|
}
|