[omemo] Port python-omemo changes from master

This commit is contained in:
Philipp Hörist
2017-01-24 11:29:58 +01:00
parent d38fe92d0f
commit a58ff4f6e8
8 changed files with 241 additions and 86 deletions

View File

@@ -18,6 +18,7 @@
# #
import sys
import logging import logging
log = logging.getLogger('gajim.plugin_system.omemo') log = logging.getLogger('gajim.plugin_system.omemo')
try: try:
@@ -35,7 +36,11 @@ def encrypt(key, iv, plaintext):
def decrypt(key, iv, ciphertext): def decrypt(key, iv, ciphertext):
return aes_decrypt(key, iv, ciphertext) plaintext = aes_decrypt(key, iv, ciphertext).decode('utf-8')
if sys.version_info < (3, 0):
return unicode(plaintext)
else:
return plaintext
class NoValidSessions(Exception): class NoValidSessions(Exception):

View File

@@ -29,11 +29,14 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import logging
from struct import pack, unpack from struct import pack, unpack
from Crypto.Cipher import AES from Crypto.Cipher import AES
from Crypto.Util import strxor from Crypto.Util import strxor
log = logging.getLogger('gajim.plugin_system.omemo')
def gcm_rightshift(vec): def gcm_rightshift(vec):
for x in range(15, 0, -1): for x in range(15, 0, -1):
@@ -140,13 +143,20 @@ def gcm_encrypt(k, iv, plaintext, auth_data):
def aes_encrypt(key, nonce, plaintext): def aes_encrypt(key, nonce, plaintext):
""" Use AES128 GCM with the given key and iv to encrypt the payload. """ """ Use AES128 GCM with the given key and iv to encrypt the payload. """
c, t = gcm_encrypt(key, nonce, plaintext, '') return gcm_encrypt(key, nonce, plaintext, '')
result = c + t
return result
def aes_decrypt(_key, nonce, payload):
def aes_decrypt(key, nonce, payload):
""" Use AES128 GCM with the given key and iv to decrypt the payload. """ """ Use AES128 GCM with the given key and iv to decrypt the payload. """
ciphertext = payload[:-16] if len(_key) >= 32:
mac = payload[-16:] # XEP-0384
log.debug('XEP Compliant Key/Tag')
ciphertext = payload
key = _key[:16]
mac = _key[16:]
else:
# Legacy
log.debug('Legacy Key/Tag')
ciphertext = payload[:-16]
key = _key
mac = payload[-16:]
return gcm_decrypt(key, nonce, ciphertext, '', mac) return gcm_decrypt(key, nonce, ciphertext, '', mac)

View File

@@ -19,6 +19,7 @@
import os import os
import logging
from cryptography.hazmat.primitives.ciphers import Cipher from cryptography.hazmat.primitives.ciphers import Cipher
from cryptography.hazmat.primitives.ciphers import algorithms from cryptography.hazmat.primitives.ciphers import algorithms
from cryptography.hazmat.primitives.ciphers.modes import GCM from cryptography.hazmat.primitives.ciphers.modes import GCM
@@ -32,11 +33,22 @@ if os.name == 'nt':
else: else:
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
log = logging.getLogger('gajim.plugin_system.omemo')
def aes_decrypt(key, iv, payload): def aes_decrypt(_key, iv, payload):
""" Use AES128 GCM with the given key and iv to decrypt the payload. """ """ Use AES128 GCM with the given key and iv to decrypt the payload. """
data = payload[:-16] if len(_key) >= 32:
tag = payload[-16:] # XEP-0384
log.debug('XEP Compliant Key/Tag')
data = payload
key = _key[:16]
tag = _key[16:]
else:
# Legacy
log.debug('Legacy Key/Tag')
data = payload[:-16]
key = _key
tag = payload[-16:]
if os.name == 'nt': if os.name == 'nt':
_backend = backend _backend = backend
else: else:
@@ -58,4 +70,4 @@ def aes_encrypt(key, iv, plaintext):
algorithms.AES(key), algorithms.AES(key),
GCM(iv), GCM(iv),
backend=_backend).encryptor() backend=_backend).encryptor()
return encryptor.update(plaintext) + encryptor.finalize() + encryptor.tag return encryptor.update(plaintext) + encryptor.finalize(), encryptor.tag

View File

