gotr: update provided potr to 1.0.0beta7

This commit is contained in:
Kjell Braden
2013-09-22 17:05:00 +02:00
parent e0902e8cd3
commit fd8f9f7d80
7 changed files with 331 additions and 227 deletions

View File

@@ -24,4 +24,4 @@ from potr.utils import human_hash
''' version is: (major, minor, patch, sub) with sub being one of 'alpha', ''' version is: (major, minor, patch, sub) with sub being one of 'alpha',
'beta', 'final' ''' 'beta', 'final' '''
VERSION = (1, 0, 0, 'beta5') VERSION = (1, 0, 0, 'beta7')

View File

@@ -26,8 +26,8 @@ from potr.utils import human_hash, bytes_to_long, unpack, pack_mpi
DEFAULT_KEYTYPE = 0x0000 DEFAULT_KEYTYPE = 0x0000
pkTypes = {} pkTypes = {}
def registerkeytype(cls): def registerkeytype(cls):
if not hasattr(cls, 'parsePayload'): if cls.keyType is None:
raise TypeError('registered key types need parsePayload()') raise TypeError('registered key class needs a type value')
pkTypes[cls.keyType] = cls pkTypes[cls.keyType] = cls
return cls return cls
@@ -35,12 +35,16 @@ def generateDefaultKey():
return pkTypes[DEFAULT_KEYTYPE].generate() return pkTypes[DEFAULT_KEYTYPE].generate()
class PK(object): class PK(object):
__slots__ = [] keyType = None
@classmethod @classmethod
def generate(cls): def generate(cls):
raise NotImplementedError raise NotImplementedError
@classmethod
def parsePayload(cls, data, private=False):
raise NotImplementedError
def sign(self, data): def sign(self, data):
raise NotImplementedError raise NotImplementedError
def verify(self, data): def verify(self, data):
@@ -80,13 +84,13 @@ class PK(object):
@classmethod @classmethod
def parsePrivateKey(cls, data): def parsePrivateKey(cls, data):
implCls, data = cls.getImplementation(data) implCls, data = cls.getImplementation(data)
logging.debug('Got privkey of type %r' % implCls) logging.debug('Got privkey of type %r', implCls)
return implCls.parsePayload(data, private=True) return implCls.parsePayload(data, private=True)
@classmethod @classmethod
def parsePublicKey(cls, data): def parsePublicKey(cls, data):
implCls, data = cls.getImplementation(data) implCls, data = cls.getImplementation(data)
logging.debug('Got pubkey of type %r' % implCls) logging.debug('Got pubkey of type %r', implCls)
return implCls.parsePayload(data) return implCls.parsePayload(data)
def __str__(self): def __str__(self):

View File

@@ -15,18 +15,16 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>. # along with this library. If not, see <http://www.gnu.org/licenses/>.
from Crypto import Cipher, Random from Crypto import Cipher
from Crypto.Hash import SHA256 as _SHA256 from Crypto.Hash import SHA256 as _SHA256
from Crypto.Hash import SHA as _SHA1 from Crypto.Hash import SHA as _SHA1
from Crypto.Hash import HMAC as _HMAC from Crypto.Hash import HMAC as _HMAC
from Crypto.PublicKey import DSA from Crypto.PublicKey import DSA
from Crypto.Random import random
from numbers import Number from numbers import Number
from potr.compatcrypto import common from potr.compatcrypto import common
from potr.utils import pack_mpi, read_mpi, bytes_to_long, long_to_bytes from potr.utils import read_mpi, bytes_to_long, long_to_bytes
# XXX atfork?
RNG = Random.new()
def SHA256(data): def SHA256(data):
return _SHA256.new(data).digest() return _SHA256.new(data).digest()
@@ -54,7 +52,6 @@ def AESCTR(key, counter=0):
return Cipher.AES.new(key, Cipher.AES.MODE_CTR, counter=counter) return Cipher.AES.new(key, Cipher.AES.MODE_CTR, counter=counter)
class Counter(object): class Counter(object):
__slots__ = ['prefix', 'val']
def __init__(self, prefix): def __init__(self, prefix):
self.prefix = prefix self.prefix = prefix
self.val = 0 self.val = 0
@@ -72,17 +69,15 @@ class Counter(object):
return '<Counter(p={p!r},v={v!r})>'.format(p=self.prefix, v=self.val) return '<Counter(p={p!r},v={v!r})>'.format(p=self.prefix, v=self.val)
def byteprefix(self): def byteprefix(self):
return long_to_bytes(self.prefix).rjust(8, b'\0') return long_to_bytes(self.prefix, 8)
def __call__(self): def __call__(self):
val = long_to_bytes(self.val) bytesuffix = long_to_bytes(self.val, 8)
prefix = long_to_bytes(self.prefix)
self.val += 1 self.val += 1
return self.byteprefix() + val.rjust(8, b'\0') return self.byteprefix() + bytesuffix
@common.registerkeytype @common.registerkeytype
class DSAKey(common.PK): class DSAKey(common.PK):
__slots__ = ['priv', 'pub']
keyType = 0x0000 keyType = 0x0000
def __init__(self, key=None, private=False): def __init__(self, key=None, private=False):
@@ -111,10 +106,10 @@ class DSAKey(common.PK):
return SHA1(self.getSerializedPublicPayload()) return SHA1(self.getSerializedPublicPayload())
def sign(self, data): def sign(self, data):
# 2 <= K <= q = 160bit = 20 byte # 2 <= K <= q
K = bytes_to_long(RNG.read(19)) + 2 K = random.randrange(2, self.priv.q)
r, s = self.priv.sign(data, K) r, s = self.priv.sign(data, K)
return long_to_bytes(r) + long_to_bytes(s) return long_to_bytes(r, 20) + long_to_bytes(s, 20)
def verify(self, data, sig): def verify(self, data, sig):
r, s = bytes_to_long(sig[:20]), bytes_to_long(sig[20:]) r, s = bytes_to_long(sig[:20]), bytes_to_long(sig[20:])

View File

