diff --git a/noise/__init__.py b/noise/__init__.py index e69de29..97064bd 100644 --- a/noise/__init__.py +++ b/noise/__init__.py @@ -0,0 +1 @@ +__all__ = ['builder'] diff --git a/noise/builder.py b/noise/builder.py new file mode 100644 index 0000000..b4b5562 --- /dev/null +++ b/noise/builder.py @@ -0,0 +1,131 @@ +from enum import Enum, auto +from typing import Union, List + +from noise.exceptions import NoisePSKError, NoiseValueError, NoiseHandshakeError +from .noise_protocol import NoiseProtocol + + +class Keypair(Enum): + STATIC = auto() + REMOTE_STATIC = auto() + EPHEMERAL = auto() + REMOTE_EPHEMERAL = auto() + + +_keypairs = {Keypair.STATIC: 's', Keypair.REMOTE_STATIC: 'rs', + Keypair.EPHEMERAL: 'e', Keypair.REMOTE_EPHEMERAL: 're'} + + +class NoiseBuilder(object): + def __init__(self): + self.noise_protocol = None + self.protocol_name = None + self.handshake_finished = False + self._handshake_started = False + self._next_fn = None + + @classmethod + def from_name(cls, name: Union[str, bytes]): + instance = cls() + # Forgiving passing string. Bytes are good too, anything else will fail inside NoiseProtocol + try: + instance.protocol_name = name.encode('ascii') if isinstance(name, str) else name + except ValueError: + raise NoiseValueError('If passing string as protocol name, it must contain only ASCII characters') + instance.noise_protocol = NoiseProtocol(protocol_name=name) + return instance + + def set_psks(self, psk: Union[bytes, str] = None, psks: List[Union[str, bytes]] = None): + if psk and psks: + raise NoisePSKError('Provide single PSK as psk or list of PSKs as psks') + if not psk and not psks: + raise NoisePSKError('No PSKs provided') + + psks = psks or [psk] + if not all([isinstance(psk, (bytes, str)) for psk in psks]): + raise NoisePSKError('PSKs must be strings or bytes') + + try: + self.noise_protocol.psks = [psk.encode('ascii') if isinstance(psk, str) else psk for psk in psks] + except UnicodeEncodeError: + raise NoisePSKError('If providing psks as (unicode) string, it must only contain ASCII characters') + + def set_prologue(self, prologue: Union[bytes, str]): + if isinstance(prologue, bytes): + self.noise_protocol.prologue = prologue + elif isinstance(prologue, str): + try: + self.noise_protocol.prologue = prologue.encode('ascii') + except UnicodeEncodeError: + raise NoiseValueError('Prologue must be ASCII string or bytes') + else: + raise NoiseValueError('Prologue must be ASCII string or bytes') + + def set_as_initiator(self): + self.noise_protocol.initiator = True + self._next_fn = self.write_message + + def set_as_responder(self): + self.noise_protocol.initiator = False + self._next_fn = self.read_message + + def set_keypair_from_private_bytes(self, keypair, private_bytes: bytes): + self.noise_protocol.keypairs[_keypairs[keypair]] = \ + self.noise_protocol.dh_fn.keypair_cls.from_private_bytes(private_bytes) + + def set_keypair_from_public_bytes(self, keypair, private_bytes: bytes): + self.noise_protocol.keypairs[_keypairs[keypair]] = \ + self.noise_protocol.dh_fn.keypair_cls.from_public_bytes(private_bytes) + + def set_keypair_from_private_path(self, keypair: Keypair, path: str): + with open(path, 'rb') as fd: + self.noise_protocol.keypairs[_keypairs[keypair]] = \ + self.noise_protocol.dh_fn.keypair_cls.from_private_bytes(fd.read()) + + def set_keypair_from_public_path(self, keypair: Keypair, path: str): + with open(path, 'rb') as fd: + self.noise_protocol.keypairs[_keypairs[keypair]] = \ + self.noise_protocol.dh_fn.keypair_cls.from_public_bytes(fd.read()) + + def start_handshake(self): + self.noise_protocol.validate() + self.noise_protocol.initialise_handshake_state() + self._handshake_started = True + + def write_message(self, payload: bytes=b'') -> bytearray: + if not self._handshake_started: + raise NoiseHandshakeError('Call NoiseBuilder.start_handshake first') + if self._next_fn != self.write_message: + raise NoiseHandshakeError('NoiseBuilder.read_message has to be called now') + if self.handshake_finished: + raise NoiseHandshakeError('Handshake finished. NoiseBuilder.encrypt should be used now') + self._next_fn = self.read_message + + buffer = bytearray() + result = self.noise_protocol.handshake_state.write_message(payload, buffer) + if result: + self.handshake_finished = True + return buffer + + def read_message(self, data: bytes) -> bytearray: + if not self._handshake_started: + raise NoiseHandshakeError('Call NoiseBuilder.start_handshake first') + if self._next_fn != self.read_message: + raise NoiseHandshakeError('NoiseBuilder.write_message has to be called now') + if self.handshake_finished: + raise NoiseHandshakeError('Handshake finished. NoiseBuilder.decrypt should be used now') + self._next_fn = self.write_message + + buffer = bytearray() + result = self.noise_protocol.handshake_state.read_message(data, buffer) + if result: + self.handshake_finished = True + return buffer + + def encrypt(self, data: bytes): + if not isinstance(data, bytes) or len(data) > 65535: + raise Exception #todo + return self.noise_protocol.cipher_state_encrypt.encrypt_with_ad(None, data) + + def decrypt(self, data: bytes): + return self.noise_protocol.cipher_state_decrypt.decrypt_with_ad(None, data) diff --git a/noise/constants.py b/noise/constants.py index 0241ab1..c32e42e 100644 --- a/noise/constants.py +++ b/noise/constants.py @@ -14,3 +14,5 @@ TOKEN_PSK = 'psk' # In bytes, as in Section 8 of specification (rev 32) MAX_PROTOCOL_NAME_LEN = 255 + +MAX_MESSAGE_LEN = 65535 diff --git a/noise/exceptions.py b/noise/exceptions.py new file mode 100644 index 0000000..95530e8 --- /dev/null +++ b/noise/exceptions.py @@ -0,0 +1,14 @@ +class NoiseProtocolNameError(Exception): + pass + + +class NoisePSKError(Exception): + pass + + +class NoiseValueError(Exception): + pass + + +class NoiseHandshakeError(Exception): + pass diff --git a/noise/noise_protocol.py b/noise/noise_protocol.py index 6619092..032ceb8 100644 --- a/noise/noise_protocol.py +++ b/noise/noise_protocol.py @@ -1,6 +1,8 @@ from functools import partial from typing import Tuple, List +from noise.exceptions import NoiseProtocolNameError, NoisePSKError +from noise.state import HandshakeState from .constants import MAX_PROTOCOL_NAME_LEN, Empty from .functions import dh_map, cipher_map, hash_map, keypair_map, hmac_hash, hkdf from .patterns import patterns_map @@ -18,11 +20,12 @@ class NoiseProtocol(object): 'keypair': keypair_map } - def __init__(self, protocol_name: bytes, psks: List[bytes]=None): + def __init__(self, protocol_name: bytes): if not isinstance(protocol_name, bytes): - raise ValueError('Protocol name has to be of type "bytes", not {}'.format(type(protocol_name))) + raise NoiseProtocolNameError('Protocol name has to be of type "bytes" not {}'.format(type(protocol_name))) if len(protocol_name) > MAX_PROTOCOL_NAME_LEN: - raise ValueError('Protocol name too long, has to be at most {} chars long'.format(MAX_PROTOCOL_NAME_LEN)) + raise NoiseProtocolNameError('Protocol name too long, has to be at most ' + '{} chars long'.format(MAX_PROTOCOL_NAME_LEN)) self.name = protocol_name mappings, pattern_modifiers = self._parse_protocol_name() @@ -34,14 +37,8 @@ class NoiseProtocol(object): self.pattern.apply_pattern_modifiers(pattern_modifiers) # Handle PSK handshake options - self.psks = psks - self.is_psk_handshake = False if not self.psks else True - if self.is_psk_handshake: - if any([len(psk) != 32 for psk in self.psks]): - raise ValueError('Invalid psk length!') - if len(self.psks) != self.pattern.psk_count: - raise ValueError('Bad number of PSKs provided to this protocol! {} are required, given {}'.format( - self.pattern.psk_count, len(self.psks))) + self.psks = None + self.is_psk_handshake = any([modifier.startswith('psk') for modifier in self.pattern_modifiers]) self.dh_fn = mappings['dh'] self.cipher_fn = mappings['cipher'] @@ -50,6 +47,7 @@ class NoiseProtocol(object): self.hmac = partial(hmac_hash, algorithm=self.hash_fn.fn) self.hkdf = partial(hkdf, hmac_hash_fn=self.hmac) + self.prologue = None self.initiator = None self.one_way = False self.handshake_hash = None @@ -60,10 +58,12 @@ class NoiseProtocol(object): self.cipher_state_encrypt = Empty() self.cipher_state_decrypt = Empty() + self.keypairs = {'s': None, 'e': None, 'rs': None, 're': None} + def _parse_protocol_name(self) -> Tuple[dict, list]: unpacked = self.name.decode().split('_') if unpacked[0] != 'Noise': - raise ValueError('Noise Protocol name shall begin with Noise! Provided: {}'.format(self.name)) + raise NoiseProtocolNameError('Noise Protocol name shall begin with Noise! Provided: {}'.format(self.name)) # Extract pattern name and pattern modifiers pattern = '' @@ -90,17 +90,46 @@ class NoiseProtocol(object): for key, map_dict in self.methods.items(): func = map_dict.get(data[key]) if not func: - raise ValueError('Unknown {} in Noise Protocol name, given {}, known {}'.format( - key, data[key], " ".join(map_dict))) + raise NoiseProtocolNameError('Unknown {} in Noise Protocol name, given {}, known {}'.format( + key, data[key], " ".join(map_dict))) mapped_data[key] = func return mapped_data, modifiers def handshake_done(self): - self.initiator = self.handshake_state.initiator if self.pattern.one_way: if self.initiator: - del self.cipher_state_decrypt + self.cipher_state_decrypt = None else: - del self.cipher_state_encrypt + self.cipher_state_encrypt = None self.handshake_hash = self.symmetric_state.h + del self.handshake_state + del self.symmetric_state + del self.cipher_state_handshake + del self.prologue + del self.initiator + del self.dh_fn + del self.hash_fn + del self.keypair_fn + + def validate(self): + if self.is_psk_handshake: + if any([len(psk) != 32 for psk in self.psks]): + raise NoisePSKError('Invalid psk length! Has to be 32 bytes long') + if len(self.psks) != self.pattern.psk_count: + raise NoisePSKError('Bad number of PSKs provided to this protocol! {} are required, ' + 'given {}'.format(self.pattern.psk_count, len(self.psks))) + + # TODO: Validate keypairs + # TODO: Validate initiator set + # TODO: Validate buffers set + # TODO: Warn about ephemerals + + def initialise_handshake_state(self): + kwargs = {'initiator': self.initiator} + if self.prologue: + kwargs['prologue'] = self.prologue + for keypair, value in self.keypairs.items(): + if value: + kwargs[keypair] = value + self.handshake_state = HandshakeState.initialize(self, **kwargs) diff --git a/noise/state.py b/noise/state.py index 17ce57a..a84f3a7 100644 --- a/noise/state.py +++ b/noise/state.py @@ -1,3 +1,5 @@ +from typing import Union + from .constants import Empty, TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK @@ -61,7 +63,6 @@ class CipherState(object): return plaintext - class SymmetricState(object): """ Implemented as per Noise Protocol specification (rev 32) - paragraph 5.2. @@ -228,7 +229,6 @@ class HandshakeState(object): # Create HandshakeState instance = cls() instance.noise_protocol = noise_protocol - noise_protocol.handshake_state = instance # Originally in specification: # "Derives a protocol_name byte sequence by combining the names for @@ -263,7 +263,7 @@ class HandshakeState(object): return instance - def write_message(self, payload: bytes, message_buffer): + def write_message(self, payload: Union[bytes, bytearray], message_buffer: bytearray): """ Comments below are mostly copied from specification. :param payload: byte sequence which may be zero-length @@ -277,14 +277,14 @@ class HandshakeState(object): if token == TOKEN_E: # Sets e = GENERATE_KEYPAIR(). Appends e.public_key to the buffer. Calls MixHash(e.public_key) self.e = self.noise_protocol.dh_fn.generate_keypair() if isinstance(self.e, Empty) else self.e # TODO: it's workaround, otherwise use mock - message_buffer.write(self.e.public_bytes) + message_buffer += self.e.public_bytes self.symmetric_state.mix_hash(self.e.public_bytes) if self.noise_protocol.is_psk_handshake: self.symmetric_state.mix_key(self.e.public_bytes) elif token == TOKEN_S: # Appends EncryptAndHash(s.public_key) to the buffer - message_buffer.write(self.symmetric_state.encrypt_and_hash(self.s.public_bytes)) + message_buffer += self.symmetric_state.encrypt_and_hash(self.s.public_bytes) elif token == TOKEN_EE: # Calls MixKey(DH(e, re)) @@ -315,13 +315,13 @@ class HandshakeState(object): raise NotImplementedError('Pattern token: {}'.format(token)) # Appends EncryptAndHash(payload) to the buffer - message_buffer.write(self.symmetric_state.encrypt_and_hash(payload)) + message_buffer += self.symmetric_state.encrypt_and_hash(payload) # If there are no more message patterns returns two new CipherState objects by calling Split() if len(self.message_patterns) == 0: return self.symmetric_state.split() - def read_message(self, message: bytes, payload_buffer): + def read_message(self, message: Union[bytes, bytearray], payload_buffer: bytearray): """ Comments below are mostly copied from specification. :param message: byte sequence containing a Noise handshake message @@ -335,7 +335,8 @@ class HandshakeState(object): for token in message_pattern: if token == TOKEN_E: # Sets re to the next DHLEN bytes from the message. Calls MixHash(re.public_key). - self.re = self.noise_protocol.keypair_fn.from_public_bytes(message.read(dhlen)) + self.re = self.noise_protocol.keypair_fn.from_public_bytes(bytes(message[:dhlen])) + message = message[dhlen:] self.symmetric_state.mix_hash(self.re.public_bytes) if self.noise_protocol.is_psk_handshake: self.symmetric_state.mix_key(self.re.public_bytes) @@ -344,9 +345,11 @@ class HandshakeState(object): # Sets temp to the next DHLEN + 16 bytes of the message if HasKey() == True, or to the next DHLEN bytes # otherwise. Sets rs to DecryptAndHash(temp). if self.noise_protocol.cipher_state_handshake.has_key(): - temp = message.read(dhlen + 16) + temp = bytes(message[:dhlen + 16]) + message = message[dhlen + 16:] else: - temp = message.read(dhlen) + temp = bytes(message[:dhlen]) + message = message[dhlen:] self.rs = self.noise_protocol.keypair_fn.from_public_bytes(self.symmetric_state.decrypt_and_hash(temp)) elif token == TOKEN_EE: @@ -378,7 +381,7 @@ class HandshakeState(object): raise NotImplementedError('Pattern token: {}'.format(token)) # Calls DecryptAndHash() on the remaining bytes of the message and stores the output into payload_buffer. - payload_buffer.write(self.symmetric_state.decrypt_and_hash(message.read())) + payload_buffer += self.symmetric_state.decrypt_and_hash(bytes(message)) # If there are no more message patterns returns two new CipherState objects by calling Split() if len(self.message_patterns) == 0: diff --git a/tests/test_vectors.py b/tests/test_vectors.py index 162ec65..d04b122 100644 --- a/tests/test_vectors.py +++ b/tests/test_vectors.py @@ -5,8 +5,8 @@ import os import pytest -from noise.state import HandshakeState, CipherState -from noise.noise_protocol import NoiseProtocol +from noise.state import CipherState +from noise.builder import NoiseBuilder, Keypair logger = logging.getLogger(__name__) @@ -60,78 +60,73 @@ class TestVectors(object): def vector(self, request): yield request.param - def _prepare_handshake_state_kwargs(self, vector, dh_fn): - # TODO: This is ugly af, refactor it :/ - kwargs = {'init': {}, 'resp': {}} - for role in ['init', 'resp']: - for key, kwarg in [('static', 's'), ('ephemeral', 'e'), ('remote_static', 'rs')]: - role_key = role + '_' + key - if role_key in vector: - if key in ['static', 'ephemeral']: - kwargs[role][kwarg] = dh_fn.keypair_cls.from_private_bytes(vector[role_key]) - elif key == 'remote_static': - kwargs[role][kwarg] = dh_fn.keypair_cls.from_public_bytes(vector[role_key]) - return kwargs + def _set_keypairs(self, vector, builder): + role = 'init' if builder.noise_protocol.initiator else 'resp' + setters = [ + (builder.set_keypair_from_private_bytes, Keypair.STATIC, role + '_static'), + (builder.set_keypair_from_private_bytes, Keypair.EPHEMERAL, role + '_ephemeral'), + (builder.set_keypair_from_public_bytes, Keypair.REMOTE_STATIC, role + '_remote_static') + ] + for fn, keypair, name in setters: + if name in vector: + fn(keypair, vector[name]) def test_vector(self, vector): + initiator = NoiseBuilder.from_name(vector['protocol_name']) + responder = NoiseBuilder.from_name(vector['protocol_name']) if 'init_psks' in vector and 'resp_psks' in vector: - init_protocol = NoiseProtocol(vector['protocol_name'], psks=vector['init_psks']) - resp_protocol = NoiseProtocol(vector['protocol_name'], psks=vector['resp_psks']) - else: - init_protocol = NoiseProtocol(vector['protocol_name']) - resp_protocol = NoiseProtocol(vector['protocol_name']) + initiator.set_psks(psks=vector['init_psks']) + responder.set_psks(psks=vector['resp_psks']) - kwargs = self._prepare_handshake_state_kwargs(vector, init_protocol.dh_fn) + initiator.set_prologue(vector['init_prologue']) + initiator.set_as_initiator() + self._set_keypairs(vector, initiator) - kwargs['init'].update(noise_protocol=init_protocol, initiator=True, prologue=vector['init_prologue']) - kwargs['resp'].update(noise_protocol=resp_protocol, initiator=False, prologue=vector['resp_prologue']) + responder.set_prologue(vector['resp_prologue']) + responder.set_as_responder() + self._set_keypairs(vector, responder) + + initiator.start_handshake() + responder.start_handshake() - initiator = HandshakeState.initialize(**kwargs['init']) - responder = HandshakeState.initialize(**kwargs['resp']) initiator_to_responder = True - handshake_finished = False for message in vector['messages']: if not handshake_finished: - message_buffer = io.BytesIO() - payload_buffer = io.BytesIO() if initiator_to_responder: sender, receiver = initiator, responder else: sender, receiver = responder, initiator - sender_result = sender.write_message(message['payload'], message_buffer) - assert message_buffer.getbuffer().tobytes() == message['ciphertext'] + sender_result = sender.write_message(message['payload']) + assert sender_result == message['ciphertext'] - message_buffer.seek(0) - receiver_result = receiver.read_message(message_buffer, payload_buffer) - assert payload_buffer.getbuffer().tobytes() == message['payload'] + receiver_result = receiver.read_message(sender_result) + assert receiver_result == message['payload'] - if sender_result is None or receiver_result is None: + if not (sender.handshake_finished and receiver.handshake_finished): # Not finished with handshake, fail if one would finish before other - assert sender_result == receiver_result + assert sender.handshake_finished == receiver.handshake_finished else: # Handshake done handshake_finished = True - assert isinstance(sender_result[0], CipherState) - assert isinstance(sender_result[1], CipherState) - assert isinstance(receiver_result[0], CipherState) - assert isinstance(receiver_result[1], CipherState) # Verify handshake hash - assert init_protocol.symmetric_state.h == resp_protocol.symmetric_state.h == vector['handshake_hash'] + assert initiator.noise_protocol.handshake_hash == responder.noise_protocol.handshake_hash == vector['handshake_hash'] # Verify split cipherstates keys - assert init_protocol.cipher_state_encrypt.k == resp_protocol.cipher_state_decrypt.k - if not init_protocol.pattern.one_way: - assert init_protocol.cipher_state_decrypt.k == resp_protocol.cipher_state_encrypt.k + assert initiator.noise_protocol.cipher_state_encrypt.k == responder.noise_protocol.cipher_state_decrypt.k + if not initiator.noise_protocol.pattern.one_way: + assert initiator.noise_protocol.cipher_state_decrypt.k == responder.noise_protocol.cipher_state_encrypt.k + else: + assert initiator.noise_protocol.cipher_state_decrypt is responder.noise_protocol.cipher_state_encrypt is None else: - if init_protocol.pattern.one_way or initiator_to_responder: - sender, receiver = init_protocol, resp_protocol + if initiator.noise_protocol.pattern.one_way or initiator_to_responder: + sender, receiver = initiator, responder else: - sender, receiver = resp_protocol, init_protocol - ciphertext = sender.cipher_state_encrypt.encrypt_with_ad(None, message['payload']) + sender, receiver = responder, initiator + ciphertext = sender.encrypt(message['payload']) assert ciphertext == message['ciphertext'] - plaintext = receiver.cipher_state_decrypt.decrypt_with_ad(None, message['ciphertext']) + plaintext = receiver.decrypt(message['ciphertext']) assert plaintext == message['payload'] initiator_to_responder = not initiator_to_responder