chundoong-lab-ta/SamsungDS22/submissions/final/hs5006.kim/B/main.cpp

167 lines
4.7 KiB
C++

#include <stdio.h>
#include <mpi.h>
#include <getopt.h>
#include <stdbool.h>
#include <stdlib.h>
#include "util.h"
#include "convolution.h"
static bool validation = false;
static int num_iterations = 1;
static int N = 8, C = 8, H = 8, W = 8;
static int K = 8, R = 3, S = 3;
static int pad = 0;
static int dilation = 1;
static int stride = 1;
static int OH, OW;
static int mpi_rank, mpi_world_size;
static void print_help(const char* prog_name) {
if (mpi_rank == 0) {
printf("Usage: %s [-vh] [-p padding] [-d dilation] [-s stride] [-n num_iterations] N C H W K R S\n", prog_name);
printf("Options:\n");
printf(" -h : print this page.\n");
printf(" -v : validate convolution. (default: off)\n");
printf(" -n : number of iterations (default: 1)\n");
printf(" -p : padding size. (default: 0)\n");
printf(" -d : dilation size. (default: 1)\n");
printf(" -s : stride size. (default: 1)\n");
printf(" N : number of images, i.e., batch size. (default: 8)\n");
printf(" C : channel dimension. (default: 8)\n");
printf(" H : image height. (default: 8)\n");
printf(" W : image width. (default: 8)\n");
printf(" K : number of filters. (default: 8)\n");
printf(" R : filter height. (default: 3)\n");
printf(" S : filter width. (default: 3)\n");
}
}
static void parse_opt(int argc, char **argv) {
int c;
while ((c = getopt(argc, argv, "vhn:p:d:s:")) != -1) {
switch (c) {
case 'v':
validation = true;
break;
case 'n':
num_iterations = atoi(optarg);
break;
case 'p':
pad = atoi(optarg);
break;
case 'd':
dilation = atoi(optarg);
break;
case 's':
stride = atoi(optarg);
break;
case 'h':
default:
print_help(argv[0]);
MPI_Finalize();
exit(0);
}
}
for (int i = optind, j = 0; i < argc; ++i, ++j) {
switch (j) {
case 0: N = atoi(argv[i]); break;
case 1: C = atoi(argv[i]); break;
case 2: H = atoi(argv[i]); break;
case 3: W = atoi(argv[i]); break;
case 4: K = atoi(argv[i]); break;
case 5: R = atoi(argv[i]); break;
case 6: S = atoi(argv[i]); break;
default: break;
}
}
OH = (H + 2 * pad - dilation * (R - 1) - 1) / stride + 1;
OW = (W + 2 * pad - dilation * (S - 1) - 1) / stride + 1;
if (mpi_rank == 0) {
printf("Options:\n");
printf(" Input size: N = %d, C = %d, H = %d, W = %d\n", N, C, H, W);
printf(" Output size: N = %d, K = %d, OH = %d, OW = %d\n", N, K, OH, OW);
printf(" Filter size: K = %d, C = %d, R = %d, S = %d\n", K, C, R, S);
printf(" Number of iterations: %d\n", num_iterations);
printf(" Validation: %s\n", validation ? "on" : "off");
printf("\n");
}
}
int main(int argc, char **argv) {
MPI_Init(&argc, &argv);
MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank);
MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size);
parse_opt(argc, argv);
if (mpi_rank == 0) {
printf("Initializing... "); fflush(stdout);
}
float *input, *output, *filter;
if (mpi_rank == 0) {
alloc_tensor(&input, N, C, H, W);
alloc_tensor(&output, N, K, OH, OW);
alloc_tensor(&filter, K, C, R, S);
rand_tensor(input, N, C, H, W);
rand_tensor(filter, K, C, R, S);
}
MPI_Barrier(MPI_COMM_WORLD);
if (mpi_rank == 0) {
printf("done!\n");
}
if (mpi_rank == 0) {
printf("Initializing Convolution...\n"); fflush(stdout);
}
convolution_init(N, C, H, W, K, R, S, pad, dilation, stride);
MPI_Barrier(MPI_COMM_WORLD);
int OH = (H + 2 * pad - dilation * (R - 1) - 1) / stride + 1;
int OW = (W + 2 * pad - dilation * (S - 1) - 1) / stride + 1;
double flops_sum = 0;
for (int i = 0; i < num_iterations; ++i) {
if (mpi_rank == 0) {
printf("Calculating...(iter=%d) ", i); fflush(stdout);
zero_tensor(output, N, K, OH, OW);
}
MPI_Barrier(MPI_COMM_WORLD);
timer_start(0);
convolution(input, output, filter, N, C, H, W, K, R, S, pad, dilation, stride);
MPI_Barrier(MPI_COMM_WORLD);
double elapsed_time = timer_stop(0);
double flops = (double) 2.0 * N * OH * OW * K * C * R * S / elapsed_time;
if (mpi_rank == 0) {
printf("%f sec\n", elapsed_time);
}
flops_sum += flops;
}
convolution_final(N, C, H, W, K, R, S, pad, dilation, stride);
MPI_Barrier(MPI_COMM_WORLD);
if (mpi_rank == 0) {
if (validation) {
check_convolution(input, output, filter, N, C, H, W, K, R, S, pad, dilation, stride);
}
double flops_avg = flops_sum / num_iterations;
printf("Avg. throughput: %f GFLOPS\n", flops_avg / 1e9);
//printf("Avg. throughput: %f GFLOPS\n", 2.0 * M * N * K / elapsed_time_avg / 1e9);
}
MPI_Finalize();
return 0;
}