mirror of
https://github.com/morgan9e/noiseprotocol
synced 2026-04-14 00:14:05 +09:00
Improved validation, various additions
noise/builder.py: - Added guard for data length in decrypt - Handling InvalidTag exception when AEAD fails - New NoiseInvalidMessage exception class noise/exceptions.py - Three new exception classes noise/noise_protocol.py - Implemented rest of validation, now checks for required keypairs, setting initiator/responder role, warns if ephemeral keypairs are set. noise/patterns.py: - added name field to every Pattern with pattern name - added get_required_keypairs method that returns list of keypairs required for given handshake pattern noise/state.py - new NoiseMaxNonceError exception Overall: some TODOs resolved
This commit is contained in:
committed by
Szarlejowiec
parent
eaecac6af4
commit
368d401701
@@ -1,7 +1,10 @@
|
||||
from enum import Enum, auto
|
||||
from typing import Union, List
|
||||
|
||||
from noise.exceptions import NoisePSKError, NoiseValueError, NoiseHandshakeError
|
||||
from cryptography.exceptions import InvalidTag
|
||||
|
||||
from noise.constants import MAX_MESSAGE_LEN
|
||||
from noise.exceptions import NoisePSKError, NoiseValueError, NoiseHandshakeError, NoiseInvalidMessage
|
||||
from .noise_protocol import NoiseProtocol
|
||||
|
||||
|
||||
@@ -123,12 +126,17 @@ class NoiseBuilder(object):
|
||||
return buffer
|
||||
|
||||
def encrypt(self, data: bytes):
|
||||
if not isinstance(data, bytes) or len(data) > 65535:
|
||||
raise Exception #todo
|
||||
if not isinstance(data, bytes) or len(data) > MAX_MESSAGE_LEN:
|
||||
raise NoiseInvalidMessage('Data must be bytes and less or equal {} bytes in length'.format(MAX_MESSAGE_LEN))
|
||||
return self.noise_protocol.cipher_state_encrypt.encrypt_with_ad(None, data)
|
||||
|
||||
def decrypt(self, data: bytes):
|
||||
return self.noise_protocol.cipher_state_decrypt.decrypt_with_ad(None, data)
|
||||
if not isinstance(data, bytes) or len(data) > MAX_MESSAGE_LEN:
|
||||
raise NoiseInvalidMessage('Data must be bytes and less or equal {} bytes in length'.format(MAX_MESSAGE_LEN))
|
||||
try:
|
||||
return self.noise_protocol.cipher_state_decrypt.decrypt_with_ad(None, data)
|
||||
except InvalidTag:
|
||||
raise NoiseInvalidMessage('Failed authentication of message')
|
||||
|
||||
def get_handshake_hash(self) -> bytes:
|
||||
return self.noise_protocol.handshake_hash
|
||||
|
||||
@@ -12,3 +12,15 @@ class NoiseValueError(Exception):
|
||||
|
||||
class NoiseHandshakeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NoiseInvalidMessage(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NoiseMaxNonceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NoiseValidationError(Exception):
|
||||
pass
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import abc
|
||||
import warnings
|
||||
from functools import partial
|
||||
from functools import partial # Turn back on when Cryptography gets fixed
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
@@ -208,9 +208,6 @@ class KeyPair448(_KeyPair):
|
||||
return cls(private=private, public=public, public_bytes=public)
|
||||
|
||||
|
||||
# Available crypto functions
|
||||
# TODO: Check if it's safe to use one instance globally per cryptoalgorithm - i.e. if wrapper only provides interface
|
||||
# If not - switch to partials(?)
|
||||
dh_map = {
|
||||
'25519': DH('ed25519'),
|
||||
'448': DH('ed448')
|
||||
@@ -222,7 +219,6 @@ cipher_map = {
|
||||
}
|
||||
|
||||
hash_map = {
|
||||
# TODO benchmark pycryptodome vs hashlib implementation
|
||||
'BLAKE2s': Hash('BLAKE2s'),
|
||||
'BLAKE2b': Hash('BLAKE2b'),
|
||||
'SHA256': Hash('SHA256'),
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
|
||||
from noise.exceptions import NoiseProtocolNameError, NoisePSKError
|
||||
from noise.exceptions import NoiseProtocolNameError, NoisePSKError, NoiseValidationError
|
||||
from noise.state import HandshakeState
|
||||
from .constants import MAX_PROTOCOL_NAME_LEN, Empty
|
||||
from .functions import dh_map, cipher_map, hash_map, keypair_map, hmac_hash, hkdf
|
||||
@@ -120,10 +121,17 @@ class NoiseProtocol(object):
|
||||
raise NoisePSKError('Bad number of PSKs provided to this protocol! {} are required, '
|
||||
'given {}'.format(self.pattern.psk_count, len(self.psks)))
|
||||
|
||||
# TODO: Validate keypairs
|
||||
# TODO: Validate initiator set
|
||||
# TODO: Validate buffers set
|
||||
# TODO: Warn about ephemerals
|
||||
if self.initiator is None:
|
||||
raise NoiseValidationError('You need to set role with NoiseBuilder.set_as_initiator '
|
||||
'or NoiseBuilder.set_as_responder')
|
||||
|
||||
for keypair in self.pattern.get_required_keypairs(self.initiator):
|
||||
if self.keypairs[keypair] is None:
|
||||
raise NoiseValidationError('Keypair {} has to be set for chosen handshake pattern'.format(keypair))
|
||||
|
||||
if not isinstance(self.keypairs['e'], Empty) or not isinstance(self.keypairs['re'], Empty):
|
||||
warnings.warn('One of ephemeral keypairs is already set. '
|
||||
'This is OK for testing, but should NEVER happen in production!')
|
||||
|
||||
def initialise_handshake_state(self):
|
||||
kwargs = {'initiator': self.initiator}
|
||||
|
||||
@@ -19,6 +19,7 @@ class Pattern(object):
|
||||
# List of lists of valid tokens, alternating between tokens for initiator and responder
|
||||
self.tokens = []
|
||||
|
||||
self.name = ''
|
||||
self.has_pre_messages = any(map(lambda x: len(x) > 0, self.pre_messages))
|
||||
self.one_way = False
|
||||
self.psk_count = 0
|
||||
@@ -54,6 +55,20 @@ class Pattern(object):
|
||||
else:
|
||||
raise ValueError('Unknown pattern modifier {}'.format(modifier))
|
||||
|
||||
def get_required_keypairs(self, initiator: bool) -> list:
|
||||
required = []
|
||||
if initiator:
|
||||
if self.name[0] in ['K', 'X', 'I']:
|
||||
required.append('s')
|
||||
if self.one_way or self.name[1] == 'K':
|
||||
required.append('rs')
|
||||
else:
|
||||
if self.name[0] == 'K':
|
||||
required.append('rs')
|
||||
if self.one_way or self.name[1] in ['K', 'X']:
|
||||
required.append('s')
|
||||
return required
|
||||
|
||||
|
||||
# One-way patterns
|
||||
|
||||
@@ -66,6 +81,7 @@ class OneWayPattern(Pattern):
|
||||
class PatternN(OneWayPattern):
|
||||
def __init__(self):
|
||||
super(PatternN, self).__init__()
|
||||
self.name = 'N'
|
||||
|
||||
self.pre_messages = [
|
||||
[],
|
||||
@@ -79,6 +95,7 @@ class PatternN(OneWayPattern):
|
||||
class PatternK(OneWayPattern):
|
||||
def __init__(self):
|
||||
super(PatternK, self).__init__()
|
||||
self.name = 'K'
|
||||
|
||||
self.pre_messages = [
|
||||
[TOKEN_S],
|
||||
@@ -92,6 +109,7 @@ class PatternK(OneWayPattern):
|
||||
class PatternX(OneWayPattern):
|
||||
def __init__(self):
|
||||
super(PatternX, self).__init__()
|
||||
self.name = 'X'
|
||||
|
||||
self.pre_messages = [
|
||||
[],
|
||||
@@ -107,6 +125,7 @@ class PatternX(OneWayPattern):
|
||||
class PatternNN(Pattern):
|
||||
def __init__(self):
|
||||
super(PatternNN, self).__init__()
|
||||
self.name = 'NN'
|
||||
|
||||
self.tokens = [
|
||||
[TOKEN_E],
|
||||
@@ -117,6 +136,7 @@ class PatternNN(Pattern):
|
||||
class PatternKN(Pattern):
|
||||
def __init__(self):
|
||||
super(PatternKN, self).__init__()
|
||||
self.name = 'KN'
|
||||
|
||||
self.pre_messages = [
|
||||
[TOKEN_S],
|
||||
@@ -131,6 +151,7 @@ class PatternKN(Pattern):
|
||||
class PatternNK(Pattern):
|
||||
def __init__(self):
|
||||
super(PatternNK, self).__init__()
|
||||
self.name = 'NK'
|
||||
|
||||
self.pre_messages = [
|
||||
[],
|
||||
@@ -145,6 +166,7 @@ class PatternNK(Pattern):
|
||||
class PatternKK(Pattern):
|
||||
def __init__(self):
|
||||
super(PatternKK, self).__init__()
|
||||
self.name = 'KK'
|
||||
|
||||
self.pre_messages = [
|
||||
[TOKEN_S],
|
||||
@@ -159,6 +181,7 @@ class PatternKK(Pattern):
|
||||
class PatternNX(Pattern):
|
||||
def __init__(self):
|
||||
super(PatternNX, self).__init__()
|
||||
self.name = 'NX'
|
||||
|
||||
self.tokens = [
|
||||
[TOKEN_E],
|
||||
@@ -169,6 +192,7 @@ class PatternNX(Pattern):
|
||||
class PatternKX(Pattern):
|
||||
def __init__(self):
|
||||
super(PatternKX, self).__init__()
|
||||
self.name = 'KX'
|
||||
|
||||
self.pre_messages = [
|
||||
[TOKEN_S],
|
||||
@@ -183,6 +207,7 @@ class PatternKX(Pattern):
|
||||
class PatternXN(Pattern):
|
||||
def __init__(self):
|
||||
super(PatternXN, self).__init__()
|
||||
self.name = 'XN'
|
||||
|
||||
self.tokens = [
|
||||
[TOKEN_E],
|
||||
@@ -194,6 +219,7 @@ class PatternXN(Pattern):
|
||||
class PatternIN(Pattern):
|
||||
def __init__(self):
|
||||
super(PatternIN, self).__init__()
|
||||
self.name = 'IN'
|
||||
|
||||
self.tokens = [
|
||||
[TOKEN_E, TOKEN_S],
|
||||
@@ -204,6 +230,7 @@ class PatternIN(Pattern):
|
||||
class PatternXK(Pattern):
|
||||
def __init__(self):
|
||||
super(PatternXK, self).__init__()
|
||||
self.name = 'XK'
|
||||
|
||||
self.pre_messages = [
|
||||
[],
|
||||
@@ -219,6 +246,7 @@ class PatternXK(Pattern):
|
||||
class PatternIK(Pattern):
|
||||
def __init__(self):
|
||||
super(PatternIK, self).__init__()
|
||||
self.name = 'IK'
|
||||
|
||||
self.pre_messages = [
|
||||
[],
|
||||
@@ -233,6 +261,7 @@ class PatternIK(Pattern):
|
||||
class PatternXX(Pattern):
|
||||
def __init__(self):
|
||||
super(PatternXX, self).__init__()
|
||||
self.name = 'XX'
|
||||
|
||||
self.tokens = [
|
||||
[TOKEN_E],
|
||||
@@ -244,6 +273,7 @@ class PatternXX(Pattern):
|
||||
class PatternIX(Pattern):
|
||||
def __init__(self):
|
||||
super(PatternIX, self).__init__()
|
||||
self.name = 'IX'
|
||||
|
||||
self.tokens = [
|
||||
[TOKEN_E, TOKEN_S],
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Union
|
||||
|
||||
from noise.exceptions import NoiseMaxNonceError
|
||||
from .constants import Empty, TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK, MAX_NONCE
|
||||
|
||||
|
||||
@@ -53,7 +54,7 @@ class CipherState(object):
|
||||
:return: plaintext bytes sequence
|
||||
"""
|
||||
if self.n == 2**64 - 1:
|
||||
raise Exception('Nonce has depleted!')
|
||||
raise NoiseMaxNonceError('Nonce has depleted!')
|
||||
|
||||
if not self.has_key():
|
||||
return ciphertext
|
||||
@@ -210,8 +211,8 @@ class HandshakeState(object):
|
||||
self.message_patterns = None
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, noise_protocol: 'NoiseProtocol', initiator: bool, prologue: bytes=b'', s: bytes=None,
|
||||
e: bytes=None, rs: bytes=None, re: bytes=None) -> 'HandshakeState': # TODO update typing (keypair)
|
||||
def initialize(cls, noise_protocol: 'NoiseProtocol', initiator: bool, prologue: bytes=b'', s: '_KeyPair'=None,
|
||||
e: '_KeyPair'=None, rs: '_KeyPair'=None, re: '_KeyPair'=None) -> 'HandshakeState':
|
||||
"""
|
||||
Constructor method.
|
||||
Comments below are mostly copied from specification.
|
||||
@@ -278,7 +279,7 @@ class HandshakeState(object):
|
||||
for token in message_pattern:
|
||||
if token == TOKEN_E:
|
||||
# Sets e = GENERATE_KEYPAIR(). Appends e.public_key to the buffer. Calls MixHash(e.public_key)
|
||||
self.e = self.noise_protocol.dh_fn.generate_keypair() if isinstance(self.e, Empty) else self.e # TODO: it's workaround, otherwise use mock
|
||||
self.e = self.noise_protocol.dh_fn.generate_keypair() if isinstance(self.e, Empty) else self.e
|
||||
message_buffer += self.e.public_bytes
|
||||
self.symmetric_state.mix_hash(self.e.public_bytes)
|
||||
if self.noise_protocol.is_psk_handshake:
|
||||
|
||||
@@ -55,7 +55,8 @@ def idfn(vector):
|
||||
return vector['protocol_name']
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings('ignore: This implementation')
|
||||
@pytest.mark.filterwarnings('ignore: This implementation of ed448')
|
||||
@pytest.mark.filterwarnings('ignore: One of ephemeral keypairs')
|
||||
class TestVectors(object):
|
||||
@pytest.fixture(params=_prepare_test_vectors(), ids=idfn)
|
||||
def vector(self, request):
|
||||
|
||||
Reference in New Issue
Block a user