diff --git a/noise/backends/default/diffie_hellmans.py b/noise/backends/default/diffie_hellmans.py index 5ddbc65..008cca0 100644 --- a/noise/backends/default/diffie_hellmans.py +++ b/noise/backends/default/diffie_hellmans.py @@ -1,4 +1,5 @@ from cryptography.hazmat.primitives.asymmetric import x25519, x448 +from cryptography.hazmat.primitives import serialization from noise.backends.default.keypairs import KeyPair25519, KeyPair448 from noise.exceptions import NoiseValueError @@ -17,7 +18,9 @@ class ED25519(DH): def generate_keypair(self) -> 'KeyPair': private_key = x25519.X25519PrivateKey.generate() public_key = private_key.public_key() - return KeyPair25519(private_key, public_key, public_key.public_bytes()) + return KeyPair25519(private_key, public_key, + public_key.public_bytes(serialization.Encoding.Raw, + serialization.PublicFormat.Raw)) def dh(self, private_key, public_key) -> bytes: if not isinstance(private_key, x25519.X25519PrivateKey) or not isinstance(public_key, x25519.X25519PublicKey): @@ -37,7 +40,9 @@ class ED448(DH): def generate_keypair(self) -> 'KeyPair': private_key = x448.X448PrivateKey.generate() public_key = private_key.public_key() - return KeyPair448(private_key, public_key, public_key.public_bytes()) + return KeyPair448(private_key, public_key, + public_key.public_bytes(serialization.Encoding.Raw, + serialization.PublicFormat.Raw)) def dh(self, private_key, public_key) -> bytes: if not isinstance(private_key, x448.X448PrivateKey) or not isinstance(public_key, x448.X448PublicKey): diff --git a/tests/test_connection.py b/tests/test_connection.py new file mode 100644 index 0000000..3995b53 --- /dev/null +++ b/tests/test_connection.py @@ -0,0 +1,34 @@ +from noise.connection import NoiseConnection + +class TestConnection(object): + def do_test_connection(self, name): + key = b"\x00" * 32 + left = NoiseConnection.from_name(name) + left.set_psks(key) + left.set_as_initiator() + left.start_handshake() + + right = NoiseConnection.from_name(name) + right.set_psks(key) + right.set_as_responder() + right.start_handshake() + + h = left.write_message() + _ = right.read_message(h) + h2 = right.write_message() + left.read_message(h2) + + assert left.handshake_finished + assert right.handshake_finished + + enc = left.encrypt(b"hello") + dec = right.decrypt(enc) + assert dec == b"hello" + + def test_25519(self): + name = b"Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s" + self.do_test_connection(name) + + def test_448(self): + name = b"Noise_NNpsk0_448_ChaChaPoly_BLAKE2s" + self.do_test_connection(name)