contrib/python-zstandard/tests/test_decompressor.py
changeset 30444 b86a448a2965
child 30924 c32454d69b85
equal deleted inserted replaced
30443:2e484bdea8c4 30444:b86a448a2965
       
     1 import io
       
     2 import random
       
     3 import struct
       
     4 import sys
       
     5 
       
     6 try:
       
     7     import unittest2 as unittest
       
     8 except ImportError:
       
     9     import unittest
       
    10 
       
    11 import zstd
       
    12 
       
    13 from .common import OpCountingBytesIO
       
    14 
       
    15 
       
    16 if sys.version_info[0] >= 3:
       
    17     next = lambda it: it.__next__()
       
    18 else:
       
    19     next = lambda it: it.next()
       
    20 
       
    21 
       
    22 class TestDecompressor_decompress(unittest.TestCase):
       
    23     def test_empty_input(self):
       
    24         dctx = zstd.ZstdDecompressor()
       
    25 
       
    26         with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'):
       
    27             dctx.decompress(b'')
       
    28 
       
    29     def test_invalid_input(self):
       
    30         dctx = zstd.ZstdDecompressor()
       
    31 
       
    32         with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'):
       
    33             dctx.decompress(b'foobar')
       
    34 
       
    35     def test_no_content_size_in_frame(self):
       
    36         cctx = zstd.ZstdCompressor(write_content_size=False)
       
    37         compressed = cctx.compress(b'foobar')
       
    38 
       
    39         dctx = zstd.ZstdDecompressor()
       
    40         with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'):
       
    41             dctx.decompress(compressed)
       
    42 
       
    43     def test_content_size_present(self):
       
    44         cctx = zstd.ZstdCompressor(write_content_size=True)
       
    45         compressed = cctx.compress(b'foobar')
       
    46 
       
    47         dctx = zstd.ZstdDecompressor()
       
    48         decompressed  = dctx.decompress(compressed)
       
    49         self.assertEqual(decompressed, b'foobar')
       
    50 
       
    51     def test_max_output_size(self):
       
    52         cctx = zstd.ZstdCompressor(write_content_size=False)
       
    53         source = b'foobar' * 256
       
    54         compressed = cctx.compress(source)
       
    55 
       
    56         dctx = zstd.ZstdDecompressor()
       
    57         # Will fit into buffer exactly the size of input.
       
    58         decompressed = dctx.decompress(compressed, max_output_size=len(source))
       
    59         self.assertEqual(decompressed, source)
       
    60 
       
    61         # Input size - 1 fails
       
    62         with self.assertRaisesRegexp(zstd.ZstdError, 'Destination buffer is too small'):
       
    63             dctx.decompress(compressed, max_output_size=len(source) - 1)
       
    64 
       
    65         # Input size + 1 works
       
    66         decompressed = dctx.decompress(compressed, max_output_size=len(source) + 1)
       
    67         self.assertEqual(decompressed, source)
       
    68 
       
    69         # A much larger buffer works.
       
    70         decompressed = dctx.decompress(compressed, max_output_size=len(source) * 64)
       
    71         self.assertEqual(decompressed, source)
       
    72 
       
    73     def test_stupidly_large_output_buffer(self):
       
    74         cctx = zstd.ZstdCompressor(write_content_size=False)
       
    75         compressed = cctx.compress(b'foobar' * 256)
       
    76         dctx = zstd.ZstdDecompressor()
       
    77 
       
    78         # Will get OverflowError on some Python distributions that can't
       
    79         # handle really large integers.
       
    80         with self.assertRaises((MemoryError, OverflowError)):
       
    81             dctx.decompress(compressed, max_output_size=2**62)
       
    82 
       
    83     def test_dictionary(self):
       
    84         samples = []
       
    85         for i in range(128):
       
    86             samples.append(b'foo' * 64)
       
    87             samples.append(b'bar' * 64)
       
    88             samples.append(b'foobar' * 64)
       
    89 
       
    90         d = zstd.train_dictionary(8192, samples)
       
    91 
       
    92         orig = b'foobar' * 16384
       
    93         cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_content_size=True)
       
    94         compressed = cctx.compress(orig)
       
    95 
       
    96         dctx = zstd.ZstdDecompressor(dict_data=d)
       
    97         decompressed = dctx.decompress(compressed)
       
    98 
       
    99         self.assertEqual(decompressed, orig)
       
   100 
       
   101     def test_dictionary_multiple(self):
       
   102         samples = []
       
   103         for i in range(128):
       
   104             samples.append(b'foo' * 64)
       
   105             samples.append(b'bar' * 64)
       
   106             samples.append(b'foobar' * 64)
       
   107 
       
   108         d = zstd.train_dictionary(8192, samples)
       
   109 
       
   110         sources = (b'foobar' * 8192, b'foo' * 8192, b'bar' * 8192)
       
   111         compressed = []
       
   112         cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_content_size=True)
       
   113         for source in sources:
       
   114             compressed.append(cctx.compress(source))
       
   115 
       
   116         dctx = zstd.ZstdDecompressor(dict_data=d)
       
   117         for i in range(len(sources)):
       
   118             decompressed = dctx.decompress(compressed[i])
       
   119             self.assertEqual(decompressed, sources[i])
       
   120 
       
   121 
       
   122 class TestDecompressor_copy_stream(unittest.TestCase):
       
   123     def test_no_read(self):
       
   124         source = object()
       
   125         dest = io.BytesIO()
       
   126 
       
   127         dctx = zstd.ZstdDecompressor()
       
   128         with self.assertRaises(ValueError):
       
   129             dctx.copy_stream(source, dest)
       
   130 
       
   131     def test_no_write(self):
       
   132         source = io.BytesIO()
       
   133         dest = object()
       
   134 
       
   135         dctx = zstd.ZstdDecompressor()
       
   136         with self.assertRaises(ValueError):
       
   137             dctx.copy_stream(source, dest)
       
   138 
       
   139     def test_empty(self):
       
   140         source = io.BytesIO()
       
   141         dest = io.BytesIO()
       
   142 
       
   143         dctx = zstd.ZstdDecompressor()
       
   144         # TODO should this raise an error?
       
   145         r, w = dctx.copy_stream(source, dest)
       
   146 
       
   147         self.assertEqual(r, 0)
       
   148         self.assertEqual(w, 0)
       
   149         self.assertEqual(dest.getvalue(), b'')
       
   150 
       
   151     def test_large_data(self):
       
   152         source = io.BytesIO()
       
   153         for i in range(255):
       
   154             source.write(struct.Struct('>B').pack(i) * 16384)
       
   155         source.seek(0)
       
   156 
       
   157         compressed = io.BytesIO()
       
   158         cctx = zstd.ZstdCompressor()
       
   159         cctx.copy_stream(source, compressed)
       
   160 
       
   161         compressed.seek(0)
       
   162         dest = io.BytesIO()
       
   163         dctx = zstd.ZstdDecompressor()
       
   164         r, w = dctx.copy_stream(compressed, dest)
       
   165 
       
   166         self.assertEqual(r, len(compressed.getvalue()))
       
   167         self.assertEqual(w, len(source.getvalue()))
       
   168 
       
   169     def test_read_write_size(self):
       
   170         source = OpCountingBytesIO(zstd.ZstdCompressor().compress(
       
   171             b'foobarfoobar'))
       
   172 
       
   173         dest = OpCountingBytesIO()
       
   174         dctx = zstd.ZstdDecompressor()
       
   175         r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1)
       
   176 
       
   177         self.assertEqual(r, len(source.getvalue()))
       
   178         self.assertEqual(w, len(b'foobarfoobar'))
       
   179         self.assertEqual(source._read_count, len(source.getvalue()) + 1)
       
   180         self.assertEqual(dest._write_count, len(dest.getvalue()))
       
   181 
       
   182 
       
   183 class TestDecompressor_decompressobj(unittest.TestCase):
       
   184     def test_simple(self):
       
   185         data = zstd.ZstdCompressor(level=1).compress(b'foobar')
       
   186 
       
   187         dctx = zstd.ZstdDecompressor()
       
   188         dobj = dctx.decompressobj()
       
   189         self.assertEqual(dobj.decompress(data), b'foobar')
       
   190 
       
   191     def test_reuse(self):
       
   192         data = zstd.ZstdCompressor(level=1).compress(b'foobar')
       
   193 
       
   194         dctx = zstd.ZstdDecompressor()
       
   195         dobj = dctx.decompressobj()
       
   196         dobj.decompress(data)
       
   197 
       
   198         with self.assertRaisesRegexp(zstd.ZstdError, 'cannot use a decompressobj'):
       
   199             dobj.decompress(data)
       
   200 
       
   201 
       
   202 def decompress_via_writer(data):
       
   203     buffer = io.BytesIO()
       
   204     dctx = zstd.ZstdDecompressor()
       
   205     with dctx.write_to(buffer) as decompressor:
       
   206         decompressor.write(data)
       
   207     return buffer.getvalue()
       
   208 
       
   209 
       
   210 class TestDecompressor_write_to(unittest.TestCase):
       
   211     def test_empty_roundtrip(self):
       
   212         cctx = zstd.ZstdCompressor()
       
   213         empty = cctx.compress(b'')
       
   214         self.assertEqual(decompress_via_writer(empty), b'')
       
   215 
       
   216     def test_large_roundtrip(self):
       
   217         chunks = []
       
   218         for i in range(255):
       
   219             chunks.append(struct.Struct('>B').pack(i) * 16384)
       
   220         orig = b''.join(chunks)
       
   221         cctx = zstd.ZstdCompressor()
       
   222         compressed = cctx.compress(orig)
       
   223 
       
   224         self.assertEqual(decompress_via_writer(compressed), orig)
       
   225 
       
   226     def test_multiple_calls(self):
       
   227         chunks = []
       
   228         for i in range(255):
       
   229             for j in range(255):
       
   230                 chunks.append(struct.Struct('>B').pack(j) * i)
       
   231 
       
   232         orig = b''.join(chunks)
       
   233         cctx = zstd.ZstdCompressor()
       
   234         compressed = cctx.compress(orig)
       
   235 
       
   236         buffer = io.BytesIO()
       
   237         dctx = zstd.ZstdDecompressor()
       
   238         with dctx.write_to(buffer) as decompressor:
       
   239             pos = 0
       
   240             while pos < len(compressed):
       
   241                 pos2 = pos + 8192
       
   242                 decompressor.write(compressed[pos:pos2])
       
   243                 pos += 8192
       
   244         self.assertEqual(buffer.getvalue(), orig)
       
   245 
       
   246     def test_dictionary(self):
       
   247         samples = []
       
   248         for i in range(128):
       
   249             samples.append(b'foo' * 64)
       
   250             samples.append(b'bar' * 64)
       
   251             samples.append(b'foobar' * 64)
       
   252 
       
   253         d = zstd.train_dictionary(8192, samples)
       
   254 
       
   255         orig = b'foobar' * 16384
       
   256         buffer = io.BytesIO()
       
   257         cctx = zstd.ZstdCompressor(dict_data=d)
       
   258         with cctx.write_to(buffer) as compressor:
       
   259             compressor.write(orig)
       
   260 
       
   261         compressed = buffer.getvalue()
       
   262         buffer = io.BytesIO()
       
   263 
       
   264         dctx = zstd.ZstdDecompressor(dict_data=d)
       
   265         with dctx.write_to(buffer) as decompressor:
       
   266             decompressor.write(compressed)
       
   267 
       
   268         self.assertEqual(buffer.getvalue(), orig)
       
   269 
       
   270     def test_memory_size(self):
       
   271         dctx = zstd.ZstdDecompressor()
       
   272         buffer = io.BytesIO()
       
   273         with dctx.write_to(buffer) as decompressor:
       
   274             size = decompressor.memory_size()
       
   275 
       
   276         self.assertGreater(size, 100000)
       
   277 
       
   278     def test_write_size(self):
       
   279         source = zstd.ZstdCompressor().compress(b'foobarfoobar')
       
   280         dest = OpCountingBytesIO()
       
   281         dctx = zstd.ZstdDecompressor()
       
   282         with dctx.write_to(dest, write_size=1) as decompressor:
       
   283             s = struct.Struct('>B')
       
   284             for c in source:
       
   285                 if not isinstance(c, str):
       
   286                     c = s.pack(c)
       
   287                 decompressor.write(c)
       
   288 
       
   289 
       
   290         self.assertEqual(dest.getvalue(), b'foobarfoobar')
       
   291         self.assertEqual(dest._write_count, len(dest.getvalue()))
       
   292 
       
   293 
       
   294 class TestDecompressor_read_from(unittest.TestCase):
       
   295     def test_type_validation(self):
       
   296         dctx = zstd.ZstdDecompressor()
       
   297 
       
   298         # Object with read() works.
       
   299         dctx.read_from(io.BytesIO())
       
   300 
       
   301         # Buffer protocol works.
       
   302         dctx.read_from(b'foobar')
       
   303 
       
   304         with self.assertRaisesRegexp(ValueError, 'must pass an object with a read'):
       
   305             dctx.read_from(True)
       
   306 
       
   307     def test_empty_input(self):
       
   308         dctx = zstd.ZstdDecompressor()
       
   309 
       
   310         source = io.BytesIO()
       
   311         it = dctx.read_from(source)
       
   312         # TODO this is arguably wrong. Should get an error about missing frame foo.
       
   313         with self.assertRaises(StopIteration):
       
   314             next(it)
       
   315 
       
   316         it = dctx.read_from(b'')
       
   317         with self.assertRaises(StopIteration):
       
   318             next(it)
       
   319 
       
   320     def test_invalid_input(self):
       
   321         dctx = zstd.ZstdDecompressor()
       
   322 
       
   323         source = io.BytesIO(b'foobar')
       
   324         it = dctx.read_from(source)
       
   325         with self.assertRaisesRegexp(zstd.ZstdError, 'Unknown frame descriptor'):
       
   326             next(it)
       
   327 
       
   328         it = dctx.read_from(b'foobar')
       
   329         with self.assertRaisesRegexp(zstd.ZstdError, 'Unknown frame descriptor'):
       
   330             next(it)
       
   331 
       
   332     def test_empty_roundtrip(self):
       
   333         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
       
   334         empty = cctx.compress(b'')
       
   335 
       
   336         source = io.BytesIO(empty)
       
   337         source.seek(0)
       
   338 
       
   339         dctx = zstd.ZstdDecompressor()
       
   340         it = dctx.read_from(source)
       
   341 
       
   342         # No chunks should be emitted since there is no data.
       
   343         with self.assertRaises(StopIteration):
       
   344             next(it)
       
   345 
       
   346         # Again for good measure.
       
   347         with self.assertRaises(StopIteration):
       
   348             next(it)
       
   349 
       
   350     def test_skip_bytes_too_large(self):
       
   351         dctx = zstd.ZstdDecompressor()
       
   352 
       
   353         with self.assertRaisesRegexp(ValueError, 'skip_bytes must be smaller than read_size'):
       
   354             dctx.read_from(b'', skip_bytes=1, read_size=1)
       
   355 
       
   356         with self.assertRaisesRegexp(ValueError, 'skip_bytes larger than first input chunk'):
       
   357             b''.join(dctx.read_from(b'foobar', skip_bytes=10))
       
   358 
       
   359     def test_skip_bytes(self):
       
   360         cctx = zstd.ZstdCompressor(write_content_size=False)
       
   361         compressed = cctx.compress(b'foobar')
       
   362 
       
   363         dctx = zstd.ZstdDecompressor()
       
   364         output = b''.join(dctx.read_from(b'hdr' + compressed, skip_bytes=3))
       
   365         self.assertEqual(output, b'foobar')
       
   366 
       
   367     def test_large_output(self):
       
   368         source = io.BytesIO()
       
   369         source.write(b'f' * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE)
       
   370         source.write(b'o')
       
   371         source.seek(0)
       
   372 
       
   373         cctx = zstd.ZstdCompressor(level=1)
       
   374         compressed = io.BytesIO(cctx.compress(source.getvalue()))
       
   375         compressed.seek(0)
       
   376 
       
   377         dctx = zstd.ZstdDecompressor()
       
   378         it = dctx.read_from(compressed)
       
   379 
       
   380         chunks = []
       
   381         chunks.append(next(it))
       
   382         chunks.append(next(it))
       
   383 
       
   384         with self.assertRaises(StopIteration):
       
   385             next(it)
       
   386 
       
   387         decompressed = b''.join(chunks)
       
   388         self.assertEqual(decompressed, source.getvalue())
       
   389 
       
   390         # And again with buffer protocol.
       
   391         it = dctx.read_from(compressed.getvalue())
       
   392         chunks = []
       
   393         chunks.append(next(it))
       
   394         chunks.append(next(it))
       
   395 
       
   396         with self.assertRaises(StopIteration):
       
   397             next(it)
       
   398 
       
   399         decompressed = b''.join(chunks)
       
   400         self.assertEqual(decompressed, source.getvalue())
       
   401 
       
   402     def test_large_input(self):
       
   403         bytes = list(struct.Struct('>B').pack(i) for i in range(256))
       
   404         compressed = io.BytesIO()
       
   405         input_size = 0
       
   406         cctx = zstd.ZstdCompressor(level=1)
       
   407         with cctx.write_to(compressed) as compressor:
       
   408             while True:
       
   409                 compressor.write(random.choice(bytes))
       
   410                 input_size += 1
       
   411 
       
   412                 have_compressed = len(compressed.getvalue()) > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE
       
   413                 have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2
       
   414                 if have_compressed and have_raw:
       
   415                     break
       
   416 
       
   417         compressed.seek(0)
       
   418         self.assertGreater(len(compressed.getvalue()),
       
   419                            zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE)
       
   420 
       
   421         dctx = zstd.ZstdDecompressor()
       
   422         it = dctx.read_from(compressed)
       
   423 
       
   424         chunks = []
       
   425         chunks.append(next(it))
       
   426         chunks.append(next(it))
       
   427         chunks.append(next(it))
       
   428 
       
   429         with self.assertRaises(StopIteration):
       
   430             next(it)
       
   431 
       
   432         decompressed = b''.join(chunks)
       
   433         self.assertEqual(len(decompressed), input_size)
       
   434 
       
   435         # And again with buffer protocol.
       
   436         it = dctx.read_from(compressed.getvalue())
       
   437 
       
   438         chunks = []
       
   439         chunks.append(next(it))
       
   440         chunks.append(next(it))
       
   441         chunks.append(next(it))
       
   442 
       
   443         with self.assertRaises(StopIteration):
       
   444             next(it)
       
   445 
       
   446         decompressed = b''.join(chunks)
       
   447         self.assertEqual(len(decompressed), input_size)
       
   448 
       
   449     def test_interesting(self):
       
   450         # Found this edge case via fuzzing.
       
   451         cctx = zstd.ZstdCompressor(level=1)
       
   452 
       
   453         source = io.BytesIO()
       
   454 
       
   455         compressed = io.BytesIO()
       
   456         with cctx.write_to(compressed) as compressor:
       
   457             for i in range(256):
       
   458                 chunk = b'\0' * 1024
       
   459                 compressor.write(chunk)
       
   460                 source.write(chunk)
       
   461 
       
   462         dctx = zstd.ZstdDecompressor()
       
   463 
       
   464         simple = dctx.decompress(compressed.getvalue(),
       
   465                                  max_output_size=len(source.getvalue()))
       
   466         self.assertEqual(simple, source.getvalue())
       
   467 
       
   468         compressed.seek(0)
       
   469         streamed = b''.join(dctx.read_from(compressed))
       
   470         self.assertEqual(streamed, source.getvalue())
       
   471 
       
   472     def test_read_write_size(self):
       
   473         source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b'foobarfoobar'))
       
   474         dctx = zstd.ZstdDecompressor()
       
   475         for chunk in dctx.read_from(source, read_size=1, write_size=1):
       
   476             self.assertEqual(len(chunk), 1)
       
   477 
       
   478         self.assertEqual(source._read_count, len(source.getvalue()))