Mercurial > hg
comparison tests/test-cbor.py @ 39411:aeb551a3bb8a
cborutil: implement sans I/O decoder
The vendored CBOR package decodes by calling read(n) on an object.
There are a number of disadvantages to this:
* Uses blocking I/O. If sufficient data is not available, the decoder
will hang until it is.
* No support for partial reads. If the read(n) returns less data than
requested, the decoder raises an error.
* Requires the use of a file like object. If the original data is in
say a buffer, we need to "cast" it to e.g. a BytesIO to appease the
decoder.
In addition, the vendored CBOR decoder doesn't provide flexibility
that we desire. Specifically:
* It buffers indefinite length bytestrings instead of streaming them.
* It doesn't allow limiting the set of types that can be decoded. This
property is useful when implementing a "hardened" decoder that is
less susceptible to abusive input.
* It doesn't provide sufficient "hook points" and introspection to
institute checks around behavior. These are useful for implementing
a "hardened" decoder.
This all adds up to a reasonable set of justifications for writing our
own decoder.
So, this commit implements our own CBOR decoder.
At the heart of the decoder is a function that decodes a single "item"
from a buffer. This item can be a complete simple value or a special
value, such as "start of array." Using this function, we can build a
decoder that effectively iterates over the stream of decoded items and
builds up higher-level values, such as arrays, maps, sets, and indefinite
length bytestrings. And we can do this without performing I/O in the
decoder itself.
The core of the sans I/O decoder will probably not be used directly.
Instead, it is expected that we'll build utility functions for invoking
the decoder given specific input types. This will allow extreme
flexibility in how data is delivered to the decoder.
I'm pretty happy with the state of the decoder modulo the TODO items
to track wanted features to help with a "hardened" decoder. The one
thing I could be convinced to change is the handling of semantic tags.
Since we only support a single semantic tag (sets), I thought it would
be easier to handle them inline in decodeitem(). This is simpler now.
But if we add support for other semantic tags, it will likely be easier
to move semantic tag handling outside of decodeitem(). But, properly
supporting semantic tags opens up a whole can of worms, as many
semantic tags imply new types. I'm optimistic we won't need these in
Mercurial. But who knows.
I'm also pretty happy with the test coverage. Writing comprehensive
tests for partial decoding did flush out a handful of bugs. One
general improvement to testing would be fuzz testing for partial
decoding. I may implement that later. I also anticipate switching the
wire protocol code to this new decoder will flush out any lingering
bugs.
Differential Revision: https://phab.mercurial-scm.org/D4414
author | Gregory Szorc <gregory.szorc@gmail.com> |
---|---|
date | Tue, 28 Aug 2018 15:02:48 -0700 |
parents | 2b3b6187c316 |
children | a40d3da89b7d |
comparison
equal
deleted
inserted
replaced
39410:fcc6bd11444b | 39411:aeb551a3bb8a |
---|---|
8 ) | 8 ) |
9 from mercurial.utils import ( | 9 from mercurial.utils import ( |
10 cborutil, | 10 cborutil, |
11 ) | 11 ) |
12 | 12 |
13 class TestCase(unittest.TestCase): | |
14 if not getattr(unittest.TestCase, 'assertRaisesRegex', False): | |
15 # Python 3.7 deprecates the regex*p* version, but 2.7 lacks | |
16 # the regex version. | |
17 assertRaisesRegex = (# camelcase-required | |
18 unittest.TestCase.assertRaisesRegexp) | |
19 | |
13 def loadit(it): | 20 def loadit(it): |
14 return cbor.loads(b''.join(it)) | 21 return cbor.loads(b''.join(it)) |
15 | 22 |
16 class BytestringTests(unittest.TestCase): | 23 class BytestringTests(TestCase): |
17 def testsimple(self): | 24 def testsimple(self): |
18 self.assertEqual( | 25 self.assertEqual( |
19 list(cborutil.streamencode(b'foobar')), | 26 list(cborutil.streamencode(b'foobar')), |
20 [b'\x46', b'foobar']) | 27 [b'\x46', b'foobar']) |
21 | 28 |
22 self.assertEqual( | 29 self.assertEqual( |
23 loadit(cborutil.streamencode(b'foobar')), | 30 loadit(cborutil.streamencode(b'foobar')), |
24 b'foobar') | 31 b'foobar') |
25 | 32 |
33 self.assertEqual(cborutil.decodeall(b'\x46foobar'), | |
34 [b'foobar']) | |
35 | |
36 self.assertEqual(cborutil.decodeall(b'\x46foobar\x45fizbi'), | |
37 [b'foobar', b'fizbi']) | |
38 | |
26 def testlong(self): | 39 def testlong(self): |
27 source = b'x' * 1048576 | 40 source = b'x' * 1048576 |
28 | 41 |
29 self.assertEqual(loadit(cborutil.streamencode(source)), source) | 42 self.assertEqual(loadit(cborutil.streamencode(source)), source) |
43 | |
44 encoded = b''.join(cborutil.streamencode(source)) | |
45 self.assertEqual(cborutil.decodeall(encoded), [source]) | |
30 | 46 |
31 def testfromiter(self): | 47 def testfromiter(self): |
32 # This is the example from RFC 7049 Section 2.2.2. | 48 # This is the example from RFC 7049 Section 2.2.2. |
33 source = [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99'] | 49 source = [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99'] |
34 | 50 |
45 | 61 |
46 self.assertEqual( | 62 self.assertEqual( |
47 loadit(cborutil.streamencodebytestringfromiter(source)), | 63 loadit(cborutil.streamencodebytestringfromiter(source)), |
48 b''.join(source)) | 64 b''.join(source)) |
49 | 65 |
66 self.assertEqual(cborutil.decodeall(b'\x5f\x44\xaa\xbb\xcc\xdd' | |
67 b'\x43\xee\xff\x99\xff'), | |
68 [b'\xaa\xbb\xcc\xdd', b'\xee\xff\x99', b'']) | |
69 | |
70 for i, chunk in enumerate( | |
71 cborutil.decodeall(b'\x5f\x44\xaa\xbb\xcc\xdd' | |
72 b'\x43\xee\xff\x99\xff')): | |
73 self.assertIsInstance(chunk, cborutil.bytestringchunk) | |
74 | |
75 if i == 0: | |
76 self.assertTrue(chunk.isfirst) | |
77 else: | |
78 self.assertFalse(chunk.isfirst) | |
79 | |
80 if i == 2: | |
81 self.assertTrue(chunk.islast) | |
82 else: | |
83 self.assertFalse(chunk.islast) | |
84 | |
50 def testfromiterlarge(self): | 85 def testfromiterlarge(self): |
51 source = [b'a' * 16, b'b' * 128, b'c' * 1024, b'd' * 1048576] | 86 source = [b'a' * 16, b'b' * 128, b'c' * 1024, b'd' * 1048576] |
52 | 87 |
53 self.assertEqual( | 88 self.assertEqual( |
54 loadit(cborutil.streamencodebytestringfromiter(source)), | 89 loadit(cborutil.streamencodebytestringfromiter(source)), |
69 | 104 |
70 dest = b''.join(cborutil.streamencodeindefinitebytestring( | 105 dest = b''.join(cborutil.streamencodeindefinitebytestring( |
71 source, chunksize=42)) | 106 source, chunksize=42)) |
72 self.assertEqual(cbor.loads(dest), source) | 107 self.assertEqual(cbor.loads(dest), source) |
73 | 108 |
109 self.assertEqual(b''.join(cborutil.decodeall(dest)), source) | |
110 | |
111 for chunk in cborutil.decodeall(dest): | |
112 self.assertIsInstance(chunk, cborutil.bytestringchunk) | |
113 self.assertIn(len(chunk), (0, 8, 42)) | |
114 | |
115 encoded = b'\x5f\xff' | |
116 b = cborutil.decodeall(encoded) | |
117 self.assertEqual(b, [b'']) | |
118 self.assertTrue(b[0].isfirst) | |
119 self.assertTrue(b[0].islast) | |
120 | |
74 def testreadtoiter(self): | 121 def testreadtoiter(self): |
75 source = io.BytesIO(b'\x5f\x44\xaa\xbb\xcc\xdd\x43\xee\xff\x99\xff') | 122 source = io.BytesIO(b'\x5f\x44\xaa\xbb\xcc\xdd\x43\xee\xff\x99\xff') |
76 | 123 |
77 it = cborutil.readindefinitebytestringtoiter(source) | 124 it = cborutil.readindefinitebytestringtoiter(source) |
78 self.assertEqual(next(it), b'\xaa\xbb\xcc\xdd') | 125 self.assertEqual(next(it), b'\xaa\xbb\xcc\xdd') |
79 self.assertEqual(next(it), b'\xee\xff\x99') | 126 self.assertEqual(next(it), b'\xee\xff\x99') |
80 | 127 |
81 with self.assertRaises(StopIteration): | 128 with self.assertRaises(StopIteration): |
82 next(it) | 129 next(it) |
83 | 130 |
84 class IntTests(unittest.TestCase): | 131 def testdecodevariouslengths(self): |
132 for i in (0, 1, 22, 23, 24, 25, 254, 255, 256, 65534, 65535, 65536): | |
133 source = b'x' * i | |
134 encoded = b''.join(cborutil.streamencode(source)) | |
135 | |
136 if len(source) < 24: | |
137 hlen = 1 | |
138 elif len(source) < 256: | |
139 hlen = 2 | |
140 elif len(source) < 65536: | |
141 hlen = 3 | |
142 elif len(source) < 1048576: | |
143 hlen = 5 | |
144 | |
145 self.assertEqual(cborutil.decodeitem(encoded), | |
146 (True, source, hlen + len(source), | |
147 cborutil.SPECIAL_NONE)) | |
148 | |
149 def testpartialdecode(self): | |
150 encoded = b''.join(cborutil.streamencode(b'foobar')) | |
151 | |
152 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
153 (False, None, -6, cborutil.SPECIAL_NONE)) | |
154 self.assertEqual(cborutil.decodeitem(encoded[0:2]), | |
155 (False, None, -5, cborutil.SPECIAL_NONE)) | |
156 self.assertEqual(cborutil.decodeitem(encoded[0:3]), | |
157 (False, None, -4, cborutil.SPECIAL_NONE)) | |
158 self.assertEqual(cborutil.decodeitem(encoded[0:4]), | |
159 (False, None, -3, cborutil.SPECIAL_NONE)) | |
160 self.assertEqual(cborutil.decodeitem(encoded[0:5]), | |
161 (False, None, -2, cborutil.SPECIAL_NONE)) | |
162 self.assertEqual(cborutil.decodeitem(encoded[0:6]), | |
163 (False, None, -1, cborutil.SPECIAL_NONE)) | |
164 self.assertEqual(cborutil.decodeitem(encoded[0:7]), | |
165 (True, b'foobar', 7, cborutil.SPECIAL_NONE)) | |
166 | |
167 def testpartialdecodevariouslengths(self): | |
168 lens = [ | |
169 2, | |
170 3, | |
171 10, | |
172 23, | |
173 24, | |
174 25, | |
175 31, | |
176 100, | |
177 254, | |
178 255, | |
179 256, | |
180 257, | |
181 16384, | |
182 65534, | |
183 65535, | |
184 65536, | |
185 65537, | |
186 131071, | |
187 131072, | |
188 131073, | |
189 1048575, | |
190 1048576, | |
191 1048577, | |
192 ] | |
193 | |
194 for size in lens: | |
195 if size < 24: | |
196 hlen = 1 | |
197 elif size < 2**8: | |
198 hlen = 2 | |
199 elif size < 2**16: | |
200 hlen = 3 | |
201 elif size < 2**32: | |
202 hlen = 5 | |
203 else: | |
204 assert False | |
205 | |
206 source = b'x' * size | |
207 encoded = b''.join(cborutil.streamencode(source)) | |
208 | |
209 res = cborutil.decodeitem(encoded[0:1]) | |
210 | |
211 if hlen > 1: | |
212 self.assertEqual(res, (False, None, -(hlen - 1), | |
213 cborutil.SPECIAL_NONE)) | |
214 else: | |
215 self.assertEqual(res, (False, None, -(size + hlen - 1), | |
216 cborutil.SPECIAL_NONE)) | |
217 | |
218 # Decoding partial header reports remaining header size. | |
219 for i in range(hlen - 1): | |
220 self.assertEqual(cborutil.decodeitem(encoded[0:i + 1]), | |
221 (False, None, -(hlen - i - 1), | |
222 cborutil.SPECIAL_NONE)) | |
223 | |
224 # Decoding complete header reports item size. | |
225 self.assertEqual(cborutil.decodeitem(encoded[0:hlen]), | |
226 (False, None, -size, cborutil.SPECIAL_NONE)) | |
227 | |
228 # Decoding single byte after header reports item size - 1 | |
229 self.assertEqual(cborutil.decodeitem(encoded[0:hlen + 1]), | |
230 (False, None, -(size - 1), cborutil.SPECIAL_NONE)) | |
231 | |
232 # Decoding all but the last byte reports -1 needed. | |
233 self.assertEqual(cborutil.decodeitem(encoded[0:hlen + size - 1]), | |
234 (False, None, -1, cborutil.SPECIAL_NONE)) | |
235 | |
236 # Decoding last byte retrieves value. | |
237 self.assertEqual(cborutil.decodeitem(encoded[0:hlen + size]), | |
238 (True, source, hlen + size, cborutil.SPECIAL_NONE)) | |
239 | |
240 def testindefinitepartialdecode(self): | |
241 encoded = b''.join(cborutil.streamencodebytestringfromiter( | |
242 [b'foobar', b'biz'])) | |
243 | |
244 # First item should be begin of bytestring special. | |
245 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
246 (True, None, 1, | |
247 cborutil.SPECIAL_START_INDEFINITE_BYTESTRING)) | |
248 | |
249 # Second item should be the first chunk. But only available when | |
250 # we give it 7 bytes (1 byte header + 6 byte chunk). | |
251 self.assertEqual(cborutil.decodeitem(encoded[1:2]), | |
252 (False, None, -6, cborutil.SPECIAL_NONE)) | |
253 self.assertEqual(cborutil.decodeitem(encoded[1:3]), | |
254 (False, None, -5, cborutil.SPECIAL_NONE)) | |
255 self.assertEqual(cborutil.decodeitem(encoded[1:4]), | |
256 (False, None, -4, cborutil.SPECIAL_NONE)) | |
257 self.assertEqual(cborutil.decodeitem(encoded[1:5]), | |
258 (False, None, -3, cborutil.SPECIAL_NONE)) | |
259 self.assertEqual(cborutil.decodeitem(encoded[1:6]), | |
260 (False, None, -2, cborutil.SPECIAL_NONE)) | |
261 self.assertEqual(cborutil.decodeitem(encoded[1:7]), | |
262 (False, None, -1, cborutil.SPECIAL_NONE)) | |
263 | |
264 self.assertEqual(cborutil.decodeitem(encoded[1:8]), | |
265 (True, b'foobar', 7, cborutil.SPECIAL_NONE)) | |
266 | |
267 # Third item should be second chunk. But only available when | |
268 # we give it 4 bytes (1 byte header + 3 byte chunk). | |
269 self.assertEqual(cborutil.decodeitem(encoded[8:9]), | |
270 (False, None, -3, cborutil.SPECIAL_NONE)) | |
271 self.assertEqual(cborutil.decodeitem(encoded[8:10]), | |
272 (False, None, -2, cborutil.SPECIAL_NONE)) | |
273 self.assertEqual(cborutil.decodeitem(encoded[8:11]), | |
274 (False, None, -1, cborutil.SPECIAL_NONE)) | |
275 | |
276 self.assertEqual(cborutil.decodeitem(encoded[8:12]), | |
277 (True, b'biz', 4, cborutil.SPECIAL_NONE)) | |
278 | |
279 # Fourth item should be end of indefinite stream marker. | |
280 self.assertEqual(cborutil.decodeitem(encoded[12:13]), | |
281 (True, None, 1, cborutil.SPECIAL_INDEFINITE_BREAK)) | |
282 | |
283 # Now test the behavior when going through the decoder. | |
284 | |
285 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:1]), | |
286 (False, 1, 0)) | |
287 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:2]), | |
288 (False, 1, 6)) | |
289 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:3]), | |
290 (False, 1, 5)) | |
291 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:4]), | |
292 (False, 1, 4)) | |
293 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:5]), | |
294 (False, 1, 3)) | |
295 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:6]), | |
296 (False, 1, 2)) | |
297 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:7]), | |
298 (False, 1, 1)) | |
299 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:8]), | |
300 (True, 8, 0)) | |
301 | |
302 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:9]), | |
303 (True, 8, 3)) | |
304 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:10]), | |
305 (True, 8, 2)) | |
306 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:11]), | |
307 (True, 8, 1)) | |
308 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:12]), | |
309 (True, 12, 0)) | |
310 | |
311 self.assertEqual(cborutil.sansiodecoder().decode(encoded[0:13]), | |
312 (True, 13, 0)) | |
313 | |
314 decoder = cborutil.sansiodecoder() | |
315 decoder.decode(encoded[0:8]) | |
316 values = decoder.getavailable() | |
317 self.assertEqual(values, [b'foobar']) | |
318 self.assertTrue(values[0].isfirst) | |
319 self.assertFalse(values[0].islast) | |
320 | |
321 self.assertEqual(decoder.decode(encoded[8:12]), | |
322 (True, 4, 0)) | |
323 values = decoder.getavailable() | |
324 self.assertEqual(values, [b'biz']) | |
325 self.assertFalse(values[0].isfirst) | |
326 self.assertFalse(values[0].islast) | |
327 | |
328 self.assertEqual(decoder.decode(encoded[12:]), | |
329 (True, 1, 0)) | |
330 values = decoder.getavailable() | |
331 self.assertEqual(values, [b'']) | |
332 self.assertFalse(values[0].isfirst) | |
333 self.assertTrue(values[0].islast) | |
334 | |
335 class StringTests(TestCase): | |
336 def testdecodeforbidden(self): | |
337 encoded = b'\x63foo' | |
338 with self.assertRaisesRegex(cborutil.CBORDecodeError, | |
339 'string major type not supported'): | |
340 cborutil.decodeall(encoded) | |
341 | |
342 class IntTests(TestCase): | |
85 def testsmall(self): | 343 def testsmall(self): |
86 self.assertEqual(list(cborutil.streamencode(0)), [b'\x00']) | 344 self.assertEqual(list(cborutil.streamencode(0)), [b'\x00']) |
345 self.assertEqual(cborutil.decodeall(b'\x00'), [0]) | |
346 | |
87 self.assertEqual(list(cborutil.streamencode(1)), [b'\x01']) | 347 self.assertEqual(list(cborutil.streamencode(1)), [b'\x01']) |
348 self.assertEqual(cborutil.decodeall(b'\x01'), [1]) | |
349 | |
88 self.assertEqual(list(cborutil.streamencode(2)), [b'\x02']) | 350 self.assertEqual(list(cborutil.streamencode(2)), [b'\x02']) |
351 self.assertEqual(cborutil.decodeall(b'\x02'), [2]) | |
352 | |
89 self.assertEqual(list(cborutil.streamencode(3)), [b'\x03']) | 353 self.assertEqual(list(cborutil.streamencode(3)), [b'\x03']) |
354 self.assertEqual(cborutil.decodeall(b'\x03'), [3]) | |
355 | |
90 self.assertEqual(list(cborutil.streamencode(4)), [b'\x04']) | 356 self.assertEqual(list(cborutil.streamencode(4)), [b'\x04']) |
357 self.assertEqual(cborutil.decodeall(b'\x04'), [4]) | |
358 | |
359 # Multiple value decode works. | |
360 self.assertEqual(cborutil.decodeall(b'\x00\x01\x02\x03\x04'), | |
361 [0, 1, 2, 3, 4]) | |
91 | 362 |
92 def testnegativesmall(self): | 363 def testnegativesmall(self): |
93 self.assertEqual(list(cborutil.streamencode(-1)), [b'\x20']) | 364 self.assertEqual(list(cborutil.streamencode(-1)), [b'\x20']) |
365 self.assertEqual(cborutil.decodeall(b'\x20'), [-1]) | |
366 | |
94 self.assertEqual(list(cborutil.streamencode(-2)), [b'\x21']) | 367 self.assertEqual(list(cborutil.streamencode(-2)), [b'\x21']) |
368 self.assertEqual(cborutil.decodeall(b'\x21'), [-2]) | |
369 | |
95 self.assertEqual(list(cborutil.streamencode(-3)), [b'\x22']) | 370 self.assertEqual(list(cborutil.streamencode(-3)), [b'\x22']) |
371 self.assertEqual(cborutil.decodeall(b'\x22'), [-3]) | |
372 | |
96 self.assertEqual(list(cborutil.streamencode(-4)), [b'\x23']) | 373 self.assertEqual(list(cborutil.streamencode(-4)), [b'\x23']) |
374 self.assertEqual(cborutil.decodeall(b'\x23'), [-4]) | |
375 | |
97 self.assertEqual(list(cborutil.streamencode(-5)), [b'\x24']) | 376 self.assertEqual(list(cborutil.streamencode(-5)), [b'\x24']) |
377 self.assertEqual(cborutil.decodeall(b'\x24'), [-5]) | |
378 | |
379 # Multiple value decode works. | |
380 self.assertEqual(cborutil.decodeall(b'\x20\x21\x22\x23\x24'), | |
381 [-1, -2, -3, -4, -5]) | |
98 | 382 |
99 def testrange(self): | 383 def testrange(self): |
100 for i in range(-70000, 70000, 10): | 384 for i in range(-70000, 70000, 10): |
101 self.assertEqual( | 385 encoded = b''.join(cborutil.streamencode(i)) |
102 b''.join(cborutil.streamencode(i)), | 386 |
103 cbor.dumps(i)) | 387 self.assertEqual(encoded, cbor.dumps(i)) |
104 | 388 self.assertEqual(cborutil.decodeall(encoded), [i]) |
105 class ArrayTests(unittest.TestCase): | 389 |
390 def testdecodepartialubyte(self): | |
391 encoded = b''.join(cborutil.streamencode(250)) | |
392 | |
393 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
394 (False, None, -1, cborutil.SPECIAL_NONE)) | |
395 self.assertEqual(cborutil.decodeitem(encoded[0:2]), | |
396 (True, 250, 2, cborutil.SPECIAL_NONE)) | |
397 | |
398 def testdecodepartialbyte(self): | |
399 encoded = b''.join(cborutil.streamencode(-42)) | |
400 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
401 (False, None, -1, cborutil.SPECIAL_NONE)) | |
402 self.assertEqual(cborutil.decodeitem(encoded[0:2]), | |
403 (True, -42, 2, cborutil.SPECIAL_NONE)) | |
404 | |
405 def testdecodepartialushort(self): | |
406 encoded = b''.join(cborutil.streamencode(2**15)) | |
407 | |
408 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
409 (False, None, -2, cborutil.SPECIAL_NONE)) | |
410 self.assertEqual(cborutil.decodeitem(encoded[0:2]), | |
411 (False, None, -1, cborutil.SPECIAL_NONE)) | |
412 self.assertEqual(cborutil.decodeitem(encoded[0:5]), | |
413 (True, 2**15, 3, cborutil.SPECIAL_NONE)) | |
414 | |
415 def testdecodepartialshort(self): | |
416 encoded = b''.join(cborutil.streamencode(-1024)) | |
417 | |
418 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
419 (False, None, -2, cborutil.SPECIAL_NONE)) | |
420 self.assertEqual(cborutil.decodeitem(encoded[0:2]), | |
421 (False, None, -1, cborutil.SPECIAL_NONE)) | |
422 self.assertEqual(cborutil.decodeitem(encoded[0:3]), | |
423 (True, -1024, 3, cborutil.SPECIAL_NONE)) | |
424 | |
425 def testdecodepartialulong(self): | |
426 encoded = b''.join(cborutil.streamencode(2**28)) | |
427 | |
428 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
429 (False, None, -4, cborutil.SPECIAL_NONE)) | |
430 self.assertEqual(cborutil.decodeitem(encoded[0:2]), | |
431 (False, None, -3, cborutil.SPECIAL_NONE)) | |
432 self.assertEqual(cborutil.decodeitem(encoded[0:3]), | |
433 (False, None, -2, cborutil.SPECIAL_NONE)) | |
434 self.assertEqual(cborutil.decodeitem(encoded[0:4]), | |
435 (False, None, -1, cborutil.SPECIAL_NONE)) | |
436 self.assertEqual(cborutil.decodeitem(encoded[0:5]), | |
437 (True, 2**28, 5, cborutil.SPECIAL_NONE)) | |
438 | |
439 def testdecodepartiallong(self): | |
440 encoded = b''.join(cborutil.streamencode(-1048580)) | |
441 | |
442 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
443 (False, None, -4, cborutil.SPECIAL_NONE)) | |
444 self.assertEqual(cborutil.decodeitem(encoded[0:2]), | |
445 (False, None, -3, cborutil.SPECIAL_NONE)) | |
446 self.assertEqual(cborutil.decodeitem(encoded[0:3]), | |
447 (False, None, -2, cborutil.SPECIAL_NONE)) | |
448 self.assertEqual(cborutil.decodeitem(encoded[0:4]), | |
449 (False, None, -1, cborutil.SPECIAL_NONE)) | |
450 self.assertEqual(cborutil.decodeitem(encoded[0:5]), | |
451 (True, -1048580, 5, cborutil.SPECIAL_NONE)) | |
452 | |
453 def testdecodepartialulonglong(self): | |
454 encoded = b''.join(cborutil.streamencode(2**32)) | |
455 | |
456 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
457 (False, None, -8, cborutil.SPECIAL_NONE)) | |
458 self.assertEqual(cborutil.decodeitem(encoded[0:2]), | |
459 (False, None, -7, cborutil.SPECIAL_NONE)) | |
460 self.assertEqual(cborutil.decodeitem(encoded[0:3]), | |
461 (False, None, -6, cborutil.SPECIAL_NONE)) | |
462 self.assertEqual(cborutil.decodeitem(encoded[0:4]), | |
463 (False, None, -5, cborutil.SPECIAL_NONE)) | |
464 self.assertEqual(cborutil.decodeitem(encoded[0:5]), | |
465 (False, None, -4, cborutil.SPECIAL_NONE)) | |
466 self.assertEqual(cborutil.decodeitem(encoded[0:6]), | |
467 (False, None, -3, cborutil.SPECIAL_NONE)) | |
468 self.assertEqual(cborutil.decodeitem(encoded[0:7]), | |
469 (False, None, -2, cborutil.SPECIAL_NONE)) | |
470 self.assertEqual(cborutil.decodeitem(encoded[0:8]), | |
471 (False, None, -1, cborutil.SPECIAL_NONE)) | |
472 self.assertEqual(cborutil.decodeitem(encoded[0:9]), | |
473 (True, 2**32, 9, cborutil.SPECIAL_NONE)) | |
474 | |
475 with self.assertRaisesRegex( | |
476 cborutil.CBORDecodeError, 'input data not fully consumed'): | |
477 cborutil.decodeall(encoded[0:1]) | |
478 | |
479 with self.assertRaisesRegex( | |
480 cborutil.CBORDecodeError, 'input data not fully consumed'): | |
481 cborutil.decodeall(encoded[0:2]) | |
482 | |
483 def testdecodepartiallonglong(self): | |
484 encoded = b''.join(cborutil.streamencode(-7000000000)) | |
485 | |
486 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
487 (False, None, -8, cborutil.SPECIAL_NONE)) | |
488 self.assertEqual(cborutil.decodeitem(encoded[0:2]), | |
489 (False, None, -7, cborutil.SPECIAL_NONE)) | |
490 self.assertEqual(cborutil.decodeitem(encoded[0:3]), | |
491 (False, None, -6, cborutil.SPECIAL_NONE)) | |
492 self.assertEqual(cborutil.decodeitem(encoded[0:4]), | |
493 (False, None, -5, cborutil.SPECIAL_NONE)) | |
494 self.assertEqual(cborutil.decodeitem(encoded[0:5]), | |
495 (False, None, -4, cborutil.SPECIAL_NONE)) | |
496 self.assertEqual(cborutil.decodeitem(encoded[0:6]), | |
497 (False, None, -3, cborutil.SPECIAL_NONE)) | |
498 self.assertEqual(cborutil.decodeitem(encoded[0:7]), | |
499 (False, None, -2, cborutil.SPECIAL_NONE)) | |
500 self.assertEqual(cborutil.decodeitem(encoded[0:8]), | |
501 (False, None, -1, cborutil.SPECIAL_NONE)) | |
502 self.assertEqual(cborutil.decodeitem(encoded[0:9]), | |
503 (True, -7000000000, 9, cborutil.SPECIAL_NONE)) | |
504 | |
505 class ArrayTests(TestCase): | |
106 def testempty(self): | 506 def testempty(self): |
107 self.assertEqual(list(cborutil.streamencode([])), [b'\x80']) | 507 self.assertEqual(list(cborutil.streamencode([])), [b'\x80']) |
108 self.assertEqual(loadit(cborutil.streamencode([])), []) | 508 self.assertEqual(loadit(cborutil.streamencode([])), []) |
109 | 509 |
510 self.assertEqual(cborutil.decodeall(b'\x80'), [[]]) | |
511 | |
110 def testbasic(self): | 512 def testbasic(self): |
111 source = [b'foo', b'bar', 1, -10] | 513 source = [b'foo', b'bar', 1, -10] |
112 | 514 |
113 self.assertEqual(list(cborutil.streamencode(source)), [ | 515 chunks = [ |
114 b'\x84', b'\x43', b'foo', b'\x43', b'bar', b'\x01', b'\x29']) | 516 b'\x84', b'\x43', b'foo', b'\x43', b'bar', b'\x01', b'\x29'] |
517 | |
518 self.assertEqual(list(cborutil.streamencode(source)), chunks) | |
519 | |
520 self.assertEqual(cborutil.decodeall(b''.join(chunks)), [source]) | |
115 | 521 |
116 def testemptyfromiter(self): | 522 def testemptyfromiter(self): |
117 self.assertEqual(b''.join(cborutil.streamencodearrayfromiter([])), | 523 self.assertEqual(b''.join(cborutil.streamencodearrayfromiter([])), |
118 b'\x9f\xff') | 524 b'\x9f\xff') |
525 | |
526 with self.assertRaisesRegex(cborutil.CBORDecodeError, | |
527 'indefinite length uint not allowed'): | |
528 cborutil.decodeall(b'\x9f\xff') | |
119 | 529 |
120 def testfromiter1(self): | 530 def testfromiter1(self): |
121 source = [b'foo'] | 531 source = [b'foo'] |
122 | 532 |
123 self.assertEqual(list(cborutil.streamencodearrayfromiter(source)), [ | 533 self.assertEqual(list(cborutil.streamencodearrayfromiter(source)), [ |
127 ]) | 537 ]) |
128 | 538 |
129 dest = b''.join(cborutil.streamencodearrayfromiter(source)) | 539 dest = b''.join(cborutil.streamencodearrayfromiter(source)) |
130 self.assertEqual(cbor.loads(dest), source) | 540 self.assertEqual(cbor.loads(dest), source) |
131 | 541 |
542 with self.assertRaisesRegex(cborutil.CBORDecodeError, | |
543 'indefinite length uint not allowed'): | |
544 cborutil.decodeall(dest) | |
545 | |
132 def testtuple(self): | 546 def testtuple(self): |
133 source = (b'foo', None, 42) | 547 source = (b'foo', None, 42) |
134 | 548 encoded = b''.join(cborutil.streamencode(source)) |
135 self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))), | 549 |
136 list(source)) | 550 self.assertEqual(cbor.loads(encoded), list(source)) |
137 | 551 |
138 class SetTests(unittest.TestCase): | 552 self.assertEqual(cborutil.decodeall(encoded), [list(source)]) |
553 | |
554 def testpartialdecode(self): | |
555 source = list(range(4)) | |
556 encoded = b''.join(cborutil.streamencode(source)) | |
557 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
558 (True, 4, 1, cborutil.SPECIAL_START_ARRAY)) | |
559 self.assertEqual(cborutil.decodeitem(encoded[0:2]), | |
560 (True, 4, 1, cborutil.SPECIAL_START_ARRAY)) | |
561 | |
562 source = list(range(23)) | |
563 encoded = b''.join(cborutil.streamencode(source)) | |
564 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
565 (True, 23, 1, cborutil.SPECIAL_START_ARRAY)) | |
566 self.assertEqual(cborutil.decodeitem(encoded[0:2]), | |
567 (True, 23, 1, cborutil.SPECIAL_START_ARRAY)) | |
568 | |
569 source = list(range(24)) | |
570 encoded = b''.join(cborutil.streamencode(source)) | |
571 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
572 (False, None, -1, cborutil.SPECIAL_NONE)) | |
573 self.assertEqual(cborutil.decodeitem(encoded[0:2]), | |
574 (True, 24, 2, cborutil.SPECIAL_START_ARRAY)) | |
575 self.assertEqual(cborutil.decodeitem(encoded[0:3]), | |
576 (True, 24, 2, cborutil.SPECIAL_START_ARRAY)) | |
577 | |
578 source = list(range(256)) | |
579 encoded = b''.join(cborutil.streamencode(source)) | |
580 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
581 (False, None, -2, cborutil.SPECIAL_NONE)) | |
582 self.assertEqual(cborutil.decodeitem(encoded[0:2]), | |
583 (False, None, -1, cborutil.SPECIAL_NONE)) | |
584 self.assertEqual(cborutil.decodeitem(encoded[0:3]), | |
585 (True, 256, 3, cborutil.SPECIAL_START_ARRAY)) | |
586 self.assertEqual(cborutil.decodeitem(encoded[0:4]), | |
587 (True, 256, 3, cborutil.SPECIAL_START_ARRAY)) | |
588 | |
589 def testnested(self): | |
590 source = [[], [], [[], [], []]] | |
591 encoded = b''.join(cborutil.streamencode(source)) | |
592 self.assertEqual(cborutil.decodeall(encoded), [source]) | |
593 | |
594 source = [True, None, [True, 0, 2], [None], [], [[[]], -87]] | |
595 encoded = b''.join(cborutil.streamencode(source)) | |
596 self.assertEqual(cborutil.decodeall(encoded), [source]) | |
597 | |
598 # A set within an array. | |
599 source = [None, {b'foo', b'bar', None, False}, set()] | |
600 encoded = b''.join(cborutil.streamencode(source)) | |
601 self.assertEqual(cborutil.decodeall(encoded), [source]) | |
602 | |
603 # A map within an array. | |
604 source = [None, {}, {b'foo': b'bar', True: False}, [{}]] | |
605 encoded = b''.join(cborutil.streamencode(source)) | |
606 self.assertEqual(cborutil.decodeall(encoded), [source]) | |
607 | |
608 def testindefinitebytestringvalues(self): | |
609 # Single value array whose value is an empty indefinite bytestring. | |
610 encoded = b'\x81\x5f\x40\xff' | |
611 | |
612 with self.assertRaisesRegex(cborutil.CBORDecodeError, | |
613 'indefinite length bytestrings not ' | |
614 'allowed as array values'): | |
615 cborutil.decodeall(encoded) | |
616 | |
617 class SetTests(TestCase): | |
139 def testempty(self): | 618 def testempty(self): |
140 self.assertEqual(list(cborutil.streamencode(set())), [ | 619 self.assertEqual(list(cborutil.streamencode(set())), [ |
141 b'\xd9\x01\x02', | 620 b'\xd9\x01\x02', |
142 b'\x80', | 621 b'\x80', |
143 ]) | 622 ]) |
144 | 623 |
624 self.assertEqual(cborutil.decodeall(b'\xd9\x01\x02\x80'), [set()]) | |
625 | |
145 def testset(self): | 626 def testset(self): |
146 source = {b'foo', None, 42} | 627 source = {b'foo', None, 42} |
147 | 628 encoded = b''.join(cborutil.streamencode(source)) |
148 self.assertEqual(cbor.loads(b''.join(cborutil.streamencode(source))), | 629 |
149 source) | 630 self.assertEqual(cbor.loads(encoded), source) |
150 | 631 |
151 class BoolTests(unittest.TestCase): | 632 self.assertEqual(cborutil.decodeall(encoded), [source]) |
633 | |
634 def testinvalidtag(self): | |
635 # Must use array to encode sets. | |
636 encoded = b'\xd9\x01\x02\xa0' | |
637 | |
638 with self.assertRaisesRegex(cborutil.CBORDecodeError, | |
639 'expected array after finite set ' | |
640 'semantic tag'): | |
641 cborutil.decodeall(encoded) | |
642 | |
643 def testpartialdecode(self): | |
644 # Semantic tag item will be 3 bytes. Set header will be variable | |
645 # depending on length. | |
646 encoded = b''.join(cborutil.streamencode({i for i in range(23)})) | |
647 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
648 (False, None, -2, cborutil.SPECIAL_NONE)) | |
649 self.assertEqual(cborutil.decodeitem(encoded[0:2]), | |
650 (False, None, -1, cborutil.SPECIAL_NONE)) | |
651 self.assertEqual(cborutil.decodeitem(encoded[0:3]), | |
652 (False, None, -1, cborutil.SPECIAL_NONE)) | |
653 self.assertEqual(cborutil.decodeitem(encoded[0:4]), | |
654 (True, 23, 4, cborutil.SPECIAL_START_SET)) | |
655 self.assertEqual(cborutil.decodeitem(encoded[0:5]), | |
656 (True, 23, 4, cborutil.SPECIAL_START_SET)) | |
657 | |
658 encoded = b''.join(cborutil.streamencode({i for i in range(24)})) | |
659 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
660 (False, None, -2, cborutil.SPECIAL_NONE)) | |
661 self.assertEqual(cborutil.decodeitem(encoded[0:2]), | |
662 (False, None, -1, cborutil.SPECIAL_NONE)) | |
663 self.assertEqual(cborutil.decodeitem(encoded[0:3]), | |
664 (False, None, -1, cborutil.SPECIAL_NONE)) | |
665 self.assertEqual(cborutil.decodeitem(encoded[0:4]), | |
666 (False, None, -1, cborutil.SPECIAL_NONE)) | |
667 self.assertEqual(cborutil.decodeitem(encoded[0:5]), | |
668 (True, 24, 5, cborutil.SPECIAL_START_SET)) | |
669 self.assertEqual(cborutil.decodeitem(encoded[0:6]), | |
670 (True, 24, 5, cborutil.SPECIAL_START_SET)) | |
671 | |
672 encoded = b''.join(cborutil.streamencode({i for i in range(256)})) | |
673 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
674 (False, None, -2, cborutil.SPECIAL_NONE)) | |
675 self.assertEqual(cborutil.decodeitem(encoded[0:2]), | |
676 (False, None, -1, cborutil.SPECIAL_NONE)) | |
677 self.assertEqual(cborutil.decodeitem(encoded[0:3]), | |
678 (False, None, -1, cborutil.SPECIAL_NONE)) | |
679 self.assertEqual(cborutil.decodeitem(encoded[0:4]), | |
680 (False, None, -2, cborutil.SPECIAL_NONE)) | |
681 self.assertEqual(cborutil.decodeitem(encoded[0:5]), | |
682 (False, None, -1, cborutil.SPECIAL_NONE)) | |
683 self.assertEqual(cborutil.decodeitem(encoded[0:6]), | |
684 (True, 256, 6, cborutil.SPECIAL_START_SET)) | |
685 | |
686 def testinvalidvalue(self): | |
687 encoded = b''.join([ | |
688 b'\xd9\x01\x02', # semantic tag | |
689 b'\x81', # array of size 1 | |
690 b'\x5f\x43foo\xff', # indefinite length bytestring "foo" | |
691 ]) | |
692 | |
693 with self.assertRaisesRegex(cborutil.CBORDecodeError, | |
694 'indefinite length bytestrings not ' | |
695 'allowed as set values'): | |
696 cborutil.decodeall(encoded) | |
697 | |
698 encoded = b''.join([ | |
699 b'\xd9\x01\x02', | |
700 b'\x81', | |
701 b'\x80', # empty array | |
702 ]) | |
703 | |
704 with self.assertRaisesRegex(cborutil.CBORDecodeError, | |
705 'collections not allowed as set values'): | |
706 cborutil.decodeall(encoded) | |
707 | |
708 encoded = b''.join([ | |
709 b'\xd9\x01\x02', | |
710 b'\x81', | |
711 b'\xa0', # empty map | |
712 ]) | |
713 | |
714 with self.assertRaisesRegex(cborutil.CBORDecodeError, | |
715 'collections not allowed as set values'): | |
716 cborutil.decodeall(encoded) | |
717 | |
718 encoded = b''.join([ | |
719 b'\xd9\x01\x02', | |
720 b'\x81', | |
721 b'\xd9\x01\x02\x81\x01', # set with integer 1 | |
722 ]) | |
723 | |
724 with self.assertRaisesRegex(cborutil.CBORDecodeError, | |
725 'collections not allowed as set values'): | |
726 cborutil.decodeall(encoded) | |
727 | |
728 class BoolTests(TestCase): | |
152 def testbasic(self): | 729 def testbasic(self): |
153 self.assertEqual(list(cborutil.streamencode(True)), [b'\xf5']) | 730 self.assertEqual(list(cborutil.streamencode(True)), [b'\xf5']) |
154 self.assertEqual(list(cborutil.streamencode(False)), [b'\xf4']) | 731 self.assertEqual(list(cborutil.streamencode(False)), [b'\xf4']) |
155 | 732 |
156 self.assertIs(loadit(cborutil.streamencode(True)), True) | 733 self.assertIs(loadit(cborutil.streamencode(True)), True) |
157 self.assertIs(loadit(cborutil.streamencode(False)), False) | 734 self.assertIs(loadit(cborutil.streamencode(False)), False) |
158 | 735 |
159 class NoneTests(unittest.TestCase): | 736 self.assertEqual(cborutil.decodeall(b'\xf4'), [False]) |
737 self.assertEqual(cborutil.decodeall(b'\xf5'), [True]) | |
738 | |
739 self.assertEqual(cborutil.decodeall(b'\xf4\xf5\xf5\xf4'), | |
740 [False, True, True, False]) | |
741 | |
742 class NoneTests(TestCase): | |
160 def testbasic(self): | 743 def testbasic(self): |
161 self.assertEqual(list(cborutil.streamencode(None)), [b'\xf6']) | 744 self.assertEqual(list(cborutil.streamencode(None)), [b'\xf6']) |
162 | 745 |
163 self.assertIs(loadit(cborutil.streamencode(None)), None) | 746 self.assertIs(loadit(cborutil.streamencode(None)), None) |
164 | 747 |
165 class MapTests(unittest.TestCase): | 748 self.assertEqual(cborutil.decodeall(b'\xf6'), [None]) |
749 self.assertEqual(cborutil.decodeall(b'\xf6\xf6'), [None, None]) | |
750 | |
751 class MapTests(TestCase): | |
166 def testempty(self): | 752 def testempty(self): |
167 self.assertEqual(list(cborutil.streamencode({})), [b'\xa0']) | 753 self.assertEqual(list(cborutil.streamencode({})), [b'\xa0']) |
168 self.assertEqual(loadit(cborutil.streamencode({})), {}) | 754 self.assertEqual(loadit(cborutil.streamencode({})), {}) |
169 | 755 |
756 self.assertEqual(cborutil.decodeall(b'\xa0'), [{}]) | |
757 | |
170 def testemptyindefinite(self): | 758 def testemptyindefinite(self): |
171 self.assertEqual(list(cborutil.streamencodemapfromiter([])), [ | 759 self.assertEqual(list(cborutil.streamencodemapfromiter([])), [ |
172 b'\xbf', b'\xff']) | 760 b'\xbf', b'\xff']) |
173 | 761 |
174 self.assertEqual(loadit(cborutil.streamencodemapfromiter([])), {}) | 762 self.assertEqual(loadit(cborutil.streamencodemapfromiter([])), {}) |
763 | |
764 with self.assertRaisesRegex(cborutil.CBORDecodeError, | |
765 'indefinite length uint not allowed'): | |
766 cborutil.decodeall(b'\xbf\xff') | |
175 | 767 |
176 def testone(self): | 768 def testone(self): |
177 source = {b'foo': b'bar'} | 769 source = {b'foo': b'bar'} |
178 self.assertEqual(list(cborutil.streamencode(source)), [ | 770 self.assertEqual(list(cborutil.streamencode(source)), [ |
179 b'\xa1', b'\x43', b'foo', b'\x43', b'bar']) | 771 b'\xa1', b'\x43', b'foo', b'\x43', b'bar']) |
180 | 772 |
181 self.assertEqual(loadit(cborutil.streamencode(source)), source) | 773 self.assertEqual(loadit(cborutil.streamencode(source)), source) |
774 | |
775 self.assertEqual(cborutil.decodeall(b'\xa1\x43foo\x43bar'), [source]) | |
182 | 776 |
183 def testmultiple(self): | 777 def testmultiple(self): |
184 source = { | 778 source = { |
185 b'foo': b'bar', | 779 b'foo': b'bar', |
186 b'baz': b'value1', | 780 b'baz': b'value1', |
190 | 784 |
191 self.assertEqual( | 785 self.assertEqual( |
192 loadit(cborutil.streamencodemapfromiter(source.items())), | 786 loadit(cborutil.streamencodemapfromiter(source.items())), |
193 source) | 787 source) |
194 | 788 |
789 encoded = b''.join(cborutil.streamencode(source)) | |
790 self.assertEqual(cborutil.decodeall(encoded), [source]) | |
791 | |
195 def testcomplex(self): | 792 def testcomplex(self): |
196 source = { | 793 source = { |
197 b'key': 1, | 794 b'key': 1, |
198 2: -10, | 795 2: -10, |
199 } | 796 } |
203 | 800 |
204 self.assertEqual( | 801 self.assertEqual( |
205 loadit(cborutil.streamencodemapfromiter(source.items())), | 802 loadit(cborutil.streamencodemapfromiter(source.items())), |
206 source) | 803 source) |
207 | 804 |
805 encoded = b''.join(cborutil.streamencode(source)) | |
806 self.assertEqual(cborutil.decodeall(encoded), [source]) | |
807 | |
808 def testnested(self): | |
809 source = {b'key1': None, b'key2': {b'sub1': b'sub2'}, b'sub2': {}} | |
810 encoded = b''.join(cborutil.streamencode(source)) | |
811 | |
812 self.assertEqual(cborutil.decodeall(encoded), [source]) | |
813 | |
814 source = { | |
815 b'key1': [], | |
816 b'key2': [None, False], | |
817 b'key3': {b'foo', b'bar'}, | |
818 b'key4': {}, | |
819 } | |
820 encoded = b''.join(cborutil.streamencode(source)) | |
821 self.assertEqual(cborutil.decodeall(encoded), [source]) | |
822 | |
823 def testillegalkey(self): | |
824 encoded = b''.join([ | |
825 # map header + len 1 | |
826 b'\xa1', | |
827 # indefinite length bytestring "foo" in key position | |
828 b'\x5f\x03foo\xff' | |
829 ]) | |
830 | |
831 with self.assertRaisesRegex(cborutil.CBORDecodeError, | |
832 'indefinite length bytestrings not ' | |
833 'allowed as map keys'): | |
834 cborutil.decodeall(encoded) | |
835 | |
836 encoded = b''.join([ | |
837 b'\xa1', | |
838 b'\x80', # empty array | |
839 b'\x43foo', | |
840 ]) | |
841 | |
842 with self.assertRaisesRegex(cborutil.CBORDecodeError, | |
843 'collections not supported as map keys'): | |
844 cborutil.decodeall(encoded) | |
845 | |
846 def testillegalvalue(self): | |
847 encoded = b''.join([ | |
848 b'\xa1', # map headers | |
849 b'\x43foo', # key | |
850 b'\x5f\x03bar\xff', # indefinite length value | |
851 ]) | |
852 | |
853 with self.assertRaisesRegex(cborutil.CBORDecodeError, | |
854 'indefinite length bytestrings not ' | |
855 'allowed as map values'): | |
856 cborutil.decodeall(encoded) | |
857 | |
858 def testpartialdecode(self): | |
859 source = {b'key1': b'value1'} | |
860 encoded = b''.join(cborutil.streamencode(source)) | |
861 | |
862 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
863 (True, 1, 1, cborutil.SPECIAL_START_MAP)) | |
864 self.assertEqual(cborutil.decodeitem(encoded[0:2]), | |
865 (True, 1, 1, cborutil.SPECIAL_START_MAP)) | |
866 | |
867 source = {b'key%d' % i: None for i in range(23)} | |
868 encoded = b''.join(cborutil.streamencode(source)) | |
869 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
870 (True, 23, 1, cborutil.SPECIAL_START_MAP)) | |
871 | |
872 source = {b'key%d' % i: None for i in range(24)} | |
873 encoded = b''.join(cborutil.streamencode(source)) | |
874 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
875 (False, None, -1, cborutil.SPECIAL_NONE)) | |
876 self.assertEqual(cborutil.decodeitem(encoded[0:2]), | |
877 (True, 24, 2, cborutil.SPECIAL_START_MAP)) | |
878 self.assertEqual(cborutil.decodeitem(encoded[0:3]), | |
879 (True, 24, 2, cborutil.SPECIAL_START_MAP)) | |
880 | |
881 source = {b'key%d' % i: None for i in range(256)} | |
882 encoded = b''.join(cborutil.streamencode(source)) | |
883 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
884 (False, None, -2, cborutil.SPECIAL_NONE)) | |
885 self.assertEqual(cborutil.decodeitem(encoded[0:2]), | |
886 (False, None, -1, cborutil.SPECIAL_NONE)) | |
887 self.assertEqual(cborutil.decodeitem(encoded[0:3]), | |
888 (True, 256, 3, cborutil.SPECIAL_START_MAP)) | |
889 self.assertEqual(cborutil.decodeitem(encoded[0:4]), | |
890 (True, 256, 3, cborutil.SPECIAL_START_MAP)) | |
891 | |
892 source = {b'key%d' % i: None for i in range(65536)} | |
893 encoded = b''.join(cborutil.streamencode(source)) | |
894 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
895 (False, None, -4, cborutil.SPECIAL_NONE)) | |
896 self.assertEqual(cborutil.decodeitem(encoded[0:2]), | |
897 (False, None, -3, cborutil.SPECIAL_NONE)) | |
898 self.assertEqual(cborutil.decodeitem(encoded[0:3]), | |
899 (False, None, -2, cborutil.SPECIAL_NONE)) | |
900 self.assertEqual(cborutil.decodeitem(encoded[0:4]), | |
901 (False, None, -1, cborutil.SPECIAL_NONE)) | |
902 self.assertEqual(cborutil.decodeitem(encoded[0:5]), | |
903 (True, 65536, 5, cborutil.SPECIAL_START_MAP)) | |
904 self.assertEqual(cborutil.decodeitem(encoded[0:6]), | |
905 (True, 65536, 5, cborutil.SPECIAL_START_MAP)) | |
906 | |
907 class SemanticTagTests(TestCase): | |
908 def testdecodeforbidden(self): | |
909 for i in range(500): | |
910 if i == cborutil.SEMANTIC_TAG_FINITE_SET: | |
911 continue | |
912 | |
913 tag = cborutil.encodelength(cborutil.MAJOR_TYPE_SEMANTIC, | |
914 i) | |
915 | |
916 encoded = tag + cborutil.encodelength(cborutil.MAJOR_TYPE_UINT, 42) | |
917 | |
918 # Partial decode is incomplete. | |
919 if i < 24: | |
920 pass | |
921 elif i < 256: | |
922 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
923 (False, None, -1, cborutil.SPECIAL_NONE)) | |
924 elif i < 65536: | |
925 self.assertEqual(cborutil.decodeitem(encoded[0:1]), | |
926 (False, None, -2, cborutil.SPECIAL_NONE)) | |
927 self.assertEqual(cborutil.decodeitem(encoded[0:2]), | |
928 (False, None, -1, cborutil.SPECIAL_NONE)) | |
929 | |
930 with self.assertRaisesRegex(cborutil.CBORDecodeError, | |
931 'semantic tag \d+ not allowed'): | |
932 cborutil.decodeitem(encoded) | |
933 | |
934 class SpecialTypesTests(TestCase): | |
935 def testforbiddentypes(self): | |
936 for i in range(256): | |
937 if i == cborutil.SUBTYPE_FALSE: | |
938 continue | |
939 elif i == cborutil.SUBTYPE_TRUE: | |
940 continue | |
941 elif i == cborutil.SUBTYPE_NULL: | |
942 continue | |
943 | |
944 encoded = cborutil.encodelength(cborutil.MAJOR_TYPE_SPECIAL, i) | |
945 | |
946 with self.assertRaisesRegex(cborutil.CBORDecodeError, | |
947 'special type \d+ not allowed'): | |
948 cborutil.decodeitem(encoded) | |
949 | |
950 class SansIODecoderTests(TestCase): | |
951 def testemptyinput(self): | |
952 decoder = cborutil.sansiodecoder() | |
953 self.assertEqual(decoder.decode(b''), (False, 0, 0)) | |
954 | |
955 class DecodeallTests(TestCase): | |
956 def testemptyinput(self): | |
957 self.assertEqual(cborutil.decodeall(b''), []) | |
958 | |
959 def testpartialinput(self): | |
960 encoded = b''.join([ | |
961 b'\x82', # array of 2 elements | |
962 b'\x01', # integer 1 | |
963 ]) | |
964 | |
965 with self.assertRaisesRegex(cborutil.CBORDecodeError, | |
966 'input data not complete'): | |
967 cborutil.decodeall(encoded) | |
968 | |
208 if __name__ == '__main__': | 969 if __name__ == '__main__': |
209 import silenttestrunner | 970 import silenttestrunner |
210 silenttestrunner.main(__name__) | 971 silenttestrunner.main(__name__) |