[omemo] Refactor AxolotlStore
- Dont use cursor object - Use namedtuple factory
This commit is contained in:
@@ -17,6 +17,8 @@
|
|||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import sqlite3
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
from axolotl.state.axolotlstore import AxolotlStore
|
from axolotl.state.axolotlstore import AxolotlStore
|
||||||
from axolotl.state.signedprekeyrecord import SignedPreKeyRecord
|
from axolotl.state.signedprekeyrecord import SignedPreKeyRecord
|
||||||
@@ -44,41 +46,58 @@ UNTRUSTED = 0
|
|||||||
|
|
||||||
|
|
||||||
class LiteAxolotlStore(AxolotlStore):
|
class LiteAxolotlStore(AxolotlStore):
|
||||||
def __init__(self, connection):
|
def __init__(self, db_path):
|
||||||
self.dbConn = connection
|
self._con = sqlite3.connect(db_path, check_same_thread=False)
|
||||||
self.dbConn.text_factory = bytes
|
self._con.row_factory = self._namedtuple_factory
|
||||||
self.createDb()
|
self.createDb()
|
||||||
self.migrateDb()
|
self.migrateDb()
|
||||||
c = self.dbConn.cursor()
|
|
||||||
c.execute("PRAGMA synchronous=NORMAL;")
|
self._con.execute("PRAGMA secure_delete=1")
|
||||||
c.execute("PRAGMA journal_mode;")
|
self._con.execute("PRAGMA synchronous=NORMAL;")
|
||||||
mode = c.fetchone()[0]
|
mode = self._con.execute("PRAGMA journal_mode;").fetchone()[0]
|
||||||
|
|
||||||
# WAL is a persistent DB mode, don't override it if user has set it
|
# WAL is a persistent DB mode, don't override it if user has set it
|
||||||
if mode != 'wal':
|
if mode != 'wal':
|
||||||
c.execute("PRAGMA journal_mode=MEMORY;")
|
self._con.execute("PRAGMA journal_mode=MEMORY;")
|
||||||
self.dbConn.commit()
|
self._con.commit()
|
||||||
|
|
||||||
if not self.getLocalRegistrationId():
|
if not self.getLocalRegistrationId():
|
||||||
log.info("Generating OMEMO keys")
|
log.info("Generating OMEMO keys")
|
||||||
self._generate_axolotl_keys()
|
self._generate_axolotl_keys()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _namedtuple_factory(cursor, row):
|
||||||
|
fields = []
|
||||||
|
for col in cursor.description:
|
||||||
|
if col[0] == '_id':
|
||||||
|
fields.append('id')
|
||||||
|
elif 'strftime' in col[0]:
|
||||||
|
fields.append('formated_time')
|
||||||
|
elif 'MAX' in col[0] or 'COUNT' in col[0]:
|
||||||
|
col_name = col[0].replace('(', '_')
|
||||||
|
col_name = col_name.replace(')', '')
|
||||||
|
fields.append(col_name.lower())
|
||||||
|
else:
|
||||||
|
fields.append(col[0])
|
||||||
|
return namedtuple("Row", fields)(*row)
|
||||||
|
|
||||||
def _generate_axolotl_keys(self):
|
def _generate_axolotl_keys(self):
|
||||||
identityKeyPair = KeyHelper.generateIdentityKeyPair()
|
identity_key_pair = KeyHelper.generateIdentityKeyPair()
|
||||||
registrationId = KeyHelper.generateRegistrationId()
|
registration_id = KeyHelper.generateRegistrationId()
|
||||||
preKeys = KeyHelper.generatePreKeys(KeyHelper.getRandomSequence(),
|
pre_keys = KeyHelper.generatePreKeys(KeyHelper.getRandomSequence(),
|
||||||
DEFAULT_PREKEY_AMOUNT)
|
DEFAULT_PREKEY_AMOUNT)
|
||||||
self.storeLocalData(registrationId, identityKeyPair)
|
self.storeLocalData(registration_id, identity_key_pair)
|
||||||
|
|
||||||
signedPreKey = KeyHelper.generateSignedPreKey(
|
signed_pre_key = KeyHelper.generateSignedPreKey(
|
||||||
identityKeyPair, KeyHelper.getRandomSequence(65536))
|
identity_key_pair, KeyHelper.getRandomSequence(65536))
|
||||||
|
|
||||||
self.storeSignedPreKey(signedPreKey.getId(), signedPreKey)
|
self.storeSignedPreKey(signed_pre_key.getId(), signed_pre_key)
|
||||||
|
|
||||||
for preKey in preKeys:
|
for pre_key in pre_keys:
|
||||||
self.storePreKey(preKey.getId(), preKey)
|
self.storePreKey(pre_key.getId(), pre_key)
|
||||||
|
|
||||||
def user_version(self):
|
def user_version(self):
|
||||||
return self.dbConn.execute('PRAGMA user_version').fetchone()[0]
|
return self._con.execute('PRAGMA user_version').fetchone()[0]
|
||||||
|
|
||||||
def createDb(self):
|
def createDb(self):
|
||||||
if self.user_version() == 0:
|
if self.user_version() == 0:
|
||||||
@@ -122,7 +141,7 @@ class LiteAxolotlStore(AxolotlStore):
|
|||||||
PRAGMA user_version=5;
|
PRAGMA user_version=5;
|
||||||
END TRANSACTION;
|
END TRANSACTION;
|
||||||
""" % (create_tables)
|
""" % (create_tables)
|
||||||
self.dbConn.executescript(create_db_sql)
|
self._con.executescript(create_db_sql)
|
||||||
|
|
||||||
def migrateDb(self):
|
def migrateDb(self):
|
||||||
""" Migrates the DB
|
""" Migrates the DB
|
||||||
@@ -138,11 +157,12 @@ class LiteAxolotlStore(AxolotlStore):
|
|||||||
);
|
);
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.dbConn.executescript(""" BEGIN TRANSACTION;
|
self._con.executescript(
|
||||||
%s
|
""" BEGIN TRANSACTION;
|
||||||
PRAGMA user_version=2;
|
%s
|
||||||
END TRANSACTION;
|
PRAGMA user_version=2;
|
||||||
""" % (delete_dupes))
|
END TRANSACTION;
|
||||||
|
""" % (delete_dupes))
|
||||||
|
|
||||||
if self.user_version() < 3:
|
if self.user_version() < 3:
|
||||||
# Create a UNIQUE INDEX so every public key/recipient_id tuple
|
# Create a UNIQUE INDEX so every public key/recipient_id tuple
|
||||||
@@ -152,11 +172,12 @@ class LiteAxolotlStore(AxolotlStore):
|
|||||||
ON identities (public_key, recipient_id);
|
ON identities (public_key, recipient_id);
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.dbConn.executescript(""" BEGIN TRANSACTION;
|
self._con.executescript(
|
||||||
%s
|
""" BEGIN TRANSACTION;
|
||||||
PRAGMA user_version=3;
|
%s
|
||||||
END TRANSACTION;
|
PRAGMA user_version=3;
|
||||||
""" % (add_index))
|
END TRANSACTION;
|
||||||
|
""" % (add_index))
|
||||||
|
|
||||||
if self.user_version() < 4:
|
if self.user_version() < 4:
|
||||||
# Adds column "active" to the sessions table
|
# Adds column "active" to the sessions table
|
||||||
@@ -164,11 +185,12 @@ class LiteAxolotlStore(AxolotlStore):
|
|||||||
ADD COLUMN active INTEGER DEFAULT 1;
|
ADD COLUMN active INTEGER DEFAULT 1;
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.dbConn.executescript(""" BEGIN TRANSACTION;
|
self._con.executescript(
|
||||||
%s
|
""" BEGIN TRANSACTION;
|
||||||
PRAGMA user_version=4;
|
%s
|
||||||
END TRANSACTION;
|
PRAGMA user_version=4;
|
||||||
""" % (add_active))
|
END TRANSACTION;
|
||||||
|
""" % (add_active))
|
||||||
|
|
||||||
if self.user_version() < 5:
|
if self.user_version() < 5:
|
||||||
# Adds DEFAULT Timestamp
|
# Adds DEFAULT Timestamp
|
||||||
@@ -182,437 +204,316 @@ class LiteAxolotlStore(AxolotlStore):
|
|||||||
UPDATE identities SET shown = 1;
|
UPDATE identities SET shown = 1;
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.dbConn.executescript(""" BEGIN TRANSACTION;
|
self._con.executescript(
|
||||||
%s
|
""" BEGIN TRANSACTION;
|
||||||
PRAGMA user_version=5;
|
%s
|
||||||
END TRANSACTION;
|
PRAGMA user_version=5;
|
||||||
""" % (add_timestamp))
|
END TRANSACTION;
|
||||||
|
""" % (add_timestamp))
|
||||||
|
|
||||||
|
|
||||||
def loadSignedPreKey(self, signedPreKeyId):
|
def loadSignedPreKey(self, signedPreKeyId):
|
||||||
q = "SELECT record FROM signed_prekeys WHERE prekey_id = ?"
|
query = 'SELECT record FROM signed_prekeys WHERE prekey_id = ?'
|
||||||
|
result = self._con.execute(query, (signedPreKeyId, )).fetchone()
|
||||||
cursor = self.dbConn.cursor()
|
if result is None:
|
||||||
cursor.execute(q, (signedPreKeyId, ))
|
|
||||||
|
|
||||||
result = cursor.fetchone()
|
|
||||||
if not result:
|
|
||||||
raise InvalidKeyIdException("No such signedprekeyrecord! %s " %
|
raise InvalidKeyIdException("No such signedprekeyrecord! %s " %
|
||||||
signedPreKeyId)
|
signedPreKeyId)
|
||||||
|
return SignedPreKeyRecord(serialized=result.record)
|
||||||
return SignedPreKeyRecord(serialized=result[0])
|
|
||||||
|
|
||||||
def loadSignedPreKeys(self):
|
def loadSignedPreKeys(self):
|
||||||
q = "SELECT record FROM signed_prekeys"
|
query = 'SELECT record FROM signed_prekeys'
|
||||||
|
results = self._con.execute(query).fetchall()
|
||||||
cursor = self.dbConn.cursor()
|
return [SignedPreKeyRecord(serialized=row.record) for row in results]
|
||||||
cursor.execute(q, )
|
|
||||||
result = cursor.fetchall()
|
|
||||||
results = []
|
|
||||||
for row in result:
|
|
||||||
results.append(SignedPreKeyRecord(serialized=row[0]))
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def storeSignedPreKey(self, signedPreKeyId, signedPreKeyRecord):
|
def storeSignedPreKey(self, signedPreKeyId, signedPreKeyRecord):
|
||||||
q = "INSERT INTO signed_prekeys (prekey_id, record) VALUES(?,?)"
|
query = 'INSERT INTO signed_prekeys (prekey_id, record) VALUES(?,?)'
|
||||||
cursor = self.dbConn.cursor()
|
self._con.execute(query, (signedPreKeyId,
|
||||||
cursor.execute(q, (signedPreKeyId, signedPreKeyRecord.serialize()))
|
signedPreKeyRecord.serialize()))
|
||||||
self.dbConn.commit()
|
self._con.commit()
|
||||||
|
|
||||||
def containsSignedPreKey(self, signedPreKeyId):
|
def containsSignedPreKey(self, signedPreKeyId):
|
||||||
q = "SELECT record FROM signed_prekeys WHERE prekey_id = ?"
|
query = 'SELECT record FROM signed_prekeys WHERE prekey_id = ?'
|
||||||
cursor = self.dbConn.cursor()
|
result = self._con.execute(query, (signedPreKeyId,)).fetchone()
|
||||||
cursor.execute(q, (signedPreKeyId, ))
|
return result is not None
|
||||||
return cursor.fetchone() is not None
|
|
||||||
|
|
||||||
def removeSignedPreKey(self, signedPreKeyId):
|
def removeSignedPreKey(self, signedPreKeyId):
|
||||||
q = "DELETE FROM signed_prekeys WHERE prekey_id = ?"
|
query = 'DELETE FROM signed_prekeys WHERE prekey_id = ?'
|
||||||
cursor = self.dbConn.cursor()
|
self._con.execute(query, (signedPreKeyId,))
|
||||||
cursor.execute(q, (signedPreKeyId, ))
|
self._con.commit()
|
||||||
self.dbConn.commit()
|
|
||||||
|
|
||||||
def getNextSignedPreKeyId(self):
|
def getNextSignedPreKeyId(self):
|
||||||
result = self.getCurrentSignedPreKeyId()
|
result = self.getCurrentSignedPreKeyId()
|
||||||
if not result:
|
if result is None:
|
||||||
return 1 # StartId if no SignedPreKeys exist
|
return 1 # StartId if no SignedPreKeys exist
|
||||||
else:
|
return (result % (Medium.MAX_VALUE - 1)) + 1
|
||||||
return (result % (Medium.MAX_VALUE - 1)) + 1
|
|
||||||
|
|
||||||
def getCurrentSignedPreKeyId(self):
|
def getCurrentSignedPreKeyId(self):
|
||||||
q = "SELECT MAX(prekey_id) FROM signed_prekeys"
|
query = 'SELECT MAX(prekey_id) FROM signed_prekeys'
|
||||||
|
result = self._con.execute(query).fetchone()
|
||||||
cursor = self.dbConn.cursor()
|
return result.max_prekey_id if result is not None else None
|
||||||
cursor.execute(q)
|
|
||||||
result = cursor.fetchone()
|
|
||||||
if not result:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return result[0]
|
|
||||||
|
|
||||||
def getSignedPreKeyTimestamp(self, signedPreKeyId):
|
def getSignedPreKeyTimestamp(self, signedPreKeyId):
|
||||||
q = "SELECT strftime('%s', timestamp) FROM " \
|
query = '''SELECT strftime('%s', timestamp) FROM
|
||||||
"signed_prekeys WHERE prekey_id = ?"
|
signed_prekeys WHERE prekey_id = ?'''
|
||||||
|
|
||||||
cursor = self.dbConn.cursor()
|
result = self._con.execute(query, (signedPreKeyId,)).fetchone()
|
||||||
cursor.execute(q, (signedPreKeyId, ))
|
if result is None:
|
||||||
|
raise InvalidKeyIdException('No such signedprekeyrecord! %s' %
|
||||||
result = cursor.fetchone()
|
|
||||||
if not result:
|
|
||||||
raise InvalidKeyIdException("No such signedprekeyrecord! %s " %
|
|
||||||
signedPreKeyId)
|
signedPreKeyId)
|
||||||
|
|
||||||
return result[0]
|
return result.formated_time
|
||||||
|
|
||||||
def removeOldSignedPreKeys(self, timestamp):
|
def removeOldSignedPreKeys(self, timestamp):
|
||||||
q = "DELETE FROM signed_prekeys " \
|
query = '''DELETE FROM signed_prekeys
|
||||||
"WHERE timestamp < datetime(?, 'unixepoch')"
|
WHERE timestamp < datetime(?, "unixepoch")'''
|
||||||
cursor = self.dbConn.cursor()
|
self._con.execute(query, (timestamp,))
|
||||||
cursor.execute(q, (timestamp, ))
|
self._con.commit()
|
||||||
self.dbConn.commit()
|
|
||||||
|
|
||||||
def loadSession(self, recipientId, deviceId):
|
def loadSession(self, recipientId, deviceId):
|
||||||
q = "SELECT record FROM sessions WHERE recipient_id = ? AND device_id = ?"
|
query = '''SELECT record FROM sessions WHERE
|
||||||
c = self.dbConn.cursor()
|
recipient_id = ? AND device_id = ?'''
|
||||||
c.execute(q, (recipientId, deviceId))
|
result = self._con.execute(query, (recipientId, deviceId)).fetchone()
|
||||||
result = c.fetchone()
|
if result is None:
|
||||||
|
|
||||||
if result:
|
|
||||||
return SessionRecord(serialized=result[0])
|
|
||||||
else:
|
|
||||||
return SessionRecord()
|
return SessionRecord()
|
||||||
|
return SessionRecord(serialized=result.record)
|
||||||
def getSubDeviceSessions(self, recipientId):
|
|
||||||
q = "SELECT device_id from sessions WHERE recipient_id = ?"
|
|
||||||
c = self.dbConn.cursor()
|
|
||||||
c.execute(q, (recipientId, ))
|
|
||||||
result = c.fetchall()
|
|
||||||
|
|
||||||
deviceIds = [r[0] for r in result]
|
|
||||||
return deviceIds
|
|
||||||
|
|
||||||
def getJidFromDevice(self, device_id):
|
def getJidFromDevice(self, device_id):
|
||||||
q = "SELECT recipient_id from sessions WHERE device_id = ?"
|
query = 'SELECT recipient_id from sessions WHERE device_id = ?'
|
||||||
c = self.dbConn.cursor()
|
result = self._con.execute(query, (device_id, )).fetchone()
|
||||||
c.execute(q, (device_id, ))
|
return result.recipient_id if result is not None else None
|
||||||
result = c.fetchone()
|
|
||||||
|
|
||||||
return result[0].decode('utf-8') if result else None
|
|
||||||
|
|
||||||
def getActiveDeviceTuples(self):
|
def getActiveDeviceTuples(self):
|
||||||
q = "SELECT recipient_id, device_id FROM sessions WHERE active = 1"
|
query = 'SELECT recipient_id, device_id FROM sessions WHERE active = 1'
|
||||||
c = self.dbConn.cursor()
|
return self._con.execute(query).fetchall()
|
||||||
result = []
|
|
||||||
for row in c.execute(q):
|
|
||||||
result.append((row[0].decode('utf-8'), row[1]))
|
|
||||||
return result
|
|
||||||
|
|
||||||
def storeSession(self, recipientId, deviceId, sessionRecord):
|
def storeSession(self, recipientId, deviceId, sessionRecord):
|
||||||
self.deleteSession(recipientId, deviceId)
|
self.deleteSession(recipientId, deviceId)
|
||||||
|
|
||||||
q = "INSERT INTO sessions(recipient_id, device_id, record) VALUES(?,?,?)"
|
query = '''INSERT INTO sessions(recipient_id, device_id, record)
|
||||||
c = self.dbConn.cursor()
|
VALUES(?,?,?)'''
|
||||||
c.execute(q, (recipientId, deviceId, sessionRecord.serialize()))
|
self._con.execute(query, (recipientId,
|
||||||
self.dbConn.commit()
|
deviceId,
|
||||||
|
sessionRecord.serialize()))
|
||||||
|
self._con.commit()
|
||||||
|
|
||||||
def containsSession(self, recipientId, deviceId):
|
def containsSession(self, recipientId, deviceId):
|
||||||
q = "SELECT record FROM sessions WHERE recipient_id = ? AND device_id = ?"
|
query = '''SELECT record FROM sessions
|
||||||
c = self.dbConn.cursor()
|
WHERE recipient_id = ? AND device_id = ?'''
|
||||||
c.execute(q, (recipientId, deviceId))
|
result = self._con.execute(query, (recipientId, deviceId)).fetchone()
|
||||||
result = c.fetchone()
|
|
||||||
|
|
||||||
return result is not None
|
return result is not None
|
||||||
|
|
||||||
def deleteSession(self, recipientId, deviceId):
|
def deleteSession(self, recipientId, deviceId):
|
||||||
q = "DELETE FROM sessions WHERE recipient_id = ? AND device_id = ?"
|
query = "DELETE FROM sessions WHERE recipient_id = ? AND device_id = ?"
|
||||||
self.dbConn.cursor().execute(q, (recipientId, deviceId))
|
self._con.execute(query, (recipientId, deviceId))
|
||||||
self.dbConn.commit()
|
self._con.commit()
|
||||||
|
|
||||||
def deleteAllSessions(self, recipientId):
|
def deleteAllSessions(self, recipientId):
|
||||||
q = "DELETE FROM sessions WHERE recipient_id = ?"
|
query = 'DELETE FROM sessions WHERE recipient_id = ?'
|
||||||
self.dbConn.cursor().execute(q, (recipientId, ))
|
self._con.execute(query, (recipientId,))
|
||||||
self.dbConn.commit()
|
self._con.commit()
|
||||||
|
|
||||||
def getAllSessions(self):
|
|
||||||
q = "SELECT _id, recipient_id, device_id, record, active from sessions"
|
|
||||||
c = self.dbConn.cursor()
|
|
||||||
result = []
|
|
||||||
for row in c.execute(q):
|
|
||||||
result.append((row[0], row[1].decode('utf-8'), row[2], row[3], row[4]))
|
|
||||||
return result
|
|
||||||
|
|
||||||
def getSessionsFromJid(self, recipientId):
|
def getSessionsFromJid(self, recipientId):
|
||||||
q = "SELECT _id, recipient_id, device_id, record, active from sessions" \
|
query = '''SELECT _id, recipient_id, device_id, record, active
|
||||||
" WHERE recipient_id = ?"
|
from sessions WHERE recipient_id = ?'''
|
||||||
c = self.dbConn.cursor()
|
return self._con.execute(query, (recipientId,)).fetchall()
|
||||||
result = []
|
|
||||||
for row in c.execute(q, (recipientId,)):
|
|
||||||
result.append((row[0], row[1].decode('utf-8'), row[2], row[3], row[4]))
|
|
||||||
return result
|
|
||||||
|
|
||||||
def getSessionsFromJids(self, recipientId):
|
def getSessionsFromJids(self, recipientIds):
|
||||||
q = "SELECT _id, recipient_id, device_id, record, active from sessions" \
|
query = '''SELECT _id, recipient_id, device_id, record, active from sessions
|
||||||
" WHERE recipient_id IN ({})" \
|
WHERE recipient_id IN ({})'''.format(
|
||||||
.format(', '.join(['?'] * len(recipientId)))
|
', '.join(['?'] * len(recipientIds)))
|
||||||
c = self.dbConn.cursor()
|
return self._con.execute(query, recipientIds).fetchall()
|
||||||
result = []
|
|
||||||
for row in c.execute(q, recipientId):
|
|
||||||
result.append((row[0], row[1].decode('utf-8'), row[2], row[3], row[4]))
|
|
||||||
return result
|
|
||||||
|
|
||||||
def setActiveState(self, deviceList, jid):
|
def setActiveState(self, deviceList, jid):
|
||||||
c = self.dbConn.cursor()
|
query = '''UPDATE sessions SET active = 1
|
||||||
|
WHERE recipient_id = ? AND device_id IN ({})'''.format(
|
||||||
|
', '.join(['?'] * len(deviceList)))
|
||||||
|
self._con.execute(query, (jid,) + tuple(deviceList))
|
||||||
|
|
||||||
q = "UPDATE sessions SET active = {} " \
|
query = '''UPDATE sessions SET active = 0
|
||||||
"WHERE recipient_id = '{}' AND device_id IN ({})" \
|
WHERE recipient_id = ? AND device_id NOT IN ({})'''.format(
|
||||||
.format(1, jid, ', '.join(['?'] * len(deviceList)))
|
', '.join(['?'] * len(deviceList)))
|
||||||
c.execute(q, deviceList)
|
self._con.execute(query, (jid,) + tuple(deviceList))
|
||||||
|
self._con.commit()
|
||||||
q = "UPDATE sessions SET active = {} " \
|
|
||||||
"WHERE recipient_id = '{}' AND device_id NOT IN ({})" \
|
|
||||||
.format(0, jid, ', '.join(['?'] * len(deviceList)))
|
|
||||||
c.execute(q, deviceList)
|
|
||||||
self.dbConn.commit()
|
|
||||||
|
|
||||||
def getInactiveSessionsKeys(self, recipientId):
|
def getInactiveSessionsKeys(self, recipientId):
|
||||||
q = "SELECT record FROM sessions WHERE active = 0 AND recipient_id = ?"
|
query = '''SELECT record FROM sessions
|
||||||
c = self.dbConn.cursor()
|
WHERE active = 0 AND recipient_id = ?'''
|
||||||
result = []
|
result = self._con.execute(query, (recipientId,)).fetchall()
|
||||||
for row in c.execute(q, (recipientId,)):
|
|
||||||
public_key = (SessionRecord(serialized=row[0]).
|
results = []
|
||||||
|
for row in result:
|
||||||
|
public_key = (SessionRecord(serialized=row.record).
|
||||||
getSessionState().getRemoteIdentityKey().
|
getSessionState().getRemoteIdentityKey().
|
||||||
getPublicKey())
|
getPublicKey())
|
||||||
result.append(public_key.serialize())
|
results.append(public_key.serialize())
|
||||||
return result
|
return results
|
||||||
|
|
||||||
def loadPreKey(self, preKeyId):
|
def loadPreKey(self, preKeyId):
|
||||||
q = "SELECT record FROM prekeys WHERE prekey_id = ?"
|
query = '''SELECT record FROM prekeys WHERE prekey_id = ?'''
|
||||||
|
|
||||||
cursor = self.dbConn.cursor()
|
result = self._con.execute(query, (preKeyId,)).fetchone()
|
||||||
cursor.execute(q, (preKeyId, ))
|
if result is None:
|
||||||
|
|
||||||
result = cursor.fetchone()
|
|
||||||
if not result:
|
|
||||||
raise Exception("No such prekeyRecord!")
|
raise Exception("No such prekeyRecord!")
|
||||||
|
return PreKeyRecord(serialized=result.record)
|
||||||
return PreKeyRecord(serialized=result[0])
|
|
||||||
|
|
||||||
def loadPendingPreKeys(self):
|
def loadPendingPreKeys(self):
|
||||||
q = "SELECT record FROM prekeys"
|
query = '''SELECT record FROM prekeys'''
|
||||||
cursor = self.dbConn.cursor()
|
result = self._con.execute(query).fetchall()
|
||||||
cursor.execute(q)
|
return [PreKeyRecord(serialized=row.record) for row in result]
|
||||||
result = cursor.fetchall()
|
|
||||||
|
|
||||||
return [PreKeyRecord(serialized=r[0]) for r in result]
|
|
||||||
|
|
||||||
def storePreKey(self, preKeyId, preKeyRecord):
|
def storePreKey(self, preKeyId, preKeyRecord):
|
||||||
q = "INSERT INTO prekeys (prekey_id, record) VALUES(?,?)"
|
query = 'INSERT INTO prekeys (prekey_id, record) VALUES(?,?)'
|
||||||
cursor = self.dbConn.cursor()
|
self._con.execute(query, (preKeyId, preKeyRecord.serialize()))
|
||||||
cursor.execute(q, (preKeyId, preKeyRecord.serialize()))
|
self._con.commit()
|
||||||
self.dbConn.commit()
|
|
||||||
|
|
||||||
def containsPreKey(self, preKeyId):
|
def containsPreKey(self, preKeyId):
|
||||||
q = "SELECT record FROM prekeys WHERE prekey_id = ?"
|
query = 'SELECT record FROM prekeys WHERE prekey_id = ?'
|
||||||
cursor = self.dbConn.cursor()
|
result = self._con.execute(query, (preKeyId,)).fetchone()
|
||||||
cursor.execute(q, (preKeyId, ))
|
return result is not None
|
||||||
return cursor.fetchone() is not None
|
|
||||||
|
|
||||||
def removePreKey(self, preKeyId):
|
def removePreKey(self, preKeyId):
|
||||||
q = "DELETE FROM prekeys WHERE prekey_id = ?"
|
query = 'DELETE FROM prekeys WHERE prekey_id = ?'
|
||||||
cursor = self.dbConn.cursor()
|
self._con.execute(query, (preKeyId,))
|
||||||
cursor.execute(q, (preKeyId, ))
|
self._con.commit()
|
||||||
self.dbConn.commit()
|
|
||||||
|
|
||||||
def getCurrentPreKeyId(self):
|
def getCurrentPreKeyId(self):
|
||||||
q = "SELECT MAX(prekey_id) FROM prekeys"
|
query = 'SELECT MAX(prekey_id) FROM prekeys'
|
||||||
cursor = self.dbConn.cursor()
|
return self._con.execute(query).fetchone().max_prekey_id
|
||||||
cursor.execute(q)
|
|
||||||
return cursor.fetchone()[0]
|
|
||||||
|
|
||||||
def getPreKeyCount(self):
|
def getPreKeyCount(self):
|
||||||
q = "SELECT COUNT(prekey_id) FROM prekeys"
|
query = 'SELECT COUNT(prekey_id) FROM prekeys'
|
||||||
cursor = self.dbConn.cursor()
|
return self._con.execute(query).fetchone().count_prekey_id
|
||||||
cursor.execute(q)
|
|
||||||
return cursor.fetchone()[0]
|
|
||||||
|
|
||||||
def generateNewPreKeys(self, count):
|
def generateNewPreKeys(self, count):
|
||||||
startId = self.getCurrentPreKeyId() + 1
|
start_id = self.getCurrentPreKeyId() + 1
|
||||||
preKeys = KeyHelper.generatePreKeys(startId, count)
|
pre_keys = KeyHelper.generatePreKeys(start_id, count)
|
||||||
|
|
||||||
for preKey in preKeys:
|
for pre_key in pre_keys:
|
||||||
self.storePreKey(preKey.getId(), preKey)
|
self.storePreKey(pre_key.getId(), pre_key)
|
||||||
|
|
||||||
def getIdentityKeyPair(self):
|
def getIdentityKeyPair(self):
|
||||||
q = "SELECT public_key, private_key FROM identities " + \
|
query = '''SELECT public_key, private_key FROM identities
|
||||||
"WHERE recipient_id = -1"
|
WHERE recipient_id = -1'''
|
||||||
c = self.dbConn.cursor()
|
result = self._con.execute(query).fetchone()
|
||||||
c.execute(q)
|
|
||||||
result = c.fetchone()
|
|
||||||
|
|
||||||
publicKey, privateKey = result
|
|
||||||
return IdentityKeyPair(
|
return IdentityKeyPair(
|
||||||
IdentityKey(DjbECPublicKey(publicKey[1:])),
|
IdentityKey(DjbECPublicKey(result.public_key[1:])),
|
||||||
DjbECPrivateKey(privateKey))
|
DjbECPrivateKey(result.private_key))
|
||||||
|
|
||||||
def getLocalRegistrationId(self):
|
def getLocalRegistrationId(self):
|
||||||
q = "SELECT registration_id FROM identities WHERE recipient_id = -1"
|
query = 'SELECT registration_id FROM identities WHERE recipient_id = -1'
|
||||||
c = self.dbConn.cursor()
|
result = self._con.execute(query).fetchone()
|
||||||
c.execute(q)
|
return result.registration_id if result is not None else None
|
||||||
result = c.fetchone()
|
|
||||||
return result[0] if result else None
|
|
||||||
|
|
||||||
def storeLocalData(self, registrationId, identityKeyPair):
|
def storeLocalData(self, registrationId, identityKeyPair):
|
||||||
q = "INSERT INTO identities( " + \
|
query = '''INSERT INTO identities(
|
||||||
"recipient_id, registration_id, public_key, private_key) " + \
|
recipient_id, registration_id, public_key, private_key)
|
||||||
"VALUES(-1, ?, ?, ?)"
|
VALUES(-1, ?, ?, ?)'''
|
||||||
c = self.dbConn.cursor()
|
|
||||||
c.execute(q,
|
|
||||||
(registrationId,
|
|
||||||
identityKeyPair.getPublicKey().getPublicKey().serialize(),
|
|
||||||
identityKeyPair.getPrivateKey().serialize()))
|
|
||||||
|
|
||||||
self.dbConn.commit()
|
public_key = identityKeyPair.getPublicKey().getPublicKey().serialize()
|
||||||
|
private_key = identityKeyPair.getPrivateKey().serialize()
|
||||||
|
self._con.execute(query, (registrationId, public_key, private_key))
|
||||||
|
self._con.commit()
|
||||||
|
|
||||||
def saveIdentity(self, recipientId, identityKey):
|
def saveIdentity(self, recipientId, identityKey):
|
||||||
q = "INSERT INTO identities (recipient_id, public_key, trust) " \
|
query = '''INSERT INTO identities (recipient_id, public_key, trust)
|
||||||
"VALUES(?, ?, ?)"
|
VALUES(?, ?, ?)'''
|
||||||
c = self.dbConn.cursor()
|
if not self.containsIdentity(recipientId, identityKey):
|
||||||
|
self._con.execute(query, (recipientId,
|
||||||
|
identityKey.getPublicKey().serialize(),
|
||||||
|
UNDECIDED))
|
||||||
|
self._con.commit()
|
||||||
|
|
||||||
if not self.getIdentity(recipientId, identityKey):
|
def containsIdentity(self, recipientId, identityKey):
|
||||||
c.execute(q, (recipientId,
|
query = '''SELECT * FROM identities WHERE recipient_id = ?
|
||||||
identityKey.getPublicKey().serialize(),
|
AND public_key = ?'''
|
||||||
UNDECIDED))
|
|
||||||
self.dbConn.commit()
|
|
||||||
|
|
||||||
def getIdentity(self, recipientId, identityKey):
|
public_key = identityKey.getPublicKey().serialize()
|
||||||
q = "SELECT * FROM identities WHERE recipient_id = ? " \
|
result = self._con.execute(query, (recipientId,
|
||||||
"AND public_key = ?"
|
public_key)).fetchone()
|
||||||
c = self.dbConn.cursor()
|
|
||||||
|
|
||||||
c.execute(q, (recipientId, identityKey.getPublicKey().serialize()))
|
|
||||||
result = c.fetchone()
|
|
||||||
|
|
||||||
return result is not None
|
return result is not None
|
||||||
|
|
||||||
def deleteIdentity(self, recipientId, identityKey):
|
def deleteIdentity(self, recipientId, identityKey):
|
||||||
q = "DELETE FROM identities WHERE recipient_id = ? AND public_key = ?"
|
query = '''DELETE FROM identities
|
||||||
c = self.dbConn.cursor()
|
WHERE recipient_id = ? AND public_key = ?'''
|
||||||
c.execute(q, (recipientId,
|
public_key = identityKey.getPublicKey().serialize()
|
||||||
identityKey.getPublicKey().serialize()))
|
self._con.execute(query, (recipientId, public_key))
|
||||||
self.dbConn.commit()
|
self._con.commit()
|
||||||
|
|
||||||
def isTrustedIdentity(self, recipientId, identityKey):
|
def isTrustedIdentity(self, recipientId, identityKey):
|
||||||
q = "SELECT trust FROM identities WHERE recipient_id = ? " \
|
query = '''SELECT trust FROM identities WHERE recipient_id = ?
|
||||||
"AND public_key = ?"
|
AND public_key = ?'''
|
||||||
c = self.dbConn.cursor()
|
public_key = identityKey.getPublicKey().serialize()
|
||||||
|
result = self._con.execute(query, (recipientId, public_key)).fetchone()
|
||||||
c.execute(q, (recipientId, identityKey.getPublicKey().serialize()))
|
if result is None:
|
||||||
result = c.fetchone()
|
return True
|
||||||
|
|
||||||
states = [UNTRUSTED, TRUSTED, UNDECIDED]
|
states = [UNTRUSTED, TRUSTED, UNDECIDED]
|
||||||
|
if result.trust in states:
|
||||||
if result and result[0] in states:
|
return result.trust
|
||||||
return result[0]
|
return False
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def getAllFingerprints(self):
|
def getAllFingerprints(self):
|
||||||
q = "SELECT _id, recipient_id, public_key, trust FROM identities " \
|
query = '''SELECT _id, recipient_id, public_key, trust FROM identities
|
||||||
"WHERE recipient_id != -1 ORDER BY recipient_id ASC"
|
WHERE recipient_id != -1 ORDER BY recipient_id ASC'''
|
||||||
c = self.dbConn.cursor()
|
return self._con.execute(query).fetchall()
|
||||||
|
|
||||||
result = []
|
|
||||||
for row in c.execute(q):
|
|
||||||
result.append((row[0], row[1], row[2], row[3]))
|
|
||||||
return result
|
|
||||||
|
|
||||||
def getFingerprints(self, jid):
|
def getFingerprints(self, jid):
|
||||||
q = "SELECT _id, recipient_id, public_key, trust FROM identities " \
|
query = '''SELECT _id, recipient_id, public_key, trust FROM identities
|
||||||
"WHERE recipient_id =? ORDER BY trust ASC"
|
WHERE recipient_id =? ORDER BY trust ASC'''
|
||||||
c = self.dbConn.cursor()
|
return self._con.execute(query, (jid,)).fetchall()
|
||||||
|
|
||||||
result = []
|
|
||||||
c.execute(q, (jid,))
|
|
||||||
rows = c.fetchall()
|
|
||||||
for row in rows:
|
|
||||||
result.append((row[0], row[1], row[2], row[3]))
|
|
||||||
return result
|
|
||||||
|
|
||||||
def getTrustedFingerprints(self, jid):
|
def getTrustedFingerprints(self, jid):
|
||||||
q = "SELECT public_key FROM identities WHERE recipient_id = ? AND trust = ?"
|
query = '''SELECT public_key FROM identities
|
||||||
c = self.dbConn.cursor()
|
WHERE recipient_id = ? AND trust = ?'''
|
||||||
|
result = self._con.execute(query, (jid, TRUSTED)).fetchall()
|
||||||
result = []
|
return [row.public_key for row in result]
|
||||||
c.execute(q, (jid, TRUSTED))
|
|
||||||
rows = c.fetchall()
|
|
||||||
for row in rows:
|
|
||||||
result.append(row[0])
|
|
||||||
return result
|
|
||||||
|
|
||||||
def getUndecidedFingerprints(self, jid):
|
def getUndecidedFingerprints(self, jid):
|
||||||
q = "SELECT trust FROM identities WHERE recipient_id = ? AND trust = ?"
|
query = '''SELECT trust FROM identities
|
||||||
c = self.dbConn.cursor()
|
WHERE recipient_id = ? AND trust = ?'''
|
||||||
|
return self._con.execute(query, (jid, UNDECIDED)).fetchall()
|
||||||
result = []
|
|
||||||
c.execute(q, (jid, UNDECIDED))
|
|
||||||
result = c.fetchall()
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def getNewFingerprints(self, jid):
|
def getNewFingerprints(self, jid):
|
||||||
q = "SELECT _id FROM identities WHERE shown = 0 AND " \
|
query = '''SELECT _id FROM identities WHERE shown = 0
|
||||||
"recipient_id = ?"
|
AND recipient_id = ?'''
|
||||||
c = self.dbConn.cursor()
|
|
||||||
result = []
|
result = self._con.execute(query, (jid,)).fetchall()
|
||||||
for row in c.execute(q, (jid,)):
|
return [row.id for row in result]
|
||||||
result.append(row[0])
|
|
||||||
return result
|
|
||||||
|
|
||||||
def setShownFingerprints(self, fingerprints):
|
def setShownFingerprints(self, fingerprints):
|
||||||
q = "UPDATE identities SET shown = 1 WHERE _id IN ({})" \
|
query = 'UPDATE identities SET shown = 1 WHERE _id IN ({})'.format(
|
||||||
.format(', '.join(['?'] * len(fingerprints)))
|
', '.join(['?'] * len(fingerprints)))
|
||||||
c = self.dbConn.cursor()
|
self._con.execute(query, fingerprints)
|
||||||
c.execute(q, fingerprints)
|
self._con.commit()
|
||||||
self.dbConn.commit()
|
|
||||||
|
|
||||||
def setTrust(self, identityKey, trust):
|
def setTrust(self, identityKey, trust):
|
||||||
q = "UPDATE identities SET trust = ? WHERE public_key = ?"
|
query = 'UPDATE identities SET trust = ? WHERE public_key = ?'
|
||||||
c = self.dbConn.cursor()
|
public_key = identityKey.getPublicKey().serialize()
|
||||||
c.execute(q, (trust, identityKey.getPublicKey().serialize()))
|
self._con.execute(query, (trust, public_key))
|
||||||
self.dbConn.commit()
|
self._con.commit()
|
||||||
|
|
||||||
def activate(self, jid):
|
def activate(self, jid):
|
||||||
q = """INSERT OR REPLACE INTO encryption_state (jid, encryption)
|
query = '''INSERT OR REPLACE INTO encryption_state (jid, encryption)
|
||||||
VALUES (?, 1) """
|
VALUES (?, 1)'''
|
||||||
|
|
||||||
c = self.dbConn.cursor()
|
self._con.execute(query, (jid,))
|
||||||
c.execute(q, (jid, ))
|
self._con.commit()
|
||||||
self.dbConn.commit()
|
|
||||||
|
|
||||||
def deactivate(self, jid):
|
def deactivate(self, jid):
|
||||||
q = """INSERT OR REPLACE INTO encryption_state (jid, encryption)
|
query = '''INSERT OR REPLACE INTO encryption_state (jid, encryption)
|
||||||
VALUES (?, 0)"""
|
VALUES (?, 0)'''
|
||||||
|
|
||||||
c = self.dbConn.cursor()
|
self._con.execute(query, (jid, ))
|
||||||
c.execute(q, (jid, ))
|
self._con.commit()
|
||||||
self.dbConn.commit()
|
|
||||||
|
|
||||||
def is_active(self, jid):
|
def is_active(self, jid):
|
||||||
q = 'SELECT encryption FROM encryption_state where jid = ?;'
|
query = 'SELECT encryption FROM encryption_state where jid = ?'
|
||||||
c = self.dbConn.cursor()
|
result = self._con.execute(query, (jid,)).fetchone()
|
||||||
c.execute(q, (jid, ))
|
return result.encryption if result is not None else False
|
||||||
result = c.fetchone()
|
|
||||||
if result is None:
|
|
||||||
return False
|
|
||||||
return result[0]
|
|
||||||
|
|
||||||
def exist(self, jid):
|
def exist(self, jid):
|
||||||
q = 'SELECT encryption FROM encryption_state where jid = ?;'
|
query = 'SELECT encryption FROM encryption_state where jid = ?'
|
||||||
c = self.dbConn.cursor()
|
result = self._con.execute(query, (jid,)).fetchone()
|
||||||
c.execute(q, (jid, ))
|
return result is not None
|
||||||
result = c.fetchone()
|
|
||||||
if result is None:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|||||||
@@ -17,8 +17,6 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import binascii
|
|
||||||
import textwrap
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from nbxmpp.structs import OMEMOBundle
|
from nbxmpp.structs import OMEMOBundle
|
||||||
@@ -53,14 +51,15 @@ UNDECIDED = 2
|
|||||||
|
|
||||||
|
|
||||||
class OmemoState:
|
class OmemoState:
|
||||||
def __init__(self, own_jid, db_con, account, xmpp_con):
|
def __init__(self, own_jid, db_path, account, xmpp_con):
|
||||||
self.account = account
|
self.account = account
|
||||||
self.xmpp_con = xmpp_con
|
self.xmpp_con = xmpp_con
|
||||||
self._session_ciphers = defaultdict(dict)
|
self._session_ciphers = defaultdict(dict)
|
||||||
self.own_jid = own_jid
|
self.own_jid = own_jid
|
||||||
self.device_ids = {}
|
self.device_ids = {}
|
||||||
self.own_devices = []
|
self.own_devices = []
|
||||||
self.store = LiteAxolotlStore(db_con)
|
|
||||||
|
self.store = LiteAxolotlStore(db_path)
|
||||||
for jid, device_id in self.store.getActiveDeviceTuples():
|
for jid, device_id in self.store.getActiveDeviceTuples():
|
||||||
if jid != own_jid:
|
if jid != own_jid:
|
||||||
self.add_device(jid, device_id)
|
self.add_device(jid, device_id)
|
||||||
|
|||||||
@@ -19,7 +19,6 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import sqlite3
|
|
||||||
|
|
||||||
import nbxmpp
|
import nbxmpp
|
||||||
from nbxmpp.protocol import NodeProcessed
|
from nbxmpp.protocol import NodeProcessed
|
||||||
@@ -30,7 +29,6 @@ from nbxmpp.structs import StanzaHandler
|
|||||||
from nbxmpp.modules.omemo import create_omemo_message
|
from nbxmpp.modules.omemo import create_omemo_message
|
||||||
|
|
||||||
from gajim.common import app
|
from gajim.common import app
|
||||||
from gajim.common import ged
|
|
||||||
from gajim.common import helpers
|
from gajim.common import helpers
|
||||||
from gajim.common import configpaths
|
from gajim.common import configpaths
|
||||||
from gajim.common.nec import NetworkEvent
|
from gajim.common.nec import NetworkEvent
|
||||||
@@ -117,9 +115,7 @@ class OMEMO(BaseModule):
|
|||||||
def __get_omemo(self):
|
def __get_omemo(self):
|
||||||
data_dir = configpaths.get('MY_DATA')
|
data_dir = configpaths.get('MY_DATA')
|
||||||
db_path = os.path.join(data_dir, 'omemo_' + self.own_jid + '.db')
|
db_path = os.path.join(data_dir, 'omemo_' + self.own_jid + '.db')
|
||||||
conn = sqlite3.connect(db_path, check_same_thread=False)
|
return OmemoState(self.own_jid, db_path, self._account, self)
|
||||||
conn.execute("PRAGMA secure_delete=1")
|
|
||||||
return OmemoState(self.own_jid, conn, self._account, self)
|
|
||||||
|
|
||||||
def on_signed_in(self):
|
def on_signed_in(self):
|
||||||
log.info('%s => Announce Support after Sign In', self._account)
|
log.info('%s => Announce Support after Sign In', self._account)
|
||||||
|
|||||||
Reference in New Issue
Block a user