Do not reinitialise cipher class every time

Now CipherState holds instance of Cipher wrapper and manages
initialization of underlying cipher method with keys.

Closes #6
This commit is contained in:
Piotr Lizonczyk
2017-10-14 17:05:54 +02:00
parent d636c506d3
commit 2bac81d05c
2 changed files with 20 additions and 17 deletions

View File

@@ -65,28 +65,22 @@ class Cipher(object):
self.rekey = self._default_rekey self.rekey = self._default_rekey
else: else:
raise NotImplementedError('Cipher method: {}'.format(method)) raise NotImplementedError('Cipher method: {}'.format(method))
self.cipher = None
def _aesgcm_encrypt(self, k, n, ad, plaintext): def _aesgcm_encrypt(self, k, n, ad, plaintext):
# Might be expensive to initialise AESGCM with the same key every time. The key should be (as per spec) kept in return self.cipher.encrypt(nonce=self._aesgcm_nonce(n), data=plaintext, associated_data=ad)
# CipherState, but we may as well hold an initialised AESGCM and manage reinitialisation on CipherState.rekey
cipher = self._cipher(k)
return cipher.encrypt(nonce=self._aesgcm_nonce(n), data=plaintext, associated_data=ad)
def _aesgcm_decrypt(self, k, n, ad, ciphertext): def _aesgcm_decrypt(self, k, n, ad, ciphertext):
cipher = self._cipher(k) return self.cipher.decrypt(nonce=self._aesgcm_nonce(n), data=ciphertext, associated_data=ad)
return cipher.decrypt(nonce=self._aesgcm_nonce(n), data=ciphertext, associated_data=ad)
def _aesgcm_nonce(self, n): def _aesgcm_nonce(self, n):
return b'\x00\x00\x00\x00' + n.to_bytes(length=8, byteorder='big') return b'\x00\x00\x00\x00' + n.to_bytes(length=8, byteorder='big')
def _chacha20_encrypt(self, k, n, ad, plaintext): def _chacha20_encrypt(self, k, n, ad, plaintext):
# Same comment as with AESGCM return self.cipher.encrypt(nonce=self._chacha20_nonce(n), data=plaintext, associated_data=ad)
cipher = self._cipher(k)
return cipher.encrypt(nonce=self._chacha20_nonce(n), data=plaintext, associated_data=ad)
def _chacha20_decrypt(self, k, n, ad, ciphertext): def _chacha20_decrypt(self, k, n, ad, ciphertext):
cipher = self._cipher(k) return self.cipher.decrypt(nonce=self._chacha20_nonce(n), data=ciphertext, associated_data=ad)
return cipher.decrypt(nonce=self._chacha20_nonce(n), data=ciphertext, associated_data=ad)
def _chacha20_nonce(self, n): def _chacha20_nonce(self, n):
return b'\x00\x00\x00\x00' + n.to_bytes(length=8, byteorder='little') return b'\x00\x00\x00\x00' + n.to_bytes(length=8, byteorder='little')
@@ -94,6 +88,9 @@ class Cipher(object):
def _default_rekey(self, k): def _default_rekey(self, k):
return self.encrypt(k, MAX_NONCE, b'', b'\x00' * 32)[:32] return self.encrypt(k, MAX_NONCE, b'', b'\x00' * 32)[:32]
def initialize(self, key):
self.cipher = self._cipher(key)
class Hash(object): class Hash(object):
def __init__(self, method): def __init__(self, method):
@@ -209,8 +206,8 @@ dh_map = {
} }
cipher_map = { cipher_map = {
'AESGCM': Cipher('AESGCM'), 'AESGCM': partial(Cipher, 'AESGCM'),
'ChaChaPoly': Cipher('ChaCha20') 'ChaChaPoly': partial(Cipher, 'ChaCha20')
} }
hash_map = { hash_map = {

View File

@@ -9,11 +9,14 @@ class CipherState(object):
Implemented as per Noise Protocol specification (rev 32) - paragraph 5.1. Implemented as per Noise Protocol specification (rev 32) - paragraph 5.1.
The initialize_key() function takes additional required argument - noise_protocol. The initialize_key() function takes additional required argument - noise_protocol.
This class holds an instance of Cipher wrapper. It manages initialisation of underlying cipher function
with appropriate key in initialize_key() and rekey() methods.
""" """
def __init__(self, noise_protocol): def __init__(self, noise_protocol):
self.k = Empty() self.k = Empty()
self.n = None self.n = None
self.noise_protocol = noise_protocol self.cipher = noise_protocol.cipher_fn()
def initialize_key(self, key): def initialize_key(self, key):
""" """
@@ -21,6 +24,8 @@ class CipherState(object):
""" """
self.k = key self.k = key
self.n = 0 self.n = 0
if self.has_key():
self.cipher.initialize(key)
def has_key(self): def has_key(self):
""" """
@@ -41,7 +46,7 @@ class CipherState(object):
if not self.has_key(): if not self.has_key():
return plaintext return plaintext
ciphertext = self.noise_protocol.cipher_fn.encrypt(self.k, self.n, ad, plaintext) ciphertext = self.cipher.encrypt(self.k, self.n, ad, plaintext)
self.n = self.n + 1 self.n = self.n + 1
return ciphertext return ciphertext
@@ -59,12 +64,13 @@ class CipherState(object):
if not self.has_key(): if not self.has_key():
return ciphertext return ciphertext
plaintext = self.noise_protocol.cipher_fn.decrypt(self.k, self.n, ad, ciphertext) plaintext = self.cipher.decrypt(self.k, self.n, ad, ciphertext)
self.n = self.n + 1 self.n = self.n + 1
return plaintext return plaintext
def rekey(self): def rekey(self):
self.k = self.noise_protocol.cipher_fn.rekey(self.k) self.k = self.cipher.rekey(self.k)
self.cipher.initialize(self.k)
class SymmetricState(object): class SymmetricState(object):