contrib/python-zstandard/tests/test_compressor.py
changeset 30822 b54a2984cdd4
parent 30444 b86a448a2965
child 30924 c32454d69b85
--- a/contrib/python-zstandard/tests/test_compressor.py	Sat Jan 14 20:05:15 2017 +0530
+++ b/contrib/python-zstandard/tests/test_compressor.py	Sat Jan 14 19:41:43 2017 -0800
@@ -41,6 +41,14 @@
         self.assertEqual(cctx.compress(b''),
                          b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00')
 
+        # TODO should be temporary until https://github.com/facebook/zstd/issues/506
+        # is fixed.
+        cctx = zstd.ZstdCompressor(write_content_size=True)
+        with self.assertRaises(ValueError):
+            cctx.compress(b'')
+
+        cctx.compress(b'', allow_empty=True)
+
     def test_compress_large(self):
         chunks = []
         for i in range(255):
@@ -139,19 +147,45 @@
 
         self.assertEqual(len(with_size), len(no_size) + 1)
 
-    def test_compress_after_flush(self):
+    def test_compress_after_finished(self):
         cctx = zstd.ZstdCompressor()
         cobj = cctx.compressobj()
 
         cobj.compress(b'foo')
         cobj.flush()
 
-        with self.assertRaisesRegexp(zstd.ZstdError, 'cannot call compress\(\) after flush'):
+        with self.assertRaisesRegexp(zstd.ZstdError, 'cannot call compress\(\) after compressor'):
             cobj.compress(b'foo')
 
-        with self.assertRaisesRegexp(zstd.ZstdError, 'flush\(\) already called'):
+        with self.assertRaisesRegexp(zstd.ZstdError, 'compressor object already finished'):
             cobj.flush()
 
+    def test_flush_block_repeated(self):
+        cctx = zstd.ZstdCompressor(level=1)
+        cobj = cctx.compressobj()
+
+        self.assertEqual(cobj.compress(b'foo'), b'')
+        self.assertEqual(cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK),
+                         b'\x28\xb5\x2f\xfd\x00\x48\x18\x00\x00foo')
+        self.assertEqual(cobj.compress(b'bar'), b'')
+        # 3 byte header plus content.
+        self.assertEqual(cobj.flush(), b'\x19\x00\x00bar')
+
+    def test_flush_empty_block(self):
+        cctx = zstd.ZstdCompressor(write_checksum=True)
+        cobj = cctx.compressobj()
+
+        cobj.compress(b'foobar')
+        cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK)
+        # No-op if no block is active (this is internal to zstd).
+        self.assertEqual(cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b'')
+
+        trailing = cobj.flush()
+        # 3 bytes block header + 4 bytes frame checksum
+        self.assertEqual(len(trailing), 7)
+        header = trailing[0:3]
+        self.assertEqual(header, b'\x01\x00\x00')
+
 
 class TestCompressor_copy_stream(unittest.TestCase):
     def test_no_read(self):
@@ -384,6 +418,43 @@
 
         self.assertEqual(len(dest.getvalue()), dest._write_count)
 
+    def test_flush_repeated(self):
+        cctx = zstd.ZstdCompressor(level=3)
+        dest = OpCountingBytesIO()
+        with cctx.write_to(dest) as compressor:
+            compressor.write(b'foo')
+            self.assertEqual(dest._write_count, 0)
+            compressor.flush()
+            self.assertEqual(dest._write_count, 1)
+            compressor.write(b'bar')
+            self.assertEqual(dest._write_count, 1)
+            compressor.flush()
+            self.assertEqual(dest._write_count, 2)
+            compressor.write(b'baz')
+
+        self.assertEqual(dest._write_count, 3)
+
+    def test_flush_empty_block(self):
+        cctx = zstd.ZstdCompressor(level=3, write_checksum=True)
+        dest = OpCountingBytesIO()
+        with cctx.write_to(dest) as compressor:
+            compressor.write(b'foobar' * 8192)
+            count = dest._write_count
+            offset = dest.tell()
+            compressor.flush()
+            self.assertGreater(dest._write_count, count)
+            self.assertGreater(dest.tell(), offset)
+            offset = dest.tell()
+            # Ending the write here should cause an empty block to be written
+            # to denote end of frame.
+
+        trailing = dest.getvalue()[offset:]
+        # 3 bytes block header + 4 bytes frame checksum
+        self.assertEqual(len(trailing), 7)
+
+        header = trailing[0:3]
+        self.assertEqual(header, b'\x01\x00\x00')
+
 
 class TestCompressor_read_from(unittest.TestCase):
     def test_type_validation(self):