Files
noiseprotocol/noise/noise_protocol.py
Piotr Lizonczyk 7366c45e72 0.2.0 pre-release commit
Renamed NoiseBuilder to NoiseConnection
Bumped Cryptography from 2.1.1 to 2.1.2
Ready to release

Closes #5
2017-11-01 11:17:07 +01:00

144 lines
5.7 KiB
Python

import warnings
from functools import partial
from typing import Tuple
from noise.exceptions import NoiseProtocolNameError, NoisePSKError, NoiseValidationError
from noise.state import HandshakeState
from .constants import MAX_PROTOCOL_NAME_LEN, Empty
from .functions import dh_map, cipher_map, hash_map, keypair_map, hmac_hash, hkdf
from .patterns import patterns_map
class NoiseProtocol(object):
"""
TODO: Document
"""
methods = {
'pattern': patterns_map,
'dh': dh_map,
'cipher': cipher_map,
'hash': hash_map,
'keypair': keypair_map
}
def __init__(self, protocol_name: bytes):
if not isinstance(protocol_name, bytes):
raise NoiseProtocolNameError('Protocol name has to be of type "bytes" not {}'.format(type(protocol_name)))
if len(protocol_name) > MAX_PROTOCOL_NAME_LEN:
raise NoiseProtocolNameError('Protocol name too long, has to be at most '
'{} chars long'.format(MAX_PROTOCOL_NAME_LEN))
self.name = protocol_name
mappings, pattern_modifiers = self._parse_protocol_name()
# A valid Pattern instance (see Section 7 of specification (rev 32))
self.pattern = mappings['pattern']()
self.pattern_modifiers = pattern_modifiers
if self.pattern_modifiers:
self.pattern.apply_pattern_modifiers(pattern_modifiers)
# Handle PSK handshake options
self.psks = None
self.is_psk_handshake = any([modifier.startswith('psk') for modifier in self.pattern_modifiers])
self.dh_fn = mappings['dh']
self.cipher_fn = mappings['cipher']
self.hash_fn = mappings['hash']
self.keypair_fn = mappings['keypair']
self.hmac = partial(hmac_hash, algorithm=self.hash_fn.fn)
self.hkdf = partial(hkdf, hmac_hash_fn=self.hmac)
self.prologue = None
self.initiator = None
self.handshake_hash = None
self.handshake_state = Empty()
self.symmetric_state = Empty()
self.cipher_state_handshake = Empty()
self.cipher_state_encrypt = Empty()
self.cipher_state_decrypt = Empty()
self.keypairs = {'s': None, 'e': None, 'rs': None, 're': None}
def _parse_protocol_name(self) -> Tuple[dict, list]:
unpacked = self.name.decode().split('_')
if unpacked[0] != 'Noise':
raise NoiseProtocolNameError('Noise Protocol name shall begin with Noise! Provided: {}'.format(self.name))
# Extract pattern name and pattern modifiers
pattern = ''
modifiers_str = None
for i, char in enumerate(unpacked[1]):
if char.isupper():
pattern += char
else:
# End of pattern, now look for modifiers
modifiers_str = unpacked[1][i:] # Will be empty string if it exceeds string size
break
modifiers = modifiers_str.split('+') if modifiers_str else []
data = {'pattern': 'Pattern' + pattern,
'dh': unpacked[2],
'cipher': unpacked[3],
'hash': unpacked[4],
'keypair': unpacked[2],
'pattern_modifiers': modifiers}
mapped_data = {}
# Validate if we know everything that Noise Protocol is supposed to use and map appropriate functions
for key, map_dict in self.methods.items():
func = map_dict.get(data[key])
if not func:
raise NoiseProtocolNameError('Unknown {} in Noise Protocol name, given {}, known {}'.format(
key, data[key], " ".join(map_dict)))
mapped_data[key] = func
return mapped_data, modifiers
def handshake_done(self):
if self.pattern.one_way:
if self.initiator:
self.cipher_state_decrypt = None
else:
self.cipher_state_encrypt = None
self.handshake_hash = self.symmetric_state.get_handshake_hash()
del self.handshake_state
del self.symmetric_state
del self.cipher_state_handshake
del self.prologue
del self.initiator
del self.dh_fn
del self.hash_fn
del self.keypair_fn
def validate(self):
if self.is_psk_handshake:
if any([len(psk) != 32 for psk in self.psks]):
raise NoisePSKError('Invalid psk length! Has to be 32 bytes long')
if len(self.psks) != self.pattern.psk_count:
raise NoisePSKError('Bad number of PSKs provided to this protocol! {} are required, '
'given {}'.format(self.pattern.psk_count, len(self.psks)))
if self.initiator is None:
raise NoiseValidationError('You need to set role with NoiseConnection.set_as_initiator '
'or NoiseConnection.set_as_responder')
for keypair in self.pattern.get_required_keypairs(self.initiator):
if self.keypairs[keypair] is None:
raise NoiseValidationError('Keypair {} has to be set for chosen handshake pattern'.format(keypair))
if self.keypairs['e'] is not None or self.keypairs['re'] is not None:
warnings.warn('One of ephemeral keypairs is already set. '
'This is OK for testing, but should NEVER happen in production!')
def initialise_handshake_state(self):
kwargs = {'initiator': self.initiator}
if self.prologue:
kwargs['prologue'] = self.prologue
for keypair, value in self.keypairs.items():
if value:
kwargs[keypair] = value
self.handshake_state = HandshakeState.initialize(self, **kwargs)
self.symmetric_state = self.handshake_state.symmetric_state