Enabling PSK support. Core functionality ready!

noise/noise_protocol.py
* PSKs should be now delivered to NoiseProtocol while initialising
* New field `is_psk_handshake` in NoiseProtocol

noise/patterns.py
* Fixed erronenous super call in OneWayPattern
* Changed class variables to instance variables in Patterns, fixes
things.

noise/state.py
* Added missing mix_key_and_hash to SymmetricState
* Added required calls when in PSK handshake (TOKEN_E and TOKEN_PSK),
both in write_message and read_message of HandshakeState

tests/test_vectors.py
* Enabled PSK tests, some minor fixes to make them work
This commit is contained in:
Piotr Lizonczyk
2017-08-19 01:27:59 +02:00
parent 69aafd92d8
commit 52fd5058bc
4 changed files with 196 additions and 128 deletions

View File

@@ -1,5 +1,5 @@
from functools import partial
from typing import Tuple
from typing import Tuple, List
from .constants import MAX_PROTOCOL_NAME_LEN, Empty
from .functions import dh_map, cipher_map, hash_map, keypair_map, hmac_hash, hkdf
@@ -18,7 +18,7 @@ class NoiseProtocol(object):
'keypair': keypair_map
}
def __init__(self, protocol_name: bytes):
def __init__(self, protocol_name: bytes, psks: List[bytes]=None):
if not isinstance(protocol_name, bytes):
raise ValueError('Protocol name has to be of type "bytes", not {}'.format(type(protocol_name)))
if len(protocol_name) > MAX_PROTOCOL_NAME_LEN:
@@ -33,6 +33,16 @@ class NoiseProtocol(object):
if self.pattern_modifiers:
self.pattern.apply_pattern_modifiers(pattern_modifiers)
# Handle PSK handshake options
self.psks = psks
self.is_psk_handshake = False if not self.psks else True
if self.is_psk_handshake:
if any([len(psk) != 32 for psk in self.psks]):
raise ValueError('Invalid psk length!')
if len(self.psks) != self.pattern.psk_count:
raise ValueError('Bad number of PSKs provided to this protocol! {} are required, given {}'.format(
self.pattern.psk_count, len(self.psks)))
self.dh_fn = mappings['dh']
self.cipher_fn = mappings['cipher']
self.hash_fn = mappings['hash']
@@ -43,7 +53,6 @@ class NoiseProtocol(object):
self.initiator = None
self.one_way = False
self.handshake_hash = None
self.psks = None # Placeholder for PSKs
self.handshake_state = Empty()
self.symmetric_state = Empty()
@@ -87,9 +96,6 @@ class NoiseProtocol(object):
return mapped_data, modifiers
def set_psks(self, psks: list) -> None:
self.psks = psks
def handshake_done(self):
self.initiator = self.handshake_state.initiator
if self.pattern.one_way:

View File

