[omemo] Refactor AxolotlStore

- Dont use cursor object
- Use namedtuple factory
This commit is contained in:
Philipp Hörist
2019-02-16 13:57:43 +01:00
parent 41cd853085
commit 87ece2397e
3 changed files with 257 additions and 361 deletions

View File

@@ -17,6 +17,8 @@
import logging import logging
import sqlite3
from collections import namedtuple
from axolotl.state.axolotlstore import AxolotlStore from axolotl.state.axolotlstore import AxolotlStore
from axolotl.state.signedprekeyrecord import SignedPreKeyRecord from axolotl.state.signedprekeyrecord import SignedPreKeyRecord
@@ -44,41 +46,58 @@ UNTRUSTED = 0
class LiteAxolotlStore(AxolotlStore): class LiteAxolotlStore(AxolotlStore):
def __init__(self, connection): def __init__(self, db_path):
self.dbConn = connection self._con = sqlite3.connect(db_path, check_same_thread=False)
self.dbConn.text_factory = bytes self._con.row_factory = self._namedtuple_factory
self.createDb() self.createDb()
self.migrateDb() self.migrateDb()
c = self.dbConn.cursor()
c.execute("PRAGMA synchronous=NORMAL;") self._con.execute("PRAGMA secure_delete=1")
c.execute("PRAGMA journal_mode;") self._con.execute("PRAGMA synchronous=NORMAL;")
mode = c.fetchone()[0] mode = self._con.execute("PRAGMA journal_mode;").fetchone()[0]
# WAL is a persistent DB mode, don't override it if user has set it # WAL is a persistent DB mode, don't override it if user has set it
if mode != 'wal': if mode != 'wal':
c.execute("PRAGMA journal_mode=MEMORY;") self._con.execute("PRAGMA journal_mode=MEMORY;")
self.dbConn.commit() self._con.commit()
if not self.getLocalRegistrationId(): if not self.getLocalRegistrationId():
log.info("Generating OMEMO keys") log.info("Generating OMEMO keys")
self._generate_axolotl_keys() self._generate_axolotl_keys()
@staticmethod
def _namedtuple_factory(cursor, row):
fields = []
for col in cursor.description:
if col[0] == '_id':
fields.append('id')
elif 'strftime' in col[0]:
fields.append('formated_time')
elif 'MAX' in col[0] or 'COUNT' in col[0]:
col_name = col[0].replace('(', '_')
col_name = col_name.replace(')', '')
fields.append(col_name.lower())
else:
fields.append(col[0])
return namedtuple("Row", fields)(*row)
def _generate_axolotl_keys(self): def _generate_axolotl_keys(self):
identityKeyPair = KeyHelper.generateIdentityKeyPair() identity_key_pair = KeyHelper.generateIdentityKeyPair()
registrationId = KeyHelper.generateRegistrationId() registration_id = KeyHelper.generateRegistrationId()
preKeys = KeyHelper.generatePreKeys(KeyHelper.getRandomSequence(), pre_keys = KeyHelper.generatePreKeys(KeyHelper.getRandomSequence(),
DEFAULT_PREKEY_AMOUNT) DEFAULT_PREKEY_AMOUNT)
self.storeLocalData(registrationId, identityKeyPair) self.storeLocalData(registration_id, identity_key_pair)
signedPreKey = KeyHelper.generateSignedPreKey( signed_pre_key = KeyHelper.generateSignedPreKey(
identityKeyPair, KeyHelper.getRandomSequence(65536)) identity_key_pair, KeyHelper.getRandomSequence(65536))
self.storeSignedPreKey(signedPreKey.getId(), signedPreKey) self.storeSignedPreKey(signed_pre_key.getId(), signed_pre_key)
for preKey in preKeys: for pre_key in pre_keys:
self.storePreKey(preKey.getId(), preKey) self.storePreKey(pre_key.getId(), pre_key)
def user_version(self): def user_version(self):
return self.dbConn.execute('PRAGMA user_version').fetchone()[0] return self._con.execute('PRAGMA user_version').fetchone()[0]
def createDb(self): def createDb(self):
if self.user_version() == 0: if self.user_version() == 0:
@@ -122,7 +141,7 @@ class LiteAxolotlStore(AxolotlStore):
PRAGMA user_version=5; PRAGMA user_version=5;
END TRANSACTION; END TRANSACTION;
""" % (create_tables) """ % (create_tables)
self.dbConn.executescript(create_db_sql) self._con.executescript(create_db_sql)
def migrateDb(self): def migrateDb(self):
""" Migrates the DB """ Migrates the DB
@@ -138,11 +157,12 @@ class LiteAxolotlStore(AxolotlStore):
); );
""" """
self.dbConn.executescript(""" BEGIN TRANSACTION; self._con.executescript(
%s """ BEGIN TRANSACTION;
PRAGMA user_version=2; %s
END TRANSACTION; PRAGMA user_version=2;
""" % (delete_dupes)) END TRANSACTION;
""" % (delete_dupes))
if self.user_version() < 3: if self.user_version() < 3:
# Create a UNIQUE INDEX so every public key/recipient_id tuple # Create a UNIQUE INDEX so every public key/recipient_id tuple
@@ -152,11 +172,12 @@ class LiteAxolotlStore(AxolotlStore):
ON identities (public_key, recipient_id); ON identities (public_key, recipient_id);
""" """
self.dbConn.executescript(""" BEGIN TRANSACTION; self._con.executescript(
%s """ BEGIN TRANSACTION;
PRAGMA user_version=3; %s
END TRANSACTION; PRAGMA user_version=3;
""" % (add_index)) END TRANSACTION;
""" % (add_index))
if self.user_version() < 4: if self.user_version() < 4:
# Adds column "active" to the sessions table # Adds column "active" to the sessions table
@@ -164,11 +185,12 @@ class LiteAxolotlStore(AxolotlStore):
ADD COLUMN active INTEGER DEFAULT 1; ADD COLUMN active INTEGER DEFAULT 1;
""" """
self.dbConn.executescript(""" BEGIN TRANSACTION; self._con.executescript(
%s """ BEGIN TRANSACTION;
PRAGMA user_version=4; %s
END TRANSACTION; PRAGMA user_version=4;
""" % (add_active)) END TRANSACTION;
""" % (add_active))
if self.user_version() < 5: if self.user_version() < 5:
# Adds DEFAULT Timestamp # Adds DEFAULT Timestamp
@@ -182,437 +204,316 @@ class LiteAxolotlStore(AxolotlStore):
UPDATE identities SET shown = 1; UPDATE identities SET shown = 1;
""" """
self.dbConn.executescript(""" BEGIN TRANSACTION; self._con.executescript(
%s """ BEGIN TRANSACTION;
PRAGMA user_version=5; %s
END TRANSACTION; PRAGMA user_version=5;
""" % (add_timestamp)) END TRANSACTION;
""" % (add_timestamp))
def loadSignedPreKey(self, signedPreKeyId): def loadSignedPreKey(self, signedPreKeyId):
q = "SELECT record FROM signed_prekeys WHERE prekey_id = ?" query = 'SELECT record FROM signed_prekeys WHERE prekey_id = ?'
result = self._con.execute(query, (signedPreKeyId, )).fetchone()
cursor = self.dbConn.cursor() if result is None:
cursor.execute(q, (signedPreKeyId, ))
result = cursor.fetchone()
if not result:
raise InvalidKeyIdException("No such signedprekeyrecord! %s " % raise InvalidKeyIdException("No such signedprekeyrecord! %s " %
signedPreKeyId) signedPreKeyId)
return SignedPreKeyRecord(serialized=result.record)
return SignedPreKeyRecord(serialized=result[0])
def loadSignedPreKeys(self): def loadSignedPreKeys(self):
q = "SELECT record FROM signed_prekeys" query = 'SELECT record FROM signed_prekeys'
results = self._con.execute(query).fetchall()
cursor = self.dbConn.cursor() return [SignedPreKeyRecord(serialized=row.record) for row in results]
cursor.execute(q, )
result = cursor.fetchall()
results = []
for row in result:
results.append(SignedPreKeyRecord(serialized=row[0]))
return results
def storeSignedPreKey(self, signedPreKeyId, signedPreKeyRecord): def storeSignedPreKey(self, signedPreKeyId, signedPreKeyRecord):
q = "INSERT INTO signed_prekeys (prekey_id, record) VALUES(?,?)" query = 'INSERT INTO signed_prekeys (prekey_id, record) VALUES(?,?)'
cursor = self.dbConn.cursor() self._con.execute(query, (signedPreKeyId,
cursor.execute(q, (signedPreKeyId, signedPreKeyRecord.serialize())) signedPreKeyRecord.serialize()))
self.dbConn.commit() self._con.commit()
def containsSignedPreKey(self, signedPreKeyId): def containsSignedPreKey(self, signedPreKeyId):
q = "SELECT record FROM signed_prekeys WHERE prekey_id = ?" query = 'SELECT record FROM signed_prekeys WHERE prekey_id = ?'
cursor = self.dbConn.cursor() result = self._con.execute(query, (signedPreKeyId,)).fetchone()
cursor.execute(q, (signedPreKeyId, )) return result is not None
return cursor.fetchone() is not None
def removeSignedPreKey(self, signedPreKeyId): def removeSignedPreKey(self, signedPreKeyId):
q = "DELETE FROM signed_prekeys WHERE prekey_id = ?" query = 'DELETE FROM signed_prekeys WHERE prekey_id = ?'
cursor = self.dbConn.cursor() self._con.execute(query, (signedPreKeyId,))
cursor.execute(q, (signedPreKeyId, )) self._con.commit()
self.dbConn.commit()
def getNextSignedPreKeyId(self): def getNextSignedPreKeyId(self):
result = self.getCurrentSignedPreKeyId() result = self.getCurrentSignedPreKeyId()
if not result: if result is None:
return 1 # StartId if no SignedPreKeys exist return 1 # StartId if no SignedPreKeys exist
else: return (result % (Medium.MAX_VALUE - 1)) + 1
return (result % (Medium.MAX_VALUE - 1)) + 1
def getCurrentSignedPreKeyId(self): def getCurrentSignedPreKeyId(self):
q = "SELECT MAX(prekey_id) FROM signed_prekeys" query = 'SELECT MAX(prekey_id) FROM signed_prekeys'
result = self._con.execute(query).fetchone()
cursor = self.dbConn.cursor() return result.max_prekey_id if result is not None else None
cursor.execute(q)
result = cursor.fetchone()
if not result:
return None
else:
return result[0]
def getSignedPreKeyTimestamp(self, signedPreKeyId): def getSignedPreKeyTimestamp(self, signedPreKeyId):
q = "SELECT strftime('%s', timestamp) FROM " \ query = '''SELECT strftime('%s', timestamp) FROM
"signed_prekeys WHERE prekey_id = ?" signed_prekeys WHERE prekey_id = ?'''
cursor = self.dbConn.cursor() result = self._con.execute(query, (signedPreKeyId,)).fetchone()
cursor.execute(q, (signedPreKeyId, )) if result is None:
raise InvalidKeyIdException('No such signedprekeyrecord! %s' %
result = cursor.fetchone()
if not result:
raise InvalidKeyIdException("No such signedprekeyrecord! %s " %
signedPreKeyId) signedPreKeyId)
return result[0] return result.formated_time
def removeOldSignedPreKeys(self, timestamp): def removeOldSignedPreKeys(self, timestamp):
q = "DELETE FROM signed_prekeys " \ query = '''DELETE FROM signed_prekeys
"WHERE timestamp < datetime(?, 'unixepoch')" WHERE timestamp < datetime(?, "unixepoch")'''
cursor = self.dbConn.cursor() self._con.execute(query, (timestamp,))
cursor.execute(q, (timestamp, )) self._con.commit()
self.dbConn.commit()
def loadSession(self, recipientId, deviceId): def loadSession(self, recipientId, deviceId):
q = "SELECT record FROM sessions WHERE recipient_id = ? AND device_id = ?" query = '''SELECT record FROM sessions WHERE
c = self.dbConn.cursor() recipient_id = ? AND device_id = ?'''
c.execute(q, (recipientId, deviceId)) result = self._con.execute(query, (recipientId, deviceId)).fetchone()
result = c.fetchone() if result is None:
if result:
return SessionRecord(serialized=result[0])
else:
return SessionRecord() return SessionRecord()
return SessionRecord(serialized=result.record)
def getSubDeviceSessions(self, recipientId):
q = "SELECT device_id from sessions WHERE recipient_id = ?"
c = self.dbConn.cursor()
c.execute(q, (recipientId, ))
result = c.fetchall()
deviceIds = [r[0] for r in result]
return deviceIds
def getJidFromDevice(self, device_id): def getJidFromDevice(self, device_id):
q = "SELECT recipient_id from sessions WHERE device_id = ?" query = 'SELECT recipient_id from sessions WHERE device_id = ?'
c = self.dbConn.cursor() result = self._con.execute(query, (device_id, )).fetchone()
c.execute(q, (device_id, )) return result.recipient_id if result is not None else None
result = c.fetchone()
return result[0].decode('utf-8') if result else None
def getActiveDeviceTuples(self): def getActiveDeviceTuples(self):
q = "SELECT recipient_id, device_id FROM sessions WHERE active = 1" query = 'SELECT recipient_id, device_id FROM sessions WHERE active = 1'
c = self.dbConn.cursor() return self._con.execute(query).fetchall()
result = []
for row in c.execute(q):
result.append((row[0].decode('utf-8'), row[1]))
return result
def storeSession(self, recipientId, deviceId, sessionRecord): def storeSession(self, recipientId, deviceId, sessionRecord):
self.deleteSession(recipientId, deviceId) self.deleteSession(recipientId, deviceId)
q = "INSERT INTO sessions(recipient_id, device_id, record) VALUES(?,?,?)" query = '''INSERT INTO sessions(recipient_id, device_id, record)
c = self.dbConn.cursor() VALUES(?,?,?)'''
c.execute(q, (recipientId, deviceId, sessionRecord.serialize())) self._con.execute(query, (recipientId,
self.dbConn.commit() deviceId,
sessionRecord.serialize()))
self._con.commit()
def containsSession(self, recipientId, deviceId): def containsSession(self, recipientId, deviceId):
q = "SELECT record FROM sessions WHERE recipient_id = ? AND device_id = ?" query = '''SELECT record FROM sessions
c = self.dbConn.cursor() WHERE recipient_id = ? AND device_id = ?'''
c.execute(q, (recipientId, deviceId)) result = self._con.execute(query, (recipientId, deviceId)).fetchone()
result = c.fetchone()
return result is not None return result is not None
def deleteSession(self, recipientId, deviceId): def deleteSession(self, recipientId, deviceId):
q = "DELETE FROM sessions WHERE recipient_id = ? AND device_id = ?" query = "DELETE FROM sessions WHERE recipient_id = ? AND device_id = ?"
self.dbConn.cursor().execute(q, (recipientId, deviceId)) self._con.execute(query, (recipientId, deviceId))
self.dbConn.commit() self._con.commit()
def deleteAllSessions(self, recipientId): def deleteAllSessions(self, recipientId):
q = "DELETE FROM sessions WHERE recipient_id = ?" query = 'DELETE FROM sessions WHERE recipient_id = ?'
self.dbConn.cursor().execute(q, (recipientId, )) self._con.execute(query, (recipientId,))
self.dbConn.commit() self._con.commit()
def getAllSessions(self):
q = "SELECT _id, recipient_id, device_id, record, active from sessions"
c = self.dbConn.cursor()
result = []
for row in c.execute(q):
result.append((row[0], row[1].decode('utf-8'), row[2], row[3], row[4]))
return result
def getSessionsFromJid(self, recipientId): def getSessionsFromJid(self, recipientId):
q = "SELECT _id, recipient_id, device_id, record, active from sessions" \ query = '''SELECT _id, recipient_id, device_id, record, active
" WHERE recipient_id = ?" from sessions WHERE recipient_id = ?'''
c = self.dbConn.cursor() return self._con.execute(query, (recipientId,)).fetchall()
result = []
for row in c.execute(q, (recipientId,)):
result.append((row[0], row[1].decode('utf-8'), row[2], row[3], row[4]))
return result
def getSessionsFromJids(self, recipientId): def getSessionsFromJids(self, recipientIds):
q = "SELECT _id, recipient_id, device_id, record, active from sessions" \ query = '''SELECT _id, recipient_id, device_id, record, active from sessions
" WHERE recipient_id IN ({})" \ WHERE recipient_id IN ({})'''.format(
.format(', '.join(['?'] * len(recipientId))) ', '.join(['?'] * len(recipientIds)))
c = self.dbConn.cursor() return self._con.execute(query, recipientIds).fetchall()
result = []
for row in c.execute(q, recipientId):
result.append((row[0], row[1].decode('utf-8'), row[2], row[3], row[4]))
return result
def setActiveState(self, deviceList, jid): def setActiveState(self, deviceList, jid):
c = self.dbConn.cursor() query = '''UPDATE sessions SET active = 1
WHERE recipient_id = ? AND device_id IN ({})'''.format(
', '.join(['?'] * len(deviceList)))
self._con.execute(query, (jid,) + tuple(deviceList))
q = "UPDATE sessions SET active = {} " \ query = '''UPDATE sessions SET active = 0
"WHERE recipient_id = '{}' AND device_id IN ({})" \ WHERE recipient_id = ? AND device_id NOT IN ({})'''.format(
.format(1, jid, ', '.join(['?'] * len(deviceList))) ', '.join(['?'] * len(deviceList)))
c.execute(q, deviceList) self._con.execute(query, (jid,) + tuple(deviceList))
self._con.commit()
q = "UPDATE sessions SET active = {} " \
"WHERE recipient_id = '{}' AND device_id NOT IN ({})" \
.format(0, jid, ', '.join(['?'] * len(deviceList)))
c.execute(q, deviceList)
self.dbConn.commit()
def getInactiveSessionsKeys(self, recipientId): def getInactiveSessionsKeys(self, recipientId):
q = "SELECT record FROM sessions WHERE active = 0 AND recipient_id = ?" query = '''SELECT record FROM sessions
c = self.dbConn.cursor() WHERE active = 0 AND recipient_id = ?'''
result = [] result = self._con.execute(query, (recipientId,)).fetchall()
for row in c.execute(q, (recipientId,)):
public_key = (SessionRecord(serialized=row[0]). results = []
for row in result:
public_key = (SessionRecord(serialized=row.record).
getSessionState().getRemoteIdentityKey(). getSessionState().getRemoteIdentityKey().
getPublicKey()) getPublicKey())
result.append(public_key.serialize()) results.append(public_key.serialize())
return result return results
def loadPreKey(self, preKeyId): def loadPreKey(self, preKeyId):
q = "SELECT record FROM prekeys WHERE prekey_id = ?" query = '''SELECT record FROM prekeys WHERE prekey_id = ?'''
cursor = self.dbConn.cursor() result = self._con.execute(query, (preKeyId,)).fetchone()
cursor.execute(q, (preKeyId, )) if result is None:
result = cursor.fetchone()
if not result:
raise Exception("No such prekeyRecord!") raise Exception("No such prekeyRecord!")
return PreKeyRecord(serialized=result.record)
return PreKeyRecord(serialized=result[0])
def loadPendingPreKeys(self): def loadPendingPreKeys(self):
q = "SELECT record FROM prekeys" query = '''SELECT record FROM prekeys'''
cursor = self.dbConn.cursor() result = self._con.execute(query).fetchall()
cursor.execute(q) return [PreKeyRecord(serialized=row.record) for row in result]
result = cursor.fetchall()
return [PreKeyRecord(serialized=r[0]) for r in result]
def storePreKey(self, preKeyId, preKeyRecord): def storePreKey(self, preKeyId, preKeyRecord):
q = "INSERT INTO prekeys (prekey_id, record) VALUES(?,?)" query = 'INSERT INTO prekeys (prekey_id, record) VALUES(?,?)'
cursor = self.dbConn.cursor() self._con.execute(query, (preKeyId, preKeyRecord.serialize()))
cursor.execute(q, (preKeyId, preKeyRecord.serialize())) self._con.commit()
self.dbConn.commit()
def containsPreKey(self, preKeyId): def containsPreKey(self, preKeyId):
q = "SELECT record FROM prekeys WHERE prekey_id = ?" query = 'SELECT record FROM prekeys WHERE prekey_id = ?'
cursor = self.dbConn.cursor() result = self._con.execute(query, (preKeyId,)).fetchone()
cursor.execute(q, (preKeyId, )) return result is not None
return cursor.fetchone() is not None
def removePreKey(self, preKeyId): def removePreKey(self, preKeyId):
q = "DELETE FROM prekeys WHERE prekey_id = ?" query = 'DELETE FROM prekeys WHERE prekey_id = ?'
cursor = self.dbConn.cursor() self._con.execute(query, (preKeyId,))
cursor.execute(q, (preKeyId, )) self._con.commit()
self.dbConn.commit()
def getCurrentPreKeyId(self): def getCurrentPreKeyId(self):
q = "SELECT MAX(prekey_id) FROM prekeys" query = 'SELECT MAX(prekey_id) FROM prekeys'
cursor = self.dbConn.cursor() return self._con.execute(query).fetchone().max_prekey_id
cursor.execute(q)
return cursor.fetchone()[0]
def getPreKeyCount(self): def getPreKeyCount(self):
q = "SELECT COUNT(prekey_id) FROM prekeys" query = 'SELECT COUNT(prekey_id) FROM prekeys'
cursor = self.dbConn.cursor() return self._con.execute(query).fetchone().count_prekey_id
cursor.execute(q)
return cursor.fetchone()[0]
def generateNewPreKeys(self, count): def generateNewPreKeys(self, count):
startId = self.getCurrentPreKeyId() + 1 start_id = self.getCurrentPreKeyId() + 1
preKeys = KeyHelper.generatePreKeys(startId, count) pre_keys = KeyHelper.generatePreKeys(start_id, count)
for preKey in preKeys: for pre_key in pre_keys:
self.storePreKey(preKey.getId(), preKey) self.storePreKey(pre_key.getId(), pre_key)
def getIdentityKeyPair(self): def getIdentityKeyPair(self):
q = "SELECT public_key, private_key FROM identities " + \ query = '''SELECT public_key, private_key FROM identities
"WHERE recipient_id = -1" WHERE recipient_id = -1'''
c = self.dbConn.cursor() result = self._con.execute(query).fetchone()
c.execute(q)
result = c.fetchone()
publicKey, privateKey = result
return IdentityKeyPair( return IdentityKeyPair(
IdentityKey(DjbECPublicKey(publicKey[1:])), IdentityKey(DjbECPublicKey(result.public_key[1:])),
DjbECPrivateKey(privateKey)) DjbECPrivateKey(result.private_key))
def getLocalRegistrationId(self): def getLocalRegistrationId(self):
q = "SELECT registration_id FROM identities WHERE recipient_id = -1" query = 'SELECT registration_id FROM identities WHERE recipient_id = -1'
c = self.dbConn.cursor() result = self._con.execute(query).fetchone()
c.execute(q) return result.registration_id if result is not None else None
result = c.fetchone()
return result[0] if result else None
def storeLocalData(self, registrationId, identityKeyPair): def storeLocalData(self, registrationId, identityKeyPair):
q = "INSERT INTO identities( " + \ query = '''INSERT INTO identities(
"recipient_id, registration_id, public_key, private_key) " + \ recipient_id, registration_id, public_key, private_key)
"VALUES(-1, ?, ?, ?)" VALUES(-1, ?, ?, ?)'''
c = self.dbConn.cursor()
c.execute(q,
(registrationId,
identityKeyPair.getPublicKey().getPublicKey().serialize(),
identityKeyPair.getPrivateKey().serialize()))
self.dbConn.commit() public_key = identityKeyPair.getPublicKey().getPublicKey().serialize()
private_key = identityKeyPair.getPrivateKey().serialize()
self._con.execute(query, (registrationId, public_key, private_key))
self._con.commit()
def saveIdentity(self, recipientId, identityKey): def saveIdentity(self, recipientId, identityKey):
q = "INSERT INTO identities (recipient_id, public_key, trust) " \ query = '''INSERT INTO identities (recipient_id, public_key, trust)
"VALUES(?, ?, ?)" VALUES(?, ?, ?)'''
c = self.dbConn.cursor() if not self.containsIdentity(recipientId, identityKey):
self._con.execute(query, (recipientId,
identityKey.getPublicKey().serialize(),
UNDECIDED))
self._con.commit()
if not self.getIdentity(recipientId, identityKey): def containsIdentity(self, recipientId, identityKey):
c.execute(q, (recipientId, query = '''SELECT * FROM identities WHERE recipient_id = ?
identityKey.getPublicKey().serialize(), AND public_key = ?'''
UNDECIDED))
self.dbConn.commit()
def getIdentity(self, recipientId, identityKey): public_key = identityKey.getPublicKey().serialize()
q = "SELECT * FROM identities WHERE recipient_id = ? " \ result = self._con.execute(query, (recipientId,
"AND public_key = ?" public_key)).fetchone()
c = self.dbConn.cursor()
c.execute(q, (recipientId, identityKey.getPublicKey().serialize()))
result = c.fetchone()
return result is not None return result is not None
def deleteIdentity(self, recipientId, identityKey): def deleteIdentity(self, recipientId, identityKey):
q = "DELETE FROM identities WHERE recipient_id = ? AND public_key = ?" query = '''DELETE FROM identities
c = self.dbConn.cursor() WHERE recipient_id = ? AND public_key = ?'''
c.execute(q, (recipientId, public_key = identityKey.getPublicKey().serialize()
identityKey.getPublicKey().serialize())) self._con.execute(query, (recipientId, public_key))
self.dbConn.commit() self._con.commit()
def isTrustedIdentity(self, recipientId, identityKey): def isTrustedIdentity(self, recipientId, identityKey):
q = "SELECT trust FROM identities WHERE recipient_id = ? " \ query = '''SELECT trust FROM identities WHERE recipient_id = ?
"AND public_key = ?" AND public_key = ?'''
c = self.dbConn.cursor() public_key = identityKey.getPublicKey().serialize()
result = self._con.execute(query, (recipientId, public_key)).fetchone()
c.execute(q, (recipientId, identityKey.getPublicKey().serialize())) if result is None:
result = c.fetchone() return True
states = [UNTRUSTED, TRUSTED, UNDECIDED] states = [UNTRUSTED, TRUSTED, UNDECIDED]
if result.trust in states:
if result and result[0] in states: return result.trust
return result[0] return False
else:
return True
def getAllFingerprints(self): def getAllFingerprints(self):
q = "SELECT _id, recipient_id, public_key, trust FROM identities " \ query = '''SELECT _id, recipient_id, public_key, trust FROM identities
"WHERE recipient_id != -1 ORDER BY recipient_id ASC" WHERE recipient_id != -1 ORDER BY recipient_id ASC'''
c = self.dbConn.cursor() return self._con.execute(query).fetchall()
result = []
for row in c.execute(q):
result.append((row[0], row[1], row[2], row[3]))
return result
def getFingerprints(self, jid): def getFingerprints(self, jid):
q = "SELECT _id, recipient_id, public_key, trust FROM identities " \ query = '''SELECT _id, recipient_id, public_key, trust FROM identities
"WHERE recipient_id =? ORDER BY trust ASC" WHERE recipient_id =? ORDER BY trust ASC'''
c = self.dbConn.cursor() return self._con.execute(query, (jid,)).fetchall()
result = []
c.execute(q, (jid,))
rows = c.fetchall()
for row in rows:
result.append((row[0], row[1], row[2], row[3]))
return result
def getTrustedFingerprints(self, jid): def getTrustedFingerprints(self, jid):
q = "SELECT public_key FROM identities WHERE recipient_id = ? AND trust = ?" query = '''SELECT public_key FROM identities
c = self.dbConn.cursor() WHERE recipient_id = ? AND trust = ?'''
result = self._con.execute(query, (jid, TRUSTED)).fetchall()
result = [] return [row.public_key for row in result]
c.execute(q, (jid, TRUSTED))
rows = c.fetchall()
for row in rows:
result.append(row[0])
return result
def getUndecidedFingerprints(self, jid): def getUndecidedFingerprints(self, jid):
q = "SELECT trust FROM identities WHERE recipient_id = ? AND trust = ?" query = '''SELECT trust FROM identities
c = self.dbConn.cursor() WHERE recipient_id = ? AND trust = ?'''
return self._con.execute(query, (jid, UNDECIDED)).fetchall()
result = []
c.execute(q, (jid, UNDECIDED))
result = c.fetchall()
return result
def getNewFingerprints(self, jid): def getNewFingerprints(self, jid):
q = "SELECT _id FROM identities WHERE shown = 0 AND " \ query = '''SELECT _id FROM identities WHERE shown = 0
"recipient_id = ?" AND recipient_id = ?'''
c = self.dbConn.cursor()
result = [] result = self._con.execute(query, (jid,)).fetchall()
for row in c.execute(q, (jid,)): return [row.id for row in result]
result.append(row[0])
return result
def setShownFingerprints(self, fingerprints): def setShownFingerprints(self, fingerprints):
q = "UPDATE identities SET shown = 1 WHERE _id IN ({})" \ query = 'UPDATE identities SET shown = 1 WHERE _id IN ({})'.format(
.format(', '.join(['?'] * len(fingerprints))) ', '.join(['?'] * len(fingerprints)))
c = self.dbConn.cursor() self._con.execute(query, fingerprints)
c.execute(q, fingerprints) self._con.commit()
self.dbConn.commit()
def setTrust(self, identityKey, trust): def setTrust(self, identityKey, trust):
q = "UPDATE identities SET trust = ? WHERE public_key = ?" query = 'UPDATE identities SET trust = ? WHERE public_key = ?'
c = self.dbConn.cursor() public_key = identityKey.getPublicKey().serialize()
c.execute(q, (trust, identityKey.getPublicKey().serialize())) self._con.execute(query, (trust, public_key))
self.dbConn.commit() self._con.commit()
def activate(self, jid): def activate(self, jid):
q = """INSERT OR REPLACE INTO encryption_state (jid, encryption) query = '''INSERT OR REPLACE INTO encryption_state (jid, encryption)
VALUES (?, 1) """ VALUES (?, 1)'''
c = self.dbConn.cursor() self._con.execute(query, (jid,))
c.execute(q, (jid, )) self._con.commit()
self.dbConn.commit()
def deactivate(self, jid): def deactivate(self, jid):
q = """INSERT OR REPLACE INTO encryption_state (jid, encryption) query = '''INSERT OR REPLACE INTO encryption_state (jid, encryption)
VALUES (?, 0)""" VALUES (?, 0)'''
c = self.dbConn.cursor() self._con.execute(query, (jid, ))
c.execute(q, (jid, )) self._con.commit()
self.dbConn.commit()
def is_active(self, jid): def is_active(self, jid):
q = 'SELECT encryption FROM encryption_state where jid = ?;' query = 'SELECT encryption FROM encryption_state where jid = ?'
c = self.dbConn.cursor() result = self._con.execute(query, (jid,)).fetchone()
c.execute(q, (jid, )) return result.encryption if result is not None else False
result = c.fetchone()
if result is None:
return False
return result[0]
def exist(self, jid): def exist(self, jid):
q = 'SELECT encryption FROM encryption_state where jid = ?;' query = 'SELECT encryption FROM encryption_state where jid = ?'
c = self.dbConn.cursor() result = self._con.execute(query, (jid,)).fetchone()
c.execute(q, (jid, )) return result is not None
result = c.fetchone()
if result is None:
return False
else:
return True

View File

@@ -17,8 +17,6 @@
import logging import logging
import time import time
import binascii
import textwrap
from collections import defaultdict from collections import defaultdict
from nbxmpp.structs import OMEMOBundle from nbxmpp.structs import OMEMOBundle
@@ -53,14 +51,15 @@ UNDECIDED = 2
class OmemoState: class OmemoState:
def __init__(self, own_jid, db_con, account, xmpp_con): def __init__(self, own_jid, db_path, account, xmpp_con):
self.account = account self.account = account
self.xmpp_con = xmpp_con self.xmpp_con = xmpp_con
self._session_ciphers = defaultdict(dict) self._session_ciphers = defaultdict(dict)
self.own_jid = own_jid self.own_jid = own_jid
self.device_ids = {} self.device_ids = {}
self.own_devices = [] self.own_devices = []
self.store = LiteAxolotlStore(db_con)
self.store = LiteAxolotlStore(db_path)
for jid, device_id in self.store.getActiveDeviceTuples(): for jid, device_id in self.store.getActiveDeviceTuples():
if jid != own_jid: if jid != own_jid:
self.add_device(jid, device_id) self.add_device(jid, device_id)

View File

@@ -19,7 +19,6 @@
import os import os
import time import time
import logging import logging
import sqlite3
import nbxmpp import nbxmpp
from nbxmpp.protocol import NodeProcessed from nbxmpp.protocol import NodeProcessed
@@ -30,7 +29,6 @@ from nbxmpp.structs import StanzaHandler
from nbxmpp.modules.omemo import create_omemo_message from nbxmpp.modules.omemo import create_omemo_message
from gajim.common import app from gajim.common import app
from gajim.common import ged
from gajim.common import helpers from gajim.common import helpers
from gajim.common import configpaths from gajim.common import configpaths
from gajim.common.nec import NetworkEvent from gajim.common.nec import NetworkEvent
@@ -117,9 +115,7 @@ class OMEMO(BaseModule):
def __get_omemo(self): def __get_omemo(self):
data_dir = configpaths.get('MY_DATA') data_dir = configpaths.get('MY_DATA')
db_path = os.path.join(data_dir, 'omemo_' + self.own_jid + '.db') db_path = os.path.join(data_dir, 'omemo_' + self.own_jid + '.db')
conn = sqlite3.connect(db_path, check_same_thread=False) return OmemoState(self.own_jid, db_path, self._account, self)
conn.execute("PRAGMA secure_delete=1")
return OmemoState(self.own_jid, conn, self._account, self)
def on_signed_in(self): def on_signed_in(self):
log.info('%s => Announce Support after Sign In', self._account) log.info('%s => Announce Support after Sign In', self._account)