contrib/python-zstandard/tests/test_decompressor.py
changeset 30895 c32454d69b85
parent 30435 b86a448a2965
child 31796 e0dc40530c5a
--- a/contrib/python-zstandard/tests/test_decompressor.py	Thu Feb 09 21:44:32 2017 -0500
+++ b/contrib/python-zstandard/tests/test_decompressor.py	Tue Feb 07 23:24:47 2017 -0800
@@ -10,7 +10,10 @@
 
 import zstd
 
-from .common import OpCountingBytesIO
+from .common import (
+    make_cffi,
+    OpCountingBytesIO,
+)
 
 
 if sys.version_info[0] >= 3:
@@ -19,6 +22,7 @@
     next = lambda it: it.next()
 
 
+@make_cffi
 class TestDecompressor_decompress(unittest.TestCase):
     def test_empty_input(self):
         dctx = zstd.ZstdDecompressor()
@@ -119,6 +123,7 @@
             self.assertEqual(decompressed, sources[i])
 
 
+@make_cffi
 class TestDecompressor_copy_stream(unittest.TestCase):
     def test_no_read(self):
         source = object()
@@ -180,6 +185,7 @@
         self.assertEqual(dest._write_count, len(dest.getvalue()))
 
 
+@make_cffi
 class TestDecompressor_decompressobj(unittest.TestCase):
     def test_simple(self):
         data = zstd.ZstdCompressor(level=1).compress(b'foobar')
@@ -207,6 +213,7 @@
     return buffer.getvalue()
 
 
+@make_cffi
 class TestDecompressor_write_to(unittest.TestCase):
     def test_empty_roundtrip(self):
         cctx = zstd.ZstdCompressor()
@@ -256,14 +263,14 @@
         buffer = io.BytesIO()
         cctx = zstd.ZstdCompressor(dict_data=d)
         with cctx.write_to(buffer) as compressor:
-            compressor.write(orig)
+            self.assertEqual(compressor.write(orig), 1544)
 
         compressed = buffer.getvalue()
         buffer = io.BytesIO()
 
         dctx = zstd.ZstdDecompressor(dict_data=d)
         with dctx.write_to(buffer) as decompressor:
-            decompressor.write(compressed)
+            self.assertEqual(decompressor.write(compressed), len(orig))
 
         self.assertEqual(buffer.getvalue(), orig)
 
@@ -291,6 +298,7 @@
         self.assertEqual(dest._write_count, len(dest.getvalue()))
 
 
+@make_cffi
 class TestDecompressor_read_from(unittest.TestCase):
     def test_type_validation(self):
         dctx = zstd.ZstdDecompressor()
@@ -302,7 +310,7 @@
         dctx.read_from(b'foobar')
 
         with self.assertRaisesRegexp(ValueError, 'must pass an object with a read'):
-            dctx.read_from(True)
+            b''.join(dctx.read_from(True))
 
     def test_empty_input(self):
         dctx = zstd.ZstdDecompressor()
@@ -351,7 +359,7 @@
         dctx = zstd.ZstdDecompressor()
 
         with self.assertRaisesRegexp(ValueError, 'skip_bytes must be smaller than read_size'):
-            dctx.read_from(b'', skip_bytes=1, read_size=1)
+            b''.join(dctx.read_from(b'', skip_bytes=1, read_size=1))
 
         with self.assertRaisesRegexp(ValueError, 'skip_bytes larger than first input chunk'):
             b''.join(dctx.read_from(b'foobar', skip_bytes=10))
@@ -476,3 +484,94 @@
             self.assertEqual(len(chunk), 1)
 
         self.assertEqual(source._read_count, len(source.getvalue()))
