mirror of
https://github.com/morgan9e/noiseprotocol
synced 2026-04-14 00:14:05 +09:00
noise/builder.py: - Added guard for data length in decrypt - Handling InvalidTag exception when AEAD fails - New NoiseInvalidMessage exception class noise/exceptions.py - Three new exception classes noise/noise_protocol.py - Implemented rest of validation, now checks for required keypairs, setting initiator/responder role, warns if ephemeral keypairs are set. noise/patterns.py: - added name field to every Pattern with pattern name - added get_required_keypairs method that returns list of keypairs required for given handshake pattern noise/state.py - new NoiseMaxNonceError exception Overall: some TODOs resolved
135 lines
5.6 KiB
Python
135 lines
5.6 KiB
Python
import io
|
|
import json
|
|
import logging
|
|
import os
|
|
|
|
import pytest
|
|
|
|
from noise.state import CipherState
|
|
from noise.builder import NoiseBuilder, Keypair
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
vector_files = [
|
|
'vectors/cacophony.txt',
|
|
'vectors/noise-c-basic.txt'
|
|
]
|
|
|
|
# As in test vectors specification (https://github.com/noiseprotocol/noise_wiki/wiki/Test-vectors)
|
|
# We use this to cast read strings into bytes
|
|
byte_fields = ['protocol_name']
|
|
hexbyte_fields = ['init_prologue', 'init_static', 'init_ephemeral', 'init_remote_static', 'resp_static',
|
|
'resp_prologue', 'resp_ephemeral', 'resp_remote_static', 'handshake_hash']
|
|
list_fields = ['init_psks', 'resp_psks']
|
|
dict_field = 'messages'
|
|
|
|
|
|
def _prepare_test_vectors():
|
|
vectors = []
|
|
for path in vector_files:
|
|
with open(os.path.join(os.path.dirname(__file__), path)) as fd:
|
|
logging.info('Reading vectors from file {}'.format(path))
|
|
vectors_list = json.load(fd)
|
|
|
|
for vector in vectors_list:
|
|
if 'name' in vector and not 'protocol_name' in vector: # noise-c-* workaround
|
|
vector['protocol_name'] = vector['name']
|
|
if 'PSK' in vector['protocol_name']: # no old NoisePSK tests
|
|
continue # TODO REMOVE WHEN rev30 SUPPORT IS IMPLEMENTED/FIXED
|
|
for key, value in vector.copy().items():
|
|
if key in byte_fields:
|
|
vector[key] = value.encode()
|
|
if key in hexbyte_fields:
|
|
vector[key] = bytes.fromhex(value)
|
|
if key in list_fields:
|
|
vector[key] = [bytes.fromhex(k) for k in value]
|
|
if key == dict_field:
|
|
vector[key] = []
|
|
for dictionary in value:
|
|
vector[key].append({k: bytes.fromhex(v) for k, v in dictionary.items()})
|
|
vectors.append(vector)
|
|
return vectors
|
|
|
|
|
|
def idfn(vector):
|
|
return vector['protocol_name']
|
|
|
|
|
|
@pytest.mark.filterwarnings('ignore: This implementation of ed448')
|
|
@pytest.mark.filterwarnings('ignore: One of ephemeral keypairs')
|
|
class TestVectors(object):
|
|
@pytest.fixture(params=_prepare_test_vectors(), ids=idfn)
|
|
def vector(self, request):
|
|
yield request.param
|
|
|
|
def _set_keypairs(self, vector, builder):
|
|
role = 'init' if builder.noise_protocol.initiator else 'resp'
|
|
setters = [
|
|
(builder.set_keypair_from_private_bytes, Keypair.STATIC, role + '_static'),
|
|
(builder.set_keypair_from_private_bytes, Keypair.EPHEMERAL, role + '_ephemeral'),
|
|
(builder.set_keypair_from_public_bytes, Keypair.REMOTE_STATIC, role + '_remote_static')
|
|
]
|
|
for fn, keypair, name in setters:
|
|
if name in vector:
|
|
fn(keypair, vector[name])
|
|
|
|
def test_vector(self, vector):
|
|
initiator = NoiseBuilder.from_name(vector['protocol_name'])
|
|
responder = NoiseBuilder.from_name(vector['protocol_name'])
|
|
if 'init_psks' in vector and 'resp_psks' in vector:
|
|
initiator.set_psks(psks=vector['init_psks'])
|
|
responder.set_psks(psks=vector['resp_psks'])
|
|
|
|
initiator.set_prologue(vector['init_prologue'])
|
|
initiator.set_as_initiator()
|
|
self._set_keypairs(vector, initiator)
|
|
|
|
responder.set_prologue(vector['resp_prologue'])
|
|
responder.set_as_responder()
|
|
self._set_keypairs(vector, responder)
|
|
|
|
initiator.start_handshake()
|
|
responder.start_handshake()
|
|
|
|
initiator_to_responder = True
|
|
handshake_finished = False
|
|
for message in vector['messages']:
|
|
if not handshake_finished:
|
|
if initiator_to_responder:
|
|
sender, receiver = initiator, responder
|
|
else:
|
|
sender, receiver = responder, initiator
|
|
|
|
sender_result = sender.write_message(message['payload'])
|
|
assert sender_result == message['ciphertext']
|
|
|
|
receiver_result = receiver.read_message(sender_result)
|
|
assert receiver_result == message['payload']
|
|
|
|
if not (sender.handshake_finished and receiver.handshake_finished):
|
|
# Not finished with handshake, fail if one would finish before other
|
|
assert sender.handshake_finished == receiver.handshake_finished
|
|
else:
|
|
# Handshake done
|
|
handshake_finished = True
|
|
|
|
# Verify handshake hash
|
|
assert initiator.noise_protocol.handshake_hash == responder.noise_protocol.handshake_hash == vector['handshake_hash']
|
|
|
|
# Verify split cipherstates keys
|
|
assert initiator.noise_protocol.cipher_state_encrypt.k == responder.noise_protocol.cipher_state_decrypt.k
|
|
if not initiator.noise_protocol.pattern.one_way:
|
|
assert initiator.noise_protocol.cipher_state_decrypt.k == responder.noise_protocol.cipher_state_encrypt.k
|
|
else:
|
|
assert initiator.noise_protocol.cipher_state_decrypt is responder.noise_protocol.cipher_state_encrypt is None
|
|
else:
|
|
if initiator.noise_protocol.pattern.one_way or initiator_to_responder:
|
|
sender, receiver = initiator, responder
|
|
else:
|
|
sender, receiver = responder, initiator
|
|
ciphertext = sender.encrypt(message['payload'])
|
|
assert ciphertext == message['ciphertext']
|
|
plaintext = receiver.decrypt(message['ciphertext'])
|
|
assert plaintext == message['payload']
|
|
initiator_to_responder = not initiator_to_responder
|