Files
noiseprotocol/noise/builder.py
Piotr Lizonczyk 865bbfe5ba Implemented cipher rekeying
noise/builder.py
- Added methods for rekeying cipherstates
- Added method for getting handshake hash (for channel binding)

noise/functions.py
- Added default rekey behavior and set it for AESGCM and ChaCha20

noise/constants.py
- Added MAX_NONCE

noise/state.py
- Added rekey method to CipherState
- Removed writing to noise_protocol instance in SymmetricState.
NoiseProtocol fills the appropriate field by taking the data from
HandshakeState now.
2017-09-02 17:38:02 +02:00

141 lines
5.7 KiB
Python

from enum import Enum, auto
from typing import Union, List
from noise.exceptions import NoisePSKError, NoiseValueError, NoiseHandshakeError
from .noise_protocol import NoiseProtocol
class Keypair(Enum):
STATIC = auto()
REMOTE_STATIC = auto()
EPHEMERAL = auto()
REMOTE_EPHEMERAL = auto()
_keypairs = {Keypair.STATIC: 's', Keypair.REMOTE_STATIC: 'rs',
Keypair.EPHEMERAL: 'e', Keypair.REMOTE_EPHEMERAL: 're'}
class NoiseBuilder(object):
def __init__(self):
self.noise_protocol = None
self.protocol_name = None
self.handshake_finished = False
self._handshake_started = False
self._next_fn = None
@classmethod
def from_name(cls, name: Union[str, bytes]):
instance = cls()
# Forgiving passing string. Bytes are good too, anything else will fail inside NoiseProtocol
try:
instance.protocol_name = name.encode('ascii') if isinstance(name, str) else name
except ValueError:
raise NoiseValueError('If passing string as protocol name, it must contain only ASCII characters')
instance.noise_protocol = NoiseProtocol(protocol_name=name)
return instance
def set_psks(self, psk: Union[bytes, str] = None, psks: List[Union[str, bytes]] = None):
if psk and psks:
raise NoisePSKError('Provide single PSK as psk or list of PSKs as psks')
if not psk and not psks:
raise NoisePSKError('No PSKs provided')
psks = psks or [psk]
if not all([isinstance(psk, (bytes, str)) for psk in psks]):
raise NoisePSKError('PSKs must be strings or bytes')
try:
self.noise_protocol.psks = [psk.encode('ascii') if isinstance(psk, str) else psk for psk in psks]
except UnicodeEncodeError:
raise NoisePSKError('If providing psks as (unicode) string, it must only contain ASCII characters')
def set_prologue(self, prologue: Union[bytes, str]):
if isinstance(prologue, bytes):
self.noise_protocol.prologue = prologue
elif isinstance(prologue, str):
try:
self.noise_protocol.prologue = prologue.encode('ascii')
except UnicodeEncodeError:
raise NoiseValueError('Prologue must be ASCII string or bytes')
else:
raise NoiseValueError('Prologue must be ASCII string or bytes')
def set_as_initiator(self):
self.noise_protocol.initiator = True
self._next_fn = self.write_message
def set_as_responder(self):
self.noise_protocol.initiator = False
self._next_fn = self.read_message
def set_keypair_from_private_bytes(self, keypair: Keypair, private_bytes: bytes):
self.noise_protocol.keypairs[_keypairs[keypair]] = \
self.noise_protocol.dh_fn.keypair_cls.from_private_bytes(private_bytes)
def set_keypair_from_public_bytes(self, keypair: Keypair, private_bytes: bytes):
self.noise_protocol.keypairs[_keypairs[keypair]] = \
self.noise_protocol.dh_fn.keypair_cls.from_public_bytes(private_bytes)
def set_keypair_from_private_path(self, keypair: Keypair, path: str):
with open(path, 'rb') as fd:
self.noise_protocol.keypairs[_keypairs[keypair]] = \
self.noise_protocol.dh_fn.keypair_cls.from_private_bytes(fd.read())
def set_keypair_from_public_path(self, keypair: Keypair, path: str):
with open(path, 'rb') as fd:
self.noise_protocol.keypairs[_keypairs[keypair]] = \
self.noise_protocol.dh_fn.keypair_cls.from_public_bytes(fd.read())
def start_handshake(self):
self.noise_protocol.validate()
self.noise_protocol.initialise_handshake_state()
self._handshake_started = True
def write_message(self, payload: bytes=b'') -> bytearray:
if not self._handshake_started:
raise NoiseHandshakeError('Call NoiseBuilder.start_handshake first')
if self._next_fn != self.write_message:
raise NoiseHandshakeError('NoiseBuilder.read_message has to be called now')
if self.handshake_finished:
raise NoiseHandshakeError('Handshake finished. NoiseBuilder.encrypt should be used now')
self._next_fn = self.read_message
buffer = bytearray()
result = self.noise_protocol.handshake_state.write_message(payload, buffer)
if result:
self.handshake_finished = True
return buffer
def read_message(self, data: bytes) -> bytearray:
if not self._handshake_started:
raise NoiseHandshakeError('Call NoiseBuilder.start_handshake first')
if self._next_fn != self.read_message:
raise NoiseHandshakeError('NoiseBuilder.write_message has to be called now')
if self.handshake_finished:
raise NoiseHandshakeError('Handshake finished. NoiseBuilder.decrypt should be used now')
self._next_fn = self.write_message
buffer = bytearray()
result = self.noise_protocol.handshake_state.read_message(data, buffer)
if result:
self.handshake_finished = True
return buffer
def encrypt(self, data: bytes):
if not isinstance(data, bytes) or len(data) > 65535:
raise Exception #todo
return self.noise_protocol.cipher_state_encrypt.encrypt_with_ad(None, data)
def decrypt(self, data: bytes):
return self.noise_protocol.cipher_state_decrypt.decrypt_with_ad(None, data)
def get_handshake_hash(self) -> bytes:
return self.noise_protocol.handshake_hash
def rekey_inbound_cipher(self):
self.noise_protocol.cipher_state_decrypt.rekey()
def rekey_outbound_cipher(self):
self.noise_protocol.cipher_state_encrypt.rekey()