112 lines
2.6 KiB
C++
112 lines
2.6 KiB
C++
#include "mat_mul.h"
|
|
#include <cstdlib>
|
|
#include <cstdio>
|
|
#include <pthread.h>
|
|
|
|
// #define N_TILE
|
|
#define K_TILE
|
|
#define BLOCK_SIZE 48
|
|
#define UNROLL_SIZE 4
|
|
#define min(A, B) (((A) > (B)) ? (B) : (A))
|
|
|
|
typedef struct _meta_data_t
|
|
{
|
|
int row_start;
|
|
int row_num;
|
|
} meta_data_t;
|
|
|
|
static float *A, *B, *C;
|
|
static int M, N, K;
|
|
static int num_threads;
|
|
|
|
static void *mat_mul_thread(void *data)
|
|
{
|
|
// TODO: parallelize & optimize matrix multiplication
|
|
|
|
meta_data_t *meta_data = (meta_data_t *)data;
|
|
int bs = BLOCK_SIZE;
|
|
int row_start = meta_data->row_start;
|
|
int row_end = meta_data->row_num + row_start;
|
|
|
|
free(data);
|
|
|
|
float Ark = 0.0;
|
|
#ifdef K_TILE
|
|
for (int k_tile = 0; k_tile < K; k_tile += bs)
|
|
{
|
|
#endif
|
|
for (int r = row_start; r < row_end; r++)
|
|
{
|
|
#ifdef K_TILE
|
|
int k_limit = min(k_tile + bs, K);
|
|
for (int k = k_tile; k < k_limit; k++)
|
|
// for (int k = k_tile; k < min(k_tile + bs, K); k++)
|
|
#else
|
|
for (int k = 0; k < K; k++)
|
|
#endif
|
|
{
|
|
Ark = A[r * K + k];
|
|
int unroll_size = UNROLL_SIZE;
|
|
int n_unroll_limit = (N / unroll_size) * unroll_size;
|
|
for (int n = 0; n < n_unroll_limit; n += unroll_size)
|
|
{
|
|
#if UNROLL_SIZE >= 4
|
|
C[r * N + n] += Ark * B[k * N + n];
|
|
C[r * N + n + 1] += Ark * B[k * N + n + 1];
|
|
C[r * N + n + 2] += Ark * B[k * N + n + 2];
|
|
C[r * N + n + 3] += Ark * B[k * N + n + 3];
|
|
#endif
|
|
#if UNROLL_SIZE == 8
|
|
C[r * N + n + 4] += Ark * B[k * N + n + 4];
|
|
C[r * N + n + 5] += Ark * B[k * N + n + 5];
|
|
C[r * N + n + 6] += Ark * B[k * N + n + 6];
|
|
C[r * N + n + 7] += Ark * B[k * N + n + 7];
|
|
#endif
|
|
}
|
|
|
|
for (int n = n_unroll_limit; n < N; n++)
|
|
{
|
|
C[r * N + n] += Ark * B[k * N + n];
|
|
}
|
|
}
|
|
}
|
|
#ifdef K_TILE
|
|
}
|
|
#endif
|
|
|
|
return NULL;
|
|
}
|
|
|
|
void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads)
|
|
{
|
|
A = _A, B = _B, C = _C;
|
|
M = _M, N = _N, K = _K;
|
|
num_threads = _num_threads;
|
|
|
|
// TODO: create '_num_threads' pthreads
|
|
pthread_t *threads = (pthread_t *)malloc(sizeof(pthread_t) * num_threads);
|
|
meta_data_t *meta_data = NULL;
|
|
|
|
int rem_row = M % num_threads;
|
|
int num_row_per_thread = M / num_threads + 1;
|
|
|
|
for (int i = 0, row = 0; i < num_threads; i++, row += num_row_per_thread)
|
|
{
|
|
if (i == rem_row)
|
|
{
|
|
num_row_per_thread--;
|
|
}
|
|
|
|
meta_data = (meta_data_t *)malloc(sizeof(meta_data_t));
|
|
meta_data->row_start = row;
|
|
meta_data->row_num = num_row_per_thread;
|
|
|
|
pthread_create(&threads[i], NULL, mat_mul_thread, (void *)meta_data);
|
|
}
|
|
|
|
for (int i = 0; i < num_threads; i++)
|
|
{
|
|
pthread_join(threads[i], NULL);
|
|
}
|
|
}
|