Implementing write_message and read_message

noise/state.py
* Implemented HandshakeState's write_message and read_message
* Added variable placeholders in HandshakeState.__init__

noise/functions.py
* Refactored KeyPair into abstract class
* KeyPair25519 implements KeyPair with appropriate ed25519 methods

noise/noise_protocol.py
* Now holds proper KeyPair wrapper (chosen based on DH)

tests/test_vectors.py
* Skipping psk tests for now
This commit is contained in:
Piotr Lizonczyk
2017-08-15 00:06:36 +02:00
parent 512feb2029
commit bed5809cc1
4 changed files with 158 additions and 25 deletions

View File

@@ -1,3 +1,5 @@
import abc
from .crypto import ed448
from cryptography.hazmat.backends import default_backend
@@ -18,9 +20,9 @@ class DH(object):
else:
raise NotImplementedError('DH method: {}'.format(method))
def _25519_generate_keypair(self) -> 'KeyPair':
def _25519_generate_keypair(self) -> '_KeyPair':
private_key = x25519.X25519PrivateKey.generate()
return KeyPair(private_key, private_key.public_key())
return _KeyPair(private_key, private_key.public_key())
def _25519_dh(self, keypair: 'x25519.X25519PrivateKey', public_key: 'x25519.X25519PublicKey') -> bytes:
return keypair.exchange(public_key)
@@ -90,19 +92,33 @@ class Hash(object):
return digest.finalize()
class KeyPair(object):
class _KeyPair(object):
__metaclass__ = abc.ABCMeta
def __init__(self, private=None, public=None):
self.private = private
self.public = public
@classmethod
def _25519_from_private_bytes(cls, private_bytes):
@abc.abstractmethod
def from_private_bytes(cls, private_bytes):
raise NotImplementedError
@classmethod
@abc.abstractmethod
def from_public_bytes(cls, public_bytes):
raise NotImplementedError
class KeyPair25519(_KeyPair):
@classmethod
def from_private_bytes(cls, private_bytes):
private = x25519.X25519PrivateKey._from_private_bytes(private_bytes)
public = private.public_key().public_bytes()
return cls(private=private, public=public)
@classmethod
def _25519_from_public_bytes(cls, public_bytes):
def from_public_bytes(cls, public_bytes):
return cls(public=x25519.X25519PublicKey.from_public_bytes(public_bytes).public_bytes())
@@ -126,3 +142,8 @@ hash_map = {
'SHA256': Hash('SHA256'),
'SHA512': Hash('SHA512')
}
keypair_map = {
'25519': KeyPair25519,
# '448': DH('ed448') # TODO uncomment when ed448 is implemented
}

View File

@@ -1,7 +1,7 @@
from typing import Tuple
from .constants import MAX_PROTOCOL_NAME_LEN, Empty
from .functions import dh_map, cipher_map, hash_map
from .functions import dh_map, cipher_map, hash_map, keypair_map, KeyPair25519
from .patterns import patterns_map
@@ -13,7 +13,8 @@ class NoiseProtocol(object):
'pattern': patterns_map,
'dh': dh_map,
'cipher': cipher_map,
'hash': hash_map
'hash': hash_map,
'keypair': keypair_map
}
def __init__(self, protocol_name: bytes):
@@ -34,6 +35,7 @@ class NoiseProtocol(object):
self.dh_fn = mappings['dh']
self.cipher_fn = mappings['cipher']
self.hash_fn = mappings['hash']
self.keypair_fn = mappings['keypair']
self.psks = None # Placeholder for PSKs
@@ -62,6 +64,7 @@ class NoiseProtocol(object):
'dh': unpacked[2],
'cipher': unpacked[3],
'hash': unpacked[4],
'keypair': unpacked[2],
'pattern_modifiers': modifiers}
mapped_data = {}

View File

