Functioning HandshakeState.initialize()

noise/functions.py
* Wrapped cryptoalgorithms in maps with appropriate wrapper classes
* Probably finished Hash wrapper (to verify when we test first outputs
in tests)

noise/noise_protocol.py
* Slightly renamed attributes containing wrapped cryptoalgorithms
* Added placeholders for protocol State objects
* Now checks if given protocol_name is of bytes()

noise/state.py
* HandshakeState: remove handshake_pattern argument and take it from
given NoiseProtocol instance instead.
* HandshakeState: save NoiseProtocol instance in the HandshakeState
instance and vice versa
* SymmetricState: implemented initialize_symmetric() and mix_hash()
* SymmetricState: save NoiseProtocol instance in the SymmetricState
instance and vice versa
* CipherState: implemented initialize_key() as class constructor
* CipherState: save NoiseProtocol instance in the CipherState
instance and vice versa

tests/test_vectors.py
* Changes to reflect new signature of HandshakeState
* Fix - strings read from .json are now casted to bytes()
This commit is contained in:
Piotr Lizonczyk
2017-08-14 17:18:31 +02:00
parent 2e85d7527b
commit bf054106ff
4 changed files with 142 additions and 55 deletions

View File

@@ -5,25 +5,6 @@ from Crypto.Hash import BLAKE2b, BLAKE2s, SHA256, SHA512
import ed25519
dh_map = {
'25519': ed25519,
'448': ed448 # TODO implement
}
cipher_map = {
'AESGCM': AES,
'ChaChaPoly': ChaCha20
}
hash_map = {
# TODO benchmark pycryptodome vs hashlib implementation
'BLAKE2b': BLAKE2b,
'BLAKE2s': BLAKE2s,
'SHA256': SHA256,
'SHA512': SHA512
}
class DH(object):
def __init__(self, method):
self.method = method
@@ -47,15 +28,38 @@ class Cipher(object):
class Hash(object):
def __init__(self, method):
self.hashlen = 0
self.blocklen = 0
if method == 'SHA256':
self.hashlen = 32
self.blocklen = 64
self.hash = self._hash_sha256
elif method == 'SHA512':
self.hashlen = 64
self.blocklen = 128
self.hash = self._hash_sha512
elif method == 'BLAKE2s':
self.hashlen = 32
self.blocklen = 64
self.hash = self._hash_blake2s
elif method == 'BLAKE2b':
self.hashlen = 64
self.blocklen = 128
self.hash = self._hash_blake2b
def hash(self):
pass
def _hash_sha256(self, data):
return SHA256.new(data).digest()
def _hash_sha512(self, data):
return SHA512.new(data).digest()
def _hash_blake2s(self, data):
return BLAKE2s.new(data=data, digest_bytes=self.hashlen).digest()
def _hash_blake2b(self, data):
return BLAKE2b.new(data=data, digest_bytes=self.hashlen).digest()
class KeyPair(object):
def __init__(self, public='', private=''):
def __init__(self, public=b'', private=b''):
# TODO: Maybe switch to properties?
self.public = public
self.private = private
@@ -64,3 +68,25 @@ class KeyPair(object):
def derive_public_key(self):
pass
# Available crypto functions
# TODO: Check if it's safe to use one instance globally per cryptoalgorithm - i.e. if wrapper only provides interface
# If not - switch to partials(?)
dh_map = {
'25519': DH('ed25519'),
'448': DH('ed448')
}
cipher_map = {
'AESGCM': Cipher('AESGCM'),
'ChaChaPoly': Cipher('ChaCha20')
}
hash_map = {
# TODO benchmark pycryptodome vs hashlib implementation
'BLAKE2s': Hash('BLAKE2s'),
'BLAKE2b': Hash('BLAKE2b'),
'SHA256': Hash('SHA256'),
'SHA512': Hash('SHA512')
}

View File

