#include "mat_mul.h" #include #include #include #include #define MAX_NUM_THREAD 100 #define BLOCK_SIZE 45 #define UNROLL_SIZE 8 #define MIN(a,b) ((a < b) ? (a) : (b)) #define likely(x) __builtin_expect((x),1) #define unlikely(x) __builtin_expect((x),0) static float *A, *B, *C; static int M, N, K; static int num_threads; typedef struct _thread_param_ { int pid; } _thread_param; static void* mat_mul_thread(void *data) { _thread_param *tp = (_thread_param*)data; int pid = tp->pid; int pslice = M / num_threads; int m_start = pid * pslice; int m_end = pid == num_threads - 1 ? M : (pid + 1) * pslice; #if defined(USE_MATRIX_UNROLL) register int q = (N / UNROLL_SIZE) * UNROLL_SIZE, r = N & (UNROLL_SIZE-1); #endif register float Aik; register int kk, i, k, j; register int bs = BLOCK_SIZE; for (kk = 0; kk < K; kk += bs) { for (i = m_start; likely(i < m_end); ++i) { for (k = kk; k < MIN(kk + bs, K); ++k) { Aik = A[i * K + k]; //__builtin_prefetch(&B[k * N], 0); #if defined(USE_MATRIX_UNROLL) for (j = 0; likely(j < q); j += UNROLL_SIZE) { #else for (j = 0; j < N; j++ ) { #endif C[i * N + j] += Aik * B[k * N + j]; #if defined(USE_MATRIX_UNROLL) C[i * N + j + 1] += Aik * B[k * N + j + 1]; C[i * N + j + 2] += Aik * B[k * N + j + 2]; C[i * N + j + 3] += Aik * B[k * N + j + 3]; C[i * N + j + 4] += Aik * B[k * N + j + 4]; C[i * N + j + 5] += Aik * B[k * N + j + 5]; C[i * N + j + 6] += Aik * B[k * N + j + 6]; C[i * N + j + 7] += Aik * B[k * N + j + 7]; #endif } #if defined(USE_MATRIX_UNROLL) for (register int t = j; t < q + r; ++t ) { C[i * N + t] += Aik * B[k * N + t]; } #endif } } } return NULL; } void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads) { A = _A, B = _B, C = _C; M = _M, N = _N, K = _K; num_threads = _num_threads; // TODO: create '_num_threads' pthreads pthread_t thread[MAX_NUM_THREAD]; _thread_param tparam[MAX_NUM_THREAD]; int i; if (num_threads > M) { num_threads = M; } for (i = 0; i < num_threads; i++) { tparam[i].pid = i; pthread_create(&thread[i], NULL, mat_mul_thread, &tparam[i]); } for (i = 0; i < num_threads; i++) { pthread_join(thread[i], NULL); } }