135 lines
3.9 KiB
C++
135 lines
3.9 KiB
C++
|
#include "mat_mul.h"
|
||
|
|
||
|
#include <cstdio>
|
||
|
#include <cstdlib>
|
||
|
#include <mpi.h>
|
||
|
#define my_MASTER 0
|
||
|
#define FROM_MASTER 1
|
||
|
#define FROM_WORKER 2
|
||
|
#define ITILESIZE 32
|
||
|
#define JTILESIZE 1024
|
||
|
#define KTILESIZE 1024
|
||
|
|
||
|
static float *A, *B, *C;
|
||
|
static int M, N, K;
|
||
|
static int num_threads;
|
||
|
static int mpi_rank, mpi_world_size;
|
||
|
|
||
|
static void mat_mul_omp(int tid, int my_M) {
|
||
|
int is = my_M / num_threads * tid + std::min(tid, my_M % num_threads);
|
||
|
int ie = my_M / num_threads * (tid + 1) + std::min(tid + 1, my_M % num_threads);
|
||
|
|
||
|
for(int ii = is; ii < ie; ii += ITILESIZE){
|
||
|
for(int jj = 0; jj < N; jj += JTILESIZE){
|
||
|
for(int kk = 0; kk < K; kk += KTILESIZE){
|
||
|
for(int k = kk; k < std::min(K, kk + KTILESIZE); k++){
|
||
|
for(int i = ii; i < std::min(ie, ii + ITILESIZE); i++){
|
||
|
float ar = A[i * K + k];
|
||
|
for(int j = jj; j < std::min(N, jj + JTILESIZE); j++){
|
||
|
C[i * N + j] += ar * B[k * N + j];
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
void mat_mul(float *_A, float *_B, float *_C, int _M, int _N, int _K,
|
||
|
int _num_threads, int _mpi_rank, int _mpi_world_size) {
|
||
|
A = _A, B = _B, C = _C;
|
||
|
M = _M, N = _N, K = _K;
|
||
|
num_threads = _num_threads; // Threads per nodes
|
||
|
mpi_rank = _mpi_rank;
|
||
|
mpi_world_size = _mpi_world_size;
|
||
|
|
||
|
int numworkers, source, dest, mtype, averow, extra, offset;
|
||
|
int my_M;
|
||
|
int *my_offset;
|
||
|
int *my_rows;
|
||
|
MPI_Status status;
|
||
|
MPI_Request request = MPI_REQUEST_NULL;
|
||
|
my_offset = (int *)malloc(sizeof(int) * mpi_world_size);
|
||
|
my_rows = (int *)malloc(sizeof(int) * mpi_world_size);
|
||
|
for(int i=0; i<mpi_world_size; ++i) my_offset[i] = 0;
|
||
|
for(int i=0; i<mpi_world_size; ++i) my_rows[i] = 0;
|
||
|
|
||
|
if(mpi_world_size > 1){
|
||
|
numworkers = mpi_world_size;
|
||
|
averow = M / numworkers;
|
||
|
extra = M % numworkers;
|
||
|
offset = 0;
|
||
|
my_M = (mpi_rank==numworkers-1)?(averow+extra):(averow);
|
||
|
|
||
|
if(mpi_rank == my_MASTER){
|
||
|
offset = averow;
|
||
|
mtype = FROM_MASTER;
|
||
|
for(dest=1; dest<numworkers; dest++){
|
||
|
my_rows[dest] = (dest==numworkers-1)?(averow+extra):(averow);
|
||
|
my_offset[dest] = offset;
|
||
|
MPI_Isend(&A[my_offset[dest]*K], my_rows[dest]*K, MPI_FLOAT, dest, mtype, MPI_COMM_WORLD, &request);
|
||
|
MPI_Isend( B, K*N, MPI_FLOAT, dest, mtype, MPI_COMM_WORLD, &request);
|
||
|
offset = offset + my_rows[dest];
|
||
|
}
|
||
|
|
||
|
for(int l=0; l<averow*N; ++l) C[l]=0.0;
|
||
|
|
||
|
#pragma omp parallel num_threads (num_threads)
|
||
|
{
|
||
|
#pragma omp for nowait
|
||
|
for (int i=0; i < num_threads; ++i){
|
||
|
mat_mul_omp(i, averow);
|
||
|
}
|
||
|
#pragma omp barrier
|
||
|
}
|
||
|
|
||
|
mtype = FROM_WORKER;
|
||
|
for(int i=1; i<numworkers; i++){
|
||
|
source = i;
|
||
|
MPI_Irecv(&C[my_offset[source]*N], my_rows[source]*N, MPI_FLOAT, source, mtype, MPI_COMM_WORLD, &request);
|
||
|
}
|
||
|
MPI_Wait(&request, &status);
|
||
|
}
|
||
|
else{
|
||
|
A = (float *)malloc(sizeof(float) * my_M * K);
|
||
|
B = (float *)malloc(sizeof(float) * K * N);
|
||
|
mtype = FROM_MASTER;
|
||
|
//MPI_Recv( A, my_M*K, MPI_FLOAT, my_MASTER, mtype, MPI_COMM_WORLD, &status);
|
||
|
//MPI_Recv( B, K*N, MPI_FLOAT, my_MASTER, mtype, MPI_COMM_WORLD, &status);
|
||
|
MPI_Irecv( A, my_M*K, MPI_FLOAT, my_MASTER, mtype, MPI_COMM_WORLD, &request);
|
||
|
MPI_Irecv( B, K*N, MPI_FLOAT, my_MASTER, mtype, MPI_COMM_WORLD, &request);
|
||
|
|
||
|
C = (float *)malloc(sizeof(float) * my_M * N);
|
||
|
for(int l=0; l<my_M*N; ++l) C[l]=0.0;
|
||
|
MPI_Wait(&request, &status);
|
||
|
|
||
|
#pragma omp parallel num_threads (num_threads)
|
||
|
{
|
||
|
#pragma omp for nowait
|
||
|
for (int i=0; i < num_threads; ++i){
|
||
|
mat_mul_omp(i, my_M);
|
||
|
}
|
||
|
#pragma omp barrier
|
||
|
}
|
||
|
|
||
|
mtype = FROM_WORKER;
|
||
|
MPI_Send( C, my_M*N, MPI_FLOAT, my_MASTER, mtype, MPI_COMM_WORLD);
|
||
|
}
|
||
|
}
|
||
|
else {
|
||
|
// For debugging
|
||
|
//printf("@@@@ A single node case.\n");
|
||
|
for(int l=0; l<M*N; ++l) C[l]=0.0;
|
||
|
#pragma omp parallel num_threads (num_threads)
|
||
|
{
|
||
|
#pragma omp for nowait
|
||
|
for (int i=0; i < num_threads; ++i){
|
||
|
mat_mul_omp(i, M);
|
||
|
}
|
||
|
#pragma omp barrier
|
||
|
}
|
||
|
}
|
||
|
}
|