@@ -1,6 +1,6 @@
from typing import Tuple
from .constants import MAX_PROTOCOL_NAME_LEN
from .constants import MAX_PROTOCOL_NAME_LEN, Empty
from .functions import dh_map, cipher_map, hash_map
from .patterns import patterns_map
@@ -17,25 +17,32 @@ class NoiseProtocol(object):
}
def __init__(self, protocol_name: bytes):
if not isinstance(protocol_name, bytes):
raise ValueError('Protocol name has to be of type "bytes", not {}'.format(type(protocol_name)))
if len(protocol_name) > MAX_PROTOCOL_NAME_LEN:
raise ValueError('Protocol name too long, has to be at most {} chars long'.format(MAX_PROTOCOL_NAME_LEN))
self.name = protocol_name
mappings, pattern_modifiers = self._parse_protocol_name()
# A valid Pattern instance (see Section 7 of specification (rev 32))
self.pattern = mappings['pattern']()
self.pattern_modifiers = pattern_modifiers
if self.pattern_modifiers:
self.pattern.apply_pattern_modifiers(pattern_modifiers)
self.dh = mappings['pattern']
self.cipher = mappings['pattern']
self.hash = mappings['pattern']
self.dh_fn: 'DH' = mappings['dh']
self.cipher_fn: 'Cipher' = mappings['cipher']
self.hash_fn: 'Hash' = mappings['hash']
self.psks = None # Placeholder for PSKs
self.psks: list = None # Placeholder for PSKs
self.handshake_state: 'HandshakeState' = Empty()
self.symmetric_state: 'SymmetricState' = Empty()
self.cipher_state: 'CipherState' = Empty()
def _parse_protocol_name(self) -> Tuple[dict, list]:
unpacked = self.name.split('_')
unpacked = self.name.decode().split('_')
if unpacked[0] != 'Noise':
raise ValueError('Noise Protocol name shall begin with Noise! Provided: {}'.format(self.name))

View File

@@ -3,24 +3,33 @@ from .constants import Empty
class CipherState(object):
"""
Implemented as per Noise Protocol specification (rev 32) - paragraph 5.1.
The initialize_key() function takes additional required argument - noise_protocol.
"""
def __init__(self):
self.k = Empty()
self.n = None
self.noise_protocol = None
def initialize_key(self, key):
@classmethod
def initialize_key(cls, key, noise_protocol: 'NoiseProtocol') -> 'CipherState':
"""
:param key:
:return:
:param noise_protocol: a valid NoiseProtocol instance
:return: initialised CipherState instance
"""
self.k = key
self.n = 0
instance = cls()
instance.noise_protocol = noise_protocol
noise_protocol.cipher_state = instance
instance.k = key
instance.n = 0
return instance
def has_key(self):
"""
:return: True if self.k is not an instance of Empty
"""
return not isinstance(self.k, Empty)
@@ -46,17 +55,43 @@ class CipherState(object):
class SymmetricState(object):
"""
Implemented as per Noise Protocol specification (rev 32) - paragraph 5.2.
The initialize_symmetric function takes different required argument - noise_protocol, which contains protocol_name.
"""
def __init__(self):
self.h = None
self.ck = None
self.noise_protocol = None
@classmethod
def initialize_symmetric(cls, noise_protocol: 'NoiseProtocol') -> 'SymmetricState':
"""
:param noise_protocol:
:return:
Instead of taking protocol_name as an argument, we take full NoiseProtocol object, that way we have access to
protocol name and crypto functions
Comments below are mostly copied from specification.
:param noise_protocol: a valid NoiseProtocol instance
:return: initialised SymmetricState instance
"""
# Create SymmetricState
instance = cls()
# TODO
instance.noise_protocol = noise_protocol
noise_protocol.symmetric_state = instance
# If protocol_name is less than or equal to HASHLEN bytes in length, sets h equal to protocol_name with zero
# bytes appended to make HASHLEN bytes. Otherwise sets h = HASH(protocol_name).
if len(noise_protocol.name) <= noise_protocol.hash_fn.hashlen:
instance.h = noise_protocol.name.ljust(noise_protocol.hash_fn.hashlen, b'\0')
else:
instance.h = noise_protocol.hash_fn.hash(noise_protocol.name)
# Sets ck = h.
instance.ck = instance.h
# Calls InitializeKey(empty).
CipherState.initialize_key(Empty(), noise_protocol)
return instance
def mix_key(self, input_key_material):
@@ -72,6 +107,7 @@ class SymmetricState(object):
:param data:
:return:
"""
self.h = self.noise_protocol.hash_fn.hash(data + self.h)
def encrypt_and_hash(self, plaintext):
"""
@@ -101,17 +137,17 @@ class HandshakeState(object):
"""
Implemented as per Noise Protocol specification (rev 32) - paragraph 5.3.
The initialize() function takes additional required argument - protocol_name - to provide it to SymmetricState.
The initialize() function takes different required argument - noise_protocol, which contains handshake_pattern.
"""
@classmethod
def initialize(cls, noise_protocol: 'NoiseProtocol', handshake_pattern: 'Pattern', initiator: bool,
prologue: bytes=b'', s: bytes=None, e: bytes=None, rs: bytes=None,
re: bytes=None) -> 'HandshakeState':
def initialize(cls, noise_protocol: 'NoiseProtocol', initiator: bool, prologue: bytes=b'', s: bytes=None,
e: bytes=None, rs: bytes=None, re: bytes=None) -> 'HandshakeState':
"""
Constructor method.
Comments below are mostly copied from specification.
Instead of taking handshake_pattern as an argument, we take full NoiseProtocol object, that way we have access
to protocol name and crypto functions
:param handshake_pattern: a valid Pattern instance (see Section 7 of specification (rev 32))
:param noise_protocol: a valid NoiseProtocol instance
:param initiator: boolean indicating the initiator or responder role
:param prologue: byte sequence which may be zero-length, or which may contain context information that both
@@ -124,12 +160,13 @@ class HandshakeState(object):
"""
# Create HandshakeState
instance = cls()
instance.noise_protocol = noise_protocol
noise_protocol.handshake_state = instance
# Originally in specification:
# "Derives a protocol_name byte sequence by combining the names for
# the handshake pattern and crypto functions, as specified in Section 8."
# Instead, we supply the NoiseProtocol to the function. The protocol name should already be validated.
# We only check if the handshake pattern specified as an argument is the same as in the protocol name
# Calls InitializeSymmetric(noise_protocol)
instance.symmetric_state = SymmetricState.initialize_symmetric(noise_protocol)
@@ -149,13 +186,13 @@ class HandshakeState(object):
# hashed first
initiator_keypair_getter = instance._get_local_keypair if initiator else instance._get_remote_keypair
responder_keypair_getter = instance._get_remote_keypair if initiator else instance._get_local_keypair
for keypair in map(initiator_keypair_getter, handshake_pattern.get_initiator_pre_messages()):
for keypair in map(initiator_keypair_getter, noise_protocol.pattern.get_initiator_pre_messages()):
instance.symmetric_state.mix_hash(keypair.public)
for keypair in map(responder_keypair_getter, handshake_pattern.get_responder_pre_messages()):
for keypair in map(responder_keypair_getter, noise_protocol.pattern.get_responder_pre_messages()):
instance.symmetric_state.mix_hash(keypair.public)
# Sets message_patterns to the message patterns from handshake_pattern
instance.message_patterns = handshake_pattern.tokens
instance.message_patterns = noise_protocol.pattern.tokens
return instance

