#include "mat_mul.h" #include #include #include // #define N_TILE #define K_TILE #define BLOCK_SIZE 48 #define UNROLL_SIZE 4 #define min(A, B) (((A) > (B)) ? (B) : (A)) typedef struct _meta_data_t { int row_start; int row_num; } meta_data_t; static float *A, *B, *C; static int M, N, K; static int num_threads; static void *mat_mul_thread(void *data) { // TODO: parallelize & optimize matrix multiplication meta_data_t *meta_data = (meta_data_t *)data; int bs = BLOCK_SIZE; int row_start = meta_data->row_start; int row_end = meta_data->row_num + row_start; free(data); float Ark = 0.0; #ifdef K_TILE for (int k_tile = 0; k_tile < K; k_tile += bs) { #endif for (int r = row_start; r < row_end; r++) { #ifdef K_TILE int k_limit = min(k_tile + bs, K); for (int k = k_tile; k < k_limit; k++) // for (int k = k_tile; k < min(k_tile + bs, K); k++) #else for (int k = 0; k < K; k++) #endif { Ark = A[r * K + k]; int unroll_size = UNROLL_SIZE; int n_unroll_limit = (N / unroll_size) * unroll_size; for (int n = 0; n < n_unroll_limit; n += unroll_size) { #if UNROLL_SIZE >= 4 C[r * N + n] += Ark * B[k * N + n]; C[r * N + n + 1] += Ark * B[k * N + n + 1]; C[r * N + n + 2] += Ark * B[k * N + n + 2]; C[r * N + n + 3] += Ark * B[k * N + n + 3]; #endif #if UNROLL_SIZE == 8 C[r * N + n + 4] += Ark * B[k * N + n + 4]; C[r * N + n + 5] += Ark * B[k * N + n + 5]; C[r * N + n + 6] += Ark * B[k * N + n + 6]; C[r * N + n + 7] += Ark * B[k * N + n + 7]; #endif } for (int n = n_unroll_limit; n < N; n++) { C[r * N + n] += Ark * B[k * N + n]; } } } #ifdef K_TILE } #endif return NULL; } void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads) { A = _A, B = _B, C = _C; M = _M, N = _N, K = _K; num_threads = _num_threads; // TODO: create '_num_threads' pthreads pthread_t *threads = (pthread_t *)malloc(sizeof(pthread_t) * num_threads); meta_data_t *meta_data = NULL; int rem_row = M % num_threads; int num_row_per_thread = M / num_threads + 1; for (int i = 0, row = 0; i < num_threads; i++, row += num_row_per_thread) { if (i == rem_row) { num_row_per_thread--; } meta_data = (meta_data_t *)malloc(sizeof(meta_data_t)); meta_data->row_start = row; meta_data->row_num = num_row_per_thread; pthread_create(&threads[i], NULL, mat_mul_thread, (void *)meta_data); } for (int i = 0; i < num_threads; i++) { pthread_join(threads[i], NULL); } }