chundoong-lab-ta/SHPC2022/final-project/model.py

53 lines
1.8 KiB
Python
Raw Normal View History

2022-11-15 13:45:21 +09:00
import torch
import torch.nn as nn
import arguments
from constants import EOS_ID, SOS_ID
class LMModel(nn.Module):
def __init__(self, hidden_size, vocab_size, num_layers=2):
super(LMModel, self).__init__()
self.args = arguments.parse_args()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.embedding = torch.nn.Embedding(
vocab_size, self.args.embedding_dim)
self.backbone = torch.nn.GRU(input_size=self.args.embedding_dim, hidden_size=hidden_size, num_layers=num_layers,
batch_first=True, dropout=0.1)
self.linear = torch.nn.Linear(
in_features=hidden_size, out_features=vocab_size)
def forward(self, input, hidden=None):
x = self.embedding(input)
x, h_n = self.backbone(x, hidden)
x = self.linear(x)
return x, h_n
def initHidden(self, batch_size):
return \
torch.zeros((self.num_layers, batch_size,
self.hidden_size), device=self.args.device)
def generate(self, num_gen=1):
input = torch.tensor([[SOS_ID]
for _ in range(num_gen)]).to(self.args.device)
ret = input.clone().detach()
hidden = None
for i in range(self.args.password_length):
output, hidden = self.forward(input, hidden)
output = output.view((output.size(0), output.size(2)))
output = nn.functional.softmax(output, dim=1)
choices = torch.multinomial(output, num_samples=1)
ret = torch.cat((ret, choices), dim=1)
input = choices
eostensor = torch.tensor([[EOS_ID]
for _ in range(num_gen)]).to(self.args.device)
ret = torch.cat((ret, eostensor), dim=1)
return ret