[Refactor] refactor & modify operations and comments

This commit is contained in:
Jaehwan Lee 2024-04-27 15:12:18 +00:00
parent 3a11f82ac1
commit c20814b81b
6 changed files with 770 additions and 726 deletions

View File

@ -11,107 +11,107 @@
int main(int argc, char **argv) {
/* MPI Initialization */
int mpi_rank, mpi_size;
MPI_Init(&argc, &argv);
MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank);
MPI_Comm_size(MPI_COMM_WORLD, &mpi_size);
/* Parse arguments */
if (mpi_rank == 0) {
parse_args(argc, argv);
}
/* MPI Initialization */
int mpi_rank, mpi_size;
MPI_Init(&argc, &argv);
MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank);
MPI_Comm_size(MPI_COMM_WORLD, &mpi_size);
/* Parse arguments */
if (mpi_rank == 0) {
parse_args(argc, argv);
}
////////////////////////////////////////////////////////////////////
// INITIALIZATION //
////////////////////////////////////////////////////////////////////
int *input, *output;
////////////////////////////////////////////////////////////////////
// INITIALIZATION //
////////////////////////////////////////////////////////////////////
int *input, *output;
if (mpi_rank == 0) {
/* Load input (size: N x input_tokens) from file */
fprintf(stderr, "[LOG] Reading input from %s\n", input_fname);
size_t input_size;
input = (int *)read_binary(input_fname, &input_size);
if (input_size % input_tokens != 0) {
fprintf(stderr, "[ERROR] Invalid input size\n");
exit(1);
}
if (mpi_rank == 0) {
/* Load input (size: N x input_tokens) from file */
fprintf(stderr, "[LOG] Reading input from %s\n", input_fname);
size_t input_size;
input = (int *)read_binary(input_fname, &input_size);
if (input_size % input_tokens != 0) {
fprintf(stderr, "[ERROR] Invalid input size\n");
exit(1);
}
/* Allocate output (size: N x T) */
output = (int *)malloc(N * T * sizeof(int));
}
/* Initialize parameters and activations */
if (mpi_rank == 0) fprintf(stderr, "[LOG] Initializing... \n");
initialize_parameters(param_fname);
initialize_activations();
/* Allocate output (size: N x T) */
output = (int *)malloc(N * T * sizeof(int));
}
/* Initialize parameters and activations */
if (mpi_rank == 0) fprintf(stderr, "[LOG] Initializing... \n");
initialize_parameters(param_fname);
initialize_activations();
/* Cannot surpass the max_seq_len of the model */
assert(input_tokens + T <= MAX_SEQ_LEN);
/* Cannot surpass the max_seq_len of the model */
assert(input_tokens + T <= MAX_SEQ_LEN);
////////////////////////////////////////////////////////////////////
// MODEL COMPUTATION //
////////////////////////////////////////////////////////////////////
double st = 0.0, et = 0.0;
////////////////////////////////////////////////////////////////////
// MODEL COMPUTATION //
////////////////////////////////////////////////////////////////////
double st = 0.0, et = 0.0;
if (mpi_rank == 0) {
fprintf(stdout, " Start computation... \n");
st = get_time();
}
if (mpi_rank == 0) {
fprintf(stdout, " Start computation... \n");
st = get_time();
}
/* Text Generation */
MPI_Barrier(MPI_COMM_WORLD);
generate_tokens(input, output, N, T);
MPI_Barrier(MPI_COMM_WORLD);
/* Text Generation */
MPI_Barrier(MPI_COMM_WORLD);
generate_tokens(input, output, N, T);
MPI_Barrier(MPI_COMM_WORLD);
if (mpi_rank == 0) {
et = get_time();
if (mpi_rank == 0) {
et = get_time();
/* Print the result */
fprintf(stdout, " Done!\n");
fprintf(stdout, " Elapsed time: %lf (sec)\n", et - st);
fprintf(stdout, " Throughput: %lf (tokens/sec)\n",
N*T / (et - st));
}
////////////////////////////////////////////////////////////////////
// FINALIZATION //
////////////////////////////////////////////////////////////////////
/* Finalize parameters and activations */
if (mpi_rank == 0) fprintf(stderr, "[LOG] Finalizing... \n");
finalize_parameters();
finalize_activations();
if (mpi_rank == 0) {
/* Save output */
if (S) {
fprintf(stdout, " Saving output... \n");
write_binary(output, output_fname, N*T);
}
/* Print the result */
fprintf(stdout, " Done!\n");
fprintf(stdout, " Elapsed time: %lf (sec)\n", et - st);
fprintf(stdout, " Throughput: %lf (tokens/sec)\n",
N*T / (et - st));
}
////////////////////////////////////////////////////////////////////
// FINALIZATION //
////////////////////////////////////////////////////////////////////
/* Finalize parameters and activations */
if (mpi_rank == 0) fprintf(stderr, "[LOG] Finalizing... \n");
finalize_parameters();
finalize_activations();
if (mpi_rank == 0) {
/* Save output */
if (S) {
fprintf(stdout, " Saving output... \n");
write_binary(output, output_fname, N*T);
}
/* Validation */
if (V) {
fprintf(stdout, " Validation... \n");
/* Validation */
if (V) {
fprintf(stdout, " Validation... \n");
int *answer = (int *)read_binary(answer_fname, NULL);
int ret = check_validation(output, answer, N*T);
if (ret == -1) {
fprintf(stdout, " Validation passed!\n");
} else {
fprintf(stdout, " Validation failed: First mismatch "
"at prompt[#%d], token_ID[#%d] (output[%d]=%d <-> "
"answer[%d]=%d)\n", ret / T, ret % T, ret,
output[ret], ret, answer[ret]);
}
}
}
int *answer = (int *)read_binary(answer_fname, NULL);
int ret = check_validation(output, answer, N*T);
if (ret == -1) {
fprintf(stdout, " Validation passed!\n");
} else {
fprintf(stdout, " Validation failed: First mismatch "
"at prompt[#%d], token_ID[#%d] (output[%d]=%d <-> "
"answer[%d]=%d)\n", ret / T, ret % T, ret,
output[ret], ret, answer[ret]);
}
}
}
/* MPI Finalization */
MPI_Finalize();
/* MPI Finalization */
MPI_Finalize();
return 0;
return 0;
}

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,8 @@
#!bin/sh
mpirun --bind-to none -mca btl ^openib -npernode 1 \
--oversubscribe -quiet \
./main \
-v -s \
-t 5 \
-n 1 \
--oversubscribe -quiet \
./main \
-v -s \
-t 5 \
-n 1 \

