Added NoiseBuilder class as final interface. (#1)

noise/__init__.py
- __all__ containing builder module

noise/builder.py
- NoiseBuilder class providing interface for use with other apps. Allows
for setting up all required data for Noise protocol, using appropriate
methods. Enforces proper path of handshake execution

noise/constants.py
- Added maximum Noise message length constant

noise/exceptions.py
- A few exceptions created for proper signaling of errors

noise/noise_protocol.py
- handshake_done does proper cleanup now
- new validation method that should be ran before starting handshake
(checks presence of prerequisites for current settings)
- new HandshakeState initialization method

noise/state.py
- Modified read_message and write_message methods of HandshakeState to
operate on bytes/bytearray as message/payload and bytearray as
message_buffer/payload_buffer. It is application's responsibility to
provide data in this form, underlying Noise code doesn't do buffer
reading/writing anymore.

tests/test_vectors.py
- Changed tests to comply with new code
This commit is contained in:
Piotr Lizończyk
2017-09-02 16:09:49 +02:00
committed by GitHub
parent 96f7ba9b6b
commit 46825bb075
7 changed files with 250 additions and 75 deletions

View File

@@ -0,0 +1 @@
__all__ = ['builder']

131
noise/builder.py Normal file
View File

@@ -0,0 +1,131 @@
from enum import Enum, auto
from typing import Union, List
from noise.exceptions import NoisePSKError, NoiseValueError, NoiseHandshakeError
from .noise_protocol import NoiseProtocol
class Keypair(Enum):
STATIC = auto()
REMOTE_STATIC = auto()
EPHEMERAL = auto()
REMOTE_EPHEMERAL = auto()
_keypairs = {Keypair.STATIC: 's', Keypair.REMOTE_STATIC: 'rs',
Keypair.EPHEMERAL: 'e', Keypair.REMOTE_EPHEMERAL: 're'}
class NoiseBuilder(object):
def __init__(self):
self.noise_protocol = None
self.protocol_name = None
self.handshake_finished = False
self._handshake_started = False
self._next_fn = None
@classmethod
def from_name(cls, name: Union[str, bytes]):
instance = cls()
# Forgiving passing string. Bytes are good too, anything else will fail inside NoiseProtocol
try:
instance.protocol_name = name.encode('ascii') if isinstance(name, str) else name
except ValueError:
raise NoiseValueError('If passing string as protocol name, it must contain only ASCII characters')
instance.noise_protocol = NoiseProtocol(protocol_name=name)
return instance
def set_psks(self, psk: Union[bytes, str] = None, psks: List[Union[str, bytes]] = None):
if psk and psks:
raise NoisePSKError('Provide single PSK as psk or list of PSKs as psks')
if not psk and not psks:
raise NoisePSKError('No PSKs provided')
psks = psks or [psk]
if not all([isinstance(psk, (bytes, str)) for psk in psks]):
raise NoisePSKError('PSKs must be strings or bytes')
try:
self.noise_protocol.psks = [psk.encode('ascii') if isinstance(psk, str) else psk for psk in psks]
except UnicodeEncodeError:
raise NoisePSKError('If providing psks as (unicode) string, it must only contain ASCII characters')
def set_prologue(self, prologue: Union[bytes, str]):
if isinstance(prologue, bytes):
self.noise_protocol.prologue = prologue
elif isinstance(prologue, str):
try:
self.noise_protocol.prologue = prologue.encode('ascii')
except UnicodeEncodeError:
raise NoiseValueError('Prologue must be ASCII string or bytes')
else:
raise NoiseValueError('Prologue must be ASCII string or bytes')
def set_as_initiator(self):
self.noise_protocol.initiator = True
self._next_fn = self.write_message
def set_as_responder(self):
self.noise_protocol.initiator = False
self._next_fn = self.read_message
def set_keypair_from_private_bytes(self, keypair, private_bytes: bytes):
self.noise_protocol.keypairs[_keypairs[keypair]] = \
self.noise_protocol.dh_fn.keypair_cls.from_private_bytes(private_bytes)
def set_keypair_from_public_bytes(self, keypair, private_bytes: bytes):
self.noise_protocol.keypairs[_keypairs[keypair]] = \
self.noise_protocol.dh_fn.keypair_cls.from_public_bytes(private_bytes)
def set_keypair_from_private_path(self, keypair: Keypair, path: str):
with open(path, 'rb') as fd:
self.noise_protocol.keypairs[_keypairs[keypair]] = \
self.noise_protocol.dh_fn.keypair_cls.from_private_bytes(fd.read())
def set_keypair_from_public_path(self, keypair: Keypair, path: str):
with open(path, 'rb') as fd:
self.noise_protocol.keypairs[_keypairs[keypair]] = \
self.noise_protocol.dh_fn.keypair_cls.from_public_bytes(fd.read())
def start_handshake(self):
self.noise_protocol.validate()
self.noise_protocol.initialise_handshake_state()
self._handshake_started = True
def write_message(self, payload: bytes=b'') -> bytearray:
if not self._handshake_started:
raise NoiseHandshakeError('Call NoiseBuilder.start_handshake first')
if self._next_fn != self.write_message:
raise NoiseHandshakeError('NoiseBuilder.read_message has to be called now')
if self.handshake_finished:
raise NoiseHandshakeError('Handshake finished. NoiseBuilder.encrypt should be used now')
self._next_fn = self.read_message
buffer = bytearray()
result = self.noise_protocol.handshake_state.write_message(payload, buffer)
if result:
self.handshake_finished = True
return buffer
def read_message(self, data: bytes) -> bytearray:
if not self._handshake_started:
raise NoiseHandshakeError('Call NoiseBuilder.start_handshake first')
if self._next_fn != self.read_message:
raise NoiseHandshakeError('NoiseBuilder.write_message has to be called now')
if self.handshake_finished:
raise NoiseHandshakeError('Handshake finished. NoiseBuilder.decrypt should be used now')
self._next_fn = self.write_message
buffer = bytearray()
result = self.noise_protocol.handshake_state.read_message(data, buffer)
if result:
self.handshake_finished = True
return buffer
def encrypt(self, data: bytes):
if not isinstance(data, bytes) or len(data) > 65535:
raise Exception #todo
return self.noise_protocol.cipher_state_encrypt.encrypt_with_ad(None, data)
def decrypt(self, data: bytes):
return self.noise_protocol.cipher_state_decrypt.decrypt_with_ad(None, data)

View File

@@ -14,3 +14,5 @@ TOKEN_PSK = 'psk'
# In bytes, as in Section 8 of specification (rev 32)
MAX_PROTOCOL_NAME_LEN = 255
MAX_MESSAGE_LEN = 65535

14
noise/exceptions.py Normal file
View File

@@ -0,0 +1,14 @@
class NoiseProtocolNameError(Exception):
pass
class NoisePSKError(Exception):
pass
class NoiseValueError(Exception):
pass
class NoiseHandshakeError(Exception):
pass

View File

@@ -1,6 +1,8 @@
from functools import partial
from typing import Tuple, List
from noise.exceptions import NoiseProtocolNameError, NoisePSKError
from noise.state import HandshakeState
from .constants import MAX_PROTOCOL_NAME_LEN, Empty
from .functions import dh_map, cipher_map, hash_map, keypair_map, hmac_hash, hkdf
from .patterns import patterns_map
@@ -18,11 +20,12 @@ class NoiseProtocol(object):
'keypair': keypair_map
}
def __init__(self, protocol_name: bytes, psks: List[bytes]=None):
def __init__(self, protocol_name: bytes):
if not isinstance(protocol_name, bytes):
raise ValueError('Protocol name has to be of type "bytes", not {}'.format(type(protocol_name)))
raise NoiseProtocolNameError('Protocol name has to be of type "bytes" not {}'.format(type(protocol_name)))
if len(protocol_name) > MAX_PROTOCOL_NAME_LEN:
raise ValueError('Protocol name too long, has to be at most {} chars long'.format(MAX_PROTOCOL_NAME_LEN))
raise NoiseProtocolNameError('Protocol name too long, has to be at most '
'{} chars long'.format(MAX_PROTOCOL_NAME_LEN))
self.name = protocol_name
mappings, pattern_modifiers = self._parse_protocol_name()
@@ -34,14 +37,8 @@ class NoiseProtocol(object):
self.pattern.apply_pattern_modifiers(pattern_modifiers)
# Handle PSK handshake options
self.psks = psks
self.is_psk_handshake = False if not self.psks else True
if self.is_psk_handshake:
if any([len(psk) != 32 for psk in self.psks]):
raise ValueError('Invalid psk length!')
if len(self.psks) != self.pattern.psk_count:
raise ValueError('Bad number of PSKs provided to this protocol! {} are required, given {}'.format(
self.pattern.psk_count, len(self.psks)))
self.psks = None
self.is_psk_handshake = any([modifier.startswith('psk') for modifier in self.pattern_modifiers])
self.dh_fn = mappings['dh']
self.cipher_fn = mappings['cipher']
@@ -50,6 +47,7 @@ class NoiseProtocol(object):
self.hmac = partial(hmac_hash, algorithm=self.hash_fn.fn)
self.hkdf = partial(hkdf, hmac_hash_fn=self.hmac)
self.prologue = None
self.initiator = None
self.one_way = False
self.handshake_hash = None
@@ -60,10 +58,12 @@ class NoiseProtocol(object):
self.cipher_state_encrypt = Empty()
self.cipher_state_decrypt = Empty()
self.keypairs = {'s': None, 'e': None, 'rs': None, 're': None}
def _parse_protocol_name(self) -> Tuple[dict, list]:
unpacked = self.name.decode().split('_')
if unpacked[0] != 'Noise':
raise ValueError('Noise Protocol name shall begin with Noise! Provided: {}'.format(self.name))
raise NoiseProtocolNameError('Noise Protocol name shall begin with Noise! Provided: {}'.format(self.name))
# Extract pattern name and pattern modifiers
pattern = ''
@@ -90,17 +90,46 @@ class NoiseProtocol(object):
for key, map_dict in self.methods.items():
func = map_dict.get(data[key])
if not func:
raise ValueError('Unknown {} in Noise Protocol name, given {}, known {}'.format(
key, data[key], " ".join(map_dict)))
raise NoiseProtocolNameError('Unknown {} in Noise Protocol name, given {}, known {}'.format(
key, data[key], " ".join(map_dict)))
mapped_data[key] = func
return mapped_data, modifiers
def handshake_done(self):
self.initiator = self.handshake_state.initiator
if self.pattern.one_way:
if self.initiator:
del self.cipher_state_decrypt
self.cipher_state_decrypt = None
else:
del self.cipher_state_encrypt
self.cipher_state_encrypt = None
self.handshake_hash = self.symmetric_state.h
del self.handshake_state
del self.symmetric_state
del self.cipher_state_handshake
del self.prologue
del self.initiator
del self.dh_fn
del self.hash_fn
del self.keypair_fn
def validate(self):
if self.is_psk_handshake:
if any([len(psk) != 32 for psk in self.psks]):
raise NoisePSKError('Invalid psk length! Has to be 32 bytes long')
if len(self.psks) != self.pattern.psk_count:
raise NoisePSKError('Bad number of PSKs provided to this protocol! {} are required, '
'given {}'.format(self.pattern.psk_count, len(self.psks)))
# TODO: Validate keypairs
# TODO: Validate initiator set
# TODO: Validate buffers set
# TODO: Warn about ephemerals
def initialise_handshake_state(self):
kwargs = {'initiator': self.initiator}
if self.prologue:
kwargs['prologue'] = self.prologue
for keypair, value in self.keypairs.items():
if value:
kwargs[keypair] = value
self.handshake_state = HandshakeState.initialize(self, **kwargs)

