#include "mat_mul.h" #include #include #include #include #include #include static float *A, *B, *C; static int M, N, K; static int num_threads; struct thread_info { /* Used as argument to thread_start() */ pthread_t thread_id; /* ID returned by pthread_create() */ int thread_num; /* Application-defined thread # */ }; int min(int a, int b){ if (a < b) return a; else return b; } static void* mat_mul_thread(void *data) { // TODO: parallelize & optimize matrix multiplication struct thread_info *tinfo = (thread_info*)data; int thread_index = tinfo->thread_num; // printf("Thread Num : %d \n", thread_index); int m = M/num_threads; // int n = N/num_threads; int index_m = m * thread_index; // int index_n = n * thread_index; int m_mod = M%num_threads; // int n_mod = N%num_threads; // __m256 a0; // __m256 b0; // __m256 s0; // s0 = _mm256_set_ps(0., 0., 0., 0., 0., 0., 0., 0.); // // int element_num = 64; // // for(int i=0; i< element_num /8 ; ++i){ // a0 = _mm256_load_ps(a + i * 8); // b0 = _mm256_load_ps(b + i * 8); // s0 = _mm256_fmadd_ps(a0, b0 ,s0); // } // float c = s0[0] + s0[1] + s0[2] + s0[3] + s0[4] + s0[5] + s0[6] + s0[7]; // printf("AVX : %f \n", c ); // printf("M para : %d, %d, %d \n", m , index_m, m_mod); // printf("N para : %d, %d, %d \n", n , index_n, n_mod); /*********************************************/ /* (1) basic(@single Thread) */ // Avg. time: 0.303500 sec // Avg. throughput: 0.884466 GFLOPS /*********************************************/ // for (int i = 0; i < M; ++i) { // for (int j = 0; j < N; ++j) { // for (int k = 0; k < K; ++k) { // C[i * N + j] += A[i * K + k] * B[k * N + j]; // } // } // } /*******************************************/ /* (2) block(@single Thead) */ /*******************************************/ // int block = 4; // // for (int ii = 0; ii < M; ii += block) // for (int jj = 0; jj < N; jj += block) // for (int kk = 0; kk < K; kk += block) // for (int i = ii; i < min(ii+block, M); i++) // for (int j = jj; j < min(jj+block, N); j++) // for (int k = kk; k < min(kk+block, K); k++) // C[i * N + j] += A[i * K + k] * B[k * N + j]; /*****************************/ /* (3) thread split A Raw(2 thread) */ // Avg. time: 0.177479 sec // Avg. throughput: 1.512491 GFLOPS /*****************************/ // int block = 4; // int block_index; // // // if((num_threads-1) != thread_index){ // for(int i = index_m ; i < index_m + m; ++i) // { // for(int j = 0; j < N; ++j) // { // for (int k = 0; k < K; ++k){ // C[i * N + j] += A[i * K + k] * B[k * N + j]; // } // } // } // } // else{ // for(int i = index_m ; i < index_m + m + m_mod; ++i) // { // for(int j = 0; j < N; ++j) // { // for (int k = 0; k < K; ++k){ // C[i * N + j] += A[i * K + k] * B[k * N + j]; // } // } // } // } /***********************************************/ /* (4) thread split A Raw, Block */ /* Avg. time: 5.003404 sec */ /* Avg. throughput: 27.469088 GFLOPS */ /***********************************************/ // int block = 8; // int block_index; // // if((num_threads-1) != thread_index){ // for(int ii = index_m ; ii < index_m + m; ii+=block){ // for (int jj = 0; jj < N; jj += block){ // for (int kk = 0; kk < K; kk += block){ // for (int i = ii; i < min(ii+block, index_m + m); i++){ // for (int j = jj; j < min(jj+block, N); j++){ // for (int k = kk; k < min(kk+block, K); k++){ // C[i * N + j] += A[i * K + k] * B[k * N + j]; // } // } // } // } // } // } // // } // else{ // for(int ii = index_m ; ii < index_m + m + m_mod; ii+=block){ // for (int jj = 0; jj < N; jj += block){ // for (int kk = 0; kk < K; kk += block){ // for (int i = ii; i < min(ii+block, index_m + m + m_mod); i++){ // for (int j = jj; j < min(jj+block, N); j++){ // // for (int k = kk; k < min(kk+block, K); k++){ // C[i * N + j] += A[i * K + k] * B[k * N + j]; // } // // } // } // } // } // } // } /************************************/ /* (5) thread split A Raw, Block , unrolling */ //Avg. time: 4.162063 sec //Avg. throughput: 33.021836 GFLOPS /************************************/ //int block = 8; // 4.846sec // int block = 16; // 4.439sec / 32.3 GFLOPS // // int block_index; // int in; // int ik; // float temp; // // if((num_threads-1) != thread_index){ // for(int ii = index_m ; ii < index_m + m; ii+=block){ // for (int jj = 0; jj < N; jj += block){ // for (int kk = 0; kk < K; kk += block){ // for (int i = ii; i < min(ii+block, index_m + m); i++){ // in = i * N; // ik = i * K; // for (int j = jj; j < min(jj+block, N); j++){ // if(kk+block >= K){ // for (int k = kk; k < min(kk+block, K); k++){ // C[i * N + j] += A[i * K + k] * B[k * N + j]; // } // } // else{ // temp = C[in + j]; // temp += A[ik + (kk)] * B[(kk) * N + j]; // temp += A[ik + (kk+1)] * B[(kk+1) * N + j]; // temp += A[ik + (kk+2)] * B[(kk+2) * N + j]; // temp += A[ik + (kk+3)] * B[(kk+3) * N + j]; // temp += A[ik + (kk+4)] * B[(kk+4) * N + j]; // temp += A[ik + (kk+5)] * B[(kk+5) * N + j]; // temp += A[ik + (kk+6)] * B[(kk+6) * N + j]; // temp += A[ik + (kk+7)] * B[(kk+7) * N + j]; // temp += A[ik + (kk+8)] * B[(kk+8) * N + j]; // temp += A[ik + (kk+9)] * B[(kk+9) * N + j]; // temp += A[ik + (kk+10)] * B[(kk+10) * N + j]; // temp += A[ik + (kk+11)] * B[(kk+11) * N + j]; // temp += A[ik + (kk+12)] * B[(kk+12) * N + j]; // temp += A[ik + (kk+13)] * B[(kk+13) * N + j]; // temp += A[ik + (kk+14)] * B[(kk+14) * N + j]; // temp += A[ik + (kk+15)] * B[(kk+15) * N + j]; // C[in + j] = temp; // } // } // } // } // } // } // } // else{ // for(int ii = index_m ; ii < index_m + m + m_mod; ii+=block){ // for (int jj = 0; jj < N; jj += block){ // for (int kk = 0; kk < K; kk += block){ // for (int i = ii; i < min(ii+block, index_m + m + m_mod); i++){ // in = i * N; // ik = i * K; // for (int j = jj; j < min(jj+block, N); j++){ // if(kk+block >= K){ // for (int k = kk; k < min(kk+block, K); k++){ // C[i * N + j] += A[i * K + k] * B[k * N + j]; // } // } // else{ // temp = C[in + j]; // temp += A[ik + (kk)] * B[(kk) * N + j]; // temp += A[ik + (kk+1)] * B[(kk+1) * N + j]; // temp += A[ik + (kk+2)] * B[(kk+2) * N + j]; // temp += A[ik + (kk+3)] * B[(kk+3) * N + j]; // temp += A[ik + (kk+4)] * B[(kk+4) * N + j]; // temp += A[ik + (kk+5)] * B[(kk+5) * N + j]; // temp += A[ik + (kk+6)] * B[(kk+6) * N + j]; // temp += A[ik + (kk+7)] * B[(kk+7) * N + j]; // temp += A[ik + (kk+8)] * B[(kk+8) * N + j]; // temp += A[ik + (kk+9)] * B[(kk+9) * N + j]; // temp += A[ik + (kk+10)] * B[(kk+10) * N + j]; // temp += A[ik + (kk+11)] * B[(kk+11) * N + j]; // temp += A[ik + (kk+12)] * B[(kk+12) * N + j]; // temp += A[ik + (kk+13)] * B[(kk+13) * N + j]; // temp += A[ik + (kk+14)] * B[(kk+14) * N + j]; // temp += A[ik + (kk+15)] * B[(kk+15) * N + j]; // C[in + j] = temp; // } // } // } // } // } // } // } /************************************/ /* (6) single Thread, addressing, column (512,512,512) */ //Avg. time: 0.033199 sec //Avg. throughput: 8.085689 GFLOPS // /************************************/ // for (int i = 0; i < M; ++i) // { // float * c = C + i * N; // for (int j = 0; j < N; ++j) // c[j] = 0; // for (int k = 0; k < K; ++k) // { // const float * b = B + k * N; // float a = A[i*K + k]; // for (int j = 0; j < N; ++j) // c[j] += a * b[j]; // } // } /************************************/ /* (7) 2 Thread, (512,512,512) */ // Avg. time: 0.019109 sec (2 thread) // Avg. throughput: 14.047585 GFLOPS (2 thread) // Performance : 91 GFLOPS /************************************/ // int block = 4; // int block_index; // // // if((num_threads-1) != thread_index){ // for(int i = index_m ; i < index_m + m; ++i) // { // float * c = C + i * N; // for (int j = 0; j < N; ++j) // c[j] = 0; // for (int k = 0; k < K; ++k) // { //// const float * b = B + k * N; // float * b = B + k * N; // float a = A[i*K + k]; // for (int j = 0; j < N; ++j) // c[j] += a * b[j]; // } //// for(int j = 0; j < N; ++j) //// { //// for (int k = 0; k < K; ++k){ //// C[i * N + j] += A[i * K + k] * B[k * N + j]; //// } //// } // } // } // else{ // for(int i = index_m ; i < index_m + m + m_mod; ++i) // { // float * c = C + i * N; // for (int j = 0; j < N; ++j) // c[j] = 0; // for (int k = 0; k < K; ++k) // { //// const float * b = B + k * N; // float * b = B + k * N; // float a = A[i*K + k]; // for (int j = 0; j < N; ++j) // c[j] += a * b[j]; // } // // for(int j = 0; j < N; ++j) // // { // // for (int k = 0; k < K; ++k){ // // C[i * N + j] += A[i * K + k] * B[k * N + j]; // // } // // } // } // } /************************************/ /* (8) thread split A Raw, Block */ // Avg. time: 0.018701 sec // Avg. throughput: 14.354194 GFLOPS // Performance : 163.489443 GFLOPS (128 block) /************************************/ // int block = 128; // // if((num_threads-1) != thread_index){ // for(int ii = index_m ; ii < index_m + m; ii+=block){ // for (int jj = 0; jj < N; jj += block){ // for (int kk = 0; kk < K; kk += block){ // // for (int i = ii; i < min(ii+block, index_m + m); i++){ // float * c = C + i * N; // //// for (int j = jj; j < min(jj+block, N); j++) //// c[j] = 0; // // for (int k = kk; k < min(kk+block, K); ++k) // { // float * b = B + k * N; // float a = A[i*K + k]; // for (int j = jj; j < min(jj+block, N); j++) // c[j] += a * b[j]; // } // // } // } // } // } // } // else{ // for(int ii = index_m ; ii < index_m + m + m_mod; ii+=block){ // for (int jj = 0; jj < N; jj += block){ // for (int kk = 0; kk < K; kk += block){ // // for (int i = ii; i < min(ii+block, index_m + m + m_mod); i++){ // float * c = C + i * N; // //// for (int j = jj; j < min(jj+block, N); j++) //// c[j] = 0; // // for (int k = kk; k < min(kk+block, K); ++k) // { // float * b = B + k * N; // float a = A[i * K + k]; // for (int j = jj; j < min(jj+block, N); j++) // c[j] += a * b[j]; // } // // // } // } // } // } // } /************************************/ /* (9) thread split A Raw, Block */ // Avg. time: 0.066729 sec // Avg. throughput: 4.022766 GFLOPS /************************************/ // int block = 128; // // if((num_threads-1) != thread_index){ // for(int ii = index_m ; ii < index_m + m; ii+=block){ // for (int jj = 0; jj < N; jj += block){ // for (int kk = 0; kk < K; kk += block){ // // for (int i = ii; i < min(ii+block, index_m + m); i++){ // float * c = C + i * N; // // for (int k = kk; k < min(kk+block, K); ++k) // { // float * b = B + k * N; // float a = A[i * K + k]; // if(jj+block >= N){ // // for (int j = jj; j < min(jj+block, N); j++) // c[j] += a * b[j]; // } // else{ // // C[i * N + j] += A[i * K + k] * B[k * N + j]; // c[jj] += a * b[jj]; c[jj+1] += a * b[jj+1]; c[jj+2] += a * b[jj+2]; c[jj+3] += a * b[jj+3]; // c[jj+4] += a * b[jj+4]; c[jj+5] += a * b[jj+5]; c[jj+6] += a * b[jj+6]; c[jj+7] += a * b[jj+7]; // c[jj+8] += a * b[jj+8]; c[jj+9] += a * b[jj+9]; c[jj+10] += a * b[jj+10]; c[jj+11] += a * b[jj+11]; // c[jj+12] += a * b[jj+12]; c[jj+13] += a * b[jj+13]; c[jj+14] += a * b[jj+14]; c[jj+15] += a * b[jj+15]; // c[jj+16] += a * b[jj+16]; c[jj+17] += a * b[jj+17]; c[jj+18] += a * b[jj+18]; c[jj+19] += a * b[jj+19]; // c[jj+20] += a * b[jj+20]; c[jj+21] += a * b[jj+21]; c[jj+22] += a * b[jj+22]; c[jj+23] += a * b[jj+23]; // c[jj+24] += a * b[jj+24]; c[jj+25] += a * b[jj+25]; c[jj+26] += a * b[jj+26]; c[jj+27] += a * b[jj+27]; // c[jj+28] += a * b[jj+28]; c[jj+29] += a * b[jj+29]; c[jj+30] += a * b[jj+30]; c[jj+31] += a * b[jj+31]; // // c[jj+32] += a * b[jj+32]; c[jj+33] += a * b[jj+33]; c[jj+34] += a * b[jj+34]; c[jj+35] += a * b[jj+35]; // c[jj+36] += a * b[jj+36]; c[jj+37] += a * b[jj+37]; c[jj+38] += a * b[jj+38]; c[jj+39] += a * b[jj+39]; // c[jj+40] += a * b[jj+40]; c[jj+41] += a * b[jj+41]; c[jj+42] += a * b[jj+42]; c[jj+43] += a * b[jj+43]; // c[jj+44] += a * b[jj+44]; c[jj+45] += a * b[jj+45]; c[jj+46] += a * b[jj+46]; c[jj+47] += a * b[jj+47]; // c[jj+48] += a * b[jj+48]; c[jj+49] += a * b[jj+49]; c[jj+50] += a * b[jj+50]; c[jj+51] += a * b[jj+51]; // c[jj+52] += a * b[jj+52]; c[jj+53] += a * b[jj+53]; c[jj+54] += a * b[jj+54]; c[jj+55] += a * b[jj+55]; // c[jj+56] += a * b[jj+56]; c[jj+57] += a * b[jj+57]; c[jj+58] += a * b[jj+58]; c[jj+59] += a * b[jj+59]; // c[jj+60] += a * b[jj+60]; c[jj+61] += a * b[jj+61]; c[jj+62] += a * b[jj+62]; c[jj+63] += a * b[jj+63]; // // c[jj+64] += a * b[jj+64]; c[jj+65] += a * b[jj+65]; c[jj+66] += a * b[jj+66]; c[jj+67] += a * b[jj+67]; // c[jj+68] += a * b[jj+68]; c[jj+69] += a * b[jj+69]; c[jj+70] += a * b[jj+70]; c[jj+71] += a * b[jj+71]; // c[jj+72] += a * b[jj+72]; c[jj+73] += a * b[jj+73]; c[jj+74] += a * b[jj+74]; c[jj+75] += a * b[jj+75]; // c[jj+76] += a * b[jj+76]; c[jj+77] += a * b[jj+77]; c[jj+78] += a * b[jj+78]; c[jj+79] += a * b[jj+79]; // c[jj+80] += a * b[jj+80]; c[jj+81] += a * b[jj+81]; c[jj+82] += a * b[jj+82]; c[jj+83] += a * b[jj+83]; // c[jj+84] += a * b[jj+84]; c[jj+85] += a * b[jj+85]; c[jj+86] += a * b[jj+86]; c[jj+87] += a * b[jj+87]; // c[jj+88] += a * b[jj+88]; c[jj+89] += a * b[jj+89]; c[jj+90] += a * b[jj+90]; c[jj+91] += a * b[jj+91]; // c[jj+92] += a * b[jj+92]; c[jj+93] += a * b[jj+93]; c[jj+94] += a * b[jj+94]; c[jj+95] += a * b[jj+95]; // // c[jj+96] += a * b[jj+96]; c[jj+97] += a * b[jj+97]; c[jj+98] += a * b[jj+98]; c[jj+99] += a * b[jj+99]; // c[jj+100] += a * b[jj+100]; c[jj+101] += a * b[jj+101]; c[jj+102] += a * b[jj+102]; c[jj+103] += a * b[jj+103]; // c[jj+104] += a * b[jj+104]; c[jj+105] += a * b[jj+105]; c[jj+106] += a * b[jj+106]; c[jj+107] += a * b[jj+107]; // c[jj+108] += a * b[jj+108]; c[jj+109] += a * b[jj+109]; c[jj+110] += a * b[jj+110]; c[jj+111] += a * b[jj+111]; // c[jj+112] += a * b[jj+112]; c[jj+113] += a * b[jj+113]; c[jj+114] += a * b[jj+114]; c[jj+115] += a * b[jj+115]; // c[jj+116] += a * b[jj+116]; c[jj+117] += a * b[jj+117]; c[jj+118] += a * b[jj+118]; c[jj+119] += a * b[jj+119]; // c[jj+120] += a * b[jj+120]; c[jj+121] += a * b[jj+121]; c[jj+122] += a * b[jj+122]; c[jj+123] += a * b[jj+123]; // c[jj+124] += a * b[jj+124]; c[jj+125] += a * b[jj+125]; c[jj+126] += a * b[jj+126]; c[jj+127] += a * b[jj+127]; // } // } // // } // } // } // } // } // else{ // for(int ii = index_m ; ii < index_m + m + m_mod; ii+=block){ // for (int jj = 0; jj < N; jj += block){ // for (int kk = 0; kk < K; kk += block){ // // for (int i = ii; i < min(ii+block, index_m + m + m_mod); i++){ // float * c = C + i * N; // // for (int k = kk; k < min(kk+block, K); ++k) // { // float * b = B + k * N; // float a = A[i * K + k]; // if(jj+block >= N){ // // for (int j = jj; j < min(jj+block, N); j++) // c[j] += a * b[j]; // } // else{ // // C[i * N + j] += A[i * K + k] * B[k * N + j]; // c[jj] += a * b[jj]; c[jj+1] += a * b[jj+1]; c[jj+2] += a * b[jj+2]; c[jj+3] += a * b[jj+3]; // c[jj+4] += a * b[jj+4]; c[jj+5] += a * b[jj+5]; c[jj+6] += a * b[jj+6]; c[jj+7] += a * b[jj+7]; // c[jj+8] += a * b[jj+8]; c[jj+9] += a * b[jj+9]; c[jj+10] += a * b[jj+10]; c[jj+11] += a * b[jj+11]; // c[jj+12] += a * b[jj+12]; c[jj+13] += a * b[jj+13]; c[jj+14] += a * b[jj+14]; c[jj+15] += a * b[jj+15]; // c[jj+16] += a * b[jj+16]; c[jj+17] += a * b[jj+17]; c[jj+18] += a * b[jj+18]; c[jj+19] += a * b[jj+19]; // c[jj+20] += a * b[jj+20]; c[jj+21] += a * b[jj+21]; c[jj+22] += a * b[jj+22]; c[jj+23] += a * b[jj+23]; // c[jj+24] += a * b[jj+24]; c[jj+25] += a * b[jj+25]; c[jj+26] += a * b[jj+26]; c[jj+27] += a * b[jj+27]; // c[jj+28] += a * b[jj+28]; c[jj+29] += a * b[jj+29]; c[jj+30] += a * b[jj+30]; c[jj+31] += a * b[jj+31]; // // c[jj+32] += a * b[jj+32]; c[jj+33] += a * b[jj+33]; c[jj+34] += a * b[jj+34]; c[jj+35] += a * b[jj+35]; // c[jj+36] += a * b[jj+36]; c[jj+37] += a * b[jj+37]; c[jj+38] += a * b[jj+38]; c[jj+39] += a * b[jj+39]; // c[jj+40] += a * b[jj+40]; c[jj+41] += a * b[jj+41]; c[jj+42] += a * b[jj+42]; c[jj+43] += a * b[jj+43]; // c[jj+44] += a * b[jj+44]; c[jj+45] += a * b[jj+45]; c[jj+46] += a * b[jj+46]; c[jj+47] += a * b[jj+47]; // c[jj+48] += a * b[jj+48]; c[jj+49] += a * b[jj+49]; c[jj+50] += a * b[jj+50]; c[jj+51] += a * b[jj+51]; // c[jj+52] += a * b[jj+52]; c[jj+53] += a * b[jj+53]; c[jj+54] += a * b[jj+54]; c[jj+55] += a * b[jj+55]; // c[jj+56] += a * b[jj+56]; c[jj+57] += a * b[jj+57]; c[jj+58] += a * b[jj+58]; c[jj+59] += a * b[jj+59]; // c[jj+60] += a * b[jj+60]; c[jj+61] += a * b[jj+61]; c[jj+62] += a * b[jj+62]; c[jj+63] += a * b[jj+63]; // // c[jj+64] += a * b[jj+64]; c[jj+65] += a * b[jj+65]; c[jj+66] += a * b[jj+66]; c[jj+67] += a * b[jj+67]; // c[jj+68] += a * b[jj+68]; c[jj+69] += a * b[jj+69]; c[jj+70] += a * b[jj+70]; c[jj+71] += a * b[jj+71]; // c[jj+72] += a * b[jj+72]; c[jj+73] += a * b[jj+73]; c[jj+74] += a * b[jj+74]; c[jj+75] += a * b[jj+75]; // c[jj+76] += a * b[jj+76]; c[jj+77] += a * b[jj+77]; c[jj+78] += a * b[jj+78]; c[jj+79] += a * b[jj+79]; // c[jj+80] += a * b[jj+80]; c[jj+81] += a * b[jj+81]; c[jj+82] += a * b[jj+82]; c[jj+83] += a * b[jj+83]; // c[jj+84] += a * b[jj+84]; c[jj+85] += a * b[jj+85]; c[jj+86] += a * b[jj+86]; c[jj+87] += a * b[jj+87]; // c[jj+88] += a * b[jj+88]; c[jj+89] += a * b[jj+89]; c[jj+90] += a * b[jj+90]; c[jj+91] += a * b[jj+91]; // c[jj+92] += a * b[jj+92]; c[jj+93] += a * b[jj+93]; c[jj+94] += a * b[jj+94]; c[jj+95] += a * b[jj+95]; // // c[jj+96] += a * b[jj+96]; c[jj+97] += a * b[jj+97]; c[jj+98] += a * b[jj+98]; c[jj+99] += a * b[jj+99]; // c[jj+100] += a * b[jj+100]; c[jj+101] += a * b[jj+101]; c[jj+102] += a * b[jj+102]; c[jj+103] += a * b[jj+103]; // c[jj+104] += a * b[jj+104]; c[jj+105] += a * b[jj+105]; c[jj+106] += a * b[jj+106]; c[jj+107] += a * b[jj+107]; // c[jj+108] += a * b[jj+108]; c[jj+109] += a * b[jj+109]; c[jj+110] += a * b[jj+110]; c[jj+111] += a * b[jj+111]; // c[jj+112] += a * b[jj+112]; c[jj+113] += a * b[jj+113]; c[jj+114] += a * b[jj+114]; c[jj+115] += a * b[jj+115]; // c[jj+116] += a * b[jj+116]; c[jj+117] += a * b[jj+117]; c[jj+118] += a * b[jj+118]; c[jj+119] += a * b[jj+119]; // c[jj+120] += a * b[jj+120]; c[jj+121] += a * b[jj+121]; c[jj+122] += a * b[jj+122]; c[jj+123] += a * b[jj+123]; // c[jj+124] += a * b[jj+124]; c[jj+125] += a * b[jj+125]; c[jj+126] += a * b[jj+126]; c[jj+127] += a * b[jj+127]; // } // } // // // } // } // } // } // } /************************************/ /* (10) thread split A Raw, Block, AVX */ // Avg. time: 0.018701 sec // Avg. throughput: 14.354194 GFLOPS // Performance : 149.351403 GFLOPS (128 block) /************************************/ // int block = 128; // // if((num_threads-1) != thread_index){ // for(int ii = index_m ; ii < index_m + m; ii+=block){ // for (int jj = 0; jj < N; jj += block){ // for (int kk = 0; kk < K; kk += block){ // // for (int i = ii; i < min(ii+block, index_m + m); i++){ // float * c = C + i * N; // // for (int k = kk; k < min(kk+block, K); ++k) // { // float * b = B + k * N; // // if(jj+block > N){ // float a = A[i*K + k]; // for (int j = jj; j < min(jj+block, N); j++){ // //printf("Normal Start\n"); // c[j] += a * b[j]; // //_mm256_storeu_ps(c + j + 0, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 0), _mm256_loadu_ps(c + j + 0))); // //_mm256_storeu_ps(c + j + 8, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 8), _mm256_loadu_ps(c + j + 8))); // } // } // else{ // __m256 a = _mm256_set1_ps(A[i * K + k]); // for (int j = jj; j < min(jj+block, N); j+=8){ // //c[j] += a * b[j]; // //printf("AVX Start\n"); // _mm256_storeu_ps(c + j + 0, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 0), _mm256_loadu_ps(c + j + 0))); // //_mm256_storeu_ps(c + j + 8, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 8), _mm256_loadu_ps(c + j + 8))); // } // } // } // // } // } // } // } // } // else{ // for(int ii = index_m ; ii < index_m + m + m_mod; ii+=block){ // for (int jj = 0; jj < N; jj += block){ // for (int kk = 0; kk < K; kk += block){ // // for (int i = ii; i < min(ii+block, index_m + m + m_mod); i++){ // float * c = C + i * N; // // for (int k = kk; k < min(kk+block, K); ++k) // { // float * b = B + k * N; // // if(jj+block > N){ // float a = A[i * K + k]; // for (int j = jj; j < min(jj+block, N); j++){ // //printf("Normal Start\n"); // c[j] += a * b[j]; // //_mm256_storeu_ps(c + j + 0, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 0), _mm256_loadu_ps(c + j + 0))); // //_mm256_storeu_ps(c + j + 8, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 8), _mm256_loadu_ps(c + j + 8))); // } // } // else{ // __m256 a = _mm256_set1_ps(A[i * K + k]); // for (int j = jj; j < min(jj+block, N); j+=8){ // //c[j] += a * b[j]; // //printf("AVX Start\n"); // _mm256_storeu_ps(c + j + 0, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 0), _mm256_loadu_ps(c + j + 0))); // //_mm256_storeu_ps(c + j + 8, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 8), _mm256_loadu_ps(c + j + 8))); // } // } // } // // // } // } // } // } // } /************************************/ /* (11) thread split A Raw, Block, unrolling */ // Performance : 161.073243 GFLOPS (128 block) // Performance : 164.608178 GFLOPS (64 block/8 unrolling) // Performance : 174.608178 GFLOPS (64 block/16 unrolling) // Performance : 176.608178 GFLOPS (64 block/32 unrolling) /************************************/ int block = 64; if((num_threads-1) != thread_index){ for(int ii = index_m ; ii < index_m + m; ii+=block){ int m_index_min = (ii+block < index_m + m) ? ii+block : index_m + m; for (int jj = 0; jj < N; jj += block){ for (int kk = 0; kk < K; kk += block){ //for (int i = ii; i < min(ii+block, index_m + m); i++){ for (int i = ii; i < m_index_min; i++){ float * c = C + i * N; int k_index_min = (kk+block < K) ? kk+block : K; for (int k = kk; k < k_index_min; ++k) { float * b = B + k * N; float a = A[i*K + k]; int j_index_min = (jj+block < N) ? jj+block : N; //for (int j = jj; j < j_index_min; j++) //for (int j = jj; j < min(jj+block, N); j++) // c[j] += a * b[j]; if(jj+block > N){ for (int j = jj; j < N; j++){ c[j] += a * b[j]; } } else{ for (int j = jj; j < j_index_min; j+=16) { c[j] += a * b[j]; c[j+1] += a * b[j+1]; c[j+2] += a * b[j+2]; c[j+3] += a * b[j+3]; c[j+4] += a * b[j+4]; c[j+5] += a * b[j+5]; c[j+6] += a * b[j+6]; c[j+7] += a * b[j+7]; c[j+8] += a * b[j+8]; c[j+9] += a * b[j+9]; c[j+10] += a * b[j+10]; c[j+11] += a * b[j+11]; c[j+12] += a * b[j+12]; c[j+13] += a * b[j+13]; c[j+14] += a * b[j+14]; c[j+15] += a * b[j+15]; // c[j+16] += a * b[j+16]; c[j+17] += a * b[j+17]; c[j+18] += a * b[j+18]; c[j+19] += a * b[j+19]; // c[j+20] += a * b[j+20]; c[j+21] += a * b[j+21]; c[j+22] += a * b[j+22]; c[j+23] += a * b[j+23]; // c[j+24] += a * b[j+24]; c[j+25] += a * b[j+25]; c[j+26] += a * b[j+26]; c[j+27] += a * b[j+27]; // c[j+28] += a * b[j+28]; c[j+29] += a * b[j+29]; c[j+30] += a * b[j+30]; c[j+31] += a * b[j+31]; } // __m256 a = _mm256_set1_ps(A[i * K + k]); // for (int j = jj; j < j_index_min; j+=8){ // //c[j] += a * b[j]; // //printf("AVX Start\n"); // _mm256_storeu_ps(c + j + 0, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 0), _mm256_loadu_ps(c + j + 0))); // //_mm256_storeu_ps(c + j + 8, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 8), _mm256_loadu_ps(c + j + 8))); // } } } } } } } } else{ for(int ii = index_m ; ii < index_m + m + m_mod; ii+=block){ int m_index_min = (ii+block < index_m + m + m_mod) ? ii+block : index_m + m+m_mod; for (int jj = 0; jj < N; jj += block){ for (int kk = 0; kk < K; kk += block){ //for (int i = ii; i < min(ii+block, index_m + m +m_mod); i++){ for (int i = ii; i < m_index_min; i++){ float * c = C + i * N; int k_index_min = (kk+block < K) ? kk+block : K; for (int k = kk; k < k_index_min; ++k) //for (int k = kk; k < min(kk+block, K); ++k) { float * b = B + k * N; float a = A[i * K + k]; int j_index_min = (jj+block < N) ? jj+block : N; //for (int j = jj; j < j_index_min; j++) //for (int j = jj; j < min(jj+block, N); j++) // c[j] += a * b[j]; if(jj+block > N){ for (int j = jj; j < N; j++){ c[j] += a * b[j]; } } else{ for (int j = jj; j < j_index_min; j+=16) { c[j] += a * b[j]; c[j+1] += a * b[j+1]; c[j+2] += a * b[j+2]; c[j+3] += a * b[j+3]; c[j+4] += a * b[j+4]; c[j+5] += a * b[j+5]; c[j+6] += a * b[j+6]; c[j+7] += a * b[j+7]; c[j+8] += a * b[j+8]; c[j+9] += a * b[j+9]; c[j+10] += a * b[j+10]; c[j+11] += a * b[j+11]; c[j+12] += a * b[j+12]; c[j+13] += a * b[j+13]; c[j+14] += a * b[j+14]; c[j+15] += a * b[j+15]; // c[j+16] += a * b[j+16]; c[j+17] += a * b[j+17]; c[j+18] += a * b[j+18]; c[j+19] += a * b[j+19]; // c[j+20] += a * b[j+20]; c[j+21] += a * b[j+21]; c[j+22] += a * b[j+22]; c[j+23] += a * b[j+23]; // c[j+24] += a * b[j+24]; c[j+25] += a * b[j+25]; c[j+26] += a * b[j+26]; c[j+27] += a * b[j+27]; // c[j+28] += a * b[j+28]; c[j+29] += a * b[j+29]; c[j+30] += a * b[j+30]; c[j+31] += a * b[j+31]; } // __m256 a = _mm256_set1_ps(A[i * K + k]); // for (int j = jj; j < j_index_min; j+=8){ // //c[j] += a * b[j]; // //printf("AVX Start\n"); // _mm256_storeu_ps(c + j + 0, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 0), _mm256_loadu_ps(c + j + 0))); // //_mm256_storeu_ps(c + j + 8, _mm256_fmadd_ps(a,_mm256_loadu_ps(b + j + 8), _mm256_loadu_ps(c + j + 8))); // } } } } } } } } return NULL; } void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads) { A = _A, B = _B, C = _C; M = _M, N = _N, K = _K; num_threads = _num_threads; // TODO: create '_num_threads' pthreads void *res; struct thread_info *tinfo = (thread_info*)calloc(num_threads, sizeof(*tinfo)); for (int tnum = 0; tnum < num_threads; tnum++) { tinfo[tnum].thread_num = tnum; // printf("Thread Create : %d \n", tinfo[tnum].thread_num); pthread_create(&tinfo[tnum].thread_id, NULL, mat_mul_thread, &tinfo[tnum]); } for (int tnum = 0; tnum < num_threads; tnum++) { pthread_join(tinfo[tnum].thread_id, &res); // printf("Thread Joined : %d \n", tinfo[tnum].thread_num); } // pthread_t thread; // pthread_create(&thread, NULL, mat_mul_thread, NULL); // pthread_join(thread, NULL); }