169 lines
4.7 KiB
Python
169 lines
4.7 KiB
Python
import string
|
|
import time
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
from constants import EOS, EOS_ID, SOS, SOS_ID, PAD, PAD_ID
|
|
import numpy as np
|
|
|
|
|
|
def is_ascii(s: str) -> bool:
|
|
try:
|
|
s.encode('ascii')
|
|
except UnicodeEncodeError:
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
|
|
def build_uniquelist_from_file(fname: str, password_length_limit: int) -> list[str]:
|
|
passwdset = set()
|
|
raw_list = [x.strip() for x in open(fname, 'r').readlines()]
|
|
for passwd in raw_list:
|
|
if not is_ascii(passwd):
|
|
continue
|
|
if len(passwd) > password_length_limit:
|
|
continue
|
|
if len(passwd) < 1:
|
|
continue
|
|
passwdset.add(passwd)
|
|
return passwdset
|
|
|
|
|
|
def ord_to_string(l, inv_charmap) -> str:
|
|
ret = []
|
|
for pl in l:
|
|
s = ""
|
|
for i in range(len(pl)):
|
|
s += inv_charmap[pl[i]]
|
|
ret.append(s)
|
|
return ret
|
|
|
|
|
|
# Left and right pad the list
|
|
# Returns password_length + 2
|
|
# +2 is to ensure <sos> at beginning and at least one <eos>
|
|
def refine_password_char(l: list[str], password_length: int) -> list[int]:
|
|
if l[0] != SOS:
|
|
l = [SOS] + l
|
|
first_eos = -1
|
|
for i in range(len(l)):
|
|
if l[i] == EOS:
|
|
first_eos = i
|
|
break
|
|
if first_eos > -1:
|
|
l = l[:first_eos]
|
|
ret = l + [EOS]
|
|
|
|
ret = ret + [PAD for _ in range(password_length + 2 - len(ret))]
|
|
|
|
assert len(ret) == password_length + 2
|
|
return ret
|
|
|
|
|
|
def refine_password_int(l: list[int], password_length: int) -> list[int]:
|
|
if l[0] != SOS_ID:
|
|
l = [SOS_ID] + l
|
|
first_eos = -1
|
|
for i in range(len(l)):
|
|
if l[i] == EOS_ID:
|
|
first_eos = i
|
|
break
|
|
if first_eos > -1:
|
|
l = l[:first_eos]
|
|
ret = l + [EOS_ID]
|
|
|
|
ret = ret + [PAD_ID for _ in range(password_length + 2 - len(ret))]
|
|
|
|
assert len(ret) == password_length + 2
|
|
return ret
|
|
|
|
|
|
def load_dataset(path, max_length, tokenize=False, max_vocab_size=2048):
|
|
lines = []
|
|
with open(path, 'r') as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
lines.append(refine_password_char(list(line), max_length))
|
|
|
|
np.random.shuffle(lines)
|
|
|
|
print("loaded {} lines in dataset".format(len(lines)))
|
|
|
|
return lines
|
|
|
|
|
|
class NaiveTokenizer():
|
|
def __init__(self, dataset_paths: list[str]):
|
|
self.password_list = []
|
|
self.char2id = {EOS: EOS_ID, SOS: SOS_ID, PAD: PAD_ID}
|
|
self.id2char = {EOS_ID: EOS, SOS_ID: SOS, PAD_ID: PAD}
|
|
|
|
for i in range(3, 256):
|
|
self.char2id[chr(i)] = i
|
|
self.id2char[i] = chr(i)
|
|
|
|
def vocab_size(self) -> int:
|
|
return len(self.id2char)
|
|
|
|
|
|
class PasswordDataset(Dataset):
|
|
def __init__(self, dataset_path, password_length, tokenizer, args):
|
|
self.tokenizer = tokenizer
|
|
self.password_length = password_length
|
|
self.password_list = []
|
|
with open(dataset_path, 'r') as f:
|
|
for line in f.readlines():
|
|
line = line.strip()
|
|
if len(line) > password_length or len(line) < 1:
|
|
continue
|
|
self.password_list.append(refine_password_char(
|
|
list(line), password_length))
|
|
|
|
self.args = args
|
|
self.data_len = len(self.password_list)
|
|
|
|
self.password_string_set = set(
|
|
[''.join(x) for x in self.password_list])
|
|
self.password_tensor_list = []
|
|
|
|
# TODO: this line takes too long...
|
|
for password in self.password_list:
|
|
l = [self.tokenizer.char2id[x] for x in password]
|
|
self.password_tensor_list.append(
|
|
torch.tensor(l, dtype=torch.long)
|
|
)
|
|
|
|
def __len__(self):
|
|
return self.data_len
|
|
|
|
def __getitem__(self, idx):
|
|
return self.password_tensor_list[idx]
|
|
|
|
def idtensor_to_string(self, id_tensor2d):
|
|
id_list2d = id_tensor2d.cpu().detach().tolist()
|
|
refined = []
|
|
for l in id_list2d:
|
|
refined.append(refine_password_int(l, self.password_length))
|
|
return ord_to_string(refined, self.tokenizer.id2char)
|
|
|
|
def contains(self, password: str) -> bool:
|
|
return password in self.password_string_set
|
|
|
|
def count_hit(self, password_list) -> int:
|
|
cnt = 0
|
|
for password in password_list:
|
|
# Left-pad with sos
|
|
if password[0] != SOS:
|
|
password = SOS + password
|
|
|
|
# Right-pad with eos
|
|
i = password.find(EOS)
|
|
if i > -1:
|
|
password = password[:i]
|
|
password = password + EOS * \
|
|
(self.password_length + 2 - len(password))
|
|
|
|
if password in self.password_set:
|
|
cnt += 1
|
|
return cnt
|