@@ -19,7 +19,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
try: try:
basestring = basestring type(basestring)
except NameError: except NameError:
# all strings are unicode in python3k # all strings are unicode in python3k
basestring = str basestring = str
@@ -27,7 +27,7 @@ except NameError:
# callable is not available in python 3.0 and 3.1 # callable is not available in python 3.0 and 3.1
try: try:
callable = callable type(callable)
except NameError: except NameError:
from collections import Callable from collections import Callable
def callable(x): def callable(x):
@@ -42,6 +42,7 @@ logger = logging.getLogger(__name__)
from potr import crypt from potr import crypt
from potr import proto from potr import proto
from potr import compatcrypto
from time import time from time import time
@@ -62,16 +63,11 @@ OFFER_REJECTED = 2
OFFER_ACCEPTED = 3 OFFER_ACCEPTED = 3
class Context(object): class Context(object):
__slots__ = ['user', 'policy', 'crypto', 'tagOffer', 'lastSend',
'lastMessage', 'mayRetransmit', 'fragment', 'fragmentInfo', 'state',
'inject', 'trust', 'peer', 'trustName']
def __init__(self, account, peername): def __init__(self, account, peername):
self.user = account self.user = account
self.peer = peername self.peer = peername
self.policy = {} self.policy = {}
self.crypto = crypt.CryptEngine(self) self.crypto = crypt.CryptEngine(self)
self.discardFragment()
self.tagOffer = OFFER_NOTSENT self.tagOffer = OFFER_NOTSENT
self.mayRetransmit = 0 self.mayRetransmit = 0
self.lastSend = 0 self.lastSend = 0
@@ -79,6 +75,10 @@ class Context(object):
self.state = STATE_PLAINTEXT self.state = STATE_PLAINTEXT
self.trustName = self.peer self.trustName = self.peer
self.fragmentInfo = None
self.fragment = None
self.discardFragment()
def getPolicy(self, key): def getPolicy(self, key):
raise NotImplementedError raise NotImplementedError
@@ -100,13 +100,19 @@ class Context(object):
params = message.split(b',') params = message.split(b',')
if len(params) < 5 or not params[1].isdigit() or not params[2].isdigit(): if len(params) < 5 or not params[1].isdigit() or not params[2].isdigit():
logger.warning('invalid formed fragmented message: %r', params) logger.warning('invalid formed fragmented message: %r', params)
return None self.discardFragment()
return message
K, N = self.fragmentInfo K, N = self.fragmentInfo
try:
k = int(params[1])
n = int(params[2])
except ValueError:
logger.warning('invalid formed fragmented message: %r', params)
self.discardFragment()
return message
k = int(params[1])
n = int(params[2])
fragData = params[3] fragData = params[3]
logger.debug(params) logger.debug(params)
@@ -114,17 +120,17 @@ class Context(object):
if n >= k == 1: if n >= k == 1:
# first fragment # first fragment
self.discardFragment() self.discardFragment()
self.fragmentInfo = (k,n) self.fragmentInfo = (k, n)
self.fragment.append(fragData) self.fragment.append(fragData)
elif N == n >= k > 1 and k == K+1: elif N == n >= k > 1 and k == K+1:
# accumulate # accumulate
self.fragmentInfo = (k,n) self.fragmentInfo = (k, n)
self.fragment.append(fragData) self.fragment.append(fragData)
else: else:
# bad, discard # bad, discard
self.discardFragment() self.discardFragment()
logger.warning('invalid fragmented message: %r', params) logger.warning('invalid fragmented message: %r', params)
return None return message
if n == k > 0: if n == k > 0:
assembled = b''.join(self.fragment) assembled = b''.join(self.fragment)
@@ -210,7 +216,7 @@ class Context(object):
if self.state != STATE_ENCRYPTED: if self.state != STATE_ENCRYPTED:
self.sendInternal(proto.Error( self.sendInternal(proto.Error(
'You sent encrypted to {user}, who wasn\'t expecting it.' 'You sent encrypted to {user}, who wasn\'t expecting it.'
.format(user=self.user.name)), appdata=appdata) .format(user=self.user.name).encode('utf-8')), appdata=appdata)
if ignore: if ignore:
return IGN return IGN
raise NotEncryptedError(EXC_UNREADABLE_MESSAGE) raise NotEncryptedError(EXC_UNREADABLE_MESSAGE)
@@ -263,12 +269,13 @@ class Context(object):
return msg return msg
def processOutgoingMessage(self, msg, flags, tlvs=[]): def processOutgoingMessage(self, msg, flags, tlvs=[]):
if isinstance(self.parse(msg), proto.Query): isQuery = self.parseExplicitQuery(msg) is not None
if isQuery:
return self.user.getDefaultQueryMessage(self.getPolicy) return self.user.getDefaultQueryMessage(self.getPolicy)
if self.state == STATE_PLAINTEXT: if self.state == STATE_PLAINTEXT:
if self.getPolicy('REQUIRE_ENCRYPTION'): if self.getPolicy('REQUIRE_ENCRYPTION'):
if not isinstance(self.parse(msg), proto.Query): if not isQuery:
self.lastMessage = msg self.lastMessage = msg
self.lastSend = time() self.lastSend = time()
self.mayRetransmit = 2 self.mayRetransmit = 2
@@ -277,8 +284,12 @@ class Context(object):
return msg return msg
if self.getPolicy('SEND_TAG') and self.tagOffer != OFFER_REJECTED: if self.getPolicy('SEND_TAG') and self.tagOffer != OFFER_REJECTED:
self.tagOffer = OFFER_SENT self.tagOffer = OFFER_SENT
return proto.TaggedPlaintext(msg, self.getPolicy('ALLOW_V1'), versions = set()
self.getPolicy('ALLOW_V2')) if self.getPolicy('ALLOW_V1'):
versions.add(1)
if self.getPolicy('ALLOW_V2'):
versions.add(2)
return proto.TaggedPlaintext(msg, versions)
return msg return msg
if self.state == STATE_ENCRYPTED: if self.state == STATE_ENCRYPTED:
msg = self.crypto.createDataMessage(msg, flags, tlvs) msg = self.crypto.createDataMessage(msg, flags, tlvs)
@@ -304,9 +315,9 @@ class Context(object):
def sendFragmented(self, msg, policy=FRAGMENT_SEND_ALL, appdata=None): def sendFragmented(self, msg, policy=FRAGMENT_SEND_ALL, appdata=None):
mms = self.maxMessageSize(appdata) mms = self.maxMessageSize(appdata)
msgLen = len(msg) msgLen = len(msg)
if mms != 0 and len(msg) > mms: if mms != 0 and msgLen > mms:
fms = mms - 19 fms = mms - 19
fragments = [ msg[i:i+fms] for i in range(0, len(msg), fms) ] fragments = [ msg[i:i+fms] for i in range(0, msgLen, fms) ]
fc = len(fragments) fc = len(fragments)
@@ -375,9 +386,9 @@ class Context(object):
self.crypto.smpSecret(secret, question=question, appdata=appdata) self.crypto.smpSecret(secret, question=question, appdata=appdata)
def handleQuery(self, message, appdata=None): def handleQuery(self, message, appdata=None):
if message.v2 and self.getPolicy('ALLOW_V2'): if 2 in message.versions and self.getPolicy('ALLOW_V2'):
self.authStartV2(appdata=appdata) self.authStartV2(appdata=appdata)
elif message.v1 and self.getPolicy('ALLOW_V1'): elif 1 in message.versions and self.getPolicy('ALLOW_V1'):
self.authStartV1(appdata=appdata) self.authStartV1(appdata=appdata)
def authStartV1(self, appdata=None): def authStartV1(self, appdata=None):
@@ -386,7 +397,33 @@ class Context(object):
def authStartV2(self, appdata=None): def authStartV2(self, appdata=None):
self.crypto.startAKE(appdata=appdata) self.crypto.startAKE(appdata=appdata)
def parse(self, message): def parseExplicitQuery(self, message):
otrTagPos = message.find(proto.OTRTAG)
if otrTagPos == -1:
return None
indexBase = otrTagPos + len(proto.OTRTAG)
if len(message) <= indexBase:
return None
compare = message[indexBase]
hasq = compare == b'?'[0]
hasv = compare == b'v'[0]
if not hasq and not hasv:
return None
hasv |= len(message) > indexBase+1 and message[indexBase+1] == b'v'[0]
if hasv:
end = message.find(b'?', indexBase+1)
else:
end = indexBase+1
return message[indexBase:end]
def parse(self, message, nofragment=False):
otrTagPos = message.find(proto.OTRTAG) otrTagPos = message.find(proto.OTRTAG)
if otrTagPos == -1: if otrTagPos == -1:
if proto.MESSAGE_TAG_BASE in message: if proto.MESSAGE_TAG_BASE in message:
@@ -395,38 +432,40 @@ class Context(object):
return message return message
indexBase = otrTagPos + len(proto.OTRTAG) indexBase = otrTagPos + len(proto.OTRTAG)
if len(message) <= indexBase:
return message
compare = message[indexBase] compare = message[indexBase]
if compare == b','[0]: if nofragment is False and compare == b','[0]:
message = self.fragmentAccumulate(message[indexBase:]) message = self.fragmentAccumulate(message[indexBase:])
if message is None: if message is None:
return None return None
else: else:
return self.parse(message) return self.parse(message, nofragment=True)
else: else:
self.discardFragment() self.discardFragment()
hasq = compare == b'?'[0] queryPayload = self.parseExplicitQuery(message)
hasv = compare == b'v'[0] if queryPayload is not None:
if hasq or hasv: return proto.Query.parse(queryPayload)
hasv |= len(message) > indexBase+1 and \
message[indexBase+1] == b'v'[0]
if hasv:
end = message.find(b'?', indexBase+1)
else:
end = indexBase+1
payload = message[indexBase:end]
return proto.Query.parse(payload)
if compare == b':'[0] and len(message) > indexBase + 4: if compare == b':'[0] and len(message) > indexBase + 4:
infoTag = base64.b64decode(message[indexBase+1:indexBase+5]) try:
classInfo = struct.unpack(b'!HB', infoTag) infoTag = base64.b64decode(message[indexBase+1:indexBase+5])
cls = proto.messageClasses.get(classInfo, None) classInfo = struct.unpack(b'!HB', infoTag)
if cls is None:
cls = proto.messageClasses.get(classInfo, None)
if cls is None:
return message
logger.debug('{user} got msg {typ!r}' \
.format(user=self.user.name, typ=cls))
return cls.parsePayload(message[indexBase+5:])
except (TypeError, struct.error):
logger.exception('could not parse OTR message %s', message)
return message return message
logger.debug('{user} got msg {typ!r}' \
.format(user=self.user.name, typ=cls))
return cls.parsePayload(message[indexBase+5:])
if message[indexBase:indexBase+7] == b' Error:': if message[indexBase:indexBase+7] == b' Error:':
return proto.Error(message[indexBase+7:]) return proto.Error(message[indexBase+7:])
@@ -437,6 +476,22 @@ class Context(object):
"""Return the max message size for this context.""" """Return the max message size for this context."""
return self.user.maxMessageSize return self.user.maxMessageSize
def getExtraKey(self, extraKeyAppId=None, extraKeyAppData=None, appdata=None):
""" retrieves the generated extra symmetric key.
if extraKeyAppId is set, notifies the chat partner about intended
usage (additional application specific information can be supplied in
extraKeyAppData).
returns the 256 bit symmetric key """
if self.state != STATE_ENCRYPTED:
raise NotEncryptedError
if extraKeyAppId is not None:
tlvs = [proto.ExtraKeyTLV(extraKeyAppId, extraKeyAppData)]
self.sendInternal(b'', tlvs=tlvs, appdata=appdata)
return self.crypto.extraKey
class Account(object): class Account(object):
contextclass = Context contextclass = Context
def __init__(self, name, protocol, maxMessageSize, privkey=None): def __init__(self, name, protocol, maxMessageSize, privkey=None):
@@ -447,10 +502,10 @@ class Account(object):
self.ctxs = {} self.ctxs = {}
self.trusts = {} self.trusts = {}
self.maxMessageSize = maxMessageSize self.maxMessageSize = maxMessageSize
self.defaultQuery = b'?OTRv{versions}?\n{accountname} has requested ' \ self.defaultQuery = '?OTRv{versions}?\n{accountname} has requested ' \
b'an Off-the-Record private conversation. However, you ' \ 'an Off-the-Record private conversation. However, you ' \
b'do not have a plugin to support that.\nSee '\ 'do not have a plugin to support that.\nSee '\
b'http://otr.cypherpunks.ca/ for more information.'; 'http://otr.cypherpunks.ca/ for more information.'
def __repr__(self): def __repr__(self):
return '<{cls}(name={name!r})>'.format(cls=self.__class__.__name__, return '<{cls}(name={name!r})>'.format(cls=self.__class__.__name__,
@@ -461,7 +516,7 @@ class Account(object):
self.privkey = self.loadPrivkey() self.privkey = self.loadPrivkey()
if self.privkey is None: if self.privkey is None:
if autogen is True: if autogen is True:
self.privkey = crypt.generateDefaultKey() self.privkey = compatcrypto.generateDefaultKey()
self.savePrivkey() self.savePrivkey()
else: else:
raise LookupError raise LookupError
@@ -484,8 +539,9 @@ class Account(object):
return self.ctxs[uid] return self.ctxs[uid]
def getDefaultQueryMessage(self, policy): def getDefaultQueryMessage(self, policy):
v = b'2' if policy('ALLOW_V2') else b'' v = '2' if policy('ALLOW_V2') else ''
return self.defaultQuery.format(accountname=self.name, versions=v) msg = self.defaultQuery.format(accountname=self.name, versions=v)
return msg.encode('ascii')
def setTrust(self, key, fingerprint, trustLevel): def setTrust(self, key, fingerprint, trustLevel):
if key not in self.trusts: if key not in self.trusts:

View File

@@ -22,8 +22,8 @@ import logging
import struct import struct
from potr.compatcrypto import SHA256, SHA1, HMAC, SHA1HMAC, SHA256HMAC, \ from potr.compatcrypto import SHA256, SHA1, SHA1HMAC, SHA256HMAC, \
SHA256HMAC160, Counter, AESCTR, RNG, PK, generateDefaultKey SHA256HMAC160, Counter, AESCTR, PK, random
from potr.utils import bytes_to_long, long_to_bytes, pack_mpi, read_mpi from potr.utils import bytes_to_long, long_to_bytes, pack_mpi, read_mpi
from potr import proto from potr import proto
@@ -36,32 +36,31 @@ STATE_AWAITING_SIG = 4
STATE_V1_SETUP = 5 STATE_V1_SETUP = 5
DH1536_MODULUS = 2410312426921032588552076022197566074856950548502459942654116941958108831682612228890093858261341614673227141477904012196503648957050582631942730706805009223062734745341073406696246014589361659774041027169249453200378729434170325843778659198143763193776859869524088940195577346119843545301547043747207749969763750084308926339295559968882457872412993810129130294592999947926365264059284647209730384947211681434464714438488520940127459844288859336526896320919633919 DH_MODULUS = 2410312426921032588552076022197566074856950548502459942654116941958108831682612228890093858261341614673227141477904012196503648957050582631942730706805009223062734745341073406696246014589361659774041027169249453200378729434170325843778659198143763193776859869524088940195577346119843545301547043747207749969763750084308926339295559968882457872412993810129130294592999947926365264059284647209730384947211681434464714438488520940127459844288859336526896320919633919
DH1536_MODULUS_2 = DH1536_MODULUS-2 DH_MODULUS_2 = DH_MODULUS-2
DH1536_GENERATOR = 2 DH_GENERATOR = 2
SM_ORDER = (DH1536_MODULUS - 1) // 2 DH_BITS = 1536
DH_MAX = 2**DH_BITS
SM_ORDER = (DH_MODULUS - 1) // 2
def check_group(n): def check_group(n):
return 2 <= n <= DH1536_MODULUS_2 return 2 <= n <= DH_MODULUS_2
def check_exp(n): def check_exp(n):
return 1 <= n < SM_ORDER return 1 <= n < SM_ORDER
class DH(object): class DH(object):
__slots__ = ['priv', 'pub']
@classmethod @classmethod
def set_params(cls, prime, gen): def set_params(cls, prime, gen):
cls.prime = prime cls.prime = prime
cls.gen = gen cls.gen = gen
def __init__(self): def __init__(self):
self.priv = bytes_to_long(RNG.read(40)) self.priv = random.randrange(2, 2**320)
self.pub = pow(self.gen, self.priv, self.prime) self.pub = pow(self.gen, self.priv, self.prime)
DH.set_params(DH1536_MODULUS, DH1536_GENERATOR) DH.set_params(DH_MODULUS, DH_GENERATOR)
class DHSession(object): class DHSession(object):
__slots__ = ['sendenc', 'sendmac', 'rcvenc', 'rcvmac', 'sendctr', 'rcvctr',
'sendmacused', 'rcvmacused']
def __init__(self, sendenc, sendmac, rcvenc, rcvmac): def __init__(self, sendenc, sendmac, rcvenc, rcvmac):
self.sendenc = sendenc self.sendenc = sendenc
self.sendmac = sendmac self.sendmac = sendmac
@@ -79,7 +78,7 @@ class DHSession(object):
@classmethod @classmethod
def create(cls, dh, y): def create(cls, dh, y):
s = pow(y, dh.priv, DH1536_MODULUS) s = pow(y, dh.priv, DH_MODULUS)
sb = pack_mpi(s) sb = pack_mpi(s)
if dh.pub > y: if dh.pub > y:
@@ -96,9 +95,6 @@ class DHSession(object):
return cls(sendenc, sendmac, rcvenc, rcvmac) return cls(sendenc, sendmac, rcvenc, rcvmac)
class CryptEngine(object): class CryptEngine(object):
__slots__ = ['ctx', 'ake', 'sessionId', 'sessionIdHalf', 'theirKeyid',
'theirY', 'theirOldY', 'ourOldDHKey', 'ourDHKey', 'ourKeyid',
'sessionkeys', 'theirPubkey', 'savedMacKeys', 'smp']
def __init__(self, ctx): def __init__(self, ctx):
self.ctx = ctx self.ctx = ctx
self.ake = None self.ake = None
@@ -118,6 +114,7 @@ class CryptEngine(object):
self.savedMacKeys = [] self.savedMacKeys = []
self.smp = None self.smp = None
self.extraKey = None
def revealMacs(self, ours=True): def revealMacs(self, ours=True):
if ours: if ours:
@@ -174,7 +171,7 @@ class CryptEngine(object):
if msg.mac != SHA1HMAC(sesskey.rcvmac, msg.getMacedData()): if msg.mac != SHA1HMAC(sesskey.rcvmac, msg.getMacedData()):
logger.error('HMACs don\'t match') logger.error('HMACs don\'t match')
raise InvalidParameterError raise InvalidParameterError
sesskey.rcvmacused = 1 sesskey.rcvmacused = True
newCtrPrefix = bytes_to_long(msg.ctr) newCtrPrefix = bytes_to_long(msg.ctr)
if newCtrPrefix <= sesskey.rcvctr.prefix: if newCtrPrefix <= sesskey.rcvctr.prefix:
@@ -223,11 +220,14 @@ class CryptEngine(object):
self.smp = SMPHandler(self) self.smp = SMPHandler(self)
self.smp.abort(appdata=appdata) self.smp.abort(appdata=appdata)
def createDataMessage(self, message, flags=0, tlvs=[]): def createDataMessage(self, message, flags=0, tlvs=None):
# check MSGSTATE # check MSGSTATE
if self.theirKeyid == 0: if self.theirKeyid == 0:
raise InvalidParameterError raise InvalidParameterError
if tlvs is None:
tlvs = []
sess = self.sessionkeys[1][0] sess = self.sessionkeys[1][0]
sess.sendctr.inc() sess.sendctr.inc()
@@ -303,13 +303,16 @@ class CryptEngine(object):
self.ourKeyid = ake.ourKeyid self.ourKeyid = ake.ourKeyid
self.theirY = ake.gy self.theirY = ake.gy
self.theirOldY = None self.theirOldY = None
self.extraKey = ake.extraKey
if self.ourKeyid != ake.ourKeyid + 1 or self.ourOldDHKey != ake.dh.pub: if self.ourKeyid != ake.ourKeyid + 1 or self.ourOldDHKey != ake.dh.pub:
# XXX is this really ok?
self.ourDHKey = ake.dh self.ourDHKey = ake.dh
self.sessionkeys[0][0] = DHSession.create(self.ourDHKey, self.theirY) self.sessionkeys[0][0] = DHSession.create(self.ourDHKey, self.theirY)
self.rotateDHKeys() self.rotateDHKeys()
# we don't need the AKE anymore, free the reference
self.ake = None
self.ctx._wentEncrypted() self.ctx._wentEncrypted()
logger.info('went encrypted with {0}'.format(self.theirPubkey)) logger.info('went encrypted with {0}'.format(self.theirPubkey))
@@ -317,10 +320,6 @@ class CryptEngine(object):
self.smp = None self.smp = None
class AuthKeyExchange(object): class AuthKeyExchange(object):
__slots__ = ['privkey', 'state', 'r', 'encgx', 'hashgx', 'ourKeyid',
'theirPubkey', 'theirKeyid', 'enc_c', 'enc_cp', 'mac_m1',
'mac_m1p', 'mac_m2', 'mac_m2p', 'sessionId', 'dh', 'onSuccess',
'gy', 'lastmsg', 'sessionIdHalf']
def __init__(self, privkey, onSuccess): def __init__(self, privkey, onSuccess):
self.privkey = privkey self.privkey = privkey
self.state = STATE_NONE self.state = STATE_NONE
@@ -341,9 +340,11 @@ class AuthKeyExchange(object):
self.dh = DH() self.dh = DH()
self.onSuccess = onSuccess self.onSuccess = onSuccess
self.gy = None self.gy = None
self.extraKey = None
self.lastmsg = None
def startAKE(self): def startAKE(self):
self.r = RNG.read(16) self.r = long_to_bytes(random.getrandbits(128))
gxmpi = pack_mpi(self.dh.pub) gxmpi = pack_mpi(self.dh.pub)
@@ -444,15 +445,17 @@ class AuthKeyExchange(object):
self.state = STATE_NONE self.state = STATE_NONE
def createAuthKeys(self): def createAuthKeys(self):
s = pow(self.gy, self.dh.priv, DH1536_MODULUS) s = pow(self.gy, self.dh.priv, DH_MODULUS)
sbyte = pack_mpi(s) sbyte = pack_mpi(s)
self.sessionId = SHA256(b'\0' + sbyte)[:8] self.sessionId = SHA256(b'\x00' + sbyte)[:8]
enc = SHA256(b'\1' + sbyte) enc = SHA256(b'\x01' + sbyte)
self.enc_c, self.enc_cp = enc[:16], enc[16:] self.enc_c = enc[:16]
self.mac_m1 = SHA256(b'\2' + sbyte) self.enc_cp = enc[16:]
self.mac_m2 = SHA256(b'\3' + sbyte) self.mac_m1 = SHA256(b'\x02' + sbyte)
self.mac_m1p = SHA256(b'\4' + sbyte) self.mac_m2 = SHA256(b'\x03' + sbyte)
self.mac_m2p = SHA256(b'\5' + sbyte) self.mac_m1p = SHA256(b'\x04' + sbyte)
self.mac_m2p = SHA256(b'\x05' + sbyte)
self.extraKey = SHA256(b'\xff' + sbyte)
def calculatePubkeyAuth(self, key, mackey): def calculatePubkeyAuth(self, key, mackey):
pubkey = self.privkey.serializePublicKey() pubkey = self.privkey.serializePublicKey()
@@ -490,14 +493,15 @@ SMPPROG_FAILED = -1
SMPPROG_SUCCEEDED = 1 SMPPROG_SUCCEEDED = 1
class SMPHandler: class SMPHandler:
__slots__ = ['crypto', 'questionReceived', 'prog', 'state', 'g1', 'g3o',
'x2', 'x3', 'g2', 'g3', 'pab', 'qab', 'secret', 'p', 'q']
def __init__(self, crypto): def __init__(self, crypto):
self.crypto = crypto self.crypto = crypto
self.state = 1 self.state = 1
self.g1 = DH1536_GENERATOR self.g1 = DH_GENERATOR
self.g2 = None
self.g3 = None
self.g3o = None self.g3o = None
self.x2 = None
self.x3 = None
self.prog = SMPPROG_OK self.prog = SMPPROG_OK
self.pab = None self.pab = None
self.qab = None self.qab = None
@@ -539,11 +543,11 @@ class SMPHandler:
self.g3o = msg[3] self.g3o = msg[3]
self.x2 = bytes_to_long(RNG.read(192)) self.x2 = random.randrange(2, DH_MAX)
self.x3 = bytes_to_long(RNG.read(192)) self.x3 = random.randrange(2, DH_MAX)
self.g2 = pow(msg[0], self.x2, DH1536_MODULUS) self.g2 = pow(msg[0], self.x2, DH_MODULUS)
self.g3 = pow(msg[3], self.x3, DH1536_MODULUS) self.g3 = pow(msg[3], self.x3, DH_MODULUS)
self.prog = SMPPROG_OK self.prog = SMPPROG_OK
self.state = 0 self.state = 0
@@ -568,29 +572,29 @@ class SMPHandler:
return return
self.g3o = msg[3] self.g3o = msg[3]
self.g2 = pow(msg[0], self.x2, DH1536_MODULUS) self.g2 = pow(msg[0], self.x2, DH_MODULUS)
self.g3 = pow(msg[3], self.x3, DH1536_MODULUS) self.g3 = pow(msg[3], self.x3, DH_MODULUS)
if not self.check_equal_coords(msg[6:11], 5): if not self.check_equal_coords(msg[6:11], 5):
logger.error('invalid SMP2TLV received') logger.error('invalid SMP2TLV received')
self.abort(appdata=appdata) self.abort(appdata=appdata)
return return
r = bytes_to_long(RNG.read(192)) r = random.randrange(2, DH_MAX)
self.p = pow(self.g3, r, DH1536_MODULUS) self.p = pow(self.g3, r, DH_MODULUS)
msg = [self.p] msg = [self.p]
qa1 = pow(self.g1, r, DH1536_MODULUS) qa1 = pow(self.g1, r, DH_MODULUS)
qa2 = pow(self.g2, self.secret, DH1536_MODULUS) qa2 = pow(self.g2, self.secret, DH_MODULUS)
self.q = qa1*qa2 % DH1536_MODULUS self.q = qa1*qa2 % DH_MODULUS
msg.append(self.q) msg.append(self.q)
msg += self.proof_equal_coords(r, 6) msg += self.proof_equal_coords(r, 6)
inv = invMod(mp) inv = invMod(mp)
self.pab = self.p * inv % DH1536_MODULUS self.pab = self.p * inv % DH_MODULUS
inv = invMod(mq) inv = invMod(mq)
self.qab = self.q * inv % DH1536_MODULUS self.qab = self.q * inv % DH_MODULUS
msg.append(pow(self.qab, self.x3, DH1536_MODULUS)) msg.append(pow(self.qab, self.x3, DH_MODULUS))
msg += self.proof_equal_logs(7) msg += self.proof_equal_logs(7)
self.state = 4 self.state = 4
@@ -613,9 +617,9 @@ class SMPHandler:
return return
inv = invMod(self.p) inv = invMod(self.p)
self.pab = msg[0] * inv % DH1536_MODULUS self.pab = msg[0] * inv % DH_MODULUS
inv = invMod(self.q) inv = invMod(self.q)
self.qab = msg[1] * inv % DH1536_MODULUS self.qab = msg[1] * inv % DH_MODULUS
if not self.check_equal_logs(msg[5:8], 7): if not self.check_equal_logs(msg[5:8], 7):
logger.error('invalid SMP3TLV received') logger.error('invalid SMP3TLV received')
@@ -623,10 +627,10 @@ class SMPHandler:
return return
md = msg[5] md = msg[5]
msg = [pow(self.qab, self.x3, DH1536_MODULUS)] msg = [pow(self.qab, self.x3, DH_MODULUS)]
msg += self.proof_equal_logs(8) msg += self.proof_equal_logs(8)
rab = pow(md, self.x3, DH1536_MODULUS) rab = pow(md, self.x3, DH_MODULUS)
self.prog = SMPPROG_SUCCEEDED if self.pab == rab else SMPPROG_FAILED self.prog = SMPPROG_SUCCEEDED if self.pab == rab else SMPPROG_FAILED
if self.prog != SMPPROG_SUCCEEDED: if self.prog != SMPPROG_SUCCEEDED:
@@ -654,7 +658,7 @@ class SMPHandler:
self.abort(appdata=appdata) self.abort(appdata=appdata)
return return
rab = pow(msg[0], self.x3, DH1536_MODULUS) rab = pow(msg[0], self.x3, DH_MODULUS)
self.prog = SMPPROG_SUCCEEDED if self.pab == rab else SMPPROG_FAILED self.prog = SMPPROG_SUCCEEDED if self.pab == rab else SMPPROG_FAILED
@@ -679,12 +683,12 @@ class SMPHandler:
self.secret = bytes_to_long(combSecret) self.secret = bytes_to_long(combSecret)
self.x2 = bytes_to_long(RNG.read(192)) self.x2 = random.randrange(2, DH_MAX)
self.x3 = bytes_to_long(RNG.read(192)) self.x3 = random.randrange(2, DH_MAX)
msg = [pow(self.g1, self.x2, DH1536_MODULUS)] msg = [pow(self.g1, self.x2, DH_MODULUS)]
msg += proof_known_log(self.g1, self.x2, 1) msg += proof_known_log(self.g1, self.x2, 1)
msg.append(pow(self.g1, self.x3, DH1536_MODULUS)) msg.append(pow(self.g1, self.x3, DH_MODULUS))
msg += proof_known_log(self.g1, self.x3, 2) msg += proof_known_log(self.g1, self.x3, 2)
self.prog = SMPPROG_OK self.prog = SMPPROG_OK
@@ -700,19 +704,19 @@ class SMPHandler:
self.secret = bytes_to_long(combSecret) self.secret = bytes_to_long(combSecret)
msg = [pow(self.g1, self.x2, DH1536_MODULUS)] msg = [pow(self.g1, self.x2, DH_MODULUS)]
msg += proof_known_log(self.g1, self.x2, 3) msg += proof_known_log(self.g1, self.x2, 3)
msg.append(pow(self.g1, self.x3, DH1536_MODULUS)) msg.append(pow(self.g1, self.x3, DH_MODULUS))
msg += proof_known_log(self.g1, self.x3, 4) msg += proof_known_log(self.g1, self.x3, 4)
r = bytes_to_long(RNG.read(192)) r = random.randrange(2, DH_MAX)
self.p = pow(self.g3, r, DH1536_MODULUS) self.p = pow(self.g3, r, DH_MODULUS)
msg.append(self.p) msg.append(self.p)
qb1 = pow(self.g1, r, DH1536_MODULUS) qb1 = pow(self.g1, r, DH_MODULUS)
qb2 = pow(self.g2, self.secret, DH1536_MODULUS) qb2 = pow(self.g2, self.secret, DH_MODULUS)
self.q = qb1 * qb2 % DH1536_MODULUS self.q = qb1 * qb2 % DH_MODULUS
msg.append(self.q) msg.append(self.q)
msg += self.proof_equal_coords(r, 5) msg += self.proof_equal_coords(r, 5)
@@ -721,11 +725,11 @@ class SMPHandler:
self.sendTLV(proto.SMP2TLV(msg), appdata=appdata) self.sendTLV(proto.SMP2TLV(msg), appdata=appdata)
def proof_equal_coords(self, r, v): def proof_equal_coords(self, r, v):
r1 = bytes_to_long(RNG.read(192)) r1 = random.randrange(2, DH_MAX)
r2 = bytes_to_long(RNG.read(192)) r2 = random.randrange(2, DH_MAX)
temp2 = pow(self.g1, r1, DH1536_MODULUS) \ temp2 = pow(self.g1, r1, DH_MODULUS) \
* pow(self.g2, r2, DH1536_MODULUS) % DH1536_MODULUS * pow(self.g2, r2, DH_MODULUS) % DH_MODULUS
temp1 = pow(self.g3, r1, DH1536_MODULUS) temp1 = pow(self.g3, r1, DH_MODULUS)
cb = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2)) cb = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
c = bytes_to_long(cb) c = bytes_to_long(cb)
@@ -739,21 +743,21 @@ class SMPHandler:
def check_equal_coords(self, coords, v): def check_equal_coords(self, coords, v):
(p, q, c, d1, d2) = coords (p, q, c, d1, d2) = coords
temp1 = pow(self.g3, d1, DH1536_MODULUS) * pow(p, c, DH1536_MODULUS) \ temp1 = pow(self.g3, d1, DH_MODULUS) * pow(p, c, DH_MODULUS) \
% DH1536_MODULUS % DH_MODULUS
temp2 = pow(self.g1, d1, DH1536_MODULUS) \ temp2 = pow(self.g1, d1, DH_MODULUS) \
* pow(self.g2, d2, DH1536_MODULUS) \ * pow(self.g2, d2, DH_MODULUS) \
* pow(q, c, DH1536_MODULUS) % DH1536_MODULUS * pow(q, c, DH_MODULUS) % DH_MODULUS
cprime = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2)) cprime = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
return long_to_bytes(c) == cprime return long_to_bytes(c, 32) == cprime
def proof_equal_logs(self, v): def proof_equal_logs(self, v):
r = bytes_to_long(RNG.read(192)) r = random.randrange(2, DH_MAX)
temp1 = pow(self.g1, r, DH1536_MODULUS) temp1 = pow(self.g1, r, DH_MODULUS)
temp2 = pow(self.qab, r, DH1536_MODULUS) temp2 = pow(self.qab, r, DH_MODULUS)
cb = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2)) cb = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
c = bytes_to_long(cb) c = bytes_to_long(cb)
@@ -763,29 +767,29 @@ class SMPHandler:
def check_equal_logs(self, logs, v): def check_equal_logs(self, logs, v):
(r, c, d) = logs (r, c, d) = logs
temp1 = pow(self.g1, d, DH1536_MODULUS) \ temp1 = pow(self.g1, d, DH_MODULUS) \
* pow(self.g3o, c, DH1536_MODULUS) % DH1536_MODULUS * pow(self.g3o, c, DH_MODULUS) % DH_MODULUS
temp2 = pow(self.qab, d, DH1536_MODULUS) \ temp2 = pow(self.qab, d, DH_MODULUS) \
* pow(r, c, DH1536_MODULUS) % DH1536_MODULUS * pow(r, c, DH_MODULUS) % DH_MODULUS
cprime = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2)) cprime = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
return long_to_bytes(c) == cprime return long_to_bytes(c, 32) == cprime
def proof_known_log(g, x, v): def proof_known_log(g, x, v):
r = bytes_to_long(RNG.read(192)) r = random.randrange(2, DH_MAX)
c = bytes_to_long(SHA256(struct.pack(b'B', v) + pack_mpi(pow(g, r, DH1536_MODULUS)))) c = bytes_to_long(SHA256(struct.pack(b'B', v) + pack_mpi(pow(g, r, DH_MODULUS))))
temp = x * c % SM_ORDER temp = x * c % SM_ORDER
return c, (r-temp) % SM_ORDER return c, (r-temp) % SM_ORDER
def check_known_log(c, d, g, x, v): def check_known_log(c, d, g, x, v):
gd = pow(g, d, DH1536_MODULUS) gd = pow(g, d, DH_MODULUS)
xc = pow(x, c, DH1536_MODULUS) xc = pow(x, c, DH_MODULUS)
gdxc = gd * xc % DH1536_MODULUS gdxc = gd * xc % DH_MODULUS
return SHA256(struct.pack(b'B', v) + pack_mpi(gdxc)) == long_to_bytes(c) return SHA256(struct.pack(b'B', v) + pack_mpi(gdxc)) == long_to_bytes(c, 32)
def invMod(n): def invMod(n):
return pow(n, DH1536_MODULUS_2, DH1536_MODULUS) return pow(n, DH_MODULUS_2, DH_MODULUS)
class InvalidParameterError(RuntimeError): class InvalidParameterError(RuntimeError):
pass pass