@@ -1,4 +1,4 @@
from .constants import Empty
from .constants import Empty, TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK
class CipherState(object):
@@ -139,6 +139,16 @@ class HandshakeState(object):
The initialize() function takes different required argument - noise_protocol, which contains handshake_pattern.
"""
def __init__(self):
self.noise_protocol = None
self.symmetric_state = None
self.initiator = None
self.s = None
self.e = None
self.rs = None
self.re = None
self.message_patterns = None
@classmethod
def initialize(cls, noise_protocol: 'NoiseProtocol', initiator: bool, prologue: bytes=b'', s: bytes=None,
e: bytes=None, rs: bytes=None, re: bytes=None) -> 'HandshakeState':
@@ -196,23 +206,122 @@ class HandshakeState(object):
return instance
def write_message(self, payload, message_buffer):
def write_message(self, payload: bytes, message_buffer):
"""
:param payload:
:param message_buffer:
:return:
Comments below are mostly copied from specification.
:param payload: byte sequence which may be zero-length
:param message_buffer: buffer-like object
:return: None or result of SymmetricState.split() - tuple (CipherState, CipherState)
"""
pass
# Fetches and deletes the next message pattern from message_patterns, then sequentially processes each token
# from the message pattern
message_pattern = self.message_patterns.pop(0)
for token in message_pattern:
if token == TOKEN_E:
# Sets e = GENERATE_KEYPAIR(). Appends e.public_key to the buffer. Calls MixHash(e.public_key)
self.e = self.noise_protocol.dh_fn.generate_keypair()
message_buffer.write(self.e.public)
self.symmetric_state.mix_hash(self.e.public)
def read_message(self, message, payload_buffer):
elif token == TOKEN_S:
# Appends EncryptAndHash(s.public_key) to the buffer
message_buffer.write(self.symmetric_state.encrypt_and_hash(self.s.public))
elif token == TOKEN_EE:
# Calls MixKey(DH(e, re))
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.re.public))
elif token == TOKEN_ES:
# Calls MixKey(DH(e, rs)) if initiator, MixKey(DH(s, re)) if responder
if self.initiator:
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.rs.public))
else:
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.re.public))
elif token == TOKEN_SE:
# Calls MixKey(DH(s, re)) if initiator, MixKey(DH(e, rs)) if responder
if self.initiator:
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.re.public))
else:
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.rs.public))
elif token == TOKEN_SS:
# Calls MixKey(DH(s, rs))
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.rs.public))
elif token == TOKEN_PSK:
raise NotImplementedError
else:
raise NotImplementedError('Pattern token: {}'.format(token))
# Appends EncryptAndHash(payload) to the buffer
message_buffer.write(self.symmetric_state.encrypt_and_hash(payload))
# If there are no more message patterns returns two new CipherState objects by calling Split()
if len(self.message_patterns) == 0:
return self.symmetric_state.split()
def read_message(self, message: bytes, payload_buffer):
"""
:param message:
:param payload_buffer:
:return:
Comments below are mostly copied from specification.
:param message: byte sequence containing a Noise handshake message
:param payload_buffer: buffer-like object
:return: None or result of SymmetricState.split() - tuple (CipherState, CipherState)
"""
pass
# Fetches and deletes the next message pattern from message_patterns, then sequentially processes each token
# from the message pattern
dhlen = self.noise_protocol.dh_fn.dhlen
message_pattern = self.message_patterns.pop(0)
for token in message_pattern:
if token == TOKEN_E:
# Sets re to the next DHLEN bytes from the message. Calls MixHash(re.public_key).
self.re = self.noise_protocol.keypair_fn.from_public_bytes(message.read(dhlen))
self.symmetric_state.mix_hash(self.re.public)
elif token == TOKEN_S:
# Sets temp to the next DHLEN + 16 bytes of the message if HasKey() == True, or to the next DHLEN bytes
# otherwise. Sets rs to DecryptAndHash(temp).
if self.noise_protocol.cipher_state.has_key():
temp = message.read(dhlen + 16)
else:
temp = message.read(dhlen)
self.rs = self.symmetric_state.decrypt_and_hash(temp)
elif token == TOKEN_EE:
# Calls MixKey(DH(e, re)).
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.re.public))
elif token == TOKEN_ES:
# Calls MixKey(DH(e, rs)) if initiator, MixKey(DH(s, re)) if responder
if self.initiator:
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.rs.public))
else:
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.re.public))
elif token == TOKEN_SE:
# Calls MixKey(DH(s, re)) if initiator, MixKey(DH(e, rs)) if responder
if self.initiator:
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.re.public))
else:
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.rs.public))
elif token == TOKEN_SS:
# Calls MixKey(DH(s, rs))
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.rs.public))
elif token == TOKEN_PSK:
raise NotImplementedError
else:
raise NotImplementedError('Pattern token: {}'.format(token))
# Calls DecryptAndHash() on the remaining bytes of the message and stores the output into payload_buffer.
payload_buffer.write(self.symmetric_state.decrypt_and_hash(message)) # TODO remaining bytes!
# If there are no more message patterns returns two new CipherState objects by calling Split()
if len(self.message_patterns) == 0:
return self.symmetric_state.split()
def _get_local_keypair(self, token: str) -> 'KeyPair':
keypair = getattr(self, token) # Maybe explicitly handle exception when getting improper keypair

View File

@@ -4,7 +4,7 @@ import os
import pytest
from noise.functions import KeyPair
from noise.functions import KeyPair25519
from noise.state import HandshakeState
from noise.noise_protocol import NoiseProtocol
@@ -29,8 +29,8 @@ def _prepare_test_vectors():
vectors_list = json.load(fd)
for vector in vectors_list:
if '_448_' in vector['protocol_name'] or 'ChaCha' in vector['protocol_name']:
continue # TODO REMOVE WHEN ed448/ChaCha SUPPORT IS IMPLEMENTED
if '_448_' in vector['protocol_name'] or 'ChaCha' in vector['protocol_name'] or 'psk' in vector['protocol_name']:
continue # TODO REMOVE WHEN ed448/ChaCha/psk SUPPORT IS IMPLEMENTED
for key, value in vector.copy().items():
if key in byte_fields:
vector[key] = value.encode()
@@ -59,9 +59,9 @@ class TestVectors(object):
role_key = role + '_' + key
if role_key in vector:
if key in ['static', 'ephemeral']:
kwargs[role][kwarg] = KeyPair._25519_from_private_bytes(vector[role_key]) # TODO unify after adding 448
kwargs[role][kwarg] = KeyPair25519.from_private_bytes(vector[role_key]) # TODO unify after adding 448
elif key == 'remote_static':
kwargs[role][kwarg] = KeyPair._25519_from_public_bytes(vector[role_key]) # TODO unify after adding 448
kwargs[role][kwarg] = KeyPair25519.from_public_bytes(vector[role_key]) # TODO unify after adding 448
return kwargs
def test_vector(self, vector):