chundoong-lab-ta/SamsungDS22/submissions/HW2/c.w.son/mat_mul.cpp

759 lines
33 KiB
C++

#include "mat_mul.h"
#include <cstdlib>
#include <cstdio>
#include <pthread.h>
#include <math.h>
#include <unistd.h>
#include <immintrin.h>
static float *A, *B, *C;
static int M, N, K;
static int num_threads;
struct thread_info { /* Used as argument to thread_start() */
pthread_t thread_id; /* ID returned by pthread_create() */
int thread_num; /* Application-defined thread # */
};
int min(int a, int b){
if (a < b)
return a;
else
return b;
}
static void* mat_mul_thread(void *data) {
// TODO: parallelize & optimize matrix multiplication
struct thread_info *tinfo = (thread_info*)data;
int thread_index = tinfo->thread_num;
// printf("Thread Num : %d \n", thread_index);
int m = M/num_threads;
// int n = N/num_threads;
int index_m = m * thread_index;
// int index_n = n * thread_index;
int m_mod = M%num_threads;
// int n_mod = N%num_threads;
// __m256 a0;
// __m256 b0;
// __m256 s0;
// s0 = _mm256_set_ps(0., 0., 0., 0., 0., 0., 0., 0.);
//
// int element_num = 64;
//
// for(int i=0; i< element_num /8 ; ++i){
// a0 = _mm256_load_ps(a + i * 8);
// b0 = _mm256_load_ps(b + i * 8);
// s0 = _mm256_fmadd_ps(a0, b0 ,s0);
// }
// float c = s0[0] + s0[1] + s0[2] + s0[3] + s0[4] + s0[5] + s0[6] + s0[7];
// printf("AVX : %f \n", c );
// printf("M para : %d, %d, %d \n", m , index_m, m_mod);
// printf("N para : %d, %d, %d \n", n , index_n, n_mod);
/*********************************************/
/* (1) basic(@single Thread) */
// Avg. time: 0.303500 sec
// Avg. throughput: 0.884466 GFLOPS
/*********************************************/
// for (int i = 0; i < M; ++i) {
// for (int j = 0; j < N; ++j) {
// for (int k = 0; k < K; ++k) {
// C[i * N + j] += A[i * K + k] * B[k * N + j];
// }
// }
// }
/*******************************************/
/* (2) block(@single Thead) */
/*******************************************/
// int block = 4;
//
// for (int ii = 0; ii < M; ii += block)
// for (int jj = 0; jj < N; jj += block)
// for (int kk = 0; kk < K; kk += block)
// for (int i = ii; i < min(ii+block, M); i++)
// for (int j = jj; j < min(jj+block, N); j++)
// for (int k = kk; k < min(kk+block, K); k++)
// C[i * N + j] += A[i * K + k] * B[k * N + j];
/*****************************/
/* (3) thread split A Raw(2 thread) */
// Avg. time: 0.177479 sec
// Avg. throughput: 1.512491 GFLOPS
/*****************************/
// int block = 4;
// int block_index;
//
//
// if((num_threads-1) != thread_index){
// for(int i = index_m ; i < index_m + m; ++i)
// {
// for(int j = 0; j < N; ++j)
// {
// for (int k = 0; k < K; ++k){
// C[i * N + j] += A[i * K + k] * B[k * N + j];
// }
// }
// }
// }
// else{
// for(int i = index_m ; i < index_m + m + m_mod; ++i)
// {
// for(int j = 0; j < N; ++j)
// {
// for (int k = 0; k < K; ++k){
// C[i * N + j] += A[i * K + k] * B[k * N + j];
// }
// }
// }
// }
/***********************************************/
/* (4) thread split A Raw, Block */
/* Avg. time: 5.003404 sec */
/* Avg. throughput: 27.469088 GFLOPS */
/***********************************************/
// int block = 8;
// int block_index;
//
// if((num_threads-1) != thread_index){
// for(int ii = index_m ; ii < index_m + m; ii+=block){
// for (int jj = 0; jj < N; jj += block){
// for (int kk = 0; kk < K; kk += block){
// for (int i = ii; i < min(ii+block, index_m + m); i++){
// for (int j = jj; j < min(jj+block, N); j++){
// for (int k = kk; k < min(kk+block, K); k++){
// C[i * N + j] += A[i * K + k] * B[k * N + j];
// }
// }
// }
// }
// }
// }
//
// }
// else{
// for(int ii = index_m ; ii < index_m + m + m_mod; ii+=block){
// for (int jj = 0; jj < N; jj += block){
// for (int kk = 0; kk < K; kk += block){
// for (int i = ii; i < min(ii+block, index_m + m + m_mod); i++){
// for (int j = jj; j < min(jj+block, N); j++){
//
// for (int k = kk; k < min(kk+block, K); k++){
// C[i * N + j] += A[i * K + k] * B[k * N + j];
// }
//
// }
// }
// }
// }
// }
// }
/************************************/
/* (5) thread split A Raw, Block , unrolling */
//Avg. time: 4.162063 sec
//Avg. throughput: 33.021836 GFLOPS
/************************************/
//int block = 8; // 4.846sec
// int block = 16; // 4.439sec / 32.3 GFLOPS
//
// int block_index;
// int in;
// int ik;
// float temp;
//
// if((num_threads-1) != thread_index){
// for(int ii = index_m ; ii < index_m + m; ii+=block){
// for (int jj = 0; jj < N; jj += block){
// for (int kk = 0; kk < K; kk += block){
// for (int i = ii; i < min(ii+block, index_m + m); i++){
// in = i * N;
// ik = i * K;
// for (int j = jj; j < min(jj+block, N); j++){
// if(kk+block >= K){
// for (int k = kk; k < min(kk+block, K); k++){
// C[i * N + j] += A[i * K + k] * B[k * N + j];
// }
// }
// else{
// temp = C[in + j];
// temp += A[ik + (kk)] * B[(kk) * N + j];
// temp += A[ik + (kk+1)] * B[(kk+1) * N + j];
// temp += A[ik + (kk+2)] * B[(kk+2) * N + j];
// temp += A[ik + (kk+3)] * B[(kk+3) * N + j];
// temp += A[ik + (kk+4)] * B[(kk+4) * N + j];
// temp += A[ik + (kk+5)] * B[(kk+5) * N + j];
// temp += A[ik + (kk+6)] * B[(kk+6) * N + j];
// temp += A[ik + (kk+7)] * B[(kk+7) * N + j];
// temp += A[ik + (kk+8)] * B[(kk+8) * N + j];
// temp += A[ik + (kk+9)] * B[(kk+9) * N + j];
// temp += A[ik + (kk+10)] * B[(kk+10) * N + j];
// temp += A[ik + (kk+11)] * B[(kk+11) * N + j];
// temp += A[ik + (kk+12)] * B[(kk+12) * N + j];
// temp += A[ik + (kk+13)] * B[(kk+13) * N + j];
// temp += A[ik + (kk+14)] * B[(kk+14) * N + j];
// temp += A[ik + (kk+15)] * B[(kk+15) * N + j];
// C[in + j] = temp;
// }
// }
// }
// }
// }
// }
// }
// else{
// for(int ii = index_m ; ii < index_m + m + m_mod; ii+=block){
// for (int jj = 0; jj < N; jj += block){
// for (int kk = 0; kk < K; kk += block){
// for (int i = ii; i < min(ii+block, index_m + m + m_mod); i++){
// in = i * N;
// ik = i * K;
// for (int j = jj; j < min(jj+block, N); j++){
// if(kk+block >= K){
// for (int k = kk; k < min(kk+block, K); k++){
// C[i * N + j] += A[i * K + k] * B[k * N + j];
// }
// }
// else{
// temp = C[in + j];
// temp += A[ik + (kk)] * B[(kk) * N + j];
// temp += A[ik + (kk+1)] * B[(kk+1) * N + j];
// temp += A[ik + (kk+2)] * B[(kk+2) * N + j];
// temp += A[ik + (kk+3)] * B[(kk+3) * N + j];
// temp += A[ik + (kk+4)] * B[(kk+4) * N + j];
// temp += A[ik + (kk+5)] * B[(kk+5) * N + j];
// temp += A[ik + (kk+6)] * B[(kk+6) * N + j];
// temp += A[ik + (kk+7)] * B[(kk+7) * N + j];
// temp += A[ik + (kk+8)] * B[(kk+8) * N + j];
// temp += A[ik + (kk+9)] * B[(kk+9) * N + j];
// temp += A[ik + (kk+10)] * B[(kk+10) * N + j];
// temp += A[ik + (kk+11)] * B[(kk+11) * N + j];
// temp += A[ik + (kk+12)] * B[(kk+12) * N + j];
// temp += A[ik + (kk+13)] * B[(kk+13) * N + j];
// temp += A[ik + (kk+14)] * B[(kk+14) * N + j];
// temp += A[ik + (kk+15)] * B[(kk+15) * N + j];
// C[in + j] = temp;
// }
// }
// }
// }
// }
// }
// }
/************************************/
/* (6) single Thread, addressing, column (512,512,512) */
//Avg. time: 0.033199 sec
//Avg. throughput: 8.085689 GFLOPS
//
/************************************/
// for (int i = 0; i < M; ++i)
// {
// float * c = C + i * N;
// for (int j = 0; j < N; ++j)
// c[j] = 0;
// for (int k = 0; k < K; ++k)
// {
// const float * b = B + k * N;
// float a = A[i*K + k];
// for (int j = 0; j < N; ++j)
// c[j] += a * b[j];
// }
// }
/************************************/
/* (7) 2 Thread, (512,512,512) */
// Avg. time: 0.019109 sec (2 thread)
// Avg. throughput: 14.047585 GFLOPS (2 thread)
// Performance : 91 GFLOPS
/************************************/
// int block = 4;
// int block_index;
//
//
// if((num_threads-1) != thread_index){
// for(int i = index_m ; i < index_m + m; ++i)
// {
// float * c = C + i * N;
// for (int j = 0; j < N; ++j)
// c[j] = 0;
// for (int k = 0; k < K; ++k)
// {
//// const float * b = B + k * N;
// float * b = B + k * N;
// float a = A[i*K + k];
// for (int j = 0; j < N; ++j)
// c[j] += a * b[j];
// }
//// for(int j = 0; j < N; ++j)
//// {
//// for (int k = 0; k < K; ++k){
//// C[i * N + j] += A[i * K + k] * B[k * N + j];
//// }
//// }
// }
// }
// else{
// for(int i = index_m ; i < index_m + m + m_mod; ++i)
// {
// float * c = C + i * N;
// for (int j = 0; j < N; ++j)
// c[j] = 0;
// for (int k = 0; k < K; ++k)
// {
//// const float * b = B + k * N;
// float * b = B + k * N;
// float a = A[i*K + k];
// for (int j = 0; j < N; ++j)
// c[j] += a * b[j];
// }
// // for(int j = 0; j < N; ++j)
// // {
// // for (int k = 0; k < K; ++k){
// // C[i * N + j] += A[i * K + k] * B[k * N + j];
// // }
// // }
// }
// }
/************************************/
/* (8) thread split A Raw, Block */
// Avg. time: 0.018701 sec
// Avg. throughput: 14.354194 GFLOPS
// Performance : 163.489443 GFLOPS (128 block)
/************************************/
// int block = 128;
//
// if((num_threads-1) != thread_index){
// for(int ii = index_m ; ii < index_m + m; ii+=block){
// for (int jj = 0; jj < N; jj += block){
// for (int kk = 0; kk < K; kk += block){
//
// for (int i = ii; i < min(ii+block, index_m + m); i++){
// float * c = C + i * N;
//
//// for (int j = jj; j < min(jj+block, N); j++)
//// c[j] = 0;
//
// for (int k = kk; k < min(kk+block, K); ++k)
// {
// float * b = B + k * N;
// float a = A[i*K + k];
// for (int j = jj; j < min(jj+block, N); j++)
// c[j] += a * b[j];
// }
//
// }
// }
// }
// }
// }
// else{
// for(int ii = index_m ; ii < index_m + m + m_mod; ii+=block){
// for (int jj = 0; jj < N; jj += block){
// for (int kk = 0; kk < K; kk += block){
//
// for (int i = ii; i < min(ii+block, index_m + m + m_mod); i++){
// float * c = C + i * N;
//
//// for (int j = jj; j < min(jj+block, N); j++)
//// c[j] = 0;
//
// for (int k = kk; k < min(kk+block, K); ++k)
// {
// float * b = B + k * N;
// float a = A[i * K + k];
// for (int j = jj; j < min(jj+block, N); j++)
// c[j] += a * b[j];
// }
//
//
// }
// }
// }
// }
// }
/************************************/
/* (9) thread split A Raw, Block */
// Avg. time: 0.066729 sec
// Avg. throughput: 4.022766 GFLOPS
/************************************/
// int block = 128;
//
// if((num_threads-1) != thread_index){
// for(int ii = index_m ; ii < index_m + m; ii+=block){
// for (int jj = 0; jj < N; jj += block){
// for (int kk = 0; kk < K; kk += block){
//
// for (int i = ii; i < min(ii+block, index_m + m); i++){
// float * c = C + i * N;
//
// for (int k = kk; k < min(kk+block, K); ++k)
// {
// float * b = B + k * N;
// float a = A[i * K + k];
// if(jj+block >= N){
//
// for (int j = jj; j < min(jj+block, N); j++)
// c[j] += a * b[j];
// }
// else{
// // C[i * N + j] += A[i * K + k] * B[k * N + j];
// c[jj] += a * b[jj]; c[jj+1] += a * b[jj+1]; c[jj+2] += a * b[jj+2]; c[jj+3] += a * b[jj+3];
// c[jj+4] += a * b[jj+4]; c[jj+5] += a * b[jj+5]; c[jj+6] += a * b[jj+6]; c[jj+7] += a * b[jj+7];
// c[jj+8] += a * b[jj+8]; c[jj+9] += a * b[jj+9]; c[jj+10] += a * b[jj+10]; c[jj+11] += a * b[jj+11];
// c[jj+12] += a * b[jj+12]; c[jj+13] += a * b[jj+13]; c[jj+14] += a * b[jj+14]; c[jj+15] += a * b[jj+15];
// c[jj+16] += a * b[jj+16]; c[jj+17] += a * b[jj+17]; c[jj+18] += a * b[jj+18]; c[jj+19] += a * b[jj+19];
// c[jj+20] += a * b[jj+20]; c[jj+21] += a * b[jj+21]; c[jj+22] += a * b[jj+22]; c[jj+23] += a * b[jj+23];
// c[jj+24] += a * b[jj+24]; c[jj+25] += a * b[jj+25]; c[jj+26] += a * b[jj+26]; c[jj+27] += a * b[jj+27];
// c[jj+28] += a * b[jj+28]; c[jj+29] += a * b[jj+29]; c[jj+30] += a * b[jj+30]; c[jj+31] += a * b[jj+31];
//
// c[jj+32] += a * b[jj+32]; c[jj+33] += a * b[jj+33]; c[jj+34] += a * b[jj+34]; c[jj+35] += a * b[jj+35];
// c[jj+36] += a * b[jj+36]; c[jj+37] += a * b[jj+37]; c[jj+38] += a * b[jj+38]; c[jj+39] += a * b[jj+39];
// c[jj+40] += a * b[jj+40]; c[jj+41] += a * b[jj+41]; c[jj+42] += a * b[jj+42]; c[jj+43] += a * b[jj+43];
// c[jj+44] += a * b[jj+44]; c[jj+45] += a * b[jj+45]; c[jj+46] += a * b[jj+46]; c[jj+47] += a * b[jj+47];
// c[jj+48] += a * b[jj+48]; c[jj+49] += a * b[jj+49]; c[jj+50] += a * b[jj+50]; c[jj+51] += a * b[jj+51];
// c[jj+52] += a * b[jj+52]; c[jj+53] += a * b[jj+53]; c[jj+54] += a * b[jj+54]; c[jj+55] += a * b[jj+55];
// c[jj+56] += a * b[jj+56]; c[jj+57] += a * b[jj+57]; c[jj+58] += a * b[jj+58]; c[jj+59] += a * b[jj+59];
// c[jj+60] += a * b[jj+60]; c[jj+61] += a * b[jj+61]; c[jj+62] += a * b[jj+62]; c[jj+63] += a * b[jj+63];
//
// c[jj+64] += a * b[jj+64]; c[jj+65] += a * b[jj+65]; c[jj+66] += a * b[jj+66]; c[jj+67] += a * b[jj+67];
// c[jj+68] += a * b[jj+68]; c[jj+69] += a * b[jj+69]; c[jj+70] += a * b[jj+70]; c[jj+71] += a * b[jj+71];
// c[jj+72] += a * b[jj+72]; c[jj+73] += a * b[jj+73]; c[jj+74] += a * b[jj+74]; c[jj+75] += a * b[jj+75];
// c[jj+76] += a * b[jj+76]; c[jj+77] += a * b[jj+77]; c[jj+78] += a * b[jj+78]; c[jj+79] += a * b[jj+79];
// c[jj+80] += a * b[jj+80]; c[jj+81] += a * b[jj+81]; c[jj+82] += a * b[jj+82]; c[jj+83] += a * b[jj+83];
// c[jj+84] += a * b[jj+84]; c[jj+85] += a * b[jj+85]; c[jj+86] += a * b[jj+86]; c[jj+87] += a * b[jj+87];
// c[jj+88] += a * b[jj+88]; c[jj+89] += a * b[jj+89]; c[jj+90] += a * b[jj+90]; c[jj+91] += a * b[jj+91];
// c[jj+92] += a * b[jj+92]; c[jj+93] += a * b[jj+93]; c[jj+94] += a * b[jj+94]; c[jj+95] += a * b[jj+95];
//
// c[jj+96] += a * b[jj+96]; c[jj+97] += a * b[jj+97]; c[jj+98] += a * b[jj+98]; c[jj+99] += a * b[jj+99];
// c[jj+100] += a * b[jj+100]; c[jj+101] += a * b[jj+101]; c[jj+102] += a * b[jj+102]; c[jj+103] += a * b[jj+103];
// c[jj+104] += a * b[jj+104]; c[jj+105] += a * b[jj+105]; c[jj+106] += a * b[jj+106]; c[jj+107] += a * b[jj+107];
// c[jj+108] += a * b[jj+108]; c[jj+109] += a * b[jj+109]; c[jj+110] += a * b[jj+110]; c[jj+111] += a * b[jj+111];
// c[jj+112] += a * b[jj+112]; c[jj+113] += a * b[jj+113]; c[jj+114] += a * b[jj+114]; c[jj+115] += a * b[jj+115];
// c[jj+116] += a * b[jj+116]; c[jj+117] += a * b[jj+117]; c[jj+118] += a * b[jj+118]; c[jj+119] += a * b[jj+119];
// c[jj+120] += a * b[jj+120]; c[jj+121] += a * b[jj+121]; c[jj+122] += a * b[jj+122]; c[jj+123] += a * b[jj+123];
// c[jj+124] += a * b[jj+124]; c[jj+125] += a * b[jj+125]; c[jj+126] += a * b[jj+126]; c[jj+127] += a * b[jj+127];
// }
// }
//
// }
// }
// }
// }
// }
// else{
// for(int ii = index_m ; ii < index_m + m + m_mod; ii+=block){
// for (int jj = 0; jj < N; jj += block){
// for (int kk = 0; kk < K; kk += block){
//
// for (int i = ii; i < min(ii+block, index_m + m + m_mod); i++){
// float * c = C + i * N;
//
// for (int k = kk; k < min(kk+block, K); ++k)
// {
// float * b = B + k * N;
// float a = A[i * K + k];
// if(jj+block >= N){
//
// for (int j = jj; j < min(jj+block, N); j++)
// c[j] += a * b[j];
// }
// else{
// // C[i * N + j] += A[i * K + k] * B[k * N + j];
// c[jj] += a * b[jj]; c[jj+1] += a * b[jj+1]; c[jj+2] += a * b[jj+2]; c[jj+3] += a * b[jj+3];
// c[jj+4] += a * b[jj+4]; c[jj+5] += a * b[jj+5]; c[jj+6] += a * b[jj+6]; c[jj+7] += a * b[jj+7];
// c[jj+8] += a * b[jj+8]; c[jj+9] += a * b[jj+9]; c[jj+10] += a * b[jj+10]; c[jj+11] += a * b[jj+11];
// c[jj+12] += a * b[jj+12]; c[jj+13] += a * b[jj+13]; c[jj+14] += a * b[jj+14]; c[jj+15] += a * b[jj+15];
// c[jj+16] += a * b[jj+16]; c[jj+17] += a * b[jj+17]; c[jj+18] += a * b[jj+18]; c[jj+19] += a * b[jj+19];
// c[jj+20] += a * b[jj+20]; c[jj+21] += a * b[jj+21]; c[jj+22] += a * b[jj+22]; c[jj+23] += a * b[jj+23];
// c[jj+24] += a * b[jj+24]; c[jj+25] += a * b[jj+25]; c[jj+26] += a * b[jj+26]; c[jj+27] += a * b[jj+27];
// c[jj+28] += a * b[jj+28]; c[jj+29] += a * b[jj+29]; c[jj+30] += a * b[jj+30]; c[jj+31] += a * b[jj+31];
//
// c[jj+32] += a * b[jj+32]; c[jj+33] += a * b[jj+33]; c[jj+34] += a * b[jj+34]; c[jj+35] += a * b[jj+35];
// c[jj+36] += a * b[jj+36]; c[jj+37] += a * b[jj+37]; c[jj+38] += a * b[jj+38]; c[jj+39] += a * b[jj+39];
// c[jj+40] += a * b[jj+40]; c[jj+41] += a * b[jj+41]; c[jj+42] += a * b[jj+42]; c[jj+43] += a * b[jj+43];
// c[jj+44] += a * b[jj+44]; c[jj+45] += a * b[jj+45]; c[jj+46] += a * b[jj+46]; c[jj+47] += a * b[jj+47];
// c[jj+48] += a * b[jj+48]; c[jj+49] += a * b[jj+49]; c[jj+50] += a * b[jj+50]; c[jj+51] += a * b[jj+51];
// c[jj+52] += a * b[jj+52]; c[jj+53] += a * b[jj+53]; c[jj+54] += a * b[jj+54]; c[jj+55] += a * b[jj+55];
// c[jj+56] += a * b[jj+56]; c[jj+57] += a * b[jj+57]; c[jj+58] += a * b[jj+58]; c[jj+59] += a * b[jj+59];
// c[jj+60] += a * b[jj+60]; c[jj+61] += a * b[jj+61]; c[jj+62] += a * b[jj+62]; c[jj+63] += a * b[jj+63];
//
// c[jj+64] += a * b[jj+64]; c[jj+65] += a * b[jj+65]; c[jj+66] += a * b[jj+66]; c[jj+67] += a * b[jj+67];
// c[jj+68] += a * b[jj+68]; c[jj+69] += a * b[jj+69]; c[jj+70] += a * b[jj+70]; c[jj+71] += a * b[jj+71];
// c[jj+72] += a * b[jj+72]; c[jj+73] += a * b[jj+73]; c[jj+74] += a * b[jj+74]; c[jj+75] += a * b[jj+75];
// c[jj+76] += a * b[jj+76]; c[jj+77] += a * b[jj+77]; c[jj+78] += a * b[jj+78]; c[jj+79] += a * b[jj+79];
// c[jj+80] += a * b[jj+80]; c[jj+81] += a * b[jj+81]; c[jj+82] += a * b[jj+82]; c[jj+83] += a * b[jj+83];
// c[jj+84] += a * b[jj+84]; c[jj+85] += a * b[jj+85]; c[jj+86] += a * b[jj+86]; c[jj+87] += a * b[jj+87];
// c[jj+88] += a * b[jj+88]; c[jj+89] += a * b[jj+89]; c[jj+90] += a * b[jj+90]; c[jj+91] += a * b[jj+91];
// c[jj+92] += a * b[jj+92]; c[jj+93] += a * b[jj+93]; c[jj+94] += a * b[jj+94]; c[jj+95] += a * b[jj+95];
//
// c[jj+96] += a * b[jj+96]; c[jj+97] += a * b[jj+97]; c[jj+98] += a * b[jj+98]; c[jj+99] += a * b[jj+99];
// c[jj+100] += a * b[jj+100]; c[jj+101] += a * b[jj+101]; c[jj+102] += a * b[jj+102]; c[jj+103] += a * b[jj+103];
// c[jj+104] += a * b[jj+104]; c[jj+105] += a * b[jj+105]; c[jj+106] += a * b[jj+106]; c[jj+107] += a * b[jj+107];
// c[jj+108] += a * b[jj+108]; c[jj+109] += a * b[jj+109]; c[jj+110] += a * b[jj+110]; c[jj+111] += a * b[jj+111];
// c[jj+112] += a * b[jj+112]; c[jj+113] += a * b[jj+113]; c[jj+114] += a * b[jj+114]; c[jj+115] += a * b[jj+115];
// c[jj+116] += a * b[jj+116]; c[jj+117] += a * b[jj+117]; c[jj+118] += a * b[jj+118]; c[jj+119] += a * b[jj+119];
// c[jj+120] += a * b[jj+120]; c[jj+121] += a * b[jj+121]; c[jj+122] += a * b[jj+122]; c[jj+123] += a * b[jj+123];
// c[jj+124] += a * b[jj+124]; c[jj+125] += a * b[jj+125]; c[jj+126] += a * b[jj+126]; c[jj+127] += a * b[jj+127];
// }
// }
//
//
// }
// }
// }
// }
// }
/************************************/
/* (10) thread split A Raw, Block, AVX */
// Avg. time: 0.018701 sec
// Avg. throughput: 14.354194 GFLOPS
// Performance : 149.351403 GFLOPS (128 block)
/************************************/
// int block = 128;
//
// if((num_threads-1) != thread_index){
// for(int ii = index_m ; ii < index_m + m; ii+=block){
// for (int jj = 0; jj < N; jj += block){
// for (int kk = 0; kk < K; kk += block){
//
// for (int i = ii; i < min(ii+block, index_m + m); i++){
// float * c = C + i * N;
//
// for (int k = kk; k < min(kk+block, K); ++k)
// {
// float * b = B + k * N;
//
// if(jj+block > N){
// float a = A[i*K + k];
// for (int j = jj; j < min(jj+block, N); j++){
// //printf("Normal Start\n");
// c[j] += a * b[j];
// //_mm256_storeu_ps(c + j + 0, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 0), _mm256_loadu_ps(c + j + 0)));
// //_mm256_storeu_ps(c + j + 8, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 8), _mm256_loadu_ps(c + j + 8)));
// }
// }
// else{
// __m256 a = _mm256_set1_ps(A[i * K + k]);
// for (int j = jj; j < min(jj+block, N); j+=8){
// //c[j] += a * b[j];
// //printf("AVX Start\n");
// _mm256_storeu_ps(c + j + 0, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 0), _mm256_loadu_ps(c + j + 0)));
// //_mm256_storeu_ps(c + j + 8, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 8), _mm256_loadu_ps(c + j + 8)));
// }
// }
// }
//
// }
// }
// }
// }
// }
// else{
// for(int ii = index_m ; ii < index_m + m + m_mod; ii+=block){
// for (int jj = 0; jj < N; jj += block){
// for (int kk = 0; kk < K; kk += block){
//
// for (int i = ii; i < min(ii+block, index_m + m + m_mod); i++){
// float * c = C + i * N;
//
// for (int k = kk; k < min(kk+block, K); ++k)
// {
// float * b = B + k * N;
//
// if(jj+block > N){
// float a = A[i * K + k];
// for (int j = jj; j < min(jj+block, N); j++){
// //printf("Normal Start\n");
// c[j] += a * b[j];
// //_mm256_storeu_ps(c + j + 0, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 0), _mm256_loadu_ps(c + j + 0)));
// //_mm256_storeu_ps(c + j + 8, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 8), _mm256_loadu_ps(c + j + 8)));
// }
// }
// else{
// __m256 a = _mm256_set1_ps(A[i * K + k]);
// for (int j = jj; j < min(jj+block, N); j+=8){
// //c[j] += a * b[j];
// //printf("AVX Start\n");
// _mm256_storeu_ps(c + j + 0, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 0), _mm256_loadu_ps(c + j + 0)));
// //_mm256_storeu_ps(c + j + 8, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 8), _mm256_loadu_ps(c + j + 8)));
// }
// }
// }
//
//
// }
// }
// }
// }
// }
/************************************/
/* (11) thread split A Raw, Block, unrolling */
// Performance : 161.073243 GFLOPS (128 block)
// Performance : 164.608178 GFLOPS (64 block/8 unrolling)
// Performance : 174.608178 GFLOPS (64 block/16 unrolling)
// Performance : 176.608178 GFLOPS (64 block/32 unrolling)
/************************************/
int block = 64;
if((num_threads-1) != thread_index){
for(int ii = index_m ; ii < index_m + m; ii+=block){
int m_index_min = (ii+block < index_m + m) ? ii+block : index_m + m;
for (int jj = 0; jj < N; jj += block){
for (int kk = 0; kk < K; kk += block){
//for (int i = ii; i < min(ii+block, index_m + m); i++){
for (int i = ii; i < m_index_min; i++){
float * c = C + i * N;
int k_index_min = (kk+block < K) ? kk+block : K;
for (int k = kk; k < k_index_min; ++k)
{
float * b = B + k * N;
float a = A[i*K + k];
int j_index_min = (jj+block < N) ? jj+block : N;
//for (int j = jj; j < j_index_min; j++)
//for (int j = jj; j < min(jj+block, N); j++)
// c[j] += a * b[j];
if(jj+block > N){
for (int j = jj; j < N; j++){
c[j] += a * b[j];
}
}
else{
for (int j = jj; j < j_index_min; j+=16)
{
c[j] += a * b[j]; c[j+1] += a * b[j+1]; c[j+2] += a * b[j+2]; c[j+3] += a * b[j+3];
c[j+4] += a * b[j+4]; c[j+5] += a * b[j+5]; c[j+6] += a * b[j+6]; c[j+7] += a * b[j+7];
c[j+8] += a * b[j+8]; c[j+9] += a * b[j+9]; c[j+10] += a * b[j+10]; c[j+11] += a * b[j+11];
c[j+12] += a * b[j+12]; c[j+13] += a * b[j+13]; c[j+14] += a * b[j+14]; c[j+15] += a * b[j+15];
// c[j+16] += a * b[j+16]; c[j+17] += a * b[j+17]; c[j+18] += a * b[j+18]; c[j+19] += a * b[j+19];
// c[j+20] += a * b[j+20]; c[j+21] += a * b[j+21]; c[j+22] += a * b[j+22]; c[j+23] += a * b[j+23];
// c[j+24] += a * b[j+24]; c[j+25] += a * b[j+25]; c[j+26] += a * b[j+26]; c[j+27] += a * b[j+27];
// c[j+28] += a * b[j+28]; c[j+29] += a * b[j+29]; c[j+30] += a * b[j+30]; c[j+31] += a * b[j+31];
}
// __m256 a = _mm256_set1_ps(A[i * K + k]);
// for (int j = jj; j < j_index_min; j+=8){
// //c[j] += a * b[j];
// //printf("AVX Start\n");
// _mm256_storeu_ps(c + j + 0, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 0), _mm256_loadu_ps(c + j + 0)));
// //_mm256_storeu_ps(c + j + 8, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 8), _mm256_loadu_ps(c + j + 8)));
// }
}
}
}
}
}
}
}
else{
for(int ii = index_m ; ii < index_m + m + m_mod; ii+=block){
int m_index_min = (ii+block < index_m + m + m_mod) ? ii+block : index_m + m+m_mod;
for (int jj = 0; jj < N; jj += block){
for (int kk = 0; kk < K; kk += block){
//for (int i = ii; i < min(ii+block, index_m + m +m_mod); i++){
for (int i = ii; i < m_index_min; i++){
float * c = C + i * N;
int k_index_min = (kk+block < K) ? kk+block : K;
for (int k = kk; k < k_index_min; ++k)
//for (int k = kk; k < min(kk+block, K); ++k)
{
float * b = B + k * N;
float a = A[i * K + k];
int j_index_min = (jj+block < N) ? jj+block : N;
//for (int j = jj; j < j_index_min; j++)
//for (int j = jj; j < min(jj+block, N); j++)
// c[j] += a * b[j];
if(jj+block > N){
for (int j = jj; j < N; j++){
c[j] += a * b[j];
}
}
else{
for (int j = jj; j < j_index_min; j+=16)
{
c[j] += a * b[j]; c[j+1] += a * b[j+1]; c[j+2] += a * b[j+2]; c[j+3] += a * b[j+3];
c[j+4] += a * b[j+4]; c[j+5] += a * b[j+5]; c[j+6] += a * b[j+6]; c[j+7] += a * b[j+7];
c[j+8] += a * b[j+8]; c[j+9] += a * b[j+9]; c[j+10] += a * b[j+10]; c[j+11] += a * b[j+11];
c[j+12] += a * b[j+12]; c[j+13] += a * b[j+13]; c[j+14] += a * b[j+14]; c[j+15] += a * b[j+15];
// c[j+16] += a * b[j+16]; c[j+17] += a * b[j+17]; c[j+18] += a * b[j+18]; c[j+19] += a * b[j+19];
// c[j+20] += a * b[j+20]; c[j+21] += a * b[j+21]; c[j+22] += a * b[j+22]; c[j+23] += a * b[j+23];
// c[j+24] += a * b[j+24]; c[j+25] += a * b[j+25]; c[j+26] += a * b[j+26]; c[j+27] += a * b[j+27];
// c[j+28] += a * b[j+28]; c[j+29] += a * b[j+29]; c[j+30] += a * b[j+30]; c[j+31] += a * b[j+31];
}
// __m256 a = _mm256_set1_ps(A[i * K + k]);
// for (int j = jj; j < j_index_min; j+=8){
// //c[j] += a * b[j];
// //printf("AVX Start\n");
// _mm256_storeu_ps(c + j + 0, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 0), _mm256_loadu_ps(c + j + 0)));
// //_mm256_storeu_ps(c + j + 8, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 8), _mm256_loadu_ps(c + j + 8)));
// }
}
}
}
}
}
}
}
return NULL;
}
void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads) {
A = _A, B = _B, C = _C;
M = _M, N = _N, K = _K;
num_threads = _num_threads;
// TODO: create '_num_threads' pthreads
void *res;
struct thread_info *tinfo = (thread_info*)calloc(num_threads, sizeof(*tinfo));
for (int tnum = 0; tnum < num_threads; tnum++) {
tinfo[tnum].thread_num = tnum;
// printf("Thread Create : %d \n", tinfo[tnum].thread_num);
pthread_create(&tinfo[tnum].thread_id, NULL, mat_mul_thread, &tinfo[tnum]);
}
for (int tnum = 0; tnum < num_threads; tnum++) {
pthread_join(tinfo[tnum].thread_id, &res);
// printf("Thread Joined : %d \n", tinfo[tnum].thread_num);
}
// pthread_t thread;
// pthread_create(&thread, NULL, mat_mul_thread, NULL);
// pthread_join(thread, NULL);
}