gotr: update provided potr to 1.0.0beta7
This commit is contained in:
@@ -19,14 +19,16 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import struct
|
||||
from potr.utils import pack_mpi, read_mpi, pack_data, read_data, unpack
|
||||
|
||||
OTRTAG = b'?OTR'
|
||||
MESSAGE_TAG_BASE = b' \t \t\t\t\t \t \t \t '
|
||||
MESSAGE_TAG_V1 = b' \t \t \t '
|
||||
MESSAGE_TAG_V2 = b' \t\t \t '
|
||||
MESSAGE_TAGS = {
|
||||
1:b' \t \t \t ',
|
||||
2:b' \t\t \t ',
|
||||
3:b' \t\t \t\t',
|
||||
}
|
||||
|
||||
MSGTYPE_NOTOTR = 0
|
||||
MSGTYPE_TAGGEDPLAINTEXT = 1
|
||||
@@ -62,6 +64,8 @@ def registermessage(cls):
|
||||
def registertlv(cls):
|
||||
if not hasattr(cls, 'parsePayload'):
|
||||
raise TypeError('registered tlv types need parsePayload()')
|
||||
if cls.typ is None:
|
||||
raise TypeError('registered tlv type needs type ID')
|
||||
tlvClasses[cls.typ] = cls
|
||||
return cls
|
||||
|
||||
@@ -84,16 +88,6 @@ class OTRMessage(object):
|
||||
__slots__ = ['payload']
|
||||
version = 0x0002
|
||||
msgtype = 0
|
||||
def __init__(self, payload):
|
||||
self.payload = payload
|
||||
|
||||
def getPayload(self):
|
||||
return self.payload
|
||||
|
||||
def __bytes__(self):
|
||||
data = struct.pack(b'!HB', self.version, self.msgtype) \
|
||||
+ self.getPayload()
|
||||
return b'?OTR:' + base64.b64encode(data) + b'.'
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
@@ -110,6 +104,7 @@ class OTRMessage(object):
|
||||
class Error(OTRMessage):
|
||||
__slots__ = ['error']
|
||||
def __init__(self, error):
|
||||
super(Error, self).__init__()
|
||||
self.error = error
|
||||
|
||||
def __repr__(self):
|
||||
@@ -119,56 +114,58 @@ class Error(OTRMessage):
|
||||
return b'?OTR Error:' + self.error
|
||||
|
||||
class Query(OTRMessage):
|
||||
__slots__ = ['v1', 'v2']
|
||||
def __init__(self, v1, v2):
|
||||
self.v1 = v1
|
||||
self.v2 = v2
|
||||
__slots__ = ['versions']
|
||||
def __init__(self, versions=set()):
|
||||
super(Query, self).__init__()
|
||||
self.versions = versions
|
||||
|
||||
@classmethod
|
||||
def parse(cls, data):
|
||||
v2 = False
|
||||
v1 = False
|
||||
if len(data) > 0 and data[0:1] == b'?':
|
||||
data = data[1:]
|
||||
v1 = True
|
||||
if not isinstance(data, bytes):
|
||||
raise TypeError('can only parse bytes')
|
||||
udata = data.decode('ascii', errors='replace')
|
||||
|
||||
if len(data) > 0 and data[0:1] == b'v':
|
||||
for c in data[1:]:
|
||||
if c == b'2'[0]:
|
||||
v2 = True
|
||||
return cls(v1, v2)
|
||||
versions = set()
|
||||
if len(udata) > 0 and udata[0] == '?':
|
||||
udata = udata[1:]
|
||||
versions.add(1)
|
||||
|
||||
if len(udata) > 0 and udata[0] == 'v':
|
||||
versions.update(( int(c) for c in udata if c.isdigit() ))
|
||||
return cls(versions)
|
||||
|
||||
def __repr__(self):
|
||||
return '<proto.Query(v1=%r,v2=%r)>'%(self.v1,self.v2)
|
||||
return '<proto.Query(versions=%r)>' % (self.versions)
|
||||
|
||||
def __bytes__(self):
|
||||
d = b'?OTR'
|
||||
if self.v1:
|
||||
if 1 in self.versions:
|
||||
d += b'?'
|
||||
d += b'v'
|
||||
if self.v2:
|
||||
d += b'2'
|
||||
|
||||
# in python3 there is only int->unicode conversion
|
||||
# so I convert to unicode and encode it to a byte string
|
||||
versions = [ '%d' % v for v in self.versions if v != 1 ]
|
||||
d += ''.join(versions).encode('ascii')
|
||||
|
||||
d += b'?'
|
||||
return d
|
||||
|
||||
class TaggedPlaintext(Query):
|
||||
__slots__ = ['msg']
|
||||
def __init__(self, msg, v1, v2):
|
||||
def __init__(self, msg, versions):
|
||||
super(TaggedPlaintext, self).__init__(versions)
|
||||
self.msg = msg
|
||||
self.v1 = v1
|
||||
self.v2 = v2
|
||||
|
||||
def __bytes__(self):
|
||||
data = self.msg + MESSAGE_TAG_BASE
|
||||
if self.v1:
|
||||
data += MESSAGE_TAG_V1
|
||||
if self.v2:
|
||||
data += MESSAGE_TAG_V2
|
||||
for v in self.versions:
|
||||
data += MESSAGE_TAGS[v]
|
||||
return data
|
||||
|
||||
def __repr__(self):
|
||||
return '<proto.TaggedPlaintext(v1={v1!r},v2={v2!r},msg={msg!r})>' \
|
||||
.format(v1=self.v1, v2=self.v2, msg=self.msg)
|
||||
return '<proto.TaggedPlaintext(versions={versions!r},msg={msg!r})>' \
|
||||
.format(versions=self.versions, msg=self.msg)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, data):
|
||||
@@ -177,21 +174,18 @@ class TaggedPlaintext(Query):
|
||||
raise TypeError(
|
||||
'this is not a tagged plaintext ({0!r:.20})'.format(data))
|
||||
|
||||
v1 = False
|
||||
v2 = False
|
||||
|
||||
tags = [ data[i:i+8] for i in range(tagPos, len(data), 8) ]
|
||||
for tag in tags:
|
||||
if not tag.isspace():
|
||||
break
|
||||
v1 |= tag == MESSAGE_TAG_V1
|
||||
v2 |= tag == MESSAGE_TAG_V2
|
||||
versions = set([ version for version, tag in MESSAGE_TAGS.items() if tag
|
||||
in tags ])
|
||||
|
||||
return TaggedPlaintext(data[:tagPos], v1, v2)
|
||||
return TaggedPlaintext(data[:tagPos], versions)
|
||||
|
||||
class GenericOTRMessage(OTRMessage):
|
||||
__slots__ = ['data']
|
||||
fields = []
|
||||
|
||||
def __init__(self, *args):
|
||||
super(GenericOTRMessage, self).__init__()
|
||||
if len(args) != len(self.fields):
|
||||
raise TypeError('%s needs %d arguments, got %d' %
|
||||
(self.__class__.__name__, len(self.fields), len(args)))
|
||||
@@ -213,6 +207,11 @@ class GenericOTRMessage(OTRMessage):
|
||||
self.__getattr__(attr) # existence check
|
||||
self.data[attr] = val
|
||||
|
||||
def __bytes__(self):
|
||||
data = struct.pack(b'!HB', self.version, self.msgtype) \
|
||||
+ self.getPayload()
|
||||
return b'?OTR:' + base64.b64encode(data) + b'.'
|
||||
|
||||
def __repr__(self):
|
||||
name = self.__class__.__name__
|
||||
data = ''
|
||||
@@ -224,11 +223,10 @@ class GenericOTRMessage(OTRMessage):
|
||||
def parsePayload(cls, data):
|
||||
data = base64.b64decode(data)
|
||||
args = []
|
||||
for k, ftype in cls.fields:
|
||||
for _, ftype in cls.fields:
|
||||
if ftype == 'data':
|
||||
value, data = read_data(data)
|
||||
elif isinstance(ftype, bytes):
|
||||
size = int(struct.calcsize(ftype))
|
||||
value, data = unpack(ftype, data)
|
||||
elif isinstance(ftype, int):
|
||||
value, data = data[:ftype], data[ftype:]
|
||||
@@ -251,26 +249,24 @@ class GenericOTRMessage(OTRMessage):
|
||||
|
||||
class AKEMessage(GenericOTRMessage):
|
||||
__slots__ = []
|
||||
pass
|
||||
|
||||
@registermessage
|
||||
class DHCommit(AKEMessage):
|
||||
__slots__ = []
|
||||
msgtype = 0x02
|
||||
fields = [('encgx','data'), ('hashgx','data'), ]
|
||||
|
||||
fields = [('encgx', 'data'), ('hashgx', 'data'), ]
|
||||
|
||||
@registermessage
|
||||
class DHKey(AKEMessage):
|
||||
__slots__ = []
|
||||
msgtype = 0x0a
|
||||
fields = [('gy','data'), ]
|
||||
fields = [('gy', 'data'), ]
|
||||
|
||||
@registermessage
|
||||
class RevealSig(AKEMessage):
|
||||
__slots__ = []
|
||||
msgtype = 0x11
|
||||
fields = [('rkey','data'), ('encsig','data'), ('mac',20),]
|
||||
fields = [('rkey', 'data'), ('encsig', 'data'), ('mac', 20),]
|
||||
|
||||
def getMacedData(self):
|
||||
p = self.encsig
|
||||
@@ -280,7 +276,7 @@ class RevealSig(AKEMessage):
|
||||
class Signature(AKEMessage):
|
||||
__slots__ = []
|
||||
msgtype = 0x12
|
||||
fields = [('encsig','data'), ('mac',20)]
|
||||
fields = [('encsig', 'data'), ('mac', 20)]
|
||||
|
||||
def getMacedData(self):
|
||||
p = self.encsig
|
||||
@@ -290,8 +286,9 @@ class Signature(AKEMessage):
|
||||
class DataMessage(GenericOTRMessage):
|
||||
__slots__ = []
|
||||
msgtype = 0x03
|
||||
fields = [('flags',b'!B'), ('skeyid',b'!I'), ('rkeyid',b'!I'), ('dhy','data'),
|
||||
('ctr',8), ('encmsg','data'), ('mac',20), ('oldmacs','data'), ]
|
||||
fields = [('flags', b'!B'), ('skeyid', b'!I'), ('rkeyid', b'!I'),
|
||||
('dhy', 'data'), ('ctr', 8), ('encmsg', 'data'), ('mac', 20),
|
||||
('oldmacs', 'data'), ]
|
||||
|
||||
def getMacedData(self):
|
||||
return struct.pack(b'!HB', self.version, self.msgtype) + \
|
||||
@@ -300,6 +297,10 @@ class DataMessage(GenericOTRMessage):
|
||||
@bytesAndStrings
|
||||
class TLV(object):
|
||||
__slots__ = []
|
||||
typ = None
|
||||
|
||||
def getPayload(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
val = self.getPayload()
|
||||
@@ -330,11 +331,28 @@ class TLV(object):
|
||||
def __neq__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
@registertlv
|
||||
class PaddingTLV(TLV):
|
||||
typ = 0
|
||||
|
||||
__slots__ = ['padding']
|
||||
|
||||
def __init__(self, padding):
|
||||
super(PaddingTLV, self).__init__()
|
||||
self.padding = padding
|
||||
|
||||
def getPayload(self):
|
||||
return self.padding
|
||||
|
||||
@classmethod
|
||||
def parsePayload(cls, data):
|
||||
return cls(data)
|
||||
|
||||
@registertlv
|
||||
class DisconnectTLV(TLV):
|
||||
typ = 1
|
||||
def __init__(self):
|
||||
pass
|
||||
super(DisconnectTLV, self).__init__()
|
||||
|
||||
def getPayload(self):
|
||||
return b''
|
||||
@@ -348,8 +366,14 @@ class DisconnectTLV(TLV):
|
||||
|
||||
class SMPTLV(TLV):
|
||||
__slots__ = ['mpis']
|
||||
dlen = None
|
||||
|
||||
def __init__(self, mpis=[]):
|
||||
def __init__(self, mpis=None):
|
||||
super(SMPTLV, self).__init__()
|
||||
if mpis is None:
|
||||
mpis = []
|
||||
if self.dlen is None:
|
||||
raise TypeError('no amount of mpis specified in dlen')
|
||||
if len(mpis) != self.dlen:
|
||||
raise TypeError('expected {0} mpis, got {1}'
|
||||
.format(self.dlen, len(mpis)))
|
||||
@@ -366,7 +390,7 @@ class SMPTLV(TLV):
|
||||
mpis = []
|
||||
if cls.dlen > 0:
|
||||
count, data = unpack(b'!I', data)
|
||||
for i in range(count):
|
||||
for _ in range(count):
|
||||
n, data = read_mpi(data)
|
||||
mpis.append(n)
|
||||
if len(data) > 0:
|
||||
@@ -419,3 +443,23 @@ class SMPABORTTLV(SMPTLV):
|
||||
|
||||
def getPayload(self):
|
||||
return b''
|
||||
|
||||
@registertlv
|
||||
class ExtraKeyTLV(TLV):
|
||||
typ = 8
|
||||
|
||||
__slots__ = ['appid', 'appdata']
|
||||
|
||||
def __init__(self, appid, appdata):
|
||||
super(ExtraKeyTLV, self).__init__()
|
||||
self.appid = appid
|
||||
self.appdata = appdata
|
||||
if appdata is None:
|
||||
self.appdata = b''
|
||||
|
||||
def getPayload(self):
|
||||
return self.appid + self.appdata
|
||||
|
||||
@classmethod
|
||||
def parsePayload(cls, data):
|
||||
return cls(data[:4], data[4:])
|
||||
|
||||
Reference in New Issue
Block a user