chundoong-lab-ta/SamsungDS22/submissions/HW2/mjstyle.kim/mat_mul.cpp

139 lines
4.0 KiB
C++

#include "mat_mul.h"
#include <cstdlib>
#include <cstdio>
#include <pthread.h>
static float *A, *B, *C;
static int M, N, K;
static int num_threads;
int i_pos[40];
int j_pos[40];
int k_pos[40];
short t_idx_arr[40];
short t_flag[40];
#define I_SIZE 16
#define J_SIZE 4096
#define K_SIZE 1024
#if 1
static void* mat_mul_thread(void *data) {
int j;
short *t_idx = (short *)data;
int ii = i_pos[*t_idx]+I_SIZE, jj = j_pos[*t_idx]+J_SIZE, kk = k_pos[*t_idx]+K_SIZE;
if(ii > M) ii = M;
if(jj > N) jj = N;
if(kk > K) kk = K;
for(int k = k_pos[*t_idx]; k < kk; k++) {
for(int i = i_pos[*t_idx]; i < ii; i++) {
float A_tmp = A[i * K + k];
for(j = j_pos[*t_idx]; j <= jj-16; j=j+16) {
C[i * N + j] += A_tmp * B[k * N + j];
C[i * N + j+1] += A_tmp * B[k * N + j+1];
C[i * N + j+2] += A_tmp * B[k * N + j+2];
C[i * N + j+3] += A_tmp * B[k * N + j+3];
C[i * N + j+4] += A_tmp * B[k * N + j+4];
C[i * N + j+5] += A_tmp * B[k * N + j+5];
C[i * N + j+6] += A_tmp * B[k * N + j+6];
C[i * N + j+7] += A_tmp * B[k * N + j+7];
C[i * N + j+8] += A_tmp * B[k * N + j+8];
C[i * N + j+9] += A_tmp * B[k * N + j+9];
C[i * N + j+10] += A_tmp * B[k * N + j+10];
C[i * N + j+11] += A_tmp * B[k * N + j+11];
C[i * N + j+12] += A_tmp * B[k * N + j+12];
C[i * N + j+13] += A_tmp * B[k * N + j+13];
C[i * N + j+14] += A_tmp * B[k * N + j+14];
C[i * N + j+15] += A_tmp * B[k * N + j+15];
}
if(j != N) {
for (j = j ; j < N; ++j) {
C[i * N + j] += A_tmp * B[k * N + j];
}
}
}
}
return NULL;
}
void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads) {
A = _A, B = _B, C = _C;
M = _M, N = _N, K = _K;
num_threads = _num_threads;
short t_idx=0;
// TODO: create '_num_threads' pthreads
pthread_t thread[num_threads];
for(int k = 0; k < K; k=k+K_SIZE) {
t_idx = 0;
for(int i = 0; i < M; i=i+I_SIZE) {
for(int j = 0; j < N; j=j+J_SIZE) {
if(t_flag[t_idx])
pthread_join(thread[t_idx], NULL);
t_idx_arr[t_idx] = t_idx; i_pos[t_idx] = i; j_pos[t_idx] = j; k_pos[t_idx] = k;
pthread_create(&thread[t_idx], NULL, mat_mul_thread, &t_idx_arr[t_idx]);
t_flag[t_idx] = 1;
if(t_idx == num_threads-1) {
t_idx = 0;
} else {
t_idx++;
}
}
}
}
for(t_idx = 0; t_idx < num_threads; ++t_idx) {
if(t_flag[t_idx]) {
pthread_join(thread[t_idx], NULL);
t_flag[t_idx] = 0;
}
}
}
#else
static void* mat_mul_thread(void *data) {
int j;
for (int k = 0; k < K; ++k) {
for (int i = 0; i < M; ++i) {
float A_tmp = A[i * K + k];
for (j = 0; j <= N-16; j = j+16) {
C[i * N + j] += A_tmp * B[k * N + j];
C[i * N + j+1] += A_tmp * B[k * N + j+1];
C[i * N + j+2] += A_tmp * B[k * N + j+2];
C[i * N + j+3] += A_tmp * B[k * N + j+3];
C[i * N + j+4] += A_tmp * B[k * N + j+4];
C[i * N + j+5] += A_tmp * B[k * N + j+5];
C[i * N + j+6] += A_tmp * B[k * N + j+6];
C[i * N + j+7] += A_tmp * B[k * N + j+7];
C[i * N + j+8] += A_tmp * B[k * N + j+8];
C[i * N + j+9] += A_tmp * B[k * N + j+9];
C[i * N + j+10] += A_tmp * B[k * N + j+10];
C[i * N + j+11] += A_tmp * B[k * N + j+11];
C[i * N + j+12] += A_tmp * B[k * N + j+12];
C[i * N + j+13] += A_tmp * B[k * N + j+13];
C[i * N + j+14] += A_tmp * B[k * N + j+14];
C[i * N + j+15] += A_tmp * B[k * N + j+15];
}
if(j != N) {
for (j = j ; j < N; ++j) {
C[i * N + j] += A_tmp * B[k * N + j];
}
}
}
}
return NULL;
}
void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads) {
A = _A, B = _B, C = _C;
M = _M, N = _N, K = _K;
num_threads = _num_threads;
// TODO: create '_num_threads' pthreads
pthread_t thread[num_threads];
pthread_create(&thread[2], NULL, mat_mul_thread, NULL);
pthread_join(thread[2], NULL);
}
#endif