@@ -7,20 +7,21 @@ class Pattern(object):
"""
TODO document
"""
# As per specification, if both parties have pre-messages, the initiator is listed first. To reduce complexity,
# pre_messages shall be a list of two lists:
# the first for the initiator's pre-messages, the second for the responder
pre_messages = [
[],
[]
]
# List of lists of valid tokens, alternating between tokens for initiator and responder
tokens = []
def __init__(self):
# As per specification, if both parties have pre-messages, the initiator is listed first. To reduce complexity,
# pre_messages shall be a list of two lists:
# the first for the initiator's pre-messages, the second for the responder
self.pre_messages = [
[],
[]
]
# List of lists of valid tokens, alternating between tokens for initiator and responder
self.tokens = []
self.has_pre_messages = any(map(lambda x: len(x) > 0, self.pre_messages))
self.one_way = False
self.psk_count = 0
def get_initiator_pre_messages(self) -> list:
return self.pre_messages[0].copy()
@@ -41,10 +42,11 @@ class Pattern(object):
raise ValueError('Modifier {} cannot be applied - pattern has not enough messages'.format(modifier))
# Add TOKEN_PSK in the correct place in the correct message
if index % 2 == 0:
self.tokens[index//2].insert(0, TOKEN_PSK)
else:
self.tokens[index//2].append(TOKEN_PSK)
if index == 0: # if 0, insert at the beginning of first message
self.tokens[0].insert(0, TOKEN_PSK)
else: # if bigger than zero, append at the end of first, second etc.
self.tokens[index - 1].append(TOKEN_PSK)
self.psk_count += 1
elif modifier == 'fallback':
raise NotImplementedError # TODO implement
@@ -57,151 +59,196 @@ class Pattern(object):
class OneWayPattern(Pattern):
def __init__(self):
super(Pattern, self).__init__()
super(OneWayPattern, self).__init__()
self.one_way = True
class PatternN(OneWayPattern):
pre_messages = [
[],
[TOKEN_S]
]
tokens = [
[TOKEN_E, TOKEN_ES]
]
def __init__(self):
super(PatternN, self).__init__()
self.pre_messages = [
[],
[TOKEN_S]
]
self.tokens = [
[TOKEN_E, TOKEN_ES]
]
class PatternK(OneWayPattern):
pre_messages = [
[TOKEN_S],
[TOKEN_S]
]
tokens = [
[TOKEN_E, TOKEN_ES, TOKEN_SS]
]
def __init__(self):
super(PatternK, self).__init__()
self.pre_messages = [
[TOKEN_S],
[TOKEN_S]
]
self.tokens = [
[TOKEN_E, TOKEN_ES, TOKEN_SS]
]
class PatternX(OneWayPattern):
pre_messages = [
[],
[TOKEN_S]
]
tokens = [
[TOKEN_E, TOKEN_ES, TOKEN_S, TOKEN_SS]
]
def __init__(self):
super(PatternX, self).__init__()
self.pre_messages = [
[],
[TOKEN_S]
]
self.tokens = [
[TOKEN_E, TOKEN_ES, TOKEN_S, TOKEN_SS]
]
# Interactive patterns
class PatternNN(Pattern):
tokens = [
[TOKEN_E],
[TOKEN_E, TOKEN_EE]
]
def __init__(self):
super(PatternNN, self).__init__()
self.tokens = [
[TOKEN_E],
[TOKEN_E, TOKEN_EE]
]
class PatternKN(Pattern):
pre_messages = [
[TOKEN_S],
[]
]
tokens = [
[TOKEN_E],
[TOKEN_E, TOKEN_EE, TOKEN_SE]
]
def __init__(self):
super(PatternKN, self).__init__()
self.pre_messages = [
[TOKEN_S],
[]
]
self.tokens = [
[TOKEN_E],
[TOKEN_E, TOKEN_EE, TOKEN_SE]
]
class PatternNK(Pattern):
pre_messages = [
[],
[TOKEN_S]
]
tokens = [
[TOKEN_E, TOKEN_ES],
[TOKEN_E, TOKEN_EE]
]
def __init__(self):
super(PatternNK, self).__init__()
self.pre_messages = [
[],
[TOKEN_S]
]
self.tokens = [
[TOKEN_E, TOKEN_ES],
[TOKEN_E, TOKEN_EE]
]
class PatternKK(Pattern):
pre_messages = [
[TOKEN_S],
[TOKEN_S]
]
tokens = [
[TOKEN_E, TOKEN_ES, TOKEN_SS],
[TOKEN_E, TOKEN_EE, TOKEN_SE]
]
def __init__(self):
super(PatternKK, self).__init__()
self.pre_messages = [
[TOKEN_S],
[TOKEN_S]
]
self.tokens = [
[TOKEN_E, TOKEN_ES, TOKEN_SS],
[TOKEN_E, TOKEN_EE, TOKEN_SE]
]
class PatternNX(Pattern):
tokens = [
[TOKEN_E],
[TOKEN_E, TOKEN_EE, TOKEN_S, TOKEN_ES]
]
def __init__(self):
super(PatternNX, self).__init__()
self.tokens = [
[TOKEN_E],
[TOKEN_E, TOKEN_EE, TOKEN_S, TOKEN_ES]
]
class PatternKX(Pattern):
pre_messages = [
[TOKEN_S],
[]
]
tokens = [
[TOKEN_E],
[TOKEN_E, TOKEN_EE, TOKEN_SE, TOKEN_S, TOKEN_ES]
]
def __init__(self):
super(PatternKX, self).__init__()
self.pre_messages = [
[TOKEN_S],
[]
]
self.tokens = [
[TOKEN_E],
[TOKEN_E, TOKEN_EE, TOKEN_SE, TOKEN_S, TOKEN_ES]
]
class PatternXN(Pattern):
tokens = [
[TOKEN_E],
[TOKEN_E, TOKEN_EE],
[TOKEN_S, TOKEN_SE]
]
def __init__(self):
super(PatternXN, self).__init__()
self.tokens = [
[TOKEN_E],
[TOKEN_E, TOKEN_EE],
[TOKEN_S, TOKEN_SE]
]
class PatternIN(Pattern):
tokens = [
[TOKEN_E, TOKEN_S],
[TOKEN_E, TOKEN_EE, TOKEN_SE]
]
def __init__(self):
super(PatternIN, self).__init__()
self.tokens = [
[TOKEN_E, TOKEN_S],
[TOKEN_E, TOKEN_EE, TOKEN_SE]
]
class PatternXK(Pattern):
pre_messages = [
[],
[TOKEN_S]
]
tokens = [
[TOKEN_E, TOKEN_ES],
[TOKEN_E, TOKEN_EE],
[TOKEN_S, TOKEN_SE]
]
def __init__(self):
super(PatternXK, self).__init__()
self.pre_messages = [
[],
[TOKEN_S]
]
self.tokens = [
[TOKEN_E, TOKEN_ES],
[TOKEN_E, TOKEN_EE],
[TOKEN_S, TOKEN_SE]
]
class PatternIK(Pattern):
pre_messages = [
[],
[TOKEN_S]
]
tokens = [
[TOKEN_E, TOKEN_ES, TOKEN_S, TOKEN_SS],
[TOKEN_E, TOKEN_EE, TOKEN_SE]
]
def __init__(self):
super(PatternIK, self).__init__()
self.pre_messages = [
[],
[TOKEN_S]
]
self.tokens = [
[TOKEN_E, TOKEN_ES, TOKEN_S, TOKEN_SS],
[TOKEN_E, TOKEN_EE, TOKEN_SE]
]
class PatternXX(Pattern):
tokens = [
[TOKEN_E],
[TOKEN_E, TOKEN_EE, TOKEN_S, TOKEN_ES],
[TOKEN_S, TOKEN_SE]
]
def __init__(self):
super(PatternXX, self).__init__()
self.tokens = [
[TOKEN_E],
[TOKEN_E, TOKEN_EE, TOKEN_S, TOKEN_ES],
[TOKEN_S, TOKEN_SE]
]
class PatternIX(Pattern):
tokens = [
[TOKEN_E, TOKEN_S],
[TOKEN_E, TOKEN_EE, TOKEN_SE, TOKEN_S, TOKEN_ES]
]
def __init__(self):
super(PatternIX, self).__init__()
self.tokens = [
[TOKEN_E, TOKEN_S],
[TOKEN_E, TOKEN_EE, TOKEN_SE, TOKEN_S, TOKEN_ES]
]
patterns_map = {

View File

@@ -126,6 +126,17 @@ class SymmetricState(object):
"""
self.h = self.noise_protocol.hash_fn.hash(self.h + data)
def mix_key_and_hash(self, input_key_material: bytes):
# Sets ck, temp_h, temp_k = HKDF(ck, input_key_material, 3).
self.ck, temp_h, temp_k = self.noise_protocol.hkdf(self.ck, input_key_material, 3)
# Calls MixHash(temp_h).
self.mix_hash(temp_h)
# If HASHLEN is 64, then truncates temp_k to 32 bytes.
if self.noise_protocol.hash_fn.hashlen == 64:
temp_k = temp_k[:32]
# Calls InitializeKey(temp_k).
self.noise_protocol.cipher_state_handshake.initialize_key(temp_k)
def encrypt_and_hash(self, plaintext: bytes) -> bytes:
"""
Sets ciphertext = EncryptWithAd(h, plaintext), calls MixHash(ciphertext), and returns ciphertext. Note that if
@@ -268,6 +279,8 @@ class HandshakeState(object):
self.e = self.noise_protocol.dh_fn.generate_keypair() if isinstance(self.e, Empty) else self.e # TODO: it's workaround, otherwise use mock
message_buffer.write(self.e.public_bytes)
self.symmetric_state.mix_hash(self.e.public_bytes)
if self.noise_protocol.is_psk_handshake:
self.symmetric_state.mix_key(self.e.public_bytes)
elif token == TOKEN_S:
# Appends EncryptAndHash(s.public_key) to the buffer
@@ -294,10 +307,9 @@ class HandshakeState(object):
elif token == TOKEN_SS:
# Calls MixKey(DH(s, rs))
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s.private, self.rs.public))
pass
elif token == TOKEN_PSK:
raise NotImplementedError
self.symmetric_state.mix_key_and_hash(self.noise_protocol.psks.pop(0))
else:
raise NotImplementedError('Pattern token: {}'.format(token))
@@ -325,6 +337,8 @@ class HandshakeState(object):
# Sets re to the next DHLEN bytes from the message. Calls MixHash(re.public_key).
self.re = self.noise_protocol.keypair_fn.from_public_bytes(message.read(dhlen))
self.symmetric_state.mix_hash(self.re.public_bytes)
if self.noise_protocol.is_psk_handshake:
self.symmetric_state.mix_key(self.re.public_bytes)
elif token == TOKEN_S:
# Sets temp to the next DHLEN + 16 bytes of the message if HasKey() == True, or to the next DHLEN bytes
@@ -358,7 +372,7 @@ class HandshakeState(object):
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s.private, self.rs.public))
elif token == TOKEN_PSK:
raise NotImplementedError
self.symmetric_state.mix_key_and_hash(self.noise_protocol.psks.pop(0))
else:
raise NotImplementedError('Pattern token: {}'.format(token))