View File

@@ -19,14 +19,16 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import base64 import base64
import logging
import struct import struct
from potr.utils import pack_mpi, read_mpi, pack_data, read_data, unpack from potr.utils import pack_mpi, read_mpi, pack_data, read_data, unpack
OTRTAG = b'?OTR' OTRTAG = b'?OTR'
MESSAGE_TAG_BASE = b' \t \t\t\t\t \t \t \t ' MESSAGE_TAG_BASE = b' \t \t\t\t\t \t \t \t '
MESSAGE_TAG_V1 = b' \t \t \t ' MESSAGE_TAGS = {
MESSAGE_TAG_V2 = b' \t\t \t ' 1:b' \t \t \t ',
2:b' \t\t \t ',
3:b' \t\t \t\t',
}
MSGTYPE_NOTOTR = 0 MSGTYPE_NOTOTR = 0
MSGTYPE_TAGGEDPLAINTEXT = 1 MSGTYPE_TAGGEDPLAINTEXT = 1
@@ -62,6 +64,8 @@ def registermessage(cls):
def registertlv(cls): def registertlv(cls):
if not hasattr(cls, 'parsePayload'): if not hasattr(cls, 'parsePayload'):
raise TypeError('registered tlv types need parsePayload()') raise TypeError('registered tlv types need parsePayload()')
if cls.typ is None:
raise TypeError('registered tlv type needs type ID')
tlvClasses[cls.typ] = cls tlvClasses[cls.typ] = cls
return cls return cls
@@ -84,16 +88,6 @@ class OTRMessage(object):
__slots__ = ['payload'] __slots__ = ['payload']
version = 0x0002 version = 0x0002
msgtype = 0 msgtype = 0
def __init__(self, payload):
self.payload = payload
def getPayload(self):
return self.payload
def __bytes__(self):
data = struct.pack(b'!HB', self.version, self.msgtype) \
+ self.getPayload()
return b'?OTR:' + base64.b64encode(data) + b'.'
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
@@ -110,6 +104,7 @@ class OTRMessage(object):
class Error(OTRMessage): class Error(OTRMessage):
__slots__ = ['error'] __slots__ = ['error']
def __init__(self, error): def __init__(self, error):
super(Error, self).__init__()
self.error = error self.error = error
def __repr__(self): def __repr__(self):
@@ -119,56 +114,58 @@ class Error(OTRMessage):
return b'?OTR Error:' + self.error return b'?OTR Error:' + self.error
class Query(OTRMessage): class Query(OTRMessage):
__slots__ = ['v1', 'v2'] __slots__ = ['versions']
def __init__(self, v1, v2): def __init__(self, versions=set()):
self.v1 = v1 super(Query, self).__init__()
self.v2 = v2 self.versions = versions
@classmethod @classmethod
def parse(cls, data): def parse(cls, data):
v2 = False if not isinstance(data, bytes):
v1 = False raise TypeError('can only parse bytes')
if len(data) > 0 and data[0:1] == b'?': udata = data.decode('ascii', errors='replace')
data = data[1:]
v1 = True
if len(data) > 0 and data[0:1] == b'v': versions = set()
for c in data[1:]: if len(udata) > 0 and udata[0] == '?':
if c == b'2'[0]: udata = udata[1:]
v2 = True versions.add(1)
return cls(v1, v2)
if len(udata) > 0 and udata[0] == 'v':
versions.update(( int(c) for c in udata if c.isdigit() ))
return cls(versions)
def __repr__(self): def __repr__(self):
return '<proto.Query(v1=%r,v2=%r)>'%(self.v1,self.v2) return '<proto.Query(versions=%r)>' % (self.versions)
def __bytes__(self): def __bytes__(self):
d = b'?OTR' d = b'?OTR'
if self.v1: if 1 in self.versions:
d += b'?' d += b'?'
d += b'v' d += b'v'
if self.v2:
d += b'2' # in python3 there is only int->unicode conversion
# so I convert to unicode and encode it to a byte string
versions = [ '%d' % v for v in self.versions if v != 1 ]
d += ''.join(versions).encode('ascii')
d += b'?' d += b'?'
return d return d
class TaggedPlaintext(Query): class TaggedPlaintext(Query):
__slots__ = ['msg'] __slots__ = ['msg']
def __init__(self, msg, v1, v2): def __init__(self, msg, versions):
super(TaggedPlaintext, self).__init__(versions)
self.msg = msg self.msg = msg
self.v1 = v1
self.v2 = v2
def __bytes__(self): def __bytes__(self):
data = self.msg + MESSAGE_TAG_BASE data = self.msg + MESSAGE_TAG_BASE
if self.v1: for v in self.versions:
data += MESSAGE_TAG_V1 data += MESSAGE_TAGS[v]
if self.v2:
data += MESSAGE_TAG_V2
return data return data
def __repr__(self): def __repr__(self):
return '<proto.TaggedPlaintext(v1={v1!r},v2={v2!r},msg={msg!r})>' \ return '<proto.TaggedPlaintext(versions={versions!r},msg={msg!r})>' \
.format(v1=self.v1, v2=self.v2, msg=self.msg) .format(versions=self.versions, msg=self.msg)
@classmethod @classmethod
def parse(cls, data): def parse(cls, data):
@@ -177,21 +174,18 @@ class TaggedPlaintext(Query):
raise TypeError( raise TypeError(
'this is not a tagged plaintext ({0!r:.20})'.format(data)) 'this is not a tagged plaintext ({0!r:.20})'.format(data))
v1 = False
v2 = False
tags = [ data[i:i+8] for i in range(tagPos, len(data), 8) ] tags = [ data[i:i+8] for i in range(tagPos, len(data), 8) ]
for tag in tags: versions = set([ version for version, tag in MESSAGE_TAGS.items() if tag
if not tag.isspace(): in tags ])
break
v1 |= tag == MESSAGE_TAG_V1
v2 |= tag == MESSAGE_TAG_V2
return TaggedPlaintext(data[:tagPos], v1, v2) return TaggedPlaintext(data[:tagPos], versions)
class GenericOTRMessage(OTRMessage): class GenericOTRMessage(OTRMessage):
__slots__ = ['data'] __slots__ = ['data']
fields = []
def __init__(self, *args): def __init__(self, *args):
super(GenericOTRMessage, self).__init__()
if len(args) != len(self.fields): if len(args) != len(self.fields):
raise TypeError('%s needs %d arguments, got %d' % raise TypeError('%s needs %d arguments, got %d' %
(self.__class__.__name__, len(self.fields), len(args))) (self.__class__.__name__, len(self.fields), len(args)))
@@ -213,6 +207,11 @@ class GenericOTRMessage(OTRMessage):
self.__getattr__(attr) # existence check self.__getattr__(attr) # existence check
self.data[attr] = val self.data[attr] = val
def __bytes__(self):
data = struct.pack(b'!HB', self.version, self.msgtype) \
+ self.getPayload()
return b'?OTR:' + base64.b64encode(data) + b'.'
def __repr__(self): def __repr__(self):
name = self.__class__.__name__ name = self.__class__.__name__
data = '' data = ''
@@ -224,11 +223,10 @@ class GenericOTRMessage(OTRMessage):
def parsePayload(cls, data): def parsePayload(cls, data):
data = base64.b64decode(data) data = base64.b64decode(data)
args = [] args = []
for k, ftype in cls.fields: for _, ftype in cls.fields:
if ftype == 'data': if ftype == 'data':
value, data = read_data(data) value, data = read_data(data)
elif isinstance(ftype, bytes): elif isinstance(ftype, bytes):
size = int(struct.calcsize(ftype))
value, data = unpack(ftype, data) value, data = unpack(ftype, data)
elif isinstance(ftype, int): elif isinstance(ftype, int):
value, data = data[:ftype], data[ftype:] value, data = data[:ftype], data[ftype:]
@@ -251,26 +249,24 @@ class GenericOTRMessage(OTRMessage):
class AKEMessage(GenericOTRMessage): class AKEMessage(GenericOTRMessage):
__slots__ = [] __slots__ = []
pass
@registermessage @registermessage
class DHCommit(AKEMessage): class DHCommit(AKEMessage):
__slots__ = [] __slots__ = []
msgtype = 0x02 msgtype = 0x02
fields = [('encgx','data'), ('hashgx','data'), ] fields = [('encgx', 'data'), ('hashgx', 'data'), ]
@registermessage @registermessage
class DHKey(AKEMessage): class DHKey(AKEMessage):
__slots__ = [] __slots__ = []
msgtype = 0x0a msgtype = 0x0a
fields = [('gy','data'), ] fields = [('gy', 'data'), ]
@registermessage @registermessage
class RevealSig(AKEMessage): class RevealSig(AKEMessage):
__slots__ = [] __slots__ = []
msgtype = 0x11 msgtype = 0x11
fields = [('rkey','data'), ('encsig','data'), ('mac',20),] fields = [('rkey', 'data'), ('encsig', 'data'), ('mac', 20),]
def getMacedData(self): def getMacedData(self):
p = self.encsig p = self.encsig
@@ -280,7 +276,7 @@ class RevealSig(AKEMessage):
class Signature(AKEMessage): class Signature(AKEMessage):
__slots__ = [] __slots__ = []
msgtype = 0x12 msgtype = 0x12
fields = [('encsig','data'), ('mac',20)] fields = [('encsig', 'data'), ('mac', 20)]
def getMacedData(self): def getMacedData(self):
p = self.encsig p = self.encsig
@@ -290,8 +286,9 @@ class Signature(AKEMessage):
class DataMessage(GenericOTRMessage): class DataMessage(GenericOTRMessage):
__slots__ = [] __slots__ = []
msgtype = 0x03 msgtype = 0x03
fields = [('flags',b'!B'), ('skeyid',b'!I'), ('rkeyid',b'!I'), ('dhy','data'), fields = [('flags', b'!B'), ('skeyid', b'!I'), ('rkeyid', b'!I'),
('ctr',8), ('encmsg','data'), ('mac',20), ('oldmacs','data'), ] ('dhy', 'data'), ('ctr', 8), ('encmsg', 'data'), ('mac', 20),
('oldmacs', 'data'), ]
def getMacedData(self): def getMacedData(self):
return struct.pack(b'!HB', self.version, self.msgtype) + \ return struct.pack(b'!HB', self.version, self.msgtype) + \
@@ -300,6 +297,10 @@ class DataMessage(GenericOTRMessage):
@bytesAndStrings @bytesAndStrings
class TLV(object): class TLV(object):
__slots__ = [] __slots__ = []
typ = None
def getPayload(self):
raise NotImplementedError
def __repr__(self): def __repr__(self):
val = self.getPayload() val = self.getPayload()
@@ -330,11 +331,28 @@ class TLV(object):
def __neq__(self, other): def __neq__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
@registertlv
class PaddingTLV(TLV):
typ = 0
__slots__ = ['padding']
def __init__(self, padding):
super(PaddingTLV, self).__init__()
self.padding = padding
def getPayload(self):
return self.padding
@classmethod
def parsePayload(cls, data):
return cls(data)
@registertlv @registertlv
class DisconnectTLV(TLV): class DisconnectTLV(TLV):
typ = 1 typ = 1
def __init__(self): def __init__(self):
pass super(DisconnectTLV, self).__init__()
def getPayload(self): def getPayload(self):
return b'' return b''
@@ -348,8 +366,14 @@ class DisconnectTLV(TLV):
class SMPTLV(TLV): class SMPTLV(TLV):
__slots__ = ['mpis'] __slots__ = ['mpis']
dlen = None
def __init__(self, mpis=[]): def __init__(self, mpis=None):
super(SMPTLV, self).__init__()
if mpis is None:
mpis = []
if self.dlen is None:
raise TypeError('no amount of mpis specified in dlen')
if len(mpis) != self.dlen: if len(mpis) != self.dlen:
raise TypeError('expected {0} mpis, got {1}' raise TypeError('expected {0} mpis, got {1}'
.format(self.dlen, len(mpis))) .format(self.dlen, len(mpis)))
@@ -366,7 +390,7 @@ class SMPTLV(TLV):
mpis = [] mpis = []
if cls.dlen > 0: if cls.dlen > 0:
count, data = unpack(b'!I', data) count, data = unpack(b'!I', data)
for i in range(count): for _ in range(count):
n, data = read_mpi(data) n, data = read_mpi(data)
mpis.append(n) mpis.append(n)
if len(data) > 0: if len(data) > 0:
@@ -419,3 +443,23 @@ class SMPABORTTLV(SMPTLV):
def getPayload(self): def getPayload(self):
return b'' return b''
@registertlv
class ExtraKeyTLV(TLV):
typ = 8
__slots__ = ['appid', 'appdata']
def __init__(self, appid, appdata):
super(ExtraKeyTLV, self).__init__()
self.appid = appid
self.appdata = appdata
if appdata is None:
self.appdata = b''
def getPayload(self):
return self.appid + self.appdata
@classmethod
def parsePayload(cls, data):
return cls(data[:4], data[4:])

View File

@@ -43,11 +43,12 @@ def bytes_to_long(b):
s += byte_to_long(b[i:i+1]) << 8*(l-i-1) s += byte_to_long(b[i:i+1]) << 8*(l-i-1)
return s return s
def long_to_bytes(l): def long_to_bytes(l, n=0):
b = b'' b = b''
while l != 0: while l != 0 or n > 0:
b = long_to_byte(l & 0xff) + b b = long_to_byte(l & 0xff) + b
l >>= 8 l >>= 8
n -= 1
return b return b
def byte_to_long(b): def byte_to_long(b):