Implemented cipher rekeying

noise/builder.py
- Added methods for rekeying cipherstates
- Added method for getting handshake hash (for channel binding)

noise/functions.py
- Added default rekey behavior and set it for AESGCM and ChaCha20

noise/constants.py
- Added MAX_NONCE

noise/state.py
- Added rekey method to CipherState
- Removed writing to noise_protocol instance in SymmetricState.
NoiseProtocol fills the appropriate field by taking the data from
HandshakeState now.
This commit is contained in:
Piotr Lizonczyk
2017-09-02 17:32:47 +02:00
parent 46825bb075
commit 865bbfe5ba
5 changed files with 26 additions and 7 deletions

View File

@@ -69,11 +69,11 @@ class NoiseBuilder(object):
self.noise_protocol.initiator = False
self._next_fn = self.read_message
def set_keypair_from_private_bytes(self, keypair, private_bytes: bytes):
def set_keypair_from_private_bytes(self, keypair: Keypair, private_bytes: bytes):
self.noise_protocol.keypairs[_keypairs[keypair]] = \
self.noise_protocol.dh_fn.keypair_cls.from_private_bytes(private_bytes)
def set_keypair_from_public_bytes(self, keypair, private_bytes: bytes):
def set_keypair_from_public_bytes(self, keypair: Keypair, private_bytes: bytes):
self.noise_protocol.keypairs[_keypairs[keypair]] = \
self.noise_protocol.dh_fn.keypair_cls.from_public_bytes(private_bytes)
@@ -129,3 +129,12 @@ class NoiseBuilder(object):
def decrypt(self, data: bytes):
return self.noise_protocol.cipher_state_decrypt.decrypt_with_ad(None, data)
def get_handshake_hash(self) -> bytes:
return self.noise_protocol.handshake_hash
def rekey_inbound_cipher(self):
self.noise_protocol.cipher_state_decrypt.rekey()
def rekey_outbound_cipher(self):
self.noise_protocol.cipher_state_encrypt.rekey()

View File

@@ -16,3 +16,5 @@ TOKEN_PSK = 'psk'
MAX_PROTOCOL_NAME_LEN = 255
MAX_MESSAGE_LEN = 65535
MAX_NONCE = 2 ** 64 - 1

View File

@@ -9,7 +9,7 @@ from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import x25519
from cryptography.hazmat.primitives.ciphers.aead import AESGCM, ChaCha20Poly1305
# from cryptography.hazmat.primitives.hmac import HMAC # Turn back on when Cryptography gets fixed
from noise.constants import MAX_NONCE
from .crypto import X448
backend = default_backend()
@@ -53,10 +53,12 @@ class Cipher(object):
self._cipher = AESGCM
self.encrypt = self._aesgcm_encrypt
self.decrypt = self._aesgcm_decrypt
self.rekey = self._default_rekey
elif method == 'ChaCha20':
self._cipher = ChaCha20Poly1305
self.encrypt = self._chacha20_encrypt
self.decrypt = self._chacha20_decrypt
self.rekey = self._default_rekey
else:
raise NotImplementedError('Cipher method: {}'.format(method))
@@ -85,6 +87,9 @@ class Cipher(object):
def _chacha20_nonce(self, n):
return b'\x00\x00\x00\x00' + n.to_bytes(length=8, byteorder='little')
def _default_rekey(self, k):
return self.encrypt(k, MAX_NONCE, b'', b'\x00' * 32)[:32]
class Hash(object):
def __init__(self, method):

View File

@@ -1,5 +1,5 @@
from functools import partial
from typing import Tuple, List
from typing import Tuple
from noise.exceptions import NoiseProtocolNameError, NoisePSKError
from noise.state import HandshakeState
@@ -133,3 +133,4 @@ class NoiseProtocol(object):
if value:
kwargs[keypair] = value
self.handshake_state = HandshakeState.initialize(self, **kwargs)
self.symmetric_state = self.handshake_state.symmetric_state

View File

@@ -1,6 +1,6 @@
from typing import Union
from .constants import Empty, TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK
from .constants import Empty, TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK, MAX_NONCE
class CipherState(object):
@@ -34,7 +34,7 @@ class CipherState(object):
:param plaintext: bytes sequence
:return: ciphertext bytes sequence
"""
if self.n == 2**64 - 1:
if self.n == MAX_NONCE:
raise Exception('Nonce has depleted!')
if not self.has_key():
@@ -62,6 +62,9 @@ class CipherState(object):
self.n = self.n + 1
return plaintext
def rekey(self):
self.k = self.noise_protocol.cipher_fn.rekey(self.k)
class SymmetricState(object):
"""
@@ -87,7 +90,6 @@ class SymmetricState(object):
# Create SymmetricState
instance = cls()
instance.noise_protocol = noise_protocol
noise_protocol.symmetric_state = instance
# If protocol_name is less than or equal to HASHLEN bytes in length, sets h equal to protocol_name with zero
# bytes appended to make HASHLEN bytes. Otherwise sets h = HASH(protocol_name).