Implemented cipher rekeying

noise/builder.py
- Added methods for rekeying cipherstates
- Added method for getting handshake hash (for channel binding)

noise/functions.py
- Added default rekey behavior and set it for AESGCM and ChaCha20

noise/constants.py
- Added MAX_NONCE

noise/state.py
- Added rekey method to CipherState
- Removed writing to noise_protocol instance in SymmetricState.
NoiseProtocol fills the appropriate field by taking the data from
HandshakeState now.
This commit is contained in:
Piotr Lizonczyk
2017-09-02 17:32:47 +02:00
parent 46825bb075
commit 865bbfe5ba
5 changed files with 26 additions and 7 deletions

View File

@@ -69,11 +69,11 @@ class NoiseBuilder(object):
self.noise_protocol.initiator = False self.noise_protocol.initiator = False
self._next_fn = self.read_message self._next_fn = self.read_message
def set_keypair_from_private_bytes(self, keypair, private_bytes: bytes): def set_keypair_from_private_bytes(self, keypair: Keypair, private_bytes: bytes):
self.noise_protocol.keypairs[_keypairs[keypair]] = \ self.noise_protocol.keypairs[_keypairs[keypair]] = \
self.noise_protocol.dh_fn.keypair_cls.from_private_bytes(private_bytes) self.noise_protocol.dh_fn.keypair_cls.from_private_bytes(private_bytes)
def set_keypair_from_public_bytes(self, keypair, private_bytes: bytes): def set_keypair_from_public_bytes(self, keypair: Keypair, private_bytes: bytes):
self.noise_protocol.keypairs[_keypairs[keypair]] = \ self.noise_protocol.keypairs[_keypairs[keypair]] = \
self.noise_protocol.dh_fn.keypair_cls.from_public_bytes(private_bytes) self.noise_protocol.dh_fn.keypair_cls.from_public_bytes(private_bytes)
@@ -129,3 +129,12 @@ class NoiseBuilder(object):
def decrypt(self, data: bytes): def decrypt(self, data: bytes):
return self.noise_protocol.cipher_state_decrypt.decrypt_with_ad(None, data) return self.noise_protocol.cipher_state_decrypt.decrypt_with_ad(None, data)
def get_handshake_hash(self) -> bytes:
return self.noise_protocol.handshake_hash
def rekey_inbound_cipher(self):
self.noise_protocol.cipher_state_decrypt.rekey()
def rekey_outbound_cipher(self):
self.noise_protocol.cipher_state_encrypt.rekey()

View File

@@ -16,3 +16,5 @@ TOKEN_PSK = 'psk'
MAX_PROTOCOL_NAME_LEN = 255 MAX_PROTOCOL_NAME_LEN = 255
MAX_MESSAGE_LEN = 65535 MAX_MESSAGE_LEN = 65535
MAX_NONCE = 2 ** 64 - 1

View File

@@ -9,7 +9,7 @@ from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import x25519 from cryptography.hazmat.primitives.asymmetric import x25519
from cryptography.hazmat.primitives.ciphers.aead import AESGCM, ChaCha20Poly1305 from cryptography.hazmat.primitives.ciphers.aead import AESGCM, ChaCha20Poly1305
# from cryptography.hazmat.primitives.hmac import HMAC # Turn back on when Cryptography gets fixed # from cryptography.hazmat.primitives.hmac import HMAC # Turn back on when Cryptography gets fixed
from noise.constants import MAX_NONCE
from .crypto import X448 from .crypto import X448
backend = default_backend() backend = default_backend()
@@ -53,10 +53,12 @@ class Cipher(object):
self._cipher = AESGCM self._cipher = AESGCM
self.encrypt = self._aesgcm_encrypt self.encrypt = self._aesgcm_encrypt
self.decrypt = self._aesgcm_decrypt self.decrypt = self._aesgcm_decrypt
self.rekey = self._default_rekey
elif method == 'ChaCha20': elif method == 'ChaCha20':
self._cipher = ChaCha20Poly1305 self._cipher = ChaCha20Poly1305
self.encrypt = self._chacha20_encrypt self.encrypt = self._chacha20_encrypt
self.decrypt = self._chacha20_decrypt self.decrypt = self._chacha20_decrypt
self.rekey = self._default_rekey
else: else:
raise NotImplementedError('Cipher method: {}'.format(method)) raise NotImplementedError('Cipher method: {}'.format(method))
@@ -85,6 +87,9 @@ class Cipher(object):
def _chacha20_nonce(self, n): def _chacha20_nonce(self, n):
return b'\x00\x00\x00\x00' + n.to_bytes(length=8, byteorder='little') return b'\x00\x00\x00\x00' + n.to_bytes(length=8, byteorder='little')
def _default_rekey(self, k):
return self.encrypt(k, MAX_NONCE, b'', b'\x00' * 32)[:32]
class Hash(object): class Hash(object):
def __init__(self, method): def __init__(self, method):

View File

@@ -1,5 +1,5 @@
from functools import partial from functools import partial
from typing import Tuple, List from typing import Tuple
from noise.exceptions import NoiseProtocolNameError, NoisePSKError from noise.exceptions import NoiseProtocolNameError, NoisePSKError
from noise.state import HandshakeState from noise.state import HandshakeState
@@ -133,3 +133,4 @@ class NoiseProtocol(object):
if value: if value:
kwargs[keypair] = value kwargs[keypair] = value
self.handshake_state = HandshakeState.initialize(self, **kwargs) self.handshake_state = HandshakeState.initialize(self, **kwargs)
self.symmetric_state = self.handshake_state.symmetric_state

View File

@@ -1,6 +1,6 @@
from typing import Union from typing import Union
from .constants import Empty, TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK from .constants import Empty, TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK, MAX_NONCE
class CipherState(object): class CipherState(object):
@@ -34,7 +34,7 @@ class CipherState(object):
:param plaintext: bytes sequence :param plaintext: bytes sequence
:return: ciphertext bytes sequence :return: ciphertext bytes sequence
""" """
if self.n == 2**64 - 1: if self.n == MAX_NONCE:
raise Exception('Nonce has depleted!') raise Exception('Nonce has depleted!')
if not self.has_key(): if not self.has_key():
@@ -62,6 +62,9 @@ class CipherState(object):
self.n = self.n + 1 self.n = self.n + 1
return plaintext return plaintext
def rekey(self):
self.k = self.noise_protocol.cipher_fn.rekey(self.k)
class SymmetricState(object): class SymmetricState(object):
""" """
@@ -87,7 +90,6 @@ class SymmetricState(object):
# Create SymmetricState # Create SymmetricState
instance = cls() instance = cls()
instance.noise_protocol = noise_protocol instance.noise_protocol = noise_protocol
noise_protocol.symmetric_state = instance
# If protocol_name is less than or equal to HASHLEN bytes in length, sets h equal to protocol_name with zero # If protocol_name is less than or equal to HASHLEN bytes in length, sets h equal to protocol_name with zero
# bytes appended to make HASHLEN bytes. Otherwise sets h = HASH(protocol_name). # bytes appended to make HASHLEN bytes. Otherwise sets h = HASH(protocol_name).