From 87ece2397ed9d745f9158eba4baa7002eb4477b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20H=C3=B6rist?= Date: Sat, 16 Feb 2019 13:57:43 +0100 Subject: [PATCH] [omemo] Refactor AxolotlStore - Dont use cursor object - Use namedtuple factory --- omemo/backend/liteaxolotlstore.py | 605 +++++++++++++----------------- omemo/backend/state.py | 7 +- omemo/modules/omemo.py | 6 +- 3 files changed, 257 insertions(+), 361 deletions(-) diff --git a/omemo/backend/liteaxolotlstore.py b/omemo/backend/liteaxolotlstore.py index 0edd4b5..5ef2ddc 100644 --- a/omemo/backend/liteaxolotlstore.py +++ b/omemo/backend/liteaxolotlstore.py @@ -17,6 +17,8 @@ import logging +import sqlite3 +from collections import namedtuple from axolotl.state.axolotlstore import AxolotlStore from axolotl.state.signedprekeyrecord import SignedPreKeyRecord @@ -44,41 +46,58 @@ UNTRUSTED = 0 class LiteAxolotlStore(AxolotlStore): - def __init__(self, connection): - self.dbConn = connection - self.dbConn.text_factory = bytes + def __init__(self, db_path): + self._con = sqlite3.connect(db_path, check_same_thread=False) + self._con.row_factory = self._namedtuple_factory self.createDb() self.migrateDb() - c = self.dbConn.cursor() - c.execute("PRAGMA synchronous=NORMAL;") - c.execute("PRAGMA journal_mode;") - mode = c.fetchone()[0] + + self._con.execute("PRAGMA secure_delete=1") + self._con.execute("PRAGMA synchronous=NORMAL;") + mode = self._con.execute("PRAGMA journal_mode;").fetchone()[0] + # WAL is a persistent DB mode, don't override it if user has set it if mode != 'wal': - c.execute("PRAGMA journal_mode=MEMORY;") - self.dbConn.commit() + self._con.execute("PRAGMA journal_mode=MEMORY;") + self._con.commit() if not self.getLocalRegistrationId(): log.info("Generating OMEMO keys") self._generate_axolotl_keys() + @staticmethod + def _namedtuple_factory(cursor, row): + fields = [] + for col in cursor.description: + if col[0] == '_id': + fields.append('id') + elif 'strftime' in col[0]: + fields.append('formated_time') + elif 'MAX' in col[0] or 'COUNT' in col[0]: + col_name = col[0].replace('(', '_') + col_name = col_name.replace(')', '') + fields.append(col_name.lower()) + else: + fields.append(col[0]) + return namedtuple("Row", fields)(*row) + def _generate_axolotl_keys(self): - identityKeyPair = KeyHelper.generateIdentityKeyPair() - registrationId = KeyHelper.generateRegistrationId() - preKeys = KeyHelper.generatePreKeys(KeyHelper.getRandomSequence(), - DEFAULT_PREKEY_AMOUNT) - self.storeLocalData(registrationId, identityKeyPair) + identity_key_pair = KeyHelper.generateIdentityKeyPair() + registration_id = KeyHelper.generateRegistrationId() + pre_keys = KeyHelper.generatePreKeys(KeyHelper.getRandomSequence(), + DEFAULT_PREKEY_AMOUNT) + self.storeLocalData(registration_id, identity_key_pair) - signedPreKey = KeyHelper.generateSignedPreKey( - identityKeyPair, KeyHelper.getRandomSequence(65536)) + signed_pre_key = KeyHelper.generateSignedPreKey( + identity_key_pair, KeyHelper.getRandomSequence(65536)) - self.storeSignedPreKey(signedPreKey.getId(), signedPreKey) + self.storeSignedPreKey(signed_pre_key.getId(), signed_pre_key) - for preKey in preKeys: - self.storePreKey(preKey.getId(), preKey) + for pre_key in pre_keys: + self.storePreKey(pre_key.getId(), pre_key) def user_version(self): - return self.dbConn.execute('PRAGMA user_version').fetchone()[0] + return self._con.execute('PRAGMA user_version').fetchone()[0] def createDb(self): if self.user_version() == 0: @@ -122,7 +141,7 @@ class LiteAxolotlStore(AxolotlStore): PRAGMA user_version=5; END TRANSACTION; """ % (create_tables) - self.dbConn.executescript(create_db_sql) + self._con.executescript(create_db_sql) def migrateDb(self): """ Migrates the DB @@ -138,11 +157,12 @@ class LiteAxolotlStore(AxolotlStore): ); """ - self.dbConn.executescript(""" BEGIN TRANSACTION; - %s - PRAGMA user_version=2; - END TRANSACTION; - """ % (delete_dupes)) + self._con.executescript( + """ BEGIN TRANSACTION; + %s + PRAGMA user_version=2; + END TRANSACTION; + """ % (delete_dupes)) if self.user_version() < 3: # Create a UNIQUE INDEX so every public key/recipient_id tuple @@ -152,11 +172,12 @@ class LiteAxolotlStore(AxolotlStore): ON identities (public_key, recipient_id); """ - self.dbConn.executescript(""" BEGIN TRANSACTION; - %s - PRAGMA user_version=3; - END TRANSACTION; - """ % (add_index)) + self._con.executescript( + """ BEGIN TRANSACTION; + %s + PRAGMA user_version=3; + END TRANSACTION; + """ % (add_index)) if self.user_version() < 4: # Adds column "active" to the sessions table @@ -164,11 +185,12 @@ class LiteAxolotlStore(AxolotlStore): ADD COLUMN active INTEGER DEFAULT 1; """ - self.dbConn.executescript(""" BEGIN TRANSACTION; - %s - PRAGMA user_version=4; - END TRANSACTION; - """ % (add_active)) + self._con.executescript( + """ BEGIN TRANSACTION; + %s + PRAGMA user_version=4; + END TRANSACTION; + """ % (add_active)) if self.user_version() < 5: # Adds DEFAULT Timestamp @@ -182,437 +204,316 @@ class LiteAxolotlStore(AxolotlStore): UPDATE identities SET shown = 1; """ - self.dbConn.executescript(""" BEGIN TRANSACTION; - %s - PRAGMA user_version=5; - END TRANSACTION; - """ % (add_timestamp)) - + self._con.executescript( + """ BEGIN TRANSACTION; + %s + PRAGMA user_version=5; + END TRANSACTION; + """ % (add_timestamp)) def loadSignedPreKey(self, signedPreKeyId): - q = "SELECT record FROM signed_prekeys WHERE prekey_id = ?" - - cursor = self.dbConn.cursor() - cursor.execute(q, (signedPreKeyId, )) - - result = cursor.fetchone() - if not result: + query = 'SELECT record FROM signed_prekeys WHERE prekey_id = ?' + result = self._con.execute(query, (signedPreKeyId, )).fetchone() + if result is None: raise InvalidKeyIdException("No such signedprekeyrecord! %s " % signedPreKeyId) - - return SignedPreKeyRecord(serialized=result[0]) + return SignedPreKeyRecord(serialized=result.record) def loadSignedPreKeys(self): - q = "SELECT record FROM signed_prekeys" - - cursor = self.dbConn.cursor() - cursor.execute(q, ) - result = cursor.fetchall() - results = [] - for row in result: - results.append(SignedPreKeyRecord(serialized=row[0])) - - return results + query = 'SELECT record FROM signed_prekeys' + results = self._con.execute(query).fetchall() + return [SignedPreKeyRecord(serialized=row.record) for row in results] def storeSignedPreKey(self, signedPreKeyId, signedPreKeyRecord): - q = "INSERT INTO signed_prekeys (prekey_id, record) VALUES(?,?)" - cursor = self.dbConn.cursor() - cursor.execute(q, (signedPreKeyId, signedPreKeyRecord.serialize())) - self.dbConn.commit() + query = 'INSERT INTO signed_prekeys (prekey_id, record) VALUES(?,?)' + self._con.execute(query, (signedPreKeyId, + signedPreKeyRecord.serialize())) + self._con.commit() def containsSignedPreKey(self, signedPreKeyId): - q = "SELECT record FROM signed_prekeys WHERE prekey_id = ?" - cursor = self.dbConn.cursor() - cursor.execute(q, (signedPreKeyId, )) - return cursor.fetchone() is not None + query = 'SELECT record FROM signed_prekeys WHERE prekey_id = ?' + result = self._con.execute(query, (signedPreKeyId,)).fetchone() + return result is not None def removeSignedPreKey(self, signedPreKeyId): - q = "DELETE FROM signed_prekeys WHERE prekey_id = ?" - cursor = self.dbConn.cursor() - cursor.execute(q, (signedPreKeyId, )) - self.dbConn.commit() + query = 'DELETE FROM signed_prekeys WHERE prekey_id = ?' + self._con.execute(query, (signedPreKeyId,)) + self._con.commit() def getNextSignedPreKeyId(self): result = self.getCurrentSignedPreKeyId() - if not result: + if result is None: return 1 # StartId if no SignedPreKeys exist - else: - return (result % (Medium.MAX_VALUE - 1)) + 1 + return (result % (Medium.MAX_VALUE - 1)) + 1 def getCurrentSignedPreKeyId(self): - q = "SELECT MAX(prekey_id) FROM signed_prekeys" - - cursor = self.dbConn.cursor() - cursor.execute(q) - result = cursor.fetchone() - if not result: - return None - else: - return result[0] + query = 'SELECT MAX(prekey_id) FROM signed_prekeys' + result = self._con.execute(query).fetchone() + return result.max_prekey_id if result is not None else None def getSignedPreKeyTimestamp(self, signedPreKeyId): - q = "SELECT strftime('%s', timestamp) FROM " \ - "signed_prekeys WHERE prekey_id = ?" + query = '''SELECT strftime('%s', timestamp) FROM + signed_prekeys WHERE prekey_id = ?''' - cursor = self.dbConn.cursor() - cursor.execute(q, (signedPreKeyId, )) - - result = cursor.fetchone() - if not result: - raise InvalidKeyIdException("No such signedprekeyrecord! %s " % + result = self._con.execute(query, (signedPreKeyId,)).fetchone() + if result is None: + raise InvalidKeyIdException('No such signedprekeyrecord! %s' % signedPreKeyId) - return result[0] + return result.formated_time def removeOldSignedPreKeys(self, timestamp): - q = "DELETE FROM signed_prekeys " \ - "WHERE timestamp < datetime(?, 'unixepoch')" - cursor = self.dbConn.cursor() - cursor.execute(q, (timestamp, )) - self.dbConn.commit() + query = '''DELETE FROM signed_prekeys + WHERE timestamp < datetime(?, "unixepoch")''' + self._con.execute(query, (timestamp,)) + self._con.commit() def loadSession(self, recipientId, deviceId): - q = "SELECT record FROM sessions WHERE recipient_id = ? AND device_id = ?" - c = self.dbConn.cursor() - c.execute(q, (recipientId, deviceId)) - result = c.fetchone() - - if result: - return SessionRecord(serialized=result[0]) - else: + query = '''SELECT record FROM sessions WHERE + recipient_id = ? AND device_id = ?''' + result = self._con.execute(query, (recipientId, deviceId)).fetchone() + if result is None: return SessionRecord() - - def getSubDeviceSessions(self, recipientId): - q = "SELECT device_id from sessions WHERE recipient_id = ?" - c = self.dbConn.cursor() - c.execute(q, (recipientId, )) - result = c.fetchall() - - deviceIds = [r[0] for r in result] - return deviceIds + return SessionRecord(serialized=result.record) def getJidFromDevice(self, device_id): - q = "SELECT recipient_id from sessions WHERE device_id = ?" - c = self.dbConn.cursor() - c.execute(q, (device_id, )) - result = c.fetchone() - - return result[0].decode('utf-8') if result else None + query = 'SELECT recipient_id from sessions WHERE device_id = ?' + result = self._con.execute(query, (device_id, )).fetchone() + return result.recipient_id if result is not None else None def getActiveDeviceTuples(self): - q = "SELECT recipient_id, device_id FROM sessions WHERE active = 1" - c = self.dbConn.cursor() - result = [] - for row in c.execute(q): - result.append((row[0].decode('utf-8'), row[1])) - return result + query = 'SELECT recipient_id, device_id FROM sessions WHERE active = 1' + return self._con.execute(query).fetchall() def storeSession(self, recipientId, deviceId, sessionRecord): self.deleteSession(recipientId, deviceId) - q = "INSERT INTO sessions(recipient_id, device_id, record) VALUES(?,?,?)" - c = self.dbConn.cursor() - c.execute(q, (recipientId, deviceId, sessionRecord.serialize())) - self.dbConn.commit() + query = '''INSERT INTO sessions(recipient_id, device_id, record) + VALUES(?,?,?)''' + self._con.execute(query, (recipientId, + deviceId, + sessionRecord.serialize())) + self._con.commit() def containsSession(self, recipientId, deviceId): - q = "SELECT record FROM sessions WHERE recipient_id = ? AND device_id = ?" - c = self.dbConn.cursor() - c.execute(q, (recipientId, deviceId)) - result = c.fetchone() - + query = '''SELECT record FROM sessions + WHERE recipient_id = ? AND device_id = ?''' + result = self._con.execute(query, (recipientId, deviceId)).fetchone() return result is not None def deleteSession(self, recipientId, deviceId): - q = "DELETE FROM sessions WHERE recipient_id = ? AND device_id = ?" - self.dbConn.cursor().execute(q, (recipientId, deviceId)) - self.dbConn.commit() + query = "DELETE FROM sessions WHERE recipient_id = ? AND device_id = ?" + self._con.execute(query, (recipientId, deviceId)) + self._con.commit() def deleteAllSessions(self, recipientId): - q = "DELETE FROM sessions WHERE recipient_id = ?" - self.dbConn.cursor().execute(q, (recipientId, )) - self.dbConn.commit() - - def getAllSessions(self): - q = "SELECT _id, recipient_id, device_id, record, active from sessions" - c = self.dbConn.cursor() - result = [] - for row in c.execute(q): - result.append((row[0], row[1].decode('utf-8'), row[2], row[3], row[4])) - return result + query = 'DELETE FROM sessions WHERE recipient_id = ?' + self._con.execute(query, (recipientId,)) + self._con.commit() def getSessionsFromJid(self, recipientId): - q = "SELECT _id, recipient_id, device_id, record, active from sessions" \ - " WHERE recipient_id = ?" - c = self.dbConn.cursor() - result = [] - for row in c.execute(q, (recipientId,)): - result.append((row[0], row[1].decode('utf-8'), row[2], row[3], row[4])) - return result + query = '''SELECT _id, recipient_id, device_id, record, active + from sessions WHERE recipient_id = ?''' + return self._con.execute(query, (recipientId,)).fetchall() - def getSessionsFromJids(self, recipientId): - q = "SELECT _id, recipient_id, device_id, record, active from sessions" \ - " WHERE recipient_id IN ({})" \ - .format(', '.join(['?'] * len(recipientId))) - c = self.dbConn.cursor() - result = [] - for row in c.execute(q, recipientId): - result.append((row[0], row[1].decode('utf-8'), row[2], row[3], row[4])) - return result + def getSessionsFromJids(self, recipientIds): + query = '''SELECT _id, recipient_id, device_id, record, active from sessions + WHERE recipient_id IN ({})'''.format( + ', '.join(['?'] * len(recipientIds))) + return self._con.execute(query, recipientIds).fetchall() def setActiveState(self, deviceList, jid): - c = self.dbConn.cursor() + query = '''UPDATE sessions SET active = 1 + WHERE recipient_id = ? AND device_id IN ({})'''.format( + ', '.join(['?'] * len(deviceList))) + self._con.execute(query, (jid,) + tuple(deviceList)) - q = "UPDATE sessions SET active = {} " \ - "WHERE recipient_id = '{}' AND device_id IN ({})" \ - .format(1, jid, ', '.join(['?'] * len(deviceList))) - c.execute(q, deviceList) - - q = "UPDATE sessions SET active = {} " \ - "WHERE recipient_id = '{}' AND device_id NOT IN ({})" \ - .format(0, jid, ', '.join(['?'] * len(deviceList))) - c.execute(q, deviceList) - self.dbConn.commit() + query = '''UPDATE sessions SET active = 0 + WHERE recipient_id = ? AND device_id NOT IN ({})'''.format( + ', '.join(['?'] * len(deviceList))) + self._con.execute(query, (jid,) + tuple(deviceList)) + self._con.commit() def getInactiveSessionsKeys(self, recipientId): - q = "SELECT record FROM sessions WHERE active = 0 AND recipient_id = ?" - c = self.dbConn.cursor() - result = [] - for row in c.execute(q, (recipientId,)): - public_key = (SessionRecord(serialized=row[0]). + query = '''SELECT record FROM sessions + WHERE active = 0 AND recipient_id = ?''' + result = self._con.execute(query, (recipientId,)).fetchall() + + results = [] + for row in result: + public_key = (SessionRecord(serialized=row.record). getSessionState().getRemoteIdentityKey(). getPublicKey()) - result.append(public_key.serialize()) - return result + results.append(public_key.serialize()) + return results def loadPreKey(self, preKeyId): - q = "SELECT record FROM prekeys WHERE prekey_id = ?" + query = '''SELECT record FROM prekeys WHERE prekey_id = ?''' - cursor = self.dbConn.cursor() - cursor.execute(q, (preKeyId, )) - - result = cursor.fetchone() - if not result: + result = self._con.execute(query, (preKeyId,)).fetchone() + if result is None: raise Exception("No such prekeyRecord!") - - return PreKeyRecord(serialized=result[0]) + return PreKeyRecord(serialized=result.record) def loadPendingPreKeys(self): - q = "SELECT record FROM prekeys" - cursor = self.dbConn.cursor() - cursor.execute(q) - result = cursor.fetchall() - - return [PreKeyRecord(serialized=r[0]) for r in result] + query = '''SELECT record FROM prekeys''' + result = self._con.execute(query).fetchall() + return [PreKeyRecord(serialized=row.record) for row in result] def storePreKey(self, preKeyId, preKeyRecord): - q = "INSERT INTO prekeys (prekey_id, record) VALUES(?,?)" - cursor = self.dbConn.cursor() - cursor.execute(q, (preKeyId, preKeyRecord.serialize())) - self.dbConn.commit() + query = 'INSERT INTO prekeys (prekey_id, record) VALUES(?,?)' + self._con.execute(query, (preKeyId, preKeyRecord.serialize())) + self._con.commit() def containsPreKey(self, preKeyId): - q = "SELECT record FROM prekeys WHERE prekey_id = ?" - cursor = self.dbConn.cursor() - cursor.execute(q, (preKeyId, )) - return cursor.fetchone() is not None + query = 'SELECT record FROM prekeys WHERE prekey_id = ?' + result = self._con.execute(query, (preKeyId,)).fetchone() + return result is not None def removePreKey(self, preKeyId): - q = "DELETE FROM prekeys WHERE prekey_id = ?" - cursor = self.dbConn.cursor() - cursor.execute(q, (preKeyId, )) - self.dbConn.commit() + query = 'DELETE FROM prekeys WHERE prekey_id = ?' + self._con.execute(query, (preKeyId,)) + self._con.commit() def getCurrentPreKeyId(self): - q = "SELECT MAX(prekey_id) FROM prekeys" - cursor = self.dbConn.cursor() - cursor.execute(q) - return cursor.fetchone()[0] + query = 'SELECT MAX(prekey_id) FROM prekeys' + return self._con.execute(query).fetchone().max_prekey_id def getPreKeyCount(self): - q = "SELECT COUNT(prekey_id) FROM prekeys" - cursor = self.dbConn.cursor() - cursor.execute(q) - return cursor.fetchone()[0] + query = 'SELECT COUNT(prekey_id) FROM prekeys' + return self._con.execute(query).fetchone().count_prekey_id def generateNewPreKeys(self, count): - startId = self.getCurrentPreKeyId() + 1 - preKeys = KeyHelper.generatePreKeys(startId, count) + start_id = self.getCurrentPreKeyId() + 1 + pre_keys = KeyHelper.generatePreKeys(start_id, count) - for preKey in preKeys: - self.storePreKey(preKey.getId(), preKey) + for pre_key in pre_keys: + self.storePreKey(pre_key.getId(), pre_key) def getIdentityKeyPair(self): - q = "SELECT public_key, private_key FROM identities " + \ - "WHERE recipient_id = -1" - c = self.dbConn.cursor() - c.execute(q) - result = c.fetchone() + query = '''SELECT public_key, private_key FROM identities + WHERE recipient_id = -1''' + result = self._con.execute(query).fetchone() - publicKey, privateKey = result return IdentityKeyPair( - IdentityKey(DjbECPublicKey(publicKey[1:])), - DjbECPrivateKey(privateKey)) + IdentityKey(DjbECPublicKey(result.public_key[1:])), + DjbECPrivateKey(result.private_key)) def getLocalRegistrationId(self): - q = "SELECT registration_id FROM identities WHERE recipient_id = -1" - c = self.dbConn.cursor() - c.execute(q) - result = c.fetchone() - return result[0] if result else None + query = 'SELECT registration_id FROM identities WHERE recipient_id = -1' + result = self._con.execute(query).fetchone() + return result.registration_id if result is not None else None def storeLocalData(self, registrationId, identityKeyPair): - q = "INSERT INTO identities( " + \ - "recipient_id, registration_id, public_key, private_key) " + \ - "VALUES(-1, ?, ?, ?)" - c = self.dbConn.cursor() - c.execute(q, - (registrationId, - identityKeyPair.getPublicKey().getPublicKey().serialize(), - identityKeyPair.getPrivateKey().serialize())) + query = '''INSERT INTO identities( + recipient_id, registration_id, public_key, private_key) + VALUES(-1, ?, ?, ?)''' - self.dbConn.commit() + public_key = identityKeyPair.getPublicKey().getPublicKey().serialize() + private_key = identityKeyPair.getPrivateKey().serialize() + self._con.execute(query, (registrationId, public_key, private_key)) + self._con.commit() def saveIdentity(self, recipientId, identityKey): - q = "INSERT INTO identities (recipient_id, public_key, trust) " \ - "VALUES(?, ?, ?)" - c = self.dbConn.cursor() + query = '''INSERT INTO identities (recipient_id, public_key, trust) + VALUES(?, ?, ?)''' + if not self.containsIdentity(recipientId, identityKey): + self._con.execute(query, (recipientId, + identityKey.getPublicKey().serialize(), + UNDECIDED)) + self._con.commit() - if not self.getIdentity(recipientId, identityKey): - c.execute(q, (recipientId, - identityKey.getPublicKey().serialize(), - UNDECIDED)) - self.dbConn.commit() + def containsIdentity(self, recipientId, identityKey): + query = '''SELECT * FROM identities WHERE recipient_id = ? + AND public_key = ?''' - def getIdentity(self, recipientId, identityKey): - q = "SELECT * FROM identities WHERE recipient_id = ? " \ - "AND public_key = ?" - c = self.dbConn.cursor() - - c.execute(q, (recipientId, identityKey.getPublicKey().serialize())) - result = c.fetchone() + public_key = identityKey.getPublicKey().serialize() + result = self._con.execute(query, (recipientId, + public_key)).fetchone() return result is not None def deleteIdentity(self, recipientId, identityKey): - q = "DELETE FROM identities WHERE recipient_id = ? AND public_key = ?" - c = self.dbConn.cursor() - c.execute(q, (recipientId, - identityKey.getPublicKey().serialize())) - self.dbConn.commit() + query = '''DELETE FROM identities + WHERE recipient_id = ? AND public_key = ?''' + public_key = identityKey.getPublicKey().serialize() + self._con.execute(query, (recipientId, public_key)) + self._con.commit() def isTrustedIdentity(self, recipientId, identityKey): - q = "SELECT trust FROM identities WHERE recipient_id = ? " \ - "AND public_key = ?" - c = self.dbConn.cursor() - - c.execute(q, (recipientId, identityKey.getPublicKey().serialize())) - result = c.fetchone() + query = '''SELECT trust FROM identities WHERE recipient_id = ? + AND public_key = ?''' + public_key = identityKey.getPublicKey().serialize() + result = self._con.execute(query, (recipientId, public_key)).fetchone() + if result is None: + return True states = [UNTRUSTED, TRUSTED, UNDECIDED] - - if result and result[0] in states: - return result[0] - else: - return True + if result.trust in states: + return result.trust + return False def getAllFingerprints(self): - q = "SELECT _id, recipient_id, public_key, trust FROM identities " \ - "WHERE recipient_id != -1 ORDER BY recipient_id ASC" - c = self.dbConn.cursor() - - result = [] - for row in c.execute(q): - result.append((row[0], row[1], row[2], row[3])) - return result + query = '''SELECT _id, recipient_id, public_key, trust FROM identities + WHERE recipient_id != -1 ORDER BY recipient_id ASC''' + return self._con.execute(query).fetchall() def getFingerprints(self, jid): - q = "SELECT _id, recipient_id, public_key, trust FROM identities " \ - "WHERE recipient_id =? ORDER BY trust ASC" - c = self.dbConn.cursor() - - result = [] - c.execute(q, (jid,)) - rows = c.fetchall() - for row in rows: - result.append((row[0], row[1], row[2], row[3])) - return result + query = '''SELECT _id, recipient_id, public_key, trust FROM identities + WHERE recipient_id =? ORDER BY trust ASC''' + return self._con.execute(query, (jid,)).fetchall() def getTrustedFingerprints(self, jid): - q = "SELECT public_key FROM identities WHERE recipient_id = ? AND trust = ?" - c = self.dbConn.cursor() - - result = [] - c.execute(q, (jid, TRUSTED)) - rows = c.fetchall() - for row in rows: - result.append(row[0]) - return result + query = '''SELECT public_key FROM identities + WHERE recipient_id = ? AND trust = ?''' + result = self._con.execute(query, (jid, TRUSTED)).fetchall() + return [row.public_key for row in result] def getUndecidedFingerprints(self, jid): - q = "SELECT trust FROM identities WHERE recipient_id = ? AND trust = ?" - c = self.dbConn.cursor() - - result = [] - c.execute(q, (jid, UNDECIDED)) - result = c.fetchall() - - return result + query = '''SELECT trust FROM identities + WHERE recipient_id = ? AND trust = ?''' + return self._con.execute(query, (jid, UNDECIDED)).fetchall() def getNewFingerprints(self, jid): - q = "SELECT _id FROM identities WHERE shown = 0 AND " \ - "recipient_id = ?" - c = self.dbConn.cursor() - result = [] - for row in c.execute(q, (jid,)): - result.append(row[0]) - return result + query = '''SELECT _id FROM identities WHERE shown = 0 + AND recipient_id = ?''' + + result = self._con.execute(query, (jid,)).fetchall() + return [row.id for row in result] def setShownFingerprints(self, fingerprints): - q = "UPDATE identities SET shown = 1 WHERE _id IN ({})" \ - .format(', '.join(['?'] * len(fingerprints))) - c = self.dbConn.cursor() - c.execute(q, fingerprints) - self.dbConn.commit() + query = 'UPDATE identities SET shown = 1 WHERE _id IN ({})'.format( + ', '.join(['?'] * len(fingerprints))) + self._con.execute(query, fingerprints) + self._con.commit() def setTrust(self, identityKey, trust): - q = "UPDATE identities SET trust = ? WHERE public_key = ?" - c = self.dbConn.cursor() - c.execute(q, (trust, identityKey.getPublicKey().serialize())) - self.dbConn.commit() + query = 'UPDATE identities SET trust = ? WHERE public_key = ?' + public_key = identityKey.getPublicKey().serialize() + self._con.execute(query, (trust, public_key)) + self._con.commit() def activate(self, jid): - q = """INSERT OR REPLACE INTO encryption_state (jid, encryption) - VALUES (?, 1) """ + query = '''INSERT OR REPLACE INTO encryption_state (jid, encryption) + VALUES (?, 1)''' - c = self.dbConn.cursor() - c.execute(q, (jid, )) - self.dbConn.commit() + self._con.execute(query, (jid,)) + self._con.commit() def deactivate(self, jid): - q = """INSERT OR REPLACE INTO encryption_state (jid, encryption) - VALUES (?, 0)""" + query = '''INSERT OR REPLACE INTO encryption_state (jid, encryption) + VALUES (?, 0)''' - c = self.dbConn.cursor() - c.execute(q, (jid, )) - self.dbConn.commit() + self._con.execute(query, (jid, )) + self._con.commit() def is_active(self, jid): - q = 'SELECT encryption FROM encryption_state where jid = ?;' - c = self.dbConn.cursor() - c.execute(q, (jid, )) - result = c.fetchone() - if result is None: - return False - return result[0] + query = 'SELECT encryption FROM encryption_state where jid = ?' + result = self._con.execute(query, (jid,)).fetchone() + return result.encryption if result is not None else False def exist(self, jid): - q = 'SELECT encryption FROM encryption_state where jid = ?;' - c = self.dbConn.cursor() - c.execute(q, (jid, )) - result = c.fetchone() - if result is None: - return False - else: - return True + query = 'SELECT encryption FROM encryption_state where jid = ?' + result = self._con.execute(query, (jid,)).fetchone() + return result is not None diff --git a/omemo/backend/state.py b/omemo/backend/state.py index 429835c..5eaa087 100644 --- a/omemo/backend/state.py +++ b/omemo/backend/state.py @@ -17,8 +17,6 @@ import logging import time -import binascii -import textwrap from collections import defaultdict from nbxmpp.structs import OMEMOBundle @@ -53,14 +51,15 @@ UNDECIDED = 2 class OmemoState: - def __init__(self, own_jid, db_con, account, xmpp_con): + def __init__(self, own_jid, db_path, account, xmpp_con): self.account = account self.xmpp_con = xmpp_con self._session_ciphers = defaultdict(dict) self.own_jid = own_jid self.device_ids = {} self.own_devices = [] - self.store = LiteAxolotlStore(db_con) + + self.store = LiteAxolotlStore(db_path) for jid, device_id in self.store.getActiveDeviceTuples(): if jid != own_jid: self.add_device(jid, device_id) diff --git a/omemo/modules/omemo.py b/omemo/modules/omemo.py index 5236f60..9f21604 100644 --- a/omemo/modules/omemo.py +++ b/omemo/modules/omemo.py @@ -19,7 +19,6 @@ import os import time import logging -import sqlite3 import nbxmpp from nbxmpp.protocol import NodeProcessed @@ -30,7 +29,6 @@ from nbxmpp.structs import StanzaHandler from nbxmpp.modules.omemo import create_omemo_message from gajim.common import app -from gajim.common import ged from gajim.common import helpers from gajim.common import configpaths from gajim.common.nec import NetworkEvent @@ -117,9 +115,7 @@ class OMEMO(BaseModule): def __get_omemo(self): data_dir = configpaths.get('MY_DATA') db_path = os.path.join(data_dir, 'omemo_' + self.own_jid + '.db') - conn = sqlite3.connect(db_path, check_same_thread=False) - conn.execute("PRAGMA secure_delete=1") - return OmemoState(self.own_jid, conn, self._account, self) + return OmemoState(self.own_jid, db_path, self._account, self) def on_signed_in(self): log.info('%s => Announce Support after Sign In', self._account)