Improved validation, various additions

noise/builder.py:
- Added guard for data length in decrypt
- Handling InvalidTag exception when AEAD fails
- New NoiseInvalidMessage exception class

noise/exceptions.py
- Three new exception classes

noise/noise_protocol.py
- Implemented rest of validation, now checks for required keypairs, setting initiator/responder role, warns if ephemeral keypairs are set.

noise/patterns.py:
- added name field to every Pattern with pattern name
- added get_required_keypairs method that returns list of keypairs required for given handshake pattern

noise/state.py
- new NoiseMaxNonceError exception

Overall: some TODOs resolved
This commit is contained in:
Piotr Lizonczyk
2017-09-03 13:27:34 +02:00
committed by Szarlejowiec
parent eaecac6af4
commit 368d401701
7 changed files with 75 additions and 19 deletions

View File

@@ -1,7 +1,10 @@
from enum import Enum, auto
from typing import Union, List
from noise.exceptions import NoisePSKError, NoiseValueError, NoiseHandshakeError
from cryptography.exceptions import InvalidTag
from noise.constants import MAX_MESSAGE_LEN
from noise.exceptions import NoisePSKError, NoiseValueError, NoiseHandshakeError, NoiseInvalidMessage
from .noise_protocol import NoiseProtocol
@@ -123,12 +126,17 @@ class NoiseBuilder(object):
return buffer
def encrypt(self, data: bytes):
if not isinstance(data, bytes) or len(data) > 65535:
raise Exception #todo
if not isinstance(data, bytes) or len(data) > MAX_MESSAGE_LEN:
raise NoiseInvalidMessage('Data must be bytes and less or equal {} bytes in length'.format(MAX_MESSAGE_LEN))
return self.noise_protocol.cipher_state_encrypt.encrypt_with_ad(None, data)
def decrypt(self, data: bytes):
return self.noise_protocol.cipher_state_decrypt.decrypt_with_ad(None, data)
if not isinstance(data, bytes) or len(data) > MAX_MESSAGE_LEN:
raise NoiseInvalidMessage('Data must be bytes and less or equal {} bytes in length'.format(MAX_MESSAGE_LEN))
try:
return self.noise_protocol.cipher_state_decrypt.decrypt_with_ad(None, data)
except InvalidTag:
raise NoiseInvalidMessage('Failed authentication of message')
def get_handshake_hash(self) -> bytes:
return self.noise_protocol.handshake_hash

View File

@@ -12,3 +12,15 @@ class NoiseValueError(Exception):
class NoiseHandshakeError(Exception):
pass
class NoiseInvalidMessage(Exception):
pass
class NoiseMaxNonceError(Exception):
pass
class NoiseValidationError(Exception):
pass

View File

