contrib/python-zstandard/tests/test_compressor.py
changeset 40121 73fef626dae3
parent 37495 b1fb341d8a61
child 42070 675775c33ab6
--- a/contrib/python-zstandard/tests/test_compressor.py	Tue Sep 25 20:55:03 2018 +0900
+++ b/contrib/python-zstandard/tests/test_compressor.py	Mon Oct 08 16:27:40 2018 -0700
@@ -153,7 +153,7 @@
         no_params = zstd.get_frame_parameters(no_dict_id)
         with_params = zstd.get_frame_parameters(with_dict_id)
         self.assertEqual(no_params.dict_id, 0)
-        self.assertEqual(with_params.dict_id, 1387616518)
+        self.assertEqual(with_params.dict_id, 1880053135)
 
     def test_compress_dict_multiple(self):
         samples = []
@@ -216,7 +216,7 @@
         self.assertEqual(params.dict_id, d.dict_id())
 
         self.assertEqual(result,
-                         b'\x28\xb5\x2f\xfd\x23\x06\x59\xb5\x52\x03\x19\x00\x00'
+                         b'\x28\xb5\x2f\xfd\x23\x8f\x55\x0f\x70\x03\x19\x00\x00'
                          b'\x66\x6f\x6f')
 
     def test_multithreaded_compression_params(self):
@@ -336,7 +336,9 @@
                          b'\x28\xb5\x2f\xfd\x00\x48\x18\x00\x00foo')
         self.assertEqual(cobj.compress(b'bar'), b'')
         # 3 byte header plus content.
-        self.assertEqual(cobj.flush(), b'\x19\x00\x00bar')
+        self.assertEqual(cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK),
+                         b'\x18\x00\x00bar')
+        self.assertEqual(cobj.flush(), b'\x01\x00\x00')
 
     def test_flush_empty_block(self):
         cctx = zstd.ZstdCompressor(write_checksum=True)
@@ -576,15 +578,23 @@
     def test_context_manager(self):
         cctx = zstd.ZstdCompressor()
 
-        reader = cctx.stream_reader(b'foo' * 60)
-        with self.assertRaisesRegexp(zstd.ZstdError, 'read\(\) must be called from an active'):
-            reader.read(10)
-
         with cctx.stream_reader(b'foo') as reader:
             with self.assertRaisesRegexp(ValueError, 'cannot __enter__ multiple times'):
                 with reader as reader2:
                     pass
 
+    def test_no_context_manager(self):
+        cctx = zstd.ZstdCompressor()
+
+        reader = cctx.stream_reader(b'foo')
+        reader.read(4)
+        self.assertFalse(reader.closed)
+
+        reader.close()
+        self.assertTrue(reader.closed)
+        with self.assertRaisesRegexp(ValueError, 'stream is closed'):
+            reader.read(1)
+
     def test_not_implemented(self):
         cctx = zstd.ZstdCompressor()
 
@@ -619,13 +629,18 @@
             self.assertFalse(reader.writable())
             self.assertFalse(reader.seekable())
             self.assertFalse(reader.isatty())
+            self.assertFalse(reader.closed)
             self.assertIsNone(reader.flush())
+            self.assertFalse(reader.closed)
+
+        self.assertTrue(reader.closed)
 
     def test_read_closed(self):
         cctx = zstd.ZstdCompressor()
 
         with cctx.stream_reader(b'foo' * 60) as reader:
             reader.close()
+            self.assertTrue(reader.closed)
             with self.assertRaisesRegexp(ValueError, 'stream is closed'):
                 reader.read(10)
 
@@ -715,7 +730,7 @@
             while reader.read(8192):
                 pass
 
-        with self.assertRaisesRegexp(zstd.ZstdError, 'read\(\) must be called from an active'):
+        with self.assertRaisesRegexp(ValueError, 'stream is closed'):
             reader.read(10)
 
     def test_bad_size(self):
