Implemented HandshakeState.initialize()

constants.py
* Adding max Noise protocol name constant

noise_protocol.py
* Provisioning NoiseProtocol and KeyPair classes

patterns.py
* Switching to proper intra-package relative imports
* Adding getter functions for pre-messages

state.py
* Switching to proper intra-package relative imports
* Removed __init__ of HandshakeState, leaving only initialize() as
constructor function.
* Implemented initialize() along with helper functions for retrieving
keypairs
* Modified SymmetricState, removing __init__ and leaving
initialize_symmetric as a constructor function (only provisioned)
This commit is contained in:
Piotr Lizonczyk
2017-08-07 00:43:22 +02:00
parent 19e78f1583
commit a6eec85ef7
4 changed files with 110 additions and 30 deletions

View File

@@ -10,3 +10,7 @@ TOKEN_ES = 'es'
TOKEN_SE = 'se'
TOKEN_SS = 'ss'
TOKEN_PSK = 'psk'
# In bytes, as in Section 8 of specification (rev 32)
MAX_PROTOCOL_NAME_LEN = 255

23
noise/noise_protocol.py Normal file
View File

@@ -0,0 +1,23 @@
from .constants import MAX_PROTOCOL_NAME_LEN
class NoiseProtocol(object):
def __init__(self, protocol_name: bytes):
if len(protocol_name) > MAX_PROTOCOL_NAME_LEN:
raise Exception('Protocol name too long, has to be at most {} chars long'.format(MAX_PROTOCOL_NAME_LEN))
self.pattern = None
self.pattern_modifiers = None
self.dh = None
self.dh_modifiers = None
self.cipher = None
self.cipher_modifiers = None
self.hash = None
self.hash_modifiers = None
class KeyPair(object):
def __init__(self, public='', private=''):
# TODO: Maybe switch to properties?
self.public = public
self.private = private

View File

@@ -1,4 +1,4 @@
from noise.constants import TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK
from .constants import TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK
class Pattern(object):
@@ -15,6 +15,15 @@ class Pattern(object):
# TODO Comment
tokens = []
def __init__(self):
self.has_pre_messages = any(map(lambda x: len(x) > 0, self.pre_messages))
def get_initiator_pre_messages(self) -> list:
return self.pre_messages[0]
def get_responder_pre_messages(self) -> list:
return self.pre_messages[1]
# One-way patterns

View File

@@ -1,4 +1,4 @@
from noise.constants import Empty
from .constants import Empty
class CipherState(object):
@@ -48,17 +48,16 @@ class SymmetricState(object):
"""
"""
def __init__(self):
self.ck = None
self.h = None
def initialize_symmetric(self, protocol_name):
@classmethod
def initialize_symmetric(cls, protocol_name) -> 'SymmetricState':
"""
:param protocol_name:
:return:
"""
pass
instance = cls()
# TODO
return instance
def mix_key(self, input_key_material):
"""
@@ -97,33 +96,66 @@ class SymmetricState(object):
"""
pass
class HandshakeState(object):
"""
"""
def __init__(self):
self.symmetric_state = Empty()
self.handshake_pattern = None
self.initiator = None
self.prologue = b''
self.s = Empty()
self.e = Empty()
self.rs = Empty()
self.re = Empty()
Implemented as per Noise Protocol specification (rev 32) - paragraph 5.3.
def initialize(self, handshake_pattern, initiator, prologue=b'', s=None, e=None, rs=None, re=None):
The initialize() function takes additional required argument - protocol_name - to provide it to SymmetricState.
"""
@classmethod
def initialize(cls, handshake_pattern: 'Pattern', protocol_name: 'NoiseProtocol', initiator: bool,
prologue: bytes=b'', s: bytes=None, e: bytes=None, rs: bytes=None,
re: bytes=None) -> 'HandshakeState':
"""
:param handshake_pattern:
:param initiator:
:param prologue:
:param s:
:param e:
:param rs:
:param re:
:return:
Constructor method.
Comments below are mostly copied from specification.
:param handshake_pattern: a valid Pattern instance (see Section 7 of specification (rev 32))
:param protocol_name: a valid NoiseProtocol instance
:param initiator: boolean indicating the initiator or responder role
:param prologue: byte sequence which may be zero-length, or which may contain context information that both
parties want to confirm is identical
:param s: local static key pair
:param e: local ephemeral key pair
:param rs: remote partys static public key
:param re: remote partys ephemeral public key
:return: initialized HandshakeState instance
"""
pass
# Create HandshakeState
instance = cls()
# Originally in specification:
# "Derives a protocol_name byte sequence by combining the names for
# the handshake pattern and crypto functions, as specified in Section 8."
# Instead, we supply the protocol name to the function. It should already be validated. We only check if the
# handshake pattern specified as an argument is the same as in the protocol name
# Calls InitializeSymmetric(protocol_name)
instance.symmetric_state = SymmetricState.initialize_symmetric(protocol_name)
# Calls MixHash(prologue)
instance.symmetric_state.mix_hash(prologue)
# Sets the initiator, s, e, rs, and re variables to the corresponding arguments
instance.initiator = initiator
instance.s = s if s is not None else Empty()
instance.e = e if e is not None else Empty()
instance.rs = rs if rs is not None else Empty()
instance.re = re if re is not None else Empty()
# Calls MixHash() once for each public key listed in the pre-messages from handshake_pattern, with the specified
# public key as input (...). If both initiator and responder have pre-messages, the initiators public keys are
# hashed first
for keypair in map(instance._get_local_keypair, handshake_pattern.get_initiator_pre_messages()):
instance.symmetric_state.mix_hash(keypair.public)
for keypair in map(instance._get_remote_keypair, handshake_pattern.get_responder_pre_messages()):
instance.symmetric_state.mix_hash(keypair.public)
# Sets message_patterns to the message patterns from handshake_pattern
instance.message_patterns = handshake_pattern.tokens
return instance
def write_message(self, payload, message_buffer):
"""
@@ -142,3 +174,15 @@ class HandshakeState(object):
:return:
"""
pass
def _get_local_keypair(self, token: str) -> 'KeyPair':
keypair = getattr(self, token) # Maybe explicitly handle exception when getting improper keypair
if isinstance(keypair, Empty):
raise Exception('Required keypair {} is empty!'.format(token)) # Maybe subclassed exception
return keypair
def _get_remote_keypair(self, token: str) -> 'KeyPair':
keypair = getattr(self, 'r' + token) # Maybe explicitly handle exception when getting improper keypair
if isinstance(keypair, Empty):
raise Exception('Required keypair {} is empty!'.format('r' + token)) # Maybe subclassed exception
return keypair