remove warmup
This commit is contained in:
parent
62612cbcfb
commit
4beb68a2a8
|
@ -51,13 +51,6 @@ int main(int argc, char **argv) {
|
|||
/* Cannot surpass the max_seq_len of the model */
|
||||
assert(N_SEQ + T <= N_CTX);
|
||||
|
||||
/* Warm-up */
|
||||
if (W) {
|
||||
if (mpi_rank == 0) fprintf(stdout, " Warming up... \n");
|
||||
for (int i = 0; i < 3; i++)
|
||||
generate_tokens(input, output, 1, 1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
// MODEL COMPUTATION //
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -5,5 +5,4 @@ mpirun --bind-to none -mca btl ^openib -npernode 1 \
|
|||
./main \
|
||||
-v -s \
|
||||
-t 5 \
|
||||
-n 1 \
|
||||
# -w \
|
||||
-n 1 \
|
|
@ -9,7 +9,7 @@ using namespace std;
|
|||
|
||||
int N = 1;
|
||||
int T = 5;
|
||||
bool V, S, W;
|
||||
bool V, S;
|
||||
|
||||
char input_fname[] = "./data/input.bin";
|
||||
char param_fname[] = "./assets/model_file.bin";
|
||||
|
@ -44,9 +44,6 @@ void parse_args(int argc, char **argv) {
|
|||
case 's':
|
||||
S = true;
|
||||
break;
|
||||
case 'w':
|
||||
W = true;
|
||||
break;
|
||||
case 'h':
|
||||
print_help();
|
||||
exit(0);
|
||||
|
@ -61,7 +58,6 @@ void parse_args(int argc, char **argv) {
|
|||
fprintf(stdout, "\n=============================================\n");
|
||||
fprintf(stdout, " Model: GPT-2 (12 layers)\n");
|
||||
fprintf(stdout, "---------------------------------------------\n");
|
||||
fprintf(stdout, " Warmup: %s\n", W ? "ON" : "OFF");
|
||||
fprintf(stdout, " Validation: %s\n", V ? "ON" : "OFF");
|
||||
fprintf(stdout, " Save output: %s\n", S ? "ON" : "OFF");
|
||||
fprintf(stdout, " Number of Prompts: %d\n", N);
|
||||
|
@ -72,7 +68,7 @@ void parse_args(int argc, char **argv) {
|
|||
void print_help() {
|
||||
fprintf(stdout,
|
||||
" Usage: ./main [-i 'pth'] [-p 'pth'] [-o 'pth'] [-a 'pth']"
|
||||
" [-t 'tokens'] [-n 'prompts'] [-v] [-s] [-w] [-h]\n");
|
||||
" [-t 'tokens'] [-n 'prompts'] [-v] [-s] [-h]\n");
|
||||
fprintf(stdout, " Options:\n");
|
||||
fprintf(stdout, " -i: Input binary path (default: data/input.bin)\n");
|
||||
fprintf(stdout, " -p: Model parameter path (default: assets/model_file.bin)\n");
|
||||
|
@ -82,7 +78,6 @@ void print_help() {
|
|||
fprintf(stdout, " -t: Number of tokens to generate (default: 5)\n");
|
||||
fprintf(stdout, " -v: Enable validation (default: OFF)\n");
|
||||
fprintf(stdout, " -s: Enable saving output tensor (default: OFF)\n");
|
||||
fprintf(stdout, " -w: Enable warmup (default: OFF)\n");
|
||||
fprintf(stdout, " -h: Print manual and options (default: OFF)\n");
|
||||
}
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ using namespace std;
|
|||
|
||||
extern int N;
|
||||
extern int T;
|
||||
extern bool V, S, W;
|
||||
extern bool V, S;
|
||||
|
||||
extern char param_fname[100];
|
||||
extern char input_fname[100];
|
||||
|
|
Loading…
Reference in New Issue