import io
import random
import struct
import sys
try:
import unittest2 as unittest
except ImportError:
import unittest
import zstd
from .common import (
make_cffi,
OpCountingBytesIO,
)
if sys.version_info[0] >= 3:
next = lambda it: it.__next__()
else:
next = lambda it: it.next()
@make_cffi
class TestDecompressor_decompress(unittest.TestCase):
def test_empty_input(self):
dctx = zstd.ZstdDecompressor()
with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'):
dctx.decompress(b'')
def test_invalid_input(self):
dctx = zstd.ZstdDecompressor()
with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'):
dctx.decompress(b'foobar')
def test_no_content_size_in_frame(self):
cctx = zstd.ZstdCompressor(write_content_size=False)
compressed = cctx.compress(b'foobar')
dctx = zstd.ZstdDecompressor()
with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'):
dctx.decompress(compressed)
def test_content_size_present(self):
cctx = zstd.ZstdCompressor(write_content_size=True)
compressed = cctx.compress(b'foobar')
dctx = zstd.ZstdDecompressor()
decompressed = dctx.decompress(compressed)
self.assertEqual(decompressed, b'foobar')
def test_max_output_size(self):
cctx = zstd.ZstdCompressor(write_content_size=False)
source = b'foobar' * 256
compressed = cctx.compress(source)
dctx = zstd.ZstdDecompressor()
# Will fit into buffer exactly the size of input.
decompressed = dctx.decompress(compressed, max_output_size=len(source))
self.assertEqual(decompressed, source)
# Input size - 1 fails
with self.assertRaisesRegexp(zstd.ZstdError, 'Destination buffer is too small'):
dctx.decompress(compressed, max_output_size=len(source) - 1)
# Input size + 1 works
decompressed = dctx.decompress(compressed, max_output_size=len(source) + 1)
self.assertEqual(decompressed, source)
# A much larger buffer works.
decompressed = dctx.decompress(compressed, max_output_size=len(source) * 64)
self.assertEqual(decompressed, source)
def test_stupidly_large_output_buffer(self):
cctx = zstd.ZstdCompressor(write_content_size=False)
compressed = cctx.compress(b'foobar' * 256)
dctx = zstd.ZstdDecompressor()
# Will get OverflowError on some Python distributions that can't
# handle really large integers.
with self.assertRaises((MemoryError, OverflowError)):
dctx.decompress(compressed, max_output_size=2**62)
def test_dictionary(self):
samples = []
for i in range(128):
samples.append(b'foo' * 64)
samples.append(b'bar' * 64)
samples.append(b'foobar' * 64)
d = zstd.train_dictionary(8192, samples)
orig = b'foobar' * 16384
cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_content_size=True)
compressed = cctx.compress(orig)
dctx = zstd.ZstdDecompressor(dict_data=d)
decompressed = dctx.decompress(compressed)
self.assertEqual(decompressed, orig)
def test_dictionary_multiple(self):
samples = []
for i in range(128):
samples.append(b'foo' * 64)
samples.append(b'bar' * 64)
samples.append(b'foobar' * 64)
d = zstd.train_dictionary(8192, samples)
sources = (b'foobar' * 8192, b'foo' * 8192, b'bar' * 8192)
compressed = []
cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_content_size=True)
for source in sources:
compressed.append(cctx.compress(source))
dctx = zstd.ZstdDecompressor(dict_data=d)
for i in range(len(sources)):
decompressed = dctx.decompress(compressed[i])
self.assertEqual(decompressed, sources[i])
@make_cffi
class TestDecompressor_copy_stream(unittest.TestCase):
def test_no_read(self):
source = object()
dest = io.BytesIO()
dctx = zstd.ZstdDecompressor()
with self.assertRaises(ValueError):
dctx.copy_stream(source, dest)
def test_no_write(self):
source = io.BytesIO()
dest = object()
dctx = zstd.ZstdDecompressor()
with self.assertRaises(ValueError):
dctx.copy_stream(source, dest)
def test_empty(self):
source = io.BytesIO()
dest = io.BytesIO()
dctx = zstd.ZstdDecompressor()
# TODO should this raise an error?
r, w = dctx.copy_stream(source, dest)
self.assertEqual(r, 0)
self.assertEqual(w, 0)
self.assertEqual(dest.getvalue(), b'')
def test_large_data(self):
source = io.BytesIO()
for i in range(255):
source.write(struct.Struct('>B').pack(i) * 16384)
source.seek(0)
compressed = io.BytesIO()
cctx = zstd.ZstdCompressor()
cctx.copy_stream(source, compressed)
compressed.seek(0)
dest = io.BytesIO()
dctx = zstd.ZstdDecompressor()
r, w = dctx.copy_stream(compressed, dest)
self.assertEqual(r, len(compressed.getvalue()))
self.assertEqual(w, len(source.getvalue()))
def test_read_write_size(self):
source = OpCountingBytesIO(zstd.ZstdCompressor().compress(
b'foobarfoobar'))
dest = OpCountingBytesIO()
dctx = zstd.ZstdDecompressor()
r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1)
self.assertEqual(r, len(source.getvalue()))
self.assertEqual(w, len(b'foobarfoobar'))
self.assertEqual(source._read_count, len(source.getvalue()) + 1)
self.assertEqual(dest._write_count, len(dest.getvalue()))
@make_cffi
class TestDecompressor_decompressobj(unittest.TestCase):
def test_simple(self):
data = zstd.ZstdCompressor(level=1).compress(b'foobar')
dctx = zstd.ZstdDecompressor()
dobj = dctx.decompressobj()
self.assertEqual(dobj.decompress(data), b'foobar')
def test_reuse(self):
data = zstd.ZstdCompressor(level=1).compress(b'foobar')
dctx = zstd.ZstdDecompressor()
dobj = dctx.decompressobj()
dobj.decompress(data)
with self.assertRaisesRegexp(zstd.ZstdError, 'cannot use a decompressobj'):
dobj.decompress(data)
def decompress_via_writer(data):
buffer = io.BytesIO()
dctx = zstd.ZstdDecompressor()
with dctx.write_to(buffer) as decompressor:
decompressor.write(data)
return buffer.getvalue()
@make_cffi
class TestDecompressor_write_to(unittest.TestCase):
def test_empty_roundtrip(self):
cctx = zstd.ZstdCompressor()
empty = cctx.compress(b'')
self.assertEqual(decompress_via_writer(empty), b'')
def test_large_roundtrip(self):
chunks = []
for i in range(255):
chunks.append(struct.Struct('>B').pack(i) * 16384)
orig = b''.join(chunks)
cctx = zstd.ZstdCompressor()
compressed = cctx.compress(orig)
self.assertEqual(decompress_via_writer(compressed), orig)
def test_multiple_calls(self):
chunks = []
for i in range(255):
for j in range(255):
chunks.append(struct.Struct('>B').pack(j) * i)
orig = b''.join(chunks)
cctx = zstd.ZstdCompressor()
compressed = cctx.compress(orig)
buffer = io.BytesIO()
dctx = zstd.ZstdDecompressor()
with dctx.write_to(buffer) as decompressor:
pos = 0
while pos < len(compressed):
pos2 = pos + 8192
decompressor.write(compressed[pos:pos2])
pos += 8192
self.assertEqual(buffer.getvalue(), orig)
def test_dictionary(self):
samples = []
for i in range(128):
samples.append(b'foo' * 64)
samples.append(b'bar' * 64)
samples.append(b'foobar' * 64)
d = zstd.train_dictionary(8192, samples)
orig = b'foobar' * 16384
buffer = io.BytesIO()
cctx = zstd.ZstdCompressor(dict_data=d)
with cctx.write_to(buffer) as compressor:
self.assertEqual(compressor.write(orig), 1544)
compressed = buffer.getvalue()
buffer = io.BytesIO()
dctx = zstd.ZstdDecompressor(dict_data=d)
with dctx.write_to(buffer) as decompressor:
self.assertEqual(decompressor.write(compressed), len(orig))
self.assertEqual(buffer.getvalue(), orig)
def test_memory_size(self):
dctx = zstd.ZstdDecompressor()
buffer = io.BytesIO()
with dctx.write_to(buffer) as decompressor:
size = decompressor.memory_size()
self.assertGreater(size, 100000)
def test_write_size(self):
source = zstd.ZstdCompressor().compress(b'foobarfoobar')
dest = OpCountingBytesIO()
dctx = zstd.ZstdDecompressor()
with dctx.write_to(dest, write_size=1) as decompressor:
s = struct.Struct('>B')
for c in source:
if not isinstance(c, str):
c = s.pack(c)
decompressor.write(c)
self.assertEqual(dest.getvalue(), b'foobarfoobar')
self.assertEqual(dest._write_count, len(dest.getvalue()))
@make_cffi
class TestDecompressor_read_from(unittest.TestCase):
def test_type_validation(self):
dctx = zstd.ZstdDecompressor()
# Object with read() works.
dctx.read_from(io.BytesIO())
# Buffer protocol works.
dctx.read_from(b'foobar')
with self.assertRaisesRegexp(ValueError, 'must pass an object with a read'):
b''.join(dctx.read_from(True))
def test_empty_input(self):
dctx = zstd.ZstdDecompressor()
source = io.BytesIO()
it = dctx.read_from(source)
# TODO this is arguably wrong. Should get an error about missing frame foo.
with self.assertRaises(StopIteration):
next(it)
it = dctx.read_from(b'')
with self.assertRaises(StopIteration):
next(it)
def test_invalid_input(self):
dctx = zstd.ZstdDecompressor()
source = io.BytesIO(b'foobar')
it = dctx.read_from(source)
with self.assertRaisesRegexp(zstd.ZstdError, 'Unknown frame descriptor'):
next(it)
it = dctx.read_from(b'foobar')
with self.assertRaisesRegexp(zstd.ZstdError, 'Unknown frame descriptor'):
next(it)
def test_empty_roundtrip(self):
cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
empty = cctx.compress(b'')
source = io.BytesIO(empty)
source.seek(0)
dctx = zstd.ZstdDecompressor()
it = dctx.read_from(source)
# No chunks should be emitted since there is no data.
with self.assertRaises(StopIteration):
next(it)
# Again for good measure.
with self.assertRaises(StopIteration):
next(it)
def test_skip_bytes_too_large(self):
dctx = zstd.ZstdDecompressor()
with self.assertRaisesRegexp(ValueError, 'skip_bytes must be smaller than read_size'):
b''.join(dctx.read_from(b'', skip_bytes=1, read_size=1))
with self.assertRaisesRegexp(ValueError, 'skip_bytes larger than first input chunk'):
b''.join(dctx.read_from(b'foobar', skip_bytes=10))
def test_skip_bytes(self):
cctx = zstd.ZstdCompressor(write_content_size=False)
compressed = cctx.compress(b'foobar')
dctx = zstd.ZstdDecompressor()
output = b''.join(dctx.read_from(b'hdr' + compressed, skip_bytes=3))
self.assertEqual(output, b'foobar')
def test_large_output(self):
source = io.BytesIO()
source.write(b'f' * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE)
source.write(b'o')
source.seek(0)
cctx = zstd.ZstdCompressor(level=1)
compressed = io.BytesIO(cctx.compress(source.getvalue()))
compressed.seek(0)
dctx = zstd.ZstdDecompressor()
it = dctx.read_from(compressed)
chunks = []
chunks.append(next(it))
chunks.append(next(it))
with self.assertRaises(StopIteration):
next(it)
decompressed = b''.join(chunks)
self.assertEqual(decompressed, source.getvalue())
# And again with buffer protocol.
it = dctx.read_from(compressed.getvalue())
chunks = []
chunks.append(next(it))
chunks.append(next(it))
with self.assertRaises(StopIteration):
next(it)
decompressed = b''.join(chunks)
self.assertEqual(decompressed, source.getvalue())
def test_large_input(self):
bytes = list(struct.Struct('>B').pack(i) for i in range(256))
compressed = io.BytesIO()
input_size = 0
cctx = zstd.ZstdCompressor(level=1)
with cctx.write_to(compressed) as compressor:
while True:
compressor.write(random.choice(bytes))
input_size += 1
have_compressed = len(compressed.getvalue()) > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE
have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2
if have_compressed and have_raw:
break
compressed.seek(0)
self.assertGreater(len(compressed.getvalue()),
zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE)
dctx = zstd.ZstdDecompressor()
it = dctx.read_from(compressed)
chunks = []
chunks.append(next(it))
chunks.append(next(it))
chunks.append(next(it))
with self.assertRaises(StopIteration):
next(it)
decompressed = b''.join(chunks)
self.assertEqual(len(decompressed), input_size)
# And again with buffer protocol.
it = dctx.read_from(compressed.getvalue())
chunks = []
chunks.append(next(it))
chunks.append(next(it))
chunks.append(next(it))
with self.assertRaises(StopIteration):
next(it)
decompressed = b''.join(chunks)
self.assertEqual(len(decompressed), input_size)
def test_interesting(self):
# Found this edge case via fuzzing.
cctx = zstd.ZstdCompressor(level=1)
source = io.BytesIO()
compressed = io.BytesIO()
with cctx.write_to(compressed) as compressor:
for i in range(256):
chunk = b'\0' * 1024
compressor.write(chunk)
source.write(chunk)
dctx = zstd.ZstdDecompressor()
simple = dctx.decompress(compressed.getvalue(),
max_output_size=len(source.getvalue()))
self.assertEqual(simple, source.getvalue())
compressed.seek(0)
streamed = b''.join(dctx.read_from(compressed))
self.assertEqual(streamed, source.getvalue())
def test_read_write_size(self):
source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b'foobarfoobar'))
dctx = zstd.ZstdDecompressor()
for chunk in dctx.read_from(source, read_size=1, write_size=1):
self.assertEqual(len(chunk), 1)
self.assertEqual(source._read_count, len(source.getvalue()))
@make_cffi
class TestDecompressor_content_dict_chain(unittest.TestCase):
def test_bad_inputs_simple(self):
dctx = zstd.ZstdDecompressor()
with self.assertRaises(TypeError):
dctx.decompress_content_dict_chain(b'foo')
with self.assertRaises(TypeError):
dctx.decompress_content_dict_chain((b'foo', b'bar'))
with self.assertRaisesRegexp(ValueError, 'empty input chain'):
dctx.decompress_content_dict_chain([])
with self.assertRaisesRegexp(ValueError, 'chunk 0 must be bytes'):
dctx.decompress_content_dict_chain([u'foo'])
with self.assertRaisesRegexp(ValueError, 'chunk 0 must be bytes'):
dctx.decompress_content_dict_chain([True])
with self.assertRaisesRegexp(ValueError, 'chunk 0 is too small to contain a zstd frame'):
dctx.decompress_content_dict_chain([zstd.FRAME_HEADER])
with self.assertRaisesRegexp(ValueError, 'chunk 0 is not a valid zstd frame'):
dctx.decompress_content_dict_chain([b'foo' * 8])
no_size = zstd.ZstdCompressor().compress(b'foo' * 64)
with self.assertRaisesRegexp(ValueError, 'chunk 0 missing content size in frame'):
dctx.decompress_content_dict_chain([no_size])
# Corrupt first frame.
frame = zstd.ZstdCompressor(write_content_size=True).compress(b'foo' * 64)
frame = frame[0:12] + frame[15:]
with self.assertRaisesRegexp(zstd.ZstdError, 'could not decompress chunk 0'):
dctx.decompress_content_dict_chain([frame])
def test_bad_subsequent_input(self):
initial = zstd.ZstdCompressor(write_content_size=True).compress(b'foo' * 64)
dctx = zstd.ZstdDecompressor()
with self.assertRaisesRegexp(ValueError, 'chunk 1 must be bytes'):
dctx.decompress_content_dict_chain([initial, u'foo'])
with self.assertRaisesRegexp(ValueError, 'chunk 1 must be bytes'):
dctx.decompress_content_dict_chain([initial, None])
with self.assertRaisesRegexp(ValueError, 'chunk 1 is too small to contain a zstd frame'):
dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER])
with self.assertRaisesRegexp(ValueError, 'chunk 1 is not a valid zstd frame'):
dctx.decompress_content_dict_chain([initial, b'foo' * 8])
no_size = zstd.ZstdCompressor().compress(b'foo' * 64)
with self.assertRaisesRegexp(ValueError, 'chunk 1 missing content size in frame'):
dctx.decompress_content_dict_chain([initial, no_size])
# Corrupt second frame.
cctx = zstd.ZstdCompressor(write_content_size=True, dict_data=zstd.ZstdCompressionDict(b'foo' * 64))
frame = cctx.compress(b'bar' * 64)
frame = frame[0:12] + frame[15:]
with self.assertRaisesRegexp(zstd.ZstdError, 'could not decompress chunk 1'):
dctx.decompress_content_dict_chain([initial, frame])
def test_simple(self):
original = [
b'foo' * 64,
b'foobar' * 64,
b'baz' * 64,
b'foobaz' * 64,
b'foobarbaz' * 64,
]
chunks = []
chunks.append(zstd.ZstdCompressor(write_content_size=True).compress(original[0]))
for i, chunk in enumerate(original[1:]):
d = zstd.ZstdCompressionDict(original[i])
cctx = zstd.ZstdCompressor(dict_data=d, write_content_size=True)
chunks.append(cctx.compress(chunk))
for i in range(1, len(original)):
chain = chunks[0:i]
expected = original[i - 1]
dctx = zstd.ZstdDecompressor()
decompressed = dctx.decompress_content_dict_chain(chain)
self.assertEqual(decompressed, expected)
# TODO enable for CFFI
class TestDecompressor_multi_decompress_to_buffer(unittest.TestCase):
def test_invalid_inputs(self):
dctx = zstd.ZstdDecompressor()
with self.assertRaises(TypeError):
dctx.multi_decompress_to_buffer(True)
with self.assertRaises(TypeError):
dctx.multi_decompress_to_buffer((1, 2))
with self.assertRaisesRegexp(TypeError, 'item 0 not a bytes like object'):
dctx.multi_decompress_to_buffer([u'foo'])
with self.assertRaisesRegexp(ValueError, 'could not determine decompressed size of item 0'):
dctx.multi_decompress_to_buffer([b'foobarbaz'])
def test_list_input(self):
cctx = zstd.ZstdCompressor(write_content_size=True)
original = [b'foo' * 4, b'bar' * 6]
frames = [cctx.compress(d) for d in original]
dctx = zstd.ZstdDecompressor()
result = dctx.multi_decompress_to_buffer(frames)
self.assertEqual(len(result), len(frames))
self.assertEqual(result.size(), sum(map(len, original)))
for i, data in enumerate(original):
self.assertEqual(result[i].tobytes(), data)
self.assertEqual(result[0].offset, 0)
self.assertEqual(len(result[0]), 12)
self.assertEqual(result[1].offset, 12)
self.assertEqual(len(result[1]), 18)
def test_list_input_frame_sizes(self):
cctx = zstd.ZstdCompressor(write_content_size=False)
original = [b'foo' * 4, b'bar' * 6, b'baz' * 8]
frames = [cctx.compress(d) for d in original]
sizes = struct.pack('=' + 'Q' * len(original), *map(len, original))
dctx = zstd.ZstdDecompressor()
result = dctx.multi_decompress_to_buffer(frames, decompressed_sizes=sizes)
self.assertEqual(len(result), len(frames))
self.assertEqual(result.size(), sum(map(len, original)))
for i, data in enumerate(original):
self.assertEqual(result[i].tobytes(), data)
def test_buffer_with_segments_input(self):
cctx = zstd.ZstdCompressor(write_content_size=True)
original = [b'foo' * 4, b'bar' * 6]
frames = [cctx.compress(d) for d in original]
dctx = zstd.ZstdDecompressor()
segments = struct.pack('=QQQQ', 0, len(frames[0]), len(frames[0]), len(frames[1]))
b = zstd.BufferWithSegments(b''.join(frames), segments)
result = dctx.multi_decompress_to_buffer(b)
self.assertEqual(len(result), len(frames))
self.assertEqual(result[0].offset, 0)
self.assertEqual(len(result[0]), 12)
self.assertEqual(result[1].offset, 12)
self.assertEqual(len(result[1]), 18)
def test_buffer_with_segments_sizes(self):
cctx = zstd.ZstdCompressor(write_content_size=False)
original = [b'foo' * 4, b'bar' * 6, b'baz' * 8]
frames = [cctx.compress(d) for d in original]
sizes = struct.pack('=' + 'Q' * len(original), *map(len, original))
segments = struct.pack('=QQQQQQ', 0, len(frames[0]),
len(frames[0]), len(frames[1]),
len(frames[0]) + len(frames[1]), len(frames[2]))
b = zstd.BufferWithSegments(b''.join(frames), segments)
dctx = zstd.ZstdDecompressor()
result = dctx.multi_decompress_to_buffer(b, decompressed_sizes=sizes)
self.assertEqual(len(result), len(frames))
self.assertEqual(result.size(), sum(map(len, original)))
for i, data in enumerate(original):
self.assertEqual(result[i].tobytes(), data)
def test_buffer_with_segments_collection_input(self):
cctx = zstd.ZstdCompressor(write_content_size=True)
original = [
b'foo0' * 2,
b'foo1' * 3,
b'foo2' * 4,
b'foo3' * 5,
b'foo4' * 6,
]
frames = cctx.multi_compress_to_buffer(original)
# Check round trip.
dctx = zstd.ZstdDecompressor()
decompressed = dctx.multi_decompress_to_buffer(frames, threads=3)
self.assertEqual(len(decompressed), len(original))
for i, data in enumerate(original):
self.assertEqual(data, decompressed[i].tobytes())
# And a manual mode.
b = b''.join([frames[0].tobytes(), frames[1].tobytes()])
b1 = zstd.BufferWithSegments(b, struct.pack('=QQQQ',
0, len(frames[0]),
len(frames[0]), len(frames[1])))
b = b''.join([frames[2].tobytes(), frames[3].tobytes(), frames[4].tobytes()])
b2 = zstd.BufferWithSegments(b, struct.pack('=QQQQQQ',
0, len(frames[2]),
len(frames[2]), len(frames[3]),
len(frames[2]) + len(frames[3]), len(frames[4])))
c = zstd.BufferWithSegmentsCollection(b1, b2)
dctx = zstd.ZstdDecompressor()
decompressed = dctx.multi_decompress_to_buffer(c)
self.assertEqual(len(decompressed), 5)
for i in range(5):
self.assertEqual(decompressed[i].tobytes(), original[i])
def test_multiple_threads(self):
cctx = zstd.ZstdCompressor(write_content_size=True)
frames = []
frames.extend(cctx.compress(b'x' * 64) for i in range(256))
frames.extend(cctx.compress(b'y' * 64) for i in range(256))
dctx = zstd.ZstdDecompressor()
result = dctx.multi_decompress_to_buffer(frames, threads=-1)
self.assertEqual(len(result), len(frames))
self.assertEqual(result.size(), 2 * 64 * 256)
self.assertEqual(result[0].tobytes(), b'x' * 64)
self.assertEqual(result[256].tobytes(), b'y' * 64)
def test_item_failure(self):
cctx = zstd.ZstdCompressor(write_content_size=True)
frames = [cctx.compress(b'x' * 128), cctx.compress(b'y' * 128)]
frames[1] = frames[1] + b'extra'
dctx = zstd.ZstdDecompressor()
with self.assertRaisesRegexp(zstd.ZstdError, 'error decompressing item 1: Src size incorrect'):
dctx.multi_decompress_to_buffer(frames)
with self.assertRaisesRegexp(zstd.ZstdError, 'error decompressing item 1: Src size incorrect'):
dctx.multi_decompress_to_buffer(frames, threads=2)