mirror of
https://github.com/morgan9e/noiseprotocol
synced 2026-04-14 00:14:05 +09:00
Enabling PSK support. Core functionality ready!
noise/noise_protocol.py * PSKs should be now delivered to NoiseProtocol while initialising * New field `is_psk_handshake` in NoiseProtocol noise/patterns.py * Fixed erronenous super call in OneWayPattern * Changed class variables to instance variables in Patterns, fixes things. noise/state.py * Added missing mix_key_and_hash to SymmetricState * Added required calls when in PSK handshake (TOKEN_E and TOKEN_PSK), both in write_message and read_message of HandshakeState tests/test_vectors.py * Enabled PSK tests, some minor fixes to make them work
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
from typing import Tuple, List
|
||||
|
||||
from .constants import MAX_PROTOCOL_NAME_LEN, Empty
|
||||
from .functions import dh_map, cipher_map, hash_map, keypair_map, hmac_hash, hkdf
|
||||
@@ -18,7 +18,7 @@ class NoiseProtocol(object):
|
||||
'keypair': keypair_map
|
||||
}
|
||||
|
||||
def __init__(self, protocol_name: bytes):
|
||||
def __init__(self, protocol_name: bytes, psks: List[bytes]=None):
|
||||
if not isinstance(protocol_name, bytes):
|
||||
raise ValueError('Protocol name has to be of type "bytes", not {}'.format(type(protocol_name)))
|
||||
if len(protocol_name) > MAX_PROTOCOL_NAME_LEN:
|
||||
@@ -33,6 +33,16 @@ class NoiseProtocol(object):
|
||||
if self.pattern_modifiers:
|
||||
self.pattern.apply_pattern_modifiers(pattern_modifiers)
|
||||
|
||||
# Handle PSK handshake options
|
||||
self.psks = psks
|
||||
self.is_psk_handshake = False if not self.psks else True
|
||||
if self.is_psk_handshake:
|
||||
if any([len(psk) != 32 for psk in self.psks]):
|
||||
raise ValueError('Invalid psk length!')
|
||||
if len(self.psks) != self.pattern.psk_count:
|
||||
raise ValueError('Bad number of PSKs provided to this protocol! {} are required, given {}'.format(
|
||||
self.pattern.psk_count, len(self.psks)))
|
||||
|
||||
self.dh_fn = mappings['dh']
|
||||
self.cipher_fn = mappings['cipher']
|
||||
self.hash_fn = mappings['hash']
|
||||
@@ -43,7 +53,6 @@ class NoiseProtocol(object):
|
||||
self.initiator = None
|
||||
self.one_way = False
|
||||
self.handshake_hash = None
|
||||
self.psks = None # Placeholder for PSKs
|
||||
|
||||
self.handshake_state = Empty()
|
||||
self.symmetric_state = Empty()
|
||||
@@ -87,9 +96,6 @@ class NoiseProtocol(object):
|
||||
|
||||
return mapped_data, modifiers
|
||||
|
||||
def set_psks(self, psks: list) -> None:
|
||||
self.psks = psks
|
||||
|
||||
def handshake_done(self):
|
||||
self.initiator = self.handshake_state.initiator
|
||||
if self.pattern.one_way:
|
||||
|
||||
@@ -7,20 +7,21 @@ class Pattern(object):
|
||||
"""
|
||||
TODO document
|
||||
"""
|
||||
# As per specification, if both parties have pre-messages, the initiator is listed first. To reduce complexity,
|
||||
# pre_messages shall be a list of two lists:
|
||||
# the first for the initiator's pre-messages, the second for the responder
|
||||
pre_messages = [
|
||||
[],
|
||||
[]
|
||||
]
|
||||
|
||||
# List of lists of valid tokens, alternating between tokens for initiator and responder
|
||||
tokens = []
|
||||
|
||||
def __init__(self):
|
||||
# As per specification, if both parties have pre-messages, the initiator is listed first. To reduce complexity,
|
||||
# pre_messages shall be a list of two lists:
|
||||
# the first for the initiator's pre-messages, the second for the responder
|
||||
self.pre_messages = [
|
||||
[],
|
||||
[]
|
||||
]
|
||||
|
||||
# List of lists of valid tokens, alternating between tokens for initiator and responder
|
||||
self.tokens = []
|
||||
|
||||
self.has_pre_messages = any(map(lambda x: len(x) > 0, self.pre_messages))
|
||||
self.one_way = False
|
||||
self.psk_count = 0
|
||||
|
||||
def get_initiator_pre_messages(self) -> list:
|
||||
return self.pre_messages[0].copy()
|
||||
@@ -41,10 +42,11 @@ class Pattern(object):
|
||||
raise ValueError('Modifier {} cannot be applied - pattern has not enough messages'.format(modifier))
|
||||
|
||||
# Add TOKEN_PSK in the correct place in the correct message
|
||||
if index % 2 == 0:
|
||||
self.tokens[index//2].insert(0, TOKEN_PSK)
|
||||
else:
|
||||
self.tokens[index//2].append(TOKEN_PSK)
|
||||
if index == 0: # if 0, insert at the beginning of first message
|
||||
self.tokens[0].insert(0, TOKEN_PSK)
|
||||
else: # if bigger than zero, append at the end of first, second etc.
|
||||
self.tokens[index - 1].append(TOKEN_PSK)
|
||||
self.psk_count += 1
|
||||
|
||||
elif modifier == 'fallback':
|
||||
raise NotImplementedError # TODO implement
|
||||
@@ -57,151 +59,196 @@ class Pattern(object):
|
||||
|
||||
class OneWayPattern(Pattern):
|
||||
def __init__(self):
|
||||
super(Pattern, self).__init__()
|
||||
super(OneWayPattern, self).__init__()
|
||||
self.one_way = True
|
||||
|
||||
|
||||
class PatternN(OneWayPattern):
|
||||
pre_messages = [
|
||||
[],
|
||||
[TOKEN_S]
|
||||
]
|
||||
tokens = [
|
||||
[TOKEN_E, TOKEN_ES]
|
||||
]
|
||||
def __init__(self):
|
||||
super(PatternN, self).__init__()
|
||||
|
||||
self.pre_messages = [
|
||||
[],
|
||||
[TOKEN_S]
|
||||
]
|
||||
self.tokens = [
|
||||
[TOKEN_E, TOKEN_ES]
|
||||
]
|
||||
|
||||
|
||||
class PatternK(OneWayPattern):
|
||||
pre_messages = [
|
||||
[TOKEN_S],
|
||||
[TOKEN_S]
|
||||
]
|
||||
tokens = [
|
||||
[TOKEN_E, TOKEN_ES, TOKEN_SS]
|
||||
]
|
||||
def __init__(self):
|
||||
super(PatternK, self).__init__()
|
||||
|
||||
self.pre_messages = [
|
||||
[TOKEN_S],
|
||||
[TOKEN_S]
|
||||
]
|
||||
self.tokens = [
|
||||
[TOKEN_E, TOKEN_ES, TOKEN_SS]
|
||||
]
|
||||
|
||||
|
||||
class PatternX(OneWayPattern):
|
||||
pre_messages = [
|
||||
[],
|
||||
[TOKEN_S]
|
||||
]
|
||||
tokens = [
|
||||
[TOKEN_E, TOKEN_ES, TOKEN_S, TOKEN_SS]
|
||||
]
|
||||
def __init__(self):
|
||||
super(PatternX, self).__init__()
|
||||
|
||||
self.pre_messages = [
|
||||
[],
|
||||
[TOKEN_S]
|
||||
]
|
||||
self.tokens = [
|
||||
[TOKEN_E, TOKEN_ES, TOKEN_S, TOKEN_SS]
|
||||
]
|
||||
|
||||
|
||||
# Interactive patterns
|
||||
|
||||
class PatternNN(Pattern):
|
||||
tokens = [
|
||||
[TOKEN_E],
|
||||
[TOKEN_E, TOKEN_EE]
|
||||
]
|
||||
def __init__(self):
|
||||
super(PatternNN, self).__init__()
|
||||
|
||||
self.tokens = [
|
||||
[TOKEN_E],
|
||||
[TOKEN_E, TOKEN_EE]
|
||||
]
|
||||
|
||||
|
||||
class PatternKN(Pattern):
|
||||
pre_messages = [
|
||||
[TOKEN_S],
|
||||
[]
|
||||
]
|
||||
tokens = [
|
||||
[TOKEN_E],
|
||||
[TOKEN_E, TOKEN_EE, TOKEN_SE]
|
||||
]
|
||||
def __init__(self):
|
||||
super(PatternKN, self).__init__()
|
||||
|
||||
self.pre_messages = [
|
||||
[TOKEN_S],
|
||||
[]
|
||||
]
|
||||
self.tokens = [
|
||||
[TOKEN_E],
|
||||
[TOKEN_E, TOKEN_EE, TOKEN_SE]
|
||||
]
|
||||
|
||||
|
||||
class PatternNK(Pattern):
|
||||
pre_messages = [
|
||||
[],
|
||||
[TOKEN_S]
|
||||
]
|
||||
tokens = [
|
||||
[TOKEN_E, TOKEN_ES],
|
||||
[TOKEN_E, TOKEN_EE]
|
||||
]
|
||||
def __init__(self):
|
||||
super(PatternNK, self).__init__()
|
||||
|
||||
self.pre_messages = [
|
||||
[],
|
||||
[TOKEN_S]
|
||||
]
|
||||
self.tokens = [
|
||||
[TOKEN_E, TOKEN_ES],
|
||||
[TOKEN_E, TOKEN_EE]
|
||||
]
|
||||
|
||||
|
||||
class PatternKK(Pattern):
|
||||
pre_messages = [
|
||||
[TOKEN_S],
|
||||
[TOKEN_S]
|
||||
]
|
||||
tokens = [
|
||||
[TOKEN_E, TOKEN_ES, TOKEN_SS],
|
||||
[TOKEN_E, TOKEN_EE, TOKEN_SE]
|
||||
]
|
||||
def __init__(self):
|
||||
super(PatternKK, self).__init__()
|
||||
|
||||
self.pre_messages = [
|
||||
[TOKEN_S],
|
||||
[TOKEN_S]
|
||||
]
|
||||
self.tokens = [
|
||||
[TOKEN_E, TOKEN_ES, TOKEN_SS],
|
||||
[TOKEN_E, TOKEN_EE, TOKEN_SE]
|
||||
]
|
||||
|
||||
|
||||
class PatternNX(Pattern):
|
||||
tokens = [
|
||||
[TOKEN_E],
|
||||
[TOKEN_E, TOKEN_EE, TOKEN_S, TOKEN_ES]
|
||||
]
|
||||
def __init__(self):
|
||||
super(PatternNX, self).__init__()
|
||||
|
||||
self.tokens = [
|
||||
[TOKEN_E],
|
||||
[TOKEN_E, TOKEN_EE, TOKEN_S, TOKEN_ES]
|
||||
]
|
||||
|
||||
|
||||
class PatternKX(Pattern):
|
||||
pre_messages = [
|
||||
[TOKEN_S],
|
||||
[]
|
||||
]
|
||||
tokens = [
|
||||
[TOKEN_E],
|
||||
[TOKEN_E, TOKEN_EE, TOKEN_SE, TOKEN_S, TOKEN_ES]
|
||||
]
|
||||
def __init__(self):
|
||||
super(PatternKX, self).__init__()
|
||||
|
||||
self.pre_messages = [
|
||||
[TOKEN_S],
|
||||
[]
|
||||
]
|
||||
self.tokens = [
|
||||
[TOKEN_E],
|
||||
[TOKEN_E, TOKEN_EE, TOKEN_SE, TOKEN_S, TOKEN_ES]
|
||||
]
|
||||
|
||||
|
||||
class PatternXN(Pattern):
|
||||
tokens = [
|
||||
[TOKEN_E],
|
||||
[TOKEN_E, TOKEN_EE],
|
||||
[TOKEN_S, TOKEN_SE]
|
||||
]
|
||||
def __init__(self):
|
||||
super(PatternXN, self).__init__()
|
||||
|
||||
self.tokens = [
|
||||
[TOKEN_E],
|
||||
[TOKEN_E, TOKEN_EE],
|
||||
[TOKEN_S, TOKEN_SE]
|
||||
]
|
||||
|
||||
|
||||
class PatternIN(Pattern):
|
||||
tokens = [
|
||||
[TOKEN_E, TOKEN_S],
|
||||
[TOKEN_E, TOKEN_EE, TOKEN_SE]
|
||||
]
|
||||
def __init__(self):
|
||||
super(PatternIN, self).__init__()
|
||||
|
||||
self.tokens = [
|
||||
[TOKEN_E, TOKEN_S],
|
||||
[TOKEN_E, TOKEN_EE, TOKEN_SE]
|
||||
]
|
||||
|
||||
|
||||
class PatternXK(Pattern):
|
||||
pre_messages = [
|
||||
[],
|
||||
[TOKEN_S]
|
||||
]
|
||||
tokens = [
|
||||
[TOKEN_E, TOKEN_ES],
|
||||
[TOKEN_E, TOKEN_EE],
|
||||
[TOKEN_S, TOKEN_SE]
|
||||
]
|
||||
def __init__(self):
|
||||
super(PatternXK, self).__init__()
|
||||
|
||||
self.pre_messages = [
|
||||
[],
|
||||
[TOKEN_S]
|
||||
]
|
||||
self.tokens = [
|
||||
[TOKEN_E, TOKEN_ES],
|
||||
[TOKEN_E, TOKEN_EE],
|
||||
[TOKEN_S, TOKEN_SE]
|
||||
]
|
||||
|
||||
|
||||
class PatternIK(Pattern):
|
||||
pre_messages = [
|
||||
[],
|
||||
[TOKEN_S]
|
||||
]
|
||||
tokens = [
|
||||
[TOKEN_E, TOKEN_ES, TOKEN_S, TOKEN_SS],
|
||||
[TOKEN_E, TOKEN_EE, TOKEN_SE]
|
||||
]
|
||||
def __init__(self):
|
||||
super(PatternIK, self).__init__()
|
||||
|
||||
self.pre_messages = [
|
||||
[],
|
||||
[TOKEN_S]
|
||||
]
|
||||
self.tokens = [
|
||||
[TOKEN_E, TOKEN_ES, TOKEN_S, TOKEN_SS],
|
||||
[TOKEN_E, TOKEN_EE, TOKEN_SE]
|
||||
]
|
||||
|
||||
|
||||
class PatternXX(Pattern):
|
||||
tokens = [
|
||||
[TOKEN_E],
|
||||
[TOKEN_E, TOKEN_EE, TOKEN_S, TOKEN_ES],
|
||||
[TOKEN_S, TOKEN_SE]
|
||||
]
|
||||
def __init__(self):
|
||||
super(PatternXX, self).__init__()
|
||||
|
||||
self.tokens = [
|
||||
[TOKEN_E],
|
||||
[TOKEN_E, TOKEN_EE, TOKEN_S, TOKEN_ES],
|
||||
[TOKEN_S, TOKEN_SE]
|
||||
]
|
||||
|
||||
|
||||
class PatternIX(Pattern):
|
||||
tokens = [
|
||||
[TOKEN_E, TOKEN_S],
|
||||
[TOKEN_E, TOKEN_EE, TOKEN_SE, TOKEN_S, TOKEN_ES]
|
||||
]
|
||||
def __init__(self):
|
||||
super(PatternIX, self).__init__()
|
||||
|
||||
self.tokens = [
|
||||
[TOKEN_E, TOKEN_S],
|
||||
[TOKEN_E, TOKEN_EE, TOKEN_SE, TOKEN_S, TOKEN_ES]
|
||||
]
|
||||
|
||||
|
||||
patterns_map = {
|
||||
|
||||
@@ -126,6 +126,17 @@ class SymmetricState(object):
|
||||
"""
|
||||
self.h = self.noise_protocol.hash_fn.hash(self.h + data)
|
||||
|
||||
def mix_key_and_hash(self, input_key_material: bytes):
|
||||
# Sets ck, temp_h, temp_k = HKDF(ck, input_key_material, 3).
|
||||
self.ck, temp_h, temp_k = self.noise_protocol.hkdf(self.ck, input_key_material, 3)
|
||||
# Calls MixHash(temp_h).
|
||||
self.mix_hash(temp_h)
|
||||
# If HASHLEN is 64, then truncates temp_k to 32 bytes.
|
||||
if self.noise_protocol.hash_fn.hashlen == 64:
|
||||
temp_k = temp_k[:32]
|
||||
# Calls InitializeKey(temp_k).
|
||||
self.noise_protocol.cipher_state_handshake.initialize_key(temp_k)
|
||||
|
||||
def encrypt_and_hash(self, plaintext: bytes) -> bytes:
|
||||
"""
|
||||
Sets ciphertext = EncryptWithAd(h, plaintext), calls MixHash(ciphertext), and returns ciphertext. Note that if
|
||||
@@ -268,6 +279,8 @@ class HandshakeState(object):
|
||||
self.e = self.noise_protocol.dh_fn.generate_keypair() if isinstance(self.e, Empty) else self.e # TODO: it's workaround, otherwise use mock
|
||||
message_buffer.write(self.e.public_bytes)
|
||||
self.symmetric_state.mix_hash(self.e.public_bytes)
|
||||
if self.noise_protocol.is_psk_handshake:
|
||||
self.symmetric_state.mix_key(self.e.public_bytes)
|
||||
|
||||
elif token == TOKEN_S:
|
||||
# Appends EncryptAndHash(s.public_key) to the buffer
|
||||
@@ -294,10 +307,9 @@ class HandshakeState(object):
|
||||
elif token == TOKEN_SS:
|
||||
# Calls MixKey(DH(s, rs))
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s.private, self.rs.public))
|
||||
pass
|
||||
|
||||
elif token == TOKEN_PSK:
|
||||
raise NotImplementedError
|
||||
self.symmetric_state.mix_key_and_hash(self.noise_protocol.psks.pop(0))
|
||||
|
||||
else:
|
||||
raise NotImplementedError('Pattern token: {}'.format(token))
|
||||
@@ -325,6 +337,8 @@ class HandshakeState(object):
|
||||
# Sets re to the next DHLEN bytes from the message. Calls MixHash(re.public_key).
|
||||
self.re = self.noise_protocol.keypair_fn.from_public_bytes(message.read(dhlen))
|
||||
self.symmetric_state.mix_hash(self.re.public_bytes)
|
||||
if self.noise_protocol.is_psk_handshake:
|
||||
self.symmetric_state.mix_key(self.re.public_bytes)
|
||||
|
||||
elif token == TOKEN_S:
|
||||
# Sets temp to the next DHLEN + 16 bytes of the message if HasKey() == True, or to the next DHLEN bytes
|
||||
@@ -358,7 +372,7 @@ class HandshakeState(object):
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s.private, self.rs.public))
|
||||
|
||||
elif token == TOKEN_PSK:
|
||||
raise NotImplementedError
|
||||
self.symmetric_state.mix_key_and_hash(self.noise_protocol.psks.pop(0))
|
||||
|
||||
else:
|
||||
raise NotImplementedError('Pattern token: {}'.format(token))
|
||||
|
||||
@@ -35,15 +35,15 @@ def _prepare_test_vectors():
|
||||
for vector in vectors_list:
|
||||
if 'name' in vector and not 'protocol_name' in vector: # noise-c-* workaround
|
||||
vector['protocol_name'] = vector['name']
|
||||
if '_448_' in vector['protocol_name'] or 'psk' in vector['protocol_name'] or 'PSK' in vector['protocol_name']:
|
||||
continue # TODO REMOVE WHEN ed448/psk/blake SUPPORT IS IMPLEMENTED/FIXED
|
||||
if '_448_' in vector['protocol_name'] or 'PSK' in vector['protocol_name']: # no old NoisePSK tests
|
||||
continue # TODO REMOVE WHEN ed448 SUPPORT IS IMPLEMENTED/FIXED
|
||||
for key, value in vector.copy().items():
|
||||
if key in byte_fields:
|
||||
vector[key] = value.encode()
|
||||
if key in hexbyte_fields:
|
||||
vector[key] = bytes.fromhex(value)
|
||||
if key in list_fields:
|
||||
vector[key] = [k.encode() for k in value]
|
||||
vector[key] = [bytes.fromhex(k) for k in value]
|
||||
if key == dict_field:
|
||||
vector[key] = []
|
||||
for dictionary in value:
|
||||
@@ -77,11 +77,12 @@ class TestVectors(object):
|
||||
def test_vector(self, vector):
|
||||
kwargs = self._prepare_handshake_state_kwargs(vector)
|
||||
|
||||
init_protocol = NoiseProtocol(vector['protocol_name'])
|
||||
resp_protocol = NoiseProtocol(vector['protocol_name'])
|
||||
if 'init_psks' in vector and 'resp_psks' in vector:
|
||||
init_protocol.set_psks(vector['init_psks'])
|
||||
resp_protocol.set_psks(vector['resp_psks'])
|
||||
init_protocol = NoiseProtocol(vector['protocol_name'], psks=vector['init_psks'])
|
||||
resp_protocol = NoiseProtocol(vector['protocol_name'], psks=vector['resp_psks'])
|
||||
else:
|
||||
init_protocol = NoiseProtocol(vector['protocol_name'])
|
||||
resp_protocol = NoiseProtocol(vector['protocol_name'])
|
||||
|
||||
kwargs['init'].update(noise_protocol=init_protocol, initiator=True, prologue=vector['init_prologue'])
|
||||
kwargs['resp'].update(noise_protocol=resp_protocol, initiator=False, prologue=vector['resp_prologue'])
|
||||
|
||||
Reference in New Issue
Block a user