diff --git a/noise/builder.py b/noise/builder.py index 1485886..c64fd90 100644 --- a/noise/builder.py +++ b/noise/builder.py @@ -1,7 +1,10 @@ from enum import Enum, auto from typing import Union, List -from noise.exceptions import NoisePSKError, NoiseValueError, NoiseHandshakeError +from cryptography.exceptions import InvalidTag + +from noise.constants import MAX_MESSAGE_LEN +from noise.exceptions import NoisePSKError, NoiseValueError, NoiseHandshakeError, NoiseInvalidMessage from .noise_protocol import NoiseProtocol @@ -123,12 +126,17 @@ class NoiseBuilder(object): return buffer def encrypt(self, data: bytes): - if not isinstance(data, bytes) or len(data) > 65535: - raise Exception #todo + if not isinstance(data, bytes) or len(data) > MAX_MESSAGE_LEN: + raise NoiseInvalidMessage('Data must be bytes and less or equal {} bytes in length'.format(MAX_MESSAGE_LEN)) return self.noise_protocol.cipher_state_encrypt.encrypt_with_ad(None, data) def decrypt(self, data: bytes): - return self.noise_protocol.cipher_state_decrypt.decrypt_with_ad(None, data) + if not isinstance(data, bytes) or len(data) > MAX_MESSAGE_LEN: + raise NoiseInvalidMessage('Data must be bytes and less or equal {} bytes in length'.format(MAX_MESSAGE_LEN)) + try: + return self.noise_protocol.cipher_state_decrypt.decrypt_with_ad(None, data) + except InvalidTag: + raise NoiseInvalidMessage('Failed authentication of message') def get_handshake_hash(self) -> bytes: return self.noise_protocol.handshake_hash diff --git a/noise/exceptions.py b/noise/exceptions.py index 95530e8..afb6673 100644 --- a/noise/exceptions.py +++ b/noise/exceptions.py @@ -12,3 +12,15 @@ class NoiseValueError(Exception): class NoiseHandshakeError(Exception): pass + + +class NoiseInvalidMessage(Exception): + pass + + +class NoiseMaxNonceError(Exception): + pass + + +class NoiseValidationError(Exception): + pass diff --git a/noise/functions.py b/noise/functions.py index ce91b18..cfbd305 100644 --- a/noise/functions.py +++ b/noise/functions.py @@ -1,6 +1,6 @@ import abc import warnings -from functools import partial +from functools import partial # Turn back on when Cryptography gets fixed import hashlib import hmac import os @@ -208,9 +208,6 @@ class KeyPair448(_KeyPair): return cls(private=private, public=public, public_bytes=public) -# Available crypto functions -# TODO: Check if it's safe to use one instance globally per cryptoalgorithm - i.e. if wrapper only provides interface -# If not - switch to partials(?) dh_map = { '25519': DH('ed25519'), '448': DH('ed448') @@ -222,7 +219,6 @@ cipher_map = { } hash_map = { - # TODO benchmark pycryptodome vs hashlib implementation 'BLAKE2s': Hash('BLAKE2s'), 'BLAKE2b': Hash('BLAKE2b'), 'SHA256': Hash('SHA256'), diff --git a/noise/noise_protocol.py b/noise/noise_protocol.py index a87c39f..49f1d88 100644 --- a/noise/noise_protocol.py +++ b/noise/noise_protocol.py @@ -1,7 +1,8 @@ +import warnings from functools import partial from typing import Tuple -from noise.exceptions import NoiseProtocolNameError, NoisePSKError +from noise.exceptions import NoiseProtocolNameError, NoisePSKError, NoiseValidationError from noise.state import HandshakeState from .constants import MAX_PROTOCOL_NAME_LEN, Empty from .functions import dh_map, cipher_map, hash_map, keypair_map, hmac_hash, hkdf @@ -120,10 +121,17 @@ class NoiseProtocol(object): raise NoisePSKError('Bad number of PSKs provided to this protocol! {} are required, ' 'given {}'.format(self.pattern.psk_count, len(self.psks))) - # TODO: Validate keypairs - # TODO: Validate initiator set - # TODO: Validate buffers set - # TODO: Warn about ephemerals + if self.initiator is None: + raise NoiseValidationError('You need to set role with NoiseBuilder.set_as_initiator ' + 'or NoiseBuilder.set_as_responder') + + for keypair in self.pattern.get_required_keypairs(self.initiator): + if self.keypairs[keypair] is None: + raise NoiseValidationError('Keypair {} has to be set for chosen handshake pattern'.format(keypair)) + + if not isinstance(self.keypairs['e'], Empty) or not isinstance(self.keypairs['re'], Empty): + warnings.warn('One of ephemeral keypairs is already set. ' + 'This is OK for testing, but should NEVER happen in production!') def initialise_handshake_state(self): kwargs = {'initiator': self.initiator} diff --git a/noise/patterns.py b/noise/patterns.py index 575ca46..9d68198 100644 --- a/noise/patterns.py +++ b/noise/patterns.py @@ -19,6 +19,7 @@ class Pattern(object): # List of lists of valid tokens, alternating between tokens for initiator and responder self.tokens = [] + self.name = '' self.has_pre_messages = any(map(lambda x: len(x) > 0, self.pre_messages)) self.one_way = False self.psk_count = 0 @@ -54,6 +55,20 @@ class Pattern(object): else: raise ValueError('Unknown pattern modifier {}'.format(modifier)) + def get_required_keypairs(self, initiator: bool) -> list: + required = [] + if initiator: + if self.name[0] in ['K', 'X', 'I']: + required.append('s') + if self.one_way or self.name[1] == 'K': + required.append('rs') + else: + if self.name[0] == 'K': + required.append('rs') + if self.one_way or self.name[1] in ['K', 'X']: + required.append('s') + return required + # One-way patterns @@ -66,6 +81,7 @@ class OneWayPattern(Pattern): class PatternN(OneWayPattern): def __init__(self): super(PatternN, self).__init__() + self.name = 'N' self.pre_messages = [ [], @@ -79,6 +95,7 @@ class PatternN(OneWayPattern): class PatternK(OneWayPattern): def __init__(self): super(PatternK, self).__init__() + self.name = 'K' self.pre_messages = [ [TOKEN_S], @@ -92,6 +109,7 @@ class PatternK(OneWayPattern): class PatternX(OneWayPattern): def __init__(self): super(PatternX, self).__init__() + self.name = 'X' self.pre_messages = [ [], @@ -107,6 +125,7 @@ class PatternX(OneWayPattern): class PatternNN(Pattern): def __init__(self): super(PatternNN, self).__init__() + self.name = 'NN' self.tokens = [ [TOKEN_E], @@ -117,6 +136,7 @@ class PatternNN(Pattern): class PatternKN(Pattern): def __init__(self): super(PatternKN, self).__init__() + self.name = 'KN' self.pre_messages = [ [TOKEN_S], @@ -131,6 +151,7 @@ class PatternKN(Pattern): class PatternNK(Pattern): def __init__(self): super(PatternNK, self).__init__() + self.name = 'NK' self.pre_messages = [ [], @@ -145,6 +166,7 @@ class PatternNK(Pattern): class PatternKK(Pattern): def __init__(self): super(PatternKK, self).__init__() + self.name = 'KK' self.pre_messages = [ [TOKEN_S], @@ -159,6 +181,7 @@ class PatternKK(Pattern): class PatternNX(Pattern): def __init__(self): super(PatternNX, self).__init__() + self.name = 'NX' self.tokens = [ [TOKEN_E], @@ -169,6 +192,7 @@ class PatternNX(Pattern): class PatternKX(Pattern): def __init__(self): super(PatternKX, self).__init__() + self.name = 'KX' self.pre_messages = [ [TOKEN_S], @@ -183,6 +207,7 @@ class PatternKX(Pattern): class PatternXN(Pattern): def __init__(self): super(PatternXN, self).__init__() + self.name = 'XN' self.tokens = [ [TOKEN_E], @@ -194,6 +219,7 @@ class PatternXN(Pattern): class PatternIN(Pattern): def __init__(self): super(PatternIN, self).__init__() + self.name = 'IN' self.tokens = [ [TOKEN_E, TOKEN_S], @@ -204,6 +230,7 @@ class PatternIN(Pattern): class PatternXK(Pattern): def __init__(self): super(PatternXK, self).__init__() + self.name = 'XK' self.pre_messages = [ [], @@ -219,6 +246,7 @@ class PatternXK(Pattern): class PatternIK(Pattern): def __init__(self): super(PatternIK, self).__init__() + self.name = 'IK' self.pre_messages = [ [], @@ -233,6 +261,7 @@ class PatternIK(Pattern): class PatternXX(Pattern): def __init__(self): super(PatternXX, self).__init__() + self.name = 'XX' self.tokens = [ [TOKEN_E], @@ -244,6 +273,7 @@ class PatternXX(Pattern): class PatternIX(Pattern): def __init__(self): super(PatternIX, self).__init__() + self.name = 'IX' self.tokens = [ [TOKEN_E, TOKEN_S], diff --git a/noise/state.py b/noise/state.py index f6e30f9..9c0be6b 100644 --- a/noise/state.py +++ b/noise/state.py @@ -1,5 +1,6 @@ from typing import Union +from noise.exceptions import NoiseMaxNonceError from .constants import Empty, TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK, MAX_NONCE @@ -53,7 +54,7 @@ class CipherState(object): :return: plaintext bytes sequence """ if self.n == 2**64 - 1: - raise Exception('Nonce has depleted!') + raise NoiseMaxNonceError('Nonce has depleted!') if not self.has_key(): return ciphertext @@ -210,8 +211,8 @@ class HandshakeState(object): self.message_patterns = None @classmethod - def initialize(cls, noise_protocol: 'NoiseProtocol', initiator: bool, prologue: bytes=b'', s: bytes=None, - e: bytes=None, rs: bytes=None, re: bytes=None) -> 'HandshakeState': # TODO update typing (keypair) + def initialize(cls, noise_protocol: 'NoiseProtocol', initiator: bool, prologue: bytes=b'', s: '_KeyPair'=None, + e: '_KeyPair'=None, rs: '_KeyPair'=None, re: '_KeyPair'=None) -> 'HandshakeState': """ Constructor method. Comments below are mostly copied from specification. @@ -278,7 +279,7 @@ class HandshakeState(object): for token in message_pattern: if token == TOKEN_E: # Sets e = GENERATE_KEYPAIR(). Appends e.public_key to the buffer. Calls MixHash(e.public_key) - self.e = self.noise_protocol.dh_fn.generate_keypair() if isinstance(self.e, Empty) else self.e # TODO: it's workaround, otherwise use mock + self.e = self.noise_protocol.dh_fn.generate_keypair() if isinstance(self.e, Empty) else self.e message_buffer += self.e.public_bytes self.symmetric_state.mix_hash(self.e.public_bytes) if self.noise_protocol.is_psk_handshake: diff --git a/tests/test_vectors.py b/tests/test_vectors.py index d19764f..15d893b 100644 --- a/tests/test_vectors.py +++ b/tests/test_vectors.py @@ -55,7 +55,8 @@ def idfn(vector): return vector['protocol_name'] -@pytest.mark.filterwarnings('ignore: This implementation') +@pytest.mark.filterwarnings('ignore: This implementation of ed448') +@pytest.mark.filterwarnings('ignore: One of ephemeral keypairs') class TestVectors(object): @pytest.fixture(params=_prepare_test_vectors(), ids=idfn) def vector(self, request):