diff --git a/noise/functions.py b/noise/functions.py new file mode 100644 index 0000000..3af6f9f --- /dev/null +++ b/noise/functions.py @@ -0,0 +1,66 @@ +from .crypto import ed448 + +from Crypto.Cipher import AES, ChaCha20 +from Crypto.Hash import BLAKE2b, BLAKE2s, SHA256, SHA512 +import ed25519 + + +dh_map = { + '25519': ed25519, + '448': ed448 # TODO implement +} + +cipher_map = { + 'AESGCM': AES, + 'ChaChaPoly': ChaCha20 +} + +hash_map = { + # TODO benchmark pycryptodome vs hashlib implementation + 'BLAKE2b': BLAKE2b, + 'BLAKE2s': BLAKE2s, + 'SHA256': SHA256, + 'SHA512': SHA512 +} + + +class DH(object): + def __init__(self, method): + self.method = method + self.dhlen = 0 + self.dh = None + + def generate_keypair(self) -> 'KeyPair': + pass + + +class Cipher(object): + def __init__(self, method): + pass + + def encrypt(self, k, n, ad, plaintext): + pass + + def decrypt(self, k, n, ad, ciphertext): + pass + + +class Hash(object): + def __init__(self, method): + self.hashlen = 0 + self.blocklen = 0 + + def hash(self): + pass + + +class KeyPair(object): + def __init__(self, public='', private=''): + # TODO: Maybe switch to properties? + self.public = public + self.private = private + if private and not public: + self.derive_public_key() + + def derive_public_key(self): + pass diff --git a/noise/noise_protocol.py b/noise/noise_protocol.py index 0aedf94..df95424 100644 --- a/noise/noise_protocol.py +++ b/noise/noise_protocol.py @@ -1,32 +1,8 @@ -from functools import partial from typing import Tuple -from .patterns import patterns_map from .constants import MAX_PROTOCOL_NAME_LEN -from .crypto import ed448 - -from Crypto.Cipher import AES, ChaCha20 -from Crypto.Hash import BLAKE2b, BLAKE2s, SHA256, SHA512 -import ed25519 - - -dh_map = { - '25519': ed25519.create_keypair, - '448': ed448 # TODO implement -} - -cipher_map = { - 'AESGCM': partial(AES.new, mode=AES.MODE_GCM), - 'ChaChaPoly': lambda key: ChaCha20.new(key=key) -} - -hash_map = { - # TODO benchmark vs hashlib implementation - 'BLAKE2b': BLAKE2b, # TODO PARTIALS - 'BLAKE2s': BLAKE2s, # TODO PARTIALS - 'SHA256': SHA256, # TODO PARTIALS - 'SHA512': SHA512 # TODO PARTIALS -} +from .functions import dh_map, cipher_map, hash_map +from .patterns import patterns_map class NoiseProtocol(object): @@ -56,6 +32,8 @@ class NoiseProtocol(object): self.cipher = mappings['pattern'] self.hash = mappings['pattern'] + self.psks = None # Placeholder for PSKs + def _parse_protocol_name(self) -> Tuple[dict, list]: unpacked = self.name.split('_') if unpacked[0] != 'Noise': @@ -91,9 +69,5 @@ class NoiseProtocol(object): return mapped_data, modifiers - -class KeyPair(object): - def __init__(self, public='', private=''): - # TODO: Maybe switch to properties? - self.public = public - self.private = private + def set_psks(self, psks: list) -> None: + self.psks = psks diff --git a/noise/state.py b/noise/state.py index 2a8d967..5e635c7 100644 --- a/noise/state.py +++ b/noise/state.py @@ -147,9 +147,11 @@ class HandshakeState(object): # Calls MixHash() once for each public key listed in the pre-messages from handshake_pattern, with the specified # public key as input (...). If both initiator and responder have pre-messages, the initiator’s public keys are # hashed first - for keypair in map(instance._get_local_keypair, handshake_pattern.get_initiator_pre_messages()): + initiator_keypair_getter = instance._get_local_keypair if initiator else instance._get_remote_keypair + responder_keypair_getter = instance._get_remote_keypair if initiator else instance._get_local_keypair + for keypair in map(initiator_keypair_getter, handshake_pattern.get_initiator_pre_messages()): instance.symmetric_state.mix_hash(keypair.public) - for keypair in map(instance._get_remote_keypair, handshake_pattern.get_responder_pre_messages()): + for keypair in map(responder_keypair_getter, handshake_pattern.get_responder_pre_messages()): instance.symmetric_state.mix_hash(keypair.public) # Sets message_patterns to the message patterns from handshake_pattern diff --git a/tests/test_vectors.py b/tests/test_vectors.py index ee99b26..5fc5bb0 100644 --- a/tests/test_vectors.py +++ b/tests/test_vectors.py @@ -4,6 +4,7 @@ import os import pytest +from noise.functions import KeyPair from noise.state import HandshakeState from noise.noise_protocol import NoiseProtocol @@ -12,7 +13,7 @@ logger = logging.getLogger(__name__) vector_files = ['vectors/cacophony.txt'] -def prepare_test_vectors(): +def _prepare_test_vectors(): vectors = [] for path in vector_files: with open(os.path.join(os.path.dirname(__file__), path)) as fd: @@ -21,16 +22,39 @@ def prepare_test_vectors(): return vectors -@pytest.fixture(params=prepare_test_vectors()) -def vector(request): - yield request.param +class TestVectors(object): + @pytest.fixture(params=_prepare_test_vectors()) + def vector(self, request): + yield request.param + def _prepare_handshake_state_kwargs(self, vector): + # TODO: This is ugly af, refactor it :/ + kwargs = {'init': {}, 'resp': {}} + for role in ['init', 'resp']: + for key, kwarg in [('static', 's'), ('ephemeral', 'e'), ('remote_static', 'rs')]: + role_key = role + '_' + key + if role_key in vector: + if key in ['static', 'ephemeral']: + kwargs[role][kwarg] = KeyPair(private=vector[role_key]) + else: + kwargs[role][kwarg] = KeyPair(public=vector[role_key]) + return kwargs -def test_vector(vector): - logging.info('Testing vector {}'.format(vector['protocol_name'])) - init_protocol = NoiseProtocol(vector['protocol_name']) - resp_protocol = NoiseProtocol(vector['protocol_name']) - initiator = HandshakeState.initialize(noise_protocol=init_protocol, handshake_pattern=init_protocol.pattern, - initiator=True, prologue=vector['init_prologue']) - responder = HandshakeState.initialize(noise_protocol=resp_protocol, handshake_pattern=resp_protocol.pattern, - initiator=True, prologue=vector['resp_prologue']) + def test_vector(self, vector): + logging.info('Testing vector {}'.format(vector['protocol_name'])) + + kwargs = self._prepare_handshake_state_kwargs(vector) + + init_protocol = NoiseProtocol(vector['protocol_name']) + resp_protocol = NoiseProtocol(vector['protocol_name']) + if 'init_psks' in vector and 'resp_psks' in vector: + init_protocol.set_psks(vector['init_psks']) + resp_protocol.set_psks(vector['resp_psks']) + + kwargs['init'].update(noise_protocol=init_protocol, handshake_pattern=init_protocol.pattern, initiator=True, + prologue=vector['init_prologue']) + kwargs['resp'].update(noise_protocol=resp_protocol, handshake_pattern=resp_protocol.pattern, initiator=False, + prologue=vector['resp_prologue']) + + initiator = HandshakeState.initialize(**kwargs['init']) + responder = HandshakeState.initialize(**kwargs['resp'])