View File

@ -7,32 +7,32 @@
using namespace std;
Tensor::Tensor(const vector<int> &shape_) {
ndim = shape_.size();
for (size_t i = 0; i < ndim; i++) {
shape[i] = shape_[i];
}
int N_ = num_elem();
buf = (float *)calloc(N_, sizeof(float));
ndim = shape_.size();
for (size_t i = 0; i < ndim; i++) {
shape[i] = shape_[i];
}
int N_ = num_elem();
buf = (float *)calloc(N_, sizeof(float));
}
Tensor::Tensor(const vector<int> &shape_, float *buf_) {
ndim = shape_.size();
for (size_t i = 0; i < ndim; i++) {
shape[i] = shape_[i];
}
int N_ = num_elem();
buf = (float *)malloc(N_ * sizeof(float));
memcpy(buf, buf_, N_ * sizeof(float));
ndim = shape_.size();
for (size_t i = 0; i < ndim; i++) {
shape[i] = shape_[i];
}
int N_ = num_elem();
buf = (float *)malloc(N_ * sizeof(float));
memcpy(buf, buf_, N_ * sizeof(float));
}
Tensor::~Tensor() {
if (buf != nullptr) free(buf);
if (buf != nullptr) free(buf);
}
int Tensor::num_elem() {
int size = 1;
for (size_t i = 0; i < ndim; i++) {
size *= shape[i];
}
return size;
int size = 1;
for (size_t i = 0; i < ndim; i++) {
size *= shape[i];
}
return size;
}