@@ -83,10 +83,16 @@ class LiteAxolotlStore(AxolotlStore):
def saveIdentity(self, recepientId, identityKey): def saveIdentity(self, recepientId, identityKey):
self.identityKeyStore.saveIdentity(recepientId, identityKey) self.identityKeyStore.saveIdentity(recepientId, identityKey)
def deleteIdentity(self, recipientId, identityKey):
self.identityKeyStore.deleteIdentity(recipientId, identityKey)
def isTrustedIdentity(self, recepientId, identityKey): def isTrustedIdentity(self, recepientId, identityKey):
return self.identityKeyStore.isTrustedIdentity(recepientId, return self.identityKeyStore.isTrustedIdentity(recepientId,
identityKey) identityKey)
def setTrust(self, identityKey, trust):
return self.identityKeyStore.setTrust(identityKey, trust)
def getTrustedFingerprints(self, jid): def getTrustedFingerprints(self, jid):
return self.identityKeyStore.getTrustedFingerprints(jid) return self.identityKeyStore.getTrustedFingerprints(jid)
@@ -127,6 +133,9 @@ class LiteAxolotlStore(AxolotlStore):
# TODO Reuse this # TODO Reuse this
return self.sessionStore.getSubDeviceSessions(recepientId) return self.sessionStore.getSubDeviceSessions(recepientId)
def getJidFromDevice(self, device_id):
return self.sessionStore.getJidFromDevice(device_id)
def storeSession(self, recepientId, deviceId, sessionRecord): def storeSession(self, recepientId, deviceId, sessionRecord):
self.sessionStore.storeSession(recepientId, deviceId, sessionRecord) self.sessionStore.storeSession(recepientId, deviceId, sessionRecord)
@@ -139,6 +148,15 @@ class LiteAxolotlStore(AxolotlStore):
def deleteAllSessions(self, recepientId): def deleteAllSessions(self, recepientId):
self.sessionStore.deleteAllSessions(recepientId) self.sessionStore.deleteAllSessions(recepientId)
def getSessionsFromJid(self, recipientId):
return self.sessionStore.getSessionsFromJid(recipientId)
def getSessionsFromJids(self, recipientId):
return self.sessionStore.getSessionsFromJids(recipientId)
def getAllSessions(self):
return self.sessionStore.getAllSessions()
def loadSignedPreKey(self, signedPreKeyId): def loadSignedPreKey(self, signedPreKeyId):
return self.signedPreKeyStore.loadSignedPreKey(signedPreKeyId) return self.signedPreKeyStore.loadSignedPreKey(signedPreKeyId)

View File

@@ -86,6 +86,13 @@ class LiteIdentityKeyStore(IdentityKeyStore):
return result is not None return result is not None
def deleteIdentity(self, recipientId, identityKey):
q = "DELETE FROM identities WHERE recipient_id = ? AND public_key = ?"
c = self.dbConn.cursor()
c.execute(q, (recipientId,
identityKey.getPublicKey().serialize()))
self.dbConn.commit()
def isTrustedIdentity(self, recipientId, identityKey): def isTrustedIdentity(self, recipientId, identityKey):
q = "SELECT trust FROM identities WHERE recipient_id = ? " \ q = "SELECT trust FROM identities WHERE recipient_id = ? " \
"AND public_key = ?" "AND public_key = ?"
@@ -160,8 +167,8 @@ class LiteIdentityKeyStore(IdentityKeyStore):
c.execute(q, fingerprints) c.execute(q, fingerprints)
self.dbConn.commit() self.dbConn.commit()
def setTrust(self, _id, trust): def setTrust(self, identityKey, trust):
q = "UPDATE identities SET trust = ? WHERE _id = ?" q = "UPDATE identities SET trust = ? WHERE public_key = ?"
c = self.dbConn.cursor() c = self.dbConn.cursor()
c.execute(q, (trust, _id)) c.execute(q, (trust, identityKey.getPublicKey().serialize()))
self.dbConn.commit() self.dbConn.commit()

View File

