#define _GNU_SOURCE #include "util.h" #include #include #include #include #include #include void matmul(const float *A, const float *B, float *C, int M, int N, int K, int threads_per_process, int mpi_rank, int mpi_world_size) { int start = ((M) / mpi_world_size) * mpi_rank; int end = ((M) / mpi_world_size) * (mpi_rank + 1); int nrow = end - start; if (nrow > 0) { MPI_Scatter(A, (nrow)*K, MPI_FLOAT, A, (nrow)*K, MPI_FLOAT, 0, MPI_COMM_WORLD); MPI_Bcast(B, K * N, MPI_FLOAT, 0, MPI_COMM_WORLD); #pragma omp parallel for for (int i = 0; i < nrow; ++i) { for (int k = 0; k < K; ++k) { for (int j = 0; j < N; ++j) { C[i * N + j] += A[i * K + k] * B[k * N + j]; } } } } if (mpi_rank == 0) { #pragma omp parallel for for (int i = nrow * mpi_world_size; i < M; ++i) { for (int k = 0; k < K; ++k) { for (int j = 0; j < N; ++j) { C[i * N + j] += A[i * K + k] * B[k * N + j]; } } } } if (nrow > 0) { MPI_Gather(C, (nrow)*N, MPI_FLOAT, C, (nrow)*N, MPI_FLOAT, 0, MPI_COMM_WORLD); } }