--- a/mercurial/utils/cborutil.py Tue Aug 28 15:22:06 2018 -0700
+++ b/mercurial/utils/cborutil.py Tue Aug 28 15:02:48 2018 -0700
@@ -8,6 +8,7 @@
from __future__ import absolute_import
import struct
+import sys
from ..thirdparty.cbor.cbor2 import (
decoder as decodermod,
@@ -35,11 +36,16 @@
SUBTYPE_MASK = 0b00011111
+SUBTYPE_FALSE = 20
+SUBTYPE_TRUE = 21
+SUBTYPE_NULL = 22
SUBTYPE_HALF_FLOAT = 25
SUBTYPE_SINGLE_FLOAT = 26
SUBTYPE_DOUBLE_FLOAT = 27
SUBTYPE_INDEFINITE = 31
+SEMANTIC_TAG_FINITE_SET = 258
+
# Indefinite types begin with their major type ORd with information value 31.
BEGIN_INDEFINITE_BYTESTRING = struct.pack(
r'>B', MAJOR_TYPE_BYTESTRING << 5 | SUBTYPE_INDEFINITE)
@@ -146,7 +152,7 @@
def streamencodeset(s):
# https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml defines
# semantic tag 258 for finite sets.
- yield encodelength(MAJOR_TYPE_SEMANTIC, 258)
+ yield encodelength(MAJOR_TYPE_SEMANTIC, SEMANTIC_TAG_FINITE_SET)
for chunk in streamencodearray(sorted(s, key=_mixedtypesortkey)):
yield chunk
@@ -260,3 +266,710 @@
len(chunk), length))
yield chunk
+
+class CBORDecodeError(Exception):
+ """Represents an error decoding CBOR."""
+
+if sys.version_info.major >= 3:
+ def _elementtointeger(b, i):
+ return b[i]
+else:
+ def _elementtointeger(b, i):
+ return ord(b[i])
+
+STRUCT_BIG_UBYTE = struct.Struct(r'>B')
+STRUCT_BIG_USHORT = struct.Struct('>H')
+STRUCT_BIG_ULONG = struct.Struct('>L')
+STRUCT_BIG_ULONGLONG = struct.Struct('>Q')
+
+SPECIAL_NONE = 0
+SPECIAL_START_INDEFINITE_BYTESTRING = 1
+SPECIAL_START_ARRAY = 2
+SPECIAL_START_MAP = 3
+SPECIAL_START_SET = 4
+SPECIAL_INDEFINITE_BREAK = 5
+
+def decodeitem(b, offset=0):
+ """Decode a new CBOR value from a buffer at offset.
+
+ This function attempts to decode up to one complete CBOR value
+ from ``b`` starting at offset ``offset``.
+
+ The beginning of a collection (such as an array, map, set, or
+ indefinite length bytestring) counts as a single value. For these
+ special cases, a state flag will indicate that a special value was seen.
+
+ When called, the function either returns a decoded value or gives
+ a hint as to how many more bytes are needed to do so. By calling
+ the function repeatedly given a stream of bytes, the caller can
+ build up the original values.
+
+ Returns a tuple with the following elements:
+
+ * Bool indicating whether a complete value was decoded.
+ * A decoded value if first value is True otherwise None
+ * Integer number of bytes. If positive, the number of bytes
+ read. If negative, the number of bytes we need to read to
+ decode this value or the next chunk in this value.
+ * One of the ``SPECIAL_*`` constants indicating special treatment
+ for this value. ``SPECIAL_NONE`` means this is a fully decoded
+ simple value (such as an integer or bool).
+ """
+
+ initial = _elementtointeger(b, offset)
+ offset += 1
+
+ majortype = initial >> 5
+ subtype = initial & SUBTYPE_MASK
+
+ if majortype == MAJOR_TYPE_UINT:
+ complete, value, readcount = decodeuint(subtype, b, offset)
+
+ if complete:
+ return True, value, readcount + 1, SPECIAL_NONE
+ else:
+ return False, None, readcount, SPECIAL_NONE
+
+ elif majortype == MAJOR_TYPE_NEGINT:
+ # Negative integers are the same as UINT except inverted minus 1.
+ complete, value, readcount = decodeuint(subtype, b, offset)
+
+ if complete:
+ return True, -value - 1, readcount + 1, SPECIAL_NONE
+ else:
+ return False, None, readcount, SPECIAL_NONE
+
+ elif majortype == MAJOR_TYPE_BYTESTRING:
+ # Beginning of bytestrings are treated as uints in order to
+ # decode their length, which may be indefinite.
+ complete, size, readcount = decodeuint(subtype, b, offset,
+ allowindefinite=True)
+
+ # We don't know the size of the bytestring. It must be a definitive
+ # length since the indefinite subtype would be encoded in the initial
+ # byte.
+ if not complete:
+ return False, None, readcount, SPECIAL_NONE
+
+ # We know the length of the bytestring.
+ if size is not None:
+ # And the data is available in the buffer.
+ if offset + readcount + size <= len(b):
+ value = b[offset + readcount:offset + readcount + size]
+ return True, value, readcount + size + 1, SPECIAL_NONE
+
+ # And we need more data in order to return the bytestring.
+ else:
+ wanted = len(b) - offset - readcount - size
+ return False, None, wanted, SPECIAL_NONE
+
+ # It is an indefinite length bytestring.
+ else:
+ return True, None, 1, SPECIAL_START_INDEFINITE_BYTESTRING
+
+ elif majortype == MAJOR_TYPE_STRING:
+ raise CBORDecodeError('string major type not supported')
+
+ elif majortype == MAJOR_TYPE_ARRAY:
+ # Beginning of arrays are treated as uints in order to decode their
+ # length. We don't allow indefinite length arrays.
+ complete, size, readcount = decodeuint(subtype, b, offset)
+
+ if complete:
+ return True, size, readcount + 1, SPECIAL_START_ARRAY
+ else:
+ return False, None, readcount, SPECIAL_NONE
+
+ elif majortype == MAJOR_TYPE_MAP:
+ # Beginning of maps are treated as uints in order to decode their
+ # number of elements. We don't allow indefinite length arrays.
+ complete, size, readcount = decodeuint(subtype, b, offset)
+
+ if complete:
+ return True, size, readcount + 1, SPECIAL_START_MAP
+ else:
+ return False, None, readcount, SPECIAL_NONE
+
+ elif majortype == MAJOR_TYPE_SEMANTIC:
+ # Semantic tag value is read the same as a uint.
+ complete, tagvalue, readcount = decodeuint(subtype, b, offset)
+
+ if not complete:
+ return False, None, readcount, SPECIAL_NONE
+
+ # This behavior here is a little wonky. The main type being "decorated"
+ # by this semantic tag follows. A more robust parser would probably emit
+ # a special flag indicating this as a semantic tag and let the caller
+ # deal with the types that follow. But since we don't support many
+ # semantic tags, it is easier to deal with the special cases here and
+ # hide complexity from the caller. If we add support for more semantic
+ # tags, we should probably move semantic tag handling into the caller.
+ if tagvalue == SEMANTIC_TAG_FINITE_SET:
+ if offset + readcount >= len(b):
+ return False, None, -1, SPECIAL_NONE
+
+ complete, size, readcount2, special = decodeitem(b,
+ offset + readcount)
+
+ if not complete:
+ return False, None, readcount2, SPECIAL_NONE
+
+ if special != SPECIAL_START_ARRAY:
+ raise CBORDecodeError('expected array after finite set '
+ 'semantic tag')
+
+ return True, size, readcount + readcount2 + 1, SPECIAL_START_SET
+
+ else:
+ raise CBORDecodeError('semantic tag %d not allowed' % tagvalue)
+
+ elif majortype == MAJOR_TYPE_SPECIAL:
+ # Only specific values for the information field are allowed.
+ if subtype == SUBTYPE_FALSE:
+ return True, False, 1, SPECIAL_NONE
+ elif subtype == SUBTYPE_TRUE:
+ return True, True, 1, SPECIAL_NONE
+ elif subtype == SUBTYPE_NULL:
+ return True, None, 1, SPECIAL_NONE
+ elif subtype == SUBTYPE_INDEFINITE:
+ return True, None, 1, SPECIAL_INDEFINITE_BREAK
+ # If value is 24, subtype is in next byte.
+ else:
+ raise CBORDecodeError('special type %d not allowed' % subtype)
+ else:
+ assert False
+
+def decodeuint(subtype, b, offset=0, allowindefinite=False):
+ """Decode an unsigned integer.
+
+ ``subtype`` is the lower 5 bits from the initial byte CBOR item
+ "header." ``b`` is a buffer containing bytes. ``offset`` points to
+ the index of the first byte after the byte that ``subtype`` was
+ derived from.
+
+ ``allowindefinite`` allows the special indefinite length value
+ indicator.
+
+ Returns a 3-tuple of (successful, value, count).
+
+ The first element is a bool indicating if decoding completed. The 2nd
+ is the decoded integer value or None if not fully decoded or the subtype
+ is 31 and ``allowindefinite`` is True. The 3rd value is the count of bytes.
+ If positive, it is the number of additional bytes decoded. If negative,
+ it is the number of additional bytes needed to decode this value.
+ """
+
+ # Small values are inline.
+ if subtype < 24:
+ return True, subtype, 0
+ # Indefinite length specifier.
+ elif subtype == 31:
+ if allowindefinite:
+ return True, None, 0
+ else:
+ raise CBORDecodeError('indefinite length uint not allowed here')
+ elif subtype >= 28:
+ raise CBORDecodeError('unsupported subtype on integer type: %d' %
+ subtype)
+
+ if subtype == 24:
+ s = STRUCT_BIG_UBYTE
+ elif subtype == 25:
+ s = STRUCT_BIG_USHORT
+ elif subtype == 26:
+ s = STRUCT_BIG_ULONG
+ elif subtype == 27:
+ s = STRUCT_BIG_ULONGLONG
+ else:
+ raise CBORDecodeError('bounds condition checking violation')
+
+ if len(b) - offset >= s.size:
+ return True, s.unpack_from(b, offset)[0], s.size
+ else:
+ return False, None, len(b) - offset - s.size
+
+class bytestringchunk(bytes):
+ """Represents a chunk/segment in an indefinite length bytestring.
+
+ This behaves like a ``bytes`` but in addition has the ``isfirst``
+ and ``islast`` attributes indicating whether this chunk is the first
+ or last in an indefinite length bytestring.
+ """
+
+ def __new__(cls, v, first=False, last=False):
+ self = bytes.__new__(cls, v)
+ self.isfirst = first
+ self.islast = last
+
+ return self
+
+class sansiodecoder(object):
+ """A CBOR decoder that doesn't perform its own I/O.
+
+ To use, construct an instance and feed it segments containing
+ CBOR-encoded bytes via ``decode()``. The return value from ``decode()``
+ indicates whether a fully-decoded value is available, how many bytes
+ were consumed, and offers a hint as to how many bytes should be fed
+ in next time to decode the next value.
+
+ The decoder assumes it will decode N discrete CBOR values, not just
+ a single value. i.e. if the bytestream contains uints packed one after
+ the other, the decoder will decode them all, rather than just the initial
+ one.
+
+ When ``decode()`` indicates a value is available, call ``getavailable()``
+ to return all fully decoded values.
+
+ ``decode()`` can partially decode input. It is up to the caller to keep
+ track of what data was consumed and to pass unconsumed data in on the
+ next invocation.
+
+ The decoder decodes atomically at the *item* level. See ``decodeitem()``.
+ If an *item* cannot be fully decoded, the decoder won't record it as
+ partially consumed. Instead, the caller will be instructed to pass in
+ the initial bytes of this item on the next invocation. This does result
+ in some redundant parsing. But the overhead should be minimal.
+
+ This decoder only supports a subset of CBOR as required by Mercurial.
+ It lacks support for:
+
+ * Indefinite length arrays
+ * Indefinite length maps
+ * Use of indefinite length bytestrings as keys or values within
+ arrays, maps, or sets.
+ * Nested arrays, maps, or sets within sets
+ * Any semantic tag that isn't a mathematical finite set
+ * Floating point numbers
+ * Undefined special value
+
+ CBOR types are decoded to Python types as follows:
+
+ uint -> int
+ negint -> int
+ bytestring -> bytes
+ map -> dict
+ array -> list
+ True -> bool
+ False -> bool
+ null -> None
+ indefinite length bytestring chunk -> [bytestringchunk]
+
+ The only non-obvious mapping here is an indefinite length bytestring
+ to the ``bytestringchunk`` type. This is to facilitate streaming
+ indefinite length bytestrings out of the decoder and to differentiate
+ a regular bytestring from an indefinite length bytestring.
+ """
+
+ _STATE_NONE = 0
+ _STATE_WANT_MAP_KEY = 1
+ _STATE_WANT_MAP_VALUE = 2
+ _STATE_WANT_ARRAY_VALUE = 3
+ _STATE_WANT_SET_VALUE = 4
+ _STATE_WANT_BYTESTRING_CHUNK_FIRST = 5
+ _STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT = 6
+
+ def __init__(self):
+ # TODO add support for limiting size of bytestrings
+ # TODO add support for limiting number of keys / values in collections
+ # TODO add support for limiting size of buffered partial values
+
+ self.decodedbytecount = 0
+
+ self._state = self._STATE_NONE
+
+ # Stack of active nested collections. Each entry is a dict describing
+ # the collection.
+ self._collectionstack = []
+
+ # Fully decoded key to use for the current map.
+ self._currentmapkey = None
+
+ # Fully decoded values available for retrieval.
+ self._decodedvalues = []
+
+ @property
+ def inprogress(self):
+ """Whether the decoder has partially decoded a value."""
+ return self._state != self._STATE_NONE
+
+ def decode(self, b, offset=0):
+ """Attempt to decode bytes from an input buffer.
+
+ ``b`` is a collection of bytes and ``offset`` is the byte
+ offset within that buffer from which to begin reading data.
+
+ ``b`` must support ``len()`` and accessing bytes slices via
+ ``__slice__``. Typically ``bytes`` instances are used.
+
+ Returns a tuple with the following fields:
+
+ * Bool indicating whether values are available for retrieval.
+ * Integer indicating the number of bytes that were fully consumed,
+ starting from ``offset``.
+ * Integer indicating the number of bytes that are desired for the
+ next call in order to decode an item.
+ """
+ if not b:
+ return bool(self._decodedvalues), 0, 0
+
+ initialoffset = offset
+
+ # We could easily split the body of this loop into a function. But
+ # Python performance is sensitive to function calls and collections
+ # are composed of many items. So leaving as a while loop could help
+ # with performance. One thing that may not help is the use of
+ # if..elif versus a lookup/dispatch table. There may be value
+ # in switching that.
+ while offset < len(b):
+ # Attempt to decode an item. This could be a whole value or a
+ # special value indicating an event, such as start or end of a
+ # collection or indefinite length type.
+ complete, value, readcount, special = decodeitem(b, offset)
+
+ if readcount > 0:
+ self.decodedbytecount += readcount
+
+ if not complete:
+ assert readcount < 0
+ return (
+ bool(self._decodedvalues),
+ offset - initialoffset,
+ -readcount,
+ )
+
+ offset += readcount
+
+ # No nested state. We either have a full value or beginning of a
+ # complex value to deal with.
+ if self._state == self._STATE_NONE:
+ # A normal value.
+ if special == SPECIAL_NONE:
+ self._decodedvalues.append(value)
+
+ elif special == SPECIAL_START_ARRAY:
+ self._collectionstack.append({
+ 'remaining': value,
+ 'v': [],
+ })
+ self._state = self._STATE_WANT_ARRAY_VALUE
+
+ elif special == SPECIAL_START_MAP:
+ self._collectionstack.append({
+ 'remaining': value,
+ 'v': {},
+ })
+ self._state = self._STATE_WANT_MAP_KEY
+
+ elif special == SPECIAL_START_SET:
+ self._collectionstack.append({
+ 'remaining': value,
+ 'v': set(),
+ })
+ self._state = self._STATE_WANT_SET_VALUE
+
+ elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
+ self._state = self._STATE_WANT_BYTESTRING_CHUNK_FIRST
+
+ else:
+ raise CBORDecodeError('unhandled special state: %d' %
+ special)
+
+ # This value becomes an element of the current array.
+ elif self._state == self._STATE_WANT_ARRAY_VALUE:
+ # Simple values get appended.
+ if special == SPECIAL_NONE:
+ c = self._collectionstack[-1]
+ c['v'].append(value)
+ c['remaining'] -= 1
+
+ # self._state doesn't need changed.
+
+ # An array nested within an array.
+ elif special == SPECIAL_START_ARRAY:
+ lastc = self._collectionstack[-1]
+ newvalue = []
+
+ lastc['v'].append(newvalue)
+ lastc['remaining'] -= 1
+
+ self._collectionstack.append({
+ 'remaining': value,
+ 'v': newvalue,
+ })
+
+ # self._state doesn't need changed.
+
+ # A map nested within an array.
+ elif special == SPECIAL_START_MAP:
+ lastc = self._collectionstack[-1]
+ newvalue = {}
+
+ lastc['v'].append(newvalue)
+ lastc['remaining'] -= 1
+
+ self._collectionstack.append({
+ 'remaining': value,
+ 'v': newvalue
+ })
+
+ self._state = self._STATE_WANT_MAP_KEY
+
+ elif special == SPECIAL_START_SET:
+ lastc = self._collectionstack[-1]
+ newvalue = set()
+
+ lastc['v'].append(newvalue)
+ lastc['remaining'] -= 1
+
+ self._collectionstack.append({
+ 'remaining': value,
+ 'v': newvalue,
+ })
+
+ self._state = self._STATE_WANT_SET_VALUE
+
+ elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
+ raise CBORDecodeError('indefinite length bytestrings '
+ 'not allowed as array values')
+
+ else:
+ raise CBORDecodeError('unhandled special item when '
+ 'expecting array value: %d' % special)
+
+ # This value becomes the key of the current map instance.
+ elif self._state == self._STATE_WANT_MAP_KEY:
+ if special == SPECIAL_NONE:
+ self._currentmapkey = value
+ self._state = self._STATE_WANT_MAP_VALUE
+
+ elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
+ raise CBORDecodeError('indefinite length bytestrings '
+ 'not allowed as map keys')
+
+ elif special in (SPECIAL_START_ARRAY, SPECIAL_START_MAP,
+ SPECIAL_START_SET):
+ raise CBORDecodeError('collections not supported as map '
+ 'keys')
+
+ # We do not allow special values to be used as map keys.
+ else:
+ raise CBORDecodeError('unhandled special item when '
+ 'expecting map key: %d' % special)
+
+ # This value becomes the value of the current map key.
+ elif self._state == self._STATE_WANT_MAP_VALUE:
+ # Simple values simply get inserted into the map.
+ if special == SPECIAL_NONE:
+ lastc = self._collectionstack[-1]
+ lastc['v'][self._currentmapkey] = value
+ lastc['remaining'] -= 1
+
+ self._state = self._STATE_WANT_MAP_KEY
+
+ # A new array is used as the map value.
+ elif special == SPECIAL_START_ARRAY:
+ lastc = self._collectionstack[-1]
+ newvalue = []
+
+ lastc['v'][self._currentmapkey] = newvalue
+ lastc['remaining'] -= 1
+
+ self._collectionstack.append({
+ 'remaining': value,
+ 'v': newvalue,
+ })
+
+ self._state = self._STATE_WANT_ARRAY_VALUE
+
+ # A new map is used as the map value.
+ elif special == SPECIAL_START_MAP:
+ lastc = self._collectionstack[-1]
+ newvalue = {}
+
+ lastc['v'][self._currentmapkey] = newvalue
+ lastc['remaining'] -= 1
+
+ self._collectionstack.append({
+ 'remaining': value,
+ 'v': newvalue,
+ })
+
+ self._state = self._STATE_WANT_MAP_KEY
+
+ # A new set is used as the map value.
+ elif special == SPECIAL_START_SET:
+ lastc = self._collectionstack[-1]
+ newvalue = set()
+
+ lastc['v'][self._currentmapkey] = newvalue
+ lastc['remaining'] -= 1
+
+ self._collectionstack.append({
+ 'remaining': value,
+ 'v': newvalue,
+ })
+
+ self._state = self._STATE_WANT_SET_VALUE
+
+ elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
+ raise CBORDecodeError('indefinite length bytestrings not '
+ 'allowed as map values')
+
+ else:
+ raise CBORDecodeError('unhandled special item when '
+ 'expecting map value: %d' % special)
+
+ self._currentmapkey = None
+
+ # This value is added to the current set.
+ elif self._state == self._STATE_WANT_SET_VALUE:
+ if special == SPECIAL_NONE:
+ lastc = self._collectionstack[-1]
+ lastc['v'].add(value)
+ lastc['remaining'] -= 1
+
+ elif special == SPECIAL_START_INDEFINITE_BYTESTRING:
+ raise CBORDecodeError('indefinite length bytestrings not '
+ 'allowed as set values')
+
+ elif special in (SPECIAL_START_ARRAY,
+ SPECIAL_START_MAP,
+ SPECIAL_START_SET):
+ raise CBORDecodeError('collections not allowed as set '
+ 'values')
+
+ # We don't allow non-trivial types to exist as set values.
+ else:
+ raise CBORDecodeError('unhandled special item when '
+ 'expecting set value: %d' % special)
+
+ # This value represents the first chunk in an indefinite length
+ # bytestring.
+ elif self._state == self._STATE_WANT_BYTESTRING_CHUNK_FIRST:
+ # We received a full chunk.
+ if special == SPECIAL_NONE:
+ self._decodedvalues.append(bytestringchunk(value,
+ first=True))
+
+ self._state = self._STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT
+
+ # The end of stream marker. This means it is an empty
+ # indefinite length bytestring.
+ elif special == SPECIAL_INDEFINITE_BREAK:
+ # We /could/ convert this to a b''. But we want to preserve
+ # the nature of the underlying data so consumers expecting
+ # an indefinite length bytestring get one.
+ self._decodedvalues.append(bytestringchunk(b'',
+ first=True,
+ last=True))
+
+ # Since indefinite length bytestrings can't be used in
+ # collections, we must be at the root level.
+ assert not self._collectionstack
+ self._state = self._STATE_NONE
+
+ else:
+ raise CBORDecodeError('unexpected special value when '
+ 'expecting bytestring chunk: %d' %
+ special)
+
+ # This value represents the non-initial chunk in an indefinite
+ # length bytestring.
+ elif self._state == self._STATE_WANT_BYTESTRING_CHUNK_SUBSEQUENT:
+ # We received a full chunk.
+ if special == SPECIAL_NONE:
+ self._decodedvalues.append(bytestringchunk(value))
+
+ # The end of stream marker.
+ elif special == SPECIAL_INDEFINITE_BREAK:
+ self._decodedvalues.append(bytestringchunk(b'', last=True))
+
+ # Since indefinite length bytestrings can't be used in
+ # collections, we must be at the root level.
+ assert not self._collectionstack
+ self._state = self._STATE_NONE
+
+ else:
+ raise CBORDecodeError('unexpected special value when '
+ 'expecting bytestring chunk: %d' %
+ special)
+
+ else:
+ raise CBORDecodeError('unhandled decoder state: %d' %
+ self._state)
+
+ # We could have just added the final value in a collection. End
+ # all complete collections at the top of the stack.
+ while True:
+ # Bail if we're not waiting on a new collection item.
+ if self._state not in (self._STATE_WANT_ARRAY_VALUE,
+ self._STATE_WANT_MAP_KEY,
+ self._STATE_WANT_SET_VALUE):
+ break
+
+ # Or we are expecting more items for this collection.
+ lastc = self._collectionstack[-1]
+
+ if lastc['remaining']:
+ break
+
+ # The collection at the top of the stack is complete.
+
+ # Discard it, as it isn't needed for future items.
+ self._collectionstack.pop()
+
+ # If this is a nested collection, we don't emit it, since it
+ # will be emitted by its parent collection. But we do need to
+ # update state to reflect what the new top-most collection
+ # on the stack is.
+ if self._collectionstack:
+ self._state = {
+ list: self._STATE_WANT_ARRAY_VALUE,
+ dict: self._STATE_WANT_MAP_KEY,
+ set: self._STATE_WANT_SET_VALUE,
+ }[type(self._collectionstack[-1]['v'])]
+
+ # If this is the root collection, emit it.
+ else:
+ self._decodedvalues.append(lastc['v'])
+ self._state = self._STATE_NONE
+
+ return (
+ bool(self._decodedvalues),
+ offset - initialoffset,
+ 0,
+ )
+
+ def getavailable(self):
+ """Returns an iterator over fully decoded values.
+
+ Once values are retrieved, they won't be available on the next call.
+ """
+
+ l = list(self._decodedvalues)
+ self._decodedvalues = []
+ return l
+
+def decodeall(b):
+ """Decode all CBOR items present in an iterable of bytes.
+
+ In addition to regular decode errors, raises CBORDecodeError if the
+ entirety of the passed buffer does not fully decode to complete CBOR
+ values. This includes failure to decode any value, incomplete collection
+ types, incomplete indefinite length items, and extra data at the end of
+ the buffer.
+ """
+ if not b:
+ return []
+
+ decoder = sansiodecoder()
+
+ havevalues, readcount, wantbytes = decoder.decode(b)
+
+ if readcount != len(b):
+ raise CBORDecodeError('input data not fully consumed')
+
+ if decoder.inprogress:
+ raise CBORDecodeError('input data not complete')
+
+ return decoder.getavailable()
--- a/tests/test-cbor.py Tue Aug 28 15:22:06 2018 -0700
+++ b/tests/test-cbor.py Tue Aug 28 15:02:48 2018 -0700
@@ -10,10 +10,17 @@
cborutil,
)
+class TestCase(unittest.TestCase):
+ if not getattr(unittest.TestCase, 'assertRaisesRegex', False):
+ # Python 3.7 deprecates the regex*p* version, but 2.7 lacks
+ # the regex version.
+ assertRaisesRegex = (# camelcase-required
+ unittest.TestCase.assertRaisesRegexp)
+
def loadit(it):
return cbor.loads(b''.join(it))
-class BytestringTests(unittest.TestCase):
+class BytestringTests(TestCase):
def testsimple(self):
self.assertEqual(
list(cborutil.streamencode(b'foobar')),
@@ -23,11 +30,20 @@
loadit(cborutil.streamencode(b'foobar')),
b'foobar')
+ self.assertEqual(cborutil.decodeall(b'\x46foobar'),
+ [b'foobar'])
+
+ self.assertEqual(cborutil.decodeall(b'\x46foobar\x45fizbi'),
+ [b'foobar', b'fizbi'])
+
def testlong(self):
source = b'x' * 1048576
self.assertEqual(loadit(cborutil.streamencode(source)), source)
+ encoded = b''.join(cborutil.streamencode(source))
+ self.assertEqual(cborutil.decodeall(encoded), [source])
+
def testfromiter(self):
# This is the example from RFC 7049 Section 2.2.2.
source = [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99']
@@ -47,6 +63,25 @@
loadit(cborutil.streamencodebytestringfromiter(source)),
b''.join(source))
+ self.assertEqual(cborutil.decodeall(b'\x5f\x44\xaa\xbb\xcc\xdd'
+ b'\x43\xee\xff\x99\xff'),
+ [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99', b''])
+
+ for i, chunk in enumerate(
+ cborutil.decodeall(b'\x5f\x44\xaa\xbb\xcc\xdd'
+ b'\x43\xee\xff\x99\xff')):
+ self.assertIsInstance(chunk, cborutil.bytestringchunk)
+
+ if i == 0:
+ self.assertTrue(chunk.isfirst)
+ else:
+ self.assertFalse(chunk.isfirst)
+
+ if i == 2:
+ self.assertTrue(chunk.islast)
+ else:
+ self.assertFalse(chunk.islast)
+
def testfromiterlarge(self):
source = [b'a' * 16, b'b' * 128, b'c' * 1024, b'd' * 1048576]
@@ -71,6 +106,18 @@
source, chunksize=42))
self.assertEqual(cbor.loads(dest), source)
+ self.assertEqual(b''.join(cborutil.decodeall(dest)), source)
+
+ for chunk in cborutil.decodeall(dest):
+ self.assertIsInstance(chunk, cborutil.bytestringchunk)
+ self.assertIn(len(chunk), (0, 8, 42))
+
+ encoded = b'\x5f\xff'
+ b = cborutil.decodeall(encoded)
+ self.assertEqual(b, [b''])
+ self.assertTrue(b[0].isfirst)
+ self.assertTrue(b[0].islast)
+
def testreadtoiter(self):
source = io.BytesIO(b'\x5f\x44\xaa\xbb\xcc\xdd\x43\xee\xff\x99\xff')
@@ -81,42 +128,405 @@
with self.assertRaises(StopIteration):
next(it)
-class IntTests(unittest.TestCase):
+ def testdecodevariouslengths(self):
+ for i in (0, 1, 22, 23, 24, 25, 254, 255, 256, 65534, 65535, 65536):
+ source = b'x' * i
+ encoded = b''.join(cborutil.streamencode(source))
+
+ if len(source) < 24:
+ hlen = 1
+ elif len(source) < 256:
+ hlen = 2
+ elif len(source) < 65536:
+ hlen = 3
+ elif len(source) < 1048576:
+ hlen = 5
+
+ self.assertEqual(cborutil.decodeitem(encoded),
+ (True, source, hlen + len(source),
+ cborutil.SPECIAL_NONE))
+
+ def testpartialdecode(self):
+ encoded = b''.join(cborutil.streamencode(b'foobar'))
+
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (False, None, -6, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+ (False, None, -5, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+ (False, None, -4, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+ (False, None, -3, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+ (False, None, -2, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:6]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:7]),
+ (True, b'foobar', 7, cborutil.SPECIAL_NONE))
+
+ def testpartialdecodevariouslengths(self):
+ lens = [
+ 2,
+ 3,
+ 10,
+ 23,
+ 24,
+ 25,
+ 31,
+ 100,
+ 254,
+ 255,
+ 256,
+ 257,
+ 16384,
+ 65534,
+ 65535,
+ 65536,
+ 65537,
+ 131071,
+ 131072,
+ 131073,
+ 1048575,
+ 1048576,
+ 1048577,
+ ]
+
+ for size in lens:
+ if size < 24:
+ hlen = 1
+ elif size < 2**8:
+ hlen = 2
+ elif size < 2**16:
+ hlen = 3
+ elif size < 2**32:
+ hlen = 5
+ else:
+ assert False
+
+ source = b'x' * size
+ encoded = b''.join(cborutil.streamencode(source))
+
+ res = cborutil.decodeitem(encoded[0:1])
+
+ if hlen > 1:
+ self.assertEqual(res, (False, None, -(hlen - 1),
+ cborutil.SPECIAL_NONE))
+ else:
+ self.assertEqual(res, (False, None, -(size + hlen - 1),
+ cborutil.SPECIAL_NONE))
+
+ # Decoding partial header reports remaining header size.
+ for i in range(hlen - 1):
+ self.assertEqual(cborutil.decodeitem(encoded[0:i + 1]),
+ (False, None, -(hlen - i - 1),
+ cborutil.SPECIAL_NONE))
+
+ # Decoding complete header reports item size.
+ self.assertEqual(cborutil.decodeitem(encoded[0:hlen]),
+ (False, None, -size, cborutil.SPECIAL_NONE))
+
+ # Decoding single byte after header reports item size - 1
+ self.assertEqual(cborutil.decodeitem(encoded[0:hlen + 1]),
+ (False, None, -(size - 1), cborutil.SPECIAL_NONE))
+
+ # Decoding all but the last byte reports -1 needed.
+ self.assertEqual(cborutil.decodeitem(encoded[0:hlen + size - 1]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+
+ # Decoding last byte retrieves value.
+ self.assertEqual(cborutil.decodeitem(encoded[0:hlen + size]),
+ (True, source, hlen + size, cborutil.SPECIAL_NONE))
+
+ def testindefinitepartialdecode(self):
+ encoded = b''.join(cborutil.streamencodebytestringfromiter(
+ [b'foobar', b'biz']))
+
+ # First item should be begin of bytestring special.
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (True, None, 1,
+ cborutil.SPECIAL_START_INDEFINITE_BYTESTRING))
+
+ # Second item should be the first chunk. But only available when
+ # we give it 7 bytes (1 byte header + 6 byte chunk).
+ self.assertEqual(cborutil.decodeitem(encoded[1:2]),
+ (False, None, -6, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[1:3]),
+ (False, None, -5, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[1:4]),
+ (False, None, -4, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[1:5]),
+ (False, None, -3, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[1:6]),
+ (False, None, -2, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[1:7]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+
+ self.assertEqual(cborutil.decodeitem(encoded[1:8]),
+ (True, b'foobar', 7, cborutil.SPECIAL_NONE))
+
+ # Third item should be second chunk. But only available when
+ # we give it 4 bytes (1 byte header + 3 byte chunk).
+ self.assertEqual(cborutil.decodeitem(encoded[8:9]),
+ (False, None, -3, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[8:10]),
+ (False, None, -2, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[8:11]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+
+ self.assertEqual(cborutil.decodeitem(encoded[8:12]),
+ (True, b'biz', 4, cborutil.SPECIAL_NONE))
+
+ # Fourth item should be end of indefinite stream marker.
+ self.assertEqual(cborutil.decodeitem(encoded[12:13]),
+ (True, None, 1, cborutil.SPECIAL_INDEFINITE_BREAK))
+
+ # Now test the behavior when going through the decoder.
+
+ self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:1]),
+ (False, 1, 0))
+ self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:2]),
+ (False, 1, 6))
+ self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:3]),
+ (False, 1, 5))
+ self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:4]),
+ (False, 1, 4))
+ self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:5]),
+ (False, 1, 3))
+ self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:6]),
+ (False, 1, 2))
+ self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:7]),
+ (False, 1, 1))
+ self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:8]),
+ (True, 8, 0))
+
+ self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:9]),
+ (True, 8, 3))
+ self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:10]),
+ (True, 8, 2))
+ self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:11]),
+ (True, 8, 1))
+ self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:12]),
+ (True, 12, 0))
+
+ self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:13]),
+ (True, 13, 0))
+
+ decoder = cborutil.sansiodecoder()
+ decoder.decode(encoded[0:8])
+ values = decoder.getavailable()
+ self.assertEqual(values, [b'foobar'])
+ self.assertTrue(values[0].isfirst)
+ self.assertFalse(values[0].islast)
+
+ self.assertEqual(decoder.decode(encoded[8:12]),
+ (True, 4, 0))
+ values = decoder.getavailable()
+ self.assertEqual(values, [b'biz'])
+ self.assertFalse(values[0].isfirst)
+ self.assertFalse(values[0].islast)
+
+ self.assertEqual(decoder.decode(encoded[12:]),
+ (True, 1, 0))
+ values = decoder.getavailable()
+ self.assertEqual(values, [b''])
+ self.assertFalse(values[0].isfirst)
+ self.assertTrue(values[0].islast)
+
+class StringTests(TestCase):
+ def testdecodeforbidden(self):
+ encoded = b'\x63foo'
+ with self.assertRaisesRegex(cborutil.CBORDecodeError,
+ 'string major type not supported'):
+ cborutil.decodeall(encoded)
+
+class IntTests(TestCase):
def testsmall(self):
self.assertEqual(list(cborutil.streamencode(0)), [b'\x00'])
+ self.assertEqual(cborutil.decodeall(b'\x00'), [0])
+
self.assertEqual(list(cborutil.streamencode(1)), [b'\x01'])
+ self.assertEqual(cborutil.decodeall(b'\x01'), [1])
+
self.assertEqual(list(cborutil.streamencode(2)), [b'\x02'])
+ self.assertEqual(cborutil.decodeall(b'\x02'), [2])
+
self.assertEqual(list(cborutil.streamencode(3)), [b'\x03'])
+ self.assertEqual(cborutil.decodeall(b'\x03'), [3])
+
self.assertEqual(list(cborutil.streamencode(4)), [b'\x04'])
+ self.assertEqual(cborutil.decodeall(b'\x04'), [4])
+
+ # Multiple value decode works.
+ self.assertEqual(cborutil.decodeall(b'\x00\x01\x02\x03\x04'),
+ [0, 1, 2, 3, 4])
def testnegativesmall(self):
self.assertEqual(list(cborutil.streamencode(-1)), [b'\x20'])
+ self.assertEqual(cborutil.decodeall(b'\x20'), [-1])
+
self.assertEqual(list(cborutil.streamencode(-2)), [b'\x21'])
+ self.assertEqual(cborutil.decodeall(b'\x21'), [-2])
+
self.assertEqual(list(cborutil.streamencode(-3)), [b'\x22'])
+ self.assertEqual(cborutil.decodeall(b'\x22'), [-3])
+
self.assertEqual(list(cborutil.streamencode(-4)), [b'\x23'])
+ self.assertEqual(cborutil.decodeall(b'\x23'), [-4])
+
self.assertEqual(list(cborutil.streamencode(-5)), [b'\x24'])
+ self.assertEqual(cborutil.decodeall(b'\x24'), [-5])
+
+ # Multiple value decode works.
+ self.assertEqual(cborutil.decodeall(b'\x20\x21\x22\x23\x24'),
+ [-1, -2, -3, -4, -5])
def testrange(self):
for i in range(-70000, 70000, 10):
- self.assertEqual(
- b''.join(cborutil.streamencode(i)),
- cbor.dumps(i))
+ encoded = b''.join(cborutil.streamencode(i))
+
+ self.assertEqual(encoded, cbor.dumps(i))
+ self.assertEqual(cborutil.decodeall(encoded), [i])
+
+ def testdecodepartialubyte(self):
+ encoded = b''.join(cborutil.streamencode(250))
+
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+ (True, 250, 2, cborutil.SPECIAL_NONE))
+
+ def testdecodepartialbyte(self):
+ encoded = b''.join(cborutil.streamencode(-42))
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+ (True, -42, 2, cborutil.SPECIAL_NONE))
+
+ def testdecodepartialushort(self):
+ encoded = b''.join(cborutil.streamencode(2**15))
+
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (False, None, -2, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+ (True, 2**15, 3, cborutil.SPECIAL_NONE))
+
+ def testdecodepartialshort(self):
+ encoded = b''.join(cborutil.streamencode(-1024))
+
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (False, None, -2, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+ (True, -1024, 3, cborutil.SPECIAL_NONE))
+
+ def testdecodepartialulong(self):
+ encoded = b''.join(cborutil.streamencode(2**28))
+
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (False, None, -4, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+ (False, None, -3, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+ (False, None, -2, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+ (True, 2**28, 5, cborutil.SPECIAL_NONE))
+
+ def testdecodepartiallong(self):
+ encoded = b''.join(cborutil.streamencode(-1048580))
-class ArrayTests(unittest.TestCase):
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (False, None, -4, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+ (False, None, -3, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+ (False, None, -2, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+ (True, -1048580, 5, cborutil.SPECIAL_NONE))
+
+ def testdecodepartialulonglong(self):
+ encoded = b''.join(cborutil.streamencode(2**32))
+
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (False, None, -8, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+ (False, None, -7, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+ (False, None, -6, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+ (False, None, -5, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+ (False, None, -4, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:6]),
+ (False, None, -3, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:7]),
+ (False, None, -2, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:8]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:9]),
+ (True, 2**32, 9, cborutil.SPECIAL_NONE))
+
+ with self.assertRaisesRegex(
+ cborutil.CBORDecodeError, 'input data not fully consumed'):
+ cborutil.decodeall(encoded[0:1])
+
+ with self.assertRaisesRegex(
+ cborutil.CBORDecodeError, 'input data not fully consumed'):
+ cborutil.decodeall(encoded[0:2])
+
+ def testdecodepartiallonglong(self):
+ encoded = b''.join(cborutil.streamencode(-7000000000))
+
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (False, None, -8, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+ (False, None, -7, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+ (False, None, -6, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+ (False, None, -5, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+ (False, None, -4, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:6]),
+ (False, None, -3, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:7]),
+ (False, None, -2, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:8]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:9]),
+ (True, -7000000000, 9, cborutil.SPECIAL_NONE))
+
+class ArrayTests(TestCase):
def testempty(self):
self.assertEqual(list(cborutil.streamencode([])), [b'\x80'])
self.assertEqual(loadit(cborutil.streamencode([])), [])
+ self.assertEqual(cborutil.decodeall(b'\x80'), [[]])
+
def testbasic(self):
source = [b'foo', b'bar', 1, -10]
- self.assertEqual(list(cborutil.streamencode(source)), [
- b'\x84', b'\x43', b'foo', b'\x43', b'bar', b'\x01', b'\x29'])
+ chunks = [
+ b'\x84', b'\x43', b'foo', b'\x43', b'bar', b'\x01', b'\x29']
+
+ self.assertEqual(list(cborutil.streamencode(source)), chunks)
+
+ self.assertEqual(cborutil.decodeall(b''.join(chunks)), [source])
def testemptyfromiter(self):
self.assertEqual(b''.join(cborutil.streamencodearrayfromiter([])),
b'\x9f\xff')
+ with self.assertRaisesRegex(cborutil.CBORDecodeError,
+ 'indefinite length uint not allowed'):
+ cborutil.decodeall(b'\x9f\xff')
+
def testfromiter1(self):
source = [b'foo']
@@ -129,26 +539,193 @@
dest = b''.join(cborutil.streamencodearrayfromiter(source))
self.assertEqual(cbor.loads(dest), source)
+ with self.assertRaisesRegex(cborutil.CBORDecodeError,
+ 'indefinite length uint not allowed'):
+ cborutil.decodeall(dest)
+
def testtuple(self):
source = (b'foo', None, 42)
+ encoded = b''.join(cborutil.streamencode(source))
- self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))),
- list(source))
+ self.assertEqual(cbor.loads(encoded), list(source))
+
+ self.assertEqual(cborutil.decodeall(encoded), [list(source)])
+
+ def testpartialdecode(self):
+ source = list(range(4))
+ encoded = b''.join(cborutil.streamencode(source))
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (True, 4, 1, cborutil.SPECIAL_START_ARRAY))
+ self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+ (True, 4, 1, cborutil.SPECIAL_START_ARRAY))
+
+ source = list(range(23))
+ encoded = b''.join(cborutil.streamencode(source))
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (True, 23, 1, cborutil.SPECIAL_START_ARRAY))
+ self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+ (True, 23, 1, cborutil.SPECIAL_START_ARRAY))
+
+ source = list(range(24))
+ encoded = b''.join(cborutil.streamencode(source))
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+ (True, 24, 2, cborutil.SPECIAL_START_ARRAY))
+ self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+ (True, 24, 2, cborutil.SPECIAL_START_ARRAY))
-class SetTests(unittest.TestCase):
+ source = list(range(256))
+ encoded = b''.join(cborutil.streamencode(source))
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (False, None, -2, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+ (True, 256, 3, cborutil.SPECIAL_START_ARRAY))
+ self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+ (True, 256, 3, cborutil.SPECIAL_START_ARRAY))
+
+ def testnested(self):
+ source = [[], [], [[], [], []]]
+ encoded = b''.join(cborutil.streamencode(source))
+ self.assertEqual(cborutil.decodeall(encoded), [source])
+
+ source = [True, None, [True, 0, 2], [None], [], [[[]], -87]]
+ encoded = b''.join(cborutil.streamencode(source))
+ self.assertEqual(cborutil.decodeall(encoded), [source])
+
+ # A set within an array.
+ source = [None, {b'foo', b'bar', None, False}, set()]
+ encoded = b''.join(cborutil.streamencode(source))
+ self.assertEqual(cborutil.decodeall(encoded), [source])
+
+ # A map within an array.
+ source = [None, {}, {b'foo': b'bar', True: False}, [{}]]
+ encoded = b''.join(cborutil.streamencode(source))
+ self.assertEqual(cborutil.decodeall(encoded), [source])
+
+ def testindefinitebytestringvalues(self):
+ # Single value array whose value is an empty indefinite bytestring.
+ encoded = b'\x81\x5f\x40\xff'
+
+ with self.assertRaisesRegex(cborutil.CBORDecodeError,
+ 'indefinite length bytestrings not '
+ 'allowed as array values'):
+ cborutil.decodeall(encoded)
+
+class SetTests(TestCase):
def testempty(self):
self.assertEqual(list(cborutil.streamencode(set())), [
b'\xd9\x01\x02',
b'\x80',
])
+ self.assertEqual(cborutil.decodeall(b'\xd9\x01\x02\x80'), [set()])
+
def testset(self):
source = {b'foo', None, 42}
+ encoded = b''.join(cborutil.streamencode(source))
- self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))),
- source)
+ self.assertEqual(cbor.loads(encoded), source)
+
+ self.assertEqual(cborutil.decodeall(encoded), [source])
+
+ def testinvalidtag(self):
+ # Must use array to encode sets.
+ encoded = b'\xd9\x01\x02\xa0'
+
+ with self.assertRaisesRegex(cborutil.CBORDecodeError,
+ 'expected array after finite set '
+ 'semantic tag'):
+ cborutil.decodeall(encoded)
+
+ def testpartialdecode(self):
+ # Semantic tag item will be 3 bytes. Set header will be variable
+ # depending on length.
+ encoded = b''.join(cborutil.streamencode({i for i in range(23)}))
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (False, None, -2, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+ (True, 23, 4, cborutil.SPECIAL_START_SET))
+ self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+ (True, 23, 4, cborutil.SPECIAL_START_SET))
+
+ encoded = b''.join(cborutil.streamencode({i for i in range(24)}))
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (False, None, -2, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+ (True, 24, 5, cborutil.SPECIAL_START_SET))
+ self.assertEqual(cborutil.decodeitem(encoded[0:6]),
+ (True, 24, 5, cborutil.SPECIAL_START_SET))
-class BoolTests(unittest.TestCase):
+ encoded = b''.join(cborutil.streamencode({i for i in range(256)}))
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (False, None, -2, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+ (False, None, -2, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:6]),
+ (True, 256, 6, cborutil.SPECIAL_START_SET))
+
+ def testinvalidvalue(self):
+ encoded = b''.join([
+ b'\xd9\x01\x02', # semantic tag
+ b'\x81', # array of size 1
+ b'\x5f\x43foo\xff', # indefinite length bytestring "foo"
+ ])
+
+ with self.assertRaisesRegex(cborutil.CBORDecodeError,
+ 'indefinite length bytestrings not '
+ 'allowed as set values'):
+ cborutil.decodeall(encoded)
+
+ encoded = b''.join([
+ b'\xd9\x01\x02',
+ b'\x81',
+ b'\x80', # empty array
+ ])
+
+ with self.assertRaisesRegex(cborutil.CBORDecodeError,
+ 'collections not allowed as set values'):
+ cborutil.decodeall(encoded)
+
+ encoded = b''.join([
+ b'\xd9\x01\x02',
+ b'\x81',
+ b'\xa0', # empty map
+ ])
+
+ with self.assertRaisesRegex(cborutil.CBORDecodeError,
+ 'collections not allowed as set values'):
+ cborutil.decodeall(encoded)
+
+ encoded = b''.join([
+ b'\xd9\x01\x02',
+ b'\x81',
+ b'\xd9\x01\x02\x81\x01', # set with integer 1
+ ])
+
+ with self.assertRaisesRegex(cborutil.CBORDecodeError,
+ 'collections not allowed as set values'):
+ cborutil.decodeall(encoded)
+
+class BoolTests(TestCase):
def testbasic(self):
self.assertEqual(list(cborutil.streamencode(True)), [b'\xf5'])
self.assertEqual(list(cborutil.streamencode(False)), [b'\xf4'])
@@ -156,23 +733,38 @@
self.assertIs(loadit(cborutil.streamencode(True)), True)
self.assertIs(loadit(cborutil.streamencode(False)), False)
-class NoneTests(unittest.TestCase):
+ self.assertEqual(cborutil.decodeall(b'\xf4'), [False])
+ self.assertEqual(cborutil.decodeall(b'\xf5'), [True])
+
+ self.assertEqual(cborutil.decodeall(b'\xf4\xf5\xf5\xf4'),
+ [False, True, True, False])
+
+class NoneTests(TestCase):
def testbasic(self):
self.assertEqual(list(cborutil.streamencode(None)), [b'\xf6'])
self.assertIs(loadit(cborutil.streamencode(None)), None)
-class MapTests(unittest.TestCase):
+ self.assertEqual(cborutil.decodeall(b'\xf6'), [None])
+ self.assertEqual(cborutil.decodeall(b'\xf6\xf6'), [None, None])
+
+class MapTests(TestCase):
def testempty(self):
self.assertEqual(list(cborutil.streamencode({})), [b'\xa0'])
self.assertEqual(loadit(cborutil.streamencode({})), {})
+ self.assertEqual(cborutil.decodeall(b'\xa0'), [{}])
+
def testemptyindefinite(self):
self.assertEqual(list(cborutil.streamencodemapfromiter([])), [
b'\xbf', b'\xff'])
self.assertEqual(loadit(cborutil.streamencodemapfromiter([])), {})
+ with self.assertRaisesRegex(cborutil.CBORDecodeError,
+ 'indefinite length uint not allowed'):
+ cborutil.decodeall(b'\xbf\xff')
+
def testone(self):
source = {b'foo': b'bar'}
self.assertEqual(list(cborutil.streamencode(source)), [
@@ -180,6 +772,8 @@
self.assertEqual(loadit(cborutil.streamencode(source)), source)
+ self.assertEqual(cborutil.decodeall(b'\xa1\x43foo\x43bar'), [source])
+
def testmultiple(self):
source = {
b'foo': b'bar',
@@ -192,6 +786,9 @@
loadit(cborutil.streamencodemapfromiter(source.items())),
source)
+ encoded = b''.join(cborutil.streamencode(source))
+ self.assertEqual(cborutil.decodeall(encoded), [source])
+
def testcomplex(self):
source = {
b'key': 1,
@@ -205,6 +802,170 @@
loadit(cborutil.streamencodemapfromiter(source.items())),
source)
+ encoded = b''.join(cborutil.streamencode(source))
+ self.assertEqual(cborutil.decodeall(encoded), [source])
+
+ def testnested(self):
+ source = {b'key1': None, b'key2': {b'sub1': b'sub2'}, b'sub2': {}}
+ encoded = b''.join(cborutil.streamencode(source))
+
+ self.assertEqual(cborutil.decodeall(encoded), [source])
+
+ source = {
+ b'key1': [],
+ b'key2': [None, False],
+ b'key3': {b'foo', b'bar'},
+ b'key4': {},
+ }
+ encoded = b''.join(cborutil.streamencode(source))
+ self.assertEqual(cborutil.decodeall(encoded), [source])
+
+ def testillegalkey(self):
+ encoded = b''.join([
+ # map header + len 1
+ b'\xa1',
+ # indefinite length bytestring "foo" in key position
+ b'\x5f\x03foo\xff'
+ ])
+
+ with self.assertRaisesRegex(cborutil.CBORDecodeError,
+ 'indefinite length bytestrings not '
+ 'allowed as map keys'):
+ cborutil.decodeall(encoded)
+
+ encoded = b''.join([
+ b'\xa1',
+ b'\x80', # empty array
+ b'\x43foo',
+ ])
+
+ with self.assertRaisesRegex(cborutil.CBORDecodeError,
+ 'collections not supported as map keys'):
+ cborutil.decodeall(encoded)
+
+ def testillegalvalue(self):
+ encoded = b''.join([
+ b'\xa1', # map headers
+ b'\x43foo', # key
+ b'\x5f\x03bar\xff', # indefinite length value
+ ])
+
+ with self.assertRaisesRegex(cborutil.CBORDecodeError,
+ 'indefinite length bytestrings not '
+ 'allowed as map values'):
+ cborutil.decodeall(encoded)
+
+ def testpartialdecode(self):
+ source = {b'key1': b'value1'}
+ encoded = b''.join(cborutil.streamencode(source))
+
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (True, 1, 1, cborutil.SPECIAL_START_MAP))
+ self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+ (True, 1, 1, cborutil.SPECIAL_START_MAP))
+
+ source = {b'key%d' % i: None for i in range(23)}
+ encoded = b''.join(cborutil.streamencode(source))
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (True, 23, 1, cborutil.SPECIAL_START_MAP))
+
+ source = {b'key%d' % i: None for i in range(24)}
+ encoded = b''.join(cborutil.streamencode(source))
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+ (True, 24, 2, cborutil.SPECIAL_START_MAP))
+ self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+ (True, 24, 2, cborutil.SPECIAL_START_MAP))
+
+ source = {b'key%d' % i: None for i in range(256)}
+ encoded = b''.join(cborutil.streamencode(source))
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (False, None, -2, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+ (True, 256, 3, cborutil.SPECIAL_START_MAP))
+ self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+ (True, 256, 3, cborutil.SPECIAL_START_MAP))
+
+ source = {b'key%d' % i: None for i in range(65536)}
+ encoded = b''.join(cborutil.streamencode(source))
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (False, None, -4, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+ (False, None, -3, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:3]),
+ (False, None, -2, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:4]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:5]),
+ (True, 65536, 5, cborutil.SPECIAL_START_MAP))
+ self.assertEqual(cborutil.decodeitem(encoded[0:6]),
+ (True, 65536, 5, cborutil.SPECIAL_START_MAP))
+
+class SemanticTagTests(TestCase):
+ def testdecodeforbidden(self):
+ for i in range(500):
+ if i == cborutil.SEMANTIC_TAG_FINITE_SET:
+ continue
+
+ tag = cborutil.encodelength(cborutil.MAJOR_TYPE_SEMANTIC,
+ i)
+
+ encoded = tag + cborutil.encodelength(cborutil.MAJOR_TYPE_UINT, 42)
+
+ # Partial decode is incomplete.
+ if i < 24:
+ pass
+ elif i < 256:
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+ elif i < 65536:
+ self.assertEqual(cborutil.decodeitem(encoded[0:1]),
+ (False, None, -2, cborutil.SPECIAL_NONE))
+ self.assertEqual(cborutil.decodeitem(encoded[0:2]),
+ (False, None, -1, cborutil.SPECIAL_NONE))
+
+ with self.assertRaisesRegex(cborutil.CBORDecodeError,
+ 'semantic tag \d+ not allowed'):
+ cborutil.decodeitem(encoded)
+
+class SpecialTypesTests(TestCase):
+ def testforbiddentypes(self):
+ for i in range(256):
+ if i == cborutil.SUBTYPE_FALSE:
+ continue
+ elif i == cborutil.SUBTYPE_TRUE:
+ continue
+ elif i == cborutil.SUBTYPE_NULL:
+ continue
+
+ encoded = cborutil.encodelength(cborutil.MAJOR_TYPE_SPECIAL, i)
+
+ with self.assertRaisesRegex(cborutil.CBORDecodeError,
+ 'special type \d+ not allowed'):
+ cborutil.decodeitem(encoded)
+
+class SansIODecoderTests(TestCase):
+ def testemptyinput(self):
+ decoder = cborutil.sansiodecoder()
+ self.assertEqual(decoder.decode(b''), (False, 0, 0))
+
+class DecodeallTests(TestCase):
+ def testemptyinput(self):
+ self.assertEqual(cborutil.decodeall(b''), [])
+
+ def testpartialinput(self):
+ encoded = b''.join([
+ b'\x82', # array of 2 elements
+ b'\x01', # integer 1
+ ])
+
+ with self.assertRaisesRegex(cborutil.CBORDecodeError,
+ 'input data not complete'):
+ cborutil.decodeall(encoded)
+
if __name__ == '__main__':
import silenttestrunner
silenttestrunner.main(__name__)