#include "mat_mul.h" #include "util.h" #include #include #include #include "omp.h" static float *A, *B, *C; static int M, N, K; static int num_threads; static int mpi_rank, mpi_world_size; static int min(int x, int y) { return x < y ? x : y; } static void mat_mul_omp() { #define ITILESIZE (32) #define JTILESIZE (1024) #define KTILESIZE (1024) omp_set_num_threads(num_threads); #pragma omp parallel for for (int ii = 0; ii < M; ii += ITILESIZE) { for (int jj = 0; jj < N; jj += JTILESIZE) { for (int kk = 0; kk < K; kk += KTILESIZE) { for (int k = kk; k < min(K, kk + KTILESIZE); k++) { for (int i = ii; i < min(M, ii + ITILESIZE); i++) { float ar = A[i * K + k]; for (int j = jj; j < min(N, jj + JTILESIZE); j+=1) { C[i * N + j] += ar * B[k * N + j]; } } } } } } } void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads, int _mpi_rank, int _mpi_world_size) { A = _A, B = _B, C = _C; M = _M, N = _N, K = _K; num_threads = _num_threads, mpi_rank = _mpi_rank, mpi_world_size = _mpi_world_size; MPI_Status status; MPI_Request req; int averow = M / mpi_world_size; int extra = M % mpi_world_size; if (mpi_rank == 0) { //timer_start(0); for (int node = 1, offset = averow + extra; node < mpi_world_size; node++) { //printf("Master sending partial matrix A to node(%d)\n", mpi_rank); MPI_Isend(B, K * N, MPI_FLOAT, node, node, MPI_COMM_WORLD, &req); MPI_Isend(&A[offset * K], averow * K, MPI_FLOAT, node, node, MPI_COMM_WORLD, &req); offset += averow; } //printf("Sending time %f\n", timer_stop(0)); M = averow + extra; mat_mul_omp(); } else { alloc_mat(&B, K, N); MPI_Recv(B, K * N, MPI_FLOAT, 0, mpi_rank, MPI_COMM_WORLD, &status); alloc_mat(&A, averow, K); alloc_mat(&C, averow, N); zero_mat(C, averow, N); MPI_Recv(A, averow * K, MPI_FLOAT, 0, mpi_rank, MPI_COMM_WORLD, &status); //printf("Worker node(%d) received partial matrix A\n", mpi_rank); M = averow; mat_mul_omp(); MPI_Isend(C, averow * N, MPI_FLOAT, 0, mpi_rank, MPI_COMM_WORLD, &req); //printf("Worker node(%d) sent calculation result to Master\n", mpi_rank); } if (mpi_rank == 0) { //timer_start(1); for (int node = 1, offset = averow + extra; node < mpi_world_size; node++) { MPI_Recv(&C[offset * N], averow * N, MPI_FLOAT, node, node, MPI_COMM_WORLD, &status); offset += averow; } //printf("Receiving time %f\n", timer_stop(1)); } }