mirror of
https://github.com/morgan9e/noiseprotocol
synced 2026-04-14 00:14:05 +09:00
Added NoiseBuilder class as final interface. (#1)
noise/__init__.py - __all__ containing builder module noise/builder.py - NoiseBuilder class providing interface for use with other apps. Allows for setting up all required data for Noise protocol, using appropriate methods. Enforces proper path of handshake execution noise/constants.py - Added maximum Noise message length constant noise/exceptions.py - A few exceptions created for proper signaling of errors noise/noise_protocol.py - handshake_done does proper cleanup now - new validation method that should be ran before starting handshake (checks presence of prerequisites for current settings) - new HandshakeState initialization method noise/state.py - Modified read_message and write_message methods of HandshakeState to operate on bytes/bytearray as message/payload and bytearray as message_buffer/payload_buffer. It is application's responsibility to provide data in this form, underlying Noise code doesn't do buffer reading/writing anymore. tests/test_vectors.py - Changed tests to comply with new code
This commit is contained in:
@@ -0,0 +1 @@
|
||||
__all__ = ['builder']
|
||||
|
||||
131
noise/builder.py
Normal file
131
noise/builder.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from enum import Enum, auto
|
||||
from typing import Union, List
|
||||
|
||||
from noise.exceptions import NoisePSKError, NoiseValueError, NoiseHandshakeError
|
||||
from .noise_protocol import NoiseProtocol
|
||||
|
||||
|
||||
class Keypair(Enum):
|
||||
STATIC = auto()
|
||||
REMOTE_STATIC = auto()
|
||||
EPHEMERAL = auto()
|
||||
REMOTE_EPHEMERAL = auto()
|
||||
|
||||
|
||||
_keypairs = {Keypair.STATIC: 's', Keypair.REMOTE_STATIC: 'rs',
|
||||
Keypair.EPHEMERAL: 'e', Keypair.REMOTE_EPHEMERAL: 're'}
|
||||
|
||||
|
||||
class NoiseBuilder(object):
|
||||
def __init__(self):
|
||||
self.noise_protocol = None
|
||||
self.protocol_name = None
|
||||
self.handshake_finished = False
|
||||
self._handshake_started = False
|
||||
self._next_fn = None
|
||||
|
||||
@classmethod
|
||||
def from_name(cls, name: Union[str, bytes]):
|
||||
instance = cls()
|
||||
# Forgiving passing string. Bytes are good too, anything else will fail inside NoiseProtocol
|
||||
try:
|
||||
instance.protocol_name = name.encode('ascii') if isinstance(name, str) else name
|
||||
except ValueError:
|
||||
raise NoiseValueError('If passing string as protocol name, it must contain only ASCII characters')
|
||||
instance.noise_protocol = NoiseProtocol(protocol_name=name)
|
||||
return instance
|
||||
|
||||
def set_psks(self, psk: Union[bytes, str] = None, psks: List[Union[str, bytes]] = None):
|
||||
if psk and psks:
|
||||
raise NoisePSKError('Provide single PSK as psk or list of PSKs as psks')
|
||||
if not psk and not psks:
|
||||
raise NoisePSKError('No PSKs provided')
|
||||
|
||||
psks = psks or [psk]
|
||||
if not all([isinstance(psk, (bytes, str)) for psk in psks]):
|
||||
raise NoisePSKError('PSKs must be strings or bytes')
|
||||
|
||||
try:
|
||||
self.noise_protocol.psks = [psk.encode('ascii') if isinstance(psk, str) else psk for psk in psks]
|
||||
except UnicodeEncodeError:
|
||||
raise NoisePSKError('If providing psks as (unicode) string, it must only contain ASCII characters')
|
||||
|
||||
def set_prologue(self, prologue: Union[bytes, str]):
|
||||
if isinstance(prologue, bytes):
|
||||
self.noise_protocol.prologue = prologue
|
||||
elif isinstance(prologue, str):
|
||||
try:
|
||||
self.noise_protocol.prologue = prologue.encode('ascii')
|
||||
except UnicodeEncodeError:
|
||||
raise NoiseValueError('Prologue must be ASCII string or bytes')
|
||||
else:
|
||||
raise NoiseValueError('Prologue must be ASCII string or bytes')
|
||||
|
||||
def set_as_initiator(self):
|
||||
self.noise_protocol.initiator = True
|
||||
self._next_fn = self.write_message
|
||||
|
||||
def set_as_responder(self):
|
||||
self.noise_protocol.initiator = False
|
||||
self._next_fn = self.read_message
|
||||
|
||||
def set_keypair_from_private_bytes(self, keypair, private_bytes: bytes):
|
||||
self.noise_protocol.keypairs[_keypairs[keypair]] = \
|
||||
self.noise_protocol.dh_fn.keypair_cls.from_private_bytes(private_bytes)
|
||||
|
||||
def set_keypair_from_public_bytes(self, keypair, private_bytes: bytes):
|
||||
self.noise_protocol.keypairs[_keypairs[keypair]] = \
|
||||
self.noise_protocol.dh_fn.keypair_cls.from_public_bytes(private_bytes)
|
||||
|
||||
def set_keypair_from_private_path(self, keypair: Keypair, path: str):
|
||||
with open(path, 'rb') as fd:
|
||||
self.noise_protocol.keypairs[_keypairs[keypair]] = \
|
||||
self.noise_protocol.dh_fn.keypair_cls.from_private_bytes(fd.read())
|
||||
|
||||
def set_keypair_from_public_path(self, keypair: Keypair, path: str):
|
||||
with open(path, 'rb') as fd:
|
||||
self.noise_protocol.keypairs[_keypairs[keypair]] = \
|
||||
self.noise_protocol.dh_fn.keypair_cls.from_public_bytes(fd.read())
|
||||
|
||||
def start_handshake(self):
|
||||
self.noise_protocol.validate()
|
||||
self.noise_protocol.initialise_handshake_state()
|
||||
self._handshake_started = True
|
||||
|
||||
def write_message(self, payload: bytes=b'') -> bytearray:
|
||||
if not self._handshake_started:
|
||||
raise NoiseHandshakeError('Call NoiseBuilder.start_handshake first')
|
||||
if self._next_fn != self.write_message:
|
||||
raise NoiseHandshakeError('NoiseBuilder.read_message has to be called now')
|
||||
if self.handshake_finished:
|
||||
raise NoiseHandshakeError('Handshake finished. NoiseBuilder.encrypt should be used now')
|
||||
self._next_fn = self.read_message
|
||||
|
||||
buffer = bytearray()
|
||||
result = self.noise_protocol.handshake_state.write_message(payload, buffer)
|
||||
if result:
|
||||
self.handshake_finished = True
|
||||
return buffer
|
||||
|
||||
def read_message(self, data: bytes) -> bytearray:
|
||||
if not self._handshake_started:
|
||||
raise NoiseHandshakeError('Call NoiseBuilder.start_handshake first')
|
||||
if self._next_fn != self.read_message:
|
||||
raise NoiseHandshakeError('NoiseBuilder.write_message has to be called now')
|
||||
if self.handshake_finished:
|
||||
raise NoiseHandshakeError('Handshake finished. NoiseBuilder.decrypt should be used now')
|
||||
self._next_fn = self.write_message
|
||||
|
||||
buffer = bytearray()
|
||||
result = self.noise_protocol.handshake_state.read_message(data, buffer)
|
||||
if result:
|
||||
self.handshake_finished = True
|
||||
return buffer
|
||||
|
||||
def encrypt(self, data: bytes):
|
||||
if not isinstance(data, bytes) or len(data) > 65535:
|
||||
raise Exception #todo
|
||||
return self.noise_protocol.cipher_state_encrypt.encrypt_with_ad(None, data)
|
||||
|
||||
def decrypt(self, data: bytes):
|
||||
return self.noise_protocol.cipher_state_decrypt.decrypt_with_ad(None, data)
|
||||
@@ -14,3 +14,5 @@ TOKEN_PSK = 'psk'
|
||||
|
||||
# In bytes, as in Section 8 of specification (rev 32)
|
||||
MAX_PROTOCOL_NAME_LEN = 255
|
||||
|
||||
MAX_MESSAGE_LEN = 65535
|
||||
|
||||
14
noise/exceptions.py
Normal file
14
noise/exceptions.py
Normal file
@@ -0,0 +1,14 @@
|
||||
class NoiseProtocolNameError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NoisePSKError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NoiseValueError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NoiseHandshakeError(Exception):
|
||||
pass
|
||||
@@ -1,6 +1,8 @@
|
||||
from functools import partial
|
||||
from typing import Tuple, List
|
||||
|
||||
from noise.exceptions import NoiseProtocolNameError, NoisePSKError
|
||||
from noise.state import HandshakeState
|
||||
from .constants import MAX_PROTOCOL_NAME_LEN, Empty
|
||||
from .functions import dh_map, cipher_map, hash_map, keypair_map, hmac_hash, hkdf
|
||||
from .patterns import patterns_map
|
||||
@@ -18,11 +20,12 @@ class NoiseProtocol(object):
|
||||
'keypair': keypair_map
|
||||
}
|
||||
|
||||
def __init__(self, protocol_name: bytes, psks: List[bytes]=None):
|
||||
def __init__(self, protocol_name: bytes):
|
||||
if not isinstance(protocol_name, bytes):
|
||||
raise ValueError('Protocol name has to be of type "bytes", not {}'.format(type(protocol_name)))
|
||||
raise NoiseProtocolNameError('Protocol name has to be of type "bytes" not {}'.format(type(protocol_name)))
|
||||
if len(protocol_name) > MAX_PROTOCOL_NAME_LEN:
|
||||
raise ValueError('Protocol name too long, has to be at most {} chars long'.format(MAX_PROTOCOL_NAME_LEN))
|
||||
raise NoiseProtocolNameError('Protocol name too long, has to be at most '
|
||||
'{} chars long'.format(MAX_PROTOCOL_NAME_LEN))
|
||||
|
||||
self.name = protocol_name
|
||||
mappings, pattern_modifiers = self._parse_protocol_name()
|
||||
@@ -34,14 +37,8 @@ class NoiseProtocol(object):
|
||||
self.pattern.apply_pattern_modifiers(pattern_modifiers)
|
||||
|
||||
# Handle PSK handshake options
|
||||
self.psks = psks
|
||||
self.is_psk_handshake = False if not self.psks else True
|
||||
if self.is_psk_handshake:
|
||||
if any([len(psk) != 32 for psk in self.psks]):
|
||||
raise ValueError('Invalid psk length!')
|
||||
if len(self.psks) != self.pattern.psk_count:
|
||||
raise ValueError('Bad number of PSKs provided to this protocol! {} are required, given {}'.format(
|
||||
self.pattern.psk_count, len(self.psks)))
|
||||
self.psks = None
|
||||
self.is_psk_handshake = any([modifier.startswith('psk') for modifier in self.pattern_modifiers])
|
||||
|
||||
self.dh_fn = mappings['dh']
|
||||
self.cipher_fn = mappings['cipher']
|
||||
@@ -50,6 +47,7 @@ class NoiseProtocol(object):
|
||||
self.hmac = partial(hmac_hash, algorithm=self.hash_fn.fn)
|
||||
self.hkdf = partial(hkdf, hmac_hash_fn=self.hmac)
|
||||
|
||||
self.prologue = None
|
||||
self.initiator = None
|
||||
self.one_way = False
|
||||
self.handshake_hash = None
|
||||
@@ -60,10 +58,12 @@ class NoiseProtocol(object):
|
||||
self.cipher_state_encrypt = Empty()
|
||||
self.cipher_state_decrypt = Empty()
|
||||
|
||||
self.keypairs = {'s': None, 'e': None, 'rs': None, 're': None}
|
||||
|
||||
def _parse_protocol_name(self) -> Tuple[dict, list]:
|
||||
unpacked = self.name.decode().split('_')
|
||||
if unpacked[0] != 'Noise':
|
||||
raise ValueError('Noise Protocol name shall begin with Noise! Provided: {}'.format(self.name))
|
||||
raise NoiseProtocolNameError('Noise Protocol name shall begin with Noise! Provided: {}'.format(self.name))
|
||||
|
||||
# Extract pattern name and pattern modifiers
|
||||
pattern = ''
|
||||
@@ -90,17 +90,46 @@ class NoiseProtocol(object):
|
||||
for key, map_dict in self.methods.items():
|
||||
func = map_dict.get(data[key])
|
||||
if not func:
|
||||
raise ValueError('Unknown {} in Noise Protocol name, given {}, known {}'.format(
|
||||
key, data[key], " ".join(map_dict)))
|
||||
raise NoiseProtocolNameError('Unknown {} in Noise Protocol name, given {}, known {}'.format(
|
||||
key, data[key], " ".join(map_dict)))
|
||||
mapped_data[key] = func
|
||||
|
||||
return mapped_data, modifiers
|
||||
|
||||
def handshake_done(self):
|
||||
self.initiator = self.handshake_state.initiator
|
||||
if self.pattern.one_way:
|
||||
if self.initiator:
|
||||
del self.cipher_state_decrypt
|
||||
self.cipher_state_decrypt = None
|
||||
else:
|
||||
del self.cipher_state_encrypt
|
||||
self.cipher_state_encrypt = None
|
||||
self.handshake_hash = self.symmetric_state.h
|
||||
del self.handshake_state
|
||||
del self.symmetric_state
|
||||
del self.cipher_state_handshake
|
||||
del self.prologue
|
||||
del self.initiator
|
||||
del self.dh_fn
|
||||
del self.hash_fn
|
||||
del self.keypair_fn
|
||||
|
||||
def validate(self):
|
||||
if self.is_psk_handshake:
|
||||
if any([len(psk) != 32 for psk in self.psks]):
|
||||
raise NoisePSKError('Invalid psk length! Has to be 32 bytes long')
|
||||
if len(self.psks) != self.pattern.psk_count:
|
||||
raise NoisePSKError('Bad number of PSKs provided to this protocol! {} are required, '
|
||||
'given {}'.format(self.pattern.psk_count, len(self.psks)))
|
||||
|
||||
# TODO: Validate keypairs
|
||||
# TODO: Validate initiator set
|
||||
# TODO: Validate buffers set
|
||||
# TODO: Warn about ephemerals
|
||||
|
||||
def initialise_handshake_state(self):
|
||||
kwargs = {'initiator': self.initiator}
|
||||
if self.prologue:
|
||||
kwargs['prologue'] = self.prologue
|
||||
for keypair, value in self.keypairs.items():
|
||||
if value:
|
||||
kwargs[keypair] = value
|
||||
self.handshake_state = HandshakeState.initialize(self, **kwargs)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Union
|
||||
|
||||
from .constants import Empty, TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK
|
||||
|
||||
|
||||
@@ -61,7 +63,6 @@ class CipherState(object):
|
||||
return plaintext
|
||||
|
||||
|
||||
|
||||
class SymmetricState(object):
|
||||
"""
|
||||
Implemented as per Noise Protocol specification (rev 32) - paragraph 5.2.
|
||||
@@ -228,7 +229,6 @@ class HandshakeState(object):
|
||||
# Create HandshakeState
|
||||
instance = cls()
|
||||
instance.noise_protocol = noise_protocol
|
||||
noise_protocol.handshake_state = instance
|
||||
|
||||
# Originally in specification:
|
||||
# "Derives a protocol_name byte sequence by combining the names for
|
||||
@@ -263,7 +263,7 @@ class HandshakeState(object):
|
||||
|
||||
return instance
|
||||
|
||||
def write_message(self, payload: bytes, message_buffer):
|
||||
def write_message(self, payload: Union[bytes, bytearray], message_buffer: bytearray):
|
||||
"""
|
||||
Comments below are mostly copied from specification.
|
||||
:param payload: byte sequence which may be zero-length
|
||||
@@ -277,14 +277,14 @@ class HandshakeState(object):
|
||||
if token == TOKEN_E:
|
||||
# Sets e = GENERATE_KEYPAIR(). Appends e.public_key to the buffer. Calls MixHash(e.public_key)
|
||||
self.e = self.noise_protocol.dh_fn.generate_keypair() if isinstance(self.e, Empty) else self.e # TODO: it's workaround, otherwise use mock
|
||||
message_buffer.write(self.e.public_bytes)
|
||||
message_buffer += self.e.public_bytes
|
||||
self.symmetric_state.mix_hash(self.e.public_bytes)
|
||||
if self.noise_protocol.is_psk_handshake:
|
||||
self.symmetric_state.mix_key(self.e.public_bytes)
|
||||
|
||||
elif token == TOKEN_S:
|
||||
# Appends EncryptAndHash(s.public_key) to the buffer
|
||||
message_buffer.write(self.symmetric_state.encrypt_and_hash(self.s.public_bytes))
|
||||
message_buffer += self.symmetric_state.encrypt_and_hash(self.s.public_bytes)
|
||||
|
||||
elif token == TOKEN_EE:
|
||||
# Calls MixKey(DH(e, re))
|
||||
@@ -315,13 +315,13 @@ class HandshakeState(object):
|
||||
raise NotImplementedError('Pattern token: {}'.format(token))
|
||||
|
||||
# Appends EncryptAndHash(payload) to the buffer
|
||||
message_buffer.write(self.symmetric_state.encrypt_and_hash(payload))
|
||||
message_buffer += self.symmetric_state.encrypt_and_hash(payload)
|
||||
|
||||
# If there are no more message patterns returns two new CipherState objects by calling Split()
|
||||
if len(self.message_patterns) == 0:
|
||||
return self.symmetric_state.split()
|
||||
|
||||
def read_message(self, message: bytes, payload_buffer):
|
||||
def read_message(self, message: Union[bytes, bytearray], payload_buffer: bytearray):
|
||||
"""
|
||||
Comments below are mostly copied from specification.
|
||||
:param message: byte sequence containing a Noise handshake message
|
||||
@@ -335,7 +335,8 @@ class HandshakeState(object):
|
||||
for token in message_pattern:
|
||||
if token == TOKEN_E:
|
||||
# Sets re to the next DHLEN bytes from the message. Calls MixHash(re.public_key).
|
||||
self.re = self.noise_protocol.keypair_fn.from_public_bytes(message.read(dhlen))
|
||||
self.re = self.noise_protocol.keypair_fn.from_public_bytes(bytes(message[:dhlen]))
|
||||
message = message[dhlen:]
|
||||
self.symmetric_state.mix_hash(self.re.public_bytes)
|
||||
if self.noise_protocol.is_psk_handshake:
|
||||
self.symmetric_state.mix_key(self.re.public_bytes)
|
||||
@@ -344,9 +345,11 @@ class HandshakeState(object):
|
||||
# Sets temp to the next DHLEN + 16 bytes of the message if HasKey() == True, or to the next DHLEN bytes
|
||||
# otherwise. Sets rs to DecryptAndHash(temp).
|
||||
if self.noise_protocol.cipher_state_handshake.has_key():
|
||||
temp = message.read(dhlen + 16)
|
||||
temp = bytes(message[:dhlen + 16])
|
||||
message = message[dhlen + 16:]
|
||||
else:
|
||||
temp = message.read(dhlen)
|
||||
temp = bytes(message[:dhlen])
|
||||
message = message[dhlen:]
|
||||
self.rs = self.noise_protocol.keypair_fn.from_public_bytes(self.symmetric_state.decrypt_and_hash(temp))
|
||||
|
||||
elif token == TOKEN_EE:
|
||||
@@ -378,7 +381,7 @@ class HandshakeState(object):
|
||||
raise NotImplementedError('Pattern token: {}'.format(token))
|
||||
|
||||
# Calls DecryptAndHash() on the remaining bytes of the message and stores the output into payload_buffer.
|
||||
payload_buffer.write(self.symmetric_state.decrypt_and_hash(message.read()))
|
||||
payload_buffer += self.symmetric_state.decrypt_and_hash(bytes(message))
|
||||
|
||||
# If there are no more message patterns returns two new CipherState objects by calling Split()
|
||||
if len(self.message_patterns) == 0:
|
||||
|
||||
@@ -5,8 +5,8 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from noise.state import HandshakeState, CipherState
|
||||
from noise.noise_protocol import NoiseProtocol
|
||||
from noise.state import CipherState
|
||||
from noise.builder import NoiseBuilder, Keypair
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -60,78 +60,73 @@ class TestVectors(object):
|
||||
def vector(self, request):
|
||||
yield request.param
|
||||
|
||||
def _prepare_handshake_state_kwargs(self, vector, dh_fn):
|
||||
# TODO: This is ugly af, refactor it :/
|
||||
kwargs = {'init': {}, 'resp': {}}
|
||||
for role in ['init', 'resp']:
|
||||
for key, kwarg in [('static', 's'), ('ephemeral', 'e'), ('remote_static', 'rs')]:
|
||||
role_key = role + '_' + key
|
||||
if role_key in vector:
|
||||
if key in ['static', 'ephemeral']:
|
||||
kwargs[role][kwarg] = dh_fn.keypair_cls.from_private_bytes(vector[role_key])
|
||||
elif key == 'remote_static':
|
||||
kwargs[role][kwarg] = dh_fn.keypair_cls.from_public_bytes(vector[role_key])
|
||||
return kwargs
|
||||
def _set_keypairs(self, vector, builder):
|
||||
role = 'init' if builder.noise_protocol.initiator else 'resp'
|
||||
setters = [
|
||||
(builder.set_keypair_from_private_bytes, Keypair.STATIC, role + '_static'),
|
||||
(builder.set_keypair_from_private_bytes, Keypair.EPHEMERAL, role + '_ephemeral'),
|
||||
(builder.set_keypair_from_public_bytes, Keypair.REMOTE_STATIC, role + '_remote_static')
|
||||
]
|
||||
for fn, keypair, name in setters:
|
||||
if name in vector:
|
||||
fn(keypair, vector[name])
|
||||
|
||||
def test_vector(self, vector):
|
||||
initiator = NoiseBuilder.from_name(vector['protocol_name'])
|
||||
responder = NoiseBuilder.from_name(vector['protocol_name'])
|
||||
if 'init_psks' in vector and 'resp_psks' in vector:
|
||||
init_protocol = NoiseProtocol(vector['protocol_name'], psks=vector['init_psks'])
|
||||
resp_protocol = NoiseProtocol(vector['protocol_name'], psks=vector['resp_psks'])
|
||||
else:
|
||||
init_protocol = NoiseProtocol(vector['protocol_name'])
|
||||
resp_protocol = NoiseProtocol(vector['protocol_name'])
|
||||
initiator.set_psks(psks=vector['init_psks'])
|
||||
responder.set_psks(psks=vector['resp_psks'])
|
||||
|
||||
kwargs = self._prepare_handshake_state_kwargs(vector, init_protocol.dh_fn)
|
||||
initiator.set_prologue(vector['init_prologue'])
|
||||
initiator.set_as_initiator()
|
||||
self._set_keypairs(vector, initiator)
|
||||
|
||||
kwargs['init'].update(noise_protocol=init_protocol, initiator=True, prologue=vector['init_prologue'])
|
||||
kwargs['resp'].update(noise_protocol=resp_protocol, initiator=False, prologue=vector['resp_prologue'])
|
||||
responder.set_prologue(vector['resp_prologue'])
|
||||
responder.set_as_responder()
|
||||
self._set_keypairs(vector, responder)
|
||||
|
||||
initiator.start_handshake()
|
||||
responder.start_handshake()
|
||||
|
||||
initiator = HandshakeState.initialize(**kwargs['init'])
|
||||
responder = HandshakeState.initialize(**kwargs['resp'])
|
||||
initiator_to_responder = True
|
||||
|
||||
handshake_finished = False
|
||||
for message in vector['messages']:
|
||||
if not handshake_finished:
|
||||
message_buffer = io.BytesIO()
|
||||
payload_buffer = io.BytesIO()
|
||||
if initiator_to_responder:
|
||||
sender, receiver = initiator, responder
|
||||
else:
|
||||
sender, receiver = responder, initiator
|
||||
|
||||
sender_result = sender.write_message(message['payload'], message_buffer)
|
||||
assert message_buffer.getbuffer().tobytes() == message['ciphertext']
|
||||
sender_result = sender.write_message(message['payload'])
|
||||
assert sender_result == message['ciphertext']
|
||||
|
||||
message_buffer.seek(0)
|
||||
receiver_result = receiver.read_message(message_buffer, payload_buffer)
|
||||
assert payload_buffer.getbuffer().tobytes() == message['payload']
|
||||
receiver_result = receiver.read_message(sender_result)
|
||||
assert receiver_result == message['payload']
|
||||
|
||||
if sender_result is None or receiver_result is None:
|
||||
if not (sender.handshake_finished and receiver.handshake_finished):
|
||||
# Not finished with handshake, fail if one would finish before other
|
||||
assert sender_result == receiver_result
|
||||
assert sender.handshake_finished == receiver.handshake_finished
|
||||
else:
|
||||
# Handshake done
|
||||
handshake_finished = True
|
||||
assert isinstance(sender_result[0], CipherState)
|
||||
assert isinstance(sender_result[1], CipherState)
|
||||
assert isinstance(receiver_result[0], CipherState)
|
||||
assert isinstance(receiver_result[1], CipherState)
|
||||
|
||||
# Verify handshake hash
|
||||
assert init_protocol.symmetric_state.h == resp_protocol.symmetric_state.h == vector['handshake_hash']
|
||||
assert initiator.noise_protocol.handshake_hash == responder.noise_protocol.handshake_hash == vector['handshake_hash']
|
||||
|
||||
# Verify split cipherstates keys
|
||||
assert init_protocol.cipher_state_encrypt.k == resp_protocol.cipher_state_decrypt.k
|
||||
if not init_protocol.pattern.one_way:
|
||||
assert init_protocol.cipher_state_decrypt.k == resp_protocol.cipher_state_encrypt.k
|
||||
assert initiator.noise_protocol.cipher_state_encrypt.k == responder.noise_protocol.cipher_state_decrypt.k
|
||||
if not initiator.noise_protocol.pattern.one_way:
|
||||
assert initiator.noise_protocol.cipher_state_decrypt.k == responder.noise_protocol.cipher_state_encrypt.k
|
||||
else:
|
||||
assert initiator.noise_protocol.cipher_state_decrypt is responder.noise_protocol.cipher_state_encrypt is None
|
||||
else:
|
||||
if init_protocol.pattern.one_way or initiator_to_responder:
|
||||
sender, receiver = init_protocol, resp_protocol
|
||||
if initiator.noise_protocol.pattern.one_way or initiator_to_responder:
|
||||
sender, receiver = initiator, responder
|
||||
else:
|
||||
sender, receiver = resp_protocol, init_protocol
|
||||
ciphertext = sender.cipher_state_encrypt.encrypt_with_ad(None, message['payload'])
|
||||
sender, receiver = responder, initiator
|
||||
ciphertext = sender.encrypt(message['payload'])
|
||||
assert ciphertext == message['ciphertext']
|
||||
plaintext = receiver.cipher_state_decrypt.decrypt_with_ad(None, message['ciphertext'])
|
||||
plaintext = receiver.decrypt(message['ciphertext'])
|
||||
assert plaintext == message['payload']
|
||||
initiator_to_responder = not initiator_to_responder
|
||||
|
||||
Reference in New Issue
Block a user