+
+
+@make_cffi
+class TestDecompressor_content_dict_chain(unittest.TestCase):
+    def test_bad_inputs_simple(self):
+        dctx = zstd.ZstdDecompressor()
+
+        with self.assertRaises(TypeError):
+            dctx.decompress_content_dict_chain(b'foo')
+
+        with self.assertRaises(TypeError):
+            dctx.decompress_content_dict_chain((b'foo', b'bar'))
+
+        with self.assertRaisesRegexp(ValueError, 'empty input chain'):
+            dctx.decompress_content_dict_chain([])
+
+        with self.assertRaisesRegexp(ValueError, 'chunk 0 must be bytes'):
+            dctx.decompress_content_dict_chain([u'foo'])
+
+        with self.assertRaisesRegexp(ValueError, 'chunk 0 must be bytes'):
+            dctx.decompress_content_dict_chain([True])
+
+        with self.assertRaisesRegexp(ValueError, 'chunk 0 is too small to contain a zstd frame'):
+            dctx.decompress_content_dict_chain([zstd.FRAME_HEADER])
+
+        with self.assertRaisesRegexp(ValueError, 'chunk 0 is not a valid zstd frame'):
+            dctx.decompress_content_dict_chain([b'foo' * 8])
+
+        no_size = zstd.ZstdCompressor().compress(b'foo' * 64)
+
+        with self.assertRaisesRegexp(ValueError, 'chunk 0 missing content size in frame'):
+            dctx.decompress_content_dict_chain([no_size])
+
+        # Corrupt first frame.
+        frame = zstd.ZstdCompressor(write_content_size=True).compress(b'foo' * 64)
+        frame = frame[0:12] + frame[15:]
+        with self.assertRaisesRegexp(zstd.ZstdError, 'could not decompress chunk 0'):
+            dctx.decompress_content_dict_chain([frame])
+
+    def test_bad_subsequent_input(self):
+        initial = zstd.ZstdCompressor(write_content_size=True).compress(b'foo' * 64)
+
+        dctx = zstd.ZstdDecompressor()
+
+        with self.assertRaisesRegexp(ValueError, 'chunk 1 must be bytes'):
+            dctx.decompress_content_dict_chain([initial, u'foo'])
+
+        with self.assertRaisesRegexp(ValueError, 'chunk 1 must be bytes'):
+            dctx.decompress_content_dict_chain([initial, None])
+
+        with self.assertRaisesRegexp(ValueError, 'chunk 1 is too small to contain a zstd frame'):
+            dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER])
+
+        with self.assertRaisesRegexp(ValueError, 'chunk 1 is not a valid zstd frame'):
+            dctx.decompress_content_dict_chain([initial, b'foo' * 8])
+
+        no_size = zstd.ZstdCompressor().compress(b'foo' * 64)
+
+        with self.assertRaisesRegexp(ValueError, 'chunk 1 missing content size in frame'):
+            dctx.decompress_content_dict_chain([initial, no_size])
+
+        # Corrupt second frame.
+        cctx = zstd.ZstdCompressor(write_content_size=True, dict_data=zstd.ZstdCompressionDict(b'foo' * 64))
+        frame = cctx.compress(b'bar' * 64)
+        frame = frame[0:12] + frame[15:]
+
+        with self.assertRaisesRegexp(zstd.ZstdError, 'could not decompress chunk 1'):
+            dctx.decompress_content_dict_chain([initial, frame])
+
+    def test_simple(self):
+        original = [
+            b'foo' * 64,
+            b'foobar' * 64,
+            b'baz' * 64,
+            b'foobaz' * 64,
+            b'foobarbaz' * 64,
+        ]
+
+        chunks = []
+        chunks.append(zstd.ZstdCompressor(write_content_size=True).compress(original[0]))
+        for i, chunk in enumerate(original[1:]):
+            d = zstd.ZstdCompressionDict(original[i])
+            cctx = zstd.ZstdCompressor(dict_data=d, write_content_size=True)
+            chunks.append(cctx.compress(chunk))
+
+        for i in range(1, len(original)):
+            chain = chunks[0:i]
+            expected = original[i - 1]
+            dctx = zstd.ZstdDecompressor()
+            decompressed = dctx.decompress_content_dict_chain(chain)
+            self.assertEqual(decompressed, expected)