76 lines
2.3 KiB
Python
76 lines
2.3 KiB
Python
import arguments
|
|
import dataset
|
|
from dataset import NaiveTokenizer, PasswordDataset
|
|
import model
|
|
from torch.utils.data import DataLoader
|
|
import torch.optim as optim
|
|
import torch.nn as nn
|
|
import torch
|
|
from constants import REAL_LABEL, FAKE_LABEL, EOS_ID, SOS_ID
|
|
import sys
|
|
import time
|
|
import pickle
|
|
import os
|
|
|
|
|
|
def trainLMModel():
|
|
args = arguments.parse_args()
|
|
|
|
tokenizer = NaiveTokenizer([args.train_path])
|
|
traindataset = PasswordDataset(
|
|
args.train_path, args.password_length, tokenizer, args)
|
|
train_dataloader = DataLoader(
|
|
traindataset, batch_size=args.batch_size, shuffle=True)
|
|
|
|
LMModel = model.LMModel(
|
|
args.hidden_dim, tokenizer.vocab_size())
|
|
LMModel = LMModel.to(args.device)
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
# optimizer = optim.SGD(LMModel.parameters(), lr=args.lr)
|
|
optimizer = optim.Adam(LMModel.parameters(), lr=args.lr, betas=(0.5, 0.9))
|
|
|
|
train_iter = 0
|
|
final_loss = 0
|
|
for epoch in range(0, 30):
|
|
for i, data in enumerate(train_dataloader, 0):
|
|
data = data.to(args.device)
|
|
|
|
# data tensor: [batch_size x password_length], index of each character
|
|
train_iter += 1
|
|
iter_start_time = time.time()
|
|
|
|
LMModel.zero_grad()
|
|
|
|
# target tensor: [batch_size x password_length] shifted right by one from input
|
|
target = torch.cat(
|
|
(data[:, 1:], torch.full([data.size(0), 1], SOS_ID, device=args.device)), dim=1
|
|
)
|
|
|
|
loss = 0
|
|
|
|
input = data
|
|
output, hidden = LMModel(input)
|
|
output = output.transpose(1, 2)
|
|
loss = criterion(output, target)
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
iter_end_time = time.time()
|
|
if train_iter % args.print_freq == 0:
|
|
print(
|
|
f"[Epoch {epoch} / Iter {train_iter}] loss {loss.item() / input.size(1):.3f}, {iter_end_time - iter_start_time:.3f} sec/it")
|
|
idlist = LMModel.generate(2)
|
|
print(
|
|
f"{train_dataloader.dataset.idtensor_to_string(idlist)}")
|
|
|
|
final_loss = loss.item() / input.size(1)
|
|
|
|
torch.save(LMModel.state_dict(), f"model-{final_loss:.4f}.pt")
|
|
sys.exit(0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
trainLMModel()
|