--- a/contrib/python-zstandard/tests/test_decompressor.py Thu Feb 09 21:44:32 2017 -0500
+++ b/contrib/python-zstandard/tests/test_decompressor.py Tue Feb 07 23:24:47 2017 -0800
@@ -10,7 +10,10 @@
import zstd
-from .common import OpCountingBytesIO
+from .common import (
+ make_cffi,
+ OpCountingBytesIO,
+)
if sys.version_info[0] >= 3:
@@ -19,6 +22,7 @@
next = lambda it: it.next()
+@make_cffi
class TestDecompressor_decompress(unittest.TestCase):
def test_empty_input(self):
dctx = zstd.ZstdDecompressor()
@@ -119,6 +123,7 @@
self.assertEqual(decompressed, sources[i])
+@make_cffi
class TestDecompressor_copy_stream(unittest.TestCase):
def test_no_read(self):
source = object()
@@ -180,6 +185,7 @@
self.assertEqual(dest._write_count, len(dest.getvalue()))
+@make_cffi
class TestDecompressor_decompressobj(unittest.TestCase):
def test_simple(self):
data = zstd.ZstdCompressor(level=1).compress(b'foobar')
@@ -207,6 +213,7 @@
return buffer.getvalue()
+@make_cffi
class TestDecompressor_write_to(unittest.TestCase):
def test_empty_roundtrip(self):
cctx = zstd.ZstdCompressor()
@@ -256,14 +263,14 @@
buffer = io.BytesIO()
cctx = zstd.ZstdCompressor(dict_data=d)
with cctx.write_to(buffer) as compressor:
- compressor.write(orig)
+ self.assertEqual(compressor.write(orig), 1544)
compressed = buffer.getvalue()
buffer = io.BytesIO()
dctx = zstd.ZstdDecompressor(dict_data=d)
with dctx.write_to(buffer) as decompressor:
- decompressor.write(compressed)
+ self.assertEqual(decompressor.write(compressed), len(orig))
self.assertEqual(buffer.getvalue(), orig)
@@ -291,6 +298,7 @@
self.assertEqual(dest._write_count, len(dest.getvalue()))
+@make_cffi
class TestDecompressor_read_from(unittest.TestCase):
def test_type_validation(self):
dctx = zstd.ZstdDecompressor()
@@ -302,7 +310,7 @@
dctx.read_from(b'foobar')
with self.assertRaisesRegexp(ValueError, 'must pass an object with a read'):
- dctx.read_from(True)
+ b''.join(dctx.read_from(True))
def test_empty_input(self):
dctx = zstd.ZstdDecompressor()
@@ -351,7 +359,7 @@
dctx = zstd.ZstdDecompressor()
with self.assertRaisesRegexp(ValueError, 'skip_bytes must be smaller than read_size'):
- dctx.read_from(b'', skip_bytes=1, read_size=1)
+ b''.join(dctx.read_from(b'', skip_bytes=1, read_size=1))
with self.assertRaisesRegexp(ValueError, 'skip_bytes larger than first input chunk'):
b''.join(dctx.read_from(b'foobar', skip_bytes=10))
@@ -476,3 +484,94 @@
self.assertEqual(len(chunk), 1)
self.assertEqual(source._read_count, len(source.getvalue()))
+
+
+@make_cffi
+class TestDecompressor_content_dict_chain(unittest.TestCase):
+ def test_bad_inputs_simple(self):
+ dctx = zstd.ZstdDecompressor()
+
+ with self.assertRaises(TypeError):
+ dctx.decompress_content_dict_chain(b'foo')
+
+ with self.assertRaises(TypeError):
+ dctx.decompress_content_dict_chain((b'foo', b'bar'))
+
+ with self.assertRaisesRegexp(ValueError, 'empty input chain'):
+ dctx.decompress_content_dict_chain([])
+
+ with self.assertRaisesRegexp(ValueError, 'chunk 0 must be bytes'):
+ dctx.decompress_content_dict_chain([u'foo'])
+
+ with self.assertRaisesRegexp(ValueError, 'chunk 0 must be bytes'):
+ dctx.decompress_content_dict_chain([True])
+
+ with self.assertRaisesRegexp(ValueError, 'chunk 0 is too small to contain a zstd frame'):
+ dctx.decompress_content_dict_chain([zstd.FRAME_HEADER])
+
+ with self.assertRaisesRegexp(ValueError, 'chunk 0 is not a valid zstd frame'):
+ dctx.decompress_content_dict_chain([b'foo' * 8])
+
+ no_size = zstd.ZstdCompressor().compress(b'foo' * 64)
+
+ with self.assertRaisesRegexp(ValueError, 'chunk 0 missing content size in frame'):
+ dctx.decompress_content_dict_chain([no_size])
+
+ # Corrupt first frame.
+ frame = zstd.ZstdCompressor(write_content_size=True).compress(b'foo' * 64)
+ frame = frame[0:12] + frame[15:]
+ with self.assertRaisesRegexp(zstd.ZstdError, 'could not decompress chunk 0'):
+ dctx.decompress_content_dict_chain([frame])
+
+ def test_bad_subsequent_input(self):
+ initial = zstd.ZstdCompressor(write_content_size=True).compress(b'foo' * 64)
+
+ dctx = zstd.ZstdDecompressor()
+
+ with self.assertRaisesRegexp(ValueError, 'chunk 1 must be bytes'):
+ dctx.decompress_content_dict_chain([initial, u'foo'])
+
+ with self.assertRaisesRegexp(ValueError, 'chunk 1 must be bytes'):
+ dctx.decompress_content_dict_chain([initial, None])
+
+ with self.assertRaisesRegexp(ValueError, 'chunk 1 is too small to contain a zstd frame'):
+ dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER])
+
+ with self.assertRaisesRegexp(ValueError, 'chunk 1 is not a valid zstd frame'):
+ dctx.decompress_content_dict_chain([initial, b'foo' * 8])
+
+ no_size = zstd.ZstdCompressor().compress(b'foo' * 64)
+
+ with self.assertRaisesRegexp(ValueError, 'chunk 1 missing content size in frame'):
+ dctx.decompress_content_dict_chain([initial, no_size])
+
+ # Corrupt second frame.
+ cctx = zstd.ZstdCompressor(write_content_size=True, dict_data=zstd.ZstdCompressionDict(b'foo' * 64))
+ frame = cctx.compress(b'bar' * 64)
+ frame = frame[0:12] + frame[15:]
+
+ with self.assertRaisesRegexp(zstd.ZstdError, 'could not decompress chunk 1'):
+ dctx.decompress_content_dict_chain([initial, frame])
+
+ def test_simple(self):
+ original = [
+ b'foo' * 64,
+ b'foobar' * 64,
+ b'baz' * 64,
+ b'foobaz' * 64,
+ b'foobarbaz' * 64,
+ ]
+
+ chunks = []
+ chunks.append(zstd.ZstdCompressor(write_content_size=True).compress(original[0]))
+ for i, chunk in enumerate(original[1:]):
+ d = zstd.ZstdCompressionDict(original[i])
+ cctx = zstd.ZstdCompressor(dict_data=d, write_content_size=True)
+ chunks.append(cctx.compress(chunk))
+
+ for i in range(1, len(original)):
+ chain = chunks[0:i]
+ expected = original[i - 1]
+ dctx = zstd.ZstdDecompressor()
+ decompressed = dctx.decompress_content_dict_chain(chain)
+ self.assertEqual(decompressed, expected)