Improvements to initialization of NoiseProtocol

noise/noise_protocol.py
* Added validation of given function names vs available crypto methods
* Members of NoiseProtocol should now refer to proper classes/methods
after initialization of an instance

noise/patterns.py
* Added method for application of pattern modifiers

noise/crypto.py
* Provisioned ed448 function

noise/state.py
* Changed references to NoiseProtocol instances to make it more
consistent throughout the code
This commit is contained in:
Piotr Lizonczyk
2017-08-12 13:30:44 +02:00
parent de73505ac3
commit bcaceb9ccd
6 changed files with 79 additions and 28 deletions

View File

@@ -13,4 +13,4 @@ TOKEN_PSK = 'psk'
# In bytes, as in Section 8 of specification (rev 32)
MAX_PROTOCOL_NAME_LEN = 255
MAX_PROTOCOL_NAME_LEN = 255

2
noise/crypto.py Normal file
View File

@@ -0,0 +1,2 @@
def ed448(*args, **kwargs):
raise NotImplementedError

View File

@@ -1,14 +1,18 @@
from functools import partial
from typing import Tuple
from .patterns import patterns_map
from .constants import MAX_PROTOCOL_NAME_LEN
from .crypto import ed448
from Crypto.Cipher import AES, ChaCha20
from Crypto.Hash import BLAKE2b, BLAKE2s, SHA256, SHA512
import ed25519
dh_map = {
'25519': ed25519.create_keypair,
'448': None # TODO implement
'448': ed448 # TODO implement
}
cipher_map = {
@@ -17,6 +21,7 @@ cipher_map = {
}
hash_map = {
# TODO benchmark vs hashlib implementation
'BLAKE2b': BLAKE2b, # TODO PARTIALS
'BLAKE2s': BLAKE2s, # TODO PARTIALS
'SHA256': SHA256, # TODO PARTIALS
@@ -25,35 +30,46 @@ hash_map = {
class NoiseProtocol(object):
"""
TODO: Document
"""
methods = {
'pattern': patterns_map,
'dh': dh_map,
'cipher': cipher_map,
'hash': hash_map
}
def __init__(self, protocol_name: bytes):
if len(protocol_name) > MAX_PROTOCOL_NAME_LEN:
raise ValueError('Protocol name too long, has to be at most {} chars long'.format(MAX_PROTOCOL_NAME_LEN))
self.name = protocol_name
data_dict = self._split_protocol_name()
self.pattern = patterns_map[data_dict['pattern']]
self.pattern_modifiers = None
self.dh = None
self.cipher = None
self.hash = None
mappings, pattern_modifiers = self._parse_protocol_name()
def _split_protocol_name(self):
self.pattern = mappings['pattern']()
self.pattern_modifiers = pattern_modifiers
if self.pattern_modifiers:
self.pattern.apply_pattern_modifiers(pattern_modifiers)
self.dh = mappings['pattern']
self.cipher = mappings['pattern']
self.hash = mappings['pattern']
def _parse_protocol_name(self) -> Tuple[dict, list]:
unpacked = self.name.split('_')
if unpacked[0] != 'Noise':
raise ValueError(f'Noise protocol name shall begin with Noise! Provided: {self.name}')
raise ValueError(f'Noise Protocol name shall begin with Noise! Provided: {self.name}')
# Extract pattern name and pattern modifiers
pattern = ''
modifiers_str = None
for i, char in enumerate(unpacked[1]):
if char.isupper():
pattern += char
else:
modifiers_str = unpacked[1][i+1:] # Will be empty string if it exceeds string size
# End of pattern, now look for modifiers
modifiers_str = unpacked[1][i:] # Will be empty string if it exceeds string size
break
modifiers = modifiers_str.split('+') if modifiers_str else []
@@ -63,10 +79,16 @@ class NoiseProtocol(object):
'hash': unpacked[4],
'pattern_modifiers': modifiers}
# Validate if we know everything that Noise Protocol is supposed to use
# TODO validation
mapped_data = {}
return data
# Validate if we know everything that Noise Protocol is supposed to use and map appropriate functions
for key, map_dict in self.methods.items():
func = map_dict.get(data[key])
if not func:
raise ValueError(f'Unknown {key} in Noise Protocol name, given {data[key]}, known {" ".join(map_dict)}')
mapped_data[key] = func
return mapped_data, modifiers
class KeyPair(object):

View File

@@ -1,3 +1,5 @@
from typing import List
from .constants import TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK
@@ -8,12 +10,13 @@ class Pattern(object):
# As per specification, if both parties have pre-messages, the initiator is listed first. To reduce complexity,
# pre_messages shall be a list of two lists:
# the first for the initiator's pre-messages, the second for the responder
pre_messages = [
pre_messages: List[list] = [
[],
[]
]
# TODO Comment
tokens = []
# List of lists of valid tokens, alternating between tokens for initiator and responder
tokens: List[list] = []
def __init__(self):
self.has_pre_messages = any(map(lambda x: len(x) > 0, self.pre_messages))
@@ -24,6 +27,30 @@ class Pattern(object):
def get_responder_pre_messages(self) -> list:
return self.pre_messages[1]
def apply_pattern_modifiers(self, modifiers: List[str]) -> None:
# Applies given pattern modifiers to self.tokens of the Pattern instance.
for modifier in modifiers:
if modifier.startswith('psk'):
try:
index = int(modifier.replace('psk', '', 1))
except ValueError:
raise ValueError(f'Improper psk modifier {modifier}')
if index // 2 > len(self.tokens):
raise ValueError(f'Modifier {modifier} cannot be applied - pattern has not enough messages')
# Add TOKEN_PSK in the correct place in the correct message
if index % 2 == 0:
self.tokens[index//2].insert(0, TOKEN_PSK)
else:
self.tokens[index//2].append(TOKEN_PSK)
elif modifier == 'fallback':
raise NotImplementedError # TODO implement
else:
raise ValueError(f'Unknown pattern modifier {modifier}')
# One-way patterns

View File

@@ -49,10 +49,10 @@ class SymmetricState(object):
"""
@classmethod
def initialize_symmetric(cls, protocol_name) -> 'SymmetricState':
def initialize_symmetric(cls, noise_protocol: 'NoiseProtocol') -> 'SymmetricState':
"""
:param protocol_name:
:param noise_protocol:
:return:
"""
instance = cls()
@@ -104,7 +104,7 @@ class HandshakeState(object):
The initialize() function takes additional required argument - protocol_name - to provide it to SymmetricState.
"""
@classmethod
def initialize(cls, handshake_pattern: 'Pattern', protocol_name: 'NoiseProtocol', initiator: bool,
def initialize(cls, noise_protocol: 'NoiseProtocol', handshake_pattern: 'Pattern', initiator: bool,
prologue: bytes=b'', s: bytes=None, e: bytes=None, rs: bytes=None,
re: bytes=None) -> 'HandshakeState':
"""
@@ -112,7 +112,7 @@ class HandshakeState(object):
Comments below are mostly copied from specification.
:param handshake_pattern: a valid Pattern instance (see Section 7 of specification (rev 32))
:param protocol_name: a valid NoiseProtocol instance
:param noise_protocol: a valid NoiseProtocol instance
:param initiator: boolean indicating the initiator or responder role
:param prologue: byte sequence which may be zero-length, or which may contain context information that both
parties want to confirm is identical
@@ -128,11 +128,11 @@ class HandshakeState(object):
# Originally in specification:
# "Derives a protocol_name byte sequence by combining the names for
# the handshake pattern and crypto functions, as specified in Section 8."
# Instead, we supply the protocol name to the function. It should already be validated. We only check if the
# handshake pattern specified as an argument is the same as in the protocol name
# Instead, we supply the NoiseProtocol to the function. The protocol name should already be validated.
# We only check if the handshake pattern specified as an argument is the same as in the protocol name
# Calls InitializeSymmetric(protocol_name)
instance.symmetric_state = SymmetricState.initialize_symmetric(protocol_name)
# Calls InitializeSymmetric(noise_protocol)
instance.symmetric_state = SymmetricState.initialize_symmetric(noise_protocol)
# Calls MixHash(prologue)
instance.symmetric_state.mix_hash(prologue)

View File

@@ -30,7 +30,7 @@ def test_vector(vector):
logging.info(f"Testing vector {vector['protocol_name']}")
init_protocol = NoiseProtocol(vector['protocol_name'])
resp_protocol = NoiseProtocol(vector['protocol_name'])
initiator = HandshakeState.initialize(handshake_pattern=init_protocol.pattern, protocol_name=init_protocol.name,
initiator = HandshakeState.initialize(noise_protocol=init_protocol, handshake_pattern=init_protocol.pattern,
initiator=True, prologue=vector['init_prologue'])
responder = HandshakeState.initialize(handshake_pattern=resp_protocol.pattern, protocol_name=resp_protocol.name,
responder = HandshakeState.initialize(noise_protocol=resp_protocol, handshake_pattern=resp_protocol.pattern,
initiator=True, prologue=vector['resp_prologue'])