gotr: update provided potr to 1.0.0beta7
This commit is contained in:
@@ -24,4 +24,4 @@ from potr.utils import human_hash
|
||||
|
||||
''' version is: (major, minor, patch, sub) with sub being one of 'alpha',
|
||||
'beta', 'final' '''
|
||||
VERSION = (1, 0, 0, 'beta5')
|
||||
VERSION = (1, 0, 0, 'beta7')
|
||||
|
||||
@@ -26,8 +26,8 @@ from potr.utils import human_hash, bytes_to_long, unpack, pack_mpi
|
||||
DEFAULT_KEYTYPE = 0x0000
|
||||
pkTypes = {}
|
||||
def registerkeytype(cls):
|
||||
if not hasattr(cls, 'parsePayload'):
|
||||
raise TypeError('registered key types need parsePayload()')
|
||||
if cls.keyType is None:
|
||||
raise TypeError('registered key class needs a type value')
|
||||
pkTypes[cls.keyType] = cls
|
||||
return cls
|
||||
|
||||
@@ -35,12 +35,16 @@ def generateDefaultKey():
|
||||
return pkTypes[DEFAULT_KEYTYPE].generate()
|
||||
|
||||
class PK(object):
|
||||
__slots__ = []
|
||||
keyType = None
|
||||
|
||||
@classmethod
|
||||
def generate(cls):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def parsePayload(cls, data, private=False):
|
||||
raise NotImplementedError
|
||||
|
||||
def sign(self, data):
|
||||
raise NotImplementedError
|
||||
def verify(self, data):
|
||||
@@ -80,13 +84,13 @@ class PK(object):
|
||||
@classmethod
|
||||
def parsePrivateKey(cls, data):
|
||||
implCls, data = cls.getImplementation(data)
|
||||
logging.debug('Got privkey of type %r' % implCls)
|
||||
logging.debug('Got privkey of type %r', implCls)
|
||||
return implCls.parsePayload(data, private=True)
|
||||
|
||||
@classmethod
|
||||
def parsePublicKey(cls, data):
|
||||
implCls, data = cls.getImplementation(data)
|
||||
logging.debug('Got pubkey of type %r' % implCls)
|
||||
logging.debug('Got pubkey of type %r', implCls)
|
||||
return implCls.parsePayload(data)
|
||||
|
||||
def __str__(self):
|
||||
|
||||
@@ -15,18 +15,16 @@
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from Crypto import Cipher, Random
|
||||
from Crypto import Cipher
|
||||
from Crypto.Hash import SHA256 as _SHA256
|
||||
from Crypto.Hash import SHA as _SHA1
|
||||
from Crypto.Hash import SHA as _SHA1
|
||||
from Crypto.Hash import HMAC as _HMAC
|
||||
from Crypto.PublicKey import DSA
|
||||
from Crypto.Random import random
|
||||
from numbers import Number
|
||||
|
||||
from potr.compatcrypto import common
|
||||
from potr.utils import pack_mpi, read_mpi, bytes_to_long, long_to_bytes
|
||||
|
||||
# XXX atfork?
|
||||
RNG = Random.new()
|
||||
from potr.utils import read_mpi, bytes_to_long, long_to_bytes
|
||||
|
||||
def SHA256(data):
|
||||
return _SHA256.new(data).digest()
|
||||
@@ -54,7 +52,6 @@ def AESCTR(key, counter=0):
|
||||
return Cipher.AES.new(key, Cipher.AES.MODE_CTR, counter=counter)
|
||||
|
||||
class Counter(object):
|
||||
__slots__ = ['prefix', 'val']
|
||||
def __init__(self, prefix):
|
||||
self.prefix = prefix
|
||||
self.val = 0
|
||||
@@ -72,17 +69,15 @@ class Counter(object):
|
||||
return '<Counter(p={p!r},v={v!r})>'.format(p=self.prefix, v=self.val)
|
||||
|
||||
def byteprefix(self):
|
||||
return long_to_bytes(self.prefix).rjust(8, b'\0')
|
||||
return long_to_bytes(self.prefix, 8)
|
||||
|
||||
def __call__(self):
|
||||
val = long_to_bytes(self.val)
|
||||
prefix = long_to_bytes(self.prefix)
|
||||
bytesuffix = long_to_bytes(self.val, 8)
|
||||
self.val += 1
|
||||
return self.byteprefix() + val.rjust(8, b'\0')
|
||||
return self.byteprefix() + bytesuffix
|
||||
|
||||
@common.registerkeytype
|
||||
class DSAKey(common.PK):
|
||||
__slots__ = ['priv', 'pub']
|
||||
keyType = 0x0000
|
||||
|
||||
def __init__(self, key=None, private=False):
|
||||
@@ -111,10 +106,10 @@ class DSAKey(common.PK):
|
||||
return SHA1(self.getSerializedPublicPayload())
|
||||
|
||||
def sign(self, data):
|
||||
# 2 <= K <= q = 160bit = 20 byte
|
||||
K = bytes_to_long(RNG.read(19)) + 2
|
||||
# 2 <= K <= q
|
||||
K = random.randrange(2, self.priv.q)
|
||||
r, s = self.priv.sign(data, K)
|
||||
return long_to_bytes(r) + long_to_bytes(s)
|
||||
return long_to_bytes(r, 20) + long_to_bytes(s, 20)
|
||||
|
||||
def verify(self, data, sig):
|
||||
r, s = bytes_to_long(sig[:20]), bytes_to_long(sig[20:])
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
try:
|
||||
basestring = basestring
|
||||
type(basestring)
|
||||
except NameError:
|
||||
# all strings are unicode in python3k
|
||||
basestring = str
|
||||
@@ -27,7 +27,7 @@ except NameError:
|
||||
|
||||
# callable is not available in python 3.0 and 3.1
|
||||
try:
|
||||
callable = callable
|
||||
type(callable)
|
||||
except NameError:
|
||||
from collections import Callable
|
||||
def callable(x):
|
||||
@@ -42,6 +42,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
from potr import crypt
|
||||
from potr import proto
|
||||
from potr import compatcrypto
|
||||
|
||||
from time import time
|
||||
|
||||
@@ -62,16 +63,11 @@ OFFER_REJECTED = 2
|
||||
OFFER_ACCEPTED = 3
|
||||
|
||||
class Context(object):
|
||||
__slots__ = ['user', 'policy', 'crypto', 'tagOffer', 'lastSend',
|
||||
'lastMessage', 'mayRetransmit', 'fragment', 'fragmentInfo', 'state',
|
||||
'inject', 'trust', 'peer', 'trustName']
|
||||
|
||||
def __init__(self, account, peername):
|
||||
self.user = account
|
||||
self.peer = peername
|
||||
self.policy = {}
|
||||
self.crypto = crypt.CryptEngine(self)
|
||||
self.discardFragment()
|
||||
self.tagOffer = OFFER_NOTSENT
|
||||
self.mayRetransmit = 0
|
||||
self.lastSend = 0
|
||||
@@ -79,6 +75,10 @@ class Context(object):
|
||||
self.state = STATE_PLAINTEXT
|
||||
self.trustName = self.peer
|
||||
|
||||
self.fragmentInfo = None
|
||||
self.fragment = None
|
||||
self.discardFragment()
|
||||
|
||||
def getPolicy(self, key):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -100,13 +100,19 @@ class Context(object):
|
||||
params = message.split(b',')
|
||||
if len(params) < 5 or not params[1].isdigit() or not params[2].isdigit():
|
||||
logger.warning('invalid formed fragmented message: %r', params)
|
||||
return None
|
||||
self.discardFragment()
|
||||
return message
|
||||
|
||||
|
||||
K, N = self.fragmentInfo
|
||||
try:
|
||||
k = int(params[1])
|
||||
n = int(params[2])
|
||||
except ValueError:
|
||||
logger.warning('invalid formed fragmented message: %r', params)
|
||||
self.discardFragment()
|
||||
return message
|
||||
|
||||
k = int(params[1])
|
||||
n = int(params[2])
|
||||
fragData = params[3]
|
||||
|
||||
logger.debug(params)
|
||||
@@ -114,17 +120,17 @@ class Context(object):
|
||||
if n >= k == 1:
|
||||
# first fragment
|
||||
self.discardFragment()
|
||||
self.fragmentInfo = (k,n)
|
||||
self.fragmentInfo = (k, n)
|
||||
self.fragment.append(fragData)
|
||||
elif N == n >= k > 1 and k == K+1:
|
||||
# accumulate
|
||||
self.fragmentInfo = (k,n)
|
||||
self.fragmentInfo = (k, n)
|
||||
self.fragment.append(fragData)
|
||||
else:
|
||||
# bad, discard
|
||||
self.discardFragment()
|
||||
logger.warning('invalid fragmented message: %r', params)
|
||||
return None
|
||||
return message
|
||||
|
||||
if n == k > 0:
|
||||
assembled = b''.join(self.fragment)
|
||||
@@ -210,7 +216,7 @@ class Context(object):
|
||||
if self.state != STATE_ENCRYPTED:
|
||||
self.sendInternal(proto.Error(
|
||||
'You sent encrypted to {user}, who wasn\'t expecting it.'
|
||||
.format(user=self.user.name)), appdata=appdata)
|
||||
.format(user=self.user.name).encode('utf-8')), appdata=appdata)
|
||||
if ignore:
|
||||
return IGN
|
||||
raise NotEncryptedError(EXC_UNREADABLE_MESSAGE)
|
||||
@@ -263,12 +269,13 @@ class Context(object):
|
||||
return msg
|
||||
|
||||
def processOutgoingMessage(self, msg, flags, tlvs=[]):
|
||||
if isinstance(self.parse(msg), proto.Query):
|
||||
isQuery = self.parseExplicitQuery(msg) is not None
|
||||
if isQuery:
|
||||
return self.user.getDefaultQueryMessage(self.getPolicy)
|
||||
|
||||
if self.state == STATE_PLAINTEXT:
|
||||
if self.getPolicy('REQUIRE_ENCRYPTION'):
|
||||
if not isinstance(self.parse(msg), proto.Query):
|
||||
if not isQuery:
|
||||
self.lastMessage = msg
|
||||
self.lastSend = time()
|
||||
self.mayRetransmit = 2
|
||||
@@ -277,8 +284,12 @@ class Context(object):
|
||||
return msg
|
||||
if self.getPolicy('SEND_TAG') and self.tagOffer != OFFER_REJECTED:
|
||||
self.tagOffer = OFFER_SENT
|
||||
return proto.TaggedPlaintext(msg, self.getPolicy('ALLOW_V1'),
|
||||
self.getPolicy('ALLOW_V2'))
|
||||
versions = set()
|
||||
if self.getPolicy('ALLOW_V1'):
|
||||
versions.add(1)
|
||||
if self.getPolicy('ALLOW_V2'):
|
||||
versions.add(2)
|
||||
return proto.TaggedPlaintext(msg, versions)
|
||||
return msg
|
||||
if self.state == STATE_ENCRYPTED:
|
||||
msg = self.crypto.createDataMessage(msg, flags, tlvs)
|
||||
@@ -304,9 +315,9 @@ class Context(object):
|
||||
def sendFragmented(self, msg, policy=FRAGMENT_SEND_ALL, appdata=None):
|
||||
mms = self.maxMessageSize(appdata)
|
||||
msgLen = len(msg)
|
||||
if mms != 0 and len(msg) > mms:
|
||||
if mms != 0 and msgLen > mms:
|
||||
fms = mms - 19
|
||||
fragments = [ msg[i:i+fms] for i in range(0, len(msg), fms) ]
|
||||
fragments = [ msg[i:i+fms] for i in range(0, msgLen, fms) ]
|
||||
|
||||
fc = len(fragments)
|
||||
|
||||
@@ -375,9 +386,9 @@ class Context(object):
|
||||
self.crypto.smpSecret(secret, question=question, appdata=appdata)
|
||||
|
||||
def handleQuery(self, message, appdata=None):
|
||||
if message.v2 and self.getPolicy('ALLOW_V2'):
|
||||
if 2 in message.versions and self.getPolicy('ALLOW_V2'):
|
||||
self.authStartV2(appdata=appdata)
|
||||
elif message.v1 and self.getPolicy('ALLOW_V1'):
|
||||
elif 1 in message.versions and self.getPolicy('ALLOW_V1'):
|
||||
self.authStartV1(appdata=appdata)
|
||||
|
||||
def authStartV1(self, appdata=None):
|
||||
@@ -386,7 +397,33 @@ class Context(object):
|
||||
def authStartV2(self, appdata=None):
|
||||
self.crypto.startAKE(appdata=appdata)
|
||||
|
||||
def parse(self, message):
|
||||
def parseExplicitQuery(self, message):
|
||||
otrTagPos = message.find(proto.OTRTAG)
|
||||
|
||||
if otrTagPos == -1:
|
||||
return None
|
||||
|
||||
indexBase = otrTagPos + len(proto.OTRTAG)
|
||||
|
||||
if len(message) <= indexBase:
|
||||
return None
|
||||
|
||||
compare = message[indexBase]
|
||||
|
||||
hasq = compare == b'?'[0]
|
||||
hasv = compare == b'v'[0]
|
||||
|
||||
if not hasq and not hasv:
|
||||
return None
|
||||
|
||||
hasv |= len(message) > indexBase+1 and message[indexBase+1] == b'v'[0]
|
||||
if hasv:
|
||||
end = message.find(b'?', indexBase+1)
|
||||
else:
|
||||
end = indexBase+1
|
||||
return message[indexBase:end]
|
||||
|
||||
def parse(self, message, nofragment=False):
|
||||
otrTagPos = message.find(proto.OTRTAG)
|
||||
if otrTagPos == -1:
|
||||
if proto.MESSAGE_TAG_BASE in message:
|
||||
@@ -395,38 +432,40 @@ class Context(object):
|
||||
return message
|
||||
|
||||
indexBase = otrTagPos + len(proto.OTRTAG)
|
||||
|
||||
if len(message) <= indexBase:
|
||||
return message
|
||||
|
||||
compare = message[indexBase]
|
||||
|
||||
if compare == b','[0]:
|
||||
if nofragment is False and compare == b','[0]:
|
||||
message = self.fragmentAccumulate(message[indexBase:])
|
||||
if message is None:
|
||||
return None
|
||||
else:
|
||||
return self.parse(message)
|
||||
return self.parse(message, nofragment=True)
|
||||
else:
|
||||
self.discardFragment()
|
||||
|
||||
hasq = compare == b'?'[0]
|
||||
hasv = compare == b'v'[0]
|
||||
if hasq or hasv:
|
||||
hasv |= len(message) > indexBase+1 and \
|
||||
message[indexBase+1] == b'v'[0]
|
||||
if hasv:
|
||||
end = message.find(b'?', indexBase+1)
|
||||
else:
|
||||
end = indexBase+1
|
||||
payload = message[indexBase:end]
|
||||
return proto.Query.parse(payload)
|
||||
queryPayload = self.parseExplicitQuery(message)
|
||||
if queryPayload is not None:
|
||||
return proto.Query.parse(queryPayload)
|
||||
|
||||
if compare == b':'[0] and len(message) > indexBase + 4:
|
||||
infoTag = base64.b64decode(message[indexBase+1:indexBase+5])
|
||||
classInfo = struct.unpack(b'!HB', infoTag)
|
||||
cls = proto.messageClasses.get(classInfo, None)
|
||||
if cls is None:
|
||||
try:
|
||||
infoTag = base64.b64decode(message[indexBase+1:indexBase+5])
|
||||
classInfo = struct.unpack(b'!HB', infoTag)
|
||||
|
||||
cls = proto.messageClasses.get(classInfo, None)
|
||||
if cls is None:
|
||||
return message
|
||||
|
||||
logger.debug('{user} got msg {typ!r}' \
|
||||
.format(user=self.user.name, typ=cls))
|
||||
return cls.parsePayload(message[indexBase+5:])
|
||||
except (TypeError, struct.error):
|
||||
logger.exception('could not parse OTR message %s', message)
|
||||
return message
|
||||
logger.debug('{user} got msg {typ!r}' \
|
||||
.format(user=self.user.name, typ=cls))
|
||||
return cls.parsePayload(message[indexBase+5:])
|
||||
|
||||
if message[indexBase:indexBase+7] == b' Error:':
|
||||
return proto.Error(message[indexBase+7:])
|
||||
@@ -437,6 +476,22 @@ class Context(object):
|
||||
"""Return the max message size for this context."""
|
||||
return self.user.maxMessageSize
|
||||
|
||||
def getExtraKey(self, extraKeyAppId=None, extraKeyAppData=None, appdata=None):
|
||||
""" retrieves the generated extra symmetric key.
|
||||
|
||||
if extraKeyAppId is set, notifies the chat partner about intended
|
||||
usage (additional application specific information can be supplied in
|
||||
extraKeyAppData).
|
||||
|
||||
returns the 256 bit symmetric key """
|
||||
|
||||
if self.state != STATE_ENCRYPTED:
|
||||
raise NotEncryptedError
|
||||
if extraKeyAppId is not None:
|
||||
tlvs = [proto.ExtraKeyTLV(extraKeyAppId, extraKeyAppData)]
|
||||
self.sendInternal(b'', tlvs=tlvs, appdata=appdata)
|
||||
return self.crypto.extraKey
|
||||
|
||||
class Account(object):
|
||||
contextclass = Context
|
||||
def __init__(self, name, protocol, maxMessageSize, privkey=None):
|
||||
@@ -447,10 +502,10 @@ class Account(object):
|
||||
self.ctxs = {}
|
||||
self.trusts = {}
|
||||
self.maxMessageSize = maxMessageSize
|
||||
self.defaultQuery = b'?OTRv{versions}?\n{accountname} has requested ' \
|
||||
b'an Off-the-Record private conversation. However, you ' \
|
||||
b'do not have a plugin to support that.\nSee '\
|
||||
b'http://otr.cypherpunks.ca/ for more information.';
|
||||
self.defaultQuery = '?OTRv{versions}?\n{accountname} has requested ' \
|
||||
'an Off-the-Record private conversation. However, you ' \
|
||||
'do not have a plugin to support that.\nSee '\
|
||||
'http://otr.cypherpunks.ca/ for more information.'
|
||||
|
||||
def __repr__(self):
|
||||
return '<{cls}(name={name!r})>'.format(cls=self.__class__.__name__,
|
||||
@@ -461,7 +516,7 @@ class Account(object):
|
||||
self.privkey = self.loadPrivkey()
|
||||
if self.privkey is None:
|
||||
if autogen is True:
|
||||
self.privkey = crypt.generateDefaultKey()
|
||||
self.privkey = compatcrypto.generateDefaultKey()
|
||||
self.savePrivkey()
|
||||
else:
|
||||
raise LookupError
|
||||
@@ -484,8 +539,9 @@ class Account(object):
|
||||
return self.ctxs[uid]
|
||||
|
||||
def getDefaultQueryMessage(self, policy):
|
||||
v = b'2' if policy('ALLOW_V2') else b''
|
||||
return self.defaultQuery.format(accountname=self.name, versions=v)
|
||||
v = '2' if policy('ALLOW_V2') else ''
|
||||
msg = self.defaultQuery.format(accountname=self.name, versions=v)
|
||||
return msg.encode('ascii')
|
||||
|
||||
def setTrust(self, key, fingerprint, trustLevel):
|
||||
if key not in self.trusts:
|
||||
|
||||
@@ -22,8 +22,8 @@ import logging
|
||||
import struct
|
||||
|
||||
|
||||
from potr.compatcrypto import SHA256, SHA1, HMAC, SHA1HMAC, SHA256HMAC, \
|
||||
SHA256HMAC160, Counter, AESCTR, RNG, PK, generateDefaultKey
|
||||
from potr.compatcrypto import SHA256, SHA1, SHA1HMAC, SHA256HMAC, \
|
||||
SHA256HMAC160, Counter, AESCTR, PK, random
|
||||
from potr.utils import bytes_to_long, long_to_bytes, pack_mpi, read_mpi
|
||||
from potr import proto
|
||||
|
||||
@@ -36,32 +36,31 @@ STATE_AWAITING_SIG = 4
|
||||
STATE_V1_SETUP = 5
|
||||
|
||||
|
||||
DH1536_MODULUS = 2410312426921032588552076022197566074856950548502459942654116941958108831682612228890093858261341614673227141477904012196503648957050582631942730706805009223062734745341073406696246014589361659774041027169249453200378729434170325843778659198143763193776859869524088940195577346119843545301547043747207749969763750084308926339295559968882457872412993810129130294592999947926365264059284647209730384947211681434464714438488520940127459844288859336526896320919633919
|
||||
DH1536_MODULUS_2 = DH1536_MODULUS-2
|
||||
DH1536_GENERATOR = 2
|
||||
SM_ORDER = (DH1536_MODULUS - 1) // 2
|
||||
DH_MODULUS = 2410312426921032588552076022197566074856950548502459942654116941958108831682612228890093858261341614673227141477904012196503648957050582631942730706805009223062734745341073406696246014589361659774041027169249453200378729434170325843778659198143763193776859869524088940195577346119843545301547043747207749969763750084308926339295559968882457872412993810129130294592999947926365264059284647209730384947211681434464714438488520940127459844288859336526896320919633919
|
||||
DH_MODULUS_2 = DH_MODULUS-2
|
||||
DH_GENERATOR = 2
|
||||
DH_BITS = 1536
|
||||
DH_MAX = 2**DH_BITS
|
||||
SM_ORDER = (DH_MODULUS - 1) // 2
|
||||
|
||||
def check_group(n):
|
||||
return 2 <= n <= DH1536_MODULUS_2
|
||||
return 2 <= n <= DH_MODULUS_2
|
||||
def check_exp(n):
|
||||
return 1 <= n < SM_ORDER
|
||||
|
||||
class DH(object):
|
||||
__slots__ = ['priv', 'pub']
|
||||
@classmethod
|
||||
def set_params(cls, prime, gen):
|
||||
cls.prime = prime
|
||||
cls.gen = gen
|
||||
|
||||
def __init__(self):
|
||||
self.priv = bytes_to_long(RNG.read(40))
|
||||
self.priv = random.randrange(2, 2**320)
|
||||
self.pub = pow(self.gen, self.priv, self.prime)
|
||||
|
||||
DH.set_params(DH1536_MODULUS, DH1536_GENERATOR)
|
||||
DH.set_params(DH_MODULUS, DH_GENERATOR)
|
||||
|
||||
class DHSession(object):
|
||||
__slots__ = ['sendenc', 'sendmac', 'rcvenc', 'rcvmac', 'sendctr', 'rcvctr',
|
||||
'sendmacused', 'rcvmacused']
|
||||
def __init__(self, sendenc, sendmac, rcvenc, rcvmac):
|
||||
self.sendenc = sendenc
|
||||
self.sendmac = sendmac
|
||||
@@ -79,7 +78,7 @@ class DHSession(object):
|
||||
|
||||
@classmethod
|
||||
def create(cls, dh, y):
|
||||
s = pow(y, dh.priv, DH1536_MODULUS)
|
||||
s = pow(y, dh.priv, DH_MODULUS)
|
||||
sb = pack_mpi(s)
|
||||
|
||||
if dh.pub > y:
|
||||
@@ -96,9 +95,6 @@ class DHSession(object):
|
||||
return cls(sendenc, sendmac, rcvenc, rcvmac)
|
||||
|
||||
class CryptEngine(object):
|
||||
__slots__ = ['ctx', 'ake', 'sessionId', 'sessionIdHalf', 'theirKeyid',
|
||||
'theirY', 'theirOldY', 'ourOldDHKey', 'ourDHKey', 'ourKeyid',
|
||||
'sessionkeys', 'theirPubkey', 'savedMacKeys', 'smp']
|
||||
def __init__(self, ctx):
|
||||
self.ctx = ctx
|
||||
self.ake = None
|
||||
@@ -118,6 +114,7 @@ class CryptEngine(object):
|
||||
self.savedMacKeys = []
|
||||
|
||||
self.smp = None
|
||||
self.extraKey = None
|
||||
|
||||
def revealMacs(self, ours=True):
|
||||
if ours:
|
||||
@@ -174,7 +171,7 @@ class CryptEngine(object):
|
||||
if msg.mac != SHA1HMAC(sesskey.rcvmac, msg.getMacedData()):
|
||||
logger.error('HMACs don\'t match')
|
||||
raise InvalidParameterError
|
||||
sesskey.rcvmacused = 1
|
||||
sesskey.rcvmacused = True
|
||||
|
||||
newCtrPrefix = bytes_to_long(msg.ctr)
|
||||
if newCtrPrefix <= sesskey.rcvctr.prefix:
|
||||
@@ -223,11 +220,14 @@ class CryptEngine(object):
|
||||
self.smp = SMPHandler(self)
|
||||
self.smp.abort(appdata=appdata)
|
||||
|
||||
def createDataMessage(self, message, flags=0, tlvs=[]):
|
||||
def createDataMessage(self, message, flags=0, tlvs=None):
|
||||
# check MSGSTATE
|
||||
if self.theirKeyid == 0:
|
||||
raise InvalidParameterError
|
||||
|
||||
if tlvs is None:
|
||||
tlvs = []
|
||||
|
||||
sess = self.sessionkeys[1][0]
|
||||
sess.sendctr.inc()
|
||||
|
||||
@@ -303,13 +303,16 @@ class CryptEngine(object):
|
||||
self.ourKeyid = ake.ourKeyid
|
||||
self.theirY = ake.gy
|
||||
self.theirOldY = None
|
||||
self.extraKey = ake.extraKey
|
||||
|
||||
if self.ourKeyid != ake.ourKeyid + 1 or self.ourOldDHKey != ake.dh.pub:
|
||||
# XXX is this really ok?
|
||||
self.ourDHKey = ake.dh
|
||||
self.sessionkeys[0][0] = DHSession.create(self.ourDHKey, self.theirY)
|
||||
self.rotateDHKeys()
|
||||
|
||||
# we don't need the AKE anymore, free the reference
|
||||
self.ake = None
|
||||
|
||||
self.ctx._wentEncrypted()
|
||||
logger.info('went encrypted with {0}'.format(self.theirPubkey))
|
||||
|
||||
@@ -317,10 +320,6 @@ class CryptEngine(object):
|
||||
self.smp = None
|
||||
|
||||
class AuthKeyExchange(object):
|
||||
__slots__ = ['privkey', 'state', 'r', 'encgx', 'hashgx', 'ourKeyid',
|
||||
'theirPubkey', 'theirKeyid', 'enc_c', 'enc_cp', 'mac_m1',
|
||||
'mac_m1p', 'mac_m2', 'mac_m2p', 'sessionId', 'dh', 'onSuccess',
|
||||
'gy', 'lastmsg', 'sessionIdHalf']
|
||||
def __init__(self, privkey, onSuccess):
|
||||
self.privkey = privkey
|
||||
self.state = STATE_NONE
|
||||
@@ -341,9 +340,11 @@ class AuthKeyExchange(object):
|
||||
self.dh = DH()
|
||||
self.onSuccess = onSuccess
|
||||
self.gy = None
|
||||
self.extraKey = None
|
||||
self.lastmsg = None
|
||||
|
||||
def startAKE(self):
|
||||
self.r = RNG.read(16)
|
||||
self.r = long_to_bytes(random.getrandbits(128))
|
||||
|
||||
gxmpi = pack_mpi(self.dh.pub)
|
||||
|
||||
@@ -444,15 +445,17 @@ class AuthKeyExchange(object):
|
||||
self.state = STATE_NONE
|
||||
|
||||
def createAuthKeys(self):
|
||||
s = pow(self.gy, self.dh.priv, DH1536_MODULUS)
|
||||
s = pow(self.gy, self.dh.priv, DH_MODULUS)
|
||||
sbyte = pack_mpi(s)
|
||||
self.sessionId = SHA256(b'\0' + sbyte)[:8]
|
||||
enc = SHA256(b'\1' + sbyte)
|
||||
self.enc_c, self.enc_cp = enc[:16], enc[16:]
|
||||
self.mac_m1 = SHA256(b'\2' + sbyte)
|
||||
self.mac_m2 = SHA256(b'\3' + sbyte)
|
||||
self.mac_m1p = SHA256(b'\4' + sbyte)
|
||||
self.mac_m2p = SHA256(b'\5' + sbyte)
|
||||
self.sessionId = SHA256(b'\x00' + sbyte)[:8]
|
||||
enc = SHA256(b'\x01' + sbyte)
|
||||
self.enc_c = enc[:16]
|
||||
self.enc_cp = enc[16:]
|
||||
self.mac_m1 = SHA256(b'\x02' + sbyte)
|
||||
self.mac_m2 = SHA256(b'\x03' + sbyte)
|
||||
self.mac_m1p = SHA256(b'\x04' + sbyte)
|
||||
self.mac_m2p = SHA256(b'\x05' + sbyte)
|
||||
self.extraKey = SHA256(b'\xff' + sbyte)
|
||||
|
||||
def calculatePubkeyAuth(self, key, mackey):
|
||||
pubkey = self.privkey.serializePublicKey()
|
||||
@@ -490,14 +493,15 @@ SMPPROG_FAILED = -1
|
||||
SMPPROG_SUCCEEDED = 1
|
||||
|
||||
class SMPHandler:
|
||||
__slots__ = ['crypto', 'questionReceived', 'prog', 'state', 'g1', 'g3o',
|
||||
'x2', 'x3', 'g2', 'g3', 'pab', 'qab', 'secret', 'p', 'q']
|
||||
|
||||
def __init__(self, crypto):
|
||||
self.crypto = crypto
|
||||
self.state = 1
|
||||
self.g1 = DH1536_GENERATOR
|
||||
self.g1 = DH_GENERATOR
|
||||
self.g2 = None
|
||||
self.g3 = None
|
||||
self.g3o = None
|
||||
self.x2 = None
|
||||
self.x3 = None
|
||||
self.prog = SMPPROG_OK
|
||||
self.pab = None
|
||||
self.qab = None
|
||||
@@ -539,11 +543,11 @@ class SMPHandler:
|
||||
|
||||
self.g3o = msg[3]
|
||||
|
||||
self.x2 = bytes_to_long(RNG.read(192))
|
||||
self.x3 = bytes_to_long(RNG.read(192))
|
||||
self.x2 = random.randrange(2, DH_MAX)
|
||||
self.x3 = random.randrange(2, DH_MAX)
|
||||
|
||||
self.g2 = pow(msg[0], self.x2, DH1536_MODULUS)
|
||||
self.g3 = pow(msg[3], self.x3, DH1536_MODULUS)
|
||||
self.g2 = pow(msg[0], self.x2, DH_MODULUS)
|
||||
self.g3 = pow(msg[3], self.x3, DH_MODULUS)
|
||||
|
||||
self.prog = SMPPROG_OK
|
||||
self.state = 0
|
||||
@@ -568,29 +572,29 @@ class SMPHandler:
|
||||
return
|
||||
|
||||
self.g3o = msg[3]
|
||||
self.g2 = pow(msg[0], self.x2, DH1536_MODULUS)
|
||||
self.g3 = pow(msg[3], self.x3, DH1536_MODULUS)
|
||||
self.g2 = pow(msg[0], self.x2, DH_MODULUS)
|
||||
self.g3 = pow(msg[3], self.x3, DH_MODULUS)
|
||||
|
||||
if not self.check_equal_coords(msg[6:11], 5):
|
||||
logger.error('invalid SMP2TLV received')
|
||||
self.abort(appdata=appdata)
|
||||
return
|
||||
|
||||
r = bytes_to_long(RNG.read(192))
|
||||
self.p = pow(self.g3, r, DH1536_MODULUS)
|
||||
r = random.randrange(2, DH_MAX)
|
||||
self.p = pow(self.g3, r, DH_MODULUS)
|
||||
msg = [self.p]
|
||||
qa1 = pow(self.g1, r, DH1536_MODULUS)
|
||||
qa2 = pow(self.g2, self.secret, DH1536_MODULUS)
|
||||
self.q = qa1*qa2 % DH1536_MODULUS
|
||||
qa1 = pow(self.g1, r, DH_MODULUS)
|
||||
qa2 = pow(self.g2, self.secret, DH_MODULUS)
|
||||
self.q = qa1*qa2 % DH_MODULUS
|
||||
msg.append(self.q)
|
||||
msg += self.proof_equal_coords(r, 6)
|
||||
|
||||
inv = invMod(mp)
|
||||
self.pab = self.p * inv % DH1536_MODULUS
|
||||
self.pab = self.p * inv % DH_MODULUS
|
||||
inv = invMod(mq)
|
||||
self.qab = self.q * inv % DH1536_MODULUS
|
||||
self.qab = self.q * inv % DH_MODULUS
|
||||
|
||||
msg.append(pow(self.qab, self.x3, DH1536_MODULUS))
|
||||
msg.append(pow(self.qab, self.x3, DH_MODULUS))
|
||||
msg += self.proof_equal_logs(7)
|
||||
|
||||
self.state = 4
|
||||
@@ -613,9 +617,9 @@ class SMPHandler:
|
||||
return
|
||||
|
||||
inv = invMod(self.p)
|
||||
self.pab = msg[0] * inv % DH1536_MODULUS
|
||||
self.pab = msg[0] * inv % DH_MODULUS
|
||||
inv = invMod(self.q)
|
||||
self.qab = msg[1] * inv % DH1536_MODULUS
|
||||
self.qab = msg[1] * inv % DH_MODULUS
|
||||
|
||||
if not self.check_equal_logs(msg[5:8], 7):
|
||||
logger.error('invalid SMP3TLV received')
|
||||
@@ -623,10 +627,10 @@ class SMPHandler:
|
||||
return
|
||||
|
||||
md = msg[5]
|
||||
msg = [pow(self.qab, self.x3, DH1536_MODULUS)]
|
||||
msg = [pow(self.qab, self.x3, DH_MODULUS)]
|
||||
msg += self.proof_equal_logs(8)
|
||||
|
||||
rab = pow(md, self.x3, DH1536_MODULUS)
|
||||
rab = pow(md, self.x3, DH_MODULUS)
|
||||
self.prog = SMPPROG_SUCCEEDED if self.pab == rab else SMPPROG_FAILED
|
||||
|
||||
if self.prog != SMPPROG_SUCCEEDED:
|
||||
@@ -654,7 +658,7 @@ class SMPHandler:
|
||||
self.abort(appdata=appdata)
|
||||
return
|
||||
|
||||
rab = pow(msg[0], self.x3, DH1536_MODULUS)
|
||||
rab = pow(msg[0], self.x3, DH_MODULUS)
|
||||
|
||||
self.prog = SMPPROG_SUCCEEDED if self.pab == rab else SMPPROG_FAILED
|
||||
|
||||
@@ -679,12 +683,12 @@ class SMPHandler:
|
||||
|
||||
self.secret = bytes_to_long(combSecret)
|
||||
|
||||
self.x2 = bytes_to_long(RNG.read(192))
|
||||
self.x3 = bytes_to_long(RNG.read(192))
|
||||
self.x2 = random.randrange(2, DH_MAX)
|
||||
self.x3 = random.randrange(2, DH_MAX)
|
||||
|
||||
msg = [pow(self.g1, self.x2, DH1536_MODULUS)]
|
||||
msg = [pow(self.g1, self.x2, DH_MODULUS)]
|
||||
msg += proof_known_log(self.g1, self.x2, 1)
|
||||
msg.append(pow(self.g1, self.x3, DH1536_MODULUS))
|
||||
msg.append(pow(self.g1, self.x3, DH_MODULUS))
|
||||
msg += proof_known_log(self.g1, self.x3, 2)
|
||||
|
||||
self.prog = SMPPROG_OK
|
||||
@@ -700,19 +704,19 @@ class SMPHandler:
|
||||
|
||||
self.secret = bytes_to_long(combSecret)
|
||||
|
||||
msg = [pow(self.g1, self.x2, DH1536_MODULUS)]
|
||||
msg = [pow(self.g1, self.x2, DH_MODULUS)]
|
||||
msg += proof_known_log(self.g1, self.x2, 3)
|
||||
msg.append(pow(self.g1, self.x3, DH1536_MODULUS))
|
||||
msg.append(pow(self.g1, self.x3, DH_MODULUS))
|
||||
msg += proof_known_log(self.g1, self.x3, 4)
|
||||
|
||||
r = bytes_to_long(RNG.read(192))
|
||||
r = random.randrange(2, DH_MAX)
|
||||
|
||||
self.p = pow(self.g3, r, DH1536_MODULUS)
|
||||
self.p = pow(self.g3, r, DH_MODULUS)
|
||||
msg.append(self.p)
|
||||
|
||||
qb1 = pow(self.g1, r, DH1536_MODULUS)
|
||||
qb2 = pow(self.g2, self.secret, DH1536_MODULUS)
|
||||
self.q = qb1 * qb2 % DH1536_MODULUS
|
||||
qb1 = pow(self.g1, r, DH_MODULUS)
|
||||
qb2 = pow(self.g2, self.secret, DH_MODULUS)
|
||||
self.q = qb1 * qb2 % DH_MODULUS
|
||||
msg.append(self.q)
|
||||
|
||||
msg += self.proof_equal_coords(r, 5)
|
||||
@@ -721,11 +725,11 @@ class SMPHandler:
|
||||
self.sendTLV(proto.SMP2TLV(msg), appdata=appdata)
|
||||
|
||||
def proof_equal_coords(self, r, v):
|
||||
r1 = bytes_to_long(RNG.read(192))
|
||||
r2 = bytes_to_long(RNG.read(192))
|
||||
temp2 = pow(self.g1, r1, DH1536_MODULUS) \
|
||||
* pow(self.g2, r2, DH1536_MODULUS) % DH1536_MODULUS
|
||||
temp1 = pow(self.g3, r1, DH1536_MODULUS)
|
||||
r1 = random.randrange(2, DH_MAX)
|
||||
r2 = random.randrange(2, DH_MAX)
|
||||
temp2 = pow(self.g1, r1, DH_MODULUS) \
|
||||
* pow(self.g2, r2, DH_MODULUS) % DH_MODULUS
|
||||
temp1 = pow(self.g3, r1, DH_MODULUS)
|
||||
|
||||
cb = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
|
||||
c = bytes_to_long(cb)
|
||||
@@ -739,21 +743,21 @@ class SMPHandler:
|
||||
|
||||
def check_equal_coords(self, coords, v):
|
||||
(p, q, c, d1, d2) = coords
|
||||
temp1 = pow(self.g3, d1, DH1536_MODULUS) * pow(p, c, DH1536_MODULUS) \
|
||||
% DH1536_MODULUS
|
||||
temp1 = pow(self.g3, d1, DH_MODULUS) * pow(p, c, DH_MODULUS) \
|
||||
% DH_MODULUS
|
||||
|
||||
temp2 = pow(self.g1, d1, DH1536_MODULUS) \
|
||||
* pow(self.g2, d2, DH1536_MODULUS) \
|
||||
* pow(q, c, DH1536_MODULUS) % DH1536_MODULUS
|
||||
temp2 = pow(self.g1, d1, DH_MODULUS) \
|
||||
* pow(self.g2, d2, DH_MODULUS) \
|
||||
* pow(q, c, DH_MODULUS) % DH_MODULUS
|
||||
|
||||
cprime = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
|
||||
|
||||
return long_to_bytes(c) == cprime
|
||||
return long_to_bytes(c, 32) == cprime
|
||||
|
||||
def proof_equal_logs(self, v):
|
||||
r = bytes_to_long(RNG.read(192))
|
||||
temp1 = pow(self.g1, r, DH1536_MODULUS)
|
||||
temp2 = pow(self.qab, r, DH1536_MODULUS)
|
||||
r = random.randrange(2, DH_MAX)
|
||||
temp1 = pow(self.g1, r, DH_MODULUS)
|
||||
temp2 = pow(self.qab, r, DH_MODULUS)
|
||||
|
||||
cb = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
|
||||
c = bytes_to_long(cb)
|
||||
@@ -763,29 +767,29 @@ class SMPHandler:
|
||||
|
||||
def check_equal_logs(self, logs, v):
|
||||
(r, c, d) = logs
|
||||
temp1 = pow(self.g1, d, DH1536_MODULUS) \
|
||||
* pow(self.g3o, c, DH1536_MODULUS) % DH1536_MODULUS
|
||||
temp1 = pow(self.g1, d, DH_MODULUS) \
|
||||
* pow(self.g3o, c, DH_MODULUS) % DH_MODULUS
|
||||
|
||||
temp2 = pow(self.qab, d, DH1536_MODULUS) \
|
||||
* pow(r, c, DH1536_MODULUS) % DH1536_MODULUS
|
||||
temp2 = pow(self.qab, d, DH_MODULUS) \
|
||||
* pow(r, c, DH_MODULUS) % DH_MODULUS
|
||||
|
||||
cprime = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
|
||||
return long_to_bytes(c) == cprime
|
||||
return long_to_bytes(c, 32) == cprime
|
||||
|
||||
def proof_known_log(g, x, v):
|
||||
r = bytes_to_long(RNG.read(192))
|
||||
c = bytes_to_long(SHA256(struct.pack(b'B', v) + pack_mpi(pow(g, r, DH1536_MODULUS))))
|
||||
r = random.randrange(2, DH_MAX)
|
||||
c = bytes_to_long(SHA256(struct.pack(b'B', v) + pack_mpi(pow(g, r, DH_MODULUS))))
|
||||
temp = x * c % SM_ORDER
|
||||
return c, (r-temp) % SM_ORDER
|
||||
|
||||
def check_known_log(c, d, g, x, v):
|
||||
gd = pow(g, d, DH1536_MODULUS)
|
||||
xc = pow(x, c, DH1536_MODULUS)
|
||||
gdxc = gd * xc % DH1536_MODULUS
|
||||
return SHA256(struct.pack(b'B', v) + pack_mpi(gdxc)) == long_to_bytes(c)
|
||||
gd = pow(g, d, DH_MODULUS)
|
||||
xc = pow(x, c, DH_MODULUS)
|
||||
gdxc = gd * xc % DH_MODULUS
|
||||
return SHA256(struct.pack(b'B', v) + pack_mpi(gdxc)) == long_to_bytes(c, 32)
|
||||
|
||||
def invMod(n):
|
||||
return pow(n, DH1536_MODULUS_2, DH1536_MODULUS)
|
||||
return pow(n, DH_MODULUS_2, DH_MODULUS)
|
||||
|
||||
class InvalidParameterError(RuntimeError):
|
||||
pass
|
||||
|
||||
@@ -19,14 +19,16 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import struct
|
||||
from potr.utils import pack_mpi, read_mpi, pack_data, read_data, unpack
|
||||
|
||||
OTRTAG = b'?OTR'
|
||||
MESSAGE_TAG_BASE = b' \t \t\t\t\t \t \t \t '
|
||||
MESSAGE_TAG_V1 = b' \t \t \t '
|
||||
MESSAGE_TAG_V2 = b' \t\t \t '
|
||||
MESSAGE_TAGS = {
|
||||
1:b' \t \t \t ',
|
||||
2:b' \t\t \t ',
|
||||
3:b' \t\t \t\t',
|
||||
}
|
||||
|
||||
MSGTYPE_NOTOTR = 0
|
||||
MSGTYPE_TAGGEDPLAINTEXT = 1
|
||||
@@ -62,6 +64,8 @@ def registermessage(cls):
|
||||
def registertlv(cls):
|
||||
if not hasattr(cls, 'parsePayload'):
|
||||
raise TypeError('registered tlv types need parsePayload()')
|
||||
if cls.typ is None:
|
||||
raise TypeError('registered tlv type needs type ID')
|
||||
tlvClasses[cls.typ] = cls
|
||||
return cls
|
||||
|
||||
@@ -84,16 +88,6 @@ class OTRMessage(object):
|
||||
__slots__ = ['payload']
|
||||
version = 0x0002
|
||||
msgtype = 0
|
||||
def __init__(self, payload):
|
||||
self.payload = payload
|
||||
|
||||
def getPayload(self):
|
||||
return self.payload
|
||||
|
||||
def __bytes__(self):
|
||||
data = struct.pack(b'!HB', self.version, self.msgtype) \
|
||||
+ self.getPayload()
|
||||
return b'?OTR:' + base64.b64encode(data) + b'.'
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
@@ -110,6 +104,7 @@ class OTRMessage(object):
|
||||
class Error(OTRMessage):
|
||||
__slots__ = ['error']
|
||||
def __init__(self, error):
|
||||
super(Error, self).__init__()
|
||||
self.error = error
|
||||
|
||||
def __repr__(self):
|
||||
@@ -119,56 +114,58 @@ class Error(OTRMessage):
|
||||
return b'?OTR Error:' + self.error
|
||||
|
||||
class Query(OTRMessage):
|
||||
__slots__ = ['v1', 'v2']
|
||||
def __init__(self, v1, v2):
|
||||
self.v1 = v1
|
||||
self.v2 = v2
|
||||
__slots__ = ['versions']
|
||||
def __init__(self, versions=set()):
|
||||
super(Query, self).__init__()
|
||||
self.versions = versions
|
||||
|
||||
@classmethod
|
||||
def parse(cls, data):
|
||||
v2 = False
|
||||
v1 = False
|
||||
if len(data) > 0 and data[0:1] == b'?':
|
||||
data = data[1:]
|
||||
v1 = True
|
||||
if not isinstance(data, bytes):
|
||||
raise TypeError('can only parse bytes')
|
||||
udata = data.decode('ascii', errors='replace')
|
||||
|
||||
if len(data) > 0 and data[0:1] == b'v':
|
||||
for c in data[1:]:
|
||||
if c == b'2'[0]:
|
||||
v2 = True
|
||||
return cls(v1, v2)
|
||||
versions = set()
|
||||
if len(udata) > 0 and udata[0] == '?':
|
||||
udata = udata[1:]
|
||||
versions.add(1)
|
||||
|
||||
if len(udata) > 0 and udata[0] == 'v':
|
||||
versions.update(( int(c) for c in udata if c.isdigit() ))
|
||||
return cls(versions)
|
||||
|
||||
def __repr__(self):
|
||||
return '<proto.Query(v1=%r,v2=%r)>'%(self.v1,self.v2)
|
||||
return '<proto.Query(versions=%r)>' % (self.versions)
|
||||
|
||||
def __bytes__(self):
|
||||
d = b'?OTR'
|
||||
if self.v1:
|
||||
if 1 in self.versions:
|
||||
d += b'?'
|
||||
d += b'v'
|
||||
if self.v2:
|
||||
d += b'2'
|
||||
|
||||
# in python3 there is only int->unicode conversion
|
||||
# so I convert to unicode and encode it to a byte string
|
||||
versions = [ '%d' % v for v in self.versions if v != 1 ]
|
||||
d += ''.join(versions).encode('ascii')
|
||||
|
||||
d += b'?'
|
||||
return d
|
||||
|
||||
class TaggedPlaintext(Query):
|
||||
__slots__ = ['msg']
|
||||
def __init__(self, msg, v1, v2):
|
||||
def __init__(self, msg, versions):
|
||||
super(TaggedPlaintext, self).__init__(versions)
|
||||
self.msg = msg
|
||||
self.v1 = v1
|
||||
self.v2 = v2
|
||||
|
||||
def __bytes__(self):
|
||||
data = self.msg + MESSAGE_TAG_BASE
|
||||
if self.v1:
|
||||
data += MESSAGE_TAG_V1
|
||||
if self.v2:
|
||||
data += MESSAGE_TAG_V2
|
||||
for v in self.versions:
|
||||
data += MESSAGE_TAGS[v]
|
||||
return data
|
||||
|
||||
def __repr__(self):
|
||||
return '<proto.TaggedPlaintext(v1={v1!r},v2={v2!r},msg={msg!r})>' \
|
||||
.format(v1=self.v1, v2=self.v2, msg=self.msg)
|
||||
return '<proto.TaggedPlaintext(versions={versions!r},msg={msg!r})>' \
|
||||
.format(versions=self.versions, msg=self.msg)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, data):
|
||||
@@ -177,21 +174,18 @@ class TaggedPlaintext(Query):
|
||||
raise TypeError(
|
||||
'this is not a tagged plaintext ({0!r:.20})'.format(data))
|
||||
|
||||
v1 = False
|
||||
v2 = False
|
||||
|
||||
tags = [ data[i:i+8] for i in range(tagPos, len(data), 8) ]
|
||||
for tag in tags:
|
||||
if not tag.isspace():
|
||||
break
|
||||
v1 |= tag == MESSAGE_TAG_V1
|
||||
v2 |= tag == MESSAGE_TAG_V2
|
||||
versions = set([ version for version, tag in MESSAGE_TAGS.items() if tag
|
||||
in tags ])
|
||||
|
||||
return TaggedPlaintext(data[:tagPos], v1, v2)
|
||||
return TaggedPlaintext(data[:tagPos], versions)
|
||||
|
||||
class GenericOTRMessage(OTRMessage):
|
||||
__slots__ = ['data']
|
||||
fields = []
|
||||
|
||||
def __init__(self, *args):
|
||||
super(GenericOTRMessage, self).__init__()
|
||||
if len(args) != len(self.fields):
|
||||
raise TypeError('%s needs %d arguments, got %d' %
|
||||
(self.__class__.__name__, len(self.fields), len(args)))
|
||||
@@ -213,6 +207,11 @@ class GenericOTRMessage(OTRMessage):
|
||||
self.__getattr__(attr) # existence check
|
||||
self.data[attr] = val
|
||||
|
||||
def __bytes__(self):
|
||||
data = struct.pack(b'!HB', self.version, self.msgtype) \
|
||||
+ self.getPayload()
|
||||
return b'?OTR:' + base64.b64encode(data) + b'.'
|
||||
|
||||
def __repr__(self):
|
||||
name = self.__class__.__name__
|
||||
data = ''
|
||||
@@ -224,11 +223,10 @@ class GenericOTRMessage(OTRMessage):
|
||||
def parsePayload(cls, data):
|
||||
data = base64.b64decode(data)
|
||||
args = []
|
||||
for k, ftype in cls.fields:
|
||||
for _, ftype in cls.fields:
|
||||
if ftype == 'data':
|
||||
value, data = read_data(data)
|
||||
elif isinstance(ftype, bytes):
|
||||
size = int(struct.calcsize(ftype))
|
||||
value, data = unpack(ftype, data)
|
||||
elif isinstance(ftype, int):
|
||||
value, data = data[:ftype], data[ftype:]
|
||||
@@ -251,26 +249,24 @@ class GenericOTRMessage(OTRMessage):
|
||||
|
||||
class AKEMessage(GenericOTRMessage):
|
||||
__slots__ = []
|
||||
pass
|
||||
|
||||
@registermessage
|
||||
class DHCommit(AKEMessage):
|
||||
__slots__ = []
|
||||
msgtype = 0x02
|
||||
fields = [('encgx','data'), ('hashgx','data'), ]
|
||||
|
||||
fields = [('encgx', 'data'), ('hashgx', 'data'), ]
|
||||
|
||||
@registermessage
|
||||
class DHKey(AKEMessage):
|
||||
__slots__ = []
|
||||
msgtype = 0x0a
|
||||
fields = [('gy','data'), ]
|
||||
fields = [('gy', 'data'), ]
|
||||
|
||||
@registermessage
|
||||
class RevealSig(AKEMessage):
|
||||
__slots__ = []
|
||||
msgtype = 0x11
|
||||
fields = [('rkey','data'), ('encsig','data'), ('mac',20),]
|
||||
fields = [('rkey', 'data'), ('encsig', 'data'), ('mac', 20),]
|
||||
|
||||
def getMacedData(self):
|
||||
p = self.encsig
|
||||
@@ -280,7 +276,7 @@ class RevealSig(AKEMessage):
|
||||
class Signature(AKEMessage):
|
||||
__slots__ = []
|
||||
msgtype = 0x12
|
||||
fields = [('encsig','data'), ('mac',20)]
|
||||
fields = [('encsig', 'data'), ('mac', 20)]
|
||||
|
||||
def getMacedData(self):
|
||||
p = self.encsig
|
||||
@@ -290,8 +286,9 @@ class Signature(AKEMessage):
|
||||
class DataMessage(GenericOTRMessage):
|
||||
__slots__ = []
|
||||
msgtype = 0x03
|
||||
fields = [('flags',b'!B'), ('skeyid',b'!I'), ('rkeyid',b'!I'), ('dhy','data'),
|
||||
('ctr',8), ('encmsg','data'), ('mac',20), ('oldmacs','data'), ]
|
||||
fields = [('flags', b'!B'), ('skeyid', b'!I'), ('rkeyid', b'!I'),
|
||||
('dhy', 'data'), ('ctr', 8), ('encmsg', 'data'), ('mac', 20),
|
||||
('oldmacs', 'data'), ]
|
||||
|
||||
def getMacedData(self):
|
||||
return struct.pack(b'!HB', self.version, self.msgtype) + \
|
||||
@@ -300,6 +297,10 @@ class DataMessage(GenericOTRMessage):
|
||||
@bytesAndStrings
|
||||
class TLV(object):
|
||||
__slots__ = []
|
||||
typ = None
|
||||
|
||||
def getPayload(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
val = self.getPayload()
|
||||
@@ -330,11 +331,28 @@ class TLV(object):
|
||||
def __neq__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
@registertlv
|
||||
class PaddingTLV(TLV):
|
||||
typ = 0
|
||||
|
||||
__slots__ = ['padding']
|
||||
|
||||
def __init__(self, padding):
|
||||
super(PaddingTLV, self).__init__()
|
||||
self.padding = padding
|
||||
|
||||
def getPayload(self):
|
||||
return self.padding
|
||||
|
||||
@classmethod
|
||||
def parsePayload(cls, data):
|
||||
return cls(data)
|
||||
|
||||
@registertlv
|
||||
class DisconnectTLV(TLV):
|
||||
typ = 1
|
||||
def __init__(self):
|
||||
pass
|
||||
super(DisconnectTLV, self).__init__()
|
||||
|
||||
def getPayload(self):
|
||||
return b''
|
||||
@@ -348,8 +366,14 @@ class DisconnectTLV(TLV):
|
||||
|
||||
class SMPTLV(TLV):
|
||||
__slots__ = ['mpis']
|
||||
dlen = None
|
||||
|
||||
def __init__(self, mpis=[]):
|
||||
def __init__(self, mpis=None):
|
||||
super(SMPTLV, self).__init__()
|
||||
if mpis is None:
|
||||
mpis = []
|
||||
if self.dlen is None:
|
||||
raise TypeError('no amount of mpis specified in dlen')
|
||||
if len(mpis) != self.dlen:
|
||||
raise TypeError('expected {0} mpis, got {1}'
|
||||
.format(self.dlen, len(mpis)))
|
||||
@@ -366,7 +390,7 @@ class SMPTLV(TLV):
|
||||
mpis = []
|
||||
if cls.dlen > 0:
|
||||
count, data = unpack(b'!I', data)
|
||||
for i in range(count):
|
||||
for _ in range(count):
|
||||
n, data = read_mpi(data)
|
||||
mpis.append(n)
|
||||
if len(data) > 0:
|
||||
@@ -419,3 +443,23 @@ class SMPABORTTLV(SMPTLV):
|
||||
|
||||
def getPayload(self):
|
||||
return b''
|
||||
|
||||
@registertlv
|
||||
class ExtraKeyTLV(TLV):
|
||||
typ = 8
|
||||
|
||||
__slots__ = ['appid', 'appdata']
|
||||
|
||||
def __init__(self, appid, appdata):
|
||||
super(ExtraKeyTLV, self).__init__()
|
||||
self.appid = appid
|
||||
self.appdata = appdata
|
||||
if appdata is None:
|
||||
self.appdata = b''
|
||||
|
||||
def getPayload(self):
|
||||
return self.appid + self.appdata
|
||||
|
||||
@classmethod
|
||||
def parsePayload(cls, data):
|
||||
return cls(data[:4], data[4:])
|
||||
|
||||
@@ -43,11 +43,12 @@ def bytes_to_long(b):
|
||||
s += byte_to_long(b[i:i+1]) << 8*(l-i-1)
|
||||
return s
|
||||
|
||||
def long_to_bytes(l):
|
||||
def long_to_bytes(l, n=0):
|
||||
b = b''
|
||||
while l != 0:
|
||||
while l != 0 or n > 0:
|
||||
b = long_to_byte(l & 0xff) + b
|
||||
l >>= 8
|
||||
n -= 1
|
||||
return b
|
||||
|
||||
def byte_to_long(b):
|
||||
|
||||
Reference in New Issue
Block a user