View File

@@ -12,13 +12,32 @@ logger = logging.getLogger(__name__)
vector_files = ['vectors/cacophony.txt']
# As in test vectors specification (https://github.com/noiseprotocol/noise_wiki/wiki/Test-vectors)
# We use this to cast read strings into bytes
string_fields = ['protocol_name', 'init_prologue', 'init_static', 'init_ephemeral', 'init_remote_static',
'resp_prologue', 'resp_static', 'resp_ephemeral', 'resp_remote_static', 'handshake_hash']
list_fields = ['init_psks', 'resp_psks']
dict_field = 'messages'
def _prepare_test_vectors():
vectors = []
for path in vector_files:
with open(os.path.join(os.path.dirname(__file__), path)) as fd:
logging.info('Reading vectors from file {}'.format(path))
vectors.extend(json.load(fd))
vectors_list = json.load(fd)
for vector in vectors_list:
for key, value in vector.copy().items():
if key in string_fields:
vector[key] = value.encode()
if key in list_fields:
vector[key] = [k.encode() for k in value]
if key == dict_field:
vector[key] = []
for dictionary in value:
vector[key].append({k: v.encode() for k, v in dictionary.items()})
vectors.append(vector)
return vectors
@@ -51,10 +70,8 @@ class TestVectors(object):
init_protocol.set_psks(vector['init_psks'])
resp_protocol.set_psks(vector['resp_psks'])
kwargs['init'].update(noise_protocol=init_protocol, handshake_pattern=init_protocol.pattern, initiator=True,
prologue=vector['init_prologue'])
kwargs['resp'].update(noise_protocol=resp_protocol, handshake_pattern=resp_protocol.pattern, initiator=False,
prologue=vector['resp_prologue'])
kwargs['init'].update(noise_protocol=init_protocol, initiator=True, prologue=vector['init_prologue'])
kwargs['resp'].update(noise_protocol=resp_protocol, initiator=False, prologue=vector['resp_prologue'])
initiator = HandshakeState.initialize(**kwargs['init'])
responder = HandshakeState.initialize(**kwargs['resp'])