@@ -792,7 +807,7 @@
         d = zstd.train_dictionary(8192, samples)
 
         h = hashlib.sha1(d.as_bytes()).hexdigest()
-        self.assertEqual(h, '3040faa0ddc37d50e71a4dd28052cb8db5d9d027')
+        self.assertEqual(h, '2b3b6428da5bf2c9cc9d4bb58ba0bc5990dd0e79')
 
         buffer = io.BytesIO()
         cctx = zstd.ZstdCompressor(level=9, dict_data=d)
@@ -808,9 +823,16 @@
         self.assertEqual(params.window_size, 2097152)
         self.assertEqual(params.dict_id, d.dict_id())
         self.assertFalse(params.has_checksum)
-        self.assertEqual(compressed,
-                         b'\x28\xb5\x2f\xfd\x03\x58\x06\x59\xb5\x52\x5d\x00'
-                         b'\x00\x00\x02\xfc\x3d\x3f\xd9\xb0\x51\x03\x45\x89')
+
+        h = hashlib.sha1(compressed).hexdigest()
+        self.assertEqual(h, '23f88344263678478f5f82298e0a5d1833125786')
+
+        source = b'foo' + b'bar' + (b'foo' * 16384)
+
+        dctx = zstd.ZstdDecompressor(dict_data=d)
+
+        self.assertEqual(dctx.decompress(compressed, max_output_size=len(source)),
+                         source)
 
     def test_compression_params(self):
         params = zstd.ZstdCompressionParameters(
@@ -1157,6 +1179,181 @@
         b''.join(cctx.read_to_iter(source))
 
 
+@make_cffi
+class TestCompressor_chunker(unittest.TestCase):
+    def test_empty(self):
+        cctx = zstd.ZstdCompressor(write_content_size=False)
+        chunker = cctx.chunker()
+
+        it = chunker.compress(b'')
+
+        with self.assertRaises(StopIteration):
+            next(it)
+
+        it = chunker.finish()
+
+        self.assertEqual(next(it), b'\x28\xb5\x2f\xfd\x00\x50\x01\x00\x00')
+
+        with self.assertRaises(StopIteration):
+            next(it)
+
+    def test_simple_input(self):
+        cctx = zstd.ZstdCompressor()
+        chunker = cctx.chunker()
+
+        it = chunker.compress(b'foobar')
+
+        with self.assertRaises(StopIteration):
+            next(it)
+
+        it = chunker.compress(b'baz' * 30)
+
+        with self.assertRaises(StopIteration):
+            next(it)
+
+        it = chunker.finish()
+
+        self.assertEqual(next(it),
+                         b'\x28\xb5\x2f\xfd\x00\x50\x7d\x00\x00\x48\x66\x6f'
+                         b'\x6f\x62\x61\x72\x62\x61\x7a\x01\x00\xe4\xe4\x8e')
+
+        with self.assertRaises(StopIteration):
+            next(it)
+
+    def test_input_size(self):
+        cctx = zstd.ZstdCompressor()
+        chunker = cctx.chunker(size=1024)
+
+        it = chunker.compress(b'x' * 1000)
+
+        with self.assertRaises(StopIteration):
+            next(it)
+
+        it = chunker.compress(b'y' * 24)
+
+        with self.assertRaises(StopIteration):
+            next(it)
+
+        chunks = list(chunker.finish())
+
+        self.assertEqual(chunks, [
+            b'\x28\xb5\x2f\xfd\x60\x00\x03\x65\x00\x00\x18\x78\x78\x79\x02\x00'
+            b'\xa0\x16\xe3\x2b\x80\x05'
+        ])
+
+        dctx = zstd.ZstdDecompressor()
+
+        self.assertEqual(dctx.decompress(b''.join(chunks)),
+                         (b'x' * 1000) + (b'y' * 24))
+
+    def test_small_chunk_size(self):
+        cctx = zstd.ZstdCompressor()
+        chunker = cctx.chunker(chunk_size=1)
+
+        chunks = list(chunker.compress(b'foo' * 1024))
+        self.assertEqual(chunks, [])
+
+        chunks = list(chunker.finish())
+        self.assertTrue(all(len(chunk) == 1 for chunk in chunks))
+
+        self.assertEqual(
+            b''.join(chunks),
+            b'\x28\xb5\x2f\xfd\x00\x50\x55\x00\x00\x18\x66\x6f\x6f\x01\x00'
+            b'\xfa\xd3\x77\x43')
+
+        dctx = zstd.ZstdDecompressor()
+        self.assertEqual(dctx.decompress(b''.join(chunks),
+                                         max_output_size=10000),
+                         b'foo' * 1024)
+
+    def test_input_types(self):
+        cctx = zstd.ZstdCompressor()
+
+        mutable_array = bytearray(3)
+        mutable_array[:] = b'foo'
+
+        sources = [
+            memoryview(b'foo'),
+            bytearray(b'foo'),
+            mutable_array,
+        ]
+
+        for source in sources:
+            chunker = cctx.chunker()
+
+            self.assertEqual(list(chunker.compress(source)), [])
+            self.assertEqual(list(chunker.finish()), [
+                b'\x28\xb5\x2f\xfd\x00\x50\x19\x00\x00\x66\x6f\x6f'
+            ])
+
+    def test_flush(self):
+        cctx = zstd.ZstdCompressor()
+        chunker = cctx.chunker()
+
+        self.assertEqual(list(chunker.compress(b'foo' * 1024)), [])
+        self.assertEqual(list(chunker.compress(b'bar' * 1024)), [])
+
+        chunks1 = list(chunker.flush())
+
+        self.assertEqual(chunks1, [
+            b'\x28\xb5\x2f\xfd\x00\x50\x8c\x00\x00\x30\x66\x6f\x6f\x62\x61\x72'
+            b'\x02\x00\xfa\x03\xfe\xd0\x9f\xbe\x1b\x02'
+        ])
+
+        self.assertEqual(list(chunker.flush()), [])
+        self.assertEqual(list(chunker.flush()), [])
+
+        self.assertEqual(list(chunker.compress(b'baz' * 1024)), [])
+
+        chunks2 = list(chunker.flush())
+        self.assertEqual(len(chunks2), 1)
+
+        chunks3 = list(chunker.finish())
+        self.assertEqual(len(chunks2), 1)
+
+        dctx = zstd.ZstdDecompressor()
+
+        self.assertEqual(dctx.decompress(b''.join(chunks1 + chunks2 + chunks3),
+                                         max_output_size=10000),
+                         (b'foo' * 1024) + (b'bar' * 1024) + (b'baz' * 1024))
+
+    def test_compress_after_finish(self):
+        cctx = zstd.ZstdCompressor()
+        chunker = cctx.chunker()
+
+        list(chunker.compress(b'foo'))
+        list(chunker.finish())
+
+        with self.assertRaisesRegexp(
+                zstd.ZstdError,
+                'cannot call compress\(\) after compression finished'):
+            list(chunker.compress(b'foo'))
+
+    def test_flush_after_finish(self):
+        cctx = zstd.ZstdCompressor()
+        chunker = cctx.chunker()
+
+        list(chunker.compress(b'foo'))
+        list(chunker.finish())
+
+        with self.assertRaisesRegexp(
+                zstd.ZstdError,
+                'cannot call flush\(\) after compression finished'):
+            list(chunker.flush())
+
+    def test_finish_after_finish(self):
+        cctx = zstd.ZstdCompressor()
+        chunker = cctx.chunker()
+
+        list(chunker.compress(b'foo'))
+        list(chunker.finish())
+
+        with self.assertRaisesRegexp(
+                zstd.ZstdError,
+                'cannot call finish\(\) after compression finished'):
+            list(chunker.finish())
+
+
 class TestCompressor_multi_compress_to_buffer(unittest.TestCase):
     def test_invalid_inputs(self):
         cctx = zstd.ZstdCompressor()