mirror of
https://github.com/morgan9e/noiseprotocol
synced 2026-04-14 00:14:05 +09:00
Implemented cipher rekeying
noise/builder.py - Added methods for rekeying cipherstates - Added method for getting handshake hash (for channel binding) noise/functions.py - Added default rekey behavior and set it for AESGCM and ChaCha20 noise/constants.py - Added MAX_NONCE noise/state.py - Added rekey method to CipherState - Removed writing to noise_protocol instance in SymmetricState. NoiseProtocol fills the appropriate field by taking the data from HandshakeState now.
This commit is contained in:
@@ -69,11 +69,11 @@ class NoiseBuilder(object):
|
|||||||
self.noise_protocol.initiator = False
|
self.noise_protocol.initiator = False
|
||||||
self._next_fn = self.read_message
|
self._next_fn = self.read_message
|
||||||
|
|
||||||
def set_keypair_from_private_bytes(self, keypair, private_bytes: bytes):
|
def set_keypair_from_private_bytes(self, keypair: Keypair, private_bytes: bytes):
|
||||||
self.noise_protocol.keypairs[_keypairs[keypair]] = \
|
self.noise_protocol.keypairs[_keypairs[keypair]] = \
|
||||||
self.noise_protocol.dh_fn.keypair_cls.from_private_bytes(private_bytes)
|
self.noise_protocol.dh_fn.keypair_cls.from_private_bytes(private_bytes)
|
||||||
|
|
||||||
def set_keypair_from_public_bytes(self, keypair, private_bytes: bytes):
|
def set_keypair_from_public_bytes(self, keypair: Keypair, private_bytes: bytes):
|
||||||
self.noise_protocol.keypairs[_keypairs[keypair]] = \
|
self.noise_protocol.keypairs[_keypairs[keypair]] = \
|
||||||
self.noise_protocol.dh_fn.keypair_cls.from_public_bytes(private_bytes)
|
self.noise_protocol.dh_fn.keypair_cls.from_public_bytes(private_bytes)
|
||||||
|
|
||||||
@@ -129,3 +129,12 @@ class NoiseBuilder(object):
|
|||||||
|
|
||||||
def decrypt(self, data: bytes):
|
def decrypt(self, data: bytes):
|
||||||
return self.noise_protocol.cipher_state_decrypt.decrypt_with_ad(None, data)
|
return self.noise_protocol.cipher_state_decrypt.decrypt_with_ad(None, data)
|
||||||
|
|
||||||
|
def get_handshake_hash(self) -> bytes:
|
||||||
|
return self.noise_protocol.handshake_hash
|
||||||
|
|
||||||
|
def rekey_inbound_cipher(self):
|
||||||
|
self.noise_protocol.cipher_state_decrypt.rekey()
|
||||||
|
|
||||||
|
def rekey_outbound_cipher(self):
|
||||||
|
self.noise_protocol.cipher_state_encrypt.rekey()
|
||||||
|
|||||||
@@ -16,3 +16,5 @@ TOKEN_PSK = 'psk'
|
|||||||
MAX_PROTOCOL_NAME_LEN = 255
|
MAX_PROTOCOL_NAME_LEN = 255
|
||||||
|
|
||||||
MAX_MESSAGE_LEN = 65535
|
MAX_MESSAGE_LEN = 65535
|
||||||
|
|
||||||
|
MAX_NONCE = 2 ** 64 - 1
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from cryptography.hazmat.backends import default_backend
|
|||||||
from cryptography.hazmat.primitives.asymmetric import x25519
|
from cryptography.hazmat.primitives.asymmetric import x25519
|
||||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM, ChaCha20Poly1305
|
from cryptography.hazmat.primitives.ciphers.aead import AESGCM, ChaCha20Poly1305
|
||||||
# from cryptography.hazmat.primitives.hmac import HMAC # Turn back on when Cryptography gets fixed
|
# from cryptography.hazmat.primitives.hmac import HMAC # Turn back on when Cryptography gets fixed
|
||||||
|
from noise.constants import MAX_NONCE
|
||||||
from .crypto import X448
|
from .crypto import X448
|
||||||
|
|
||||||
backend = default_backend()
|
backend = default_backend()
|
||||||
@@ -53,10 +53,12 @@ class Cipher(object):
|
|||||||
self._cipher = AESGCM
|
self._cipher = AESGCM
|
||||||
self.encrypt = self._aesgcm_encrypt
|
self.encrypt = self._aesgcm_encrypt
|
||||||
self.decrypt = self._aesgcm_decrypt
|
self.decrypt = self._aesgcm_decrypt
|
||||||
|
self.rekey = self._default_rekey
|
||||||
elif method == 'ChaCha20':
|
elif method == 'ChaCha20':
|
||||||
self._cipher = ChaCha20Poly1305
|
self._cipher = ChaCha20Poly1305
|
||||||
self.encrypt = self._chacha20_encrypt
|
self.encrypt = self._chacha20_encrypt
|
||||||
self.decrypt = self._chacha20_decrypt
|
self.decrypt = self._chacha20_decrypt
|
||||||
|
self.rekey = self._default_rekey
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Cipher method: {}'.format(method))
|
raise NotImplementedError('Cipher method: {}'.format(method))
|
||||||
|
|
||||||
@@ -85,6 +87,9 @@ class Cipher(object):
|
|||||||
def _chacha20_nonce(self, n):
|
def _chacha20_nonce(self, n):
|
||||||
return b'\x00\x00\x00\x00' + n.to_bytes(length=8, byteorder='little')
|
return b'\x00\x00\x00\x00' + n.to_bytes(length=8, byteorder='little')
|
||||||
|
|
||||||
|
def _default_rekey(self, k):
|
||||||
|
return self.encrypt(k, MAX_NONCE, b'', b'\x00' * 32)[:32]
|
||||||
|
|
||||||
|
|
||||||
class Hash(object):
|
class Hash(object):
|
||||||
def __init__(self, method):
|
def __init__(self, method):
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Tuple, List
|
from typing import Tuple
|
||||||
|
|
||||||
from noise.exceptions import NoiseProtocolNameError, NoisePSKError
|
from noise.exceptions import NoiseProtocolNameError, NoisePSKError
|
||||||
from noise.state import HandshakeState
|
from noise.state import HandshakeState
|
||||||
@@ -133,3 +133,4 @@ class NoiseProtocol(object):
|
|||||||
if value:
|
if value:
|
||||||
kwargs[keypair] = value
|
kwargs[keypair] = value
|
||||||
self.handshake_state = HandshakeState.initialize(self, **kwargs)
|
self.handshake_state = HandshakeState.initialize(self, **kwargs)
|
||||||
|
self.symmetric_state = self.handshake_state.symmetric_state
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from .constants import Empty, TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK
|
from .constants import Empty, TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK, MAX_NONCE
|
||||||
|
|
||||||
|
|
||||||
class CipherState(object):
|
class CipherState(object):
|
||||||
@@ -34,7 +34,7 @@ class CipherState(object):
|
|||||||
:param plaintext: bytes sequence
|
:param plaintext: bytes sequence
|
||||||
:return: ciphertext bytes sequence
|
:return: ciphertext bytes sequence
|
||||||
"""
|
"""
|
||||||
if self.n == 2**64 - 1:
|
if self.n == MAX_NONCE:
|
||||||
raise Exception('Nonce has depleted!')
|
raise Exception('Nonce has depleted!')
|
||||||
|
|
||||||
if not self.has_key():
|
if not self.has_key():
|
||||||
@@ -62,6 +62,9 @@ class CipherState(object):
|
|||||||
self.n = self.n + 1
|
self.n = self.n + 1
|
||||||
return plaintext
|
return plaintext
|
||||||
|
|
||||||
|
def rekey(self):
|
||||||
|
self.k = self.noise_protocol.cipher_fn.rekey(self.k)
|
||||||
|
|
||||||
|
|
||||||
class SymmetricState(object):
|
class SymmetricState(object):
|
||||||
"""
|
"""
|
||||||
@@ -87,7 +90,6 @@ class SymmetricState(object):
|
|||||||
# Create SymmetricState
|
# Create SymmetricState
|
||||||
instance = cls()
|
instance = cls()
|
||||||
instance.noise_protocol = noise_protocol
|
instance.noise_protocol = noise_protocol
|
||||||
noise_protocol.symmetric_state = instance
|
|
||||||
|
|
||||||
# If protocol_name is less than or equal to HASHLEN bytes in length, sets h equal to protocol_name with zero
|
# If protocol_name is less than or equal to HASHLEN bytes in length, sets h equal to protocol_name with zero
|
||||||
# bytes appended to make HASHLEN bytes. Otherwise sets h = HASH(protocol_name).
|
# bytes appended to make HASHLEN bytes. Otherwise sets h = HASH(protocol_name).
|
||||||
|
|||||||
Reference in New Issue
Block a user