@@ -48,6 +48,14 @@ class LiteSessionStore(SessionStore):
deviceIds = [r[0] for r in result] deviceIds = [r[0] for r in result]
return deviceIds return deviceIds
def getJidFromDevice(self, device_id):
q = "SELECT recipient_id from sessions WHERE device_id = ?"
c = self.dbConn.cursor()
c.execute(q, (device_id, ))
result = c.fetchone()
return result[0].decode('utf-8')
def getActiveDeviceTuples(self): def getActiveDeviceTuples(self):
q = "SELECT recipient_id, device_id FROM sessions WHERE active = 1" q = "SELECT recipient_id, device_id FROM sessions WHERE active = 1"
c = self.dbConn.cursor() c = self.dbConn.cursor()
@@ -82,6 +90,33 @@ class LiteSessionStore(SessionStore):
self.dbConn.cursor().execute(q, (recipientId, )) self.dbConn.cursor().execute(q, (recipientId, ))
self.dbConn.commit() self.dbConn.commit()
def getAllSessions(self):
q = "SELECT _id, recipient_id, device_id, record, active from sessions"
c = self.dbConn.cursor()
result = []
for row in c.execute(q):
result.append((row[0], row[1].decode('utf-8'), row[2], row[3], row[4]))
return result
def getSessionsFromJid(self, recipientId):
q = "SELECT _id, recipient_id, device_id, record, active from sessions" \
" WHERE recipient_id = ?"
c = self.dbConn.cursor()
result = []
for row in c.execute(q, (recipientId,)):
result.append((row[0], row[1].decode('utf-8'), row[2], row[3], row[4]))
return result
def getSessionsFromJids(self, recipientId):
q = "SELECT _id, recipient_id, device_id, record, active from sessions" \
" WHERE recipient_id IN ({})" \
.format(', '.join(['?'] * len(recipientId)))
c = self.dbConn.cursor()
result = []
for row in c.execute(q, recipientId):
result.append((row[0], row[1].decode('utf-8'), row[2], row[3], row[4]))
return result
def setActiveState(self, deviceList, jid): def setActiveState(self, deviceList, jid):
c = self.dbConn.cursor() c = self.dbConn.cursor()
@@ -96,28 +131,6 @@ class LiteSessionStore(SessionStore):
c.execute(q, deviceList) c.execute(q, deviceList)
self.dbConn.commit() self.dbConn.commit()
def getActiveSessionsKeys(self, recipientId):
q = "SELECT record FROM sessions WHERE active = 1 AND recipient_id = ?"
c = self.dbConn.cursor()
result = []
for row in c.execute(q, (recipientId,)):
public_key = (SessionRecord(serialized=row[0]).
getSessionState().getRemoteIdentityKey().
getPublicKey())
result.append(public_key.serialize())
return result
def getAllActiveSessionsKeys(self):
q = "SELECT record FROM sessions WHERE active = 1"
c = self.dbConn.cursor()
result = []
for row in c.execute(q):
public_key = (SessionRecord(serialized=row[0]).
getSessionState().getRemoteIdentityKey().
getPublicKey())
result.append(public_key.serialize())
return result
def getInactiveSessionsKeys(self, recipientId): def getInactiveSessionsKeys(self, recipientId):
q = "SELECT record FROM sessions WHERE active = 0 AND recipient_id = ?" q = "SELECT record FROM sessions WHERE active = 0 AND recipient_id = ?"
c = self.dbConn.cursor() c = self.dbConn.cursor()

View File

@@ -29,6 +29,14 @@ class SQLDatabase():
self.dbConn = dbConn self.dbConn = dbConn
self.createDb() self.createDb()
self.migrateDb() self.migrateDb()
c = self.dbConn.cursor()
c.execute("PRAGMA synchronous=NORMAL;")
c.execute("PRAGMA journal_mode;")
mode = c.fetchone()[0]
# WAL is a persistent DB mode, dont override it if user has set it
if mode != 'wal':
c.execute("PRAGMA journal_mode=MEMORY;")
self.dbConn.commit()
def createDb(self): def createDb(self):
if user_version(self.dbConn) == 0: if user_version(self.dbConn) == 0:

View File

