diff --git a/noise/functions.py b/noise/functions.py index 3af6f9f..231301f 100644 --- a/noise/functions.py +++ b/noise/functions.py @@ -5,25 +5,6 @@ from Crypto.Hash import BLAKE2b, BLAKE2s, SHA256, SHA512 import ed25519 -dh_map = { - '25519': ed25519, - '448': ed448 # TODO implement -} - -cipher_map = { - 'AESGCM': AES, - 'ChaChaPoly': ChaCha20 -} - -hash_map = { - # TODO benchmark pycryptodome vs hashlib implementation - 'BLAKE2b': BLAKE2b, - 'BLAKE2s': BLAKE2s, - 'SHA256': SHA256, - 'SHA512': SHA512 -} - - class DH(object): def __init__(self, method): self.method = method @@ -47,15 +28,38 @@ class Cipher(object): class Hash(object): def __init__(self, method): - self.hashlen = 0 - self.blocklen = 0 + if method == 'SHA256': + self.hashlen = 32 + self.blocklen = 64 + self.hash = self._hash_sha256 + elif method == 'SHA512': + self.hashlen = 64 + self.blocklen = 128 + self.hash = self._hash_sha512 + elif method == 'BLAKE2s': + self.hashlen = 32 + self.blocklen = 64 + self.hash = self._hash_blake2s + elif method == 'BLAKE2b': + self.hashlen = 64 + self.blocklen = 128 + self.hash = self._hash_blake2b - def hash(self): - pass + def _hash_sha256(self, data): + return SHA256.new(data).digest() + + def _hash_sha512(self, data): + return SHA512.new(data).digest() + + def _hash_blake2s(self, data): + return BLAKE2s.new(data=data, digest_bytes=self.hashlen).digest() + + def _hash_blake2b(self, data): + return BLAKE2b.new(data=data, digest_bytes=self.hashlen).digest() class KeyPair(object): - def __init__(self, public='', private=''): + def __init__(self, public=b'', private=b''): # TODO: Maybe switch to properties? self.public = public self.private = private @@ -64,3 +68,25 @@ class KeyPair(object): def derive_public_key(self): pass + + +# Available crypto functions +# TODO: Check if it's safe to use one instance globally per cryptoalgorithm - i.e. if wrapper only provides interface +# If not - switch to partials(?) +dh_map = { + '25519': DH('ed25519'), + '448': DH('ed448') +} + +cipher_map = { + 'AESGCM': Cipher('AESGCM'), + 'ChaChaPoly': Cipher('ChaCha20') +} + +hash_map = { + # TODO benchmark pycryptodome vs hashlib implementation + 'BLAKE2s': Hash('BLAKE2s'), + 'BLAKE2b': Hash('BLAKE2b'), + 'SHA256': Hash('SHA256'), + 'SHA512': Hash('SHA512') +} diff --git a/noise/noise_protocol.py b/noise/noise_protocol.py index df95424..cc18a3f 100644 --- a/noise/noise_protocol.py +++ b/noise/noise_protocol.py @@ -1,6 +1,6 @@ from typing import Tuple -from .constants import MAX_PROTOCOL_NAME_LEN +from .constants import MAX_PROTOCOL_NAME_LEN, Empty from .functions import dh_map, cipher_map, hash_map from .patterns import patterns_map @@ -17,25 +17,32 @@ class NoiseProtocol(object): } def __init__(self, protocol_name: bytes): + if not isinstance(protocol_name, bytes): + raise ValueError('Protocol name has to be of type "bytes", not {}'.format(type(protocol_name))) if len(protocol_name) > MAX_PROTOCOL_NAME_LEN: raise ValueError('Protocol name too long, has to be at most {} chars long'.format(MAX_PROTOCOL_NAME_LEN)) self.name = protocol_name mappings, pattern_modifiers = self._parse_protocol_name() + # A valid Pattern instance (see Section 7 of specification (rev 32)) self.pattern = mappings['pattern']() self.pattern_modifiers = pattern_modifiers if self.pattern_modifiers: self.pattern.apply_pattern_modifiers(pattern_modifiers) - self.dh = mappings['pattern'] - self.cipher = mappings['pattern'] - self.hash = mappings['pattern'] + self.dh_fn: 'DH' = mappings['dh'] + self.cipher_fn: 'Cipher' = mappings['cipher'] + self.hash_fn: 'Hash' = mappings['hash'] - self.psks = None # Placeholder for PSKs + self.psks: list = None # Placeholder for PSKs + + self.handshake_state: 'HandshakeState' = Empty() + self.symmetric_state: 'SymmetricState' = Empty() + self.cipher_state: 'CipherState' = Empty() def _parse_protocol_name(self) -> Tuple[dict, list]: - unpacked = self.name.split('_') + unpacked = self.name.decode().split('_') if unpacked[0] != 'Noise': raise ValueError('Noise Protocol name shall begin with Noise! Provided: {}'.format(self.name)) diff --git a/noise/state.py b/noise/state.py index 5e635c7..3151e19 100644 --- a/noise/state.py +++ b/noise/state.py @@ -3,24 +3,33 @@ from .constants import Empty class CipherState(object): """ - + Implemented as per Noise Protocol specification (rev 32) - paragraph 5.1. + + The initialize_key() function takes additional required argument - noise_protocol. """ def __init__(self): self.k = Empty() self.n = None + self.noise_protocol = None - def initialize_key(self, key): + @classmethod + def initialize_key(cls, key, noise_protocol: 'NoiseProtocol') -> 'CipherState': """ :param key: - :return: + :param noise_protocol: a valid NoiseProtocol instance + :return: initialised CipherState instance """ - self.k = key - self.n = 0 + instance = cls() + instance.noise_protocol = noise_protocol + noise_protocol.cipher_state = instance + + instance.k = key + instance.n = 0 + return instance def has_key(self): """ - :return: True if self.k is not an instance of Empty """ return not isinstance(self.k, Empty) @@ -46,17 +55,43 @@ class CipherState(object): class SymmetricState(object): """ - + Implemented as per Noise Protocol specification (rev 32) - paragraph 5.2. + + The initialize_symmetric function takes different required argument - noise_protocol, which contains protocol_name. """ + def __init__(self): + self.h = None + self.ck = None + self.noise_protocol = None + @classmethod def initialize_symmetric(cls, noise_protocol: 'NoiseProtocol') -> 'SymmetricState': """ - - :param noise_protocol: - :return: + Instead of taking protocol_name as an argument, we take full NoiseProtocol object, that way we have access to + protocol name and crypto functions + + Comments below are mostly copied from specification. + :param noise_protocol: a valid NoiseProtocol instance + :return: initialised SymmetricState instance """ + # Create SymmetricState instance = cls() - # TODO + instance.noise_protocol = noise_protocol + noise_protocol.symmetric_state = instance + + # If protocol_name is less than or equal to HASHLEN bytes in length, sets h equal to protocol_name with zero + # bytes appended to make HASHLEN bytes. Otherwise sets h = HASH(protocol_name). + if len(noise_protocol.name) <= noise_protocol.hash_fn.hashlen: + instance.h = noise_protocol.name.ljust(noise_protocol.hash_fn.hashlen, b'\0') + else: + instance.h = noise_protocol.hash_fn.hash(noise_protocol.name) + + # Sets ck = h. + instance.ck = instance.h + + # Calls InitializeKey(empty). + CipherState.initialize_key(Empty(), noise_protocol) + return instance def mix_key(self, input_key_material): @@ -72,6 +107,7 @@ class SymmetricState(object): :param data: :return: """ + self.h = self.noise_protocol.hash_fn.hash(data + self.h) def encrypt_and_hash(self, plaintext): """ @@ -101,17 +137,17 @@ class HandshakeState(object): """ Implemented as per Noise Protocol specification (rev 32) - paragraph 5.3. - The initialize() function takes additional required argument - protocol_name - to provide it to SymmetricState. + The initialize() function takes different required argument - noise_protocol, which contains handshake_pattern. """ @classmethod - def initialize(cls, noise_protocol: 'NoiseProtocol', handshake_pattern: 'Pattern', initiator: bool, - prologue: bytes=b'', s: bytes=None, e: bytes=None, rs: bytes=None, - re: bytes=None) -> 'HandshakeState': + def initialize(cls, noise_protocol: 'NoiseProtocol', initiator: bool, prologue: bytes=b'', s: bytes=None, + e: bytes=None, rs: bytes=None, re: bytes=None) -> 'HandshakeState': """ Constructor method. Comments below are mostly copied from specification. + Instead of taking handshake_pattern as an argument, we take full NoiseProtocol object, that way we have access + to protocol name and crypto functions - :param handshake_pattern: a valid Pattern instance (see Section 7 of specification (rev 32)) :param noise_protocol: a valid NoiseProtocol instance :param initiator: boolean indicating the initiator or responder role :param prologue: byte sequence which may be zero-length, or which may contain context information that both @@ -124,12 +160,13 @@ class HandshakeState(object): """ # Create HandshakeState instance = cls() + instance.noise_protocol = noise_protocol + noise_protocol.handshake_state = instance # Originally in specification: # "Derives a protocol_name byte sequence by combining the names for # the handshake pattern and crypto functions, as specified in Section 8." # Instead, we supply the NoiseProtocol to the function. The protocol name should already be validated. - # We only check if the handshake pattern specified as an argument is the same as in the protocol name # Calls InitializeSymmetric(noise_protocol) instance.symmetric_state = SymmetricState.initialize_symmetric(noise_protocol) @@ -149,13 +186,13 @@ class HandshakeState(object): # hashed first initiator_keypair_getter = instance._get_local_keypair if initiator else instance._get_remote_keypair responder_keypair_getter = instance._get_remote_keypair if initiator else instance._get_local_keypair - for keypair in map(initiator_keypair_getter, handshake_pattern.get_initiator_pre_messages()): + for keypair in map(initiator_keypair_getter, noise_protocol.pattern.get_initiator_pre_messages()): instance.symmetric_state.mix_hash(keypair.public) - for keypair in map(responder_keypair_getter, handshake_pattern.get_responder_pre_messages()): + for keypair in map(responder_keypair_getter, noise_protocol.pattern.get_responder_pre_messages()): instance.symmetric_state.mix_hash(keypair.public) # Sets message_patterns to the message patterns from handshake_pattern - instance.message_patterns = handshake_pattern.tokens + instance.message_patterns = noise_protocol.pattern.tokens return instance diff --git a/tests/test_vectors.py b/tests/test_vectors.py index 5fc5bb0..e1e6977 100644 --- a/tests/test_vectors.py +++ b/tests/test_vectors.py @@ -12,13 +12,32 @@ logger = logging.getLogger(__name__) vector_files = ['vectors/cacophony.txt'] +# As in test vectors specification (https://github.com/noiseprotocol/noise_wiki/wiki/Test-vectors) +# We use this to cast read strings into bytes +string_fields = ['protocol_name', 'init_prologue', 'init_static', 'init_ephemeral', 'init_remote_static', + 'resp_prologue', 'resp_static', 'resp_ephemeral', 'resp_remote_static', 'handshake_hash'] +list_fields = ['init_psks', 'resp_psks'] +dict_field = 'messages' + def _prepare_test_vectors(): vectors = [] for path in vector_files: with open(os.path.join(os.path.dirname(__file__), path)) as fd: logging.info('Reading vectors from file {}'.format(path)) - vectors.extend(json.load(fd)) + vectors_list = json.load(fd) + + for vector in vectors_list: + for key, value in vector.copy().items(): + if key in string_fields: + vector[key] = value.encode() + if key in list_fields: + vector[key] = [k.encode() for k in value] + if key == dict_field: + vector[key] = [] + for dictionary in value: + vector[key].append({k: v.encode() for k, v in dictionary.items()}) + vectors.append(vector) return vectors @@ -51,10 +70,8 @@ class TestVectors(object): init_protocol.set_psks(vector['init_psks']) resp_protocol.set_psks(vector['resp_psks']) - kwargs['init'].update(noise_protocol=init_protocol, handshake_pattern=init_protocol.pattern, initiator=True, - prologue=vector['init_prologue']) - kwargs['resp'].update(noise_protocol=resp_protocol, handshake_pattern=resp_protocol.pattern, initiator=False, - prologue=vector['resp_prologue']) + kwargs['init'].update(noise_protocol=init_protocol, initiator=True, prologue=vector['init_prologue']) + kwargs['resp'].update(noise_protocol=resp_protocol, initiator=False, prologue=vector['resp_prologue']) initiator = HandshakeState.initialize(**kwargs['init']) responder = HandshakeState.initialize(**kwargs['resp'])