105 lines
2.6 KiB
C++
105 lines
2.6 KiB
C++
#include "mat_mul.h"
|
|
|
|
#include <cstdlib>
|
|
#include <cstdio>
|
|
#include <pthread.h>
|
|
#include <immintrin.h>
|
|
|
|
static float *A, *B, *C;
|
|
static int M, N, K;
|
|
static int num_threads;
|
|
|
|
static void* mat_mul_thread(void *data) {
|
|
// TODO: parallelize & optimize matrix multiplication
|
|
|
|
// A: MxK, B: KxN, C: MxN
|
|
int num =*(int*) data;
|
|
int Block=40;
|
|
|
|
int KB,j,k,kk;
|
|
float Aik[4];
|
|
|
|
int KBlock = (K/Block)*Block;
|
|
|
|
int KU = (KBlock/4)*4;
|
|
int slice = M / num_threads;
|
|
int start = num * slice;
|
|
int end = num == num_threads - 1 ? M : (num + 1) * slice;
|
|
for(kk = 0; kk < KBlock; kk += Block){
|
|
for (int i=start; i < end; i++) {
|
|
|
|
if( kk+Block < KU ) KB = kk + Block;
|
|
else KB = KU;
|
|
for (k = kk; k < KB; k+=4) {
|
|
Aik[0] = A[i*K + k];
|
|
Aik[1] = A[i*K + k+1];
|
|
Aik[2] = A[i*K + k+2];
|
|
Aik[3] = A[i*K + k+3];
|
|
|
|
for (j = 0; j < N; j++) {
|
|
C[i*N+j] += Aik[0] * B[k*N + j] //A[i*K + k] * B[k*N + j];
|
|
+ Aik[1] * B[(k+1)*N + j] //A[i*K + k] * B[k*N + j];
|
|
+ Aik[2] * B[(k+2)*N + j] //A[i*K + k] * B[k*N + j];
|
|
+ Aik[3] * B[(k+3)*N + j]; //A[i*K + k] * B[k*N + j];
|
|
}
|
|
}
|
|
if(K -KB < 4){
|
|
for (; k < KBlock; k++) {
|
|
Aik[0] = A[i*K + k];
|
|
for (j = 0; j < N; j++) {
|
|
C[i*N+j] += Aik[0] * B[k*N + j]; //A[i*K + k] * B[k*N + j];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if(K-KBlock < Block){
|
|
for (; kk < K; kk++) {
|
|
for (int i=start; i < end; i++) {
|
|
Aik[0] = A[i*K + kk];
|
|
for (j = 0; j < N; j++) {
|
|
C[i*N+j] += Aik[0] * B[kk*N + j]; //A[i*K + k] * B[k*N + j];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
return NULL;
|
|
}
|
|
|
|
void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads) {
|
|
A = _A, B = _B, C = _C;
|
|
M = _M, N = _N, K = _K;
|
|
num_threads = _num_threads;
|
|
|
|
// TODO: create '_num_threads' pthreads
|
|
pthread_t * threads;
|
|
threads = (pthread_t *) malloc(sizeof(pthread_t) * num_threads);
|
|
|
|
int temp[num_threads];
|
|
int i,rc;
|
|
|
|
for(int i=0; i<num_threads; i++){
|
|
temp[i] =i;
|
|
int result = pthread_create(&threads[i], NULL, mat_mul_thread, (void *) &temp[i]);
|
|
if(result != 0){
|
|
printf("error %d : %d",i,result);
|
|
}
|
|
}
|
|
|
|
|
|
for(i=0; i<num_threads; i++){
|
|
|
|
rc = pthread_join(threads[i], NULL);
|
|
|
|
if( rc != 0){
|
|
printf("Error in threads[%d]: %d\n",i,rc);
|
|
exit(1);
|
|
}
|
|
}
|
|
free(threads);
|
|
|
|
}
|