@@ -200,8 +200,8 @@ class OmemoState:
key = self.handleWhisperMessage(sender_jid, sid, encrypted_key) key = self.handleWhisperMessage(sender_jid, sid, encrypted_key)
except (NoSessionException, InvalidMessageException) as e: except (NoSessionException, InvalidMessageException) as e:
log.warning('No Session found ' + e.message) log.warning('No Session found ' + e.message)
log.warning('sender_jid => ' + str(sender_jid) + log.warning('sender_jid => ' + str(sender_jid) + ' sid =>' +
' sid =>' + sid) str(sid))
return return
except (DuplicateMessageException) as e: except (DuplicateMessageException) as e:
log.warning('Duplicate message found ' + str(e.args)) log.warning('Duplicate message found ' + str(e.args))
@@ -211,7 +211,7 @@ class OmemoState:
log.warning('Duplicate message found ' + str(e.args)) log.warning('Duplicate message found ' + str(e.args))
return return
result = decrypt(key, iv, payload).decode('utf-8') result = decrypt(key, iv, payload)
log.debug("Decrypted Message => " + result) log.debug("Decrypted Message => " + result)
return result return result
@@ -226,43 +226,44 @@ class OmemoState:
log.error('No known devices') log.error('No known devices')
return return
for dev in devices_list: payload, tag = encrypt(key, iv, plaintext)
self.get_session_cipher(jid, dev)
session_ciphers = self.session_ciphers[jid] # for XEP-384 Compliance uncomment
if not session_ciphers: # key += tag
log.warning('No session ciphers for ' + jid) payload += tag
return
# Encrypt the message key with for each of receivers devices # Encrypt the message key with for each of receivers devices
for rid, cipher in session_ciphers.items(): for device in devices_list:
try: try:
if self.isTrusted(cipher) == TRUSTED: if self.isTrusted(jid, device) == TRUSTED:
encrypted_keys[rid] = cipher.encrypt(key).serialize() cipher = self.get_session_cipher(jid, device)
cipher_key = cipher.encrypt(key)
prekey = isinstance(cipher_key, PreKeyWhisperMessage)
encrypted_keys[device] = (cipher_key.serialize(), prekey)
else: else:
log.debug('Skipped Device because Trust is: ' + log.debug('Skipped Device because Trust is: ' +
str(self.isTrusted(cipher))) str(self.isTrusted(jid, device)))
except: except:
log.warning('Failed to find key for device ' + str(rid)) log.warning('Failed to find key for device ' + str(device))
if len(encrypted_keys) == 0: if len(encrypted_keys) == 0:
log_msg = 'Encrypted keys empty' log.error('Encrypted keys empty')
log.error(log_msg) raise NoValidSessions('Encrypted keys empty')
raise NoValidSessions(log_msg)
my_other_devices = set(self.own_devices) - set({self.own_device_id}) my_other_devices = set(self.own_devices) - set({self.own_device_id})
# Encrypt the message key with for each of our own devices # Encrypt the message key with for each of our own devices
for dev in my_other_devices: for device in my_other_devices:
try: try:
cipher = self.get_session_cipher(from_jid, dev) if self.isTrusted(from_jid, device) == TRUSTED:
if self.isTrusted(cipher) == TRUSTED: cipher = self.get_session_cipher(from_jid, device)
encrypted_keys[dev] = cipher.encrypt(key).serialize() cipher_key = cipher.encrypt(key)
prekey = isinstance(cipher_key, PreKeyWhisperMessage)
encrypted_keys[device] = (cipher_key.serialize(), prekey)
else: else:
log.debug('Skipped own Device because Trust is: ' + log.debug('Skipped own Device because Trust is: ' +
str(self.isTrusted(cipher))) str(self.isTrusted(from_jid, device)))
except: except:
log.warning('Failed to find key for device ' + str(dev)) log.warning('Failed to find key for device ' + str(device))
payload = encrypt(key, iv, plaintext)
result = {'sid': self.own_device_id, result = {'sid': self.own_device_id,
'keys': encrypted_keys, 'keys': encrypted_keys,
@@ -273,14 +274,109 @@ class OmemoState:
log.debug('Finished encrypting message') log.debug('Finished encrypting message')
return result return result
def isTrusted(self, cipher): def create_gc_msg(self, from_jid, jid, plaintext):
self.cipher = cipher key = get_random_bytes(16)
self.state = self.cipher.sessionStore. \ iv = get_random_bytes(16)
loadSession(self.cipher.recipientId, self.cipher.deviceId). \ encrypted_keys = {}
getSessionState() room = jid
self.key = self.state.getRemoteIdentityKey() encrypted_jids = []
return self.store.identityKeyStore. \
isTrustedIdentity(self.cipher.recipientId, self.key) devices_list = self.device_list_for(jid, True)
if len(devices_list) == 0:
log.error('No known devices')
return
payload, tag = encrypt(key, iv, plaintext)
# for XEP-384 Compliance uncomment
# key += tag
payload += tag
for tup in devices_list:
self.get_session_cipher(tup[0], tup[1])
# Encrypt the message key with for each of receivers devices
for nick in self.plugin.groupchat[room]:
jid_to = self.plugin.groupchat[room][nick]
if jid_to == self.own_jid:
continue
if jid_to in encrypted_jids: # We already encrypted to this JID
continue
for rid, cipher in self.session_ciphers[jid_to].items():
try:
if self.isTrusted(jid_to, rid) == TRUSTED:
cipher_key = cipher.encrypt(key)
prekey = isinstance(cipher_key, PreKeyWhisperMessage)
encrypted_keys[rid] = (cipher_key.serialize(), prekey)
else:
log.debug('Skipped Device because Trust is: ' +
str(self.isTrusted(jid_to, rid)))
except:
log.exception('ERROR:')
log.warning('Failed to find key for device ' +
str(rid))
encrypted_jids.append(jid_to)
if len(encrypted_keys) == 0:
log_msg = 'Encrypted keys empty'
log.error(log_msg)
raise NoValidSessions(log_msg)
my_other_devices = set(self.own_devices) - set({self.own_device_id})
# Encrypt the message key with for each of our own devices
for dev in my_other_devices:
try:
cipher = self.get_session_cipher(from_jid, dev)
if self.isTrusted(from_jid, dev) == TRUSTED:
cipher_key = cipher.encrypt(key)
prekey = isinstance(cipher_key, PreKeyWhisperMessage)
encrypted_keys[dev] = (cipher_key.serialize(), prekey)
else:
log.debug('Skipped own Device because Trust is: ' +
str(self.isTrusted(from_jid, dev)))
except:
log.exception('ERROR:')
log.warning('Failed to find key for device ' + str(dev))
result = {'sid': self.own_device_id,
'keys': encrypted_keys,
'jid': jid,
'iv': iv,
'payload': payload}
log.debug('Finished encrypting message')
return result
def device_list_for(self, jid, gc=False):
""" Return a list of known device ids for the specified jid.
Parameters
----------
jid : string
The contacts jid
gc : bool
Groupchat Message
"""
if gc:
room = jid
devicelist = []
for nick in self.plugin.groupchat[room]:
jid_to = self.plugin.groupchat[room][nick]
if jid_to == self.own_jid:
continue
for device in self.device_ids[jid_to]:
devicelist.append((jid_to, device))
return devicelist
if jid == self.own_jid:
return set(self.own_devices) - set({self.own_device_id})
if jid not in self.device_ids:
return set()
return set(self.device_ids[jid])
def isTrusted(self, recipient_id, device_id):
record = self.store.loadSession(recipient_id, device_id)
identity_key = record.getSessionState().getRemoteIdentityKey()
return self.store.isTrustedIdentity(recipient_id, identity_key)
def getTrustedFingerprints(self, recipient_id): def getTrustedFingerprints(self, recipient_id):
inactive = self.store.getInactiveSessionsKeys(recipient_id) inactive = self.store.getInactiveSessionsKeys(recipient_id)
@@ -296,20 +392,6 @@ class OmemoState:
return undecided return undecided
def device_list_for(self, jid):
""" Return a list of known device ids for the specified jid.
Parameters
----------
jid : string
The contacts jid
"""
if jid == self.own_jid:
return set(self.own_devices) - set({self.own_device_id})
if jid not in self.device_ids:
return set()
return set(self.device_ids[jid])
def devices_without_sessions(self, jid): def devices_without_sessions(self, jid):
""" List device_ids for the given jid which have no axolotl session. """ List device_ids for the given jid which have no axolotl session.
@@ -364,10 +446,10 @@ class OmemoState:
def handleWhisperMessage(self, recipient_id, device_id, key): def handleWhisperMessage(self, recipient_id, device_id, key):
whisperMessage = WhisperMessage(serialized=key) whisperMessage = WhisperMessage(serialized=key)
sessionCipher = self.get_session_cipher(recipient_id, device_id)
log.debug(self.account + " => Received WhisperMessage from " + log.debug(self.account + " => Received WhisperMessage from " +
recipient_id) recipient_id)
if self.isTrusted(sessionCipher) >= TRUSTED: if self.isTrusted(recipient_id, device_id):
sessionCipher = self.get_session_cipher(recipient_id, device_id)
key = sessionCipher.decryptMsg(whisperMessage, textMsg=False) key = sessionCipher.decryptMsg(whisperMessage, textMsg=False)
return key return key
else: else: