gotr: update provided potr to 1.0.0beta7
This commit is contained in:
@@ -24,4 +24,4 @@ from potr.utils import human_hash
|
|||||||
|
|
||||||
''' version is: (major, minor, patch, sub) with sub being one of 'alpha',
|
''' version is: (major, minor, patch, sub) with sub being one of 'alpha',
|
||||||
'beta', 'final' '''
|
'beta', 'final' '''
|
||||||
VERSION = (1, 0, 0, 'beta5')
|
VERSION = (1, 0, 0, 'beta7')
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ from potr.utils import human_hash, bytes_to_long, unpack, pack_mpi
|
|||||||
DEFAULT_KEYTYPE = 0x0000
|
DEFAULT_KEYTYPE = 0x0000
|
||||||
pkTypes = {}
|
pkTypes = {}
|
||||||
def registerkeytype(cls):
|
def registerkeytype(cls):
|
||||||
if not hasattr(cls, 'parsePayload'):
|
if cls.keyType is None:
|
||||||
raise TypeError('registered key types need parsePayload()')
|
raise TypeError('registered key class needs a type value')
|
||||||
pkTypes[cls.keyType] = cls
|
pkTypes[cls.keyType] = cls
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
@@ -35,12 +35,16 @@ def generateDefaultKey():
|
|||||||
return pkTypes[DEFAULT_KEYTYPE].generate()
|
return pkTypes[DEFAULT_KEYTYPE].generate()
|
||||||
|
|
||||||
class PK(object):
|
class PK(object):
|
||||||
__slots__ = []
|
keyType = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate(cls):
|
def generate(cls):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parsePayload(cls, data, private=False):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def sign(self, data):
|
def sign(self, data):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
def verify(self, data):
|
def verify(self, data):
|
||||||
@@ -80,13 +84,13 @@ class PK(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def parsePrivateKey(cls, data):
|
def parsePrivateKey(cls, data):
|
||||||
implCls, data = cls.getImplementation(data)
|
implCls, data = cls.getImplementation(data)
|
||||||
logging.debug('Got privkey of type %r' % implCls)
|
logging.debug('Got privkey of type %r', implCls)
|
||||||
return implCls.parsePayload(data, private=True)
|
return implCls.parsePayload(data, private=True)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parsePublicKey(cls, data):
|
def parsePublicKey(cls, data):
|
||||||
implCls, data = cls.getImplementation(data)
|
implCls, data = cls.getImplementation(data)
|
||||||
logging.debug('Got pubkey of type %r' % implCls)
|
logging.debug('Got pubkey of type %r', implCls)
|
||||||
return implCls.parsePayload(data)
|
return implCls.parsePayload(data)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
|||||||
@@ -15,18 +15,16 @@
|
|||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
# along with this library. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from Crypto import Cipher, Random
|
from Crypto import Cipher
|
||||||
from Crypto.Hash import SHA256 as _SHA256
|
from Crypto.Hash import SHA256 as _SHA256
|
||||||
from Crypto.Hash import SHA as _SHA1
|
from Crypto.Hash import SHA as _SHA1
|
||||||
from Crypto.Hash import HMAC as _HMAC
|
from Crypto.Hash import HMAC as _HMAC
|
||||||
from Crypto.PublicKey import DSA
|
from Crypto.PublicKey import DSA
|
||||||
|
from Crypto.Random import random
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
|
|
||||||
from potr.compatcrypto import common
|
from potr.compatcrypto import common
|
||||||
from potr.utils import pack_mpi, read_mpi, bytes_to_long, long_to_bytes
|
from potr.utils import read_mpi, bytes_to_long, long_to_bytes
|
||||||
|
|
||||||
# XXX atfork?
|
|
||||||
RNG = Random.new()
|
|
||||||
|
|
||||||
def SHA256(data):
|
def SHA256(data):
|
||||||
return _SHA256.new(data).digest()
|
return _SHA256.new(data).digest()
|
||||||
@@ -54,7 +52,6 @@ def AESCTR(key, counter=0):
|
|||||||
return Cipher.AES.new(key, Cipher.AES.MODE_CTR, counter=counter)
|
return Cipher.AES.new(key, Cipher.AES.MODE_CTR, counter=counter)
|
||||||
|
|
||||||
class Counter(object):
|
class Counter(object):
|
||||||
__slots__ = ['prefix', 'val']
|
|
||||||
def __init__(self, prefix):
|
def __init__(self, prefix):
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
self.val = 0
|
self.val = 0
|
||||||
@@ -72,17 +69,15 @@ class Counter(object):
|
|||||||
return '<Counter(p={p!r},v={v!r})>'.format(p=self.prefix, v=self.val)
|
return '<Counter(p={p!r},v={v!r})>'.format(p=self.prefix, v=self.val)
|
||||||
|
|
||||||
def byteprefix(self):
|
def byteprefix(self):
|
||||||
return long_to_bytes(self.prefix).rjust(8, b'\0')
|
return long_to_bytes(self.prefix, 8)
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
val = long_to_bytes(self.val)
|
bytesuffix = long_to_bytes(self.val, 8)
|
||||||
prefix = long_to_bytes(self.prefix)
|
|
||||||
self.val += 1
|
self.val += 1
|
||||||
return self.byteprefix() + val.rjust(8, b'\0')
|
return self.byteprefix() + bytesuffix
|
||||||
|
|
||||||
@common.registerkeytype
|
@common.registerkeytype
|
||||||
class DSAKey(common.PK):
|
class DSAKey(common.PK):
|
||||||
__slots__ = ['priv', 'pub']
|
|
||||||
keyType = 0x0000
|
keyType = 0x0000
|
||||||
|
|
||||||
def __init__(self, key=None, private=False):
|
def __init__(self, key=None, private=False):
|
||||||
@@ -111,10 +106,10 @@ class DSAKey(common.PK):
|
|||||||
return SHA1(self.getSerializedPublicPayload())
|
return SHA1(self.getSerializedPublicPayload())
|
||||||
|
|
||||||
def sign(self, data):
|
def sign(self, data):
|
||||||
# 2 <= K <= q = 160bit = 20 byte
|
# 2 <= K <= q
|
||||||
K = bytes_to_long(RNG.read(19)) + 2
|
K = random.randrange(2, self.priv.q)
|
||||||
r, s = self.priv.sign(data, K)
|
r, s = self.priv.sign(data, K)
|
||||||
return long_to_bytes(r) + long_to_bytes(s)
|
return long_to_bytes(r, 20) + long_to_bytes(s, 20)
|
||||||
|
|
||||||
def verify(self, data, sig):
|
def verify(self, data, sig):
|
||||||
r, s = bytes_to_long(sig[:20]), bytes_to_long(sig[20:])
|
r, s = bytes_to_long(sig[:20]), bytes_to_long(sig[20:])
|
||||||
|
|||||||
@@ -19,7 +19,7 @@
|
|||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
try:
|
try:
|
||||||
basestring = basestring
|
type(basestring)
|
||||||
except NameError:
|
except NameError:
|
||||||
# all strings are unicode in python3k
|
# all strings are unicode in python3k
|
||||||
basestring = str
|
basestring = str
|
||||||
@@ -27,7 +27,7 @@ except NameError:
|
|||||||
|
|
||||||
# callable is not available in python 3.0 and 3.1
|
# callable is not available in python 3.0 and 3.1
|
||||||
try:
|
try:
|
||||||
callable = callable
|
type(callable)
|
||||||
except NameError:
|
except NameError:
|
||||||
from collections import Callable
|
from collections import Callable
|
||||||
def callable(x):
|
def callable(x):
|
||||||
@@ -42,6 +42,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
from potr import crypt
|
from potr import crypt
|
||||||
from potr import proto
|
from potr import proto
|
||||||
|
from potr import compatcrypto
|
||||||
|
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
@@ -62,16 +63,11 @@ OFFER_REJECTED = 2
|
|||||||
OFFER_ACCEPTED = 3
|
OFFER_ACCEPTED = 3
|
||||||
|
|
||||||
class Context(object):
|
class Context(object):
|
||||||
__slots__ = ['user', 'policy', 'crypto', 'tagOffer', 'lastSend',
|
|
||||||
'lastMessage', 'mayRetransmit', 'fragment', 'fragmentInfo', 'state',
|
|
||||||
'inject', 'trust', 'peer', 'trustName']
|
|
||||||
|
|
||||||
def __init__(self, account, peername):
|
def __init__(self, account, peername):
|
||||||
self.user = account
|
self.user = account
|
||||||
self.peer = peername
|
self.peer = peername
|
||||||
self.policy = {}
|
self.policy = {}
|
||||||
self.crypto = crypt.CryptEngine(self)
|
self.crypto = crypt.CryptEngine(self)
|
||||||
self.discardFragment()
|
|
||||||
self.tagOffer = OFFER_NOTSENT
|
self.tagOffer = OFFER_NOTSENT
|
||||||
self.mayRetransmit = 0
|
self.mayRetransmit = 0
|
||||||
self.lastSend = 0
|
self.lastSend = 0
|
||||||
@@ -79,6 +75,10 @@ class Context(object):
|
|||||||
self.state = STATE_PLAINTEXT
|
self.state = STATE_PLAINTEXT
|
||||||
self.trustName = self.peer
|
self.trustName = self.peer
|
||||||
|
|
||||||
|
self.fragmentInfo = None
|
||||||
|
self.fragment = None
|
||||||
|
self.discardFragment()
|
||||||
|
|
||||||
def getPolicy(self, key):
|
def getPolicy(self, key):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -100,13 +100,19 @@ class Context(object):
|
|||||||
params = message.split(b',')
|
params = message.split(b',')
|
||||||
if len(params) < 5 or not params[1].isdigit() or not params[2].isdigit():
|
if len(params) < 5 or not params[1].isdigit() or not params[2].isdigit():
|
||||||
logger.warning('invalid formed fragmented message: %r', params)
|
logger.warning('invalid formed fragmented message: %r', params)
|
||||||
return None
|
self.discardFragment()
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
K, N = self.fragmentInfo
|
K, N = self.fragmentInfo
|
||||||
|
try:
|
||||||
|
k = int(params[1])
|
||||||
|
n = int(params[2])
|
||||||
|
except ValueError:
|
||||||
|
logger.warning('invalid formed fragmented message: %r', params)
|
||||||
|
self.discardFragment()
|
||||||
|
return message
|
||||||
|
|
||||||
k = int(params[1])
|
|
||||||
n = int(params[2])
|
|
||||||
fragData = params[3]
|
fragData = params[3]
|
||||||
|
|
||||||
logger.debug(params)
|
logger.debug(params)
|
||||||
@@ -114,17 +120,17 @@ class Context(object):
|
|||||||
if n >= k == 1:
|
if n >= k == 1:
|
||||||
# first fragment
|
# first fragment
|
||||||
self.discardFragment()
|
self.discardFragment()
|
||||||
self.fragmentInfo = (k,n)
|
self.fragmentInfo = (k, n)
|
||||||
self.fragment.append(fragData)
|
self.fragment.append(fragData)
|
||||||
elif N == n >= k > 1 and k == K+1:
|
elif N == n >= k > 1 and k == K+1:
|
||||||
# accumulate
|
# accumulate
|
||||||
self.fragmentInfo = (k,n)
|
self.fragmentInfo = (k, n)
|
||||||
self.fragment.append(fragData)
|
self.fragment.append(fragData)
|
||||||
else:
|
else:
|
||||||
# bad, discard
|
# bad, discard
|
||||||
self.discardFragment()
|
self.discardFragment()
|
||||||
logger.warning('invalid fragmented message: %r', params)
|
logger.warning('invalid fragmented message: %r', params)
|
||||||
return None
|
return message
|
||||||
|
|
||||||
if n == k > 0:
|
if n == k > 0:
|
||||||
assembled = b''.join(self.fragment)
|
assembled = b''.join(self.fragment)
|
||||||
@@ -210,7 +216,7 @@ class Context(object):
|
|||||||
if self.state != STATE_ENCRYPTED:
|
if self.state != STATE_ENCRYPTED:
|
||||||
self.sendInternal(proto.Error(
|
self.sendInternal(proto.Error(
|
||||||
'You sent encrypted to {user}, who wasn\'t expecting it.'
|
'You sent encrypted to {user}, who wasn\'t expecting it.'
|
||||||
.format(user=self.user.name)), appdata=appdata)
|
.format(user=self.user.name).encode('utf-8')), appdata=appdata)
|
||||||
if ignore:
|
if ignore:
|
||||||
return IGN
|
return IGN
|
||||||
raise NotEncryptedError(EXC_UNREADABLE_MESSAGE)
|
raise NotEncryptedError(EXC_UNREADABLE_MESSAGE)
|
||||||
@@ -263,12 +269,13 @@ class Context(object):
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
def processOutgoingMessage(self, msg, flags, tlvs=[]):
|
def processOutgoingMessage(self, msg, flags, tlvs=[]):
|
||||||
if isinstance(self.parse(msg), proto.Query):
|
isQuery = self.parseExplicitQuery(msg) is not None
|
||||||
|
if isQuery:
|
||||||
return self.user.getDefaultQueryMessage(self.getPolicy)
|
return self.user.getDefaultQueryMessage(self.getPolicy)
|
||||||
|
|
||||||
if self.state == STATE_PLAINTEXT:
|
if self.state == STATE_PLAINTEXT:
|
||||||
if self.getPolicy('REQUIRE_ENCRYPTION'):
|
if self.getPolicy('REQUIRE_ENCRYPTION'):
|
||||||
if not isinstance(self.parse(msg), proto.Query):
|
if not isQuery:
|
||||||
self.lastMessage = msg
|
self.lastMessage = msg
|
||||||
self.lastSend = time()
|
self.lastSend = time()
|
||||||
self.mayRetransmit = 2
|
self.mayRetransmit = 2
|
||||||
@@ -277,8 +284,12 @@ class Context(object):
|
|||||||
return msg
|
return msg
|
||||||
if self.getPolicy('SEND_TAG') and self.tagOffer != OFFER_REJECTED:
|
if self.getPolicy('SEND_TAG') and self.tagOffer != OFFER_REJECTED:
|
||||||
self.tagOffer = OFFER_SENT
|
self.tagOffer = OFFER_SENT
|
||||||
return proto.TaggedPlaintext(msg, self.getPolicy('ALLOW_V1'),
|
versions = set()
|
||||||
self.getPolicy('ALLOW_V2'))
|
if self.getPolicy('ALLOW_V1'):
|
||||||
|
versions.add(1)
|
||||||
|
if self.getPolicy('ALLOW_V2'):
|
||||||
|
versions.add(2)
|
||||||
|
return proto.TaggedPlaintext(msg, versions)
|
||||||
return msg
|
return msg
|
||||||
if self.state == STATE_ENCRYPTED:
|
if self.state == STATE_ENCRYPTED:
|
||||||
msg = self.crypto.createDataMessage(msg, flags, tlvs)
|
msg = self.crypto.createDataMessage(msg, flags, tlvs)
|
||||||
@@ -304,9 +315,9 @@ class Context(object):
|
|||||||
def sendFragmented(self, msg, policy=FRAGMENT_SEND_ALL, appdata=None):
|
def sendFragmented(self, msg, policy=FRAGMENT_SEND_ALL, appdata=None):
|
||||||
mms = self.maxMessageSize(appdata)
|
mms = self.maxMessageSize(appdata)
|
||||||
msgLen = len(msg)
|
msgLen = len(msg)
|
||||||
if mms != 0 and len(msg) > mms:
|
if mms != 0 and msgLen > mms:
|
||||||
fms = mms - 19
|
fms = mms - 19
|
||||||
fragments = [ msg[i:i+fms] for i in range(0, len(msg), fms) ]
|
fragments = [ msg[i:i+fms] for i in range(0, msgLen, fms) ]
|
||||||
|
|
||||||
fc = len(fragments)
|
fc = len(fragments)
|
||||||
|
|
||||||
@@ -375,9 +386,9 @@ class Context(object):
|
|||||||
self.crypto.smpSecret(secret, question=question, appdata=appdata)
|
self.crypto.smpSecret(secret, question=question, appdata=appdata)
|
||||||
|
|
||||||
def handleQuery(self, message, appdata=None):
|
def handleQuery(self, message, appdata=None):
|
||||||
if message.v2 and self.getPolicy('ALLOW_V2'):
|
if 2 in message.versions and self.getPolicy('ALLOW_V2'):
|
||||||
self.authStartV2(appdata=appdata)
|
self.authStartV2(appdata=appdata)
|
||||||
elif message.v1 and self.getPolicy('ALLOW_V1'):
|
elif 1 in message.versions and self.getPolicy('ALLOW_V1'):
|
||||||
self.authStartV1(appdata=appdata)
|
self.authStartV1(appdata=appdata)
|
||||||
|
|
||||||
def authStartV1(self, appdata=None):
|
def authStartV1(self, appdata=None):
|
||||||
@@ -386,7 +397,33 @@ class Context(object):
|
|||||||
def authStartV2(self, appdata=None):
|
def authStartV2(self, appdata=None):
|
||||||
self.crypto.startAKE(appdata=appdata)
|
self.crypto.startAKE(appdata=appdata)
|
||||||
|
|
||||||
def parse(self, message):
|
def parseExplicitQuery(self, message):
|
||||||
|
otrTagPos = message.find(proto.OTRTAG)
|
||||||
|
|
||||||
|
if otrTagPos == -1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
indexBase = otrTagPos + len(proto.OTRTAG)
|
||||||
|
|
||||||
|
if len(message) <= indexBase:
|
||||||
|
return None
|
||||||
|
|
||||||
|
compare = message[indexBase]
|
||||||
|
|
||||||
|
hasq = compare == b'?'[0]
|
||||||
|
hasv = compare == b'v'[0]
|
||||||
|
|
||||||
|
if not hasq and not hasv:
|
||||||
|
return None
|
||||||
|
|
||||||
|
hasv |= len(message) > indexBase+1 and message[indexBase+1] == b'v'[0]
|
||||||
|
if hasv:
|
||||||
|
end = message.find(b'?', indexBase+1)
|
||||||
|
else:
|
||||||
|
end = indexBase+1
|
||||||
|
return message[indexBase:end]
|
||||||
|
|
||||||
|
def parse(self, message, nofragment=False):
|
||||||
otrTagPos = message.find(proto.OTRTAG)
|
otrTagPos = message.find(proto.OTRTAG)
|
||||||
if otrTagPos == -1:
|
if otrTagPos == -1:
|
||||||
if proto.MESSAGE_TAG_BASE in message:
|
if proto.MESSAGE_TAG_BASE in message:
|
||||||
@@ -395,38 +432,40 @@ class Context(object):
|
|||||||
return message
|
return message
|
||||||
|
|
||||||
indexBase = otrTagPos + len(proto.OTRTAG)
|
indexBase = otrTagPos + len(proto.OTRTAG)
|
||||||
|
|
||||||
|
if len(message) <= indexBase:
|
||||||
|
return message
|
||||||
|
|
||||||
compare = message[indexBase]
|
compare = message[indexBase]
|
||||||
|
|
||||||
if compare == b','[0]:
|
if nofragment is False and compare == b','[0]:
|
||||||
message = self.fragmentAccumulate(message[indexBase:])
|
message = self.fragmentAccumulate(message[indexBase:])
|
||||||
if message is None:
|
if message is None:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return self.parse(message)
|
return self.parse(message, nofragment=True)
|
||||||
else:
|
else:
|
||||||
self.discardFragment()
|
self.discardFragment()
|
||||||
|
|
||||||
hasq = compare == b'?'[0]
|
queryPayload = self.parseExplicitQuery(message)
|
||||||
hasv = compare == b'v'[0]
|
if queryPayload is not None:
|
||||||
if hasq or hasv:
|
return proto.Query.parse(queryPayload)
|
||||||
hasv |= len(message) > indexBase+1 and \
|
|
||||||
message[indexBase+1] == b'v'[0]
|
|
||||||
if hasv:
|
|
||||||
end = message.find(b'?', indexBase+1)
|
|
||||||
else:
|
|
||||||
end = indexBase+1
|
|
||||||
payload = message[indexBase:end]
|
|
||||||
return proto.Query.parse(payload)
|
|
||||||
|
|
||||||
if compare == b':'[0] and len(message) > indexBase + 4:
|
if compare == b':'[0] and len(message) > indexBase + 4:
|
||||||
infoTag = base64.b64decode(message[indexBase+1:indexBase+5])
|
try:
|
||||||
classInfo = struct.unpack(b'!HB', infoTag)
|
infoTag = base64.b64decode(message[indexBase+1:indexBase+5])
|
||||||
cls = proto.messageClasses.get(classInfo, None)
|
classInfo = struct.unpack(b'!HB', infoTag)
|
||||||
if cls is None:
|
|
||||||
|
cls = proto.messageClasses.get(classInfo, None)
|
||||||
|
if cls is None:
|
||||||
|
return message
|
||||||
|
|
||||||
|
logger.debug('{user} got msg {typ!r}' \
|
||||||
|
.format(user=self.user.name, typ=cls))
|
||||||
|
return cls.parsePayload(message[indexBase+5:])
|
||||||
|
except (TypeError, struct.error):
|
||||||
|
logger.exception('could not parse OTR message %s', message)
|
||||||
return message
|
return message
|
||||||
logger.debug('{user} got msg {typ!r}' \
|
|
||||||
.format(user=self.user.name, typ=cls))
|
|
||||||
return cls.parsePayload(message[indexBase+5:])
|
|
||||||
|
|
||||||
if message[indexBase:indexBase+7] == b' Error:':
|
if message[indexBase:indexBase+7] == b' Error:':
|
||||||
return proto.Error(message[indexBase+7:])
|
return proto.Error(message[indexBase+7:])
|
||||||
@@ -437,6 +476,22 @@ class Context(object):
|
|||||||
"""Return the max message size for this context."""
|
"""Return the max message size for this context."""
|
||||||
return self.user.maxMessageSize
|
return self.user.maxMessageSize
|
||||||
|
|
||||||
|
def getExtraKey(self, extraKeyAppId=None, extraKeyAppData=None, appdata=None):
|
||||||
|
""" retrieves the generated extra symmetric key.
|
||||||
|
|
||||||
|
if extraKeyAppId is set, notifies the chat partner about intended
|
||||||
|
usage (additional application specific information can be supplied in
|
||||||
|
extraKeyAppData).
|
||||||
|
|
||||||
|
returns the 256 bit symmetric key """
|
||||||
|
|
||||||
|
if self.state != STATE_ENCRYPTED:
|
||||||
|
raise NotEncryptedError
|
||||||
|
if extraKeyAppId is not None:
|
||||||
|
tlvs = [proto.ExtraKeyTLV(extraKeyAppId, extraKeyAppData)]
|
||||||
|
self.sendInternal(b'', tlvs=tlvs, appdata=appdata)
|
||||||
|
return self.crypto.extraKey
|
||||||
|
|
||||||
class Account(object):
|
class Account(object):
|
||||||
contextclass = Context
|
contextclass = Context
|
||||||
def __init__(self, name, protocol, maxMessageSize, privkey=None):
|
def __init__(self, name, protocol, maxMessageSize, privkey=None):
|
||||||
@@ -447,10 +502,10 @@ class Account(object):
|
|||||||
self.ctxs = {}
|
self.ctxs = {}
|
||||||
self.trusts = {}
|
self.trusts = {}
|
||||||
self.maxMessageSize = maxMessageSize
|
self.maxMessageSize = maxMessageSize
|
||||||
self.defaultQuery = b'?OTRv{versions}?\n{accountname} has requested ' \
|
self.defaultQuery = '?OTRv{versions}?\n{accountname} has requested ' \
|
||||||
b'an Off-the-Record private conversation. However, you ' \
|
'an Off-the-Record private conversation. However, you ' \
|
||||||
b'do not have a plugin to support that.\nSee '\
|
'do not have a plugin to support that.\nSee '\
|
||||||
b'http://otr.cypherpunks.ca/ for more information.';
|
'http://otr.cypherpunks.ca/ for more information.'
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<{cls}(name={name!r})>'.format(cls=self.__class__.__name__,
|
return '<{cls}(name={name!r})>'.format(cls=self.__class__.__name__,
|
||||||
@@ -461,7 +516,7 @@ class Account(object):
|
|||||||
self.privkey = self.loadPrivkey()
|
self.privkey = self.loadPrivkey()
|
||||||
if self.privkey is None:
|
if self.privkey is None:
|
||||||
if autogen is True:
|
if autogen is True:
|
||||||
self.privkey = crypt.generateDefaultKey()
|
self.privkey = compatcrypto.generateDefaultKey()
|
||||||
self.savePrivkey()
|
self.savePrivkey()
|
||||||
else:
|
else:
|
||||||
raise LookupError
|
raise LookupError
|
||||||
@@ -484,8 +539,9 @@ class Account(object):
|
|||||||
return self.ctxs[uid]
|
return self.ctxs[uid]
|
||||||
|
|
||||||
def getDefaultQueryMessage(self, policy):
|
def getDefaultQueryMessage(self, policy):
|
||||||
v = b'2' if policy('ALLOW_V2') else b''
|
v = '2' if policy('ALLOW_V2') else ''
|
||||||
return self.defaultQuery.format(accountname=self.name, versions=v)
|
msg = self.defaultQuery.format(accountname=self.name, versions=v)
|
||||||
|
return msg.encode('ascii')
|
||||||
|
|
||||||
def setTrust(self, key, fingerprint, trustLevel):
|
def setTrust(self, key, fingerprint, trustLevel):
|
||||||
if key not in self.trusts:
|
if key not in self.trusts:
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ import logging
|
|||||||
import struct
|
import struct
|
||||||
|
|
||||||
|
|
||||||
from potr.compatcrypto import SHA256, SHA1, HMAC, SHA1HMAC, SHA256HMAC, \
|
from potr.compatcrypto import SHA256, SHA1, SHA1HMAC, SHA256HMAC, \
|
||||||
SHA256HMAC160, Counter, AESCTR, RNG, PK, generateDefaultKey
|
SHA256HMAC160, Counter, AESCTR, PK, random
|
||||||
from potr.utils import bytes_to_long, long_to_bytes, pack_mpi, read_mpi
|
from potr.utils import bytes_to_long, long_to_bytes, pack_mpi, read_mpi
|
||||||
from potr import proto
|
from potr import proto
|
||||||
|
|
||||||
@@ -36,32 +36,31 @@ STATE_AWAITING_SIG = 4
|
|||||||
STATE_V1_SETUP = 5
|
STATE_V1_SETUP = 5
|
||||||
|
|
||||||
|
|
||||||
DH1536_MODULUS = 2410312426921032588552076022197566074856950548502459942654116941958108831682612228890093858261341614673227141477904012196503648957050582631942730706805009223062734745341073406696246014589361659774041027169249453200378729434170325843778659198143763193776859869524088940195577346119843545301547043747207749969763750084308926339295559968882457872412993810129130294592999947926365264059284647209730384947211681434464714438488520940127459844288859336526896320919633919
|
DH_MODULUS = 2410312426921032588552076022197566074856950548502459942654116941958108831682612228890093858261341614673227141477904012196503648957050582631942730706805009223062734745341073406696246014589361659774041027169249453200378729434170325843778659198143763193776859869524088940195577346119843545301547043747207749969763750084308926339295559968882457872412993810129130294592999947926365264059284647209730384947211681434464714438488520940127459844288859336526896320919633919
|
||||||
DH1536_MODULUS_2 = DH1536_MODULUS-2
|
DH_MODULUS_2 = DH_MODULUS-2
|
||||||
DH1536_GENERATOR = 2
|
DH_GENERATOR = 2
|
||||||
SM_ORDER = (DH1536_MODULUS - 1) // 2
|
DH_BITS = 1536
|
||||||
|
DH_MAX = 2**DH_BITS
|
||||||
|
SM_ORDER = (DH_MODULUS - 1) // 2
|
||||||
|
|
||||||
def check_group(n):
|
def check_group(n):
|
||||||
return 2 <= n <= DH1536_MODULUS_2
|
return 2 <= n <= DH_MODULUS_2
|
||||||
def check_exp(n):
|
def check_exp(n):
|
||||||
return 1 <= n < SM_ORDER
|
return 1 <= n < SM_ORDER
|
||||||
|
|
||||||
class DH(object):
|
class DH(object):
|
||||||
__slots__ = ['priv', 'pub']
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_params(cls, prime, gen):
|
def set_params(cls, prime, gen):
|
||||||
cls.prime = prime
|
cls.prime = prime
|
||||||
cls.gen = gen
|
cls.gen = gen
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.priv = bytes_to_long(RNG.read(40))
|
self.priv = random.randrange(2, 2**320)
|
||||||
self.pub = pow(self.gen, self.priv, self.prime)
|
self.pub = pow(self.gen, self.priv, self.prime)
|
||||||
|
|
||||||
DH.set_params(DH1536_MODULUS, DH1536_GENERATOR)
|
DH.set_params(DH_MODULUS, DH_GENERATOR)
|
||||||
|
|
||||||
class DHSession(object):
|
class DHSession(object):
|
||||||
__slots__ = ['sendenc', 'sendmac', 'rcvenc', 'rcvmac', 'sendctr', 'rcvctr',
|
|
||||||
'sendmacused', 'rcvmacused']
|
|
||||||
def __init__(self, sendenc, sendmac, rcvenc, rcvmac):
|
def __init__(self, sendenc, sendmac, rcvenc, rcvmac):
|
||||||
self.sendenc = sendenc
|
self.sendenc = sendenc
|
||||||
self.sendmac = sendmac
|
self.sendmac = sendmac
|
||||||
@@ -79,7 +78,7 @@ class DHSession(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, dh, y):
|
def create(cls, dh, y):
|
||||||
s = pow(y, dh.priv, DH1536_MODULUS)
|
s = pow(y, dh.priv, DH_MODULUS)
|
||||||
sb = pack_mpi(s)
|
sb = pack_mpi(s)
|
||||||
|
|
||||||
if dh.pub > y:
|
if dh.pub > y:
|
||||||
@@ -96,9 +95,6 @@ class DHSession(object):
|
|||||||
return cls(sendenc, sendmac, rcvenc, rcvmac)
|
return cls(sendenc, sendmac, rcvenc, rcvmac)
|
||||||
|
|
||||||
class CryptEngine(object):
|
class CryptEngine(object):
|
||||||
__slots__ = ['ctx', 'ake', 'sessionId', 'sessionIdHalf', 'theirKeyid',
|
|
||||||
'theirY', 'theirOldY', 'ourOldDHKey', 'ourDHKey', 'ourKeyid',
|
|
||||||
'sessionkeys', 'theirPubkey', 'savedMacKeys', 'smp']
|
|
||||||
def __init__(self, ctx):
|
def __init__(self, ctx):
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
self.ake = None
|
self.ake = None
|
||||||
@@ -118,6 +114,7 @@ class CryptEngine(object):
|
|||||||
self.savedMacKeys = []
|
self.savedMacKeys = []
|
||||||
|
|
||||||
self.smp = None
|
self.smp = None
|
||||||
|
self.extraKey = None
|
||||||
|
|
||||||
def revealMacs(self, ours=True):
|
def revealMacs(self, ours=True):
|
||||||
if ours:
|
if ours:
|
||||||
@@ -174,7 +171,7 @@ class CryptEngine(object):
|
|||||||
if msg.mac != SHA1HMAC(sesskey.rcvmac, msg.getMacedData()):
|
if msg.mac != SHA1HMAC(sesskey.rcvmac, msg.getMacedData()):
|
||||||
logger.error('HMACs don\'t match')
|
logger.error('HMACs don\'t match')
|
||||||
raise InvalidParameterError
|
raise InvalidParameterError
|
||||||
sesskey.rcvmacused = 1
|
sesskey.rcvmacused = True
|
||||||
|
|
||||||
newCtrPrefix = bytes_to_long(msg.ctr)
|
newCtrPrefix = bytes_to_long(msg.ctr)
|
||||||
if newCtrPrefix <= sesskey.rcvctr.prefix:
|
if newCtrPrefix <= sesskey.rcvctr.prefix:
|
||||||
@@ -223,11 +220,14 @@ class CryptEngine(object):
|
|||||||
self.smp = SMPHandler(self)
|
self.smp = SMPHandler(self)
|
||||||
self.smp.abort(appdata=appdata)
|
self.smp.abort(appdata=appdata)
|
||||||
|
|
||||||
def createDataMessage(self, message, flags=0, tlvs=[]):
|
def createDataMessage(self, message, flags=0, tlvs=None):
|
||||||
# check MSGSTATE
|
# check MSGSTATE
|
||||||
if self.theirKeyid == 0:
|
if self.theirKeyid == 0:
|
||||||
raise InvalidParameterError
|
raise InvalidParameterError
|
||||||
|
|
||||||
|
if tlvs is None:
|
||||||
|
tlvs = []
|
||||||
|
|
||||||
sess = self.sessionkeys[1][0]
|
sess = self.sessionkeys[1][0]
|
||||||
sess.sendctr.inc()
|
sess.sendctr.inc()
|
||||||
|
|
||||||
@@ -303,13 +303,16 @@ class CryptEngine(object):
|
|||||||
self.ourKeyid = ake.ourKeyid
|
self.ourKeyid = ake.ourKeyid
|
||||||
self.theirY = ake.gy
|
self.theirY = ake.gy
|
||||||
self.theirOldY = None
|
self.theirOldY = None
|
||||||
|
self.extraKey = ake.extraKey
|
||||||
|
|
||||||
if self.ourKeyid != ake.ourKeyid + 1 or self.ourOldDHKey != ake.dh.pub:
|
if self.ourKeyid != ake.ourKeyid + 1 or self.ourOldDHKey != ake.dh.pub:
|
||||||
# XXX is this really ok?
|
|
||||||
self.ourDHKey = ake.dh
|
self.ourDHKey = ake.dh
|
||||||
self.sessionkeys[0][0] = DHSession.create(self.ourDHKey, self.theirY)
|
self.sessionkeys[0][0] = DHSession.create(self.ourDHKey, self.theirY)
|
||||||
self.rotateDHKeys()
|
self.rotateDHKeys()
|
||||||
|
|
||||||
|
# we don't need the AKE anymore, free the reference
|
||||||
|
self.ake = None
|
||||||
|
|
||||||
self.ctx._wentEncrypted()
|
self.ctx._wentEncrypted()
|
||||||
logger.info('went encrypted with {0}'.format(self.theirPubkey))
|
logger.info('went encrypted with {0}'.format(self.theirPubkey))
|
||||||
|
|
||||||
@@ -317,10 +320,6 @@ class CryptEngine(object):
|
|||||||
self.smp = None
|
self.smp = None
|
||||||
|
|
||||||
class AuthKeyExchange(object):
|
class AuthKeyExchange(object):
|
||||||
__slots__ = ['privkey', 'state', 'r', 'encgx', 'hashgx', 'ourKeyid',
|
|
||||||
'theirPubkey', 'theirKeyid', 'enc_c', 'enc_cp', 'mac_m1',
|
|
||||||
'mac_m1p', 'mac_m2', 'mac_m2p', 'sessionId', 'dh', 'onSuccess',
|
|
||||||
'gy', 'lastmsg', 'sessionIdHalf']
|
|
||||||
def __init__(self, privkey, onSuccess):
|
def __init__(self, privkey, onSuccess):
|
||||||
self.privkey = privkey
|
self.privkey = privkey
|
||||||
self.state = STATE_NONE
|
self.state = STATE_NONE
|
||||||
@@ -341,9 +340,11 @@ class AuthKeyExchange(object):
|
|||||||
self.dh = DH()
|
self.dh = DH()
|
||||||
self.onSuccess = onSuccess
|
self.onSuccess = onSuccess
|
||||||
self.gy = None
|
self.gy = None
|
||||||
|
self.extraKey = None
|
||||||
|
self.lastmsg = None
|
||||||
|
|
||||||
def startAKE(self):
|
def startAKE(self):
|
||||||
self.r = RNG.read(16)
|
self.r = long_to_bytes(random.getrandbits(128))
|
||||||
|
|
||||||
gxmpi = pack_mpi(self.dh.pub)
|
gxmpi = pack_mpi(self.dh.pub)
|
||||||
|
|
||||||
@@ -444,15 +445,17 @@ class AuthKeyExchange(object):
|
|||||||
self.state = STATE_NONE
|
self.state = STATE_NONE
|
||||||
|
|
||||||
def createAuthKeys(self):
|
def createAuthKeys(self):
|
||||||
s = pow(self.gy, self.dh.priv, DH1536_MODULUS)
|
s = pow(self.gy, self.dh.priv, DH_MODULUS)
|
||||||
sbyte = pack_mpi(s)
|
sbyte = pack_mpi(s)
|
||||||
self.sessionId = SHA256(b'\0' + sbyte)[:8]
|
self.sessionId = SHA256(b'\x00' + sbyte)[:8]
|
||||||
enc = SHA256(b'\1' + sbyte)
|
enc = SHA256(b'\x01' + sbyte)
|
||||||
self.enc_c, self.enc_cp = enc[:16], enc[16:]
|
self.enc_c = enc[:16]
|
||||||
self.mac_m1 = SHA256(b'\2' + sbyte)
|
self.enc_cp = enc[16:]
|
||||||
self.mac_m2 = SHA256(b'\3' + sbyte)
|
self.mac_m1 = SHA256(b'\x02' + sbyte)
|
||||||
self.mac_m1p = SHA256(b'\4' + sbyte)
|
self.mac_m2 = SHA256(b'\x03' + sbyte)
|
||||||
self.mac_m2p = SHA256(b'\5' + sbyte)
|
self.mac_m1p = SHA256(b'\x04' + sbyte)
|
||||||
|
self.mac_m2p = SHA256(b'\x05' + sbyte)
|
||||||
|
self.extraKey = SHA256(b'\xff' + sbyte)
|
||||||
|
|
||||||
def calculatePubkeyAuth(self, key, mackey):
|
def calculatePubkeyAuth(self, key, mackey):
|
||||||
pubkey = self.privkey.serializePublicKey()
|
pubkey = self.privkey.serializePublicKey()
|
||||||
@@ -490,14 +493,15 @@ SMPPROG_FAILED = -1
|
|||||||
SMPPROG_SUCCEEDED = 1
|
SMPPROG_SUCCEEDED = 1
|
||||||
|
|
||||||
class SMPHandler:
|
class SMPHandler:
|
||||||
__slots__ = ['crypto', 'questionReceived', 'prog', 'state', 'g1', 'g3o',
|
|
||||||
'x2', 'x3', 'g2', 'g3', 'pab', 'qab', 'secret', 'p', 'q']
|
|
||||||
|
|
||||||
def __init__(self, crypto):
|
def __init__(self, crypto):
|
||||||
self.crypto = crypto
|
self.crypto = crypto
|
||||||
self.state = 1
|
self.state = 1
|
||||||
self.g1 = DH1536_GENERATOR
|
self.g1 = DH_GENERATOR
|
||||||
|
self.g2 = None
|
||||||
|
self.g3 = None
|
||||||
self.g3o = None
|
self.g3o = None
|
||||||
|
self.x2 = None
|
||||||
|
self.x3 = None
|
||||||
self.prog = SMPPROG_OK
|
self.prog = SMPPROG_OK
|
||||||
self.pab = None
|
self.pab = None
|
||||||
self.qab = None
|
self.qab = None
|
||||||
@@ -539,11 +543,11 @@ class SMPHandler:
|
|||||||
|
|
||||||
self.g3o = msg[3]
|
self.g3o = msg[3]
|
||||||
|
|
||||||
self.x2 = bytes_to_long(RNG.read(192))
|
self.x2 = random.randrange(2, DH_MAX)
|
||||||
self.x3 = bytes_to_long(RNG.read(192))
|
self.x3 = random.randrange(2, DH_MAX)
|
||||||
|
|
||||||
self.g2 = pow(msg[0], self.x2, DH1536_MODULUS)
|
self.g2 = pow(msg[0], self.x2, DH_MODULUS)
|
||||||
self.g3 = pow(msg[3], self.x3, DH1536_MODULUS)
|
self.g3 = pow(msg[3], self.x3, DH_MODULUS)
|
||||||
|
|
||||||
self.prog = SMPPROG_OK
|
self.prog = SMPPROG_OK
|
||||||
self.state = 0
|
self.state = 0
|
||||||
@@ -568,29 +572,29 @@ class SMPHandler:
|
|||||||
return
|
return
|
||||||
|
|
||||||
self.g3o = msg[3]
|
self.g3o = msg[3]
|
||||||
self.g2 = pow(msg[0], self.x2, DH1536_MODULUS)
|
self.g2 = pow(msg[0], self.x2, DH_MODULUS)
|
||||||
self.g3 = pow(msg[3], self.x3, DH1536_MODULUS)
|
self.g3 = pow(msg[3], self.x3, DH_MODULUS)
|
||||||
|
|
||||||
if not self.check_equal_coords(msg[6:11], 5):
|
if not self.check_equal_coords(msg[6:11], 5):
|
||||||
logger.error('invalid SMP2TLV received')
|
logger.error('invalid SMP2TLV received')
|
||||||
self.abort(appdata=appdata)
|
self.abort(appdata=appdata)
|
||||||
return
|
return
|
||||||
|
|
||||||
r = bytes_to_long(RNG.read(192))
|
r = random.randrange(2, DH_MAX)
|
||||||
self.p = pow(self.g3, r, DH1536_MODULUS)
|
self.p = pow(self.g3, r, DH_MODULUS)
|
||||||
msg = [self.p]
|
msg = [self.p]
|
||||||
qa1 = pow(self.g1, r, DH1536_MODULUS)
|
qa1 = pow(self.g1, r, DH_MODULUS)
|
||||||
qa2 = pow(self.g2, self.secret, DH1536_MODULUS)
|
qa2 = pow(self.g2, self.secret, DH_MODULUS)
|
||||||
self.q = qa1*qa2 % DH1536_MODULUS
|
self.q = qa1*qa2 % DH_MODULUS
|
||||||
msg.append(self.q)
|
msg.append(self.q)
|
||||||
msg += self.proof_equal_coords(r, 6)
|
msg += self.proof_equal_coords(r, 6)
|
||||||
|
|
||||||
inv = invMod(mp)
|
inv = invMod(mp)
|
||||||
self.pab = self.p * inv % DH1536_MODULUS
|
self.pab = self.p * inv % DH_MODULUS
|
||||||
inv = invMod(mq)
|
inv = invMod(mq)
|
||||||
self.qab = self.q * inv % DH1536_MODULUS
|
self.qab = self.q * inv % DH_MODULUS
|
||||||
|
|
||||||
msg.append(pow(self.qab, self.x3, DH1536_MODULUS))
|
msg.append(pow(self.qab, self.x3, DH_MODULUS))
|
||||||
msg += self.proof_equal_logs(7)
|
msg += self.proof_equal_logs(7)
|
||||||
|
|
||||||
self.state = 4
|
self.state = 4
|
||||||
@@ -613,9 +617,9 @@ class SMPHandler:
|
|||||||
return
|
return
|
||||||
|
|
||||||
inv = invMod(self.p)
|
inv = invMod(self.p)
|
||||||
self.pab = msg[0] * inv % DH1536_MODULUS
|
self.pab = msg[0] * inv % DH_MODULUS
|
||||||
inv = invMod(self.q)
|
inv = invMod(self.q)
|
||||||
self.qab = msg[1] * inv % DH1536_MODULUS
|
self.qab = msg[1] * inv % DH_MODULUS
|
||||||
|
|
||||||
if not self.check_equal_logs(msg[5:8], 7):
|
if not self.check_equal_logs(msg[5:8], 7):
|
||||||
logger.error('invalid SMP3TLV received')
|
logger.error('invalid SMP3TLV received')
|
||||||
@@ -623,10 +627,10 @@ class SMPHandler:
|
|||||||
return
|
return
|
||||||
|
|
||||||
md = msg[5]
|
md = msg[5]
|
||||||
msg = [pow(self.qab, self.x3, DH1536_MODULUS)]
|
msg = [pow(self.qab, self.x3, DH_MODULUS)]
|
||||||
msg += self.proof_equal_logs(8)
|
msg += self.proof_equal_logs(8)
|
||||||
|
|
||||||
rab = pow(md, self.x3, DH1536_MODULUS)
|
rab = pow(md, self.x3, DH_MODULUS)
|
||||||
self.prog = SMPPROG_SUCCEEDED if self.pab == rab else SMPPROG_FAILED
|
self.prog = SMPPROG_SUCCEEDED if self.pab == rab else SMPPROG_FAILED
|
||||||
|
|
||||||
if self.prog != SMPPROG_SUCCEEDED:
|
if self.prog != SMPPROG_SUCCEEDED:
|
||||||
@@ -654,7 +658,7 @@ class SMPHandler:
|
|||||||
self.abort(appdata=appdata)
|
self.abort(appdata=appdata)
|
||||||
return
|
return
|
||||||
|
|
||||||
rab = pow(msg[0], self.x3, DH1536_MODULUS)
|
rab = pow(msg[0], self.x3, DH_MODULUS)
|
||||||
|
|
||||||
self.prog = SMPPROG_SUCCEEDED if self.pab == rab else SMPPROG_FAILED
|
self.prog = SMPPROG_SUCCEEDED if self.pab == rab else SMPPROG_FAILED
|
||||||
|
|
||||||
@@ -679,12 +683,12 @@ class SMPHandler:
|
|||||||
|
|
||||||
self.secret = bytes_to_long(combSecret)
|
self.secret = bytes_to_long(combSecret)
|
||||||
|
|
||||||
self.x2 = bytes_to_long(RNG.read(192))
|
self.x2 = random.randrange(2, DH_MAX)
|
||||||
self.x3 = bytes_to_long(RNG.read(192))
|
self.x3 = random.randrange(2, DH_MAX)
|
||||||
|
|
||||||
msg = [pow(self.g1, self.x2, DH1536_MODULUS)]
|
msg = [pow(self.g1, self.x2, DH_MODULUS)]
|
||||||
msg += proof_known_log(self.g1, self.x2, 1)
|
msg += proof_known_log(self.g1, self.x2, 1)
|
||||||
msg.append(pow(self.g1, self.x3, DH1536_MODULUS))
|
msg.append(pow(self.g1, self.x3, DH_MODULUS))
|
||||||
msg += proof_known_log(self.g1, self.x3, 2)
|
msg += proof_known_log(self.g1, self.x3, 2)
|
||||||
|
|
||||||
self.prog = SMPPROG_OK
|
self.prog = SMPPROG_OK
|
||||||
@@ -700,19 +704,19 @@ class SMPHandler:
|
|||||||
|
|
||||||
self.secret = bytes_to_long(combSecret)
|
self.secret = bytes_to_long(combSecret)
|
||||||
|
|
||||||
msg = [pow(self.g1, self.x2, DH1536_MODULUS)]
|
msg = [pow(self.g1, self.x2, DH_MODULUS)]
|
||||||
msg += proof_known_log(self.g1, self.x2, 3)
|
msg += proof_known_log(self.g1, self.x2, 3)
|
||||||
msg.append(pow(self.g1, self.x3, DH1536_MODULUS))
|
msg.append(pow(self.g1, self.x3, DH_MODULUS))
|
||||||
msg += proof_known_log(self.g1, self.x3, 4)
|
msg += proof_known_log(self.g1, self.x3, 4)
|
||||||
|
|
||||||
r = bytes_to_long(RNG.read(192))
|
r = random.randrange(2, DH_MAX)
|
||||||
|
|
||||||
self.p = pow(self.g3, r, DH1536_MODULUS)
|
self.p = pow(self.g3, r, DH_MODULUS)
|
||||||
msg.append(self.p)
|
msg.append(self.p)
|
||||||
|
|
||||||
qb1 = pow(self.g1, r, DH1536_MODULUS)
|
qb1 = pow(self.g1, r, DH_MODULUS)
|
||||||
qb2 = pow(self.g2, self.secret, DH1536_MODULUS)
|
qb2 = pow(self.g2, self.secret, DH_MODULUS)
|
||||||
self.q = qb1 * qb2 % DH1536_MODULUS
|
self.q = qb1 * qb2 % DH_MODULUS
|
||||||
msg.append(self.q)
|
msg.append(self.q)
|
||||||
|
|
||||||
msg += self.proof_equal_coords(r, 5)
|
msg += self.proof_equal_coords(r, 5)
|
||||||
@@ -721,11 +725,11 @@ class SMPHandler:
|
|||||||
self.sendTLV(proto.SMP2TLV(msg), appdata=appdata)
|
self.sendTLV(proto.SMP2TLV(msg), appdata=appdata)
|
||||||
|
|
||||||
def proof_equal_coords(self, r, v):
|
def proof_equal_coords(self, r, v):
|
||||||
r1 = bytes_to_long(RNG.read(192))
|
r1 = random.randrange(2, DH_MAX)
|
||||||
r2 = bytes_to_long(RNG.read(192))
|
r2 = random.randrange(2, DH_MAX)
|
||||||
temp2 = pow(self.g1, r1, DH1536_MODULUS) \
|
temp2 = pow(self.g1, r1, DH_MODULUS) \
|
||||||
* pow(self.g2, r2, DH1536_MODULUS) % DH1536_MODULUS
|
* pow(self.g2, r2, DH_MODULUS) % DH_MODULUS
|
||||||
temp1 = pow(self.g3, r1, DH1536_MODULUS)
|
temp1 = pow(self.g3, r1, DH_MODULUS)
|
||||||
|
|
||||||
cb = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
|
cb = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
|
||||||
c = bytes_to_long(cb)
|
c = bytes_to_long(cb)
|
||||||
@@ -739,21 +743,21 @@ class SMPHandler:
|
|||||||
|
|
||||||
def check_equal_coords(self, coords, v):
|
def check_equal_coords(self, coords, v):
|
||||||
(p, q, c, d1, d2) = coords
|
(p, q, c, d1, d2) = coords
|
||||||
temp1 = pow(self.g3, d1, DH1536_MODULUS) * pow(p, c, DH1536_MODULUS) \
|
temp1 = pow(self.g3, d1, DH_MODULUS) * pow(p, c, DH_MODULUS) \
|
||||||
% DH1536_MODULUS
|
% DH_MODULUS
|
||||||
|
|
||||||
temp2 = pow(self.g1, d1, DH1536_MODULUS) \
|
temp2 = pow(self.g1, d1, DH_MODULUS) \
|
||||||
* pow(self.g2, d2, DH1536_MODULUS) \
|
* pow(self.g2, d2, DH_MODULUS) \
|
||||||
* pow(q, c, DH1536_MODULUS) % DH1536_MODULUS
|
* pow(q, c, DH_MODULUS) % DH_MODULUS
|
||||||
|
|
||||||
cprime = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
|
cprime = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
|
||||||
|
|
||||||
return long_to_bytes(c) == cprime
|
return long_to_bytes(c, 32) == cprime
|
||||||
|
|
||||||
def proof_equal_logs(self, v):
|
def proof_equal_logs(self, v):
|
||||||
r = bytes_to_long(RNG.read(192))
|
r = random.randrange(2, DH_MAX)
|
||||||
temp1 = pow(self.g1, r, DH1536_MODULUS)
|
temp1 = pow(self.g1, r, DH_MODULUS)
|
||||||
temp2 = pow(self.qab, r, DH1536_MODULUS)
|
temp2 = pow(self.qab, r, DH_MODULUS)
|
||||||
|
|
||||||
cb = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
|
cb = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
|
||||||
c = bytes_to_long(cb)
|
c = bytes_to_long(cb)
|
||||||
@@ -763,29 +767,29 @@ class SMPHandler:
|
|||||||
|
|
||||||
def check_equal_logs(self, logs, v):
|
def check_equal_logs(self, logs, v):
|
||||||
(r, c, d) = logs
|
(r, c, d) = logs
|
||||||
temp1 = pow(self.g1, d, DH1536_MODULUS) \
|
temp1 = pow(self.g1, d, DH_MODULUS) \
|
||||||
* pow(self.g3o, c, DH1536_MODULUS) % DH1536_MODULUS
|
* pow(self.g3o, c, DH_MODULUS) % DH_MODULUS
|
||||||
|
|
||||||
temp2 = pow(self.qab, d, DH1536_MODULUS) \
|
temp2 = pow(self.qab, d, DH_MODULUS) \
|
||||||
* pow(r, c, DH1536_MODULUS) % DH1536_MODULUS
|
* pow(r, c, DH_MODULUS) % DH_MODULUS
|
||||||
|
|
||||||
cprime = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
|
cprime = SHA256(struct.pack(b'B', v) + pack_mpi(temp1) + pack_mpi(temp2))
|
||||||
return long_to_bytes(c) == cprime
|
return long_to_bytes(c, 32) == cprime
|
||||||
|
|
||||||
def proof_known_log(g, x, v):
|
def proof_known_log(g, x, v):
|
||||||
r = bytes_to_long(RNG.read(192))
|
r = random.randrange(2, DH_MAX)
|
||||||
c = bytes_to_long(SHA256(struct.pack(b'B', v) + pack_mpi(pow(g, r, DH1536_MODULUS))))
|
c = bytes_to_long(SHA256(struct.pack(b'B', v) + pack_mpi(pow(g, r, DH_MODULUS))))
|
||||||
temp = x * c % SM_ORDER
|
temp = x * c % SM_ORDER
|
||||||
return c, (r-temp) % SM_ORDER
|
return c, (r-temp) % SM_ORDER
|
||||||
|
|
||||||
def check_known_log(c, d, g, x, v):
|
def check_known_log(c, d, g, x, v):
|
||||||
gd = pow(g, d, DH1536_MODULUS)
|
gd = pow(g, d, DH_MODULUS)
|
||||||
xc = pow(x, c, DH1536_MODULUS)
|
xc = pow(x, c, DH_MODULUS)
|
||||||
gdxc = gd * xc % DH1536_MODULUS
|
gdxc = gd * xc % DH_MODULUS
|
||||||
return SHA256(struct.pack(b'B', v) + pack_mpi(gdxc)) == long_to_bytes(c)
|
return SHA256(struct.pack(b'B', v) + pack_mpi(gdxc)) == long_to_bytes(c, 32)
|
||||||
|
|
||||||
def invMod(n):
|
def invMod(n):
|
||||||
return pow(n, DH1536_MODULUS_2, DH1536_MODULUS)
|
return pow(n, DH_MODULUS_2, DH_MODULUS)
|
||||||
|
|
||||||
class InvalidParameterError(RuntimeError):
|
class InvalidParameterError(RuntimeError):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -19,14 +19,16 @@
|
|||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import logging
|
|
||||||
import struct
|
import struct
|
||||||
from potr.utils import pack_mpi, read_mpi, pack_data, read_data, unpack
|
from potr.utils import pack_mpi, read_mpi, pack_data, read_data, unpack
|
||||||
|
|
||||||
OTRTAG = b'?OTR'
|
OTRTAG = b'?OTR'
|
||||||
MESSAGE_TAG_BASE = b' \t \t\t\t\t \t \t \t '
|
MESSAGE_TAG_BASE = b' \t \t\t\t\t \t \t \t '
|
||||||
MESSAGE_TAG_V1 = b' \t \t \t '
|
MESSAGE_TAGS = {
|
||||||
MESSAGE_TAG_V2 = b' \t\t \t '
|
1:b' \t \t \t ',
|
||||||
|
2:b' \t\t \t ',
|
||||||
|
3:b' \t\t \t\t',
|
||||||
|
}
|
||||||
|
|
||||||
MSGTYPE_NOTOTR = 0
|
MSGTYPE_NOTOTR = 0
|
||||||
MSGTYPE_TAGGEDPLAINTEXT = 1
|
MSGTYPE_TAGGEDPLAINTEXT = 1
|
||||||
@@ -62,6 +64,8 @@ def registermessage(cls):
|
|||||||
def registertlv(cls):
|
def registertlv(cls):
|
||||||
if not hasattr(cls, 'parsePayload'):
|
if not hasattr(cls, 'parsePayload'):
|
||||||
raise TypeError('registered tlv types need parsePayload()')
|
raise TypeError('registered tlv types need parsePayload()')
|
||||||
|
if cls.typ is None:
|
||||||
|
raise TypeError('registered tlv type needs type ID')
|
||||||
tlvClasses[cls.typ] = cls
|
tlvClasses[cls.typ] = cls
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
@@ -84,16 +88,6 @@ class OTRMessage(object):
|
|||||||
__slots__ = ['payload']
|
__slots__ = ['payload']
|
||||||
version = 0x0002
|
version = 0x0002
|
||||||
msgtype = 0
|
msgtype = 0
|
||||||
def __init__(self, payload):
|
|
||||||
self.payload = payload
|
|
||||||
|
|
||||||
def getPayload(self):
|
|
||||||
return self.payload
|
|
||||||
|
|
||||||
def __bytes__(self):
|
|
||||||
data = struct.pack(b'!HB', self.version, self.msgtype) \
|
|
||||||
+ self.getPayload()
|
|
||||||
return b'?OTR:' + base64.b64encode(data) + b'.'
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if not isinstance(other, self.__class__):
|
if not isinstance(other, self.__class__):
|
||||||
@@ -110,6 +104,7 @@ class OTRMessage(object):
|
|||||||
class Error(OTRMessage):
|
class Error(OTRMessage):
|
||||||
__slots__ = ['error']
|
__slots__ = ['error']
|
||||||
def __init__(self, error):
|
def __init__(self, error):
|
||||||
|
super(Error, self).__init__()
|
||||||
self.error = error
|
self.error = error
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@@ -119,56 +114,58 @@ class Error(OTRMessage):
|
|||||||
return b'?OTR Error:' + self.error
|
return b'?OTR Error:' + self.error
|
||||||
|
|
||||||
class Query(OTRMessage):
|
class Query(OTRMessage):
|
||||||
__slots__ = ['v1', 'v2']
|
__slots__ = ['versions']
|
||||||
def __init__(self, v1, v2):
|
def __init__(self, versions=set()):
|
||||||
self.v1 = v1
|
super(Query, self).__init__()
|
||||||
self.v2 = v2
|
self.versions = versions
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse(cls, data):
|
def parse(cls, data):
|
||||||
v2 = False
|
if not isinstance(data, bytes):
|
||||||
v1 = False
|
raise TypeError('can only parse bytes')
|
||||||
if len(data) > 0 and data[0:1] == b'?':
|
udata = data.decode('ascii', errors='replace')
|
||||||
data = data[1:]
|
|
||||||
v1 = True
|
|
||||||
|
|
||||||
if len(data) > 0 and data[0:1] == b'v':
|
versions = set()
|
||||||
for c in data[1:]:
|
if len(udata) > 0 and udata[0] == '?':
|
||||||
if c == b'2'[0]:
|
udata = udata[1:]
|
||||||
v2 = True
|
versions.add(1)
|
||||||
return cls(v1, v2)
|
|
||||||
|
if len(udata) > 0 and udata[0] == 'v':
|
||||||
|
versions.update(( int(c) for c in udata if c.isdigit() ))
|
||||||
|
return cls(versions)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<proto.Query(v1=%r,v2=%r)>'%(self.v1,self.v2)
|
return '<proto.Query(versions=%r)>' % (self.versions)
|
||||||
|
|
||||||
def __bytes__(self):
|
def __bytes__(self):
|
||||||
d = b'?OTR'
|
d = b'?OTR'
|
||||||
if self.v1:
|
if 1 in self.versions:
|
||||||
d += b'?'
|
d += b'?'
|
||||||
d += b'v'
|
d += b'v'
|
||||||
if self.v2:
|
|
||||||
d += b'2'
|
# in python3 there is only int->unicode conversion
|
||||||
|
# so I convert to unicode and encode it to a byte string
|
||||||
|
versions = [ '%d' % v for v in self.versions if v != 1 ]
|
||||||
|
d += ''.join(versions).encode('ascii')
|
||||||
|
|
||||||
d += b'?'
|
d += b'?'
|
||||||
return d
|
return d
|
||||||
|
|
||||||
class TaggedPlaintext(Query):
|
class TaggedPlaintext(Query):
|
||||||
__slots__ = ['msg']
|
__slots__ = ['msg']
|
||||||
def __init__(self, msg, v1, v2):
|
def __init__(self, msg, versions):
|
||||||
|
super(TaggedPlaintext, self).__init__(versions)
|
||||||
self.msg = msg
|
self.msg = msg
|
||||||
self.v1 = v1
|
|
||||||
self.v2 = v2
|
|
||||||
|
|
||||||
def __bytes__(self):
|
def __bytes__(self):
|
||||||
data = self.msg + MESSAGE_TAG_BASE
|
data = self.msg + MESSAGE_TAG_BASE
|
||||||
if self.v1:
|
for v in self.versions:
|
||||||
data += MESSAGE_TAG_V1
|
data += MESSAGE_TAGS[v]
|
||||||
if self.v2:
|
|
||||||
data += MESSAGE_TAG_V2
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<proto.TaggedPlaintext(v1={v1!r},v2={v2!r},msg={msg!r})>' \
|
return '<proto.TaggedPlaintext(versions={versions!r},msg={msg!r})>' \
|
||||||
.format(v1=self.v1, v2=self.v2, msg=self.msg)
|
.format(versions=self.versions, msg=self.msg)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse(cls, data):
|
def parse(cls, data):
|
||||||
@@ -177,21 +174,18 @@ class TaggedPlaintext(Query):
|
|||||||
raise TypeError(
|
raise TypeError(
|
||||||
'this is not a tagged plaintext ({0!r:.20})'.format(data))
|
'this is not a tagged plaintext ({0!r:.20})'.format(data))
|
||||||
|
|
||||||
v1 = False
|
|
||||||
v2 = False
|
|
||||||
|
|
||||||
tags = [ data[i:i+8] for i in range(tagPos, len(data), 8) ]
|
tags = [ data[i:i+8] for i in range(tagPos, len(data), 8) ]
|
||||||
for tag in tags:
|
versions = set([ version for version, tag in MESSAGE_TAGS.items() if tag
|
||||||
if not tag.isspace():
|
in tags ])
|
||||||
break
|
|
||||||
v1 |= tag == MESSAGE_TAG_V1
|
|
||||||
v2 |= tag == MESSAGE_TAG_V2
|
|
||||||
|
|
||||||
return TaggedPlaintext(data[:tagPos], v1, v2)
|
return TaggedPlaintext(data[:tagPos], versions)
|
||||||
|
|
||||||
class GenericOTRMessage(OTRMessage):
|
class GenericOTRMessage(OTRMessage):
|
||||||
__slots__ = ['data']
|
__slots__ = ['data']
|
||||||
|
fields = []
|
||||||
|
|
||||||
def __init__(self, *args):
|
def __init__(self, *args):
|
||||||
|
super(GenericOTRMessage, self).__init__()
|
||||||
if len(args) != len(self.fields):
|
if len(args) != len(self.fields):
|
||||||
raise TypeError('%s needs %d arguments, got %d' %
|
raise TypeError('%s needs %d arguments, got %d' %
|
||||||
(self.__class__.__name__, len(self.fields), len(args)))
|
(self.__class__.__name__, len(self.fields), len(args)))
|
||||||
@@ -213,6 +207,11 @@ class GenericOTRMessage(OTRMessage):
|
|||||||
self.__getattr__(attr) # existence check
|
self.__getattr__(attr) # existence check
|
||||||
self.data[attr] = val
|
self.data[attr] = val
|
||||||
|
|
||||||
|
def __bytes__(self):
|
||||||
|
data = struct.pack(b'!HB', self.version, self.msgtype) \
|
||||||
|
+ self.getPayload()
|
||||||
|
return b'?OTR:' + base64.b64encode(data) + b'.'
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
name = self.__class__.__name__
|
name = self.__class__.__name__
|
||||||
data = ''
|
data = ''
|
||||||
@@ -224,11 +223,10 @@ class GenericOTRMessage(OTRMessage):
|
|||||||
def parsePayload(cls, data):
|
def parsePayload(cls, data):
|
||||||
data = base64.b64decode(data)
|
data = base64.b64decode(data)
|
||||||
args = []
|
args = []
|
||||||
for k, ftype in cls.fields:
|
for _, ftype in cls.fields:
|
||||||
if ftype == 'data':
|
if ftype == 'data':
|
||||||
value, data = read_data(data)
|
value, data = read_data(data)
|
||||||
elif isinstance(ftype, bytes):
|
elif isinstance(ftype, bytes):
|
||||||
size = int(struct.calcsize(ftype))
|
|
||||||
value, data = unpack(ftype, data)
|
value, data = unpack(ftype, data)
|
||||||
elif isinstance(ftype, int):
|
elif isinstance(ftype, int):
|
||||||
value, data = data[:ftype], data[ftype:]
|
value, data = data[:ftype], data[ftype:]
|
||||||
@@ -251,26 +249,24 @@ class GenericOTRMessage(OTRMessage):
|
|||||||
|
|
||||||
class AKEMessage(GenericOTRMessage):
|
class AKEMessage(GenericOTRMessage):
|
||||||
__slots__ = []
|
__slots__ = []
|
||||||
pass
|
|
||||||
|
|
||||||
@registermessage
|
@registermessage
|
||||||
class DHCommit(AKEMessage):
|
class DHCommit(AKEMessage):
|
||||||
__slots__ = []
|
__slots__ = []
|
||||||
msgtype = 0x02
|
msgtype = 0x02
|
||||||
fields = [('encgx','data'), ('hashgx','data'), ]
|
fields = [('encgx', 'data'), ('hashgx', 'data'), ]
|
||||||
|
|
||||||
|
|
||||||
@registermessage
|
@registermessage
|
||||||
class DHKey(AKEMessage):
|
class DHKey(AKEMessage):
|
||||||
__slots__ = []
|
__slots__ = []
|
||||||
msgtype = 0x0a
|
msgtype = 0x0a
|
||||||
fields = [('gy','data'), ]
|
fields = [('gy', 'data'), ]
|
||||||
|
|
||||||
@registermessage
|
@registermessage
|
||||||
class RevealSig(AKEMessage):
|
class RevealSig(AKEMessage):
|
||||||
__slots__ = []
|
__slots__ = []
|
||||||
msgtype = 0x11
|
msgtype = 0x11
|
||||||
fields = [('rkey','data'), ('encsig','data'), ('mac',20),]
|
fields = [('rkey', 'data'), ('encsig', 'data'), ('mac', 20),]
|
||||||
|
|
||||||
def getMacedData(self):
|
def getMacedData(self):
|
||||||
p = self.encsig
|
p = self.encsig
|
||||||
@@ -280,7 +276,7 @@ class RevealSig(AKEMessage):
|
|||||||
class Signature(AKEMessage):
|
class Signature(AKEMessage):
|
||||||
__slots__ = []
|
__slots__ = []
|
||||||
msgtype = 0x12
|
msgtype = 0x12
|
||||||
fields = [('encsig','data'), ('mac',20)]
|
fields = [('encsig', 'data'), ('mac', 20)]
|
||||||
|
|
||||||
def getMacedData(self):
|
def getMacedData(self):
|
||||||
p = self.encsig
|
p = self.encsig
|
||||||
@@ -290,8 +286,9 @@ class Signature(AKEMessage):
|
|||||||
class DataMessage(GenericOTRMessage):
|
class DataMessage(GenericOTRMessage):
|
||||||
__slots__ = []
|
__slots__ = []
|
||||||
msgtype = 0x03
|
msgtype = 0x03
|
||||||
fields = [('flags',b'!B'), ('skeyid',b'!I'), ('rkeyid',b'!I'), ('dhy','data'),
|
fields = [('flags', b'!B'), ('skeyid', b'!I'), ('rkeyid', b'!I'),
|
||||||
('ctr',8), ('encmsg','data'), ('mac',20), ('oldmacs','data'), ]
|
('dhy', 'data'), ('ctr', 8), ('encmsg', 'data'), ('mac', 20),
|
||||||
|
('oldmacs', 'data'), ]
|
||||||
|
|
||||||
def getMacedData(self):
|
def getMacedData(self):
|
||||||
return struct.pack(b'!HB', self.version, self.msgtype) + \
|
return struct.pack(b'!HB', self.version, self.msgtype) + \
|
||||||
@@ -300,6 +297,10 @@ class DataMessage(GenericOTRMessage):
|
|||||||
@bytesAndStrings
|
@bytesAndStrings
|
||||||
class TLV(object):
|
class TLV(object):
|
||||||
__slots__ = []
|
__slots__ = []
|
||||||
|
typ = None
|
||||||
|
|
||||||
|
def getPayload(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
val = self.getPayload()
|
val = self.getPayload()
|
||||||
@@ -330,11 +331,28 @@ class TLV(object):
|
|||||||
def __neq__(self, other):
|
def __neq__(self, other):
|
||||||
return not self.__eq__(other)
|
return not self.__eq__(other)
|
||||||
|
|
||||||
|
@registertlv
|
||||||
|
class PaddingTLV(TLV):
|
||||||
|
typ = 0
|
||||||
|
|
||||||
|
__slots__ = ['padding']
|
||||||
|
|
||||||
|
def __init__(self, padding):
|
||||||
|
super(PaddingTLV, self).__init__()
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
def getPayload(self):
|
||||||
|
return self.padding
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parsePayload(cls, data):
|
||||||
|
return cls(data)
|
||||||
|
|
||||||
@registertlv
|
@registertlv
|
||||||
class DisconnectTLV(TLV):
|
class DisconnectTLV(TLV):
|
||||||
typ = 1
|
typ = 1
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
super(DisconnectTLV, self).__init__()
|
||||||
|
|
||||||
def getPayload(self):
|
def getPayload(self):
|
||||||
return b''
|
return b''
|
||||||
@@ -348,8 +366,14 @@ class DisconnectTLV(TLV):
|
|||||||
|
|
||||||
class SMPTLV(TLV):
|
class SMPTLV(TLV):
|
||||||
__slots__ = ['mpis']
|
__slots__ = ['mpis']
|
||||||
|
dlen = None
|
||||||
|
|
||||||
def __init__(self, mpis=[]):
|
def __init__(self, mpis=None):
|
||||||
|
super(SMPTLV, self).__init__()
|
||||||
|
if mpis is None:
|
||||||
|
mpis = []
|
||||||
|
if self.dlen is None:
|
||||||
|
raise TypeError('no amount of mpis specified in dlen')
|
||||||
if len(mpis) != self.dlen:
|
if len(mpis) != self.dlen:
|
||||||
raise TypeError('expected {0} mpis, got {1}'
|
raise TypeError('expected {0} mpis, got {1}'
|
||||||
.format(self.dlen, len(mpis)))
|
.format(self.dlen, len(mpis)))
|
||||||
@@ -366,7 +390,7 @@ class SMPTLV(TLV):
|
|||||||
mpis = []
|
mpis = []
|
||||||
if cls.dlen > 0:
|
if cls.dlen > 0:
|
||||||
count, data = unpack(b'!I', data)
|
count, data = unpack(b'!I', data)
|
||||||
for i in range(count):
|
for _ in range(count):
|
||||||
n, data = read_mpi(data)
|
n, data = read_mpi(data)
|
||||||
mpis.append(n)
|
mpis.append(n)
|
||||||
if len(data) > 0:
|
if len(data) > 0:
|
||||||
@@ -419,3 +443,23 @@ class SMPABORTTLV(SMPTLV):
|
|||||||
|
|
||||||
def getPayload(self):
|
def getPayload(self):
|
||||||
return b''
|
return b''
|
||||||
|
|
||||||
|
@registertlv
|
||||||
|
class ExtraKeyTLV(TLV):
|
||||||
|
typ = 8
|
||||||
|
|
||||||
|
__slots__ = ['appid', 'appdata']
|
||||||
|
|
||||||
|
def __init__(self, appid, appdata):
|
||||||
|
super(ExtraKeyTLV, self).__init__()
|
||||||
|
self.appid = appid
|
||||||
|
self.appdata = appdata
|
||||||
|
if appdata is None:
|
||||||
|
self.appdata = b''
|
||||||
|
|
||||||
|
def getPayload(self):
|
||||||
|
return self.appid + self.appdata
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parsePayload(cls, data):
|
||||||
|
return cls(data[:4], data[4:])
|
||||||
|
|||||||
@@ -43,11 +43,12 @@ def bytes_to_long(b):
|
|||||||
s += byte_to_long(b[i:i+1]) << 8*(l-i-1)
|
s += byte_to_long(b[i:i+1]) << 8*(l-i-1)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def long_to_bytes(l):
|
def long_to_bytes(l, n=0):
|
||||||
b = b''
|
b = b''
|
||||||
while l != 0:
|
while l != 0 or n > 0:
|
||||||
b = long_to_byte(l & 0xff) + b
|
b = long_to_byte(l & 0xff) + b
|
||||||
l >>= 8
|
l >>= 8
|
||||||
|
n -= 1
|
||||||
return b
|
return b
|
||||||
|
|
||||||
def byte_to_long(b):
|
def byte_to_long(b):
|
||||||
|
|||||||
Reference in New Issue
Block a user