diff --git a/noise/functions.py b/noise/functions.py index 7c59219..80eb20d 100644 --- a/noise/functions.py +++ b/noise/functions.py @@ -1,3 +1,5 @@ +import abc + from .crypto import ed448 from cryptography.hazmat.backends import default_backend @@ -18,9 +20,9 @@ class DH(object): else: raise NotImplementedError('DH method: {}'.format(method)) - def _25519_generate_keypair(self) -> 'KeyPair': + def _25519_generate_keypair(self) -> '_KeyPair': private_key = x25519.X25519PrivateKey.generate() - return KeyPair(private_key, private_key.public_key()) + return _KeyPair(private_key, private_key.public_key()) def _25519_dh(self, keypair: 'x25519.X25519PrivateKey', public_key: 'x25519.X25519PublicKey') -> bytes: return keypair.exchange(public_key) @@ -90,19 +92,33 @@ class Hash(object): return digest.finalize() -class KeyPair(object): +class _KeyPair(object): + __metaclass__ = abc.ABCMeta + def __init__(self, private=None, public=None): self.private = private self.public = public @classmethod - def _25519_from_private_bytes(cls, private_bytes): + @abc.abstractmethod + def from_private_bytes(cls, private_bytes): + raise NotImplementedError + + @classmethod + @abc.abstractmethod + def from_public_bytes(cls, public_bytes): + raise NotImplementedError + + +class KeyPair25519(_KeyPair): + @classmethod + def from_private_bytes(cls, private_bytes): private = x25519.X25519PrivateKey._from_private_bytes(private_bytes) public = private.public_key().public_bytes() return cls(private=private, public=public) @classmethod - def _25519_from_public_bytes(cls, public_bytes): + def from_public_bytes(cls, public_bytes): return cls(public=x25519.X25519PublicKey.from_public_bytes(public_bytes).public_bytes()) @@ -126,3 +142,8 @@ hash_map = { 'SHA256': Hash('SHA256'), 'SHA512': Hash('SHA512') } + +keypair_map = { + '25519': KeyPair25519, + # '448': DH('ed448') # TODO uncomment when ed448 is implemented +} \ No newline at end of file diff --git a/noise/noise_protocol.py b/noise/noise_protocol.py index 34cb4d4..8689787 100644 --- a/noise/noise_protocol.py +++ b/noise/noise_protocol.py @@ -1,7 +1,7 @@ from typing import Tuple from .constants import MAX_PROTOCOL_NAME_LEN, Empty -from .functions import dh_map, cipher_map, hash_map +from .functions import dh_map, cipher_map, hash_map, keypair_map, KeyPair25519 from .patterns import patterns_map @@ -13,7 +13,8 @@ class NoiseProtocol(object): 'pattern': patterns_map, 'dh': dh_map, 'cipher': cipher_map, - 'hash': hash_map + 'hash': hash_map, + 'keypair': keypair_map } def __init__(self, protocol_name: bytes): @@ -34,6 +35,7 @@ class NoiseProtocol(object): self.dh_fn = mappings['dh'] self.cipher_fn = mappings['cipher'] self.hash_fn = mappings['hash'] + self.keypair_fn = mappings['keypair'] self.psks = None # Placeholder for PSKs @@ -62,6 +64,7 @@ class NoiseProtocol(object): 'dh': unpacked[2], 'cipher': unpacked[3], 'hash': unpacked[4], + 'keypair': unpacked[2], 'pattern_modifiers': modifiers} mapped_data = {} diff --git a/noise/state.py b/noise/state.py index 3151e19..db67d61 100644 --- a/noise/state.py +++ b/noise/state.py @@ -1,4 +1,4 @@ -from .constants import Empty +from .constants import Empty, TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK class CipherState(object): @@ -139,6 +139,16 @@ class HandshakeState(object): The initialize() function takes different required argument - noise_protocol, which contains handshake_pattern. """ + def __init__(self): + self.noise_protocol = None + self.symmetric_state = None + self.initiator = None + self.s = None + self.e = None + self.rs = None + self.re = None + self.message_patterns = None + @classmethod def initialize(cls, noise_protocol: 'NoiseProtocol', initiator: bool, prologue: bytes=b'', s: bytes=None, e: bytes=None, rs: bytes=None, re: bytes=None) -> 'HandshakeState': @@ -196,23 +206,122 @@ class HandshakeState(object): return instance - def write_message(self, payload, message_buffer): + def write_message(self, payload: bytes, message_buffer): """ - - :param payload: - :param message_buffer: - :return: + Comments below are mostly copied from specification. + :param payload: byte sequence which may be zero-length + :param message_buffer: buffer-like object + :return: None or result of SymmetricState.split() - tuple (CipherState, CipherState) """ - pass + # Fetches and deletes the next message pattern from message_patterns, then sequentially processes each token + # from the message pattern + message_pattern = self.message_patterns.pop(0) + for token in message_pattern: + if token == TOKEN_E: + # Sets e = GENERATE_KEYPAIR(). Appends e.public_key to the buffer. Calls MixHash(e.public_key) + self.e = self.noise_protocol.dh_fn.generate_keypair() + message_buffer.write(self.e.public) + self.symmetric_state.mix_hash(self.e.public) - def read_message(self, message, payload_buffer): + elif token == TOKEN_S: + # Appends EncryptAndHash(s.public_key) to the buffer + message_buffer.write(self.symmetric_state.encrypt_and_hash(self.s.public)) + + elif token == TOKEN_EE: + # Calls MixKey(DH(e, re)) + self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.re.public)) + + elif token == TOKEN_ES: + # Calls MixKey(DH(e, rs)) if initiator, MixKey(DH(s, re)) if responder + if self.initiator: + self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.rs.public)) + else: + self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.re.public)) + + elif token == TOKEN_SE: + # Calls MixKey(DH(s, re)) if initiator, MixKey(DH(e, rs)) if responder + if self.initiator: + self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.re.public)) + else: + self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.rs.public)) + + elif token == TOKEN_SS: + # Calls MixKey(DH(s, rs)) + self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.rs.public)) + + elif token == TOKEN_PSK: + raise NotImplementedError + + else: + raise NotImplementedError('Pattern token: {}'.format(token)) + + # Appends EncryptAndHash(payload) to the buffer + message_buffer.write(self.symmetric_state.encrypt_and_hash(payload)) + + # If there are no more message patterns returns two new CipherState objects by calling Split() + if len(self.message_patterns) == 0: + return self.symmetric_state.split() + + def read_message(self, message: bytes, payload_buffer): """ - - :param message: - :param payload_buffer: - :return: + Comments below are mostly copied from specification. + :param message: byte sequence containing a Noise handshake message + :param payload_buffer: buffer-like object + :return: None or result of SymmetricState.split() - tuple (CipherState, CipherState) """ - pass + # Fetches and deletes the next message pattern from message_patterns, then sequentially processes each token + # from the message pattern + dhlen = self.noise_protocol.dh_fn.dhlen + message_pattern = self.message_patterns.pop(0) + for token in message_pattern: + if token == TOKEN_E: + # Sets re to the next DHLEN bytes from the message. Calls MixHash(re.public_key). + self.re = self.noise_protocol.keypair_fn.from_public_bytes(message.read(dhlen)) + self.symmetric_state.mix_hash(self.re.public) + + elif token == TOKEN_S: + # Sets temp to the next DHLEN + 16 bytes of the message if HasKey() == True, or to the next DHLEN bytes + # otherwise. Sets rs to DecryptAndHash(temp). + if self.noise_protocol.cipher_state.has_key(): + temp = message.read(dhlen + 16) + else: + temp = message.read(dhlen) + self.rs = self.symmetric_state.decrypt_and_hash(temp) + + elif token == TOKEN_EE: + # Calls MixKey(DH(e, re)). + self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.re.public)) + + elif token == TOKEN_ES: + # Calls MixKey(DH(e, rs)) if initiator, MixKey(DH(s, re)) if responder + if self.initiator: + self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.rs.public)) + else: + self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.re.public)) + + elif token == TOKEN_SE: + # Calls MixKey(DH(s, re)) if initiator, MixKey(DH(e, rs)) if responder + if self.initiator: + self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.re.public)) + else: + self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.rs.public)) + + elif token == TOKEN_SS: + # Calls MixKey(DH(s, rs)) + self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.rs.public)) + + elif token == TOKEN_PSK: + raise NotImplementedError + + else: + raise NotImplementedError('Pattern token: {}'.format(token)) + + # Calls DecryptAndHash() on the remaining bytes of the message and stores the output into payload_buffer. + payload_buffer.write(self.symmetric_state.decrypt_and_hash(message)) # TODO remaining bytes! + + # If there are no more message patterns returns two new CipherState objects by calling Split() + if len(self.message_patterns) == 0: + return self.symmetric_state.split() def _get_local_keypair(self, token: str) -> 'KeyPair': keypair = getattr(self, token) # Maybe explicitly handle exception when getting improper keypair diff --git a/tests/test_vectors.py b/tests/test_vectors.py index af181d3..525a46a 100644 --- a/tests/test_vectors.py +++ b/tests/test_vectors.py @@ -4,7 +4,7 @@ import os import pytest -from noise.functions import KeyPair +from noise.functions import KeyPair25519 from noise.state import HandshakeState from noise.noise_protocol import NoiseProtocol @@ -29,8 +29,8 @@ def _prepare_test_vectors(): vectors_list = json.load(fd) for vector in vectors_list: - if '_448_' in vector['protocol_name'] or 'ChaCha' in vector['protocol_name']: - continue # TODO REMOVE WHEN ed448/ChaCha SUPPORT IS IMPLEMENTED + if '_448_' in vector['protocol_name'] or 'ChaCha' in vector['protocol_name'] or 'psk' in vector['protocol_name']: + continue # TODO REMOVE WHEN ed448/ChaCha/psk SUPPORT IS IMPLEMENTED for key, value in vector.copy().items(): if key in byte_fields: vector[key] = value.encode() @@ -59,9 +59,9 @@ class TestVectors(object): role_key = role + '_' + key if role_key in vector: if key in ['static', 'ephemeral']: - kwargs[role][kwarg] = KeyPair._25519_from_private_bytes(vector[role_key]) # TODO unify after adding 448 + kwargs[role][kwarg] = KeyPair25519.from_private_bytes(vector[role_key]) # TODO unify after adding 448 elif key == 'remote_static': - kwargs[role][kwarg] = KeyPair._25519_from_public_bytes(vector[role_key]) # TODO unify after adding 448 + kwargs[role][kwarg] = KeyPair25519.from_public_bytes(vector[role_key]) # TODO unify after adding 448 return kwargs def test_vector(self, vector):