diff --git a/noise/functions.py b/noise/functions.py index 231301f..ecce98f 100644 --- a/noise/functions.py +++ b/noise/functions.py @@ -1,18 +1,26 @@ from .crypto import ed448 -from Crypto.Cipher import AES, ChaCha20 -from Crypto.Hash import BLAKE2b, BLAKE2s, SHA256, SHA512 -import ed25519 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import x25519 class DH(object): def __init__(self, method): - self.method = method - self.dhlen = 0 - self.dh = None + if method == 'ed25519': + self.method = method + self.dhlen = 32 + self.generate_keypair = self._25519_generate_keypair + self.dh = self._25519_dh + elif method == 'ed448': + raise NotImplementedError - def generate_keypair(self) -> 'KeyPair': - pass + def _25519_generate_keypair(self) -> 'KeyPair': + private_key = x25519.X25519PrivateKey.generate() + return KeyPair(private_key, private_key.public_key()) + + def _25519_dh(self, keypair: 'x25519.X25519PrivateKey', public_key: 'x25519.X25519PublicKey') -> bytes: + return keypair.exchange(public_key) class Cipher(object): @@ -46,28 +54,40 @@ class Hash(object): self.hash = self._hash_blake2b def _hash_sha256(self, data): - return SHA256.new(data).digest() + digest = hashes.Hash(hashes.SHA256(), default_backend()) + digest.update(data) + return digest.finalize() def _hash_sha512(self, data): - return SHA512.new(data).digest() + digest = hashes.Hash(hashes.SHA512(), default_backend()) + digest.update(data) + return digest.finalize() def _hash_blake2s(self, data): - return BLAKE2s.new(data=data, digest_bytes=self.hashlen).digest() + digest = hashes.Hash(hashes.BLAKE2s(digest_size=self.hashlen), default_backend()) + digest.update(data) + return digest.finalize() def _hash_blake2b(self, data): - return BLAKE2b.new(data=data, digest_bytes=self.hashlen).digest() + digest = hashes.Hash(hashes.BLAKE2b(digest_size=self.hashlen), default_backend()) + digest.update(data) + return digest.finalize() class KeyPair(object): - def __init__(self, public=b'', private=b''): - # TODO: Maybe switch to properties? - self.public = public + def __init__(self, private=None, public=None): self.private = private - if private and not public: - self.derive_public_key() + self.public = public - def derive_public_key(self): - pass + @classmethod + def _25519_from_private_bytes(cls, private_bytes): + private = x25519.X25519PrivateKey._from_private_bytes(private_bytes) + public = private.public_key().public_bytes() + return cls(private=private, public=public) + + @classmethod + def _25519_from_public_bytes(cls, public_bytes): + return cls(public=x25519.X25519PublicKey.from_public_bytes(public_bytes).public_bytes()) # Available crypto functions @@ -75,7 +95,7 @@ class KeyPair(object): # If not - switch to partials(?) dh_map = { '25519': DH('ed25519'), - '448': DH('ed448') + # '448': DH('ed448') # TODO uncomment when ed448 is implemented } cipher_map = { diff --git a/requirements.txt b/requirements.txt index 2a6ccea..8d51f78 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,2 @@ pytest -pycryptodome -ed25519 +cryptography diff --git a/tests/test_vectors.py b/tests/test_vectors.py index e1e6977..a93ec02 100644 --- a/tests/test_vectors.py +++ b/tests/test_vectors.py @@ -14,8 +14,9 @@ vector_files = ['vectors/cacophony.txt'] # As in test vectors specification (https://github.com/noiseprotocol/noise_wiki/wiki/Test-vectors) # We use this to cast read strings into bytes -string_fields = ['protocol_name', 'init_prologue', 'init_static', 'init_ephemeral', 'init_remote_static', - 'resp_prologue', 'resp_static', 'resp_ephemeral', 'resp_remote_static', 'handshake_hash'] +byte_fields = ['protocol_name'] +hexbyte_fields = ['init_prologue', 'init_static', 'init_ephemeral', 'init_remote_static', 'resp_static', + 'resp_prologue', 'resp_ephemeral', 'resp_remote_static', 'handshake_hash'] list_fields = ['init_psks', 'resp_psks'] dict_field = 'messages' @@ -28,9 +29,13 @@ def _prepare_test_vectors(): vectors_list = json.load(fd) for vector in vectors_list: + if '_448_' in vector['protocol_name']: + continue # TODO REMOVE WHEN ed448 IS IMPLEMENTED for key, value in vector.copy().items(): - if key in string_fields: + if key in byte_fields: vector[key] = value.encode() + if key in hexbyte_fields: + vector[key] = bytes.fromhex(value) if key in list_fields: vector[key] = [k.encode() for k in value] if key == dict_field: @@ -54,9 +59,9 @@ class TestVectors(object): role_key = role + '_' + key if role_key in vector: if key in ['static', 'ephemeral']: - kwargs[role][kwarg] = KeyPair(private=vector[role_key]) - else: - kwargs[role][kwarg] = KeyPair(public=vector[role_key]) + kwargs[role][kwarg] = KeyPair._25519_from_private_bytes(vector[role_key]) # TODO unify after adding 448 + elif key == 'remote_static': + kwargs[role][kwarg] = KeyPair._25519_from_public_bytes(vector[role_key]) # TODO unify after adding 448 return kwargs def test_vector(self, vector):