chundoong-lab-ta/SHPC2022/final-project/dataset.py

169 lines
4.7 KiB
Python
Raw Normal View History

2022-11-15 13:45:21 +09:00
import string
import time
import torch
from torch.utils.data import Dataset
from constants import EOS, EOS_ID, SOS, SOS_ID, PAD, PAD_ID
import numpy as np
def is_ascii(s: str) -> bool:
try:
s.encode('ascii')
except UnicodeEncodeError:
return False
else:
return True
def build_uniquelist_from_file(fname: str, password_length_limit: int) -> list[str]:
passwdset = set()
raw_list = [x.strip() for x in open(fname, 'r').readlines()]
for passwd in raw_list:
if not is_ascii(passwd):
continue
if len(passwd) > password_length_limit:
continue
if len(passwd) < 1:
continue
passwdset.add(passwd)
return passwdset
def ord_to_string(l, inv_charmap) -> str:
ret = []
for pl in l:
s = ""
for i in range(len(pl)):
s += inv_charmap[pl[i]]
ret.append(s)
return ret
# Left and right pad the list
# Returns password_length + 2
# +2 is to ensure <sos> at beginning and at least one <eos>
def refine_password_char(l: list[str], password_length: int) -> list[int]:
if l[0] != SOS:
l = [SOS] + l
first_eos = -1
for i in range(len(l)):
if l[i] == EOS:
first_eos = i
break
if first_eos > -1:
l = l[:first_eos]
ret = l + [EOS]
ret = ret + [PAD for _ in range(password_length + 2 - len(ret))]
assert len(ret) == password_length + 2
return ret
def refine_password_int(l: list[int], password_length: int) -> list[int]:
if l[0] != SOS_ID:
l = [SOS_ID] + l
first_eos = -1
for i in range(len(l)):
if l[i] == EOS_ID:
first_eos = i
break
if first_eos > -1:
l = l[:first_eos]
ret = l + [EOS_ID]
ret = ret + [PAD_ID for _ in range(password_length + 2 - len(ret))]
assert len(ret) == password_length + 2
return ret
def load_dataset(path, max_length, tokenize=False, max_vocab_size=2048):
lines = []
with open(path, 'r') as f:
for line in f:
line = line.strip()
lines.append(refine_password_char(list(line), max_length))
np.random.shuffle(lines)
print("loaded {} lines in dataset".format(len(lines)))
return lines
class NaiveTokenizer():
def __init__(self, dataset_paths: list[str]):
self.password_list = []
self.char2id = {EOS: EOS_ID, SOS: SOS_ID, PAD: PAD_ID}
self.id2char = {EOS_ID: EOS, SOS_ID: SOS, PAD_ID: PAD}
for i in range(3, 256):
self.char2id[chr(i)] = i
self.id2char[i] = chr(i)
def vocab_size(self) -> int:
return len(self.id2char)
class PasswordDataset(Dataset):
def __init__(self, dataset_path, password_length, tokenizer, args):
self.tokenizer = tokenizer
self.password_length = password_length
self.password_list = []
with open(dataset_path, 'r') as f:
for line in f.readlines():
line = line.strip()
if len(line) > password_length or len(line) < 1:
continue
self.password_list.append(refine_password_char(
list(line), password_length))
self.args = args
self.data_len = len(self.password_list)
self.password_string_set = set(
[''.join(x) for x in self.password_list])
self.password_tensor_list = []
# TODO: this line takes too long...
for password in self.password_list:
l = [self.tokenizer.char2id[x] for x in password]
self.password_tensor_list.append(
torch.tensor(l, dtype=torch.long)
)
def __len__(self):
return self.data_len
def __getitem__(self, idx):
return self.password_tensor_list[idx]
def idtensor_to_string(self, id_tensor2d):
id_list2d = id_tensor2d.cpu().detach().tolist()
refined = []
for l in id_list2d:
refined.append(refine_password_int(l, self.password_length))
return ord_to_string(refined, self.tokenizer.id2char)
def contains(self, password: str) -> bool:
return password in self.password_string_set
def count_hit(self, password_list) -> int:
cnt = 0
for password in password_list:
# Left-pad with sos
if password[0] != SOS:
password = SOS + password
# Right-pad with eos
i = password.find(EOS)
if i > -1:
password = password[:i]
password = password + EOS * \
(self.password_length + 2 - len(password))
if password in self.password_set:
cnt += 1
return cnt