[Refactor] refactor & modify operations and comments
This commit is contained in:
parent
3a11f82ac1
commit
c20814b81b
|
@ -11,107 +11,107 @@
|
|||
|
||||
int main(int argc, char **argv) {
|
||||
|
||||
/* MPI Initialization */
|
||||
int mpi_rank, mpi_size;
|
||||
MPI_Init(&argc, &argv);
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank);
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &mpi_size);
|
||||
|
||||
/* Parse arguments */
|
||||
if (mpi_rank == 0) {
|
||||
parse_args(argc, argv);
|
||||
}
|
||||
/* MPI Initialization */
|
||||
int mpi_rank, mpi_size;
|
||||
MPI_Init(&argc, &argv);
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank);
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &mpi_size);
|
||||
|
||||
/* Parse arguments */
|
||||
if (mpi_rank == 0) {
|
||||
parse_args(argc, argv);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
// INITIALIZATION //
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
||||
int *input, *output;
|
||||
////////////////////////////////////////////////////////////////////
|
||||
// INITIALIZATION //
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
||||
int *input, *output;
|
||||
|
||||
if (mpi_rank == 0) {
|
||||
/* Load input (size: N x input_tokens) from file */
|
||||
fprintf(stderr, "[LOG] Reading input from %s\n", input_fname);
|
||||
size_t input_size;
|
||||
input = (int *)read_binary(input_fname, &input_size);
|
||||
|
||||
if (input_size % input_tokens != 0) {
|
||||
fprintf(stderr, "[ERROR] Invalid input size\n");
|
||||
exit(1);
|
||||
}
|
||||
if (mpi_rank == 0) {
|
||||
/* Load input (size: N x input_tokens) from file */
|
||||
fprintf(stderr, "[LOG] Reading input from %s\n", input_fname);
|
||||
size_t input_size;
|
||||
input = (int *)read_binary(input_fname, &input_size);
|
||||
|
||||
if (input_size % input_tokens != 0) {
|
||||
fprintf(stderr, "[ERROR] Invalid input size\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
/* Allocate output (size: N x T) */
|
||||
output = (int *)malloc(N * T * sizeof(int));
|
||||
}
|
||||
|
||||
/* Initialize parameters and activations */
|
||||
if (mpi_rank == 0) fprintf(stderr, "[LOG] Initializing... \n");
|
||||
initialize_parameters(param_fname);
|
||||
initialize_activations();
|
||||
/* Allocate output (size: N x T) */
|
||||
output = (int *)malloc(N * T * sizeof(int));
|
||||
}
|
||||
|
||||
/* Initialize parameters and activations */
|
||||
if (mpi_rank == 0) fprintf(stderr, "[LOG] Initializing... \n");
|
||||
initialize_parameters(param_fname);
|
||||
initialize_activations();
|
||||
|
||||
/* Cannot surpass the max_seq_len of the model */
|
||||
assert(input_tokens + T <= MAX_SEQ_LEN);
|
||||
/* Cannot surpass the max_seq_len of the model */
|
||||
assert(input_tokens + T <= MAX_SEQ_LEN);
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
// MODEL COMPUTATION //
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
||||
double st = 0.0, et = 0.0;
|
||||
////////////////////////////////////////////////////////////////////
|
||||
// MODEL COMPUTATION //
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
||||
double st = 0.0, et = 0.0;
|
||||
|
||||
if (mpi_rank == 0) {
|
||||
fprintf(stdout, " Start computation... \n");
|
||||
st = get_time();
|
||||
}
|
||||
if (mpi_rank == 0) {
|
||||
fprintf(stdout, " Start computation... \n");
|
||||
st = get_time();
|
||||
}
|
||||
|
||||
/* Text Generation */
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
generate_tokens(input, output, N, T);
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
/* Text Generation */
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
generate_tokens(input, output, N, T);
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
|
||||
if (mpi_rank == 0) {
|
||||
et = get_time();
|
||||
if (mpi_rank == 0) {
|
||||
et = get_time();
|
||||
|
||||
/* Print the result */
|
||||
fprintf(stdout, " Done!\n");
|
||||
fprintf(stdout, " Elapsed time: %lf (sec)\n", et - st);
|
||||
fprintf(stdout, " Throughput: %lf (tokens/sec)\n",
|
||||
N*T / (et - st));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
// FINALIZATION //
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
||||
/* Finalize parameters and activations */
|
||||
if (mpi_rank == 0) fprintf(stderr, "[LOG] Finalizing... \n");
|
||||
finalize_parameters();
|
||||
finalize_activations();
|
||||
|
||||
if (mpi_rank == 0) {
|
||||
/* Save output */
|
||||
if (S) {
|
||||
fprintf(stdout, " Saving output... \n");
|
||||
write_binary(output, output_fname, N*T);
|
||||
}
|
||||
/* Print the result */
|
||||
fprintf(stdout, " Done!\n");
|
||||
fprintf(stdout, " Elapsed time: %lf (sec)\n", et - st);
|
||||
fprintf(stdout, " Throughput: %lf (tokens/sec)\n",
|
||||
N*T / (et - st));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
// FINALIZATION //
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
||||
/* Finalize parameters and activations */
|
||||
if (mpi_rank == 0) fprintf(stderr, "[LOG] Finalizing... \n");
|
||||
finalize_parameters();
|
||||
finalize_activations();
|
||||
|
||||
if (mpi_rank == 0) {
|
||||
/* Save output */
|
||||
if (S) {
|
||||
fprintf(stdout, " Saving output... \n");
|
||||
write_binary(output, output_fname, N*T);
|
||||
}
|
||||
|
||||
/* Validation */
|
||||
if (V) {
|
||||
fprintf(stdout, " Validation... \n");
|
||||
/* Validation */
|
||||
if (V) {
|
||||
fprintf(stdout, " Validation... \n");
|
||||
|
||||
int *answer = (int *)read_binary(answer_fname, NULL);
|
||||
int ret = check_validation(output, answer, N*T);
|
||||
if (ret == -1) {
|
||||
fprintf(stdout, " Validation passed!\n");
|
||||
} else {
|
||||
fprintf(stdout, " Validation failed: First mismatch "
|
||||
"at prompt[#%d], token_ID[#%d] (output[%d]=%d <-> "
|
||||
"answer[%d]=%d)\n", ret / T, ret % T, ret,
|
||||
output[ret], ret, answer[ret]);
|
||||
}
|
||||
}
|
||||
}
|
||||
int *answer = (int *)read_binary(answer_fname, NULL);
|
||||
int ret = check_validation(output, answer, N*T);
|
||||
if (ret == -1) {
|
||||
fprintf(stdout, " Validation passed!\n");
|
||||
} else {
|
||||
fprintf(stdout, " Validation failed: First mismatch "
|
||||
"at prompt[#%d], token_ID[#%d] (output[%d]=%d <-> "
|
||||
"answer[%d]=%d)\n", ret / T, ret % T, ret,
|
||||
output[ret], ret, answer[ret]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* MPI Finalization */
|
||||
MPI_Finalize();
|
||||
/* MPI Finalization */
|
||||
MPI_Finalize();
|
||||
|
||||
return 0;
|
||||
return 0;
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -1,8 +1,8 @@
|
|||
#!bin/sh
|
||||
|
||||
mpirun --bind-to none -mca btl ^openib -npernode 1 \
|
||||
--oversubscribe -quiet \
|
||||
./main \
|
||||
-v -s \
|
||||
-t 5 \
|
||||
-n 1 \
|
||||
--oversubscribe -quiet \
|
||||
./main \
|
||||
-v -s \
|
||||
-t 5 \
|
||||
-n 1 \
|
|
@ -7,32 +7,32 @@
|
|||
using namespace std;
|
||||
|
||||
Tensor::Tensor(const vector<int> &shape_) {
|
||||
ndim = shape_.size();
|
||||
for (size_t i = 0; i < ndim; i++) {
|
||||
shape[i] = shape_[i];
|
||||
}
|
||||
int N_ = num_elem();
|
||||
buf = (float *)calloc(N_, sizeof(float));
|
||||
ndim = shape_.size();
|
||||
for (size_t i = 0; i < ndim; i++) {
|
||||
shape[i] = shape_[i];
|
||||
}
|
||||
int N_ = num_elem();
|
||||
buf = (float *)calloc(N_, sizeof(float));
|
||||
}
|
||||
|
||||
Tensor::Tensor(const vector<int> &shape_, float *buf_) {
|
||||
ndim = shape_.size();
|
||||
for (size_t i = 0; i < ndim; i++) {
|
||||
shape[i] = shape_[i];
|
||||
}
|
||||
int N_ = num_elem();
|
||||
buf = (float *)malloc(N_ * sizeof(float));
|
||||
memcpy(buf, buf_, N_ * sizeof(float));
|
||||
ndim = shape_.size();
|
||||
for (size_t i = 0; i < ndim; i++) {
|
||||
shape[i] = shape_[i];
|
||||
}
|
||||
int N_ = num_elem();
|
||||
buf = (float *)malloc(N_ * sizeof(float));
|
||||
memcpy(buf, buf_, N_ * sizeof(float));
|
||||
}
|
||||
|
||||
Tensor::~Tensor() {
|
||||
if (buf != nullptr) free(buf);
|
||||
if (buf != nullptr) free(buf);
|
||||
}
|
||||
|
||||
int Tensor::num_elem() {
|
||||
int size = 1;
|
||||
for (size_t i = 0; i < ndim; i++) {
|
||||
size *= shape[i];
|
||||
}
|
||||
return size;
|
||||
int size = 1;
|
||||
for (size_t i = 0; i < ndim; i++) {
|
||||
size *= shape[i];
|
||||
}
|
||||
return size;
|
||||
}
|
|
@ -6,13 +6,13 @@
|
|||
using namespace std;
|
||||
|
||||
struct Tensor {
|
||||
size_t ndim = 0;
|
||||
int shape[4];
|
||||
float *buf = nullptr;
|
||||
size_t ndim = 0;
|
||||
int shape[4];
|
||||
float *buf = nullptr;
|
||||
|
||||
Tensor(const vector<int> &shape_);
|
||||
Tensor(const vector<int> &shape_, float *buf_);
|
||||
~Tensor();
|
||||
Tensor(const vector<int> &shape_);
|
||||
Tensor(const vector<int> &shape_, float *buf_);
|
||||
~Tensor();
|
||||
|
||||
int num_elem();
|
||||
int num_elem();
|
||||
};
|
|
@ -17,99 +17,99 @@ char answer_fname[] = "./data/answer.bin";
|
|||
char output_fname[] = "./data/output.bin";
|
||||
|
||||
void parse_args(int argc, char **argv) {
|
||||
int args;
|
||||
while ((args = getopt(argc, argv, "i:o:a:p:n:t:vswh")) != -1) {
|
||||
switch (args) {
|
||||
case 'i':
|
||||
strcpy(input_fname, optarg);
|
||||
break;
|
||||
case 'o':
|
||||
strcpy(output_fname, optarg);
|
||||
break;
|
||||
case 'a':
|
||||
strcpy(answer_fname, optarg);
|
||||
break;
|
||||
case 'p':
|
||||
strcpy(param_fname, optarg);
|
||||
break;
|
||||
case 'n':
|
||||
N = atoi(optarg);
|
||||
break;
|
||||
case 't':
|
||||
T = atoi(optarg);
|
||||
break;
|
||||
case 'v':
|
||||
V = true;
|
||||
break;
|
||||
case 's':
|
||||
S = true;
|
||||
break;
|
||||
case 'h':
|
||||
print_help();
|
||||
exit(0);
|
||||
break;
|
||||
default:
|
||||
print_help();
|
||||
exit(0);
|
||||
break;
|
||||
}
|
||||
int args;
|
||||
while ((args = getopt(argc, argv, "i:o:a:p:n:t:vswh")) != -1) {
|
||||
switch (args) {
|
||||
case 'i':
|
||||
strcpy(input_fname, optarg);
|
||||
break;
|
||||
case 'o':
|
||||
strcpy(output_fname, optarg);
|
||||
break;
|
||||
case 'a':
|
||||
strcpy(answer_fname, optarg);
|
||||
break;
|
||||
case 'p':
|
||||
strcpy(param_fname, optarg);
|
||||
break;
|
||||
case 'n':
|
||||
N = atoi(optarg);
|
||||
break;
|
||||
case 't':
|
||||
T = atoi(optarg);
|
||||
break;
|
||||
case 'v':
|
||||
V = true;
|
||||
break;
|
||||
case 's':
|
||||
S = true;
|
||||
break;
|
||||
case 'h':
|
||||
print_help();
|
||||
exit(0);
|
||||
break;
|
||||
default:
|
||||
print_help();
|
||||
exit(0);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stdout, "\n=============================================\n");
|
||||
fprintf(stdout, " Model: GPT-2 (12 layers)\n");
|
||||
fprintf(stdout, "---------------------------------------------\n");
|
||||
fprintf(stdout, " Validation: %s\n", V ? "ON" : "OFF");
|
||||
fprintf(stdout, " Save output: %s\n", S ? "ON" : "OFF");
|
||||
fprintf(stdout, " Number of Prompts: %d\n", N);
|
||||
fprintf(stdout, " Number of Tokens to generate: %d\n", T);
|
||||
fprintf(stdout, "=============================================\n\n");
|
||||
fprintf(stdout, "\n=============================================\n");
|
||||
fprintf(stdout, " Model: GPT-2 (12 layers)\n");
|
||||
fprintf(stdout, "---------------------------------------------\n");
|
||||
fprintf(stdout, " Validation: %s\n", V ? "ON" : "OFF");
|
||||
fprintf(stdout, " Save output: %s\n", S ? "ON" : "OFF");
|
||||
fprintf(stdout, " Number of Prompts: %d\n", N);
|
||||
fprintf(stdout, " Number of Tokens to generate: %d\n", T);
|
||||
fprintf(stdout, "=============================================\n\n");
|
||||
}
|
||||
|
||||
void print_help() {
|
||||
fprintf(stdout,
|
||||
" Usage: ./main [-i 'pth'] [-p 'pth'] [-o 'pth'] [-a 'pth']"
|
||||
" [-t 'tokens'] [-n 'prompts'] [-v] [-s] [-h]\n");
|
||||
fprintf(stdout, " Options:\n");
|
||||
fprintf(stdout, " -i: Input binary path (default: data/input.bin)\n");
|
||||
fprintf(stdout, " -p: Model parameter path (default: assets/model_file.bin)\n");
|
||||
fprintf(stdout, " -o: Output binary path (default: output.bin)\n");
|
||||
fprintf(stdout, " -a: Answer binary path (default: data/answer.bin)\n");
|
||||
fprintf(stdout, " -n: Number of prompts (default: 1)\n");
|
||||
fprintf(stdout, " -t: Number of tokens to generate (default: 5)\n");
|
||||
fprintf(stdout, " -v: Enable validation (default: OFF)\n");
|
||||
fprintf(stdout, " -s: Enable saving output tensor (default: OFF)\n");
|
||||
fprintf(stdout, " -h: Print manual and options (default: OFF)\n");
|
||||
fprintf(stdout,
|
||||
" Usage: ./main [-i 'pth'] [-p 'pth'] [-o 'pth'] [-a 'pth']"
|
||||
" [-t 'tokens'] [-n 'prompts'] [-v] [-s] [-h]\n");
|
||||
fprintf(stdout, " Options:\n");
|
||||
fprintf(stdout, " -i: Input binary path (default: data/input.bin)\n");
|
||||
fprintf(stdout, " -p: Model parameter path (default: assets/model_file.bin)\n");
|
||||
fprintf(stdout, " -o: Output binary path (default: output.bin)\n");
|
||||
fprintf(stdout, " -a: Answer binary path (default: data/answer.bin)\n");
|
||||
fprintf(stdout, " -n: Number of prompts (default: 1)\n");
|
||||
fprintf(stdout, " -t: Number of tokens to generate (default: 5)\n");
|
||||
fprintf(stdout, " -v: Enable validation (default: OFF)\n");
|
||||
fprintf(stdout, " -s: Enable saving output tensor (default: OFF)\n");
|
||||
fprintf(stdout, " -h: Print manual and options (default: OFF)\n");
|
||||
}
|
||||
|
||||
void* read_binary(const char *fname, size_t *size) {
|
||||
FILE *f = fopen(fname, "rb");
|
||||
if (f == NULL) {
|
||||
fprintf(stderr, "[ERROR] Cannot open file \'%s\'\n", fname);
|
||||
exit(-1);
|
||||
}
|
||||
FILE *f = fopen(fname, "rb");
|
||||
if (f == NULL) {
|
||||
fprintf(stderr, "[ERROR] Cannot open file \'%s\'\n", fname);
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
fseek(f, 0, SEEK_END);
|
||||
size_t size_ = ftell(f);
|
||||
rewind(f);
|
||||
fseek(f, 0, SEEK_END);
|
||||
size_t size_ = ftell(f);
|
||||
rewind(f);
|
||||
|
||||
void *buf = malloc(size_);
|
||||
size_t ret = fread(buf, 1, size_, f);
|
||||
if (ret == 0) {
|
||||
fprintf(stderr, "[ERROR] Cannot read file \'%s\'\n", fname);
|
||||
exit(-1);
|
||||
}
|
||||
fclose(f);
|
||||
void *buf = malloc(size_);
|
||||
size_t ret = fread(buf, 1, size_, f);
|
||||
if (ret == 0) {
|
||||
fprintf(stderr, "[ERROR] Cannot read file \'%s\'\n", fname);
|
||||
exit(-1);
|
||||
}
|
||||
fclose(f);
|
||||
|
||||
if (size != NULL)
|
||||
*size = (size_t)(size_ / 4); // 4 bytes per float or int
|
||||
if (size != NULL)
|
||||
*size = (size_t)(size_ / 4); // 4 bytes per float or int
|
||||
|
||||
return buf;
|
||||
return buf;
|
||||
}
|
||||
|
||||
void write_binary(int *output, const char *filename, int size_) {
|
||||
FILE *f = (FILE *)fopen(filename, "w");
|
||||
fwrite(output, sizeof(int), size_, f);
|
||||
fclose(f);
|
||||
FILE *f = (FILE *)fopen(filename, "w");
|
||||
fwrite(output, sizeof(int), size_, f);
|
||||
fclose(f);
|
||||
}
|
||||
|
||||
double get_time() {
|
||||
|
@ -121,33 +121,33 @@ double get_time() {
|
|||
|
||||
int check_validation(int* output, int* answer, int size_) {
|
||||
|
||||
int ret = -1;
|
||||
int mismatch_idx = -1;
|
||||
int tolerance = size_ * 0.0001; // Error tolerance percentage
|
||||
|
||||
for (int i = 0; i < size_; i++) {
|
||||
if (isnan(output[i])) {
|
||||
fprintf(stderr, "[ERROR] Output contains NaN at index %d\n", i);
|
||||
mismatch_idx = i;
|
||||
break;
|
||||
}
|
||||
|
||||
if (output[i] != answer[i]) {
|
||||
/* Decrease tolerance */
|
||||
tolerance--;
|
||||
|
||||
/* Save the first mismatch index */
|
||||
if (mismatch_idx == -1) {
|
||||
mismatch_idx = i;
|
||||
}
|
||||
|
||||
/* Break if tolerance is reached */
|
||||
if (tolerance < 0) {
|
||||
ret = mismatch_idx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
int ret = -1;
|
||||
int mismatch_idx = -1;
|
||||
int tolerance = size_ * 0.0001; // Error tolerance percentage
|
||||
|
||||
for (int i = 0; i < size_; i++) {
|
||||
if (isnan(output[i])) {
|
||||
fprintf(stderr, "[ERROR] Output contains NaN at index %d\n", i);
|
||||
mismatch_idx = i;
|
||||
break;
|
||||
}
|
||||
|
||||
return ret;
|
||||
if (output[i] != answer[i]) {
|
||||
/* Decrease tolerance */
|
||||
tolerance--;
|
||||
|
||||
/* Save the first mismatch index */
|
||||
if (mismatch_idx == -1) {
|
||||
mismatch_idx = i;
|
||||
}
|
||||
|
||||
/* Break if tolerance is reached */
|
||||
if (tolerance < 0) {
|
||||
ret = mismatch_idx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
Loading…
Reference in New Issue