chundoong-lab-ta/SamsungDS22/submissions/HW2/yw0.kim/mat_mul.cpp

112 lines
2.6 KiB
C++
Raw Normal View History

2022-09-29 18:01:45 +09:00
#include "mat_mul.h"
#include <cstdlib>
#include <cstdio>
#include <pthread.h>
// #define N_TILE
#define K_TILE
#define BLOCK_SIZE 48
#define UNROLL_SIZE 4
#define min(A, B) (((A) > (B)) ? (B) : (A))
typedef struct _meta_data_t
{
int row_start;
int row_num;
} meta_data_t;
static float *A, *B, *C;
static int M, N, K;
static int num_threads;
static void *mat_mul_thread(void *data)
{
// TODO: parallelize & optimize matrix multiplication
meta_data_t *meta_data = (meta_data_t *)data;
int bs = BLOCK_SIZE;
int row_start = meta_data->row_start;
int row_end = meta_data->row_num + row_start;
free(data);
float Ark = 0.0;
#ifdef K_TILE
for (int k_tile = 0; k_tile < K; k_tile += bs)
{
#endif
for (int r = row_start; r < row_end; r++)
{
#ifdef K_TILE
int k_limit = min(k_tile + bs, K);
for (int k = k_tile; k < k_limit; k++)
// for (int k = k_tile; k < min(k_tile + bs, K); k++)
#else
for (int k = 0; k < K; k++)
#endif
{
Ark = A[r * K + k];
int unroll_size = UNROLL_SIZE;
int n_unroll_limit = (N / unroll_size) * unroll_size;
for (int n = 0; n < n_unroll_limit; n += unroll_size)
{
#if UNROLL_SIZE >= 4
C[r * N + n] += Ark * B[k * N + n];
C[r * N + n + 1] += Ark * B[k * N + n + 1];
C[r * N + n + 2] += Ark * B[k * N + n + 2];
C[r * N + n + 3] += Ark * B[k * N + n + 3];
#endif
#if UNROLL_SIZE == 8
C[r * N + n + 4] += Ark * B[k * N + n + 4];
C[r * N + n + 5] += Ark * B[k * N + n + 5];
C[r * N + n + 6] += Ark * B[k * N + n + 6];
C[r * N + n + 7] += Ark * B[k * N + n + 7];
#endif
}
for (int n = n_unroll_limit; n < N; n++)
{
C[r * N + n] += Ark * B[k * N + n];
}
}
}
#ifdef K_TILE
}
#endif
return NULL;
}
void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads)
{
A = _A, B = _B, C = _C;
M = _M, N = _N, K = _K;
num_threads = _num_threads;
// TODO: create '_num_threads' pthreads
pthread_t *threads = (pthread_t *)malloc(sizeof(pthread_t) * num_threads);
meta_data_t *meta_data = NULL;
int rem_row = M % num_threads;
int num_row_per_thread = M / num_threads + 1;
for (int i = 0, row = 0; i < num_threads; i++, row += num_row_per_thread)
{
if (i == rem_row)
{
num_row_per_thread--;
}
meta_data = (meta_data_t *)malloc(sizeof(meta_data_t));
meta_data->row_start = row;
meta_data->row_num = num_row_per_thread;
pthread_create(&threads[i], NULL, mat_mul_thread, (void *)meta_data);
}
for (int i = 0; i < num_threads; i++)
{
pthread_join(threads[i], NULL);
}
}