View File

@ -6,13 +6,13 @@
using namespace std;
struct Tensor {
size_t ndim = 0;
int shape[4];
float *buf = nullptr;
size_t ndim = 0;
int shape[4];
float *buf = nullptr;
Tensor(const vector<int> &shape_);
Tensor(const vector<int> &shape_, float *buf_);
~Tensor();
Tensor(const vector<int> &shape_);
Tensor(const vector<int> &shape_, float *buf_);
~Tensor();
int num_elem();
int num_elem();
};

View File

@ -17,99 +17,99 @@ char answer_fname[] = "./data/answer.bin";
char output_fname[] = "./data/output.bin";
void parse_args(int argc, char **argv) {
int args;
while ((args = getopt(argc, argv, "i:o:a:p:n:t:vswh")) != -1) {
switch (args) {
case 'i':
strcpy(input_fname, optarg);
break;
case 'o':
strcpy(output_fname, optarg);
break;
case 'a':
strcpy(answer_fname, optarg);
break;
case 'p':
strcpy(param_fname, optarg);
break;
case 'n':
N = atoi(optarg);
break;
case 't':
T = atoi(optarg);
break;
case 'v':
V = true;
break;
case 's':
S = true;
break;
case 'h':
print_help();
exit(0);
break;
default:
print_help();
exit(0);
break;
}
int args;
while ((args = getopt(argc, argv, "i:o:a:p:n:t:vswh")) != -1) {
switch (args) {
case 'i':
strcpy(input_fname, optarg);
break;
case 'o':
strcpy(output_fname, optarg);
break;
case 'a':
strcpy(answer_fname, optarg);
break;
case 'p':
strcpy(param_fname, optarg);
break;
case 'n':
N = atoi(optarg);
break;
case 't':
T = atoi(optarg);
break;
case 'v':
V = true;
break;
case 's':
S = true;
break;
case 'h':
print_help();
exit(0);
break;
default:
print_help();
exit(0);
break;
}
}
fprintf(stdout, "\n=============================================\n");
fprintf(stdout, " Model: GPT-2 (12 layers)\n");
fprintf(stdout, "---------------------------------------------\n");
fprintf(stdout, " Validation: %s\n", V ? "ON" : "OFF");
fprintf(stdout, " Save output: %s\n", S ? "ON" : "OFF");
fprintf(stdout, " Number of Prompts: %d\n", N);
fprintf(stdout, " Number of Tokens to generate: %d\n", T);
fprintf(stdout, "=============================================\n\n");
fprintf(stdout, "\n=============================================\n");
fprintf(stdout, " Model: GPT-2 (12 layers)\n");
fprintf(stdout, "---------------------------------------------\n");
fprintf(stdout, " Validation: %s\n", V ? "ON" : "OFF");
fprintf(stdout, " Save output: %s\n", S ? "ON" : "OFF");
fprintf(stdout, " Number of Prompts: %d\n", N);
fprintf(stdout, " Number of Tokens to generate: %d\n", T);
fprintf(stdout, "=============================================\n\n");
}
void print_help() {
fprintf(stdout,
" Usage: ./main [-i 'pth'] [-p 'pth'] [-o 'pth'] [-a 'pth']"
" [-t 'tokens'] [-n 'prompts'] [-v] [-s] [-h]\n");
fprintf(stdout, " Options:\n");
fprintf(stdout, " -i: Input binary path (default: data/input.bin)\n");
fprintf(stdout, " -p: Model parameter path (default: assets/model_file.bin)\n");
fprintf(stdout, " -o: Output binary path (default: output.bin)\n");
fprintf(stdout, " -a: Answer binary path (default: data/answer.bin)\n");
fprintf(stdout, " -n: Number of prompts (default: 1)\n");
fprintf(stdout, " -t: Number of tokens to generate (default: 5)\n");
fprintf(stdout, " -v: Enable validation (default: OFF)\n");
fprintf(stdout, " -s: Enable saving output tensor (default: OFF)\n");
fprintf(stdout, " -h: Print manual and options (default: OFF)\n");
fprintf(stdout,
" Usage: ./main [-i 'pth'] [-p 'pth'] [-o 'pth'] [-a 'pth']"
" [-t 'tokens'] [-n 'prompts'] [-v] [-s] [-h]\n");
fprintf(stdout, " Options:\n");
fprintf(stdout, " -i: Input binary path (default: data/input.bin)\n");
fprintf(stdout, " -p: Model parameter path (default: assets/model_file.bin)\n");
fprintf(stdout, " -o: Output binary path (default: output.bin)\n");
fprintf(stdout, " -a: Answer binary path (default: data/answer.bin)\n");
fprintf(stdout, " -n: Number of prompts (default: 1)\n");
fprintf(stdout, " -t: Number of tokens to generate (default: 5)\n");
fprintf(stdout, " -v: Enable validation (default: OFF)\n");
fprintf(stdout, " -s: Enable saving output tensor (default: OFF)\n");
fprintf(stdout, " -h: Print manual and options (default: OFF)\n");
}
void* read_binary(const char *fname, size_t *size) {
FILE *f = fopen(fname, "rb");
if (f == NULL) {
fprintf(stderr, "[ERROR] Cannot open file \'%s\'\n", fname);
exit(-1);
}
FILE *f = fopen(fname, "rb");
if (f == NULL) {
fprintf(stderr, "[ERROR] Cannot open file \'%s\'\n", fname);
exit(-1);
}
fseek(f, 0, SEEK_END);
size_t size_ = ftell(f);
rewind(f);
fseek(f, 0, SEEK_END);
size_t size_ = ftell(f);
rewind(f);
void *buf = malloc(size_);
size_t ret = fread(buf, 1, size_, f);
if (ret == 0) {
fprintf(stderr, "[ERROR] Cannot read file \'%s\'\n", fname);
exit(-1);
}
fclose(f);
void *buf = malloc(size_);
size_t ret = fread(buf, 1, size_, f);
if (ret == 0) {
fprintf(stderr, "[ERROR] Cannot read file \'%s\'\n", fname);
exit(-1);
}
fclose(f);
if (size != NULL)
*size = (size_t)(size_ / 4); // 4 bytes per float or int
if (size != NULL)
*size = (size_t)(size_ / 4); // 4 bytes per float or int
return buf;
return buf;
}
void write_binary(int *output, const char *filename, int size_) {
FILE *f = (FILE *)fopen(filename, "w");
fwrite(output, sizeof(int), size_, f);
fclose(f);
FILE *f = (FILE *)fopen(filename, "w");
fwrite(output, sizeof(int), size_, f);
fclose(f);
}
double get_time() {
@ -121,33 +121,33 @@ double get_time() {
int check_validation(int* output, int* answer, int size_) {
int ret = -1;
int mismatch_idx = -1;
int tolerance = size_ * 0.0001; // Error tolerance percentage
for (int i = 0; i < size_; i++) {
if (isnan(output[i])) {
fprintf(stderr, "[ERROR] Output contains NaN at index %d\n", i);
mismatch_idx = i;
break;
}
if (output[i] != answer[i]) {
/* Decrease tolerance */
tolerance--;
/* Save the first mismatch index */
if (mismatch_idx == -1) {
mismatch_idx = i;
}
/* Break if tolerance is reached */
if (tolerance < 0) {
ret = mismatch_idx;
break;
}
}
int ret = -1;
int mismatch_idx = -1;
int tolerance = size_ * 0.0001; // Error tolerance percentage
for (int i = 0; i < size_; i++) {
if (isnan(output[i])) {
fprintf(stderr, "[ERROR] Output contains NaN at index %d\n", i);
mismatch_idx = i;
break;
}
return ret;
if (output[i] != answer[i]) {
/* Decrease tolerance */
tolerance--;
/* Save the first mismatch index */
if (mismatch_idx == -1) {
mismatch_idx = i;
}
/* Break if tolerance is reached */
if (tolerance < 0) {
ret = mismatch_idx;
break;
}
}
}
return ret;
}