759 lines
33 KiB
C++
759 lines
33 KiB
C++
|
#include "mat_mul.h"
|
||
|
|
||
|
#include <cstdlib>
|
||
|
#include <cstdio>
|
||
|
#include <pthread.h>
|
||
|
#include <math.h>
|
||
|
#include <unistd.h>
|
||
|
#include <immintrin.h>
|
||
|
|
||
|
static float *A, *B, *C;
|
||
|
static int M, N, K;
|
||
|
static int num_threads;
|
||
|
|
||
|
struct thread_info { /* Used as argument to thread_start() */
|
||
|
pthread_t thread_id; /* ID returned by pthread_create() */
|
||
|
int thread_num; /* Application-defined thread # */
|
||
|
};
|
||
|
|
||
|
int min(int a, int b){
|
||
|
if (a < b)
|
||
|
return a;
|
||
|
else
|
||
|
return b;
|
||
|
}
|
||
|
|
||
|
static void* mat_mul_thread(void *data) {
|
||
|
// TODO: parallelize & optimize matrix multiplication
|
||
|
struct thread_info *tinfo = (thread_info*)data;
|
||
|
int thread_index = tinfo->thread_num;
|
||
|
// printf("Thread Num : %d \n", thread_index);
|
||
|
|
||
|
int m = M/num_threads;
|
||
|
// int n = N/num_threads;
|
||
|
|
||
|
int index_m = m * thread_index;
|
||
|
// int index_n = n * thread_index;
|
||
|
|
||
|
int m_mod = M%num_threads;
|
||
|
// int n_mod = N%num_threads;
|
||
|
|
||
|
// __m256 a0;
|
||
|
// __m256 b0;
|
||
|
// __m256 s0;
|
||
|
// s0 = _mm256_set_ps(0., 0., 0., 0., 0., 0., 0., 0.);
|
||
|
//
|
||
|
// int element_num = 64;
|
||
|
//
|
||
|
// for(int i=0; i< element_num /8 ; ++i){
|
||
|
// a0 = _mm256_load_ps(a + i * 8);
|
||
|
// b0 = _mm256_load_ps(b + i * 8);
|
||
|
// s0 = _mm256_fmadd_ps(a0, b0 ,s0);
|
||
|
// }
|
||
|
// float c = s0[0] + s0[1] + s0[2] + s0[3] + s0[4] + s0[5] + s0[6] + s0[7];
|
||
|
// printf("AVX : %f \n", c );
|
||
|
|
||
|
// printf("M para : %d, %d, %d \n", m , index_m, m_mod);
|
||
|
|
||
|
// printf("N para : %d, %d, %d \n", n , index_n, n_mod);
|
||
|
|
||
|
|
||
|
/*********************************************/
|
||
|
/* (1) basic(@single Thread) */
|
||
|
// Avg. time: 0.303500 sec
|
||
|
// Avg. throughput: 0.884466 GFLOPS
|
||
|
|
||
|
/*********************************************/
|
||
|
// for (int i = 0; i < M; ++i) {
|
||
|
// for (int j = 0; j < N; ++j) {
|
||
|
// for (int k = 0; k < K; ++k) {
|
||
|
// C[i * N + j] += A[i * K + k] * B[k * N + j];
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
|
||
|
/*******************************************/
|
||
|
/* (2) block(@single Thead) */
|
||
|
/*******************************************/
|
||
|
// int block = 4;
|
||
|
//
|
||
|
// for (int ii = 0; ii < M; ii += block)
|
||
|
// for (int jj = 0; jj < N; jj += block)
|
||
|
// for (int kk = 0; kk < K; kk += block)
|
||
|
// for (int i = ii; i < min(ii+block, M); i++)
|
||
|
// for (int j = jj; j < min(jj+block, N); j++)
|
||
|
// for (int k = kk; k < min(kk+block, K); k++)
|
||
|
// C[i * N + j] += A[i * K + k] * B[k * N + j];
|
||
|
|
||
|
/*****************************/
|
||
|
/* (3) thread split A Raw(2 thread) */
|
||
|
// Avg. time: 0.177479 sec
|
||
|
// Avg. throughput: 1.512491 GFLOPS
|
||
|
|
||
|
/*****************************/
|
||
|
// int block = 4;
|
||
|
// int block_index;
|
||
|
//
|
||
|
//
|
||
|
// if((num_threads-1) != thread_index){
|
||
|
// for(int i = index_m ; i < index_m + m; ++i)
|
||
|
// {
|
||
|
// for(int j = 0; j < N; ++j)
|
||
|
// {
|
||
|
// for (int k = 0; k < K; ++k){
|
||
|
// C[i * N + j] += A[i * K + k] * B[k * N + j];
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// else{
|
||
|
// for(int i = index_m ; i < index_m + m + m_mod; ++i)
|
||
|
// {
|
||
|
// for(int j = 0; j < N; ++j)
|
||
|
// {
|
||
|
// for (int k = 0; k < K; ++k){
|
||
|
// C[i * N + j] += A[i * K + k] * B[k * N + j];
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
|
||
|
/***********************************************/
|
||
|
/* (4) thread split A Raw, Block */
|
||
|
/* Avg. time: 5.003404 sec */
|
||
|
/* Avg. throughput: 27.469088 GFLOPS */
|
||
|
/***********************************************/
|
||
|
|
||
|
// int block = 8;
|
||
|
// int block_index;
|
||
|
//
|
||
|
// if((num_threads-1) != thread_index){
|
||
|
// for(int ii = index_m ; ii < index_m + m; ii+=block){
|
||
|
// for (int jj = 0; jj < N; jj += block){
|
||
|
// for (int kk = 0; kk < K; kk += block){
|
||
|
// for (int i = ii; i < min(ii+block, index_m + m); i++){
|
||
|
// for (int j = jj; j < min(jj+block, N); j++){
|
||
|
// for (int k = kk; k < min(kk+block, K); k++){
|
||
|
// C[i * N + j] += A[i * K + k] * B[k * N + j];
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
//
|
||
|
// }
|
||
|
// else{
|
||
|
// for(int ii = index_m ; ii < index_m + m + m_mod; ii+=block){
|
||
|
// for (int jj = 0; jj < N; jj += block){
|
||
|
// for (int kk = 0; kk < K; kk += block){
|
||
|
// for (int i = ii; i < min(ii+block, index_m + m + m_mod); i++){
|
||
|
// for (int j = jj; j < min(jj+block, N); j++){
|
||
|
//
|
||
|
// for (int k = kk; k < min(kk+block, K); k++){
|
||
|
// C[i * N + j] += A[i * K + k] * B[k * N + j];
|
||
|
// }
|
||
|
//
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
|
||
|
/************************************/
|
||
|
/* (5) thread split A Raw, Block , unrolling */
|
||
|
//Avg. time: 4.162063 sec
|
||
|
//Avg. throughput: 33.021836 GFLOPS
|
||
|
/************************************/
|
||
|
|
||
|
//int block = 8; // 4.846sec
|
||
|
// int block = 16; // 4.439sec / 32.3 GFLOPS
|
||
|
//
|
||
|
// int block_index;
|
||
|
// int in;
|
||
|
// int ik;
|
||
|
// float temp;
|
||
|
//
|
||
|
// if((num_threads-1) != thread_index){
|
||
|
// for(int ii = index_m ; ii < index_m + m; ii+=block){
|
||
|
// for (int jj = 0; jj < N; jj += block){
|
||
|
// for (int kk = 0; kk < K; kk += block){
|
||
|
// for (int i = ii; i < min(ii+block, index_m + m); i++){
|
||
|
// in = i * N;
|
||
|
// ik = i * K;
|
||
|
// for (int j = jj; j < min(jj+block, N); j++){
|
||
|
// if(kk+block >= K){
|
||
|
// for (int k = kk; k < min(kk+block, K); k++){
|
||
|
// C[i * N + j] += A[i * K + k] * B[k * N + j];
|
||
|
// }
|
||
|
// }
|
||
|
// else{
|
||
|
// temp = C[in + j];
|
||
|
// temp += A[ik + (kk)] * B[(kk) * N + j];
|
||
|
// temp += A[ik + (kk+1)] * B[(kk+1) * N + j];
|
||
|
// temp += A[ik + (kk+2)] * B[(kk+2) * N + j];
|
||
|
// temp += A[ik + (kk+3)] * B[(kk+3) * N + j];
|
||
|
// temp += A[ik + (kk+4)] * B[(kk+4) * N + j];
|
||
|
// temp += A[ik + (kk+5)] * B[(kk+5) * N + j];
|
||
|
// temp += A[ik + (kk+6)] * B[(kk+6) * N + j];
|
||
|
// temp += A[ik + (kk+7)] * B[(kk+7) * N + j];
|
||
|
// temp += A[ik + (kk+8)] * B[(kk+8) * N + j];
|
||
|
// temp += A[ik + (kk+9)] * B[(kk+9) * N + j];
|
||
|
// temp += A[ik + (kk+10)] * B[(kk+10) * N + j];
|
||
|
// temp += A[ik + (kk+11)] * B[(kk+11) * N + j];
|
||
|
// temp += A[ik + (kk+12)] * B[(kk+12) * N + j];
|
||
|
// temp += A[ik + (kk+13)] * B[(kk+13) * N + j];
|
||
|
// temp += A[ik + (kk+14)] * B[(kk+14) * N + j];
|
||
|
// temp += A[ik + (kk+15)] * B[(kk+15) * N + j];
|
||
|
// C[in + j] = temp;
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// else{
|
||
|
// for(int ii = index_m ; ii < index_m + m + m_mod; ii+=block){
|
||
|
// for (int jj = 0; jj < N; jj += block){
|
||
|
// for (int kk = 0; kk < K; kk += block){
|
||
|
// for (int i = ii; i < min(ii+block, index_m + m + m_mod); i++){
|
||
|
// in = i * N;
|
||
|
// ik = i * K;
|
||
|
// for (int j = jj; j < min(jj+block, N); j++){
|
||
|
// if(kk+block >= K){
|
||
|
// for (int k = kk; k < min(kk+block, K); k++){
|
||
|
// C[i * N + j] += A[i * K + k] * B[k * N + j];
|
||
|
// }
|
||
|
// }
|
||
|
// else{
|
||
|
// temp = C[in + j];
|
||
|
// temp += A[ik + (kk)] * B[(kk) * N + j];
|
||
|
// temp += A[ik + (kk+1)] * B[(kk+1) * N + j];
|
||
|
// temp += A[ik + (kk+2)] * B[(kk+2) * N + j];
|
||
|
// temp += A[ik + (kk+3)] * B[(kk+3) * N + j];
|
||
|
// temp += A[ik + (kk+4)] * B[(kk+4) * N + j];
|
||
|
// temp += A[ik + (kk+5)] * B[(kk+5) * N + j];
|
||
|
// temp += A[ik + (kk+6)] * B[(kk+6) * N + j];
|
||
|
// temp += A[ik + (kk+7)] * B[(kk+7) * N + j];
|
||
|
// temp += A[ik + (kk+8)] * B[(kk+8) * N + j];
|
||
|
// temp += A[ik + (kk+9)] * B[(kk+9) * N + j];
|
||
|
// temp += A[ik + (kk+10)] * B[(kk+10) * N + j];
|
||
|
// temp += A[ik + (kk+11)] * B[(kk+11) * N + j];
|
||
|
// temp += A[ik + (kk+12)] * B[(kk+12) * N + j];
|
||
|
// temp += A[ik + (kk+13)] * B[(kk+13) * N + j];
|
||
|
// temp += A[ik + (kk+14)] * B[(kk+14) * N + j];
|
||
|
// temp += A[ik + (kk+15)] * B[(kk+15) * N + j];
|
||
|
// C[in + j] = temp;
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
|
||
|
|
||
|
/************************************/
|
||
|
/* (6) single Thread, addressing, column (512,512,512) */
|
||
|
//Avg. time: 0.033199 sec
|
||
|
//Avg. throughput: 8.085689 GFLOPS
|
||
|
//
|
||
|
/************************************/
|
||
|
// for (int i = 0; i < M; ++i)
|
||
|
// {
|
||
|
// float * c = C + i * N;
|
||
|
// for (int j = 0; j < N; ++j)
|
||
|
// c[j] = 0;
|
||
|
// for (int k = 0; k < K; ++k)
|
||
|
// {
|
||
|
// const float * b = B + k * N;
|
||
|
// float a = A[i*K + k];
|
||
|
// for (int j = 0; j < N; ++j)
|
||
|
// c[j] += a * b[j];
|
||
|
// }
|
||
|
// }
|
||
|
|
||
|
/************************************/
|
||
|
/* (7) 2 Thread, (512,512,512) */
|
||
|
// Avg. time: 0.019109 sec (2 thread)
|
||
|
// Avg. throughput: 14.047585 GFLOPS (2 thread)
|
||
|
// Performance : 91 GFLOPS
|
||
|
/************************************/
|
||
|
// int block = 4;
|
||
|
// int block_index;
|
||
|
//
|
||
|
//
|
||
|
// if((num_threads-1) != thread_index){
|
||
|
// for(int i = index_m ; i < index_m + m; ++i)
|
||
|
// {
|
||
|
// float * c = C + i * N;
|
||
|
// for (int j = 0; j < N; ++j)
|
||
|
// c[j] = 0;
|
||
|
// for (int k = 0; k < K; ++k)
|
||
|
// {
|
||
|
//// const float * b = B + k * N;
|
||
|
// float * b = B + k * N;
|
||
|
// float a = A[i*K + k];
|
||
|
// for (int j = 0; j < N; ++j)
|
||
|
// c[j] += a * b[j];
|
||
|
// }
|
||
|
//// for(int j = 0; j < N; ++j)
|
||
|
//// {
|
||
|
//// for (int k = 0; k < K; ++k){
|
||
|
//// C[i * N + j] += A[i * K + k] * B[k * N + j];
|
||
|
//// }
|
||
|
//// }
|
||
|
// }
|
||
|
// }
|
||
|
// else{
|
||
|
// for(int i = index_m ; i < index_m + m + m_mod; ++i)
|
||
|
// {
|
||
|
// float * c = C + i * N;
|
||
|
// for (int j = 0; j < N; ++j)
|
||
|
// c[j] = 0;
|
||
|
// for (int k = 0; k < K; ++k)
|
||
|
// {
|
||
|
//// const float * b = B + k * N;
|
||
|
// float * b = B + k * N;
|
||
|
// float a = A[i*K + k];
|
||
|
// for (int j = 0; j < N; ++j)
|
||
|
// c[j] += a * b[j];
|
||
|
// }
|
||
|
// // for(int j = 0; j < N; ++j)
|
||
|
// // {
|
||
|
// // for (int k = 0; k < K; ++k){
|
||
|
// // C[i * N + j] += A[i * K + k] * B[k * N + j];
|
||
|
// // }
|
||
|
// // }
|
||
|
// }
|
||
|
// }
|
||
|
|
||
|
|
||
|
/************************************/
|
||
|
/* (8) thread split A Raw, Block */
|
||
|
// Avg. time: 0.018701 sec
|
||
|
// Avg. throughput: 14.354194 GFLOPS
|
||
|
// Performance : 163.489443 GFLOPS (128 block)
|
||
|
/************************************/
|
||
|
|
||
|
// int block = 128;
|
||
|
//
|
||
|
// if((num_threads-1) != thread_index){
|
||
|
// for(int ii = index_m ; ii < index_m + m; ii+=block){
|
||
|
// for (int jj = 0; jj < N; jj += block){
|
||
|
// for (int kk = 0; kk < K; kk += block){
|
||
|
//
|
||
|
// for (int i = ii; i < min(ii+block, index_m + m); i++){
|
||
|
// float * c = C + i * N;
|
||
|
//
|
||
|
//// for (int j = jj; j < min(jj+block, N); j++)
|
||
|
//// c[j] = 0;
|
||
|
//
|
||
|
// for (int k = kk; k < min(kk+block, K); ++k)
|
||
|
// {
|
||
|
// float * b = B + k * N;
|
||
|
// float a = A[i*K + k];
|
||
|
// for (int j = jj; j < min(jj+block, N); j++)
|
||
|
// c[j] += a * b[j];
|
||
|
// }
|
||
|
//
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// else{
|
||
|
// for(int ii = index_m ; ii < index_m + m + m_mod; ii+=block){
|
||
|
// for (int jj = 0; jj < N; jj += block){
|
||
|
// for (int kk = 0; kk < K; kk += block){
|
||
|
//
|
||
|
// for (int i = ii; i < min(ii+block, index_m + m + m_mod); i++){
|
||
|
// float * c = C + i * N;
|
||
|
//
|
||
|
//// for (int j = jj; j < min(jj+block, N); j++)
|
||
|
//// c[j] = 0;
|
||
|
//
|
||
|
// for (int k = kk; k < min(kk+block, K); ++k)
|
||
|
// {
|
||
|
// float * b = B + k * N;
|
||
|
// float a = A[i * K + k];
|
||
|
// for (int j = jj; j < min(jj+block, N); j++)
|
||
|
// c[j] += a * b[j];
|
||
|
// }
|
||
|
//
|
||
|
//
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
/************************************/
|
||
|
/* (9) thread split A Raw, Block */
|
||
|
// Avg. time: 0.066729 sec
|
||
|
// Avg. throughput: 4.022766 GFLOPS
|
||
|
|
||
|
/************************************/
|
||
|
// int block = 128;
|
||
|
//
|
||
|
// if((num_threads-1) != thread_index){
|
||
|
// for(int ii = index_m ; ii < index_m + m; ii+=block){
|
||
|
// for (int jj = 0; jj < N; jj += block){
|
||
|
// for (int kk = 0; kk < K; kk += block){
|
||
|
//
|
||
|
// for (int i = ii; i < min(ii+block, index_m + m); i++){
|
||
|
// float * c = C + i * N;
|
||
|
//
|
||
|
// for (int k = kk; k < min(kk+block, K); ++k)
|
||
|
// {
|
||
|
// float * b = B + k * N;
|
||
|
// float a = A[i * K + k];
|
||
|
// if(jj+block >= N){
|
||
|
//
|
||
|
// for (int j = jj; j < min(jj+block, N); j++)
|
||
|
// c[j] += a * b[j];
|
||
|
// }
|
||
|
// else{
|
||
|
// // C[i * N + j] += A[i * K + k] * B[k * N + j];
|
||
|
// c[jj] += a * b[jj]; c[jj+1] += a * b[jj+1]; c[jj+2] += a * b[jj+2]; c[jj+3] += a * b[jj+3];
|
||
|
// c[jj+4] += a * b[jj+4]; c[jj+5] += a * b[jj+5]; c[jj+6] += a * b[jj+6]; c[jj+7] += a * b[jj+7];
|
||
|
// c[jj+8] += a * b[jj+8]; c[jj+9] += a * b[jj+9]; c[jj+10] += a * b[jj+10]; c[jj+11] += a * b[jj+11];
|
||
|
// c[jj+12] += a * b[jj+12]; c[jj+13] += a * b[jj+13]; c[jj+14] += a * b[jj+14]; c[jj+15] += a * b[jj+15];
|
||
|
// c[jj+16] += a * b[jj+16]; c[jj+17] += a * b[jj+17]; c[jj+18] += a * b[jj+18]; c[jj+19] += a * b[jj+19];
|
||
|
// c[jj+20] += a * b[jj+20]; c[jj+21] += a * b[jj+21]; c[jj+22] += a * b[jj+22]; c[jj+23] += a * b[jj+23];
|
||
|
// c[jj+24] += a * b[jj+24]; c[jj+25] += a * b[jj+25]; c[jj+26] += a * b[jj+26]; c[jj+27] += a * b[jj+27];
|
||
|
// c[jj+28] += a * b[jj+28]; c[jj+29] += a * b[jj+29]; c[jj+30] += a * b[jj+30]; c[jj+31] += a * b[jj+31];
|
||
|
//
|
||
|
// c[jj+32] += a * b[jj+32]; c[jj+33] += a * b[jj+33]; c[jj+34] += a * b[jj+34]; c[jj+35] += a * b[jj+35];
|
||
|
// c[jj+36] += a * b[jj+36]; c[jj+37] += a * b[jj+37]; c[jj+38] += a * b[jj+38]; c[jj+39] += a * b[jj+39];
|
||
|
// c[jj+40] += a * b[jj+40]; c[jj+41] += a * b[jj+41]; c[jj+42] += a * b[jj+42]; c[jj+43] += a * b[jj+43];
|
||
|
// c[jj+44] += a * b[jj+44]; c[jj+45] += a * b[jj+45]; c[jj+46] += a * b[jj+46]; c[jj+47] += a * b[jj+47];
|
||
|
// c[jj+48] += a * b[jj+48]; c[jj+49] += a * b[jj+49]; c[jj+50] += a * b[jj+50]; c[jj+51] += a * b[jj+51];
|
||
|
// c[jj+52] += a * b[jj+52]; c[jj+53] += a * b[jj+53]; c[jj+54] += a * b[jj+54]; c[jj+55] += a * b[jj+55];
|
||
|
// c[jj+56] += a * b[jj+56]; c[jj+57] += a * b[jj+57]; c[jj+58] += a * b[jj+58]; c[jj+59] += a * b[jj+59];
|
||
|
// c[jj+60] += a * b[jj+60]; c[jj+61] += a * b[jj+61]; c[jj+62] += a * b[jj+62]; c[jj+63] += a * b[jj+63];
|
||
|
//
|
||
|
// c[jj+64] += a * b[jj+64]; c[jj+65] += a * b[jj+65]; c[jj+66] += a * b[jj+66]; c[jj+67] += a * b[jj+67];
|
||
|
// c[jj+68] += a * b[jj+68]; c[jj+69] += a * b[jj+69]; c[jj+70] += a * b[jj+70]; c[jj+71] += a * b[jj+71];
|
||
|
// c[jj+72] += a * b[jj+72]; c[jj+73] += a * b[jj+73]; c[jj+74] += a * b[jj+74]; c[jj+75] += a * b[jj+75];
|
||
|
// c[jj+76] += a * b[jj+76]; c[jj+77] += a * b[jj+77]; c[jj+78] += a * b[jj+78]; c[jj+79] += a * b[jj+79];
|
||
|
// c[jj+80] += a * b[jj+80]; c[jj+81] += a * b[jj+81]; c[jj+82] += a * b[jj+82]; c[jj+83] += a * b[jj+83];
|
||
|
// c[jj+84] += a * b[jj+84]; c[jj+85] += a * b[jj+85]; c[jj+86] += a * b[jj+86]; c[jj+87] += a * b[jj+87];
|
||
|
// c[jj+88] += a * b[jj+88]; c[jj+89] += a * b[jj+89]; c[jj+90] += a * b[jj+90]; c[jj+91] += a * b[jj+91];
|
||
|
// c[jj+92] += a * b[jj+92]; c[jj+93] += a * b[jj+93]; c[jj+94] += a * b[jj+94]; c[jj+95] += a * b[jj+95];
|
||
|
//
|
||
|
// c[jj+96] += a * b[jj+96]; c[jj+97] += a * b[jj+97]; c[jj+98] += a * b[jj+98]; c[jj+99] += a * b[jj+99];
|
||
|
// c[jj+100] += a * b[jj+100]; c[jj+101] += a * b[jj+101]; c[jj+102] += a * b[jj+102]; c[jj+103] += a * b[jj+103];
|
||
|
// c[jj+104] += a * b[jj+104]; c[jj+105] += a * b[jj+105]; c[jj+106] += a * b[jj+106]; c[jj+107] += a * b[jj+107];
|
||
|
// c[jj+108] += a * b[jj+108]; c[jj+109] += a * b[jj+109]; c[jj+110] += a * b[jj+110]; c[jj+111] += a * b[jj+111];
|
||
|
// c[jj+112] += a * b[jj+112]; c[jj+113] += a * b[jj+113]; c[jj+114] += a * b[jj+114]; c[jj+115] += a * b[jj+115];
|
||
|
// c[jj+116] += a * b[jj+116]; c[jj+117] += a * b[jj+117]; c[jj+118] += a * b[jj+118]; c[jj+119] += a * b[jj+119];
|
||
|
// c[jj+120] += a * b[jj+120]; c[jj+121] += a * b[jj+121]; c[jj+122] += a * b[jj+122]; c[jj+123] += a * b[jj+123];
|
||
|
// c[jj+124] += a * b[jj+124]; c[jj+125] += a * b[jj+125]; c[jj+126] += a * b[jj+126]; c[jj+127] += a * b[jj+127];
|
||
|
// }
|
||
|
// }
|
||
|
//
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// else{
|
||
|
// for(int ii = index_m ; ii < index_m + m + m_mod; ii+=block){
|
||
|
// for (int jj = 0; jj < N; jj += block){
|
||
|
// for (int kk = 0; kk < K; kk += block){
|
||
|
//
|
||
|
// for (int i = ii; i < min(ii+block, index_m + m + m_mod); i++){
|
||
|
// float * c = C + i * N;
|
||
|
//
|
||
|
// for (int k = kk; k < min(kk+block, K); ++k)
|
||
|
// {
|
||
|
// float * b = B + k * N;
|
||
|
// float a = A[i * K + k];
|
||
|
// if(jj+block >= N){
|
||
|
//
|
||
|
// for (int j = jj; j < min(jj+block, N); j++)
|
||
|
// c[j] += a * b[j];
|
||
|
// }
|
||
|
// else{
|
||
|
// // C[i * N + j] += A[i * K + k] * B[k * N + j];
|
||
|
// c[jj] += a * b[jj]; c[jj+1] += a * b[jj+1]; c[jj+2] += a * b[jj+2]; c[jj+3] += a * b[jj+3];
|
||
|
// c[jj+4] += a * b[jj+4]; c[jj+5] += a * b[jj+5]; c[jj+6] += a * b[jj+6]; c[jj+7] += a * b[jj+7];
|
||
|
// c[jj+8] += a * b[jj+8]; c[jj+9] += a * b[jj+9]; c[jj+10] += a * b[jj+10]; c[jj+11] += a * b[jj+11];
|
||
|
// c[jj+12] += a * b[jj+12]; c[jj+13] += a * b[jj+13]; c[jj+14] += a * b[jj+14]; c[jj+15] += a * b[jj+15];
|
||
|
// c[jj+16] += a * b[jj+16]; c[jj+17] += a * b[jj+17]; c[jj+18] += a * b[jj+18]; c[jj+19] += a * b[jj+19];
|
||
|
// c[jj+20] += a * b[jj+20]; c[jj+21] += a * b[jj+21]; c[jj+22] += a * b[jj+22]; c[jj+23] += a * b[jj+23];
|
||
|
// c[jj+24] += a * b[jj+24]; c[jj+25] += a * b[jj+25]; c[jj+26] += a * b[jj+26]; c[jj+27] += a * b[jj+27];
|
||
|
// c[jj+28] += a * b[jj+28]; c[jj+29] += a * b[jj+29]; c[jj+30] += a * b[jj+30]; c[jj+31] += a * b[jj+31];
|
||
|
//
|
||
|
// c[jj+32] += a * b[jj+32]; c[jj+33] += a * b[jj+33]; c[jj+34] += a * b[jj+34]; c[jj+35] += a * b[jj+35];
|
||
|
// c[jj+36] += a * b[jj+36]; c[jj+37] += a * b[jj+37]; c[jj+38] += a * b[jj+38]; c[jj+39] += a * b[jj+39];
|
||
|
// c[jj+40] += a * b[jj+40]; c[jj+41] += a * b[jj+41]; c[jj+42] += a * b[jj+42]; c[jj+43] += a * b[jj+43];
|
||
|
// c[jj+44] += a * b[jj+44]; c[jj+45] += a * b[jj+45]; c[jj+46] += a * b[jj+46]; c[jj+47] += a * b[jj+47];
|
||
|
// c[jj+48] += a * b[jj+48]; c[jj+49] += a * b[jj+49]; c[jj+50] += a * b[jj+50]; c[jj+51] += a * b[jj+51];
|
||
|
// c[jj+52] += a * b[jj+52]; c[jj+53] += a * b[jj+53]; c[jj+54] += a * b[jj+54]; c[jj+55] += a * b[jj+55];
|
||
|
// c[jj+56] += a * b[jj+56]; c[jj+57] += a * b[jj+57]; c[jj+58] += a * b[jj+58]; c[jj+59] += a * b[jj+59];
|
||
|
// c[jj+60] += a * b[jj+60]; c[jj+61] += a * b[jj+61]; c[jj+62] += a * b[jj+62]; c[jj+63] += a * b[jj+63];
|
||
|
//
|
||
|
// c[jj+64] += a * b[jj+64]; c[jj+65] += a * b[jj+65]; c[jj+66] += a * b[jj+66]; c[jj+67] += a * b[jj+67];
|
||
|
// c[jj+68] += a * b[jj+68]; c[jj+69] += a * b[jj+69]; c[jj+70] += a * b[jj+70]; c[jj+71] += a * b[jj+71];
|
||
|
// c[jj+72] += a * b[jj+72]; c[jj+73] += a * b[jj+73]; c[jj+74] += a * b[jj+74]; c[jj+75] += a * b[jj+75];
|
||
|
// c[jj+76] += a * b[jj+76]; c[jj+77] += a * b[jj+77]; c[jj+78] += a * b[jj+78]; c[jj+79] += a * b[jj+79];
|
||
|
// c[jj+80] += a * b[jj+80]; c[jj+81] += a * b[jj+81]; c[jj+82] += a * b[jj+82]; c[jj+83] += a * b[jj+83];
|
||
|
// c[jj+84] += a * b[jj+84]; c[jj+85] += a * b[jj+85]; c[jj+86] += a * b[jj+86]; c[jj+87] += a * b[jj+87];
|
||
|
// c[jj+88] += a * b[jj+88]; c[jj+89] += a * b[jj+89]; c[jj+90] += a * b[jj+90]; c[jj+91] += a * b[jj+91];
|
||
|
// c[jj+92] += a * b[jj+92]; c[jj+93] += a * b[jj+93]; c[jj+94] += a * b[jj+94]; c[jj+95] += a * b[jj+95];
|
||
|
//
|
||
|
// c[jj+96] += a * b[jj+96]; c[jj+97] += a * b[jj+97]; c[jj+98] += a * b[jj+98]; c[jj+99] += a * b[jj+99];
|
||
|
// c[jj+100] += a * b[jj+100]; c[jj+101] += a * b[jj+101]; c[jj+102] += a * b[jj+102]; c[jj+103] += a * b[jj+103];
|
||
|
// c[jj+104] += a * b[jj+104]; c[jj+105] += a * b[jj+105]; c[jj+106] += a * b[jj+106]; c[jj+107] += a * b[jj+107];
|
||
|
// c[jj+108] += a * b[jj+108]; c[jj+109] += a * b[jj+109]; c[jj+110] += a * b[jj+110]; c[jj+111] += a * b[jj+111];
|
||
|
// c[jj+112] += a * b[jj+112]; c[jj+113] += a * b[jj+113]; c[jj+114] += a * b[jj+114]; c[jj+115] += a * b[jj+115];
|
||
|
// c[jj+116] += a * b[jj+116]; c[jj+117] += a * b[jj+117]; c[jj+118] += a * b[jj+118]; c[jj+119] += a * b[jj+119];
|
||
|
// c[jj+120] += a * b[jj+120]; c[jj+121] += a * b[jj+121]; c[jj+122] += a * b[jj+122]; c[jj+123] += a * b[jj+123];
|
||
|
// c[jj+124] += a * b[jj+124]; c[jj+125] += a * b[jj+125]; c[jj+126] += a * b[jj+126]; c[jj+127] += a * b[jj+127];
|
||
|
// }
|
||
|
// }
|
||
|
//
|
||
|
//
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
|
||
|
/************************************/
|
||
|
/* (10) thread split A Raw, Block, AVX */
|
||
|
// Avg. time: 0.018701 sec
|
||
|
// Avg. throughput: 14.354194 GFLOPS
|
||
|
// Performance : 149.351403 GFLOPS (128 block)
|
||
|
/************************************/
|
||
|
|
||
|
// int block = 128;
|
||
|
//
|
||
|
// if((num_threads-1) != thread_index){
|
||
|
// for(int ii = index_m ; ii < index_m + m; ii+=block){
|
||
|
// for (int jj = 0; jj < N; jj += block){
|
||
|
// for (int kk = 0; kk < K; kk += block){
|
||
|
//
|
||
|
// for (int i = ii; i < min(ii+block, index_m + m); i++){
|
||
|
// float * c = C + i * N;
|
||
|
//
|
||
|
// for (int k = kk; k < min(kk+block, K); ++k)
|
||
|
// {
|
||
|
// float * b = B + k * N;
|
||
|
//
|
||
|
// if(jj+block > N){
|
||
|
// float a = A[i*K + k];
|
||
|
// for (int j = jj; j < min(jj+block, N); j++){
|
||
|
// //printf("Normal Start\n");
|
||
|
// c[j] += a * b[j];
|
||
|
// //_mm256_storeu_ps(c + j + 0, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 0), _mm256_loadu_ps(c + j + 0)));
|
||
|
// //_mm256_storeu_ps(c + j + 8, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 8), _mm256_loadu_ps(c + j + 8)));
|
||
|
// }
|
||
|
// }
|
||
|
// else{
|
||
|
// __m256 a = _mm256_set1_ps(A[i * K + k]);
|
||
|
// for (int j = jj; j < min(jj+block, N); j+=8){
|
||
|
// //c[j] += a * b[j];
|
||
|
// //printf("AVX Start\n");
|
||
|
// _mm256_storeu_ps(c + j + 0, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 0), _mm256_loadu_ps(c + j + 0)));
|
||
|
// //_mm256_storeu_ps(c + j + 8, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 8), _mm256_loadu_ps(c + j + 8)));
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
//
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// else{
|
||
|
// for(int ii = index_m ; ii < index_m + m + m_mod; ii+=block){
|
||
|
// for (int jj = 0; jj < N; jj += block){
|
||
|
// for (int kk = 0; kk < K; kk += block){
|
||
|
//
|
||
|
// for (int i = ii; i < min(ii+block, index_m + m + m_mod); i++){
|
||
|
// float * c = C + i * N;
|
||
|
//
|
||
|
// for (int k = kk; k < min(kk+block, K); ++k)
|
||
|
// {
|
||
|
// float * b = B + k * N;
|
||
|
//
|
||
|
// if(jj+block > N){
|
||
|
// float a = A[i * K + k];
|
||
|
// for (int j = jj; j < min(jj+block, N); j++){
|
||
|
// //printf("Normal Start\n");
|
||
|
// c[j] += a * b[j];
|
||
|
// //_mm256_storeu_ps(c + j + 0, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 0), _mm256_loadu_ps(c + j + 0)));
|
||
|
// //_mm256_storeu_ps(c + j + 8, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 8), _mm256_loadu_ps(c + j + 8)));
|
||
|
// }
|
||
|
// }
|
||
|
// else{
|
||
|
// __m256 a = _mm256_set1_ps(A[i * K + k]);
|
||
|
// for (int j = jj; j < min(jj+block, N); j+=8){
|
||
|
// //c[j] += a * b[j];
|
||
|
// //printf("AVX Start\n");
|
||
|
// _mm256_storeu_ps(c + j + 0, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 0), _mm256_loadu_ps(c + j + 0)));
|
||
|
// //_mm256_storeu_ps(c + j + 8, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 8), _mm256_loadu_ps(c + j + 8)));
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
//
|
||
|
//
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
|
||
|
/************************************/
|
||
|
/* (11) thread split A Raw, Block, unrolling */
|
||
|
|
||
|
// Performance : 161.073243 GFLOPS (128 block)
|
||
|
// Performance : 164.608178 GFLOPS (64 block/8 unrolling)
|
||
|
// Performance : 174.608178 GFLOPS (64 block/16 unrolling)
|
||
|
// Performance : 176.608178 GFLOPS (64 block/32 unrolling)
|
||
|
/************************************/
|
||
|
int block = 64;
|
||
|
|
||
|
if((num_threads-1) != thread_index){
|
||
|
for(int ii = index_m ; ii < index_m + m; ii+=block){
|
||
|
int m_index_min = (ii+block < index_m + m) ? ii+block : index_m + m;
|
||
|
|
||
|
for (int jj = 0; jj < N; jj += block){
|
||
|
for (int kk = 0; kk < K; kk += block){
|
||
|
|
||
|
//for (int i = ii; i < min(ii+block, index_m + m); i++){
|
||
|
for (int i = ii; i < m_index_min; i++){
|
||
|
float * c = C + i * N;
|
||
|
|
||
|
int k_index_min = (kk+block < K) ? kk+block : K;
|
||
|
for (int k = kk; k < k_index_min; ++k)
|
||
|
{
|
||
|
float * b = B + k * N;
|
||
|
float a = A[i*K + k];
|
||
|
|
||
|
int j_index_min = (jj+block < N) ? jj+block : N;
|
||
|
//for (int j = jj; j < j_index_min; j++)
|
||
|
//for (int j = jj; j < min(jj+block, N); j++)
|
||
|
// c[j] += a * b[j];
|
||
|
|
||
|
if(jj+block > N){
|
||
|
for (int j = jj; j < N; j++){
|
||
|
c[j] += a * b[j];
|
||
|
}
|
||
|
}
|
||
|
else{
|
||
|
for (int j = jj; j < j_index_min; j+=16)
|
||
|
{
|
||
|
c[j] += a * b[j]; c[j+1] += a * b[j+1]; c[j+2] += a * b[j+2]; c[j+3] += a * b[j+3];
|
||
|
c[j+4] += a * b[j+4]; c[j+5] += a * b[j+5]; c[j+6] += a * b[j+6]; c[j+7] += a * b[j+7];
|
||
|
c[j+8] += a * b[j+8]; c[j+9] += a * b[j+9]; c[j+10] += a * b[j+10]; c[j+11] += a * b[j+11];
|
||
|
c[j+12] += a * b[j+12]; c[j+13] += a * b[j+13]; c[j+14] += a * b[j+14]; c[j+15] += a * b[j+15];
|
||
|
// c[j+16] += a * b[j+16]; c[j+17] += a * b[j+17]; c[j+18] += a * b[j+18]; c[j+19] += a * b[j+19];
|
||
|
// c[j+20] += a * b[j+20]; c[j+21] += a * b[j+21]; c[j+22] += a * b[j+22]; c[j+23] += a * b[j+23];
|
||
|
// c[j+24] += a * b[j+24]; c[j+25] += a * b[j+25]; c[j+26] += a * b[j+26]; c[j+27] += a * b[j+27];
|
||
|
// c[j+28] += a * b[j+28]; c[j+29] += a * b[j+29]; c[j+30] += a * b[j+30]; c[j+31] += a * b[j+31];
|
||
|
}
|
||
|
// __m256 a = _mm256_set1_ps(A[i * K + k]);
|
||
|
// for (int j = jj; j < j_index_min; j+=8){
|
||
|
// //c[j] += a * b[j];
|
||
|
// //printf("AVX Start\n");
|
||
|
// _mm256_storeu_ps(c + j + 0, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 0), _mm256_loadu_ps(c + j + 0)));
|
||
|
// //_mm256_storeu_ps(c + j + 8, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 8), _mm256_loadu_ps(c + j + 8)));
|
||
|
// }
|
||
|
}
|
||
|
}
|
||
|
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
else{
|
||
|
for(int ii = index_m ; ii < index_m + m + m_mod; ii+=block){
|
||
|
int m_index_min = (ii+block < index_m + m + m_mod) ? ii+block : index_m + m+m_mod;
|
||
|
|
||
|
for (int jj = 0; jj < N; jj += block){
|
||
|
for (int kk = 0; kk < K; kk += block){
|
||
|
|
||
|
//for (int i = ii; i < min(ii+block, index_m + m +m_mod); i++){
|
||
|
for (int i = ii; i < m_index_min; i++){
|
||
|
float * c = C + i * N;
|
||
|
|
||
|
int k_index_min = (kk+block < K) ? kk+block : K;
|
||
|
for (int k = kk; k < k_index_min; ++k)
|
||
|
//for (int k = kk; k < min(kk+block, K); ++k)
|
||
|
{
|
||
|
float * b = B + k * N;
|
||
|
float a = A[i * K + k];
|
||
|
|
||
|
int j_index_min = (jj+block < N) ? jj+block : N;
|
||
|
//for (int j = jj; j < j_index_min; j++)
|
||
|
//for (int j = jj; j < min(jj+block, N); j++)
|
||
|
// c[j] += a * b[j];
|
||
|
|
||
|
if(jj+block > N){
|
||
|
for (int j = jj; j < N; j++){
|
||
|
c[j] += a * b[j];
|
||
|
}
|
||
|
}
|
||
|
else{
|
||
|
for (int j = jj; j < j_index_min; j+=16)
|
||
|
{
|
||
|
c[j] += a * b[j]; c[j+1] += a * b[j+1]; c[j+2] += a * b[j+2]; c[j+3] += a * b[j+3];
|
||
|
c[j+4] += a * b[j+4]; c[j+5] += a * b[j+5]; c[j+6] += a * b[j+6]; c[j+7] += a * b[j+7];
|
||
|
c[j+8] += a * b[j+8]; c[j+9] += a * b[j+9]; c[j+10] += a * b[j+10]; c[j+11] += a * b[j+11];
|
||
|
c[j+12] += a * b[j+12]; c[j+13] += a * b[j+13]; c[j+14] += a * b[j+14]; c[j+15] += a * b[j+15];
|
||
|
// c[j+16] += a * b[j+16]; c[j+17] += a * b[j+17]; c[j+18] += a * b[j+18]; c[j+19] += a * b[j+19];
|
||
|
// c[j+20] += a * b[j+20]; c[j+21] += a * b[j+21]; c[j+22] += a * b[j+22]; c[j+23] += a * b[j+23];
|
||
|
// c[j+24] += a * b[j+24]; c[j+25] += a * b[j+25]; c[j+26] += a * b[j+26]; c[j+27] += a * b[j+27];
|
||
|
// c[j+28] += a * b[j+28]; c[j+29] += a * b[j+29]; c[j+30] += a * b[j+30]; c[j+31] += a * b[j+31];
|
||
|
}
|
||
|
// __m256 a = _mm256_set1_ps(A[i * K + k]);
|
||
|
// for (int j = jj; j < j_index_min; j+=8){
|
||
|
// //c[j] += a * b[j];
|
||
|
// //printf("AVX Start\n");
|
||
|
// _mm256_storeu_ps(c + j + 0, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 0), _mm256_loadu_ps(c + j + 0)));
|
||
|
// //_mm256_storeu_ps(c + j + 8, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 8), _mm256_loadu_ps(c + j + 8)));
|
||
|
// }
|
||
|
}
|
||
|
|
||
|
}
|
||
|
|
||
|
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return NULL;
|
||
|
}
|
||
|
|
||
|
void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads) {
|
||
|
A = _A, B = _B, C = _C;
|
||
|
M = _M, N = _N, K = _K;
|
||
|
num_threads = _num_threads;
|
||
|
|
||
|
// TODO: create '_num_threads' pthreads
|
||
|
void *res;
|
||
|
struct thread_info *tinfo = (thread_info*)calloc(num_threads, sizeof(*tinfo));
|
||
|
|
||
|
for (int tnum = 0; tnum < num_threads; tnum++) {
|
||
|
tinfo[tnum].thread_num = tnum;
|
||
|
// printf("Thread Create : %d \n", tinfo[tnum].thread_num);
|
||
|
pthread_create(&tinfo[tnum].thread_id, NULL, mat_mul_thread, &tinfo[tnum]);
|
||
|
|
||
|
}
|
||
|
|
||
|
for (int tnum = 0; tnum < num_threads; tnum++) {
|
||
|
pthread_join(tinfo[tnum].thread_id, &res);
|
||
|
// printf("Thread Joined : %d \n", tinfo[tnum].thread_num);
|
||
|
}
|
||
|
// pthread_t thread;
|
||
|
// pthread_create(&thread, NULL, mat_mul_thread, NULL);
|
||
|
// pthread_join(thread, NULL);
|
||
|
}
|