#include "mat_mul.h" #include #include #include #include #include #include using namespace std; static float *A, *B, *C; static int M, N, K; static int num_threads; typedef struct thread_info{ int start, end; }thread_info; static void* mat_mul_thread(void *data) { thread_info *in = (thread_info *)data; // TODO: parallelize & optimize matrix multiplication int start_row = in->start; int end_row = in->end; int block = 32; float temp = 0; float check = N/8; if(check !=0){ for(int kk = 0; kk < K; kk+= block){ for(int i = start_row; i < min(end_row,M); ++i){ for(int k= kk; k< min(kk + block, K); ++k){ temp = A[i*K+k]; for(int j=0; j< N; ++j){ C[i*N+j] += temp * B[k*N+j]; } } } } }else{ for(int kk = 0 ; kk < K; kk += block){ // for(int jj = 0; jj < N; jj += block){ for(int i = start_row; i < min(end_row,M); i++){ for(int k = kk; k < min(kk + block, K); k++){ temp = A[i*K+k]; for(int j= 0; j< N; j+=8){ C[i*N+j+0] += temp * B[k*N+j+0]; C[i*N+j+1] += temp * B[k*N+j+1]; C[i*N+j+2] += temp * B[k*N+j+2]; C[i*N+j+3] += temp * B[k*N+j+3]; C[i*N+j+4] += temp * B[k*N+j+4]; C[i*N+j+5] += temp * B[k*N+j+5]; C[i*N+j+6] += temp * B[k*N+j+6]; C[i*N+j+7] += temp * B[k*N+j+7]; /* C[i*N+j+8] += temp * B[k*N+j+8]; C[i*N+j+9] += temp * B[k*N+j+9]; C[i*N+j+10] += temp * B[k*N+j+10]; C[i*N+j+11] += temp * B[k*N+j+11]; C[i*N+j+12] += temp * B[k*N+j+12]; C[i*N+j+13] += temp * B[k*N+j+13]; C[i*N+j+14] += temp * B[k*N+j+14]; C[i*N+j+15] += temp * B[k*N+j+15]; C[i*N+j+16] += temp * B[k*N+j+16]; C[i*N+j+17] += temp * B[k*N+j+17]; C[i*N+j+18] += temp * B[k*N+j+18]; C[i*N+j+19] += temp * B[k*N+j+19]; C[i*N+j+20] += temp * B[k*N+j+20]; C[i*N+j+21] += temp * B[k*N+j+21]; C[i*N+j+22] += temp * B[k*N+j+22]; C[i*N+j+23] += temp * B[k*N+j+23]; C[i*N+j+24] += temp * B[k*N+j+24]; C[i*N+j+25] += temp * B[k*N+j+25]; C[i*N+j+26] += temp * B[k*N+j+26]; C[i*N+j+27] += temp * B[k*N+j+27]; C[i*N+j+28] += temp * B[k*N+j+28]; C[i*N+j+29] += temp * B[k*N+j+29]; C[i*N+j+30] += temp * B[k*N+j+30]; C[i*N+j+31] += temp * B[k*N+j+31];*/ } } // } } } } return NULL; } void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads) { A = _A, B = _B, C = _C; M = _M, N = _N, K = _K; num_threads = _num_threads; int n_split = 0; int n_work = 0; float check = M/num_threads; if(check !=0) n_work = M < num_threads ? 1 : M/num_threads+1; else n_work = M < num_threads ? 1 : M/num_threads; n_split = M < num_threads ? M : num_threads; pthread_t threads[n_split]; thread_info t_pool[n_split]; //threads = (pthread_t *) malloc(sizeof(pthread_t)* n_split); for(int i = 0 ; i < n_split; i++){ thread_info tinfo; //struct thread_info *tinfo = (struct thread_info *) malloc(sizeof(struct thread_info)); tinfo.start = i*n_work; // tinfo->end = i == n_split - 1? M:(i+1)*n_split; //tinfo.start + n_work; tinfo.end = tinfo.start + n_work; t_pool[i] = tinfo; pthread_create(&threads[i], NULL, mat_mul_thread,(void*)&t_pool[i]); } for(int i = 0; i< n_split; i++) { pthread_join(threads[i], NULL); } }