contrib/python-zstandard/tests/test_decompressor.py
changeset 30924 c32454d69b85
parent 30444 b86a448a2965
child 31799 e0dc40530c5a
equal deleted inserted replaced
30923:5b60464efbde 30924:c32454d69b85
     8 except ImportError:
     8 except ImportError:
     9     import unittest
     9     import unittest
    10 
    10 
    11 import zstd
    11 import zstd
    12 
    12 
    13 from .common import OpCountingBytesIO
    13 from .common import (
       
    14     make_cffi,
       
    15     OpCountingBytesIO,
       
    16 )
    14 
    17 
    15 
    18 
    16 if sys.version_info[0] >= 3:
    19 if sys.version_info[0] >= 3:
    17     next = lambda it: it.__next__()
    20     next = lambda it: it.__next__()
    18 else:
    21 else:
    19     next = lambda it: it.next()
    22     next = lambda it: it.next()
    20 
    23 
    21 
    24 
       
    25 @make_cffi
    22 class TestDecompressor_decompress(unittest.TestCase):
    26 class TestDecompressor_decompress(unittest.TestCase):
    23     def test_empty_input(self):
    27     def test_empty_input(self):
    24         dctx = zstd.ZstdDecompressor()
    28         dctx = zstd.ZstdDecompressor()
    25 
    29 
    26         with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'):
    30         with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'):
   117         for i in range(len(sources)):
   121         for i in range(len(sources)):
   118             decompressed = dctx.decompress(compressed[i])
   122             decompressed = dctx.decompress(compressed[i])
   119             self.assertEqual(decompressed, sources[i])
   123             self.assertEqual(decompressed, sources[i])
   120 
   124 
   121 
   125 
       
   126 @make_cffi
   122 class TestDecompressor_copy_stream(unittest.TestCase):
   127 class TestDecompressor_copy_stream(unittest.TestCase):
   123     def test_no_read(self):
   128     def test_no_read(self):
   124         source = object()
   129         source = object()
   125         dest = io.BytesIO()
   130         dest = io.BytesIO()
   126 
   131 
   178         self.assertEqual(w, len(b'foobarfoobar'))
   183         self.assertEqual(w, len(b'foobarfoobar'))
   179         self.assertEqual(source._read_count, len(source.getvalue()) + 1)
   184         self.assertEqual(source._read_count, len(source.getvalue()) + 1)
   180         self.assertEqual(dest._write_count, len(dest.getvalue()))
   185         self.assertEqual(dest._write_count, len(dest.getvalue()))
   181 
   186 
   182 
   187 
       
   188 @make_cffi
   183 class TestDecompressor_decompressobj(unittest.TestCase):
   189 class TestDecompressor_decompressobj(unittest.TestCase):
   184     def test_simple(self):
   190     def test_simple(self):
   185         data = zstd.ZstdCompressor(level=1).compress(b'foobar')
   191         data = zstd.ZstdCompressor(level=1).compress(b'foobar')
   186 
   192 
   187         dctx = zstd.ZstdDecompressor()
   193         dctx = zstd.ZstdDecompressor()
   205     with dctx.write_to(buffer) as decompressor:
   211     with dctx.write_to(buffer) as decompressor:
   206         decompressor.write(data)
   212         decompressor.write(data)
   207     return buffer.getvalue()
   213     return buffer.getvalue()
   208 
   214 
   209 
   215 
       
   216 @make_cffi
   210 class TestDecompressor_write_to(unittest.TestCase):
   217 class TestDecompressor_write_to(unittest.TestCase):
   211     def test_empty_roundtrip(self):
   218     def test_empty_roundtrip(self):
   212         cctx = zstd.ZstdCompressor()
   219         cctx = zstd.ZstdCompressor()
   213         empty = cctx.compress(b'')
   220         empty = cctx.compress(b'')
   214         self.assertEqual(decompress_via_writer(empty), b'')
   221         self.assertEqual(decompress_via_writer(empty), b'')
   254 
   261 
   255         orig = b'foobar' * 16384
   262         orig = b'foobar' * 16384
   256         buffer = io.BytesIO()
   263         buffer = io.BytesIO()
   257         cctx = zstd.ZstdCompressor(dict_data=d)
   264         cctx = zstd.ZstdCompressor(dict_data=d)
   258         with cctx.write_to(buffer) as compressor:
   265         with cctx.write_to(buffer) as compressor:
   259             compressor.write(orig)
   266             self.assertEqual(compressor.write(orig), 1544)
   260 
   267 
   261         compressed = buffer.getvalue()
   268         compressed = buffer.getvalue()
   262         buffer = io.BytesIO()
   269         buffer = io.BytesIO()
   263 
   270 
   264         dctx = zstd.ZstdDecompressor(dict_data=d)
   271         dctx = zstd.ZstdDecompressor(dict_data=d)
   265         with dctx.write_to(buffer) as decompressor:
   272         with dctx.write_to(buffer) as decompressor:
   266             decompressor.write(compressed)
   273             self.assertEqual(decompressor.write(compressed), len(orig))
   267 
   274 
   268         self.assertEqual(buffer.getvalue(), orig)
   275         self.assertEqual(buffer.getvalue(), orig)
   269 
   276 
   270     def test_memory_size(self):
   277     def test_memory_size(self):
   271         dctx = zstd.ZstdDecompressor()
   278         dctx = zstd.ZstdDecompressor()
   289 
   296 
   290         self.assertEqual(dest.getvalue(), b'foobarfoobar')
   297         self.assertEqual(dest.getvalue(), b'foobarfoobar')
   291         self.assertEqual(dest._write_count, len(dest.getvalue()))
   298         self.assertEqual(dest._write_count, len(dest.getvalue()))
   292 
   299 
   293 
   300 
       
   301 @make_cffi
   294 class TestDecompressor_read_from(unittest.TestCase):
   302 class TestDecompressor_read_from(unittest.TestCase):
   295     def test_type_validation(self):
   303     def test_type_validation(self):
   296         dctx = zstd.ZstdDecompressor()
   304         dctx = zstd.ZstdDecompressor()
   297 
   305 
   298         # Object with read() works.
   306         # Object with read() works.
   300 
   308 
   301         # Buffer protocol works.
   309         # Buffer protocol works.
   302         dctx.read_from(b'foobar')
   310         dctx.read_from(b'foobar')
   303 
   311 
   304         with self.assertRaisesRegexp(ValueError, 'must pass an object with a read'):
   312         with self.assertRaisesRegexp(ValueError, 'must pass an object with a read'):
   305             dctx.read_from(True)
   313             b''.join(dctx.read_from(True))
   306 
   314 
   307     def test_empty_input(self):
   315     def test_empty_input(self):
   308         dctx = zstd.ZstdDecompressor()
   316         dctx = zstd.ZstdDecompressor()
   309 
   317 
   310         source = io.BytesIO()
   318         source = io.BytesIO()
   349 
   357 
   350     def test_skip_bytes_too_large(self):
   358     def test_skip_bytes_too_large(self):
   351         dctx = zstd.ZstdDecompressor()
   359         dctx = zstd.ZstdDecompressor()
   352 
   360 
   353         with self.assertRaisesRegexp(ValueError, 'skip_bytes must be smaller than read_size'):
   361         with self.assertRaisesRegexp(ValueError, 'skip_bytes must be smaller than read_size'):
   354             dctx.read_from(b'', skip_bytes=1, read_size=1)
   362             b''.join(dctx.read_from(b'', skip_bytes=1, read_size=1))
   355 
   363 
   356         with self.assertRaisesRegexp(ValueError, 'skip_bytes larger than first input chunk'):
   364         with self.assertRaisesRegexp(ValueError, 'skip_bytes larger than first input chunk'):
   357             b''.join(dctx.read_from(b'foobar', skip_bytes=10))
   365             b''.join(dctx.read_from(b'foobar', skip_bytes=10))
   358 
   366 
   359     def test_skip_bytes(self):
   367     def test_skip_bytes(self):
   474         dctx = zstd.ZstdDecompressor()
   482         dctx = zstd.ZstdDecompressor()
   475         for chunk in dctx.read_from(source, read_size=1, write_size=1):
   483         for chunk in dctx.read_from(source, read_size=1, write_size=1):
   476             self.assertEqual(len(chunk), 1)
   484             self.assertEqual(len(chunk), 1)
   477 
   485 
   478         self.assertEqual(source._read_count, len(source.getvalue()))
   486         self.assertEqual(source._read_count, len(source.getvalue()))
       
   487 
       
   488 
       
   489 @make_cffi
       
   490 class TestDecompressor_content_dict_chain(unittest.TestCase):
       
   491     def test_bad_inputs_simple(self):
       
   492         dctx = zstd.ZstdDecompressor()
       
   493 
       
   494         with self.assertRaises(TypeError):
       
   495             dctx.decompress_content_dict_chain(b'foo')
       
   496 
       
   497         with self.assertRaises(TypeError):
       
   498             dctx.decompress_content_dict_chain((b'foo', b'bar'))
       
   499 
       
   500         with self.assertRaisesRegexp(ValueError, 'empty input chain'):
       
   501             dctx.decompress_content_dict_chain([])
       
   502 
       
   503         with self.assertRaisesRegexp(ValueError, 'chunk 0 must be bytes'):
       
   504             dctx.decompress_content_dict_chain([u'foo'])
       
   505 
       
   506         with self.assertRaisesRegexp(ValueError, 'chunk 0 must be bytes'):
       
   507             dctx.decompress_content_dict_chain([True])
       
   508 
       
   509         with self.assertRaisesRegexp(ValueError, 'chunk 0 is too small to contain a zstd frame'):
       
   510             dctx.decompress_content_dict_chain([zstd.FRAME_HEADER])
       
   511 
       
   512         with self.assertRaisesRegexp(ValueError, 'chunk 0 is not a valid zstd frame'):
       
   513             dctx.decompress_content_dict_chain([b'foo' * 8])
       
   514 
       
   515         no_size = zstd.ZstdCompressor().compress(b'foo' * 64)
       
   516 
       
   517         with self.assertRaisesRegexp(ValueError, 'chunk 0 missing content size in frame'):
       
   518             dctx.decompress_content_dict_chain([no_size])
       
   519 
       
   520         # Corrupt first frame.
       
   521         frame = zstd.ZstdCompressor(write_content_size=True).compress(b'foo' * 64)
       
   522         frame = frame[0:12] + frame[15:]
       
   523         with self.assertRaisesRegexp(zstd.ZstdError, 'could not decompress chunk 0'):
       
   524             dctx.decompress_content_dict_chain([frame])
       
   525 
       
   526     def test_bad_subsequent_input(self):
       
   527         initial = zstd.ZstdCompressor(write_content_size=True).compress(b'foo' * 64)
       
   528 
       
   529         dctx = zstd.ZstdDecompressor()
       
   530 
       
   531         with self.assertRaisesRegexp(ValueError, 'chunk 1 must be bytes'):
       
   532             dctx.decompress_content_dict_chain([initial, u'foo'])
       
   533 
       
   534         with self.assertRaisesRegexp(ValueError, 'chunk 1 must be bytes'):
       
   535             dctx.decompress_content_dict_chain([initial, None])
       
   536 
       
   537         with self.assertRaisesRegexp(ValueError, 'chunk 1 is too small to contain a zstd frame'):
       
   538             dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER])
       
   539 
       
   540         with self.assertRaisesRegexp(ValueError, 'chunk 1 is not a valid zstd frame'):
       
   541             dctx.decompress_content_dict_chain([initial, b'foo' * 8])
       
   542 
       
   543         no_size = zstd.ZstdCompressor().compress(b'foo' * 64)
       
   544 
       
   545         with self.assertRaisesRegexp(ValueError, 'chunk 1 missing content size in frame'):
       
   546             dctx.decompress_content_dict_chain([initial, no_size])
       
   547 
       
   548         # Corrupt second frame.
       
   549         cctx = zstd.ZstdCompressor(write_content_size=True, dict_data=zstd.ZstdCompressionDict(b'foo' * 64))
       
   550         frame = cctx.compress(b'bar' * 64)
       
   551         frame = frame[0:12] + frame[15:]
       
   552 
       
   553         with self.assertRaisesRegexp(zstd.ZstdError, 'could not decompress chunk 1'):
       
   554             dctx.decompress_content_dict_chain([initial, frame])
       
   555 
       
   556     def test_simple(self):
       
   557         original = [
       
   558             b'foo' * 64,
       
   559             b'foobar' * 64,
       
   560             b'baz' * 64,
       
   561             b'foobaz' * 64,
       
   562             b'foobarbaz' * 64,
       
   563         ]
       
   564 
       
   565         chunks = []
       
   566         chunks.append(zstd.ZstdCompressor(write_content_size=True).compress(original[0]))
       
   567         for i, chunk in enumerate(original[1:]):
       
   568             d = zstd.ZstdCompressionDict(original[i])
       
   569             cctx = zstd.ZstdCompressor(dict_data=d, write_content_size=True)
       
   570             chunks.append(cctx.compress(chunk))
       
   571 
       
   572         for i in range(1, len(original)):
       
   573             chain = chunks[0:i]
       
   574             expected = original[i - 1]
       
   575             dctx = zstd.ZstdDecompressor()
       
   576             decompressed = dctx.decompress_content_dict_chain(chain)
       
   577             self.assertEqual(decompressed, expected)