#include "mat_mul.h" #include #include #include #include #include #include "util.h" #define NODE0 0 using namespace std; static float *A, *B, *C; static int M, N, K; static int num_threads; static int mpi_rank, mpi_world_size; #define ITILE 32 #define JTILE 1024 #define KTILE 1024 static void mat_mul_omp() { // TODO: parallelize & optimize matrix multiplication // Use num_threads per node // A[M*K] * B[K*N] = C[M*N]; int node_rows = M / mpi_world_size; int start, end, rows; start = mpi_rank * node_rows; end = (mpi_rank == mpi_world_size - 1) ? M : (mpi_rank+1)*node_rows; rows = end - start; int tid = omp_get_thread_num(); int slice = rows / num_threads; start = tid * slice + min(tid, rows % num_threads); end = (tid + 1) * slice + min(tid + 1, rows % num_threads); #pragma omp parallel for for (int ii = start; ii < end; ii += ITILE) { for( int jj = 0; jj < N; jj += JTILE) { for (int kk = 0; kk < K; kk += KTILE) { for (int k = kk; k < min(K, kk + KTILE); ++k) { for (int i = ii; i < min(end, ii + ITILE); ++i) { float Aik = A[i * K + k]; for (int j = jj; j < min(N, jj + JTILE); ++j) { C[i * N + j] += Aik * B[k * N + j]; } } } } } } } void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K, int _num_threads, int _mpi_rank, int _mpi_world_size) { A = _A, B = _B, C = _C; M = _M, N = _N, K = _K; num_threads = _num_threads, mpi_rank = _mpi_rank, mpi_world_size = _mpi_world_size; // TODO: parallelize & optimize matrix multiplication on multi-node // You must allocate & initialize A, B, C for non-root processes MPI_Request request; MPI_Status status; // FIXME: for now, only root process runs the matrix multiplication. if (mpi_rank == 0){ int node_rows = M / mpi_world_size; int start, end, rows; for(int node = 1; node < mpi_world_size; node++) { start = node * node_rows; end = (node == mpi_world_size - 1) ? M : (node + 1) * node_rows; rows = end - start; MPI_Isend(&A[start*K], rows*K, MPI_FLOAT, node, 0, MPI_COMM_WORLD, &request); MPI_Isend(B, K*N, MPI_FLOAT, node, 0, MPI_COMM_WORLD, &request); } #pragma omp parallel num_threads(num_threads) mat_mul_omp(); for(int node = 1; node < mpi_world_size; node++) { start = node * node_rows; end = (node == mpi_world_size - 1) ? M : (node + 1) * node_rows; rows = end - start; MPI_Recv(&C[start*N], rows*N, MPI_FLOAT, node, 0, MPI_COMM_WORLD, &status); } } else { int node_rows = M / mpi_world_size; int start, end, rows; start = mpi_rank * node_rows; end = (mpi_rank == mpi_world_size - 1) ? M : (mpi_rank + 1) * node_rows; rows = end - start; alloc_mat(&A, M, K); alloc_mat(&B, K, N); alloc_mat(&C, M, N); zero_mat(C, M, N); MPI_Recv(A, rows*K, MPI_FLOAT, NODE0, 0, MPI_COMM_WORLD,&status); MPI_Recv(B, K*N, MPI_FLOAT, NODE0, 0, MPI_COMM_WORLD,&status); #pragma omp parallel num_threads(num_threads) mat_mul_omp(); MPI_Isend(C, rows*N, MPI_FLOAT, NODE0, 0, MPI_COMM_WORLD,&request); } }