import string import time import torch from torch.utils.data import Dataset from constants import EOS, EOS_ID, SOS, SOS_ID, PAD, PAD_ID import numpy as np def is_ascii(s: str) -> bool: try: s.encode('ascii') except UnicodeEncodeError: return False else: return True def build_uniquelist_from_file(fname: str, password_length_limit: int) -> list[str]: passwdset = set() raw_list = [x.strip() for x in open(fname, 'r').readlines()] for passwd in raw_list: if not is_ascii(passwd): continue if len(passwd) > password_length_limit: continue if len(passwd) < 1: continue passwdset.add(passwd) return passwdset def ord_to_string(l, inv_charmap) -> str: ret = [] for pl in l: s = "" for i in range(len(pl)): s += inv_charmap[pl[i]] ret.append(s) return ret # Left and right pad the list # Returns password_length + 2 # +2 is to ensure at beginning and at least one def refine_password_char(l: list[str], password_length: int) -> list[int]: if l[0] != SOS: l = [SOS] + l first_eos = -1 for i in range(len(l)): if l[i] == EOS: first_eos = i break if first_eos > -1: l = l[:first_eos] ret = l + [EOS] ret = ret + [PAD for _ in range(password_length + 2 - len(ret))] assert len(ret) == password_length + 2 return ret def refine_password_int(l: list[int], password_length: int) -> list[int]: if l[0] != SOS_ID: l = [SOS_ID] + l first_eos = -1 for i in range(len(l)): if l[i] == EOS_ID: first_eos = i break if first_eos > -1: l = l[:first_eos] ret = l + [EOS_ID] ret = ret + [PAD_ID for _ in range(password_length + 2 - len(ret))] assert len(ret) == password_length + 2 return ret def load_dataset(path, max_length, tokenize=False, max_vocab_size=2048): lines = [] with open(path, 'r') as f: for line in f: line = line.strip() lines.append(refine_password_char(list(line), max_length)) np.random.shuffle(lines) print("loaded {} lines in dataset".format(len(lines))) return lines class NaiveTokenizer(): def __init__(self, dataset_paths: list[str]): self.password_list = [] self.char2id = {EOS: EOS_ID, SOS: SOS_ID, PAD: PAD_ID} self.id2char = {EOS_ID: EOS, SOS_ID: SOS, PAD_ID: PAD} for i in range(3, 256): self.char2id[chr(i)] = i self.id2char[i] = chr(i) def vocab_size(self) -> int: return len(self.id2char) class PasswordDataset(Dataset): def __init__(self, dataset_path, password_length, tokenizer, args): self.tokenizer = tokenizer self.password_length = password_length self.password_list = [] with open(dataset_path, 'r') as f: for line in f.readlines(): line = line.strip() if len(line) > password_length or len(line) < 1: continue self.password_list.append(refine_password_char( list(line), password_length)) self.args = args self.data_len = len(self.password_list) self.password_string_set = set( [''.join(x) for x in self.password_list]) self.password_tensor_list = [] # TODO: this line takes too long... for password in self.password_list: l = [self.tokenizer.char2id[x] for x in password] self.password_tensor_list.append( torch.tensor(l, dtype=torch.long) ) def __len__(self): return self.data_len def __getitem__(self, idx): return self.password_tensor_list[idx] def idtensor_to_string(self, id_tensor2d): id_list2d = id_tensor2d.cpu().detach().tolist() refined = [] for l in id_list2d: refined.append(refine_password_int(l, self.password_length)) return ord_to_string(refined, self.tokenizer.id2char) def contains(self, password: str) -> bool: return password in self.password_string_set def count_hit(self, password_list) -> int: cnt = 0 for password in password_list: # Left-pad with sos if password[0] != SOS: password = SOS + password # Right-pad with eos i = password.find(EOS) if i > -1: password = password[:i] password = password + EOS * \ (self.password_length + 2 - len(password)) if password in self.password_set: cnt += 1 return cnt