mirror of
https://github.com/morgan9e/noiseprotocol
synced 2026-04-14 00:14:05 +09:00
noise/builder.py - Added methods for rekeying cipherstates - Added method for getting handshake hash (for channel binding) noise/functions.py - Added default rekey behavior and set it for AESGCM and ChaCha20 noise/constants.py - Added MAX_NONCE noise/state.py - Added rekey method to CipherState - Removed writing to noise_protocol instance in SymmetricState. NoiseProtocol fills the appropriate field by taking the data from HandshakeState now.
141 lines
5.7 KiB
Python
141 lines
5.7 KiB
Python
from enum import Enum, auto
|
|
from typing import Union, List
|
|
|
|
from noise.exceptions import NoisePSKError, NoiseValueError, NoiseHandshakeError
|
|
from .noise_protocol import NoiseProtocol
|
|
|
|
|
|
class Keypair(Enum):
|
|
STATIC = auto()
|
|
REMOTE_STATIC = auto()
|
|
EPHEMERAL = auto()
|
|
REMOTE_EPHEMERAL = auto()
|
|
|
|
|
|
_keypairs = {Keypair.STATIC: 's', Keypair.REMOTE_STATIC: 'rs',
|
|
Keypair.EPHEMERAL: 'e', Keypair.REMOTE_EPHEMERAL: 're'}
|
|
|
|
|
|
class NoiseBuilder(object):
|
|
def __init__(self):
|
|
self.noise_protocol = None
|
|
self.protocol_name = None
|
|
self.handshake_finished = False
|
|
self._handshake_started = False
|
|
self._next_fn = None
|
|
|
|
@classmethod
|
|
def from_name(cls, name: Union[str, bytes]):
|
|
instance = cls()
|
|
# Forgiving passing string. Bytes are good too, anything else will fail inside NoiseProtocol
|
|
try:
|
|
instance.protocol_name = name.encode('ascii') if isinstance(name, str) else name
|
|
except ValueError:
|
|
raise NoiseValueError('If passing string as protocol name, it must contain only ASCII characters')
|
|
instance.noise_protocol = NoiseProtocol(protocol_name=name)
|
|
return instance
|
|
|
|
def set_psks(self, psk: Union[bytes, str] = None, psks: List[Union[str, bytes]] = None):
|
|
if psk and psks:
|
|
raise NoisePSKError('Provide single PSK as psk or list of PSKs as psks')
|
|
if not psk and not psks:
|
|
raise NoisePSKError('No PSKs provided')
|
|
|
|
psks = psks or [psk]
|
|
if not all([isinstance(psk, (bytes, str)) for psk in psks]):
|
|
raise NoisePSKError('PSKs must be strings or bytes')
|
|
|
|
try:
|
|
self.noise_protocol.psks = [psk.encode('ascii') if isinstance(psk, str) else psk for psk in psks]
|
|
except UnicodeEncodeError:
|
|
raise NoisePSKError('If providing psks as (unicode) string, it must only contain ASCII characters')
|
|
|
|
def set_prologue(self, prologue: Union[bytes, str]):
|
|
if isinstance(prologue, bytes):
|
|
self.noise_protocol.prologue = prologue
|
|
elif isinstance(prologue, str):
|
|
try:
|
|
self.noise_protocol.prologue = prologue.encode('ascii')
|
|
except UnicodeEncodeError:
|
|
raise NoiseValueError('Prologue must be ASCII string or bytes')
|
|
else:
|
|
raise NoiseValueError('Prologue must be ASCII string or bytes')
|
|
|
|
def set_as_initiator(self):
|
|
self.noise_protocol.initiator = True
|
|
self._next_fn = self.write_message
|
|
|
|
def set_as_responder(self):
|
|
self.noise_protocol.initiator = False
|
|
self._next_fn = self.read_message
|
|
|
|
def set_keypair_from_private_bytes(self, keypair: Keypair, private_bytes: bytes):
|
|
self.noise_protocol.keypairs[_keypairs[keypair]] = \
|
|
self.noise_protocol.dh_fn.keypair_cls.from_private_bytes(private_bytes)
|
|
|
|
def set_keypair_from_public_bytes(self, keypair: Keypair, private_bytes: bytes):
|
|
self.noise_protocol.keypairs[_keypairs[keypair]] = \
|
|
self.noise_protocol.dh_fn.keypair_cls.from_public_bytes(private_bytes)
|
|
|
|
def set_keypair_from_private_path(self, keypair: Keypair, path: str):
|
|
with open(path, 'rb') as fd:
|
|
self.noise_protocol.keypairs[_keypairs[keypair]] = \
|
|
self.noise_protocol.dh_fn.keypair_cls.from_private_bytes(fd.read())
|
|
|
|
def set_keypair_from_public_path(self, keypair: Keypair, path: str):
|
|
with open(path, 'rb') as fd:
|
|
self.noise_protocol.keypairs[_keypairs[keypair]] = \
|
|
self.noise_protocol.dh_fn.keypair_cls.from_public_bytes(fd.read())
|
|
|
|
def start_handshake(self):
|
|
self.noise_protocol.validate()
|
|
self.noise_protocol.initialise_handshake_state()
|
|
self._handshake_started = True
|
|
|
|
def write_message(self, payload: bytes=b'') -> bytearray:
|
|
if not self._handshake_started:
|
|
raise NoiseHandshakeError('Call NoiseBuilder.start_handshake first')
|
|
if self._next_fn != self.write_message:
|
|
raise NoiseHandshakeError('NoiseBuilder.read_message has to be called now')
|
|
if self.handshake_finished:
|
|
raise NoiseHandshakeError('Handshake finished. NoiseBuilder.encrypt should be used now')
|
|
self._next_fn = self.read_message
|
|
|
|
buffer = bytearray()
|
|
result = self.noise_protocol.handshake_state.write_message(payload, buffer)
|
|
if result:
|
|
self.handshake_finished = True
|
|
return buffer
|
|
|
|
def read_message(self, data: bytes) -> bytearray:
|
|
if not self._handshake_started:
|
|
raise NoiseHandshakeError('Call NoiseBuilder.start_handshake first')
|
|
if self._next_fn != self.read_message:
|
|
raise NoiseHandshakeError('NoiseBuilder.write_message has to be called now')
|
|
if self.handshake_finished:
|
|
raise NoiseHandshakeError('Handshake finished. NoiseBuilder.decrypt should be used now')
|
|
self._next_fn = self.write_message
|
|
|
|
buffer = bytearray()
|
|
result = self.noise_protocol.handshake_state.read_message(data, buffer)
|
|
if result:
|
|
self.handshake_finished = True
|
|
return buffer
|
|
|
|
def encrypt(self, data: bytes):
|
|
if not isinstance(data, bytes) or len(data) > 65535:
|
|
raise Exception #todo
|
|
return self.noise_protocol.cipher_state_encrypt.encrypt_with_ad(None, data)
|
|
|
|
def decrypt(self, data: bytes):
|
|
return self.noise_protocol.cipher_state_decrypt.decrypt_with_ad(None, data)
|
|
|
|
def get_handshake_hash(self) -> bytes:
|
|
return self.noise_protocol.handshake_hash
|
|
|
|
def rekey_inbound_cipher(self):
|
|
self.noise_protocol.cipher_state_decrypt.rekey()
|
|
|
|
def rekey_outbound_cipher(self):
|
|
self.noise_protocol.cipher_state_encrypt.rekey()
|