Mercurial > hg
view tests/test-cbor.py @ 49000:dd6b67d5c256 stable
rust: fix unsound `OwningDirstateMap`
As per the previous patch, `OwningDirstateMap` is unsound. Self-referential
structs are difficult to implement correctly in Rust since the compiler is
free to move structs around as much as it wants to. They are also very rarely
needed in practice, so the state-of-the-art on how they should be done within
the Rust rules is still a bit new.
The crate `ouroboros` is an attempt at providing a safe way (in the Rust sense)
of declaring self-referential structs. It is getting a lot attention and was
improved very quickly when soundness issues were found in the past: rather than
relying on our own (limited) review circle, we might as well use the de-facto
common crate to fix this problem. This will give us a much better chance of
finding issues should any new ones be discovered as well as the benefit of
fewer `unsafe` APIs of our own.
I was starting to think about how I would present a safe API to the old struct
but soon realized that the callback-based approach was already done in
`ouroboros`, along with a lot more care towards refusing incorrect structs.
In short: we don't return a mutable reference to the `DirstateMap` anymore, we
expect users of its API to pass a `FnOnce` that takes the map as an argument.
This allows our `OwningDirstateMap` to control the input and output lifetimes
of the code that modifies it to prevent such issues.
Changing to `ouroboros` meant changing every API with it, but it is relatively
low churn in the end. It correctly identified the example buggy modification of
`copy_map_insert` outlined in the previous patch as violating the borrow rules.
Differential Revision: https://phab.mercurial-scm.org/D12429
author | Raphaël Gomès <rgomes@octobus.net> |
---|---|
date | Tue, 05 Apr 2022 10:55:28 +0200 |
parents | 89a2afe31e82 |
children | 6000f5b25c9b |
line wrap: on
line source
from __future__ import absolute_import import os import sys import unittest # TODO migrate to canned cbor test strings and stop using thirdparty.cbor tpp = os.path.normpath( os.path.join(os.path.dirname(__file__), '..', 'mercurial', 'thirdparty') ) if not os.path.exists(tpp): # skip, not in a repo sys.exit(80) sys.path[0:0] = [tpp] import cbor del sys.path[0] from mercurial.utils import cborutil class TestCase(unittest.TestCase): if not getattr(unittest.TestCase, 'assertRaisesRegex', False): # Python 3.7 deprecates the regex*p* version, but 2.7 lacks # the regex version. assertRaisesRegex = ( # camelcase-required unittest.TestCase.assertRaisesRegexp ) def loadit(it): return cbor.loads(b''.join(it)) class BytestringTests(TestCase): def testsimple(self): self.assertEqual( list(cborutil.streamencode(b'foobar')), [b'\x46', b'foobar'] ) self.assertEqual(loadit(cborutil.streamencode(b'foobar')), b'foobar') self.assertEqual(cborutil.decodeall(b'\x46foobar'), [b'foobar']) self.assertEqual( cborutil.decodeall(b'\x46foobar\x45fizbi'), [b'foobar', b'fizbi'] ) def testlong(self): source = b'x' * 1048576 self.assertEqual(loadit(cborutil.streamencode(source)), source) encoded = b''.join(cborutil.streamencode(source)) self.assertEqual(cborutil.decodeall(encoded), [source]) def testfromiter(self): # This is the example from RFC 7049 Section 2.2.2. source = [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99'] self.assertEqual( list(cborutil.streamencodebytestringfromiter(source)), [ b'\x5f', b'\x44', b'\xaa\xbb\xcc\xdd', b'\x43', b'\xee\xff\x99', b'\xff', ], ) self.assertEqual( loadit(cborutil.streamencodebytestringfromiter(source)), b''.join(source), ) self.assertEqual( cborutil.decodeall( b'\x5f\x44\xaa\xbb\xcc\xdd' b'\x43\xee\xff\x99\xff' ), [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99', b''], ) for i, chunk in enumerate( cborutil.decodeall( b'\x5f\x44\xaa\xbb\xcc\xdd' b'\x43\xee\xff\x99\xff' ) ): self.assertIsInstance(chunk, cborutil.bytestringchunk) if i == 0: self.assertTrue(chunk.isfirst) else: self.assertFalse(chunk.isfirst) if i == 2: self.assertTrue(chunk.islast) else: self.assertFalse(chunk.islast) def testfromiterlarge(self): source = [b'a' * 16, b'b' * 128, b'c' * 1024, b'd' * 1048576] self.assertEqual( loadit(cborutil.streamencodebytestringfromiter(source)), b''.join(source), ) def testindefinite(self): source = b'\x00\x01\x02\x03' + b'\xff' * 16384 it = cborutil.streamencodeindefinitebytestring(source, chunksize=2) self.assertEqual(next(it), b'\x5f') self.assertEqual(next(it), b'\x42') self.assertEqual(next(it), b'\x00\x01') self.assertEqual(next(it), b'\x42') self.assertEqual(next(it), b'\x02\x03') self.assertEqual(next(it), b'\x42') self.assertEqual(next(it), b'\xff\xff') dest = b''.join( cborutil.streamencodeindefinitebytestring(source, chunksize=42) ) self.assertEqual(cbor.loads(dest), source) self.assertEqual(b''.join(cborutil.decodeall(dest)), source) for chunk in cborutil.decodeall(dest): self.assertIsInstance(chunk, cborutil.bytestringchunk) self.assertIn(len(chunk), (0, 8, 42)) encoded = b'\x5f\xff' b = cborutil.decodeall(encoded) self.assertEqual(b, [b'']) self.assertTrue(b[0].isfirst) self.assertTrue(b[0].islast) def testdecodevariouslengths(self): for i in (0, 1, 22, 23, 24, 25, 254, 255, 256, 65534, 65535, 65536): source = b'x' * i encoded = b''.join(cborutil.streamencode(source)) if len(source) < 24: hlen = 1 elif len(source) < 256: hlen = 2 elif len(source) < 65536: hlen = 3 elif len(source) < 1048576: hlen = 5 self.assertEqual( cborutil.decodeitem(encoded), (True, source, hlen + len(source), cborutil.SPECIAL_NONE), ) def testpartialdecode(self): encoded = b''.join(cborutil.streamencode(b'foobar')) self.assertEqual( cborutil.decodeitem(encoded[0:1]), (False, None, -6, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:2]), (False, None, -5, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:3]), (False, None, -4, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:4]), (False, None, -3, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:5]), (False, None, -2, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:6]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:7]), (True, b'foobar', 7, cborutil.SPECIAL_NONE), ) def testpartialdecodevariouslengths(self): lens = [ 2, 3, 10, 23, 24, 25, 31, 100, 254, 255, 256, 257, 16384, 65534, 65535, 65536, 65537, 131071, 131072, 131073, 1048575, 1048576, 1048577, ] for size in lens: if size < 24: hlen = 1 elif size < 2 ** 8: hlen = 2 elif size < 2 ** 16: hlen = 3 elif size < 2 ** 32: hlen = 5 else: assert False source = b'x' * size encoded = b''.join(cborutil.streamencode(source)) res = cborutil.decodeitem(encoded[0:1]) if hlen > 1: self.assertEqual( res, (False, None, -(hlen - 1), cborutil.SPECIAL_NONE) ) else: self.assertEqual( res, (False, None, -(size + hlen - 1), cborutil.SPECIAL_NONE), ) # Decoding partial header reports remaining header size. for i in range(hlen - 1): self.assertEqual( cborutil.decodeitem(encoded[0 : i + 1]), (False, None, -(hlen - i - 1), cborutil.SPECIAL_NONE), ) # Decoding complete header reports item size. self.assertEqual( cborutil.decodeitem(encoded[0:hlen]), (False, None, -size, cborutil.SPECIAL_NONE), ) # Decoding single byte after header reports item size - 1 self.assertEqual( cborutil.decodeitem(encoded[0 : hlen + 1]), (False, None, -(size - 1), cborutil.SPECIAL_NONE), ) # Decoding all but the last byte reports -1 needed. self.assertEqual( cborutil.decodeitem(encoded[0 : hlen + size - 1]), (False, None, -1, cborutil.SPECIAL_NONE), ) # Decoding last byte retrieves value. self.assertEqual( cborutil.decodeitem(encoded[0 : hlen + size]), (True, source, hlen + size, cborutil.SPECIAL_NONE), ) def testindefinitepartialdecode(self): encoded = b''.join( cborutil.streamencodebytestringfromiter([b'foobar', b'biz']) ) # First item should be begin of bytestring special. self.assertEqual( cborutil.decodeitem(encoded[0:1]), (True, None, 1, cborutil.SPECIAL_START_INDEFINITE_BYTESTRING), ) # Second item should be the first chunk. But only available when # we give it 7 bytes (1 byte header + 6 byte chunk). self.assertEqual( cborutil.decodeitem(encoded[1:2]), (False, None, -6, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[1:3]), (False, None, -5, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[1:4]), (False, None, -4, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[1:5]), (False, None, -3, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[1:6]), (False, None, -2, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[1:7]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[1:8]), (True, b'foobar', 7, cborutil.SPECIAL_NONE), ) # Third item should be second chunk. But only available when # we give it 4 bytes (1 byte header + 3 byte chunk). self.assertEqual( cborutil.decodeitem(encoded[8:9]), (False, None, -3, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[8:10]), (False, None, -2, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[8:11]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[8:12]), (True, b'biz', 4, cborutil.SPECIAL_NONE), ) # Fourth item should be end of indefinite stream marker. self.assertEqual( cborutil.decodeitem(encoded[12:13]), (True, None, 1, cborutil.SPECIAL_INDEFINITE_BREAK), ) # Now test the behavior when going through the decoder. self.assertEqual( cborutil.sansiodecoder().decode(encoded[0:1]), (False, 1, 0) ) self.assertEqual( cborutil.sansiodecoder().decode(encoded[0:2]), (False, 1, 6) ) self.assertEqual( cborutil.sansiodecoder().decode(encoded[0:3]), (False, 1, 5) ) self.assertEqual( cborutil.sansiodecoder().decode(encoded[0:4]), (False, 1, 4) ) self.assertEqual( cborutil.sansiodecoder().decode(encoded[0:5]), (False, 1, 3) ) self.assertEqual( cborutil.sansiodecoder().decode(encoded[0:6]), (False, 1, 2) ) self.assertEqual( cborutil.sansiodecoder().decode(encoded[0:7]), (False, 1, 1) ) self.assertEqual( cborutil.sansiodecoder().decode(encoded[0:8]), (True, 8, 0) ) self.assertEqual( cborutil.sansiodecoder().decode(encoded[0:9]), (True, 8, 3) ) self.assertEqual( cborutil.sansiodecoder().decode(encoded[0:10]), (True, 8, 2) ) self.assertEqual( cborutil.sansiodecoder().decode(encoded[0:11]), (True, 8, 1) ) self.assertEqual( cborutil.sansiodecoder().decode(encoded[0:12]), (True, 12, 0) ) self.assertEqual( cborutil.sansiodecoder().decode(encoded[0:13]), (True, 13, 0) ) decoder = cborutil.sansiodecoder() decoder.decode(encoded[0:8]) values = decoder.getavailable() self.assertEqual(values, [b'foobar']) self.assertTrue(values[0].isfirst) self.assertFalse(values[0].islast) self.assertEqual(decoder.decode(encoded[8:12]), (True, 4, 0)) values = decoder.getavailable() self.assertEqual(values, [b'biz']) self.assertFalse(values[0].isfirst) self.assertFalse(values[0].islast) self.assertEqual(decoder.decode(encoded[12:]), (True, 1, 0)) values = decoder.getavailable() self.assertEqual(values, [b'']) self.assertFalse(values[0].isfirst) self.assertTrue(values[0].islast) class StringTests(TestCase): def testdecodeforbidden(self): encoded = b'\x63foo' with self.assertRaisesRegex( cborutil.CBORDecodeError, 'string major type not supported' ): cborutil.decodeall(encoded) class IntTests(TestCase): def testsmall(self): self.assertEqual(list(cborutil.streamencode(0)), [b'\x00']) self.assertEqual(cborutil.decodeall(b'\x00'), [0]) self.assertEqual(list(cborutil.streamencode(1)), [b'\x01']) self.assertEqual(cborutil.decodeall(b'\x01'), [1]) self.assertEqual(list(cborutil.streamencode(2)), [b'\x02']) self.assertEqual(cborutil.decodeall(b'\x02'), [2]) self.assertEqual(list(cborutil.streamencode(3)), [b'\x03']) self.assertEqual(cborutil.decodeall(b'\x03'), [3]) self.assertEqual(list(cborutil.streamencode(4)), [b'\x04']) self.assertEqual(cborutil.decodeall(b'\x04'), [4]) # Multiple value decode works. self.assertEqual( cborutil.decodeall(b'\x00\x01\x02\x03\x04'), [0, 1, 2, 3, 4] ) def testnegativesmall(self): self.assertEqual(list(cborutil.streamencode(-1)), [b'\x20']) self.assertEqual(cborutil.decodeall(b'\x20'), [-1]) self.assertEqual(list(cborutil.streamencode(-2)), [b'\x21']) self.assertEqual(cborutil.decodeall(b'\x21'), [-2]) self.assertEqual(list(cborutil.streamencode(-3)), [b'\x22']) self.assertEqual(cborutil.decodeall(b'\x22'), [-3]) self.assertEqual(list(cborutil.streamencode(-4)), [b'\x23']) self.assertEqual(cborutil.decodeall(b'\x23'), [-4]) self.assertEqual(list(cborutil.streamencode(-5)), [b'\x24']) self.assertEqual(cborutil.decodeall(b'\x24'), [-5]) # Multiple value decode works. self.assertEqual( cborutil.decodeall(b'\x20\x21\x22\x23\x24'), [-1, -2, -3, -4, -5] ) def testrange(self): for i in range(-70000, 70000, 10): encoded = b''.join(cborutil.streamencode(i)) self.assertEqual(encoded, cbor.dumps(i)) self.assertEqual(cborutil.decodeall(encoded), [i]) def testdecodepartialubyte(self): encoded = b''.join(cborutil.streamencode(250)) self.assertEqual( cborutil.decodeitem(encoded[0:1]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:2]), (True, 250, 2, cborutil.SPECIAL_NONE), ) def testdecodepartialbyte(self): encoded = b''.join(cborutil.streamencode(-42)) self.assertEqual( cborutil.decodeitem(encoded[0:1]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:2]), (True, -42, 2, cborutil.SPECIAL_NONE), ) def testdecodepartialushort(self): encoded = b''.join(cborutil.streamencode(2 ** 15)) self.assertEqual( cborutil.decodeitem(encoded[0:1]), (False, None, -2, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:2]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:5]), (True, 2 ** 15, 3, cborutil.SPECIAL_NONE), ) def testdecodepartialshort(self): encoded = b''.join(cborutil.streamencode(-1024)) self.assertEqual( cborutil.decodeitem(encoded[0:1]), (False, None, -2, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:2]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:3]), (True, -1024, 3, cborutil.SPECIAL_NONE), ) def testdecodepartialulong(self): encoded = b''.join(cborutil.streamencode(2 ** 28)) self.assertEqual( cborutil.decodeitem(encoded[0:1]), (False, None, -4, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:2]), (False, None, -3, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:3]), (False, None, -2, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:4]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:5]), (True, 2 ** 28, 5, cborutil.SPECIAL_NONE), ) def testdecodepartiallong(self): encoded = b''.join(cborutil.streamencode(-1048580)) self.assertEqual( cborutil.decodeitem(encoded[0:1]), (False, None, -4, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:2]), (False, None, -3, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:3]), (False, None, -2, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:4]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:5]), (True, -1048580, 5, cborutil.SPECIAL_NONE), ) def testdecodepartialulonglong(self): encoded = b''.join(cborutil.streamencode(2 ** 32)) self.assertEqual( cborutil.decodeitem(encoded[0:1]), (False, None, -8, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:2]), (False, None, -7, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:3]), (False, None, -6, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:4]), (False, None, -5, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:5]), (False, None, -4, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:6]), (False, None, -3, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:7]), (False, None, -2, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:8]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:9]), (True, 2 ** 32, 9, cborutil.SPECIAL_NONE), ) with self.assertRaisesRegex( cborutil.CBORDecodeError, 'input data not fully consumed' ): cborutil.decodeall(encoded[0:1]) with self.assertRaisesRegex( cborutil.CBORDecodeError, 'input data not fully consumed' ): cborutil.decodeall(encoded[0:2]) def testdecodepartiallonglong(self): encoded = b''.join(cborutil.streamencode(-7000000000)) self.assertEqual( cborutil.decodeitem(encoded[0:1]), (False, None, -8, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:2]), (False, None, -7, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:3]), (False, None, -6, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:4]), (False, None, -5, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:5]), (False, None, -4, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:6]), (False, None, -3, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:7]), (False, None, -2, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:8]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:9]), (True, -7000000000, 9, cborutil.SPECIAL_NONE), ) class ArrayTests(TestCase): def testempty(self): self.assertEqual(list(cborutil.streamencode([])), [b'\x80']) self.assertEqual(loadit(cborutil.streamencode([])), []) self.assertEqual(cborutil.decodeall(b'\x80'), [[]]) def testbasic(self): source = [b'foo', b'bar', 1, -10] chunks = [b'\x84', b'\x43', b'foo', b'\x43', b'bar', b'\x01', b'\x29'] self.assertEqual(list(cborutil.streamencode(source)), chunks) self.assertEqual(cborutil.decodeall(b''.join(chunks)), [source]) def testemptyfromiter(self): self.assertEqual( b''.join(cborutil.streamencodearrayfromiter([])), b'\x9f\xff' ) with self.assertRaisesRegex( cborutil.CBORDecodeError, 'indefinite length uint not allowed' ): cborutil.decodeall(b'\x9f\xff') def testfromiter1(self): source = [b'foo'] self.assertEqual( list(cborutil.streamencodearrayfromiter(source)), [ b'\x9f', b'\x43', b'foo', b'\xff', ], ) dest = b''.join(cborutil.streamencodearrayfromiter(source)) self.assertEqual(cbor.loads(dest), source) with self.assertRaisesRegex( cborutil.CBORDecodeError, 'indefinite length uint not allowed' ): cborutil.decodeall(dest) def testtuple(self): source = (b'foo', None, 42) encoded = b''.join(cborutil.streamencode(source)) self.assertEqual(cbor.loads(encoded), list(source)) self.assertEqual(cborutil.decodeall(encoded), [list(source)]) def testpartialdecode(self): source = list(range(4)) encoded = b''.join(cborutil.streamencode(source)) self.assertEqual( cborutil.decodeitem(encoded[0:1]), (True, 4, 1, cborutil.SPECIAL_START_ARRAY), ) self.assertEqual( cborutil.decodeitem(encoded[0:2]), (True, 4, 1, cborutil.SPECIAL_START_ARRAY), ) source = list(range(23)) encoded = b''.join(cborutil.streamencode(source)) self.assertEqual( cborutil.decodeitem(encoded[0:1]), (True, 23, 1, cborutil.SPECIAL_START_ARRAY), ) self.assertEqual( cborutil.decodeitem(encoded[0:2]), (True, 23, 1, cborutil.SPECIAL_START_ARRAY), ) source = list(range(24)) encoded = b''.join(cborutil.streamencode(source)) self.assertEqual( cborutil.decodeitem(encoded[0:1]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:2]), (True, 24, 2, cborutil.SPECIAL_START_ARRAY), ) self.assertEqual( cborutil.decodeitem(encoded[0:3]), (True, 24, 2, cborutil.SPECIAL_START_ARRAY), ) source = list(range(256)) encoded = b''.join(cborutil.streamencode(source)) self.assertEqual( cborutil.decodeitem(encoded[0:1]), (False, None, -2, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:2]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:3]), (True, 256, 3, cborutil.SPECIAL_START_ARRAY), ) self.assertEqual( cborutil.decodeitem(encoded[0:4]), (True, 256, 3, cborutil.SPECIAL_START_ARRAY), ) def testnested(self): source = [[], [], [[], [], []]] encoded = b''.join(cborutil.streamencode(source)) self.assertEqual(cborutil.decodeall(encoded), [source]) source = [True, None, [True, 0, 2], [None], [], [[[]], -87]] encoded = b''.join(cborutil.streamencode(source)) self.assertEqual(cborutil.decodeall(encoded), [source]) # A set within an array. source = [None, {b'foo', b'bar', None, False}, set()] encoded = b''.join(cborutil.streamencode(source)) self.assertEqual(cborutil.decodeall(encoded), [source]) # A map within an array. source = [None, {}, {b'foo': b'bar', True: False}, [{}]] encoded = b''.join(cborutil.streamencode(source)) self.assertEqual(cborutil.decodeall(encoded), [source]) def testindefinitebytestringvalues(self): # Single value array whose value is an empty indefinite bytestring. encoded = b'\x81\x5f\x40\xff' with self.assertRaisesRegex( cborutil.CBORDecodeError, 'indefinite length bytestrings not ' 'allowed as array values', ): cborutil.decodeall(encoded) class SetTests(TestCase): def testempty(self): self.assertEqual( list(cborutil.streamencode(set())), [ b'\xd9\x01\x02', b'\x80', ], ) self.assertEqual(cborutil.decodeall(b'\xd9\x01\x02\x80'), [set()]) def testset(self): source = {b'foo', None, 42} encoded = b''.join(cborutil.streamencode(source)) self.assertEqual(cbor.loads(encoded), source) self.assertEqual(cborutil.decodeall(encoded), [source]) def testinvalidtag(self): # Must use array to encode sets. encoded = b'\xd9\x01\x02\xa0' with self.assertRaisesRegex( cborutil.CBORDecodeError, 'expected array after finite set ' 'semantic tag', ): cborutil.decodeall(encoded) def testpartialdecode(self): # Semantic tag item will be 3 bytes. Set header will be variable # depending on length. encoded = b''.join(cborutil.streamencode({i for i in range(23)})) self.assertEqual( cborutil.decodeitem(encoded[0:1]), (False, None, -2, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:2]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:3]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:4]), (True, 23, 4, cborutil.SPECIAL_START_SET), ) self.assertEqual( cborutil.decodeitem(encoded[0:5]), (True, 23, 4, cborutil.SPECIAL_START_SET), ) encoded = b''.join(cborutil.streamencode({i for i in range(24)})) self.assertEqual( cborutil.decodeitem(encoded[0:1]), (False, None, -2, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:2]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:3]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:4]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:5]), (True, 24, 5, cborutil.SPECIAL_START_SET), ) self.assertEqual( cborutil.decodeitem(encoded[0:6]), (True, 24, 5, cborutil.SPECIAL_START_SET), ) encoded = b''.join(cborutil.streamencode({i for i in range(256)})) self.assertEqual( cborutil.decodeitem(encoded[0:1]), (False, None, -2, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:2]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:3]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:4]), (False, None, -2, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:5]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:6]), (True, 256, 6, cborutil.SPECIAL_START_SET), ) def testinvalidvalue(self): encoded = b''.join( [ b'\xd9\x01\x02', # semantic tag b'\x81', # array of size 1 b'\x5f\x43foo\xff', # indefinite length bytestring "foo" ] ) with self.assertRaisesRegex( cborutil.CBORDecodeError, 'indefinite length bytestrings not ' 'allowed as set values', ): cborutil.decodeall(encoded) encoded = b''.join( [ b'\xd9\x01\x02', b'\x81', b'\x80', ] ) # empty array with self.assertRaisesRegex( cborutil.CBORDecodeError, 'collections not allowed as set values' ): cborutil.decodeall(encoded) encoded = b''.join( [ b'\xd9\x01\x02', b'\x81', b'\xa0', ] ) # empty map with self.assertRaisesRegex( cborutil.CBORDecodeError, 'collections not allowed as set values' ): cborutil.decodeall(encoded) encoded = b''.join( [ b'\xd9\x01\x02', b'\x81', b'\xd9\x01\x02\x81\x01', # set with integer 1 ] ) with self.assertRaisesRegex( cborutil.CBORDecodeError, 'collections not allowed as set values' ): cborutil.decodeall(encoded) class BoolTests(TestCase): def testbasic(self): self.assertEqual(list(cborutil.streamencode(True)), [b'\xf5']) self.assertEqual(list(cborutil.streamencode(False)), [b'\xf4']) self.assertIs(loadit(cborutil.streamencode(True)), True) self.assertIs(loadit(cborutil.streamencode(False)), False) self.assertEqual(cborutil.decodeall(b'\xf4'), [False]) self.assertEqual(cborutil.decodeall(b'\xf5'), [True]) self.assertEqual( cborutil.decodeall(b'\xf4\xf5\xf5\xf4'), [False, True, True, False] ) class NoneTests(TestCase): def testbasic(self): self.assertEqual(list(cborutil.streamencode(None)), [b'\xf6']) self.assertIs(loadit(cborutil.streamencode(None)), None) self.assertEqual(cborutil.decodeall(b'\xf6'), [None]) self.assertEqual(cborutil.decodeall(b'\xf6\xf6'), [None, None]) class MapTests(TestCase): def testempty(self): self.assertEqual(list(cborutil.streamencode({})), [b'\xa0']) self.assertEqual(loadit(cborutil.streamencode({})), {}) self.assertEqual(cborutil.decodeall(b'\xa0'), [{}]) def testemptyindefinite(self): self.assertEqual( list(cborutil.streamencodemapfromiter([])), [b'\xbf', b'\xff'] ) self.assertEqual(loadit(cborutil.streamencodemapfromiter([])), {}) with self.assertRaisesRegex( cborutil.CBORDecodeError, 'indefinite length uint not allowed' ): cborutil.decodeall(b'\xbf\xff') def testone(self): source = {b'foo': b'bar'} self.assertEqual( list(cborutil.streamencode(source)), [b'\xa1', b'\x43', b'foo', b'\x43', b'bar'], ) self.assertEqual(loadit(cborutil.streamencode(source)), source) self.assertEqual(cborutil.decodeall(b'\xa1\x43foo\x43bar'), [source]) def testmultiple(self): source = { b'foo': b'bar', b'baz': b'value1', } self.assertEqual(loadit(cborutil.streamencode(source)), source) self.assertEqual( loadit(cborutil.streamencodemapfromiter(source.items())), source ) encoded = b''.join(cborutil.streamencode(source)) self.assertEqual(cborutil.decodeall(encoded), [source]) def testcomplex(self): source = { b'key': 1, 2: -10, } self.assertEqual(loadit(cborutil.streamencode(source)), source) self.assertEqual( loadit(cborutil.streamencodemapfromiter(source.items())), source ) encoded = b''.join(cborutil.streamencode(source)) self.assertEqual(cborutil.decodeall(encoded), [source]) def testnested(self): source = {b'key1': None, b'key2': {b'sub1': b'sub2'}, b'sub2': {}} encoded = b''.join(cborutil.streamencode(source)) self.assertEqual(cborutil.decodeall(encoded), [source]) source = { b'key1': [], b'key2': [None, False], b'key3': {b'foo', b'bar'}, b'key4': {}, } encoded = b''.join(cborutil.streamencode(source)) self.assertEqual(cborutil.decodeall(encoded), [source]) def testillegalkey(self): encoded = b''.join( [ # map header + len 1 b'\xa1', # indefinite length bytestring "foo" in key position b'\x5f\x03foo\xff', ] ) with self.assertRaisesRegex( cborutil.CBORDecodeError, 'indefinite length bytestrings not ' 'allowed as map keys', ): cborutil.decodeall(encoded) encoded = b''.join( [ b'\xa1', b'\x80', b'\x43foo', ] ) # empty array with self.assertRaisesRegex( cborutil.CBORDecodeError, 'collections not supported as map keys' ): cborutil.decodeall(encoded) def testillegalvalue(self): encoded = b''.join( [ b'\xa1', # map headers b'\x43foo', # key b'\x5f\x03bar\xff', # indefinite length value ] ) with self.assertRaisesRegex( cborutil.CBORDecodeError, 'indefinite length bytestrings not ' 'allowed as map values', ): cborutil.decodeall(encoded) def testpartialdecode(self): source = {b'key1': b'value1'} encoded = b''.join(cborutil.streamencode(source)) self.assertEqual( cborutil.decodeitem(encoded[0:1]), (True, 1, 1, cborutil.SPECIAL_START_MAP), ) self.assertEqual( cborutil.decodeitem(encoded[0:2]), (True, 1, 1, cborutil.SPECIAL_START_MAP), ) source = {b'key%d' % i: None for i in range(23)} encoded = b''.join(cborutil.streamencode(source)) self.assertEqual( cborutil.decodeitem(encoded[0:1]), (True, 23, 1, cborutil.SPECIAL_START_MAP), ) source = {b'key%d' % i: None for i in range(24)} encoded = b''.join(cborutil.streamencode(source)) self.assertEqual( cborutil.decodeitem(encoded[0:1]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:2]), (True, 24, 2, cborutil.SPECIAL_START_MAP), ) self.assertEqual( cborutil.decodeitem(encoded[0:3]), (True, 24, 2, cborutil.SPECIAL_START_MAP), ) source = {b'key%d' % i: None for i in range(256)} encoded = b''.join(cborutil.streamencode(source)) self.assertEqual( cborutil.decodeitem(encoded[0:1]), (False, None, -2, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:2]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:3]), (True, 256, 3, cborutil.SPECIAL_START_MAP), ) self.assertEqual( cborutil.decodeitem(encoded[0:4]), (True, 256, 3, cborutil.SPECIAL_START_MAP), ) source = {b'key%d' % i: None for i in range(65536)} encoded = b''.join(cborutil.streamencode(source)) self.assertEqual( cborutil.decodeitem(encoded[0:1]), (False, None, -4, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:2]), (False, None, -3, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:3]), (False, None, -2, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:4]), (False, None, -1, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:5]), (True, 65536, 5, cborutil.SPECIAL_START_MAP), ) self.assertEqual( cborutil.decodeitem(encoded[0:6]), (True, 65536, 5, cborutil.SPECIAL_START_MAP), ) class SemanticTagTests(TestCase): def testdecodeforbidden(self): for i in range(500): if i == cborutil.SEMANTIC_TAG_FINITE_SET: continue tag = cborutil.encodelength(cborutil.MAJOR_TYPE_SEMANTIC, i) encoded = tag + cborutil.encodelength(cborutil.MAJOR_TYPE_UINT, 42) # Partial decode is incomplete. if i < 24: pass elif i < 256: self.assertEqual( cborutil.decodeitem(encoded[0:1]), (False, None, -1, cborutil.SPECIAL_NONE), ) elif i < 65536: self.assertEqual( cborutil.decodeitem(encoded[0:1]), (False, None, -2, cborutil.SPECIAL_NONE), ) self.assertEqual( cborutil.decodeitem(encoded[0:2]), (False, None, -1, cborutil.SPECIAL_NONE), ) with self.assertRaisesRegex( cborutil.CBORDecodeError, r'semantic tag \d+ not allowed' ): cborutil.decodeitem(encoded) class SpecialTypesTests(TestCase): def testforbiddentypes(self): for i in range(256): if i == cborutil.SUBTYPE_FALSE: continue elif i == cborutil.SUBTYPE_TRUE: continue elif i == cborutil.SUBTYPE_NULL: continue encoded = cborutil.encodelength(cborutil.MAJOR_TYPE_SPECIAL, i) with self.assertRaisesRegex( cborutil.CBORDecodeError, r'special type \d+ not allowed' ): cborutil.decodeitem(encoded) class SansIODecoderTests(TestCase): def testemptyinput(self): decoder = cborutil.sansiodecoder() self.assertEqual(decoder.decode(b''), (False, 0, 0)) class BufferingDecoderTests(TestCase): def testsimple(self): source = [ b'foobar', b'x' * 128, {b'foo': b'bar'}, True, False, None, [None] * 128, ] encoded = b''.join(cborutil.streamencode(source)) for step in range(1, 32): decoder = cborutil.bufferingdecoder() start = 0 while start < len(encoded): decoder.decode(encoded[start : start + step]) start += step self.assertEqual(decoder.getavailable(), [source]) def testbytearray(self): source = b''.join(cborutil.streamencode(b'foobar')) decoder = cborutil.bufferingdecoder() decoder.decode(bytearray(source)) self.assertEqual(decoder.getavailable(), [b'foobar']) class DecodeallTests(TestCase): def testemptyinput(self): self.assertEqual(cborutil.decodeall(b''), []) def testpartialinput(self): encoded = b''.join( [ b'\x82', b'\x01', ] # array of 2 elements # integer 1 ) with self.assertRaisesRegex( cborutil.CBORDecodeError, 'input data not complete' ): cborutil.decodeall(encoded) if __name__ == '__main__': import silenttestrunner silenttestrunner.main(__name__)