chundoong-lab-ta/SHPC2022/final-project/train.py

76 lines
2.3 KiB
Python

import arguments
import dataset
from dataset import NaiveTokenizer, PasswordDataset
import model
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
import torch
from constants import REAL_LABEL, FAKE_LABEL, EOS_ID, SOS_ID
import sys
import time
import pickle
import os
def trainLMModel():
args = arguments.parse_args()
tokenizer = NaiveTokenizer([args.train_path])
traindataset = PasswordDataset(
args.train_path, args.password_length, tokenizer, args)
train_dataloader = DataLoader(
traindataset, batch_size=args.batch_size, shuffle=True)
LMModel = model.LMModel(
args.hidden_dim, tokenizer.vocab_size())
LMModel = LMModel.to(args.device)
criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(LMModel.parameters(), lr=args.lr)
optimizer = optim.Adam(LMModel.parameters(), lr=args.lr, betas=(0.5, 0.9))
train_iter = 0
final_loss = 0
for epoch in range(0, 30):
for i, data in enumerate(train_dataloader, 0):
data = data.to(args.device)
# data tensor: [batch_size x password_length], index of each character
train_iter += 1
iter_start_time = time.time()
LMModel.zero_grad()
# target tensor: [batch_size x password_length] shifted right by one from input
target = torch.cat(
(data[:, 1:], torch.full([data.size(0), 1], SOS_ID, device=args.device)), dim=1
)
loss = 0
input = data
output, hidden = LMModel(input)
output = output.transpose(1, 2)
loss = criterion(output, target)
loss.backward()
optimizer.step()
iter_end_time = time.time()
if train_iter % args.print_freq == 0:
print(
f"[Epoch {epoch} / Iter {train_iter}] loss {loss.item() / input.size(1):.3f}, {iter_end_time - iter_start_time:.3f} sec/it")
idlist = LMModel.generate(2)
print(
f"{train_dataloader.dataset.idtensor_to_string(idlist)}")
final_loss = loss.item() / input.size(1)
torch.save(LMModel.state_dict(), f"model-{final_loss:.4f}.pt")
sys.exit(0)
if __name__ == "__main__":
trainLMModel()