[omemo] More refactoring

This commit is contained in:
Philipp Hörist
2019-02-16 10:58:13 +01:00
parent 0e25e70be3
commit 3a8fb991c8
2 changed files with 176 additions and 146 deletions

View File

@@ -17,6 +17,9 @@
import logging import logging
import time import time
import binascii
import textwrap
from collections import defaultdict
from nbxmpp.structs import OMEMOBundle from nbxmpp.structs import OMEMOBundle
from nbxmpp.structs import OMEMOMessage from nbxmpp.structs import OMEMOMessage
@@ -31,6 +34,7 @@ from axolotl.sessionbuilder import SessionBuilder
from axolotl.sessioncipher import SessionCipher from axolotl.sessioncipher import SessionCipher
from axolotl.state.prekeybundle import PreKeyBundle from axolotl.state.prekeybundle import PreKeyBundle
from axolotl.util.keyhelper import KeyHelper from axolotl.util.keyhelper import KeyHelper
from axolotl.duplicatemessagexception import DuplicateMessageException
from omemo.backend.aes import aes_decrypt, aes_encrypt from omemo.backend.aes import aes_decrypt, aes_encrypt
from omemo.backend.liteaxolotlstore import LiteAxolotlStore from omemo.backend.liteaxolotlstore import LiteAxolotlStore
@@ -52,7 +56,7 @@ class OmemoState:
def __init__(self, own_jid, db_con, account, xmpp_con): def __init__(self, own_jid, db_con, account, xmpp_con):
self.account = account self.account = account
self.xmpp_con = xmpp_con self.xmpp_con = xmpp_con
self.session_ciphers = {} self._session_ciphers = defaultdict(dict)
self.own_jid = own_jid self.own_jid = own_jid
self.device_ids = {} self.device_ids = {}
self.own_devices = [] self.own_devices = []
@@ -71,40 +75,31 @@ class OmemoState:
self.account, self.account,
self.store.getPreKeyCount()) self.store.getPreKeyCount())
def build_session(self, recipient_id, device_id, bundle): def build_session(self, jid, device_id, bundle):
sessionBuilder = SessionBuilder(self.store, self.store, self.store, session = SessionBuilder(self.store, self.store, self.store,
self.store, recipient_id, device_id) self.store, jid, device_id)
registration_id = self.store.getLocalRegistrationId() registration_id = self.store.getLocalRegistrationId()
prekey = bundle.pick_prekey() prekey = bundle.pick_prekey()
preKeyPublic = DjbECPublicKey(prekey['key'][1:]) otpk = DjbECPublicKey(prekey['key'][1:])
signedPreKeyPublic = DjbECPublicKey(bundle.spk['key'][1:]) spk = DjbECPublicKey(bundle.spk['key'][1:])
identityKey = IdentityKey(DjbECPublicKey(bundle.ik[1:])) ik = IdentityKey(DjbECPublicKey(bundle.ik[1:]))
prekey_bundle = PreKeyBundle( prekey_bundle = PreKeyBundle(registration_id,
registration_id, device_id, device_id,
prekey['id'], preKeyPublic, prekey['id'],
bundle.spk['id'], signedPreKeyPublic, otpk,
bundle.spk_signature, bundle.spk['id'],
identityKey) spk,
bundle.spk_signature,
ik)
sessionBuilder.processPreKeyBundle(prekey_bundle) session.processPreKeyBundle(prekey_bundle)
return self.get_session_cipher(recipient_id, device_id) return self._get_session_cipher(jid, device_id)
def set_devices(self, name, devices): def set_devices(self, name, devices):
""" Return a an.
Parameters
----------
jid : string
The contacts jid
devices: [int]
A list of devices
"""
self.device_ids[name] = devices self.device_ids[name] = devices
log.info('%s => Saved devices for %s', self.account, name) log.info('%s => Saved devices for %s', self.account, name)
@@ -146,62 +141,62 @@ class OmemoState:
@property @property
def bundle(self): def bundle(self):
self.checkPreKeyAmount() self._check_pre_key_count()
bundle = {'otpks': []} bundle = {'otpks': []}
for k in self.store.loadPendingPreKeys(): for k in self.store.loadPendingPreKeys():
key = k.getKeyPair().getPublicKey().serialize() key = k.getKeyPair().getPublicKey().serialize()
bundle['otpks'].append({'key': key, 'id': k.getId()}) bundle['otpks'].append({'key': key, 'id': k.getId()})
identityKeyPair = self.store.getIdentityKeyPair() ik_pair = self.store.getIdentityKeyPair()
bundle['ik'] = identityKeyPair.getPublicKey().serialize() bundle['ik'] = ik_pair.getPublicKey().serialize()
self.cycleSignedPreKey(identityKeyPair) self._cycle_signed_pre_key(ik_pair)
signedPreKey = self.store.loadSignedPreKey( spk = self.store.loadSignedPreKey(
self.store.getCurrentSignedPreKeyId()) self.store.getCurrentSignedPreKeyId())
bundle['spk_signature'] = signedPreKey.getSignature() bundle['spk_signature'] = spk.getSignature()
bundle['spk'] = {'key': signedPreKey.getKeyPair().getPublicKey().serialize(), bundle['spk'] = {'key': spk.getKeyPair().getPublicKey().serialize(),
'id': signedPreKey.getId()} 'id': spk.getId()}
return OMEMOBundle(**bundle) return OMEMOBundle(**bundle)
def decrypt_msg(self, omemo_message, jid): def decrypt_message(self, omemo_message, jid):
own_id = self.own_device_id if omemo_message.sid == self.own_device_id:
if omemo_message.sid == own_id:
log.info('Received previously sent message by us') log.info('Received previously sent message by us')
return raise SelfMessage
if own_id not in omemo_message.keys:
log.warning('OMEMO message does not contain our device key')
return
encrypted_key, prekey = omemo_message.keys[own_id] try:
encrypted_key, prekey = omemo_message.keys[self.own_device_id]
except KeyError:
log.info('Received message not for our device')
raise MessageNotForDevice
if prekey: try:
try: if prekey:
key = self.handlePreKeyWhisperMessage( key = self._process_pre_key_message(
jid, omemo_message.sid, encrypted_key) jid, omemo_message.sid, encrypted_key)
except Exception as error: else:
log.warning(error) key = self._process_message(
return
else:
try:
key = self.handleWhisperMessage(
jid, omemo_message.sid, encrypted_key) jid, omemo_message.sid, encrypted_key)
except Exception as error:
log.warning(error) except DuplicateMessageException:
return log.info('Received duplicated message')
raise DuplicateMessage
except Exception as error:
log.warning(error)
raise DecryptionFailed
if omemo_message.payload is None: if omemo_message.payload is None:
result = None
log.debug("Decrypted Key Exchange Message") log.debug("Decrypted Key Exchange Message")
else: raise KeyExchangeMessage
result = aes_decrypt(key, omemo_message.iv, omemo_message.payload)
log.debug("Decrypted Message => %s", result) result = aes_decrypt(key, omemo_message.iv, omemo_message.payload)
log.debug("Decrypted Message => %s", result)
return result return result
def create_msg(self, from_jid, jid, plaintext): def create_msg(self, jid, plaintext):
encrypted_keys = {} encrypted_keys = {}
devices_list = self.device_list_for(jid) devices_list = self.device_list_for(jid)
@@ -215,7 +210,7 @@ class OmemoState:
for device in devices_list: for device in devices_list:
try: try:
if self.isTrusted(jid, device) == TRUSTED: if self.isTrusted(jid, device) == TRUSTED:
cipher = self.get_session_cipher(jid, device) cipher = self._get_session_cipher(jid, device)
cipher_key = cipher.encrypt(result.key) cipher_key = cipher.encrypt(result.key)
prekey = isinstance(cipher_key, PreKeyWhisperMessage) prekey = isinstance(cipher_key, PreKeyWhisperMessage)
encrypted_keys[device] = (cipher_key.serialize(), prekey) encrypted_keys[device] = (cipher_key.serialize(), prekey)
@@ -233,14 +228,14 @@ class OmemoState:
# Encrypt the message key with for each of our own devices # Encrypt the message key with for each of our own devices
for device in my_other_devices: for device in my_other_devices:
try: try:
if self.isTrusted(from_jid, device) == TRUSTED: if self.isTrusted(self.own_jid, device) == TRUSTED:
cipher = self.get_session_cipher(from_jid, device) cipher = self._get_session_cipher(self.own_jid, device)
cipher_key = cipher.encrypt(result.key) cipher_key = cipher.encrypt(result.key)
prekey = isinstance(cipher_key, PreKeyWhisperMessage) prekey = isinstance(cipher_key, PreKeyWhisperMessage)
encrypted_keys[device] = (cipher_key.serialize(), prekey) encrypted_keys[device] = (cipher_key.serialize(), prekey)
else: else:
log.debug('Skipped own Device because Trust is: %s', log.debug('Skipped own Device because Trust is: %s',
self.isTrusted(from_jid, device)) self.isTrusted(self.own_jid, device))
except Exception: except Exception:
log.warning('Failed to find key for device: %s', device) log.warning('Failed to find key for device: %s', device)
@@ -260,7 +255,7 @@ class OmemoState:
result = aes_encrypt(plaintext) result = aes_encrypt(plaintext)
for tup in devices_list: for tup in devices_list:
self.get_session_cipher(tup[0], tup[1]) self._get_session_cipher(tup[0], tup[1])
# Encrypt the message key with for each of receivers devices # Encrypt the message key with for each of receivers devices
for nick in self.xmpp_con.groupchat[room]: for nick in self.xmpp_con.groupchat[room]:
@@ -269,9 +264,9 @@ class OmemoState:
continue continue
if jid_to in encrypted_jids: # We already encrypted to this JID if jid_to in encrypted_jids: # We already encrypted to this JID
continue continue
if jid_to not in self.session_ciphers: if jid_to not in self._session_ciphers:
continue continue
for rid, cipher in self.session_ciphers[jid_to].items(): for rid, cipher in self._session_ciphers[jid_to].items():
try: try:
if self.isTrusted(jid_to, rid) == TRUSTED: if self.isTrusted(jid_to, rid) == TRUSTED:
cipher_key = cipher.encrypt(result.key) cipher_key = cipher.encrypt(result.key)
@@ -289,7 +284,7 @@ class OmemoState:
# Encrypt the message key with for each of our own devices # Encrypt the message key with for each of our own devices
for dev in my_other_devices: for dev in my_other_devices:
try: try:
cipher = self.get_session_cipher(from_jid, dev) cipher = self._get_session_cipher(from_jid, dev)
if self.isTrusted(from_jid, dev) == TRUSTED: if self.isTrusted(from_jid, dev) == TRUSTED:
cipher_key = cipher.encrypt(result.key) cipher_key = cipher.encrypt(result.key)
prekey = isinstance(cipher_key, PreKeyWhisperMessage) prekey = isinstance(cipher_key, PreKeyWhisperMessage)
@@ -382,23 +377,21 @@ class OmemoState:
self.account, jid, missing_devices) self.account, jid, missing_devices)
return missing_devices return missing_devices
def get_session_cipher(self, jid, device_id): def _get_session_cipher(self, jid, device_id):
if jid not in self.session_ciphers: try:
self.session_ciphers[jid] = {} return self._session_ciphers[jid][device_id]
except KeyError:
if device_id not in self.session_ciphers[jid]:
cipher = SessionCipher(self.store, self.store, self.store, cipher = SessionCipher(self.store, self.store, self.store,
self.store, jid, device_id) self.store, jid, device_id)
self.session_ciphers[jid][device_id] = cipher self._session_ciphers[jid][device_id] = cipher
return cipher
return self.session_ciphers[jid][device_id] def _process_pre_key_message(self, recipient_id, device_id, key):
def handlePreKeyWhisperMessage(self, recipient_id, device_id, key):
preKeyWhisperMessage = PreKeyWhisperMessage(serialized=key) preKeyWhisperMessage = PreKeyWhisperMessage(serialized=key)
if not preKeyWhisperMessage.getPreKeyId(): if not preKeyWhisperMessage.getPreKeyId():
raise Exception('Received PreKeyWhisperMessage ' raise Exception('Received PreKeyWhisperMessage '
'without PreKey => %s' % recipient_id) 'without PreKey => %s' % recipient_id)
sessionCipher = self.get_session_cipher(recipient_id, device_id) sessionCipher = self._get_session_cipher(recipient_id, device_id)
try: try:
log.debug('%s => Received PreKeyWhisperMessage from %s', log.debug('%s => Received PreKeyWhisperMessage from %s',
self.account, recipient_id) self.account, recipient_id)
@@ -413,12 +406,12 @@ class OmemoState:
'from Untrusted Fingerprint! => %s', 'from Untrusted Fingerprint! => %s',
self.account, error.getName()) self.account, error.getName())
def handleWhisperMessage(self, recipient_id, device_id, key): def _process_message(self, recipient_id, device_id, key):
whisperMessage = WhisperMessage(serialized=key) whisperMessage = WhisperMessage(serialized=key)
log.debug('%s => Received WhisperMessage from %s', log.debug('%s => Received WhisperMessage from %s',
self.account, recipient_id) self.account, recipient_id)
if self.isTrusted(recipient_id, device_id): if self.isTrusted(recipient_id, device_id):
sessionCipher = self.get_session_cipher(recipient_id, device_id) sessionCipher = self._get_session_cipher(recipient_id, device_id)
key = sessionCipher.decryptMsg(whisperMessage, textMsg=False) key = sessionCipher.decryptMsg(whisperMessage, textMsg=False)
self.add_device(recipient_id, device_id) self.add_device(recipient_id, device_id)
return key return key
@@ -426,24 +419,24 @@ class OmemoState:
raise Exception('Received WhisperMessage ' raise Exception('Received WhisperMessage '
'from Untrusted Fingerprint! => %s' % recipient_id) 'from Untrusted Fingerprint! => %s' % recipient_id)
def checkPreKeyAmount(self): def _check_pre_key_count(self):
# Check if enough PreKeys are available # Check if enough PreKeys are available
preKeyCount = self.store.getPreKeyCount() pre_key_count = self.store.getPreKeyCount()
if preKeyCount < MIN_PREKEY_AMOUNT: if pre_key_count < MIN_PREKEY_AMOUNT:
newKeys = DEFAULT_PREKEY_AMOUNT - preKeyCount missing_count = DEFAULT_PREKEY_AMOUNT - pre_key_count
self.store.generateNewPreKeys(newKeys) self.store.generateNewPreKeys(missing_count)
log.info('%s => %s PreKeys created', self.account, newKeys) log.info('%s => %s PreKeys created', self.account, missing_count)
def cycleSignedPreKey(self, identityKeyPair): def _cycle_signed_pre_key(self, ik_pair):
# Publish every SPK_CYCLE_TIME a new SignedPreKey # Publish every SPK_CYCLE_TIME a new SignedPreKey
# Delete all exsiting SignedPreKeys that are older # Delete all exsiting SignedPreKeys that are older
# then SPK_ARCHIVE_TIME # then SPK_ARCHIVE_TIME
# Check if SignedPreKey exist and create if not # Check if SignedPreKey exist and create if not
if not self.store.getCurrentSignedPreKeyId(): if not self.store.getCurrentSignedPreKeyId():
signedPreKey = KeyHelper.generateSignedPreKey( spk = KeyHelper.generateSignedPreKey(
identityKeyPair, self.store.getNextSignedPreKeyId()) ik_pair, self.store.getNextSignedPreKeyId())
self.store.storeSignedPreKey(signedPreKey.getId(), signedPreKey) self.store.storeSignedPreKey(spk.getId(), spk)
log.debug('%s => New SignedPreKey created, because none existed', log.debug('%s => New SignedPreKey created, because none existed',
self.account) self.account)
@@ -453,9 +446,9 @@ class OmemoState:
self.store.getCurrentSignedPreKeyId()) self.store.getCurrentSignedPreKeyId())
if int(timestamp) < now - SPK_CYCLE_TIME: if int(timestamp) < now - SPK_CYCLE_TIME:
signedPreKey = KeyHelper.generateSignedPreKey( spk = KeyHelper.generateSignedPreKey(
identityKeyPair, self.store.getNextSignedPreKeyId()) ik_pair, self.store.getNextSignedPreKeyId())
self.store.storeSignedPreKey(signedPreKey.getId(), signedPreKey) self.store.storeSignedPreKey(spk.getId(), spk)
log.debug('%s => Cycled SignedPreKey', self.account) log.debug('%s => Cycled SignedPreKey', self.account)
# Delete all SignedPreKeys that are older than SPK_ARCHIVE_TIME # Delete all SignedPreKeys that are older than SPK_ARCHIVE_TIME
@@ -465,3 +458,27 @@ class OmemoState:
class NoValidSessions(Exception): class NoValidSessions(Exception):
pass pass
class SelfMessage(Exception):
pass
class MessageNotForDevice(Exception):
pass
class DecryptionFailed(Exception):
pass
class KeyExchangeMessage(Exception):
pass
class InvalidMessage(Exception):
pass
class DuplicateMessage(Exception):
pass

