106 lines
2.4 KiB
C++
106 lines
2.4 KiB
C++
|
#include "mat_mul.h"
|
||
|
|
||
|
#include <cstdlib>
|
||
|
#include <cstdio>
|
||
|
#include <pthread.h>
|
||
|
|
||
|
#include <iostream>
|
||
|
#include <immintrin.h>
|
||
|
|
||
|
using namespace std;
|
||
|
|
||
|
#define BLOCKSIZE 128
|
||
|
#define TILEWIDTH 2
|
||
|
|
||
|
#define min(x, y) (((x) > (y)) ? (y) : (x))
|
||
|
#define max(x, y) (((x) > (y)) ? (x) : (y))
|
||
|
|
||
|
static float *A, *B, *C;
|
||
|
static int M, N, K;
|
||
|
static int num_threads;
|
||
|
|
||
|
static void* mat_mul_thread(void *data) {
|
||
|
// TODO: parallelize & optimize matrix multiplication
|
||
|
|
||
|
int pid = * (int*) data;
|
||
|
|
||
|
int slice = M / num_threads;
|
||
|
int start = pid * slice;
|
||
|
int end = pid == num_threads - 1 ? M : (pid + 1) * slice;
|
||
|
|
||
|
float Aik;
|
||
|
int bs = 32;
|
||
|
|
||
|
int N_safe = (N / bs) * bs;
|
||
|
|
||
|
__m256 sv;
|
||
|
__m256 cv[4];
|
||
|
__m256 bv[4];
|
||
|
|
||
|
for(int kk = 0; kk < K; kk += bs){
|
||
|
for(int i = start; i < end; ++i){
|
||
|
for(int k = kk; k < min(kk + bs, K); ++k){
|
||
|
sv = _mm256_set1_ps(A[i*K + k]);
|
||
|
for(int v = 0; v < N_safe/8; v+=4){
|
||
|
cv[0] = _mm256_loadu_ps(&C[i*N + v * 8]);
|
||
|
cv[1] = _mm256_loadu_ps(&C[i*N + (v+1) * 8]);
|
||
|
cv[2] = _mm256_loadu_ps(&C[i*N + (v+2) * 8]);
|
||
|
cv[3] = _mm256_loadu_ps(&C[i*N + (v+3) * 8]);
|
||
|
|
||
|
bv[0] = _mm256_loadu_ps(&B[k*N + v * 8]);
|
||
|
bv[1] = _mm256_loadu_ps(&B[k*N + (v+1) * 8]);
|
||
|
bv[2] = _mm256_loadu_ps(&B[k*N + (v+2) * 8]);
|
||
|
bv[3] = _mm256_loadu_ps(&B[k*N + (v+3) * 8]);
|
||
|
|
||
|
cv[0] = _mm256_fmadd_ps(bv[0], sv, cv[0]);
|
||
|
cv[1] = _mm256_fmadd_ps(bv[1], sv, cv[1]);
|
||
|
cv[2] = _mm256_fmadd_ps(bv[2], sv, cv[2]);
|
||
|
cv[3] = _mm256_fmadd_ps(bv[3], sv, cv[3]);
|
||
|
|
||
|
_mm256_storeu_ps(&C[i*N + v * 8], cv[0]);
|
||
|
_mm256_storeu_ps(&C[i*N + (v+1) * 8], cv[1]);
|
||
|
_mm256_storeu_ps(&C[i*N + (v+2) * 8], cv[2]);
|
||
|
_mm256_storeu_ps(&C[i*N + (v+3) * 8], cv[3]);
|
||
|
}
|
||
|
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
for(int kk = 0; kk < K; kk += bs){
|
||
|
for(int i = start; i < end; ++i){
|
||
|
for(int k = kk; k < min(kk + bs, K); ++k){
|
||
|
Aik = A[i*K + k];
|
||
|
for(int j = N_safe; j < N; ++j){
|
||
|
C[i*N + j] += Aik * B[k*N + j];
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return NULL;
|
||
|
}
|
||
|
|
||
|
void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads) {
|
||
|
A = _A, B = _B, C = _C; M = _M, N = _N, K = _K;
|
||
|
num_threads = _num_threads;
|
||
|
|
||
|
// TODO: create '_num_threads' pthreads
|
||
|
|
||
|
pthread_t thread[num_threads];
|
||
|
int pid[num_threads];
|
||
|
int i;
|
||
|
|
||
|
for(i = 0; i < num_threads; i++){
|
||
|
pid[i] = i;
|
||
|
|
||
|
pthread_create(&(thread[i]), NULL, mat_mul_thread, &pid[i]); // original code
|
||
|
}
|
||
|
|
||
|
for(i = 0; i < num_threads; i++){
|
||
|
pthread_join(thread[i], NULL);
|
||
|
|
||
|
}
|
||
|
|
||
|
}
|