View File

@@ -1,3 +1,5 @@
from typing import Union
from .constants import Empty, TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK
@@ -61,7 +63,6 @@ class CipherState(object):
return plaintext
class SymmetricState(object):
"""
Implemented as per Noise Protocol specification (rev 32) - paragraph 5.2.
@@ -228,7 +229,6 @@ class HandshakeState(object):
# Create HandshakeState
instance = cls()
instance.noise_protocol = noise_protocol
noise_protocol.handshake_state = instance
# Originally in specification:
# "Derives a protocol_name byte sequence by combining the names for
@@ -263,7 +263,7 @@ class HandshakeState(object):
return instance
def write_message(self, payload: bytes, message_buffer):
def write_message(self, payload: Union[bytes, bytearray], message_buffer: bytearray):
"""
Comments below are mostly copied from specification.
:param payload: byte sequence which may be zero-length
@@ -277,14 +277,14 @@ class HandshakeState(object):
if token == TOKEN_E:
# Sets e = GENERATE_KEYPAIR(). Appends e.public_key to the buffer. Calls MixHash(e.public_key)
self.e = self.noise_protocol.dh_fn.generate_keypair() if isinstance(self.e, Empty) else self.e # TODO: it's workaround, otherwise use mock
message_buffer.write(self.e.public_bytes)
message_buffer += self.e.public_bytes
self.symmetric_state.mix_hash(self.e.public_bytes)
if self.noise_protocol.is_psk_handshake:
self.symmetric_state.mix_key(self.e.public_bytes)
elif token == TOKEN_S:
# Appends EncryptAndHash(s.public_key) to the buffer
message_buffer.write(self.symmetric_state.encrypt_and_hash(self.s.public_bytes))
message_buffer += self.symmetric_state.encrypt_and_hash(self.s.public_bytes)
elif token == TOKEN_EE:
# Calls MixKey(DH(e, re))
@@ -315,13 +315,13 @@ class HandshakeState(object):
raise NotImplementedError('Pattern token: {}'.format(token))
# Appends EncryptAndHash(payload) to the buffer
message_buffer.write(self.symmetric_state.encrypt_and_hash(payload))
message_buffer += self.symmetric_state.encrypt_and_hash(payload)
# If there are no more message patterns returns two new CipherState objects by calling Split()
if len(self.message_patterns) == 0:
return self.symmetric_state.split()
def read_message(self, message: bytes, payload_buffer):
def read_message(self, message: Union[bytes, bytearray], payload_buffer: bytearray):
"""
Comments below are mostly copied from specification.
:param message: byte sequence containing a Noise handshake message
@@ -335,7 +335,8 @@ class HandshakeState(object):
for token in message_pattern:
if token == TOKEN_E:
# Sets re to the next DHLEN bytes from the message. Calls MixHash(re.public_key).
self.re = self.noise_protocol.keypair_fn.from_public_bytes(message.read(dhlen))
self.re = self.noise_protocol.keypair_fn.from_public_bytes(bytes(message[:dhlen]))
message = message[dhlen:]
self.symmetric_state.mix_hash(self.re.public_bytes)
if self.noise_protocol.is_psk_handshake:
self.symmetric_state.mix_key(self.re.public_bytes)
@@ -344,9 +345,11 @@ class HandshakeState(object):
# Sets temp to the next DHLEN + 16 bytes of the message if HasKey() == True, or to the next DHLEN bytes
# otherwise. Sets rs to DecryptAndHash(temp).
if self.noise_protocol.cipher_state_handshake.has_key():
temp = message.read(dhlen + 16)
temp = bytes(message[:dhlen + 16])
message = message[dhlen + 16:]
else:
temp = message.read(dhlen)
temp = bytes(message[:dhlen])
message = message[dhlen:]
self.rs = self.noise_protocol.keypair_fn.from_public_bytes(self.symmetric_state.decrypt_and_hash(temp))
elif token == TOKEN_EE:
@@ -378,7 +381,7 @@ class HandshakeState(object):
raise NotImplementedError('Pattern token: {}'.format(token))
# Calls DecryptAndHash() on the remaining bytes of the message and stores the output into payload_buffer.
payload_buffer.write(self.symmetric_state.decrypt_and_hash(message.read()))
payload_buffer += self.symmetric_state.decrypt_and_hash(bytes(message))
# If there are no more message patterns returns two new CipherState objects by calling Split()
if len(self.message_patterns) == 0:

View File

@@ -5,8 +5,8 @@ import os
import pytest
from noise.state import HandshakeState, CipherState
from noise.noise_protocol import NoiseProtocol
from noise.state import CipherState
from noise.builder import NoiseBuilder, Keypair
logger = logging.getLogger(__name__)
@@ -60,78 +60,73 @@ class TestVectors(object):
def vector(self, request):
yield request.param
def _prepare_handshake_state_kwargs(self, vector, dh_fn):
# TODO: This is ugly af, refactor it :/
kwargs = {'init': {}, 'resp': {}}
for role in ['init', 'resp']:
for key, kwarg in [('static', 's'), ('ephemeral', 'e'), ('remote_static', 'rs')]:
role_key = role + '_' + key
if role_key in vector:
if key in ['static', 'ephemeral']:
kwargs[role][kwarg] = dh_fn.keypair_cls.from_private_bytes(vector[role_key])
elif key == 'remote_static':
kwargs[role][kwarg] = dh_fn.keypair_cls.from_public_bytes(vector[role_key])
return kwargs
def _set_keypairs(self, vector, builder):
role = 'init' if builder.noise_protocol.initiator else 'resp'
setters = [
(builder.set_keypair_from_private_bytes, Keypair.STATIC, role + '_static'),
(builder.set_keypair_from_private_bytes, Keypair.EPHEMERAL, role + '_ephemeral'),
(builder.set_keypair_from_public_bytes, Keypair.REMOTE_STATIC, role + '_remote_static')
]
for fn, keypair, name in setters:
if name in vector:
fn(keypair, vector[name])
def test_vector(self, vector):
initiator = NoiseBuilder.from_name(vector['protocol_name'])
responder = NoiseBuilder.from_name(vector['protocol_name'])
if 'init_psks' in vector and 'resp_psks' in vector:
init_protocol = NoiseProtocol(vector['protocol_name'], psks=vector['init_psks'])
resp_protocol = NoiseProtocol(vector['protocol_name'], psks=vector['resp_psks'])
else:
init_protocol = NoiseProtocol(vector['protocol_name'])
resp_protocol = NoiseProtocol(vector['protocol_name'])
initiator.set_psks(psks=vector['init_psks'])
responder.set_psks(psks=vector['resp_psks'])
kwargs = self._prepare_handshake_state_kwargs(vector, init_protocol.dh_fn)
initiator.set_prologue(vector['init_prologue'])
initiator.set_as_initiator()
self._set_keypairs(vector, initiator)
kwargs['init'].update(noise_protocol=init_protocol, initiator=True, prologue=vector['init_prologue'])
kwargs['resp'].update(noise_protocol=resp_protocol, initiator=False, prologue=vector['resp_prologue'])
responder.set_prologue(vector['resp_prologue'])
responder.set_as_responder()
self._set_keypairs(vector, responder)
initiator.start_handshake()
responder.start_handshake()
initiator = HandshakeState.initialize(**kwargs['init'])
responder = HandshakeState.initialize(**kwargs['resp'])
initiator_to_responder = True
handshake_finished = False
for message in vector['messages']:
if not handshake_finished:
message_buffer = io.BytesIO()
payload_buffer = io.BytesIO()
if initiator_to_responder:
sender, receiver = initiator, responder
else:
sender, receiver = responder, initiator
sender_result = sender.write_message(message['payload'], message_buffer)
assert message_buffer.getbuffer().tobytes() == message['ciphertext']
sender_result = sender.write_message(message['payload'])
assert sender_result == message['ciphertext']
message_buffer.seek(0)
receiver_result = receiver.read_message(message_buffer, payload_buffer)
assert payload_buffer.getbuffer().tobytes() == message['payload']
receiver_result = receiver.read_message(sender_result)
assert receiver_result == message['payload']
if sender_result is None or receiver_result is None:
if not (sender.handshake_finished and receiver.handshake_finished):
# Not finished with handshake, fail if one would finish before other
assert sender_result == receiver_result
assert sender.handshake_finished == receiver.handshake_finished
else:
# Handshake done
handshake_finished = True
assert isinstance(sender_result[0], CipherState)
assert isinstance(sender_result[1], CipherState)
assert isinstance(receiver_result[0], CipherState)
assert isinstance(receiver_result[1], CipherState)
# Verify handshake hash
assert init_protocol.symmetric_state.h == resp_protocol.symmetric_state.h == vector['handshake_hash']
assert initiator.noise_protocol.handshake_hash == responder.noise_protocol.handshake_hash == vector['handshake_hash']
# Verify split cipherstates keys
assert init_protocol.cipher_state_encrypt.k == resp_protocol.cipher_state_decrypt.k
if not init_protocol.pattern.one_way:
assert init_protocol.cipher_state_decrypt.k == resp_protocol.cipher_state_encrypt.k
assert initiator.noise_protocol.cipher_state_encrypt.k == responder.noise_protocol.cipher_state_decrypt.k
if not initiator.noise_protocol.pattern.one_way:
assert initiator.noise_protocol.cipher_state_decrypt.k == responder.noise_protocol.cipher_state_encrypt.k
else:
assert initiator.noise_protocol.cipher_state_decrypt is responder.noise_protocol.cipher_state_encrypt is None
else:
if init_protocol.pattern.one_way or initiator_to_responder:
sender, receiver = init_protocol, resp_protocol
if initiator.noise_protocol.pattern.one_way or initiator_to_responder:
sender, receiver = initiator, responder
else:
sender, receiver = resp_protocol, init_protocol
ciphertext = sender.cipher_state_encrypt.encrypt_with_ad(None, message['payload'])
sender, receiver = responder, initiator
ciphertext = sender.encrypt(message['payload'])
assert ciphertext == message['ciphertext']
plaintext = receiver.cipher_state_decrypt.decrypt_with_ad(None, message['ciphertext'])
plaintext = receiver.decrypt(message['ciphertext'])
assert plaintext == message['payload']
initiator_to_responder = not initiator_to_responder