View File

@@ -39,8 +39,14 @@ from gajim.common.modules.base import BaseModule
from gajim.common.modules.util import event_node from gajim.common.modules.util import event_node
from omemo.backend.state import OmemoState from omemo.backend.state import OmemoState
from omemo.backend.state import KeyExchangeMessage
from omemo.backend.state import SelfMessage
from omemo.backend.state import MessageNotForDevice
from omemo.backend.state import DecryptionFailed
from omemo.backend.state import DuplicateMessage
from omemo.modules.util import prepare_stanza from omemo.modules.util import prepare_stanza
ALLOWED_TAGS = [ ALLOWED_TAGS = [
('request', nbxmpp.NS_RECEIPTS), ('request', nbxmpp.NS_RECEIPTS),
('active', nbxmpp.NS_CHATSTATES), ('active', nbxmpp.NS_CHATSTATES),
@@ -159,65 +165,73 @@ class OMEMO(BaseModule):
return return
if properties.is_mam_message: if properties.is_mam_message:
if properties.omemo.sid == self.omemo.own_device_id: from_jid = self._process_mam_message(properties)
log.info('%s => Skip message because it was sent by us', elif properties.from_muc:
self._account) from_jid = self._process_muc_message(properties)
raise NodeProcessed
log.info('%s => Message received, archive: %s',
self._account, properties.mam.archive)
else: else:
log.info('%s => Message received', self._account) from_jid = properties.jid.getBare()
from_jid = properties.jid.getBare() if from_jid is None:
return
if properties.from_muc: log.info('%s => Message received from: %s', self._account, from_jid)
if properties.is_mam_message:
log.info('%s => MUC MAM Message received', self._account)
if properties.muc_user.jid is None:
log.info('No real jid found, ignore message')
return
from_jid = properties.muc_user.jid.getBare()
else:
room_jid = properties.jid.getBare()
resource = properties.jid.getResource()
if properties.muc_ofrom is not None:
# History Message from MUC
from_jid = properties.muc_ofrom.getBare()
else:
try:
from_jid = self.groupchat[room_jid][resource]
except KeyError:
log.debug('Groupchat: Last resort trying to '
'find SID in DB')
from_jid = self.omemo.store.getJidFromDevice(
properties.omemo.sid)
if not from_jid:
log.error("%s => Can't decrypt GroupChat Message "
"from %s", self._account, resource)
return
self.groupchat[room_jid][resource] = from_jid
log.debug('GroupChat Message from: %s', from_jid) try:
return self.omemo.decrypt_message(properties.omemo,
from_jid)
except (KeyExchangeMessage, DuplicateMessage):
raise NodeProcessed
plaintext = '' except SelfMessage:
if properties.omemo.sid == self.omemo.own_device_id: if properties.from_muc:
if properties.omemo.payload in self.gc_message: if properties.omemo.payload in self.gc_message:
plaintext = self.gc_message[properties.omemo.payload] plaintext = self.gc_message[properties.omemo.payload]
del self.gc_message[properties.omemo.payload] del self.gc_message[properties.omemo.payload]
else: return plaintext
log.error("%s => Can't decrypt own GroupChat Message",
self._account)
return
else:
plaintext = self.omemo.decrypt_msg(properties.omemo, from_jid)
if not plaintext: log.warning("%s => Can't decrypt own GroupChat Message",
self._account)
raise NodeProcessed
except (DecryptionFailed, MessageNotForDevice):
return return
prepare_stanza(stanza, plaintext) prepare_stanza(stanza, plaintext)
self.print_msg_to_log(stanza) self._debug_print_stanza(stanza)
properties.encrypted = EncryptionData({'name': ENCRYPTION_NAME}) properties.encrypted = EncryptionData({'name': ENCRYPTION_NAME})
def _process_muc_message(self, properties):
room_jid = properties.jid.getBare()
resource = properties.jid.getResource()
if properties.muc_ofrom is not None:
# History Message from MUC
return properties.muc_ofrom.getBare()
try:
return self.groupchat[room_jid][resource]
except KeyError:
log.info('%s => Groupchat: Last resort trying to '
'find SID in DB', self._account)
from_jid = self.omemo.store.getJidFromDevice(properties.omemo.sid)
if not from_jid:
log.error("%s => Can't decrypt GroupChat Message "
"from %s", self._account, resource)
return
self.groupchat[room_jid][resource] = from_jid
return from_jid
def _process_mam_message(self, properties):
log.info('%s => Message received, archive: %s',
self._account, properties.mam.archive)
from_jid = properties.jid.getBare()
if properties.from_muc:
log.info('%s => MUC MAM Message received', self._account)
if properties.muc_user.jid is None:
log.info('%s => No real jid found', self._account)
return
from_jid = properties.muc_user.jid.getBare()
return from_jid
def _on_muc_user_presence(self, _con, _stanza, properties): def _on_muc_user_presence(self, _con, _stanza, properties):
if properties.type == PresenceType.ERROR: if properties.type == PresenceType.ERROR:
return return
@@ -371,7 +385,7 @@ class OMEMO(BaseModule):
create_omemo_message(event.msg_iq, omemo_message, create_omemo_message(event.msg_iq, omemo_message,
node_whitelist=ALLOWED_TAGS) node_whitelist=ALLOWED_TAGS)
self.print_msg_to_log(event.msg_iq) self._debug_print_stanza(event.msg_iq)
callback(event) callback(event)
def encrypt_message(self, conn, event, callback): def encrypt_message(self, conn, event, callback):
@@ -385,8 +399,7 @@ class OMEMO(BaseModule):
to_jid = app.get_jid_without_resource(event.jid) to_jid = app.get_jid_without_resource(event.jid)
try: try:
omemo_message = self.omemo.create_msg( omemo_message = self.omemo.create_msg(to_jid, event.message)
self.own_jid, to_jid, event.message)
if omemo_message is None: if omemo_message is None:
raise OMEMOError('Error while encrypting') raise OMEMOError('Error while encrypting')
@@ -401,7 +414,7 @@ class OMEMO(BaseModule):
create_omemo_message(event.msg_iq, omemo_message, create_omemo_message(event.msg_iq, omemo_message,
node_whitelist=ALLOWED_TAGS) node_whitelist=ALLOWED_TAGS)
self.print_msg_to_log(event.msg_iq) self._debug_print_stanza(event.msg_iq)
event.xhtml = None event.xhtml = None
event.encrypted = ENCRYPTION_NAME event.encrypted = ENCRYPTION_NAME
event.additional_data['encrypted'] = {'name': ENCRYPTION_NAME} event.additional_data['encrypted'] = {'name': ENCRYPTION_NAME}
@@ -561,7 +574,7 @@ class OMEMO(BaseModule):
self.are_keys_missing(jid) self.are_keys_missing(jid)
@staticmethod @staticmethod
def print_msg_to_log(stanza): def _debug_print_stanza(stanza):
log.debug('-'*15) log.debug('-'*15)
stanzastr = '\n' + stanza.__str__(fancy=True) stanzastr = '\n' + stanza.__str__(fancy=True)
stanzastr = stanzastr[0:-1] stanzastr = stanzastr[0:-1]