diff --git a/noise/functions.py b/noise/functions.py index 80eb20d..66bfd1e 100644 --- a/noise/functions.py +++ b/noise/functions.py @@ -1,5 +1,7 @@ import abc +from cryptography.hazmat.primitives.hmac import HMAC + from .crypto import ed448 from cryptography.hazmat.backends import default_backend @@ -56,18 +58,22 @@ class Hash(object): self.hashlen = 32 self.blocklen = 64 self.hash = self._hash_sha256 + self.fn = hashes.SHA256 elif method == 'SHA512': self.hashlen = 64 self.blocklen = 128 self.hash = self._hash_sha512 + self.fn = hashes.SHA512 elif method == 'BLAKE2s': self.hashlen = 32 self.blocklen = 64 self.hash = self._hash_blake2s + self.fn = hashes.BLAKE2s elif method == 'BLAKE2b': self.hashlen = 64 self.blocklen = 128 self.hash = self._hash_blake2b + self.fn = hashes.BLAKE2b else: raise NotImplementedError('Hash method: {}'.format(method)) @@ -146,4 +152,32 @@ hash_map = { keypair_map = { '25519': KeyPair25519, # '448': DH('ed448') # TODO uncomment when ed448 is implemented -} \ No newline at end of file +} + + +def hmac_hash(key, data, algorithm): + # Applies HMAC using the HASH() function. + hmac = HMAC(key=key, algorithm=algorithm, backend=default_backend()) + hmac.update(data=data) + return hmac + + +def hkdf(chaining_key, input_key_material, num_outputs, hmac_hash_fn): + # Sets temp_key = HMAC-HASH(chaining_key, input_key_material). + temp_key = hmac_hash_fn(chaining_key, input_key_material) + + # Sets output1 = HMAC-HASH(temp_key, byte(0x01)). + output1 = hmac_hash_fn(temp_key, b'\x01') + + # Sets output2 = HMAC-HASH(temp_key, output1 || byte(0x02)). + output2 = hmac_hash_fn(temp_key, output1 + b'\x02') + + # If num_outputs == 2 then returns the pair (output1, output2). + if num_outputs == 2: + return output1, output2 + + # Sets output3 = HMAC-HASH(temp_key, output2 || byte(0x03)). + output3 = hmac_hash_fn(temp_key, output2 + b'\x03') + + # Returns the triple (output1, output2, output3). + return output1, output2, output3 diff --git a/noise/noise_protocol.py b/noise/noise_protocol.py index 8689787..964fa90 100644 --- a/noise/noise_protocol.py +++ b/noise/noise_protocol.py @@ -1,7 +1,8 @@ +from functools import partial from typing import Tuple from .constants import MAX_PROTOCOL_NAME_LEN, Empty -from .functions import dh_map, cipher_map, hash_map, keypair_map, KeyPair25519 +from .functions import dh_map, cipher_map, hash_map, keypair_map, hmac_hash, hkdf from .patterns import patterns_map @@ -36,6 +37,8 @@ class NoiseProtocol(object): self.cipher_fn = mappings['cipher'] self.hash_fn = mappings['hash'] self.keypair_fn = mappings['keypair'] + self.hmac = partial(hmac_hash, algorithm=self.hash_fn.fn) + self.hkdf = partial(hkdf, hmac_hash_fn=self.hmac) self.psks = None # Placeholder for PSKs diff --git a/noise/state.py b/noise/state.py index db67d61..ea0993e 100644 --- a/noise/state.py +++ b/noise/state.py @@ -13,7 +13,7 @@ class CipherState(object): self.noise_protocol = None @classmethod - def initialize_key(cls, key, noise_protocol: 'NoiseProtocol') -> 'CipherState': + def initialize_key(cls, key, noise_protocol: 'NoiseProtocol', create=True) -> 'CipherState': # TODO: fix for split case """ :param key: @@ -94,43 +94,69 @@ class SymmetricState(object): return instance - def mix_key(self, input_key_material): + def mix_key(self, input_key_material: bytes): """ - :param input_key_material: :return: """ + # Sets ck, temp_k = HKDF(ck, input_key_material, 2). + self.ck, temp_k = self.noise_protocol.hkdf(self.ck, input_key_material, 2) + # If HASHLEN is 64, then truncates temp_k to 32 bytes. + if self.noise_protocol.hash_fn.hashlen == 64: + temp_k = temp_k[:32] - def mix_hash(self, data): + # Calls InitializeKey(temp_k). + self.noise_protocol.cipher_state.initialize_key(temp_k) # TODO check for memory leaks here + + def mix_hash(self, data: bytes): """ - - :param data: - :return: + Sets h = HASH(h + data). + :param data: bytes sequence """ self.h = self.noise_protocol.hash_fn.hash(data + self.h) - def encrypt_and_hash(self, plaintext): + def encrypt_and_hash(self, plaintext: bytes) -> bytes: """ - - :param plaintext: - :return: + Sets ciphertext = EncryptWithAd(h, plaintext), calls MixHash(ciphertext), and returns ciphertext. Note that if + k is empty, the EncryptWithAd() call will set ciphertext equal to plaintext. + :param plaintext: bytes sequence + :return: ciphertext bytes sequence """ - pass + ciphertext = self.noise_protocol.cipher_state.encrypt_with_ad(self.h, plaintext) + self.mix_hash(ciphertext) + return ciphertext - def decrypt_and_hash(self, ciphertext): + def decrypt_and_hash(self, ciphertext: bytes) -> bytes: """ - - :param ciphertext: - :return: + Sets plaintext = DecryptWithAd(h, ciphertext), calls MixHash(ciphertext), and returns plaintext. Note that if + k is empty, the DecryptWithAd() call will set plaintext equal to ciphertext. + :param ciphertext: bytes sequence + :return: plaintext bytes sequence """ - pass + plaintext = self.noise_protocol.cipher_state.decrypt_with_ad(self.h, ciphertext) + self.mix_hash(ciphertext) + return plaintext def split(self): """ - - :return: + Returns a pair of CipherState objects for encrypting/decrypting transport messages. + :return: tuple (CipherState, CipherState) """ - pass + # Sets temp_k1, temp_k2 = HKDF(ck, b'', 2). + temp_k1, temp_k2 = self.noise_protocol.hkdf(self.ck, b'', 2) + + # If HASHLEN is 64, then truncates temp_k1 and temp_k2 to 32 bytes. + if self.noise_protocol.hash_fn.hashlen == 64: + temp_k1 = temp_k1[:32] + temp_k2 = temp_k2[:32] + + # Creates two new CipherState objects c1 and c2. + # Calls c1.InitializeKey(temp_k1) and c2.InitializeKey(temp_k2). + c1 = CipherState.initialize_key(temp_k1, self.noise_protocol) # TODO WRONG! + c2 = CipherState.initialize_key(temp_k2, self.noise_protocol) # TODO WRONG! + + # Returns the pair (c1, c2). + return c1, c2 class HandshakeState(object):