From 865bbfe5ba28ee29709f224b37ce0f0e529970cf Mon Sep 17 00:00:00 2001 From: Piotr Lizonczyk Date: Sat, 2 Sep 2017 17:32:47 +0200 Subject: [PATCH] Implemented cipher rekeying noise/builder.py - Added methods for rekeying cipherstates - Added method for getting handshake hash (for channel binding) noise/functions.py - Added default rekey behavior and set it for AESGCM and ChaCha20 noise/constants.py - Added MAX_NONCE noise/state.py - Added rekey method to CipherState - Removed writing to noise_protocol instance in SymmetricState. NoiseProtocol fills the appropriate field by taking the data from HandshakeState now. --- noise/builder.py | 13 +++++++++++-- noise/constants.py | 2 ++ noise/functions.py | 7 ++++++- noise/noise_protocol.py | 3 ++- noise/state.py | 8 +++++--- 5 files changed, 26 insertions(+), 7 deletions(-) diff --git a/noise/builder.py b/noise/builder.py index b4b5562..1485886 100644 --- a/noise/builder.py +++ b/noise/builder.py @@ -69,11 +69,11 @@ class NoiseBuilder(object): self.noise_protocol.initiator = False self._next_fn = self.read_message - def set_keypair_from_private_bytes(self, keypair, private_bytes: bytes): + def set_keypair_from_private_bytes(self, keypair: Keypair, private_bytes: bytes): self.noise_protocol.keypairs[_keypairs[keypair]] = \ self.noise_protocol.dh_fn.keypair_cls.from_private_bytes(private_bytes) - def set_keypair_from_public_bytes(self, keypair, private_bytes: bytes): + def set_keypair_from_public_bytes(self, keypair: Keypair, private_bytes: bytes): self.noise_protocol.keypairs[_keypairs[keypair]] = \ self.noise_protocol.dh_fn.keypair_cls.from_public_bytes(private_bytes) @@ -129,3 +129,12 @@ class NoiseBuilder(object): def decrypt(self, data: bytes): return self.noise_protocol.cipher_state_decrypt.decrypt_with_ad(None, data) + + def get_handshake_hash(self) -> bytes: + return self.noise_protocol.handshake_hash + + def rekey_inbound_cipher(self): + self.noise_protocol.cipher_state_decrypt.rekey() + + def rekey_outbound_cipher(self): + self.noise_protocol.cipher_state_encrypt.rekey() diff --git a/noise/constants.py b/noise/constants.py index c32e42e..05eaa10 100644 --- a/noise/constants.py +++ b/noise/constants.py @@ -16,3 +16,5 @@ TOKEN_PSK = 'psk' MAX_PROTOCOL_NAME_LEN = 255 MAX_MESSAGE_LEN = 65535 + +MAX_NONCE = 2 ** 64 - 1 diff --git a/noise/functions.py b/noise/functions.py index 9e299a6..3d4028e 100644 --- a/noise/functions.py +++ b/noise/functions.py @@ -9,7 +9,7 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import x25519 from cryptography.hazmat.primitives.ciphers.aead import AESGCM, ChaCha20Poly1305 # from cryptography.hazmat.primitives.hmac import HMAC # Turn back on when Cryptography gets fixed - +from noise.constants import MAX_NONCE from .crypto import X448 backend = default_backend() @@ -53,10 +53,12 @@ class Cipher(object): self._cipher = AESGCM self.encrypt = self._aesgcm_encrypt self.decrypt = self._aesgcm_decrypt + self.rekey = self._default_rekey elif method == 'ChaCha20': self._cipher = ChaCha20Poly1305 self.encrypt = self._chacha20_encrypt self.decrypt = self._chacha20_decrypt + self.rekey = self._default_rekey else: raise NotImplementedError('Cipher method: {}'.format(method)) @@ -85,6 +87,9 @@ class Cipher(object): def _chacha20_nonce(self, n): return b'\x00\x00\x00\x00' + n.to_bytes(length=8, byteorder='little') + def _default_rekey(self, k): + return self.encrypt(k, MAX_NONCE, b'', b'\x00' * 32)[:32] + class Hash(object): def __init__(self, method): diff --git a/noise/noise_protocol.py b/noise/noise_protocol.py index 032ceb8..a87c39f 100644 --- a/noise/noise_protocol.py +++ b/noise/noise_protocol.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Tuple, List +from typing import Tuple from noise.exceptions import NoiseProtocolNameError, NoisePSKError from noise.state import HandshakeState @@ -133,3 +133,4 @@ class NoiseProtocol(object): if value: kwargs[keypair] = value self.handshake_state = HandshakeState.initialize(self, **kwargs) + self.symmetric_state = self.handshake_state.symmetric_state diff --git a/noise/state.py b/noise/state.py index a84f3a7..f6e30f9 100644 --- a/noise/state.py +++ b/noise/state.py @@ -1,6 +1,6 @@ from typing import Union -from .constants import Empty, TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK +from .constants import Empty, TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK, MAX_NONCE class CipherState(object): @@ -34,7 +34,7 @@ class CipherState(object): :param plaintext: bytes sequence :return: ciphertext bytes sequence """ - if self.n == 2**64 - 1: + if self.n == MAX_NONCE: raise Exception('Nonce has depleted!') if not self.has_key(): @@ -62,6 +62,9 @@ class CipherState(object): self.n = self.n + 1 return plaintext + def rekey(self): + self.k = self.noise_protocol.cipher_fn.rekey(self.k) + class SymmetricState(object): """ @@ -87,7 +90,6 @@ class SymmetricState(object): # Create SymmetricState instance = cls() instance.noise_protocol = noise_protocol - noise_protocol.symmetric_state = instance # If protocol_name is less than or equal to HASHLEN bytes in length, sets h equal to protocol_name with zero # bytes appended to make HASHLEN bytes. Otherwise sets h = HASH(protocol_name).