mirror of
https://github.com/morgan9e/noiseprotocol
synced 2026-04-14 00:14:05 +09:00
Tons of fixes, working except Blake and PSK
noise/functions.py * Enabling ChaCha20 usage (from Cryptography) * Switching to per-cipher nonce formatting function * Changes to KeyPair interface - now wrappers exist for every ECDH * Fixing hmac_hash bug in implementation noise/noise_protocol.py * Added placeholders for multiple datafields in __init__, as well as for transport mode cipher states * Added handshake_done method for cleanup (post-handshake, pre-transport), not finished though noise/patterns.py * Now Pattern holds boolean telling if it's oneway. OneWayPattern class created for derivation by PatternN, PatternK, PatternX * Fixed wrong mapping of PatternK and PatternX in patterns_map noise/state.py * CipherState now takes noise_protocol in __init__, so that initialize_key() only reinitalizes CipherState instead of creating it. * Changed CipherState creation in SymmetricState to reflect change above * Fixing wrong sequence of concatenation hash and data in mix_hash() * SymmetricState's split() fixed and calling noise_protocol's handshake_done() * Pattern tokens are now copied to HandshakeState instead of modifying original Pattern * Changes in HandshakeState's writemessage and readmessage to reflect changes in KeyPair interface * Added workaround for tests (usage of pre-generated ephemeral keypair), to be removed in future tests/test_vectors.py * Individual test now is properly described in pytest with protocol name * Finished main test case, fully utilises test vectors (and all their messages) tests/vectors/noise-c-basic.txttests/vectors/noise-c-basic.txt * Forked rev30 test vector from noise-c
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
from functools import partial
|
||||
|
||||
from cryptography.hazmat.primitives.hmac import HMAC
|
||||
|
||||
@@ -6,7 +7,7 @@ from .crypto import ed448
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM, ChaCha20Poly1305
|
||||
from cryptography.hazmat.primitives.asymmetric import x25519
|
||||
|
||||
|
||||
@@ -24,7 +25,8 @@ class DH(object):
|
||||
|
||||
def _25519_generate_keypair(self) -> '_KeyPair':
|
||||
private_key = x25519.X25519PrivateKey.generate()
|
||||
return _KeyPair(private_key, private_key.public_key())
|
||||
public_key = private_key.public_key()
|
||||
return _KeyPair(private_key, public_key, public_key.public_bytes())
|
||||
|
||||
def _25519_dh(self, keypair: 'x25519.X25519PrivateKey', public_key: 'x25519.X25519PublicKey') -> bytes:
|
||||
return keypair.exchange(public_key)
|
||||
@@ -37,7 +39,9 @@ class Cipher(object):
|
||||
self.encrypt = self._aesgcm_encrypt
|
||||
self.decrypt = self._aesgcm_decrypt
|
||||
elif method == 'ChaCha20':
|
||||
raise NotImplementedError
|
||||
self._cipher = ChaCha20Poly1305
|
||||
self.encrypt = self._chacha20_encrypt
|
||||
self.decrypt = self._chacha20_decrypt
|
||||
else:
|
||||
raise NotImplementedError('Cipher method: {}'.format(method))
|
||||
|
||||
@@ -45,11 +49,26 @@ class Cipher(object):
|
||||
# Might be expensive to initialise AESGCM with the same key every time. The key should be (as per spec) kept in
|
||||
# CipherState, but we may as well hold an initialised AESGCM and manage reinitialisation on CipherState.rekey
|
||||
cipher = self._cipher(k)
|
||||
return cipher.encrypt(nonce=n, data=plaintext, associated_data=ad)
|
||||
return cipher.encrypt(nonce=self._aesgcm_nonce(n), data=plaintext, associated_data=ad)
|
||||
|
||||
def _aesgcm_decrypt(self, k, n, ad, ciphertext):
|
||||
cipher = self._cipher(k)
|
||||
return cipher.encrypt(nonce=n, data=ciphertext, associated_data=ad)
|
||||
return cipher.decrypt(nonce=self._aesgcm_nonce(n), data=ciphertext, associated_data=ad)
|
||||
|
||||
def _aesgcm_nonce(self, n):
|
||||
return b'\x00\x00\x00\x00' + n.to_bytes(length=8, byteorder='big')
|
||||
|
||||
def _chacha20_encrypt(self, k, n, ad, plaintext):
|
||||
# Same comment as with AESGCM
|
||||
cipher = self._cipher(k)
|
||||
return cipher.encrypt(nonce=self._chacha20_nonce(n), data=plaintext, associated_data=ad)
|
||||
|
||||
def _chacha20_decrypt(self, k, n, ad, ciphertext):
|
||||
cipher = self._cipher(k)
|
||||
return cipher.decrypt(nonce=self._chacha20_nonce(n), data=ciphertext, associated_data=ad)
|
||||
|
||||
def _chacha20_nonce(self, n):
|
||||
return b'\x00\x00\x00\x00' + n.to_bytes(length=8, byteorder='little')
|
||||
|
||||
|
||||
class Hash(object):
|
||||
@@ -68,12 +87,12 @@ class Hash(object):
|
||||
self.hashlen = 32
|
||||
self.blocklen = 64
|
||||
self.hash = self._hash_blake2s
|
||||
self.fn = hashes.BLAKE2s
|
||||
self.fn = partial(hashes.BLAKE2s, digest_size=self.hashlen)
|
||||
elif method == 'BLAKE2b':
|
||||
self.hashlen = 64
|
||||
self.blocklen = 128
|
||||
self.hash = self._hash_blake2b
|
||||
self.fn = hashes.BLAKE2b
|
||||
self.fn = partial(hashes.BLAKE2b, digest_size=self.hashlen)
|
||||
else:
|
||||
raise NotImplementedError('Hash method: {}'.format(method))
|
||||
|
||||
@@ -101,9 +120,10 @@ class Hash(object):
|
||||
class _KeyPair(object):
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
def __init__(self, private=None, public=None):
|
||||
def __init__(self, private=None, public=None, public_bytes=None):
|
||||
self.private = private
|
||||
self.public = public
|
||||
self.public_bytes = public_bytes
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
@@ -120,12 +140,13 @@ class KeyPair25519(_KeyPair):
|
||||
@classmethod
|
||||
def from_private_bytes(cls, private_bytes):
|
||||
private = x25519.X25519PrivateKey._from_private_bytes(private_bytes)
|
||||
public = private.public_key().public_bytes()
|
||||
return cls(private=private, public=public)
|
||||
public = private.public_key()
|
||||
return cls(private=private, public=public, public_bytes=public.public_bytes())
|
||||
|
||||
@classmethod
|
||||
def from_public_bytes(cls, public_bytes):
|
||||
return cls(public=x25519.X25519PublicKey.from_public_bytes(public_bytes).public_bytes())
|
||||
public = x25519.X25519PublicKey.from_public_bytes(public_bytes)
|
||||
return cls(public=public, public_bytes=public.public_bytes())
|
||||
|
||||
|
||||
# Available crypto functions
|
||||
@@ -138,7 +159,7 @@ dh_map = {
|
||||
|
||||
cipher_map = {
|
||||
'AESGCM': Cipher('AESGCM'),
|
||||
# 'ChaChaPoly': Cipher('ChaCha20') # TODO ugh cryptography lacks chacha primitive. use nacl I guess.
|
||||
'ChaChaPoly': Cipher('ChaCha20')
|
||||
}
|
||||
|
||||
hash_map = {
|
||||
@@ -157,9 +178,9 @@ keypair_map = {
|
||||
|
||||
def hmac_hash(key, data, algorithm):
|
||||
# Applies HMAC using the HASH() function.
|
||||
hmac = HMAC(key=key, algorithm=algorithm, backend=default_backend())
|
||||
hmac = HMAC(key=key, algorithm=algorithm(), backend=default_backend())
|
||||
hmac.update(data=data)
|
||||
return hmac
|
||||
return hmac.finalize()
|
||||
|
||||
|
||||
def hkdf(chaining_key, input_key_material, num_outputs, hmac_hash_fn):
|
||||
|
||||
@@ -40,11 +40,16 @@ class NoiseProtocol(object):
|
||||
self.hmac = partial(hmac_hash, algorithm=self.hash_fn.fn)
|
||||
self.hkdf = partial(hkdf, hmac_hash_fn=self.hmac)
|
||||
|
||||
self.initiator = None
|
||||
self.one_way = False
|
||||
self.handshake_hash = None
|
||||
self.psks = None # Placeholder for PSKs
|
||||
|
||||
self.handshake_state = Empty()
|
||||
self.symmetric_state = Empty()
|
||||
self.cipher_state = Empty()
|
||||
self.cipher_state_handshake = Empty()
|
||||
self.cipher_state_encrypt = Empty()
|
||||
self.cipher_state_decrypt = Empty()
|
||||
|
||||
def _parse_protocol_name(self) -> Tuple[dict, list]:
|
||||
unpacked = self.name.decode().split('_')
|
||||
@@ -84,3 +89,12 @@ class NoiseProtocol(object):
|
||||
|
||||
def set_psks(self, psks: list) -> None:
|
||||
self.psks = psks
|
||||
|
||||
def handshake_done(self):
|
||||
self.initiator = self.handshake_state.initiator
|
||||
if self.pattern.one_way:
|
||||
if self.initiator:
|
||||
del self.cipher_state_decrypt
|
||||
else:
|
||||
del self.cipher_state_encrypt
|
||||
self.handshake_hash = self.symmetric_state.h
|
||||
|
||||
@@ -20,12 +20,13 @@ class Pattern(object):
|
||||
|
||||
def __init__(self):
|
||||
self.has_pre_messages = any(map(lambda x: len(x) > 0, self.pre_messages))
|
||||
self.one_way = False
|
||||
|
||||
def get_initiator_pre_messages(self) -> list:
|
||||
return self.pre_messages[0]
|
||||
return self.pre_messages[0].copy()
|
||||
|
||||
def get_responder_pre_messages(self) -> list:
|
||||
return self.pre_messages[1]
|
||||
return self.pre_messages[1].copy()
|
||||
|
||||
def apply_pattern_modifiers(self, modifiers: List[str]) -> None:
|
||||
# Applies given pattern modifiers to self.tokens of the Pattern instance.
|
||||
@@ -54,7 +55,13 @@ class Pattern(object):
|
||||
|
||||
# One-way patterns
|
||||
|
||||
class PatternN(Pattern):
|
||||
class OneWayPattern(Pattern):
|
||||
def __init__(self):
|
||||
super(Pattern, self).__init__()
|
||||
self.one_way = True
|
||||
|
||||
|
||||
class PatternN(OneWayPattern):
|
||||
pre_messages = [
|
||||
[],
|
||||
[TOKEN_S]
|
||||
@@ -64,7 +71,7 @@ class PatternN(Pattern):
|
||||
]
|
||||
|
||||
|
||||
class PatternK(Pattern):
|
||||
class PatternK(OneWayPattern):
|
||||
pre_messages = [
|
||||
[TOKEN_S],
|
||||
[TOKEN_S]
|
||||
@@ -74,7 +81,7 @@ class PatternK(Pattern):
|
||||
]
|
||||
|
||||
|
||||
class PatternX(Pattern):
|
||||
class PatternX(OneWayPattern):
|
||||
pre_messages = [
|
||||
[],
|
||||
[TOKEN_S]
|
||||
@@ -199,8 +206,8 @@ class PatternIX(Pattern):
|
||||
|
||||
patterns_map = {
|
||||
'PatternN': PatternN,
|
||||
'PatternK': PatternN,
|
||||
'PatternX': PatternN,
|
||||
'PatternK': PatternK,
|
||||
'PatternX': PatternX,
|
||||
'PatternNN': PatternNN,
|
||||
'PatternKN': PatternKN,
|
||||
'PatternNK': PatternNK,
|
||||
|
||||
100
noise/state.py
100
noise/state.py
@@ -7,26 +7,17 @@ class CipherState(object):
|
||||
|
||||
The initialize_key() function takes additional required argument - noise_protocol.
|
||||
"""
|
||||
def __init__(self):
|
||||
def __init__(self, noise_protocol):
|
||||
self.k = Empty()
|
||||
self.n = None
|
||||
self.noise_protocol = None
|
||||
self.noise_protocol = noise_protocol
|
||||
|
||||
@classmethod
|
||||
def initialize_key(cls, key, noise_protocol: 'NoiseProtocol') -> 'CipherState': # TODO: fix for split case
|
||||
def initialize_key(self, key):
|
||||
"""
|
||||
|
||||
:param key:
|
||||
:param noise_protocol: a valid NoiseProtocol instance
|
||||
:return: initialised CipherState instance
|
||||
:param key: Key to set within CipherState
|
||||
"""
|
||||
instance = cls()
|
||||
instance.noise_protocol = noise_protocol
|
||||
noise_protocol.cipher_state = instance
|
||||
|
||||
instance.k = key
|
||||
instance.n = 0
|
||||
return instance
|
||||
self.k = key
|
||||
self.n = 0
|
||||
|
||||
def has_key(self):
|
||||
"""
|
||||
@@ -47,7 +38,7 @@ class CipherState(object):
|
||||
if not self.has_key():
|
||||
return plaintext
|
||||
|
||||
ciphertext = self.noise_protocol.cipher.encrypt(self.k, self.n, ad, plaintext)
|
||||
ciphertext = self.noise_protocol.cipher_fn.encrypt(self.k, self.n, ad, plaintext)
|
||||
self.n = self.n + 1
|
||||
return ciphertext
|
||||
|
||||
@@ -65,11 +56,12 @@ class CipherState(object):
|
||||
if not self.has_key():
|
||||
return ciphertext
|
||||
|
||||
plaintext = self.noise_protocol.cipher.decrypt(self.k, self.n, ad, ciphertext)
|
||||
plaintext = self.noise_protocol.cipher_fn.decrypt(self.k, self.n, ad, ciphertext)
|
||||
self.n = self.n + 1
|
||||
return plaintext
|
||||
|
||||
|
||||
|
||||
class SymmetricState(object):
|
||||
"""
|
||||
Implemented as per Noise Protocol specification (rev 32) - paragraph 5.2.
|
||||
@@ -107,7 +99,9 @@ class SymmetricState(object):
|
||||
instance.ck = instance.h
|
||||
|
||||
# Calls InitializeKey(empty).
|
||||
CipherState.initialize_key(Empty(), noise_protocol)
|
||||
cipher_state = CipherState(noise_protocol)
|
||||
cipher_state.initialize_key(Empty())
|
||||
noise_protocol.cipher_state_handshake = cipher_state
|
||||
|
||||
return instance
|
||||
|
||||
@@ -123,14 +117,14 @@ class SymmetricState(object):
|
||||
temp_k = temp_k[:32]
|
||||
|
||||
# Calls InitializeKey(temp_k).
|
||||
self.noise_protocol.cipher_state.initialize_key(temp_k) # TODO check for memory leaks here
|
||||
self.noise_protocol.cipher_state_handshake.initialize_key(temp_k)
|
||||
|
||||
def mix_hash(self, data: bytes):
|
||||
"""
|
||||
Sets h = HASH(h + data).
|
||||
:param data: bytes sequence
|
||||
"""
|
||||
self.h = self.noise_protocol.hash_fn.hash(data + self.h)
|
||||
self.h = self.noise_protocol.hash_fn.hash(self.h + data)
|
||||
|
||||
def encrypt_and_hash(self, plaintext: bytes) -> bytes:
|
||||
"""
|
||||
@@ -139,7 +133,7 @@ class SymmetricState(object):
|
||||
:param plaintext: bytes sequence
|
||||
:return: ciphertext bytes sequence
|
||||
"""
|
||||
ciphertext = self.noise_protocol.cipher_state.encrypt_with_ad(self.h, plaintext)
|
||||
ciphertext = self.noise_protocol.cipher_state_handshake.encrypt_with_ad(self.h, plaintext)
|
||||
self.mix_hash(ciphertext)
|
||||
return ciphertext
|
||||
|
||||
@@ -150,7 +144,7 @@ class SymmetricState(object):
|
||||
:param ciphertext: bytes sequence
|
||||
:return: plaintext bytes sequence
|
||||
"""
|
||||
plaintext = self.noise_protocol.cipher_state.decrypt_with_ad(self.h, ciphertext)
|
||||
plaintext = self.noise_protocol.cipher_state_handshake.decrypt_with_ad(self.h, ciphertext)
|
||||
self.mix_hash(ciphertext)
|
||||
return plaintext
|
||||
|
||||
@@ -169,8 +163,17 @@ class SymmetricState(object):
|
||||
|
||||
# Creates two new CipherState objects c1 and c2.
|
||||
# Calls c1.InitializeKey(temp_k1) and c2.InitializeKey(temp_k2).
|
||||
c1 = CipherState.initialize_key(temp_k1, self.noise_protocol) # TODO WRONG!
|
||||
c2 = CipherState.initialize_key(temp_k2, self.noise_protocol) # TODO WRONG!
|
||||
c1, c2 = CipherState(self.noise_protocol), CipherState(self.noise_protocol)
|
||||
c1.initialize_key(temp_k1)
|
||||
c2.initialize_key(temp_k2)
|
||||
if self.noise_protocol.handshake_state.initiator:
|
||||
self.noise_protocol.cipher_state_encrypt = c1
|
||||
self.noise_protocol.cipher_state_decrypt = c2
|
||||
else:
|
||||
self.noise_protocol.cipher_state_encrypt = c2
|
||||
self.noise_protocol.cipher_state_decrypt = c1
|
||||
|
||||
self.noise_protocol.handshake_done()
|
||||
|
||||
# Returns the pair (c1, c2).
|
||||
return c1, c2
|
||||
@@ -194,7 +197,7 @@ class HandshakeState(object):
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, noise_protocol: 'NoiseProtocol', initiator: bool, prologue: bytes=b'', s: bytes=None,
|
||||
e: bytes=None, rs: bytes=None, re: bytes=None) -> 'HandshakeState':
|
||||
e: bytes=None, rs: bytes=None, re: bytes=None) -> 'HandshakeState': # TODO update typing (keypair)
|
||||
"""
|
||||
Constructor method.
|
||||
Comments below are mostly copied from specification.
|
||||
@@ -240,12 +243,12 @@ class HandshakeState(object):
|
||||
initiator_keypair_getter = instance._get_local_keypair if initiator else instance._get_remote_keypair
|
||||
responder_keypair_getter = instance._get_remote_keypair if initiator else instance._get_local_keypair
|
||||
for keypair in map(initiator_keypair_getter, noise_protocol.pattern.get_initiator_pre_messages()):
|
||||
instance.symmetric_state.mix_hash(keypair.public)
|
||||
instance.symmetric_state.mix_hash(keypair.public_bytes)
|
||||
for keypair in map(responder_keypair_getter, noise_protocol.pattern.get_responder_pre_messages()):
|
||||
instance.symmetric_state.mix_hash(keypair.public)
|
||||
instance.symmetric_state.mix_hash(keypair.public_bytes)
|
||||
|
||||
# Sets message_patterns to the message patterns from handshake_pattern
|
||||
instance.message_patterns = noise_protocol.pattern.tokens
|
||||
instance.message_patterns = noise_protocol.pattern.tokens.copy()
|
||||
|
||||
return instance
|
||||
|
||||
@@ -262,35 +265,36 @@ class HandshakeState(object):
|
||||
for token in message_pattern:
|
||||
if token == TOKEN_E:
|
||||
# Sets e = GENERATE_KEYPAIR(). Appends e.public_key to the buffer. Calls MixHash(e.public_key)
|
||||
self.e = self.noise_protocol.dh_fn.generate_keypair()
|
||||
message_buffer.write(self.e.public)
|
||||
self.symmetric_state.mix_hash(self.e.public)
|
||||
self.e = self.noise_protocol.dh_fn.generate_keypair() if isinstance(self.e, Empty) else self.e # TODO: it's workaround, otherwise use mock
|
||||
message_buffer.write(self.e.public_bytes)
|
||||
self.symmetric_state.mix_hash(self.e.public_bytes)
|
||||
|
||||
elif token == TOKEN_S:
|
||||
# Appends EncryptAndHash(s.public_key) to the buffer
|
||||
message_buffer.write(self.symmetric_state.encrypt_and_hash(self.s.public))
|
||||
message_buffer.write(self.symmetric_state.encrypt_and_hash(self.s.public_bytes))
|
||||
|
||||
elif token == TOKEN_EE:
|
||||
# Calls MixKey(DH(e, re))
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.re.public))
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e.private, self.re.public))
|
||||
|
||||
elif token == TOKEN_ES:
|
||||
# Calls MixKey(DH(e, rs)) if initiator, MixKey(DH(s, re)) if responder
|
||||
if self.initiator:
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.rs.public))
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e.private, self.rs.public))
|
||||
else:
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.re.public))
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s.private, self.re.public))
|
||||
|
||||
elif token == TOKEN_SE:
|
||||
# Calls MixKey(DH(s, re)) if initiator, MixKey(DH(e, rs)) if responder
|
||||
if self.initiator:
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.re.public))
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s.private, self.re.public))
|
||||
else:
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.rs.public))
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e.private, self.rs.public))
|
||||
|
||||
elif token == TOKEN_SS:
|
||||
# Calls MixKey(DH(s, rs))
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.rs.public))
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s.private, self.rs.public))
|
||||
pass
|
||||
|
||||
elif token == TOKEN_PSK:
|
||||
raise NotImplementedError
|
||||
@@ -320,38 +324,38 @@ class HandshakeState(object):
|
||||
if token == TOKEN_E:
|
||||
# Sets re to the next DHLEN bytes from the message. Calls MixHash(re.public_key).
|
||||
self.re = self.noise_protocol.keypair_fn.from_public_bytes(message.read(dhlen))
|
||||
self.symmetric_state.mix_hash(self.re.public)
|
||||
self.symmetric_state.mix_hash(self.re.public_bytes)
|
||||
|
||||
elif token == TOKEN_S:
|
||||
# Sets temp to the next DHLEN + 16 bytes of the message if HasKey() == True, or to the next DHLEN bytes
|
||||
# otherwise. Sets rs to DecryptAndHash(temp).
|
||||
if self.noise_protocol.cipher_state.has_key():
|
||||
if self.noise_protocol.cipher_state_handshake.has_key():
|
||||
temp = message.read(dhlen + 16)
|
||||
else:
|
||||
temp = message.read(dhlen)
|
||||
self.rs = self.symmetric_state.decrypt_and_hash(temp)
|
||||
self.rs = self.noise_protocol.keypair_fn.from_public_bytes(self.symmetric_state.decrypt_and_hash(temp))
|
||||
|
||||
elif token == TOKEN_EE:
|
||||
# Calls MixKey(DH(e, re)).
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.re.public))
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e.private, self.re.public))
|
||||
|
||||
elif token == TOKEN_ES:
|
||||
# Calls MixKey(DH(e, rs)) if initiator, MixKey(DH(s, re)) if responder
|
||||
if self.initiator:
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.rs.public))
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e.private, self.rs.public))
|
||||
else:
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.re.public))
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s.private, self.re.public))
|
||||
|
||||
elif token == TOKEN_SE:
|
||||
# Calls MixKey(DH(s, re)) if initiator, MixKey(DH(e, rs)) if responder
|
||||
if self.initiator:
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.re.public))
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s.private, self.re.public))
|
||||
else:
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e, self.rs.public))
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e.private, self.rs.public))
|
||||
|
||||
elif token == TOKEN_SS:
|
||||
# Calls MixKey(DH(s, rs))
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s, self.rs.public))
|
||||
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s.private, self.rs.public))
|
||||
|
||||
elif token == TOKEN_PSK:
|
||||
raise NotImplementedError
|
||||
@@ -360,7 +364,7 @@ class HandshakeState(object):
|
||||
raise NotImplementedError('Pattern token: {}'.format(token))
|
||||
|
||||
# Calls DecryptAndHash() on the remaining bytes of the message and stores the output into payload_buffer.
|
||||
payload_buffer.write(self.symmetric_state.decrypt_and_hash(message)) # TODO remaining bytes!
|
||||
payload_buffer.write(self.symmetric_state.decrypt_and_hash(message.read()))
|
||||
|
||||
# If there are no more message patterns returns two new CipherState objects by calling Split()
|
||||
if len(self.message_patterns) == 0:
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
import logging
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from noise.functions import KeyPair25519
|
||||
from noise.state import HandshakeState
|
||||
from noise.state import HandshakeState, CipherState
|
||||
from noise.noise_protocol import NoiseProtocol
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
vector_files = ['vectors/cacophony.txt']
|
||||
vector_files = [
|
||||
'vectors/cacophony.txt',
|
||||
'vectors/noise-c-basic.txt'
|
||||
]
|
||||
|
||||
# As in test vectors specification (https://github.com/noiseprotocol/noise_wiki/wiki/Test-vectors)
|
||||
# We use this to cast read strings into bytes
|
||||
@@ -29,8 +33,10 @@ def _prepare_test_vectors():
|
||||
vectors_list = json.load(fd)
|
||||
|
||||
for vector in vectors_list:
|
||||
if '_448_' in vector['protocol_name'] or 'ChaCha' in vector['protocol_name'] or 'psk' in vector['protocol_name']:
|
||||
continue # TODO REMOVE WHEN ed448/ChaCha/psk SUPPORT IS IMPLEMENTED
|
||||
if 'name' in vector and not 'protocol_name' in vector: # noise-c-* workaround
|
||||
vector['protocol_name'] = vector['name']
|
||||
if '_448_' in vector['protocol_name'] or 'psk' in vector['protocol_name'] or 'BLAKE' in vector['protocol_name'] or 'PSK' in vector['protocol_name']:
|
||||
continue # TODO REMOVE WHEN ed448/psk/blake SUPPORT IS IMPLEMENTED/FIXED
|
||||
for key, value in vector.copy().items():
|
||||
if key in byte_fields:
|
||||
vector[key] = value.encode()
|
||||
@@ -41,13 +47,17 @@ def _prepare_test_vectors():
|
||||
if key == dict_field:
|
||||
vector[key] = []
|
||||
for dictionary in value:
|
||||
vector[key].append({k: v.encode() for k, v in dictionary.items()})
|
||||
vector[key].append({k: bytes.fromhex(v) for k, v in dictionary.items()})
|
||||
vectors.append(vector)
|
||||
return vectors
|
||||
|
||||
|
||||
def idfn(vector):
|
||||
return vector['protocol_name']
|
||||
|
||||
|
||||
class TestVectors(object):
|
||||
@pytest.fixture(params=_prepare_test_vectors())
|
||||
@pytest.fixture(params=_prepare_test_vectors(), ids=idfn)
|
||||
def vector(self, request):
|
||||
yield request.param
|
||||
|
||||
@@ -65,8 +75,6 @@ class TestVectors(object):
|
||||
return kwargs
|
||||
|
||||
def test_vector(self, vector):
|
||||
logging.info('Testing vector {}'.format(vector['protocol_name']))
|
||||
|
||||
kwargs = self._prepare_handshake_state_kwargs(vector)
|
||||
|
||||
init_protocol = NoiseProtocol(vector['protocol_name'])
|
||||
@@ -80,3 +88,50 @@ class TestVectors(object):
|
||||
|
||||
initiator = HandshakeState.initialize(**kwargs['init'])
|
||||
responder = HandshakeState.initialize(**kwargs['resp'])
|
||||
initiator_to_responder = True
|
||||
|
||||
handshake_finished = False
|
||||
for message in vector['messages']:
|
||||
if not handshake_finished:
|
||||
message_buffer = io.BytesIO()
|
||||
payload_buffer = io.BytesIO()
|
||||
if initiator_to_responder:
|
||||
sender, receiver = initiator, responder
|
||||
else:
|
||||
sender, receiver = responder, initiator
|
||||
|
||||
sender_result = sender.write_message(message['payload'], message_buffer)
|
||||
assert message_buffer.getbuffer().tobytes() == message['ciphertext']
|
||||
|
||||
message_buffer.seek(0)
|
||||
receiver_result = receiver.read_message(message_buffer, payload_buffer)
|
||||
assert payload_buffer.getbuffer().tobytes() == message['payload']
|
||||
|
||||
if sender_result is None or receiver_result is None:
|
||||
# Not finished with handshake, fail if one would finish before other
|
||||
assert sender_result == receiver_result
|
||||
else:
|
||||
# Handshake done
|
||||
handshake_finished = True
|
||||
assert isinstance(sender_result[0], CipherState)
|
||||
assert isinstance(sender_result[1], CipherState)
|
||||
assert isinstance(receiver_result[0], CipherState)
|
||||
assert isinstance(receiver_result[1], CipherState)
|
||||
|
||||
# Verify handshake hash
|
||||
assert init_protocol.symmetric_state.h == resp_protocol.symmetric_state.h == vector['handshake_hash']
|
||||
|
||||
# Verify split cipherstates keys
|
||||
assert init_protocol.cipher_state_encrypt.k == resp_protocol.cipher_state_decrypt.k
|
||||
if not init_protocol.pattern.one_way:
|
||||
assert init_protocol.cipher_state_decrypt.k == resp_protocol.cipher_state_encrypt.k
|
||||
else:
|
||||
if init_protocol.pattern.one_way or initiator_to_responder:
|
||||
sender, receiver = init_protocol, resp_protocol
|
||||
else:
|
||||
sender, receiver = resp_protocol, init_protocol
|
||||
ciphertext = sender.cipher_state_encrypt.encrypt_with_ad(None, message['payload'])
|
||||
assert ciphertext == message['ciphertext']
|
||||
plaintext = receiver.cipher_state_decrypt.decrypt_with_ad(None, message['ciphertext'])
|
||||
assert plaintext == message['payload']
|
||||
initiator_to_responder = not initiator_to_responder
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
Test vectors:
|
||||
- cacophony.txt from https://github.com/centromere/cacophony/blob/master/vectors/cacophony.txt, stripped from outer dict (so that it is just a list of json objects)
|
||||
- cacophony.txt from https://github.com/centromere/cacophony/blob/master/vectors/cacophony.txt, stripped from outer dict (so that it is just a list of json objects)
|
||||
- noise-c-basic.txt from https://github.com/rweather/noise-c/blob/master/tests/vector/noise-c-basic.txt, stripped from outer dict (so that it is just a list of json objects)
|
||||
19682
tests/vectors/noise-c-basic.txt
Normal file
19682
tests/vectors/noise-c-basic.txt
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user