139 lines
4.0 KiB
C++
139 lines
4.0 KiB
C++
#include "mat_mul.h"
|
|
|
|
#include <cstdlib>
|
|
#include <cstdio>
|
|
#include <pthread.h>
|
|
|
|
static float *A, *B, *C;
|
|
static int M, N, K;
|
|
static int num_threads;
|
|
|
|
int i_pos[40];
|
|
int j_pos[40];
|
|
int k_pos[40];
|
|
short t_idx_arr[40];
|
|
short t_flag[40];
|
|
|
|
#define I_SIZE 16
|
|
#define J_SIZE 4096
|
|
#define K_SIZE 1024
|
|
|
|
#if 1
|
|
static void* mat_mul_thread(void *data) {
|
|
int j;
|
|
short *t_idx = (short *)data;
|
|
int ii = i_pos[*t_idx]+I_SIZE, jj = j_pos[*t_idx]+J_SIZE, kk = k_pos[*t_idx]+K_SIZE;
|
|
if(ii > M) ii = M;
|
|
if(jj > N) jj = N;
|
|
if(kk > K) kk = K;
|
|
for(int k = k_pos[*t_idx]; k < kk; k++) {
|
|
for(int i = i_pos[*t_idx]; i < ii; i++) {
|
|
float A_tmp = A[i * K + k];
|
|
for(j = j_pos[*t_idx]; j <= jj-16; j=j+16) {
|
|
C[i * N + j] += A_tmp * B[k * N + j];
|
|
C[i * N + j+1] += A_tmp * B[k * N + j+1];
|
|
C[i * N + j+2] += A_tmp * B[k * N + j+2];
|
|
C[i * N + j+3] += A_tmp * B[k * N + j+3];
|
|
C[i * N + j+4] += A_tmp * B[k * N + j+4];
|
|
C[i * N + j+5] += A_tmp * B[k * N + j+5];
|
|
C[i * N + j+6] += A_tmp * B[k * N + j+6];
|
|
C[i * N + j+7] += A_tmp * B[k * N + j+7];
|
|
C[i * N + j+8] += A_tmp * B[k * N + j+8];
|
|
C[i * N + j+9] += A_tmp * B[k * N + j+9];
|
|
C[i * N + j+10] += A_tmp * B[k * N + j+10];
|
|
C[i * N + j+11] += A_tmp * B[k * N + j+11];
|
|
C[i * N + j+12] += A_tmp * B[k * N + j+12];
|
|
C[i * N + j+13] += A_tmp * B[k * N + j+13];
|
|
C[i * N + j+14] += A_tmp * B[k * N + j+14];
|
|
C[i * N + j+15] += A_tmp * B[k * N + j+15];
|
|
}
|
|
if(j != N) {
|
|
for (j = j ; j < N; ++j) {
|
|
C[i * N + j] += A_tmp * B[k * N + j];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return NULL;
|
|
}
|
|
|
|
void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads) {
|
|
A = _A, B = _B, C = _C;
|
|
M = _M, N = _N, K = _K;
|
|
num_threads = _num_threads;
|
|
|
|
short t_idx=0;
|
|
// TODO: create '_num_threads' pthreads
|
|
pthread_t thread[num_threads];
|
|
|
|
for(int k = 0; k < K; k=k+K_SIZE) {
|
|
t_idx = 0;
|
|
for(int i = 0; i < M; i=i+I_SIZE) {
|
|
for(int j = 0; j < N; j=j+J_SIZE) {
|
|
if(t_flag[t_idx])
|
|
pthread_join(thread[t_idx], NULL);
|
|
t_idx_arr[t_idx] = t_idx; i_pos[t_idx] = i; j_pos[t_idx] = j; k_pos[t_idx] = k;
|
|
pthread_create(&thread[t_idx], NULL, mat_mul_thread, &t_idx_arr[t_idx]);
|
|
t_flag[t_idx] = 1;
|
|
if(t_idx == num_threads-1) {
|
|
t_idx = 0;
|
|
} else {
|
|
t_idx++;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
for(t_idx = 0; t_idx < num_threads; ++t_idx) {
|
|
if(t_flag[t_idx]) {
|
|
pthread_join(thread[t_idx], NULL);
|
|
t_flag[t_idx] = 0;
|
|
}
|
|
}
|
|
}
|
|
#else
|
|
static void* mat_mul_thread(void *data) {
|
|
int j;
|
|
for (int k = 0; k < K; ++k) {
|
|
for (int i = 0; i < M; ++i) {
|
|
float A_tmp = A[i * K + k];
|
|
for (j = 0; j <= N-16; j = j+16) {
|
|
C[i * N + j] += A_tmp * B[k * N + j];
|
|
C[i * N + j+1] += A_tmp * B[k * N + j+1];
|
|
C[i * N + j+2] += A_tmp * B[k * N + j+2];
|
|
C[i * N + j+3] += A_tmp * B[k * N + j+3];
|
|
C[i * N + j+4] += A_tmp * B[k * N + j+4];
|
|
C[i * N + j+5] += A_tmp * B[k * N + j+5];
|
|
C[i * N + j+6] += A_tmp * B[k * N + j+6];
|
|
C[i * N + j+7] += A_tmp * B[k * N + j+7];
|
|
C[i * N + j+8] += A_tmp * B[k * N + j+8];
|
|
C[i * N + j+9] += A_tmp * B[k * N + j+9];
|
|
C[i * N + j+10] += A_tmp * B[k * N + j+10];
|
|
C[i * N + j+11] += A_tmp * B[k * N + j+11];
|
|
C[i * N + j+12] += A_tmp * B[k * N + j+12];
|
|
C[i * N + j+13] += A_tmp * B[k * N + j+13];
|
|
C[i * N + j+14] += A_tmp * B[k * N + j+14];
|
|
C[i * N + j+15] += A_tmp * B[k * N + j+15];
|
|
}
|
|
if(j != N) {
|
|
for (j = j ; j < N; ++j) {
|
|
C[i * N + j] += A_tmp * B[k * N + j];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return NULL;
|
|
}
|
|
|
|
void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads) {
|
|
A = _A, B = _B, C = _C;
|
|
M = _M, N = _N, K = _K;
|
|
num_threads = _num_threads;
|
|
|
|
// TODO: create '_num_threads' pthreads
|
|
pthread_t thread[num_threads];
|
|
pthread_create(&thread[2], NULL, mat_mul_thread, NULL);
|
|
pthread_join(thread[2], NULL);
|
|
}
|
|
#endif
|