mirror of
https://github.com/morgan9e/noiseprotocol
synced 2026-04-14 00:14:05 +09:00
Improvements to initialization of NoiseProtocol
noise/noise_protocol.py * Added validation of given function names vs available crypto methods * Members of NoiseProtocol should now refer to proper classes/methods after initialization of an instance noise/patterns.py * Added method for application of pattern modifiers noise/crypto.py * Provisioned ed448 function noise/state.py * Changed references to NoiseProtocol instances to make it more consistent throughout the code
This commit is contained in:
@@ -13,4 +13,4 @@ TOKEN_PSK = 'psk'
|
||||
|
||||
|
||||
# In bytes, as in Section 8 of specification (rev 32)
|
||||
MAX_PROTOCOL_NAME_LEN = 255
|
||||
MAX_PROTOCOL_NAME_LEN = 255
|
||||
|
||||
2
noise/crypto.py
Normal file
2
noise/crypto.py
Normal file
@@ -0,0 +1,2 @@
|
||||
def ed448(*args, **kwargs):
|
||||
raise NotImplementedError
|
||||
@@ -1,14 +1,18 @@
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
|
||||
from .patterns import patterns_map
|
||||
from .constants import MAX_PROTOCOL_NAME_LEN
|
||||
from .crypto import ed448
|
||||
|
||||
from Crypto.Cipher import AES, ChaCha20
|
||||
from Crypto.Hash import BLAKE2b, BLAKE2s, SHA256, SHA512
|
||||
import ed25519
|
||||
|
||||
|
||||
dh_map = {
|
||||
'25519': ed25519.create_keypair,
|
||||
'448': None # TODO implement
|
||||
'448': ed448 # TODO implement
|
||||
}
|
||||
|
||||
cipher_map = {
|
||||
@@ -17,6 +21,7 @@ cipher_map = {
|
||||
}
|
||||
|
||||
hash_map = {
|
||||
# TODO benchmark vs hashlib implementation
|
||||
'BLAKE2b': BLAKE2b, # TODO PARTIALS
|
||||
'BLAKE2s': BLAKE2s, # TODO PARTIALS
|
||||
'SHA256': SHA256, # TODO PARTIALS
|
||||
@@ -25,35 +30,46 @@ hash_map = {
|
||||
|
||||
|
||||
class NoiseProtocol(object):
|
||||
"""
|
||||
TODO: Document
|
||||
"""
|
||||
methods = {
|
||||
'pattern': patterns_map,
|
||||
'dh': dh_map,
|
||||
|
||||
'cipher': cipher_map,
|
||||
'hash': hash_map
|
||||
}
|
||||
|
||||
def __init__(self, protocol_name: bytes):
|
||||
if len(protocol_name) > MAX_PROTOCOL_NAME_LEN:
|
||||
raise ValueError('Protocol name too long, has to be at most {} chars long'.format(MAX_PROTOCOL_NAME_LEN))
|
||||
|
||||
self.name = protocol_name
|
||||
data_dict = self._split_protocol_name()
|
||||
self.pattern = patterns_map[data_dict['pattern']]
|
||||
self.pattern_modifiers = None
|
||||
self.dh = None
|
||||
self.cipher = None
|
||||
self.hash = None
|
||||
mappings, pattern_modifiers = self._parse_protocol_name()
|
||||
|
||||
def _split_protocol_name(self):
|
||||
self.pattern = mappings['pattern']()
|
||||
self.pattern_modifiers = pattern_modifiers
|
||||
if self.pattern_modifiers:
|
||||
self.pattern.apply_pattern_modifiers(pattern_modifiers)
|
||||
|
||||
self.dh = mappings['pattern']
|
||||
self.cipher = mappings['pattern']
|
||||
self.hash = mappings['pattern']
|
||||
|
||||
def _parse_protocol_name(self) -> Tuple[dict, list]:
|
||||
unpacked = self.name.split('_')
|
||||
if unpacked[0] != 'Noise':
|
||||
raise ValueError(f'Noise protocol name shall begin with Noise! Provided: {self.name}')
|
||||
raise ValueError(f'Noise Protocol name shall begin with Noise! Provided: {self.name}')
|
||||
|
||||
# Extract pattern name and pattern modifiers
|
||||
pattern = ''
|
||||
modifiers_str = None
|
||||
for i, char in enumerate(unpacked[1]):
|
||||
if char.isupper():
|
||||
pattern += char
|
||||
else:
|
||||
modifiers_str = unpacked[1][i+1:] # Will be empty string if it exceeds string size
|
||||
# End of pattern, now look for modifiers
|
||||
modifiers_str = unpacked[1][i:] # Will be empty string if it exceeds string size
|
||||
break
|
||||
modifiers = modifiers_str.split('+') if modifiers_str else []
|
||||
|
||||
@@ -63,10 +79,16 @@ class NoiseProtocol(object):
|
||||
'hash': unpacked[4],
|
||||
'pattern_modifiers': modifiers}
|
||||
|
||||
# Validate if we know everything that Noise Protocol is supposed to use
|
||||
# TODO validation
|
||||
mapped_data = {}
|
||||
|
||||
return data
|
||||
# Validate if we know everything that Noise Protocol is supposed to use and map appropriate functions
|
||||
for key, map_dict in self.methods.items():
|
||||
func = map_dict.get(data[key])
|
||||
if not func:
|
||||
raise ValueError(f'Unknown {key} in Noise Protocol name, given {data[key]}, known {" ".join(map_dict)}')
|
||||
mapped_data[key] = func
|
||||
|
||||
return mapped_data, modifiers
|
||||
|
||||
|
||||
class KeyPair(object):
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import List
|
||||
|
||||
from .constants import TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK
|
||||
|
||||
|
||||
@@ -8,12 +10,13 @@ class Pattern(object):
|
||||
# As per specification, if both parties have pre-messages, the initiator is listed first. To reduce complexity,
|
||||
# pre_messages shall be a list of two lists:
|
||||
# the first for the initiator's pre-messages, the second for the responder
|
||||
pre_messages = [
|
||||
pre_messages: List[list] = [
|
||||
[],
|
||||
[]
|
||||
]
|
||||
# TODO Comment
|
||||
tokens = []
|
||||
|
||||
# List of lists of valid tokens, alternating between tokens for initiator and responder
|
||||
tokens: List[list] = []
|
||||
|
||||
def __init__(self):
|
||||
self.has_pre_messages = any(map(lambda x: len(x) > 0, self.pre_messages))
|
||||
@@ -24,6 +27,30 @@ class Pattern(object):
|
||||
def get_responder_pre_messages(self) -> list:
|
||||
return self.pre_messages[1]
|
||||
|
||||
def apply_pattern_modifiers(self, modifiers: List[str]) -> None:
|
||||
# Applies given pattern modifiers to self.tokens of the Pattern instance.
|
||||
for modifier in modifiers:
|
||||
if modifier.startswith('psk'):
|
||||
try:
|
||||
index = int(modifier.replace('psk', '', 1))
|
||||
except ValueError:
|
||||
raise ValueError(f'Improper psk modifier {modifier}')
|
||||
|
||||
if index // 2 > len(self.tokens):
|
||||
raise ValueError(f'Modifier {modifier} cannot be applied - pattern has not enough messages')
|
||||
|
||||
# Add TOKEN_PSK in the correct place in the correct message
|
||||
if index % 2 == 0:
|
||||
self.tokens[index//2].insert(0, TOKEN_PSK)
|
||||
else:
|
||||
self.tokens[index//2].append(TOKEN_PSK)
|
||||
|
||||
elif modifier == 'fallback':
|
||||
raise NotImplementedError # TODO implement
|
||||
|
||||
else:
|
||||
raise ValueError(f'Unknown pattern modifier {modifier}')
|
||||
|
||||
|
||||
# One-way patterns
|
||||
|
||||
|
||||
@@ -49,10 +49,10 @@ class SymmetricState(object):
|
||||
|
||||
"""
|
||||
@classmethod
|
||||
def initialize_symmetric(cls, protocol_name) -> 'SymmetricState':
|
||||
def initialize_symmetric(cls, noise_protocol: 'NoiseProtocol') -> 'SymmetricState':
|
||||
"""
|
||||
|
||||
:param protocol_name:
|
||||
:param noise_protocol:
|
||||
:return:
|
||||
"""
|
||||
instance = cls()
|
||||
@@ -104,7 +104,7 @@ class HandshakeState(object):
|
||||
The initialize() function takes additional required argument - protocol_name - to provide it to SymmetricState.
|
||||
"""
|
||||
@classmethod
|
||||
def initialize(cls, handshake_pattern: 'Pattern', protocol_name: 'NoiseProtocol', initiator: bool,
|
||||
def initialize(cls, noise_protocol: 'NoiseProtocol', handshake_pattern: 'Pattern', initiator: bool,
|
||||
prologue: bytes=b'', s: bytes=None, e: bytes=None, rs: bytes=None,
|
||||
re: bytes=None) -> 'HandshakeState':
|
||||
"""
|
||||
@@ -112,7 +112,7 @@ class HandshakeState(object):
|
||||
Comments below are mostly copied from specification.
|
||||
|
||||
:param handshake_pattern: a valid Pattern instance (see Section 7 of specification (rev 32))
|
||||
:param protocol_name: a valid NoiseProtocol instance
|
||||
:param noise_protocol: a valid NoiseProtocol instance
|
||||
:param initiator: boolean indicating the initiator or responder role
|
||||
:param prologue: byte sequence which may be zero-length, or which may contain context information that both
|
||||
parties want to confirm is identical
|
||||
@@ -128,11 +128,11 @@ class HandshakeState(object):
|
||||
# Originally in specification:
|
||||
# "Derives a protocol_name byte sequence by combining the names for
|
||||
# the handshake pattern and crypto functions, as specified in Section 8."
|
||||
# Instead, we supply the protocol name to the function. It should already be validated. We only check if the
|
||||
# handshake pattern specified as an argument is the same as in the protocol name
|
||||
# Instead, we supply the NoiseProtocol to the function. The protocol name should already be validated.
|
||||
# We only check if the handshake pattern specified as an argument is the same as in the protocol name
|
||||
|
||||
# Calls InitializeSymmetric(protocol_name)
|
||||
instance.symmetric_state = SymmetricState.initialize_symmetric(protocol_name)
|
||||
# Calls InitializeSymmetric(noise_protocol)
|
||||
instance.symmetric_state = SymmetricState.initialize_symmetric(noise_protocol)
|
||||
|
||||
# Calls MixHash(prologue)
|
||||
instance.symmetric_state.mix_hash(prologue)
|
||||
|
||||
@@ -30,7 +30,7 @@ def test_vector(vector):
|
||||
logging.info(f"Testing vector {vector['protocol_name']}")
|
||||
init_protocol = NoiseProtocol(vector['protocol_name'])
|
||||
resp_protocol = NoiseProtocol(vector['protocol_name'])
|
||||
initiator = HandshakeState.initialize(handshake_pattern=init_protocol.pattern, protocol_name=init_protocol.name,
|
||||
initiator = HandshakeState.initialize(noise_protocol=init_protocol, handshake_pattern=init_protocol.pattern,
|
||||
initiator=True, prologue=vector['init_prologue'])
|
||||
responder = HandshakeState.initialize(handshake_pattern=resp_protocol.pattern, protocol_name=resp_protocol.name,
|
||||
responder = HandshakeState.initialize(noise_protocol=resp_protocol, handshake_pattern=resp_protocol.pattern,
|
||||
initiator=True, prologue=vector['resp_prologue'])
|
||||
|
||||
Reference in New Issue
Block a user