diff --git a/noise/noise_protocol.py b/noise/noise_protocol.py index 060e44e..3eb065e 100644 --- a/noise/noise_protocol.py +++ b/noise/noise_protocol.py @@ -27,13 +27,16 @@ class NoiseProtocol(object): self.psks = None self.is_psk_handshake = any([modifier.startswith('psk') for modifier in self.pattern_modifiers]) + # Preinitialized self.dh_fn = mappings['dh']() self.hash_fn = mappings['hash']() - self.cipher_fn = mappings['cipher'] - self.keypair_fn = mappings['keypair'] self.hmac = partial(backend.hmac, algorithm=self.hash_fn.fn) self.hkdf = partial(backend.hkdf, hmac_hash_fn=self.hmac) + # Initialized where needed + self.cipher_class = mappings['cipher'] + self.keypair_class = mappings['keypair'] + self.prologue = None self.initiator = None self.handshake_hash = None @@ -60,7 +63,7 @@ class NoiseProtocol(object): del self.initiator del self.dh_fn del self.hash_fn - del self.keypair_fn + del self.keypair_class def validate(self): if self.is_psk_handshake: diff --git a/noise/state.py b/noise/state.py index 63e53b6..312550f 100644 --- a/noise/state.py +++ b/noise/state.py @@ -16,7 +16,7 @@ class CipherState(object): def __init__(self, noise_protocol): self.k = Empty() self.n = None - self.cipher = noise_protocol.cipher_fn() + self.cipher = noise_protocol.cipher_class() def initialize_key(self, key): """ @@ -363,7 +363,7 @@ class HandshakeState(object): for token in message_pattern: if token == TOKEN_E: # Sets re to the next DHLEN bytes from the message. Calls MixHash(re.public_key). - self.re = self.noise_protocol.keypair_fn.from_public_bytes(bytes(message[:dhlen])) + self.re = self.noise_protocol.keypair_class.from_public_bytes(bytes(message[:dhlen])) message = message[dhlen:] self.symmetric_state.mix_hash(self.re.public_bytes) if self.noise_protocol.is_psk_handshake: @@ -378,7 +378,9 @@ class HandshakeState(object): else: temp = bytes(message[:dhlen]) message = message[dhlen:] - self.rs = self.noise_protocol.keypair_fn.from_public_bytes(self.symmetric_state.decrypt_and_hash(temp)) + self.rs = self.noise_protocol.keypair_class.from_public_bytes( + self.symmetric_state.decrypt_and_hash(temp) + ) elif token == TOKEN_EE: # Calls MixKey(DH(e, re)).