diff --git a/noise/backends/default/backend.py b/noise/backends/default/backend.py index e6bc780..c22c95c 100644 --- a/noise/backends/default/backend.py +++ b/noise/backends/default/backend.py @@ -1,7 +1,7 @@ from noise.backends.default.ciphers import ChaCha20Cipher, AESGCMCipher -from noise.backends.default.diffie_hellmans import ED25519, ED448 from noise.backends.default.hashes import hmac_hash, BLAKE2sHash, BLAKE2bHash, SHA256Hash, SHA512Hash -from noise.backends.default.keypairs import KeyPair25519, KeyPair448 +from noise.backends.default.diffie_hellmans import ED25519, ED448, SECP256R1 +from noise.backends.default.keypairs import KeyPair25519, KeyPair448, KeyPairSecp256r1 from noise.backends.noise_backend import NoiseBackend @@ -15,7 +15,8 @@ class DefaultNoiseBackend(NoiseBackend): self.diffie_hellmans = { '25519': ED25519, - '448': ED448 + '448': ED448, + 'secp256r1': SECP256R1 } self.ciphers = { @@ -32,7 +33,8 @@ class DefaultNoiseBackend(NoiseBackend): self.keypairs = { '25519': KeyPair25519, - '448': KeyPair448 + '448': KeyPair448, + 'secp256r1': KeyPairSecp256r1 } self.hmac = hmac_hash diff --git a/noise/backends/default/diffie_hellmans.py b/noise/backends/default/diffie_hellmans.py index 008cca0..07502ef 100644 --- a/noise/backends/default/diffie_hellmans.py +++ b/noise/backends/default/diffie_hellmans.py @@ -1,7 +1,7 @@ -from cryptography.hazmat.primitives.asymmetric import x25519, x448 +from cryptography.hazmat.primitives.asymmetric import x25519, x448, ec from cryptography.hazmat.primitives import serialization -from noise.backends.default.keypairs import KeyPair25519, KeyPair448 +from noise.backends.default.keypairs import KeyPair25519, KeyPair448, KeyPairSecp256r1 from noise.exceptions import NoiseValueError from noise.functions.dh import DH @@ -48,3 +48,27 @@ class ED448(DH): if not isinstance(private_key, x448.X448PrivateKey) or not isinstance(public_key, x448.X448PublicKey): raise NoiseValueError('Invalid keys! Must be x448.X448PrivateKey and x448.X448PublicKey instances') return private_key.exchange(public_key) + + +class SECP256R1(DH): + @property + def klass(self): + return KeyPairSecp256r1 + + @property + def dhlen(self): + return 65 + + def generate_keypair(self) -> 'KeyPair': + private_key = ec.generate_private_key(ec.SECP256R1()) + public_key = private_key.public_key() + public_bytes = public_key.public_bytes(serialization.Encoding.X962, + serialization.PublicFormat.UncompressedPoint) + return KeyPairSecp256r1(private_key, public_key, public_bytes) + + def dh(self, private_key, public_key) -> bytes: + if not isinstance(private_key, ec.EllipticCurvePrivateKey) or not isinstance(public_key, ec.EllipticCurvePublicKey): + raise NoiseValueError('Invalid keys! Must be secp256r1 private and public key instances') + if not isinstance(private_key.curve, ec.SECP256R1) or not isinstance(public_key.curve, ec.SECP256R1): + raise NoiseValueError('Invalid curve for secp256r1 DH') + return private_key.exchange(ec.ECDH(), public_key) diff --git a/noise/backends/default/keypairs.py b/noise/backends/default/keypairs.py index 5e74ddf..f8d499c 100644 --- a/noise/backends/default/keypairs.py +++ b/noise/backends/default/keypairs.py @@ -1,5 +1,5 @@ from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import x25519, x448 +from cryptography.hazmat.primitives.asymmetric import x25519, x448, ec from noise.exceptions import NoiseValueError from noise.functions.keypair import KeyPair @@ -37,3 +37,31 @@ class KeyPair448(KeyPair): raise NoiseValueError('Invalid length of private_bytes! Should be 56') public = x448.X448PublicKey.from_public_bytes(public_bytes) return cls(public=public, public_bytes=public.public_bytes(encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw)) + + +class KeyPairSecp256r1(KeyPair): + @classmethod + def from_private_bytes(cls, private_bytes): + if len(private_bytes) != 32: + raise NoiseValueError('Invalid length of private_bytes! Should be 32') + private_int = int.from_bytes(private_bytes, byteorder='big') + try: + private = ec.derive_private_key(private_int, ec.SECP256R1()) + except ValueError as exc: + raise NoiseValueError('Invalid secp256r1 private key scalar') from exc + public = private.public_key() + public_bytes = public.public_bytes(encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.UncompressedPoint) + return cls(private=private, public=public, public_bytes=public_bytes) + + @classmethod + def from_public_bytes(cls, public_bytes): + if len(public_bytes) != 65: + raise NoiseValueError('Invalid length of public_bytes! Should be 65') + try: + public = ec.EllipticCurvePublicKey.from_encoded_point(ec.SECP256R1(), public_bytes) + except ValueError as exc: + raise NoiseValueError('Invalid secp256r1 public key bytes') from exc + public_bytes = public.public_bytes(encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.UncompressedPoint) + return cls(public=public, public_bytes=public_bytes) diff --git a/tests/test_connection.py b/tests/test_connection.py index 3995b53..e9364f0 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,4 +1,6 @@ -from noise.connection import NoiseConnection +from cryptography.hazmat.primitives.asymmetric import ec + +from noise.connection import NoiseConnection, Keypair class TestConnection(object): def do_test_connection(self, name): @@ -32,3 +34,39 @@ class TestConnection(object): def test_448(self): name = b"Noise_NNpsk0_448_ChaChaPoly_BLAKE2s" self.do_test_connection(name) + + def test_secp256r1_xxpsk3(self): + name = b"Noise_XXpsk3_secp256r1_AESGCM_SHA256" + psk = b"\x01" * 32 + + left = NoiseConnection.from_name(name) + left.set_psks(psk) + left.set_as_initiator() + left_static = ec.generate_private_key(ec.SECP256R1()) + left_private_bytes = left_static.private_numbers().private_value.to_bytes(32, "big") + left.set_keypair_from_private_bytes(Keypair.STATIC, left_private_bytes) + left.start_handshake() + + right = NoiseConnection.from_name(name) + right.set_psks(psk) + right.set_as_responder() + right_static = ec.generate_private_key(ec.SECP256R1()) + right_private_bytes = right_static.private_numbers().private_value.to_bytes(32, "big") + right.set_keypair_from_private_bytes(Keypair.STATIC, right_private_bytes) + right.start_handshake() + + message1 = left.write_message() + right.read_message(message1) + + message2 = right.write_message() + left.read_message(message2) + + message3 = left.write_message() + right.read_message(message3) + + assert left.handshake_finished + assert right.handshake_finished + + ciphertext = left.encrypt(b"hello") + plaintext = right.decrypt(ciphertext) + assert plaintext == b"hello"