remove warmup

This commit is contained in:
Jaehwan Lee 2024-04-22 10:00:56 +00:00
parent 62612cbcfb
commit 4beb68a2a8
4 changed files with 4 additions and 17 deletions

View File

@ -51,13 +51,6 @@ int main(int argc, char **argv) {
/* Cannot surpass the max_seq_len of the model */
assert(N_SEQ + T <= N_CTX);
/* Warm-up */
if (W) {
if (mpi_rank == 0) fprintf(stdout, " Warming up... \n");
for (int i = 0; i < 3; i++)
generate_tokens(input, output, 1, 1);
}
////////////////////////////////////////////////////////////////////
// MODEL COMPUTATION //
////////////////////////////////////////////////////////////////////

View File

@ -5,5 +5,4 @@ mpirun --bind-to none -mca btl ^openib -npernode 1 \
./main \
-v -s \
-t 5 \
-n 1 \
# -w \
-n 1 \

View File

@ -9,7 +9,7 @@ using namespace std;
int N = 1;
int T = 5;
bool V, S, W;
bool V, S;
char input_fname[] = "./data/input.bin";
char param_fname[] = "./assets/model_file.bin";
@ -44,9 +44,6 @@ void parse_args(int argc, char **argv) {
case 's':
S = true;
break;
case 'w':
W = true;
break;
case 'h':
print_help();
exit(0);
@ -61,7 +58,6 @@ void parse_args(int argc, char **argv) {
fprintf(stdout, "\n=============================================\n");
fprintf(stdout, " Model: GPT-2 (12 layers)\n");
fprintf(stdout, "---------------------------------------------\n");
fprintf(stdout, " Warmup: %s\n", W ? "ON" : "OFF");
fprintf(stdout, " Validation: %s\n", V ? "ON" : "OFF");
fprintf(stdout, " Save output: %s\n", S ? "ON" : "OFF");
fprintf(stdout, " Number of Prompts: %d\n", N);
@ -72,7 +68,7 @@ void parse_args(int argc, char **argv) {
void print_help() {
fprintf(stdout,
" Usage: ./main [-i 'pth'] [-p 'pth'] [-o 'pth'] [-a 'pth']"
" [-t 'tokens'] [-n 'prompts'] [-v] [-s] [-w] [-h]\n");
" [-t 'tokens'] [-n 'prompts'] [-v] [-s] [-h]\n");
fprintf(stdout, " Options:\n");
fprintf(stdout, " -i: Input binary path (default: data/input.bin)\n");
fprintf(stdout, " -p: Model parameter path (default: assets/model_file.bin)\n");
@ -82,7 +78,6 @@ void print_help() {
fprintf(stdout, " -t: Number of tokens to generate (default: 5)\n");
fprintf(stdout, " -v: Enable validation (default: OFF)\n");
fprintf(stdout, " -s: Enable saving output tensor (default: OFF)\n");
fprintf(stdout, " -w: Enable warmup (default: OFF)\n");
fprintf(stdout, " -h: Print manual and options (default: OFF)\n");
}

View File

@ -7,7 +7,7 @@ using namespace std;
extern int N;
extern int T;
extern bool V, S, W;
extern bool V, S;
extern char param_fname[100];
extern char input_fname[100];