#include "mat_mul.h" #include #include #include static float *A, *B, *C; static int M, N, K; static int num_threads; int i_pos[40]; int j_pos[40]; int k_pos[40]; short t_idx_arr[40]; short t_flag[40]; #define I_SIZE 16 #define J_SIZE 4096 #define K_SIZE 1024 #if 1 static void* mat_mul_thread(void *data) { int j; short *t_idx = (short *)data; int ii = i_pos[*t_idx]+I_SIZE, jj = j_pos[*t_idx]+J_SIZE, kk = k_pos[*t_idx]+K_SIZE; if(ii > M) ii = M; if(jj > N) jj = N; if(kk > K) kk = K; for(int k = k_pos[*t_idx]; k < kk; k++) { for(int i = i_pos[*t_idx]; i < ii; i++) { float A_tmp = A[i * K + k]; for(j = j_pos[*t_idx]; j <= jj-16; j=j+16) { C[i * N + j] += A_tmp * B[k * N + j]; C[i * N + j+1] += A_tmp * B[k * N + j+1]; C[i * N + j+2] += A_tmp * B[k * N + j+2]; C[i * N + j+3] += A_tmp * B[k * N + j+3]; C[i * N + j+4] += A_tmp * B[k * N + j+4]; C[i * N + j+5] += A_tmp * B[k * N + j+5]; C[i * N + j+6] += A_tmp * B[k * N + j+6]; C[i * N + j+7] += A_tmp * B[k * N + j+7]; C[i * N + j+8] += A_tmp * B[k * N + j+8]; C[i * N + j+9] += A_tmp * B[k * N + j+9]; C[i * N + j+10] += A_tmp * B[k * N + j+10]; C[i * N + j+11] += A_tmp * B[k * N + j+11]; C[i * N + j+12] += A_tmp * B[k * N + j+12]; C[i * N + j+13] += A_tmp * B[k * N + j+13]; C[i * N + j+14] += A_tmp * B[k * N + j+14]; C[i * N + j+15] += A_tmp * B[k * N + j+15]; } if(j != N) { for (j = j ; j < N; ++j) { C[i * N + j] += A_tmp * B[k * N + j]; } } } } return NULL; } void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads) { A = _A, B = _B, C = _C; M = _M, N = _N, K = _K; num_threads = _num_threads; short t_idx=0; // TODO: create '_num_threads' pthreads pthread_t thread[num_threads]; for(int k = 0; k < K; k=k+K_SIZE) { t_idx = 0; for(int i = 0; i < M; i=i+I_SIZE) { for(int j = 0; j < N; j=j+J_SIZE) { if(t_flag[t_idx]) pthread_join(thread[t_idx], NULL); t_idx_arr[t_idx] = t_idx; i_pos[t_idx] = i; j_pos[t_idx] = j; k_pos[t_idx] = k; pthread_create(&thread[t_idx], NULL, mat_mul_thread, &t_idx_arr[t_idx]); t_flag[t_idx] = 1; if(t_idx == num_threads-1) { t_idx = 0; } else { t_idx++; } } } } for(t_idx = 0; t_idx < num_threads; ++t_idx) { if(t_flag[t_idx]) { pthread_join(thread[t_idx], NULL); t_flag[t_idx] = 0; } } } #else static void* mat_mul_thread(void *data) { int j; for (int k = 0; k < K; ++k) { for (int i = 0; i < M; ++i) { float A_tmp = A[i * K + k]; for (j = 0; j <= N-16; j = j+16) { C[i * N + j] += A_tmp * B[k * N + j]; C[i * N + j+1] += A_tmp * B[k * N + j+1]; C[i * N + j+2] += A_tmp * B[k * N + j+2]; C[i * N + j+3] += A_tmp * B[k * N + j+3]; C[i * N + j+4] += A_tmp * B[k * N + j+4]; C[i * N + j+5] += A_tmp * B[k * N + j+5]; C[i * N + j+6] += A_tmp * B[k * N + j+6]; C[i * N + j+7] += A_tmp * B[k * N + j+7]; C[i * N + j+8] += A_tmp * B[k * N + j+8]; C[i * N + j+9] += A_tmp * B[k * N + j+9]; C[i * N + j+10] += A_tmp * B[k * N + j+10]; C[i * N + j+11] += A_tmp * B[k * N + j+11]; C[i * N + j+12] += A_tmp * B[k * N + j+12]; C[i * N + j+13] += A_tmp * B[k * N + j+13]; C[i * N + j+14] += A_tmp * B[k * N + j+14]; C[i * N + j+15] += A_tmp * B[k * N + j+15]; } if(j != N) { for (j = j ; j < N; ++j) { C[i * N + j] += A_tmp * B[k * N + j]; } } } } return NULL; } void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads) { A = _A, B = _B, C = _C; M = _M, N = _N, K = _K; num_threads = _num_threads; // TODO: create '_num_threads' pthreads pthread_t thread[num_threads]; pthread_create(&thread[2], NULL, mat_mul_thread, NULL); pthread_join(thread[2], NULL); } #endif