#include #include #include #include #include #include "util.h" #include "convolution.h" static bool validation = false; static int num_iterations = 1; static int N = 8, C = 8, H = 8, W = 8; static int K = 8, R = 3, S = 3; static int pad = 0; static int dilation = 1; static int stride = 1; static int OH, OW; static int mpi_rank, mpi_world_size; static void print_help(const char* prog_name) { if (mpi_rank == 0) { printf("Usage: %s [-vh] [-p padding] [-d dilation] [-s stride] [-n num_iterations] N C H W K R S\n", prog_name); printf("Options:\n"); printf(" -h : print this page.\n"); printf(" -v : validate convolution. (default: off)\n"); printf(" -n : number of iterations (default: 1)\n"); printf(" -p : padding size. (default: 0)\n"); printf(" -d : dilation size. (default: 1)\n"); printf(" -s : stride size. (default: 1)\n"); printf(" N : number of images, i.e., batch size. (default: 8)\n"); printf(" C : channel dimension. (default: 8)\n"); printf(" H : image height. (default: 8)\n"); printf(" W : image width. (default: 8)\n"); printf(" K : number of filters. (default: 8)\n"); printf(" R : filter height. (default: 3)\n"); printf(" S : filter width. (default: 3)\n"); } } static void parse_opt(int argc, char **argv) { int c; while ((c = getopt(argc, argv, "vhn:p:d:s:")) != -1) { switch (c) { case 'v': validation = true; break; case 'n': num_iterations = atoi(optarg); break; case 'p': pad = atoi(optarg); break; case 'd': dilation = atoi(optarg); break; case 's': stride = atoi(optarg); break; case 'h': default: print_help(argv[0]); MPI_Finalize(); exit(0); } } for (int i = optind, j = 0; i < argc; ++i, ++j) { switch (j) { case 0: N = atoi(argv[i]); break; case 1: C = atoi(argv[i]); break; case 2: H = atoi(argv[i]); break; case 3: W = atoi(argv[i]); break; case 4: K = atoi(argv[i]); break; case 5: R = atoi(argv[i]); break; case 6: S = atoi(argv[i]); break; default: break; } } OH = (H + 2 * pad - dilation * (R - 1) - 1) / stride + 1; OW = (W + 2 * pad - dilation * (S - 1) - 1) / stride + 1; if (mpi_rank == 0) { printf("Options:\n"); printf(" Input size: N = %d, C = %d, H = %d, W = %d\n", N, C, H, W); printf(" Output size: N = %d, K = %d, OH = %d, OW = %d\n", N, K, OH, OW); printf(" Filter size: K = %d, C = %d, R = %d, S = %d\n", K, C, R, S); printf(" Number of iterations: %d\n", num_iterations); printf(" Validation: %s\n", validation ? "on" : "off"); printf("\n"); } } int main(int argc, char **argv) { MPI_Init(&argc, &argv); MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank); MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size); parse_opt(argc, argv); if (mpi_rank == 0) { printf("Initializing... "); fflush(stdout); } float *input, *output, *filter; if (mpi_rank == 0) { alloc_tensor(&input, N, C, H, W); alloc_tensor(&output, N, K, OH, OW); alloc_tensor(&filter, K, C, R, S); rand_tensor(input, N, C, H, W); rand_tensor(filter, K, C, R, S); } MPI_Barrier(MPI_COMM_WORLD); if (mpi_rank == 0) { printf("done!\n"); } if (mpi_rank == 0) { printf("Initializing Convolution...\n"); fflush(stdout); } convolution_init(N, C, H, W, K, R, S, pad, dilation, stride); MPI_Barrier(MPI_COMM_WORLD); int OH = (H + 2 * pad - dilation * (R - 1) - 1) / stride + 1; int OW = (W + 2 * pad - dilation * (S - 1) - 1) / stride + 1; double flops_sum = 0; for (int i = 0; i < num_iterations; ++i) { if (mpi_rank == 0) { printf("Calculating...(iter=%d) ", i); fflush(stdout); zero_tensor(output, N, K, OH, OW); } MPI_Barrier(MPI_COMM_WORLD); timer_start(0); convolution(input, output, filter, N, C, H, W, K, R, S, pad, dilation, stride); MPI_Barrier(MPI_COMM_WORLD); double elapsed_time = timer_stop(0); double flops = (double) 2.0 * N * OH * OW * K * C * R * S / elapsed_time; if (mpi_rank == 0) { printf("%f sec\n", elapsed_time); } flops_sum += flops; } convolution_final(N, C, H, W, K, R, S, pad, dilation, stride); MPI_Barrier(MPI_COMM_WORLD); if (mpi_rank == 0) { double flops_avg = flops_sum / num_iterations; printf("Avg. throughput: %f GFLOPS\n", flops_avg / 1e9); //printf("Avg. throughput: %f GFLOPS\n", 2.0 * M * N * K / elapsed_time_avg / 1e9); if (validation) { check_convolution(input, output, filter, N, C, H, W, K, R, S, pad, dilation, stride); } } MPI_Finalize(); return 0; }