@@ -1,6 +1,6 @@
import abc
import warnings
from functools import partial
from functools import partial # Turn back on when Cryptography gets fixed
import hashlib
import hmac
import os
@@ -208,9 +208,6 @@ class KeyPair448(_KeyPair):
return cls(private=private, public=public, public_bytes=public)
# Available crypto functions
# TODO: Check if it's safe to use one instance globally per cryptoalgorithm - i.e. if wrapper only provides interface
# If not - switch to partials(?)
dh_map = {
'25519': DH('ed25519'),
'448': DH('ed448')
@@ -222,7 +219,6 @@ cipher_map = {
}
hash_map = {
# TODO benchmark pycryptodome vs hashlib implementation
'BLAKE2s': Hash('BLAKE2s'),
'BLAKE2b': Hash('BLAKE2b'),
'SHA256': Hash('SHA256'),

View File

@@ -1,7 +1,8 @@
import warnings
from functools import partial
from typing import Tuple
from noise.exceptions import NoiseProtocolNameError, NoisePSKError
from noise.exceptions import NoiseProtocolNameError, NoisePSKError, NoiseValidationError
from noise.state import HandshakeState
from .constants import MAX_PROTOCOL_NAME_LEN, Empty
from .functions import dh_map, cipher_map, hash_map, keypair_map, hmac_hash, hkdf
@@ -120,10 +121,17 @@ class NoiseProtocol(object):
raise NoisePSKError('Bad number of PSKs provided to this protocol! {} are required, '
'given {}'.format(self.pattern.psk_count, len(self.psks)))
# TODO: Validate keypairs
# TODO: Validate initiator set
# TODO: Validate buffers set
# TODO: Warn about ephemerals
if self.initiator is None:
raise NoiseValidationError('You need to set role with NoiseBuilder.set_as_initiator '
'or NoiseBuilder.set_as_responder')
for keypair in self.pattern.get_required_keypairs(self.initiator):
if self.keypairs[keypair] is None:
raise NoiseValidationError('Keypair {} has to be set for chosen handshake pattern'.format(keypair))
if not isinstance(self.keypairs['e'], Empty) or not isinstance(self.keypairs['re'], Empty):
warnings.warn('One of ephemeral keypairs is already set. '
'This is OK for testing, but should NEVER happen in production!')
def initialise_handshake_state(self):
kwargs = {'initiator': self.initiator}

View File

@@ -19,6 +19,7 @@ class Pattern(object):
# List of lists of valid tokens, alternating between tokens for initiator and responder
self.tokens = []
self.name = ''
self.has_pre_messages = any(map(lambda x: len(x) > 0, self.pre_messages))
self.one_way = False
self.psk_count = 0
@@ -54,6 +55,20 @@ class Pattern(object):
else:
raise ValueError('Unknown pattern modifier {}'.format(modifier))
def get_required_keypairs(self, initiator: bool) -> list:
required = []
if initiator:
if self.name[0] in ['K', 'X', 'I']:
required.append('s')
if self.one_way or self.name[1] == 'K':
required.append('rs')
else:
if self.name[0] == 'K':
required.append('rs')
if self.one_way or self.name[1] in ['K', 'X']:
required.append('s')
return required
# One-way patterns
@@ -66,6 +81,7 @@ class OneWayPattern(Pattern):
class PatternN(OneWayPattern):
def __init__(self):
super(PatternN, self).__init__()
self.name = 'N'
self.pre_messages = [
[],
@@ -79,6 +95,7 @@ class PatternN(OneWayPattern):
class PatternK(OneWayPattern):
def __init__(self):
super(PatternK, self).__init__()
self.name = 'K'
self.pre_messages = [
[TOKEN_S],
@@ -92,6 +109,7 @@ class PatternK(OneWayPattern):
class PatternX(OneWayPattern):
def __init__(self):
super(PatternX, self).__init__()
self.name = 'X'
self.pre_messages = [
[],
@@ -107,6 +125,7 @@ class PatternX(OneWayPattern):
class PatternNN(Pattern):
def __init__(self):
super(PatternNN, self).__init__()
self.name = 'NN'
self.tokens = [
[TOKEN_E],
@@ -117,6 +136,7 @@ class PatternNN(Pattern):
class PatternKN(Pattern):
def __init__(self):
super(PatternKN, self).__init__()
self.name = 'KN'
self.pre_messages = [
[TOKEN_S],
@@ -131,6 +151,7 @@ class PatternKN(Pattern):
class PatternNK(Pattern):
def __init__(self):
super(PatternNK, self).__init__()
self.name = 'NK'
self.pre_messages = [
[],
@@ -145,6 +166,7 @@ class PatternNK(Pattern):
class PatternKK(Pattern):
def __init__(self):
super(PatternKK, self).__init__()
self.name = 'KK'
self.pre_messages = [
[TOKEN_S],
@@ -159,6 +181,7 @@ class PatternKK(Pattern):
class PatternNX(Pattern):
def __init__(self):
super(PatternNX, self).__init__()
self.name = 'NX'
self.tokens = [
[TOKEN_E],
@@ -169,6 +192,7 @@ class PatternNX(Pattern):
class PatternKX(Pattern):
def __init__(self):
super(PatternKX, self).__init__()
self.name = 'KX'
self.pre_messages = [
[TOKEN_S],
@@ -183,6 +207,7 @@ class PatternKX(Pattern):
class PatternXN(Pattern):
def __init__(self):
super(PatternXN, self).__init__()
self.name = 'XN'
self.tokens = [
[TOKEN_E],
@@ -194,6 +219,7 @@ class PatternXN(Pattern):
class PatternIN(Pattern):
def __init__(self):
super(PatternIN, self).__init__()
self.name = 'IN'
self.tokens = [
[TOKEN_E, TOKEN_S],
@@ -204,6 +230,7 @@ class PatternIN(Pattern):
class PatternXK(Pattern):
def __init__(self):
super(PatternXK, self).__init__()
self.name = 'XK'
self.pre_messages = [
[],
@@ -219,6 +246,7 @@ class PatternXK(Pattern):
class PatternIK(Pattern):
def __init__(self):
super(PatternIK, self).__init__()
self.name = 'IK'
self.pre_messages = [
[],
@@ -233,6 +261,7 @@ class PatternIK(Pattern):
class PatternXX(Pattern):
def __init__(self):
super(PatternXX, self).__init__()
self.name = 'XX'
self.tokens = [
[TOKEN_E],
@@ -244,6 +273,7 @@ class PatternXX(Pattern):
class PatternIX(Pattern):
def __init__(self):
super(PatternIX, self).__init__()
self.name = 'IX'
self.tokens = [
[TOKEN_E, TOKEN_S],

View File

@@ -1,5 +1,6 @@
from typing import Union
from noise.exceptions import NoiseMaxNonceError
from .constants import Empty, TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK, MAX_NONCE
@@ -53,7 +54,7 @@ class CipherState(object):
:return: plaintext bytes sequence
"""
if self.n == 2**64 - 1:
raise Exception('Nonce has depleted!')
raise NoiseMaxNonceError('Nonce has depleted!')
if not self.has_key():
return ciphertext
@@ -210,8 +211,8 @@ class HandshakeState(object):
self.message_patterns = None
@classmethod
def initialize(cls, noise_protocol: 'NoiseProtocol', initiator: bool, prologue: bytes=b'', s: bytes=None,
e: bytes=None, rs: bytes=None, re: bytes=None) -> 'HandshakeState': # TODO update typing (keypair)
def initialize(cls, noise_protocol: 'NoiseProtocol', initiator: bool, prologue: bytes=b'', s: '_KeyPair'=None,
e: '_KeyPair'=None, rs: '_KeyPair'=None, re: '_KeyPair'=None) -> 'HandshakeState':
"""
Constructor method.
Comments below are mostly copied from specification.
@@ -278,7 +279,7 @@ class HandshakeState(object):
for token in message_pattern:
if token == TOKEN_E:
# Sets e = GENERATE_KEYPAIR(). Appends e.public_key to the buffer. Calls MixHash(e.public_key)
self.e = self.noise_protocol.dh_fn.generate_keypair() if isinstance(self.e, Empty) else self.e # TODO: it's workaround, otherwise use mock
self.e = self.noise_protocol.dh_fn.generate_keypair() if isinstance(self.e, Empty) else self.e
message_buffer += self.e.public_bytes
self.symmetric_state.mix_hash(self.e.public_bytes)
if self.noise_protocol.is_psk_handshake:

View File

@@ -55,7 +55,8 @@ def idfn(vector):
return vector['protocol_name']
@pytest.mark.filterwarnings('ignore: This implementation')
@pytest.mark.filterwarnings('ignore: This implementation of ed448')
@pytest.mark.filterwarnings('ignore: One of ephemeral keypairs')
class TestVectors(object):
@pytest.fixture(params=_prepare_test_vectors(), ids=idfn)
def vector(self, request):