View File

@@ -35,15 +35,15 @@ def _prepare_test_vectors():
for vector in vectors_list:
if 'name' in vector and not 'protocol_name' in vector: # noise-c-* workaround
vector['protocol_name'] = vector['name']
if '_448_' in vector['protocol_name'] or 'psk' in vector['protocol_name'] or 'PSK' in vector['protocol_name']:
continue # TODO REMOVE WHEN ed448/psk/blake SUPPORT IS IMPLEMENTED/FIXED
if '_448_' in vector['protocol_name'] or 'PSK' in vector['protocol_name']: # no old NoisePSK tests
continue # TODO REMOVE WHEN ed448 SUPPORT IS IMPLEMENTED/FIXED
for key, value in vector.copy().items():
if key in byte_fields:
vector[key] = value.encode()
if key in hexbyte_fields:
vector[key] = bytes.fromhex(value)
if key in list_fields:
vector[key] = [k.encode() for k in value]
vector[key] = [bytes.fromhex(k) for k in value]
if key == dict_field:
vector[key] = []
for dictionary in value:
@@ -77,11 +77,12 @@ class TestVectors(object):
def test_vector(self, vector):
kwargs = self._prepare_handshake_state_kwargs(vector)
init_protocol = NoiseProtocol(vector['protocol_name'])
resp_protocol = NoiseProtocol(vector['protocol_name'])
if 'init_psks' in vector and 'resp_psks' in vector:
init_protocol.set_psks(vector['init_psks'])
resp_protocol.set_psks(vector['resp_psks'])
init_protocol = NoiseProtocol(vector['protocol_name'], psks=vector['init_psks'])
resp_protocol = NoiseProtocol(vector['protocol_name'], psks=vector['resp_psks'])
else:
init_protocol = NoiseProtocol(vector['protocol_name'])
resp_protocol = NoiseProtocol(vector['protocol_name'])
kwargs['init'].update(noise_protocol=init_protocol, initiator=True, prologue=vector['init_prologue'])
kwargs['resp'].update(noise_protocol=resp_protocol, initiator=False, prologue=vector['resp_prologue'])