diff --git a/noise/constants.py b/noise/constants.py index 0d0b866..0241ab1 100644 --- a/noise/constants.py +++ b/noise/constants.py @@ -13,4 +13,4 @@ TOKEN_PSK = 'psk' # In bytes, as in Section 8 of specification (rev 32) -MAX_PROTOCOL_NAME_LEN = 255 \ No newline at end of file +MAX_PROTOCOL_NAME_LEN = 255 diff --git a/noise/crypto.py b/noise/crypto.py new file mode 100644 index 0000000..6e3856c --- /dev/null +++ b/noise/crypto.py @@ -0,0 +1,2 @@ +def ed448(*args, **kwargs): + raise NotImplementedError diff --git a/noise/noise_protocol.py b/noise/noise_protocol.py index e391581..5d722d3 100644 --- a/noise/noise_protocol.py +++ b/noise/noise_protocol.py @@ -1,14 +1,18 @@ from functools import partial +from typing import Tuple from .patterns import patterns_map from .constants import MAX_PROTOCOL_NAME_LEN +from .crypto import ed448 + from Crypto.Cipher import AES, ChaCha20 from Crypto.Hash import BLAKE2b, BLAKE2s, SHA256, SHA512 import ed25519 + dh_map = { '25519': ed25519.create_keypair, - '448': None # TODO implement + '448': ed448 # TODO implement } cipher_map = { @@ -17,6 +21,7 @@ cipher_map = { } hash_map = { + # TODO benchmark vs hashlib implementation 'BLAKE2b': BLAKE2b, # TODO PARTIALS 'BLAKE2s': BLAKE2s, # TODO PARTIALS 'SHA256': SHA256, # TODO PARTIALS @@ -25,35 +30,46 @@ hash_map = { class NoiseProtocol(object): + """ + TODO: Document + """ methods = { 'pattern': patterns_map, 'dh': dh_map, - + 'cipher': cipher_map, + 'hash': hash_map } + def __init__(self, protocol_name: bytes): if len(protocol_name) > MAX_PROTOCOL_NAME_LEN: raise ValueError('Protocol name too long, has to be at most {} chars long'.format(MAX_PROTOCOL_NAME_LEN)) self.name = protocol_name - data_dict = self._split_protocol_name() - self.pattern = patterns_map[data_dict['pattern']] - self.pattern_modifiers = None - self.dh = None - self.cipher = None - self.hash = None + mappings, pattern_modifiers = self._parse_protocol_name() - def _split_protocol_name(self): + self.pattern = mappings['pattern']() + self.pattern_modifiers = pattern_modifiers + if self.pattern_modifiers: + self.pattern.apply_pattern_modifiers(pattern_modifiers) + + self.dh = mappings['pattern'] + self.cipher = mappings['pattern'] + self.hash = mappings['pattern'] + + def _parse_protocol_name(self) -> Tuple[dict, list]: unpacked = self.name.split('_') if unpacked[0] != 'Noise': - raise ValueError(f'Noise protocol name shall begin with Noise! Provided: {self.name}') + raise ValueError(f'Noise Protocol name shall begin with Noise! Provided: {self.name}') + # Extract pattern name and pattern modifiers pattern = '' modifiers_str = None for i, char in enumerate(unpacked[1]): if char.isupper(): pattern += char else: - modifiers_str = unpacked[1][i+1:] # Will be empty string if it exceeds string size + # End of pattern, now look for modifiers + modifiers_str = unpacked[1][i:] # Will be empty string if it exceeds string size break modifiers = modifiers_str.split('+') if modifiers_str else [] @@ -63,10 +79,16 @@ class NoiseProtocol(object): 'hash': unpacked[4], 'pattern_modifiers': modifiers} - # Validate if we know everything that Noise Protocol is supposed to use - # TODO validation + mapped_data = {} - return data + # Validate if we know everything that Noise Protocol is supposed to use and map appropriate functions + for key, map_dict in self.methods.items(): + func = map_dict.get(data[key]) + if not func: + raise ValueError(f'Unknown {key} in Noise Protocol name, given {data[key]}, known {" ".join(map_dict)}') + mapped_data[key] = func + + return mapped_data, modifiers class KeyPair(object): diff --git a/noise/patterns.py b/noise/patterns.py index 526ac1c..47c3d18 100644 --- a/noise/patterns.py +++ b/noise/patterns.py @@ -1,3 +1,5 @@ +from typing import List + from .constants import TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK @@ -8,12 +10,13 @@ class Pattern(object): # As per specification, if both parties have pre-messages, the initiator is listed first. To reduce complexity, # pre_messages shall be a list of two lists: # the first for the initiator's pre-messages, the second for the responder - pre_messages = [ + pre_messages: List[list] = [ [], [] ] - # TODO Comment - tokens = [] + + # List of lists of valid tokens, alternating between tokens for initiator and responder + tokens: List[list] = [] def __init__(self): self.has_pre_messages = any(map(lambda x: len(x) > 0, self.pre_messages)) @@ -24,6 +27,30 @@ class Pattern(object): def get_responder_pre_messages(self) -> list: return self.pre_messages[1] + def apply_pattern_modifiers(self, modifiers: List[str]) -> None: + # Applies given pattern modifiers to self.tokens of the Pattern instance. + for modifier in modifiers: + if modifier.startswith('psk'): + try: + index = int(modifier.replace('psk', '', 1)) + except ValueError: + raise ValueError(f'Improper psk modifier {modifier}') + + if index // 2 > len(self.tokens): + raise ValueError(f'Modifier {modifier} cannot be applied - pattern has not enough messages') + + # Add TOKEN_PSK in the correct place in the correct message + if index % 2 == 0: + self.tokens[index//2].insert(0, TOKEN_PSK) + else: + self.tokens[index//2].append(TOKEN_PSK) + + elif modifier == 'fallback': + raise NotImplementedError # TODO implement + + else: + raise ValueError(f'Unknown pattern modifier {modifier}') + # One-way patterns diff --git a/noise/state.py b/noise/state.py index 7694000..2a8d967 100644 --- a/noise/state.py +++ b/noise/state.py @@ -49,10 +49,10 @@ class SymmetricState(object): """ @classmethod - def initialize_symmetric(cls, protocol_name) -> 'SymmetricState': + def initialize_symmetric(cls, noise_protocol: 'NoiseProtocol') -> 'SymmetricState': """ - :param protocol_name: + :param noise_protocol: :return: """ instance = cls() @@ -104,7 +104,7 @@ class HandshakeState(object): The initialize() function takes additional required argument - protocol_name - to provide it to SymmetricState. """ @classmethod - def initialize(cls, handshake_pattern: 'Pattern', protocol_name: 'NoiseProtocol', initiator: bool, + def initialize(cls, noise_protocol: 'NoiseProtocol', handshake_pattern: 'Pattern', initiator: bool, prologue: bytes=b'', s: bytes=None, e: bytes=None, rs: bytes=None, re: bytes=None) -> 'HandshakeState': """ @@ -112,7 +112,7 @@ class HandshakeState(object): Comments below are mostly copied from specification. :param handshake_pattern: a valid Pattern instance (see Section 7 of specification (rev 32)) - :param protocol_name: a valid NoiseProtocol instance + :param noise_protocol: a valid NoiseProtocol instance :param initiator: boolean indicating the initiator or responder role :param prologue: byte sequence which may be zero-length, or which may contain context information that both parties want to confirm is identical @@ -128,11 +128,11 @@ class HandshakeState(object): # Originally in specification: # "Derives a protocol_name byte sequence by combining the names for # the handshake pattern and crypto functions, as specified in Section 8." - # Instead, we supply the protocol name to the function. It should already be validated. We only check if the - # handshake pattern specified as an argument is the same as in the protocol name + # Instead, we supply the NoiseProtocol to the function. The protocol name should already be validated. + # We only check if the handshake pattern specified as an argument is the same as in the protocol name - # Calls InitializeSymmetric(protocol_name) - instance.symmetric_state = SymmetricState.initialize_symmetric(protocol_name) + # Calls InitializeSymmetric(noise_protocol) + instance.symmetric_state = SymmetricState.initialize_symmetric(noise_protocol) # Calls MixHash(prologue) instance.symmetric_state.mix_hash(prologue) diff --git a/tests/test_vectors.py b/tests/test_vectors.py index 7125007..f325efa 100644 --- a/tests/test_vectors.py +++ b/tests/test_vectors.py @@ -30,7 +30,7 @@ def test_vector(vector): logging.info(f"Testing vector {vector['protocol_name']}") init_protocol = NoiseProtocol(vector['protocol_name']) resp_protocol = NoiseProtocol(vector['protocol_name']) - initiator = HandshakeState.initialize(handshake_pattern=init_protocol.pattern, protocol_name=init_protocol.name, + initiator = HandshakeState.initialize(noise_protocol=init_protocol, handshake_pattern=init_protocol.pattern, initiator=True, prologue=vector['init_prologue']) - responder = HandshakeState.initialize(handshake_pattern=resp_protocol.pattern, protocol_name=resp_protocol.name, + responder = HandshakeState.initialize(noise_protocol=resp_protocol, handshake_pattern=resp_protocol.pattern, initiator=True, prologue=vector['resp_prologue'])