diff --git a/noise/noise_protocol.py b/noise/noise_protocol.py index 1f4847c..6619092 100644 --- a/noise/noise_protocol.py +++ b/noise/noise_protocol.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Tuple +from typing import Tuple, List from .constants import MAX_PROTOCOL_NAME_LEN, Empty from .functions import dh_map, cipher_map, hash_map, keypair_map, hmac_hash, hkdf @@ -18,7 +18,7 @@ class NoiseProtocol(object): 'keypair': keypair_map } - def __init__(self, protocol_name: bytes): + def __init__(self, protocol_name: bytes, psks: List[bytes]=None): if not isinstance(protocol_name, bytes): raise ValueError('Protocol name has to be of type "bytes", not {}'.format(type(protocol_name))) if len(protocol_name) > MAX_PROTOCOL_NAME_LEN: @@ -33,6 +33,16 @@ class NoiseProtocol(object): if self.pattern_modifiers: self.pattern.apply_pattern_modifiers(pattern_modifiers) + # Handle PSK handshake options + self.psks = psks + self.is_psk_handshake = False if not self.psks else True + if self.is_psk_handshake: + if any([len(psk) != 32 for psk in self.psks]): + raise ValueError('Invalid psk length!') + if len(self.psks) != self.pattern.psk_count: + raise ValueError('Bad number of PSKs provided to this protocol! {} are required, given {}'.format( + self.pattern.psk_count, len(self.psks))) + self.dh_fn = mappings['dh'] self.cipher_fn = mappings['cipher'] self.hash_fn = mappings['hash'] @@ -43,7 +53,6 @@ class NoiseProtocol(object): self.initiator = None self.one_way = False self.handshake_hash = None - self.psks = None # Placeholder for PSKs self.handshake_state = Empty() self.symmetric_state = Empty() @@ -87,9 +96,6 @@ class NoiseProtocol(object): return mapped_data, modifiers - def set_psks(self, psks: list) -> None: - self.psks = psks - def handshake_done(self): self.initiator = self.handshake_state.initiator if self.pattern.one_way: diff --git a/noise/patterns.py b/noise/patterns.py index f4a6dc6..575ca46 100644 --- a/noise/patterns.py +++ b/noise/patterns.py @@ -7,20 +7,21 @@ class Pattern(object): """ TODO document """ - # As per specification, if both parties have pre-messages, the initiator is listed first. To reduce complexity, - # pre_messages shall be a list of two lists: - # the first for the initiator's pre-messages, the second for the responder - pre_messages = [ - [], - [] - ] - - # List of lists of valid tokens, alternating between tokens for initiator and responder - tokens = [] - def __init__(self): + # As per specification, if both parties have pre-messages, the initiator is listed first. To reduce complexity, + # pre_messages shall be a list of two lists: + # the first for the initiator's pre-messages, the second for the responder + self.pre_messages = [ + [], + [] + ] + + # List of lists of valid tokens, alternating between tokens for initiator and responder + self.tokens = [] + self.has_pre_messages = any(map(lambda x: len(x) > 0, self.pre_messages)) self.one_way = False + self.psk_count = 0 def get_initiator_pre_messages(self) -> list: return self.pre_messages[0].copy() @@ -41,10 +42,11 @@ class Pattern(object): raise ValueError('Modifier {} cannot be applied - pattern has not enough messages'.format(modifier)) # Add TOKEN_PSK in the correct place in the correct message - if index % 2 == 0: - self.tokens[index//2].insert(0, TOKEN_PSK) - else: - self.tokens[index//2].append(TOKEN_PSK) + if index == 0: # if 0, insert at the beginning of first message + self.tokens[0].insert(0, TOKEN_PSK) + else: # if bigger than zero, append at the end of first, second etc. + self.tokens[index - 1].append(TOKEN_PSK) + self.psk_count += 1 elif modifier == 'fallback': raise NotImplementedError # TODO implement @@ -57,151 +59,196 @@ class Pattern(object): class OneWayPattern(Pattern): def __init__(self): - super(Pattern, self).__init__() + super(OneWayPattern, self).__init__() self.one_way = True class PatternN(OneWayPattern): - pre_messages = [ - [], - [TOKEN_S] - ] - tokens = [ - [TOKEN_E, TOKEN_ES] - ] + def __init__(self): + super(PatternN, self).__init__() + + self.pre_messages = [ + [], + [TOKEN_S] + ] + self.tokens = [ + [TOKEN_E, TOKEN_ES] + ] class PatternK(OneWayPattern): - pre_messages = [ - [TOKEN_S], - [TOKEN_S] - ] - tokens = [ - [TOKEN_E, TOKEN_ES, TOKEN_SS] - ] + def __init__(self): + super(PatternK, self).__init__() + + self.pre_messages = [ + [TOKEN_S], + [TOKEN_S] + ] + self.tokens = [ + [TOKEN_E, TOKEN_ES, TOKEN_SS] + ] class PatternX(OneWayPattern): - pre_messages = [ - [], - [TOKEN_S] - ] - tokens = [ - [TOKEN_E, TOKEN_ES, TOKEN_S, TOKEN_SS] - ] + def __init__(self): + super(PatternX, self).__init__() + + self.pre_messages = [ + [], + [TOKEN_S] + ] + self.tokens = [ + [TOKEN_E, TOKEN_ES, TOKEN_S, TOKEN_SS] + ] # Interactive patterns class PatternNN(Pattern): - tokens = [ - [TOKEN_E], - [TOKEN_E, TOKEN_EE] - ] + def __init__(self): + super(PatternNN, self).__init__() + + self.tokens = [ + [TOKEN_E], + [TOKEN_E, TOKEN_EE] + ] class PatternKN(Pattern): - pre_messages = [ - [TOKEN_S], - [] - ] - tokens = [ - [TOKEN_E], - [TOKEN_E, TOKEN_EE, TOKEN_SE] - ] + def __init__(self): + super(PatternKN, self).__init__() + + self.pre_messages = [ + [TOKEN_S], + [] + ] + self.tokens = [ + [TOKEN_E], + [TOKEN_E, TOKEN_EE, TOKEN_SE] + ] class PatternNK(Pattern): - pre_messages = [ - [], - [TOKEN_S] - ] - tokens = [ - [TOKEN_E, TOKEN_ES], - [TOKEN_E, TOKEN_EE] - ] + def __init__(self): + super(PatternNK, self).__init__() + + self.pre_messages = [ + [], + [TOKEN_S] + ] + self.tokens = [ + [TOKEN_E, TOKEN_ES], + [TOKEN_E, TOKEN_EE] + ] class PatternKK(Pattern): - pre_messages = [ - [TOKEN_S], - [TOKEN_S] - ] - tokens = [ - [TOKEN_E, TOKEN_ES, TOKEN_SS], - [TOKEN_E, TOKEN_EE, TOKEN_SE] - ] + def __init__(self): + super(PatternKK, self).__init__() + + self.pre_messages = [ + [TOKEN_S], + [TOKEN_S] + ] + self.tokens = [ + [TOKEN_E, TOKEN_ES, TOKEN_SS], + [TOKEN_E, TOKEN_EE, TOKEN_SE] + ] class PatternNX(Pattern): - tokens = [ - [TOKEN_E], - [TOKEN_E, TOKEN_EE, TOKEN_S, TOKEN_ES] - ] + def __init__(self): + super(PatternNX, self).__init__() + + self.tokens = [ + [TOKEN_E], + [TOKEN_E, TOKEN_EE, TOKEN_S, TOKEN_ES] + ] class PatternKX(Pattern): - pre_messages = [ - [TOKEN_S], - [] - ] - tokens = [ - [TOKEN_E], - [TOKEN_E, TOKEN_EE, TOKEN_SE, TOKEN_S, TOKEN_ES] - ] + def __init__(self): + super(PatternKX, self).__init__() + + self.pre_messages = [ + [TOKEN_S], + [] + ] + self.tokens = [ + [TOKEN_E], + [TOKEN_E, TOKEN_EE, TOKEN_SE, TOKEN_S, TOKEN_ES] + ] class PatternXN(Pattern): - tokens = [ - [TOKEN_E], - [TOKEN_E, TOKEN_EE], - [TOKEN_S, TOKEN_SE] - ] + def __init__(self): + super(PatternXN, self).__init__() + + self.tokens = [ + [TOKEN_E], + [TOKEN_E, TOKEN_EE], + [TOKEN_S, TOKEN_SE] + ] class PatternIN(Pattern): - tokens = [ - [TOKEN_E, TOKEN_S], - [TOKEN_E, TOKEN_EE, TOKEN_SE] - ] + def __init__(self): + super(PatternIN, self).__init__() + + self.tokens = [ + [TOKEN_E, TOKEN_S], + [TOKEN_E, TOKEN_EE, TOKEN_SE] + ] class PatternXK(Pattern): - pre_messages = [ - [], - [TOKEN_S] - ] - tokens = [ - [TOKEN_E, TOKEN_ES], - [TOKEN_E, TOKEN_EE], - [TOKEN_S, TOKEN_SE] - ] + def __init__(self): + super(PatternXK, self).__init__() + + self.pre_messages = [ + [], + [TOKEN_S] + ] + self.tokens = [ + [TOKEN_E, TOKEN_ES], + [TOKEN_E, TOKEN_EE], + [TOKEN_S, TOKEN_SE] + ] class PatternIK(Pattern): - pre_messages = [ - [], - [TOKEN_S] - ] - tokens = [ - [TOKEN_E, TOKEN_ES, TOKEN_S, TOKEN_SS], - [TOKEN_E, TOKEN_EE, TOKEN_SE] - ] + def __init__(self): + super(PatternIK, self).__init__() + + self.pre_messages = [ + [], + [TOKEN_S] + ] + self.tokens = [ + [TOKEN_E, TOKEN_ES, TOKEN_S, TOKEN_SS], + [TOKEN_E, TOKEN_EE, TOKEN_SE] + ] class PatternXX(Pattern): - tokens = [ - [TOKEN_E], - [TOKEN_E, TOKEN_EE, TOKEN_S, TOKEN_ES], - [TOKEN_S, TOKEN_SE] - ] + def __init__(self): + super(PatternXX, self).__init__() + + self.tokens = [ + [TOKEN_E], + [TOKEN_E, TOKEN_EE, TOKEN_S, TOKEN_ES], + [TOKEN_S, TOKEN_SE] + ] class PatternIX(Pattern): - tokens = [ - [TOKEN_E, TOKEN_S], - [TOKEN_E, TOKEN_EE, TOKEN_SE, TOKEN_S, TOKEN_ES] - ] + def __init__(self): + super(PatternIX, self).__init__() + + self.tokens = [ + [TOKEN_E, TOKEN_S], + [TOKEN_E, TOKEN_EE, TOKEN_SE, TOKEN_S, TOKEN_ES] + ] patterns_map = { diff --git a/noise/state.py b/noise/state.py index fba5b25..17ce57a 100644 --- a/noise/state.py +++ b/noise/state.py @@ -126,6 +126,17 @@ class SymmetricState(object): """ self.h = self.noise_protocol.hash_fn.hash(self.h + data) + def mix_key_and_hash(self, input_key_material: bytes): + # Sets ck, temp_h, temp_k = HKDF(ck, input_key_material, 3). + self.ck, temp_h, temp_k = self.noise_protocol.hkdf(self.ck, input_key_material, 3) + # Calls MixHash(temp_h). + self.mix_hash(temp_h) + # If HASHLEN is 64, then truncates temp_k to 32 bytes. + if self.noise_protocol.hash_fn.hashlen == 64: + temp_k = temp_k[:32] + # Calls InitializeKey(temp_k). + self.noise_protocol.cipher_state_handshake.initialize_key(temp_k) + def encrypt_and_hash(self, plaintext: bytes) -> bytes: """ Sets ciphertext = EncryptWithAd(h, plaintext), calls MixHash(ciphertext), and returns ciphertext. Note that if @@ -268,6 +279,8 @@ class HandshakeState(object): self.e = self.noise_protocol.dh_fn.generate_keypair() if isinstance(self.e, Empty) else self.e # TODO: it's workaround, otherwise use mock message_buffer.write(self.e.public_bytes) self.symmetric_state.mix_hash(self.e.public_bytes) + if self.noise_protocol.is_psk_handshake: + self.symmetric_state.mix_key(self.e.public_bytes) elif token == TOKEN_S: # Appends EncryptAndHash(s.public_key) to the buffer @@ -294,10 +307,9 @@ class HandshakeState(object): elif token == TOKEN_SS: # Calls MixKey(DH(s, rs)) self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s.private, self.rs.public)) - pass elif token == TOKEN_PSK: - raise NotImplementedError + self.symmetric_state.mix_key_and_hash(self.noise_protocol.psks.pop(0)) else: raise NotImplementedError('Pattern token: {}'.format(token)) @@ -325,6 +337,8 @@ class HandshakeState(object): # Sets re to the next DHLEN bytes from the message. Calls MixHash(re.public_key). self.re = self.noise_protocol.keypair_fn.from_public_bytes(message.read(dhlen)) self.symmetric_state.mix_hash(self.re.public_bytes) + if self.noise_protocol.is_psk_handshake: + self.symmetric_state.mix_key(self.re.public_bytes) elif token == TOKEN_S: # Sets temp to the next DHLEN + 16 bytes of the message if HasKey() == True, or to the next DHLEN bytes @@ -358,7 +372,7 @@ class HandshakeState(object): self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s.private, self.rs.public)) elif token == TOKEN_PSK: - raise NotImplementedError + self.symmetric_state.mix_key_and_hash(self.noise_protocol.psks.pop(0)) else: raise NotImplementedError('Pattern token: {}'.format(token)) diff --git a/tests/test_vectors.py b/tests/test_vectors.py index 1958212..fa5cf86 100644 --- a/tests/test_vectors.py +++ b/tests/test_vectors.py @@ -35,15 +35,15 @@ def _prepare_test_vectors(): for vector in vectors_list: if 'name' in vector and not 'protocol_name' in vector: # noise-c-* workaround vector['protocol_name'] = vector['name'] - if '_448_' in vector['protocol_name'] or 'psk' in vector['protocol_name'] or 'PSK' in vector['protocol_name']: - continue # TODO REMOVE WHEN ed448/psk/blake SUPPORT IS IMPLEMENTED/FIXED + if '_448_' in vector['protocol_name'] or 'PSK' in vector['protocol_name']: # no old NoisePSK tests + continue # TODO REMOVE WHEN ed448 SUPPORT IS IMPLEMENTED/FIXED for key, value in vector.copy().items(): if key in byte_fields: vector[key] = value.encode() if key in hexbyte_fields: vector[key] = bytes.fromhex(value) if key in list_fields: - vector[key] = [k.encode() for k in value] + vector[key] = [bytes.fromhex(k) for k in value] if key == dict_field: vector[key] = [] for dictionary in value: @@ -77,11 +77,12 @@ class TestVectors(object): def test_vector(self, vector): kwargs = self._prepare_handshake_state_kwargs(vector) - init_protocol = NoiseProtocol(vector['protocol_name']) - resp_protocol = NoiseProtocol(vector['protocol_name']) if 'init_psks' in vector and 'resp_psks' in vector: - init_protocol.set_psks(vector['init_psks']) - resp_protocol.set_psks(vector['resp_psks']) + init_protocol = NoiseProtocol(vector['protocol_name'], psks=vector['init_psks']) + resp_protocol = NoiseProtocol(vector['protocol_name'], psks=vector['resp_psks']) + else: + init_protocol = NoiseProtocol(vector['protocol_name']) + resp_protocol = NoiseProtocol(vector['protocol_name']) kwargs['init'].update(noise_protocol=init_protocol, initiator=True, prologue=vector['init_prologue']) kwargs['resp'].update(noise_protocol=resp_protocol, initiator=False, prologue=vector['resp_prologue'])