mirror of
https://github.com/morgan9e/noiseprotocol
synced 2026-04-14 00:14:05 +09:00
Implementing write_message and read_message
noise/state.py * Implemented HandshakeState's write_message and read_message * Added variable placeholders in HandshakeState.__init__ noise/functions.py * Refactored KeyPair into abstract class * KeyPair25519 implements KeyPair with appropriate ed25519 methods noise/noise_protocol.py * Now holds proper KeyPair wrapper (chosen based on DH) tests/test_vectors.py * Skipping psk tests for now
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import abc
|
||||
|
||||
from .crypto import ed448
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
@@ -18,9 +20,9 @@ class DH(object):
|
||||
else:
|
||||
raise NotImplementedError('DH method: {}'.format(method))
|
||||
|
||||
def _25519_generate_keypair(self) -> 'KeyPair':
|
||||
def _25519_generate_keypair(self) -> '_KeyPair':
|
||||
private_key = x25519.X25519PrivateKey.generate()
|
||||
return KeyPair(private_key, private_key.public_key())
|
||||
return _KeyPair(private_key, private_key.public_key())
|
||||
|
||||
def _25519_dh(self, keypair: 'x25519.X25519PrivateKey', public_key: 'x25519.X25519PublicKey') -> bytes:
|
||||
return keypair.exchange(public_key)
|
||||
@@ -90,19 +92,33 @@ class Hash(object):
|
||||
return digest.finalize()
|
||||
|
||||
|
||||
class KeyPair(object):
|
||||
class _KeyPair(object):
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
def __init__(self, private=None, public=None):
|
||||
self.private = private
|
||||
self.public = public
|
||||
|
||||
@classmethod
|
||||
def _25519_from_private_bytes(cls, private_bytes):
|
||||
@abc.abstractmethod
|
||||
def from_private_bytes(cls, private_bytes):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def from_public_bytes(cls, public_bytes):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class KeyPair25519(_KeyPair):
|
||||
@classmethod
|
||||
def from_private_bytes(cls, private_bytes):
|
||||
private = x25519.X25519PrivateKey._from_private_bytes(private_bytes)
|
||||
public = private.public_key().public_bytes()
|
||||
return cls(private=private, public=public)
|
||||
|
||||
@classmethod
|
||||
def _25519_from_public_bytes(cls, public_bytes):
|
||||
def from_public_bytes(cls, public_bytes):
|
||||
return cls(public=x25519.X25519PublicKey.from_public_bytes(public_bytes).public_bytes())
|
||||
|
||||
|
||||
@@ -126,3 +142,8 @@ hash_map = {
|
||||
'SHA256': Hash('SHA256'),
|
||||
'SHA512': Hash('SHA512')
|
||||
}
|
||||
|
||||
keypair_map = {
|
||||
'25519': KeyPair25519,
|
||||
# '448': DH('ed448') # TODO uncomment when ed448 is implemented
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Tuple
|
||||
|
||||
from .constants import MAX_PROTOCOL_NAME_LEN, Empty
|
||||
from .functions import dh_map, cipher_map, hash_map
|
||||
from .functions import dh_map, cipher_map, hash_map, keypair_map, KeyPair25519
|
||||
from .patterns import patterns_map
|
||||
|
||||
|
||||
@@ -13,7 +13,8 @@ class NoiseProtocol(object):
|
||||
'pattern': patterns_map,
|
||||
'dh': dh_map,
|
||||
'cipher': cipher_map,
|
||||
'hash': hash_map
|
||||
'hash': hash_map,
|
||||
'keypair': keypair_map
|
||||
}
|
||||
|
||||
def __init__(self, protocol_name: bytes):
|
||||
@@ -34,6 +35,7 @@ class NoiseProtocol(object):
|
||||
self.dh_fn = mappings['dh']
|
||||
self.cipher_fn = mappings['cipher']
|
||||
self.hash_fn = mappings['hash']
|
||||
self.keypair_fn = mappings['keypair']
|
||||
|
||||
self.psks = None # Placeholder for PSKs
|
||||
|
||||
@@ -62,6 +64,7 @@ class NoiseProtocol(object):
|
||||
'dh': unpacked[2],
|
||||
'cipher': unpacked[3],
|
||||
'hash': unpacked[4],
|
||||
'keypair': unpacked[2],
|
||||
'pattern_modifiers': modifiers}
|
||||
|
||||
mapped_data = {}
|
||||
|
||||
135
noise/state.py
135
noise/state.py
@@ -1,4 +1,4 @@
|
||||
from .constants import Empty
|
||||
from .constants import Empty, TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK
|
||||
|
||||
|
||||
class CipherState(object):
|
||||
@@ -139,6 +139,16 @@ class HandshakeState(object):
|
||||
|
||||
The initialize() function takes different required argument - noise_protocol, which contains handshake_pattern.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.noise_protocol = None
|
||||
self.symmetric_state = None
|
||||
self.initiator = None
|
||||
self.s = None
|
||||
self.e = None
|
||||
self.rs = None
|
||||
self.re = None
|
||||
self.message_patterns = None
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, noise_protocol: 'NoiseProtocol', initiator: bool, prologue: bytes=b'', s: bytes=None,
|
||||
e: bytes=None, rs: bytes=None, re: bytes=None) -> 'HandshakeState':
|
||||
@@ -196,23 +206,122 @@ class HandshakeState(object):
|
||||
|
||||
return instance
|
||||
|
||||
def write_message(self, payload, message_buffer):
|
||||
def write_message(self, payload: bytes, message_buffer):
|
||||
"""
|
||||
|
||||
:param payload:
|
||||
:param message_buffer:
|
||||
:return:
|
||||
Comments below are mostly copied from specification.
|
||||
:param payload: byte sequence which may be zero-length
|
||||
:param message_buffer: buffer-like object
|
||||
:return: None or result of SymmetricState.split() - tuple (CipherState, CipherState)
|
||||
"""
|
||||
pass
|
||||
# Fetches and deletes the next message pattern from message_patterns, then sequentially processes each token
|
||||
# from the message pattern
|
||||
message_pattern = self.message_patterns.pop(0)
|
||||
for token in message_pattern:
|
||||
if token == TOKEN_E:
|
||||
# Sets e = GENERATE_KEYPAIR(). Appends e.public_key to the buffer. Calls MixHash(e.public_key)
|
||||
self.e = self.noise_protocol.dh_fn.generate_keypair()
|
||||
message_buffer.write(self.e.public)
|
||||
self.symmetric_state.mix_hash(self.e.public)
|
||||
|
||||
def read_message(self, message, payload_buffer):
|
||||
elif token == TOKEN_S:
|
||||
# Appends EncryptAndHash(s.public_key) to the buffer
|
||||
message_buffer.write(self.symmetric_state.encrypt_and_hash(self.s.public))
|
||||
|
||||
elif token == TOKEN_EE:
|
||||
# Calls MixKey(DH(e, re))
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.re.public))
|
||||
|
||||
elif token == TOKEN_ES:
|
||||
# Calls MixKey(DH(e, rs)) if initiator, MixKey(DH(s, re)) if responder
|
||||
if self.initiator:
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.rs.public))
|
||||
else:
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.re.public))
|
||||
|
||||
elif token == TOKEN_SE:
|
||||
# Calls MixKey(DH(s, re)) if initiator, MixKey(DH(e, rs)) if responder
|
||||
if self.initiator:
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.re.public))
|
||||
else:
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.rs.public))
|
||||
|
||||
elif token == TOKEN_SS:
|
||||
# Calls MixKey(DH(s, rs))
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.rs.public))
|
||||
|
||||
elif token == TOKEN_PSK:
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
raise NotImplementedError('Pattern token: {}'.format(token))
|
||||
|
||||
# Appends EncryptAndHash(payload) to the buffer
|
||||
message_buffer.write(self.symmetric_state.encrypt_and_hash(payload))
|
||||
|
||||
# If there are no more message patterns returns two new CipherState objects by calling Split()
|
||||
if len(self.message_patterns) == 0:
|
||||
return self.symmetric_state.split()
|
||||
|
||||
def read_message(self, message: bytes, payload_buffer):
|
||||
"""
|
||||
|
||||
:param message:
|
||||
:param payload_buffer:
|
||||
:return:
|
||||
Comments below are mostly copied from specification.
|
||||
:param message: byte sequence containing a Noise handshake message
|
||||
:param payload_buffer: buffer-like object
|
||||
:return: None or result of SymmetricState.split() - tuple (CipherState, CipherState)
|
||||
"""
|
||||
pass
|
||||
# Fetches and deletes the next message pattern from message_patterns, then sequentially processes each token
|
||||
# from the message pattern
|
||||
dhlen = self.noise_protocol.dh_fn.dhlen
|
||||
message_pattern = self.message_patterns.pop(0)
|
||||
for token in message_pattern:
|
||||
if token == TOKEN_E:
|
||||
# Sets re to the next DHLEN bytes from the message. Calls MixHash(re.public_key).
|
||||
self.re = self.noise_protocol.keypair_fn.from_public_bytes(message.read(dhlen))
|
||||
self.symmetric_state.mix_hash(self.re.public)
|
||||
|
||||
elif token == TOKEN_S:
|
||||
# Sets temp to the next DHLEN + 16 bytes of the message if HasKey() == True, or to the next DHLEN bytes
|
||||
# otherwise. Sets rs to DecryptAndHash(temp).
|
||||
if self.noise_protocol.cipher_state.has_key():
|
||||
temp = message.read(dhlen + 16)
|
||||
else:
|
||||
temp = message.read(dhlen)
|
||||
self.rs = self.symmetric_state.decrypt_and_hash(temp)
|
||||
|
||||
elif token == TOKEN_EE:
|
||||
# Calls MixKey(DH(e, re)).
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.re.public))
|
||||
|
||||
elif token == TOKEN_ES:
|
||||
# Calls MixKey(DH(e, rs)) if initiator, MixKey(DH(s, re)) if responder
|
||||
if self.initiator:
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.rs.public))
|
||||
else:
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.re.public))
|
||||
|
||||
elif token == TOKEN_SE:
|
||||
# Calls MixKey(DH(s, re)) if initiator, MixKey(DH(e, rs)) if responder
|
||||
if self.initiator:
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.re.public))
|
||||
else:
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.rs.public))
|
||||
|
||||
elif token == TOKEN_SS:
|
||||
# Calls MixKey(DH(s, rs))
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.rs.public))
|
||||
|
||||
elif token == TOKEN_PSK:
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
raise NotImplementedError('Pattern token: {}'.format(token))
|
||||
|
||||
# Calls DecryptAndHash() on the remaining bytes of the message and stores the output into payload_buffer.
|
||||
payload_buffer.write(self.symmetric_state.decrypt_and_hash(message)) # TODO remaining bytes!
|
||||
|
||||
# If there are no more message patterns returns two new CipherState objects by calling Split()
|
||||
if len(self.message_patterns) == 0:
|
||||
return self.symmetric_state.split()
|
||||
|
||||
def _get_local_keypair(self, token: str) -> 'KeyPair':
|
||||
keypair = getattr(self, token) # Maybe explicitly handle exception when getting improper keypair
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from noise.functions import KeyPair
|
||||
from noise.functions import KeyPair25519
|
||||
from noise.state import HandshakeState
|
||||
from noise.noise_protocol import NoiseProtocol
|
||||
|
||||
@@ -29,8 +29,8 @@ def _prepare_test_vectors():
|
||||
vectors_list = json.load(fd)
|
||||
|
||||
for vector in vectors_list:
|
||||
if '_448_' in vector['protocol_name'] or 'ChaCha' in vector['protocol_name']:
|
||||
continue # TODO REMOVE WHEN ed448/ChaCha SUPPORT IS IMPLEMENTED
|
||||
if '_448_' in vector['protocol_name'] or 'ChaCha' in vector['protocol_name'] or 'psk' in vector['protocol_name']:
|
||||
continue # TODO REMOVE WHEN ed448/ChaCha/psk SUPPORT IS IMPLEMENTED
|
||||
for key, value in vector.copy().items():
|
||||
if key in byte_fields:
|
||||
vector[key] = value.encode()
|
||||
@@ -59,9 +59,9 @@ class TestVectors(object):
|
||||
role_key = role + '_' + key
|
||||
if role_key in vector:
|
||||
if key in ['static', 'ephemeral']:
|
||||
kwargs[role][kwarg] = KeyPair._25519_from_private_bytes(vector[role_key]) # TODO unify after adding 448
|
||||
kwargs[role][kwarg] = KeyPair25519.from_private_bytes(vector[role_key]) # TODO unify after adding 448
|
||||
elif key == 'remote_static':
|
||||
kwargs[role][kwarg] = KeyPair._25519_from_public_bytes(vector[role_key]) # TODO unify after adding 448
|
||||
kwargs[role][kwarg] = KeyPair25519.from_public_bytes(vector[role_key]) # TODO unify after adding 448
|
||||
return kwargs
|
||||
|
||||
def test_vector(self, vector):
|
||||
|
||||
Reference in New Issue
Block a user