|
1 import io |
|
2 import random |
|
3 import struct |
|
4 import sys |
|
5 |
|
6 try: |
|
7 import unittest2 as unittest |
|
8 except ImportError: |
|
9 import unittest |
|
10 |
|
11 import zstd |
|
12 |
|
13 from .common import OpCountingBytesIO |
|
14 |
|
15 |
|
16 if sys.version_info[0] >= 3: |
|
17 next = lambda it: it.__next__() |
|
18 else: |
|
19 next = lambda it: it.next() |
|
20 |
|
21 |
|
22 class TestDecompressor_decompress(unittest.TestCase): |
|
23 def test_empty_input(self): |
|
24 dctx = zstd.ZstdDecompressor() |
|
25 |
|
26 with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'): |
|
27 dctx.decompress(b'') |
|
28 |
|
29 def test_invalid_input(self): |
|
30 dctx = zstd.ZstdDecompressor() |
|
31 |
|
32 with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'): |
|
33 dctx.decompress(b'foobar') |
|
34 |
|
35 def test_no_content_size_in_frame(self): |
|
36 cctx = zstd.ZstdCompressor(write_content_size=False) |
|
37 compressed = cctx.compress(b'foobar') |
|
38 |
|
39 dctx = zstd.ZstdDecompressor() |
|
40 with self.assertRaisesRegexp(zstd.ZstdError, 'input data invalid'): |
|
41 dctx.decompress(compressed) |
|
42 |
|
43 def test_content_size_present(self): |
|
44 cctx = zstd.ZstdCompressor(write_content_size=True) |
|
45 compressed = cctx.compress(b'foobar') |
|
46 |
|
47 dctx = zstd.ZstdDecompressor() |
|
48 decompressed = dctx.decompress(compressed) |
|
49 self.assertEqual(decompressed, b'foobar') |
|
50 |
|
51 def test_max_output_size(self): |
|
52 cctx = zstd.ZstdCompressor(write_content_size=False) |
|
53 source = b'foobar' * 256 |
|
54 compressed = cctx.compress(source) |
|
55 |
|
56 dctx = zstd.ZstdDecompressor() |
|
57 # Will fit into buffer exactly the size of input. |
|
58 decompressed = dctx.decompress(compressed, max_output_size=len(source)) |
|
59 self.assertEqual(decompressed, source) |
|
60 |
|
61 # Input size - 1 fails |
|
62 with self.assertRaisesRegexp(zstd.ZstdError, 'Destination buffer is too small'): |
|
63 dctx.decompress(compressed, max_output_size=len(source) - 1) |
|
64 |
|
65 # Input size + 1 works |
|
66 decompressed = dctx.decompress(compressed, max_output_size=len(source) + 1) |
|
67 self.assertEqual(decompressed, source) |
|
68 |
|
69 # A much larger buffer works. |
|
70 decompressed = dctx.decompress(compressed, max_output_size=len(source) * 64) |
|
71 self.assertEqual(decompressed, source) |
|
72 |
|
73 def test_stupidly_large_output_buffer(self): |
|
74 cctx = zstd.ZstdCompressor(write_content_size=False) |
|
75 compressed = cctx.compress(b'foobar' * 256) |
|
76 dctx = zstd.ZstdDecompressor() |
|
77 |
|
78 # Will get OverflowError on some Python distributions that can't |
|
79 # handle really large integers. |
|
80 with self.assertRaises((MemoryError, OverflowError)): |
|
81 dctx.decompress(compressed, max_output_size=2**62) |
|
82 |
|
83 def test_dictionary(self): |
|
84 samples = [] |
|
85 for i in range(128): |
|
86 samples.append(b'foo' * 64) |
|
87 samples.append(b'bar' * 64) |
|
88 samples.append(b'foobar' * 64) |
|
89 |
|
90 d = zstd.train_dictionary(8192, samples) |
|
91 |
|
92 orig = b'foobar' * 16384 |
|
93 cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_content_size=True) |
|
94 compressed = cctx.compress(orig) |
|
95 |
|
96 dctx = zstd.ZstdDecompressor(dict_data=d) |
|
97 decompressed = dctx.decompress(compressed) |
|
98 |
|
99 self.assertEqual(decompressed, orig) |
|
100 |
|
101 def test_dictionary_multiple(self): |
|
102 samples = [] |
|
103 for i in range(128): |
|
104 samples.append(b'foo' * 64) |
|
105 samples.append(b'bar' * 64) |
|
106 samples.append(b'foobar' * 64) |
|
107 |
|
108 d = zstd.train_dictionary(8192, samples) |
|
109 |
|
110 sources = (b'foobar' * 8192, b'foo' * 8192, b'bar' * 8192) |
|
111 compressed = [] |
|
112 cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_content_size=True) |
|
113 for source in sources: |
|
114 compressed.append(cctx.compress(source)) |
|
115 |
|
116 dctx = zstd.ZstdDecompressor(dict_data=d) |
|
117 for i in range(len(sources)): |
|
118 decompressed = dctx.decompress(compressed[i]) |
|
119 self.assertEqual(decompressed, sources[i]) |
|
120 |
|
121 |
|
122 class TestDecompressor_copy_stream(unittest.TestCase): |
|
123 def test_no_read(self): |
|
124 source = object() |
|
125 dest = io.BytesIO() |
|
126 |
|
127 dctx = zstd.ZstdDecompressor() |
|
128 with self.assertRaises(ValueError): |
|
129 dctx.copy_stream(source, dest) |
|
130 |
|
131 def test_no_write(self): |
|
132 source = io.BytesIO() |
|
133 dest = object() |
|
134 |
|
135 dctx = zstd.ZstdDecompressor() |
|
136 with self.assertRaises(ValueError): |
|
137 dctx.copy_stream(source, dest) |
|
138 |
|
139 def test_empty(self): |
|
140 source = io.BytesIO() |
|
141 dest = io.BytesIO() |
|
142 |
|
143 dctx = zstd.ZstdDecompressor() |
|
144 # TODO should this raise an error? |
|
145 r, w = dctx.copy_stream(source, dest) |
|
146 |
|
147 self.assertEqual(r, 0) |
|
148 self.assertEqual(w, 0) |
|
149 self.assertEqual(dest.getvalue(), b'') |
|
150 |
|
151 def test_large_data(self): |
|
152 source = io.BytesIO() |
|
153 for i in range(255): |
|
154 source.write(struct.Struct('>B').pack(i) * 16384) |
|
155 source.seek(0) |
|
156 |
|
157 compressed = io.BytesIO() |
|
158 cctx = zstd.ZstdCompressor() |
|
159 cctx.copy_stream(source, compressed) |
|
160 |
|
161 compressed.seek(0) |
|
162 dest = io.BytesIO() |
|
163 dctx = zstd.ZstdDecompressor() |
|
164 r, w = dctx.copy_stream(compressed, dest) |
|
165 |
|
166 self.assertEqual(r, len(compressed.getvalue())) |
|
167 self.assertEqual(w, len(source.getvalue())) |
|
168 |
|
169 def test_read_write_size(self): |
|
170 source = OpCountingBytesIO(zstd.ZstdCompressor().compress( |
|
171 b'foobarfoobar')) |
|
172 |
|
173 dest = OpCountingBytesIO() |
|
174 dctx = zstd.ZstdDecompressor() |
|
175 r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1) |
|
176 |
|
177 self.assertEqual(r, len(source.getvalue())) |
|
178 self.assertEqual(w, len(b'foobarfoobar')) |
|
179 self.assertEqual(source._read_count, len(source.getvalue()) + 1) |
|
180 self.assertEqual(dest._write_count, len(dest.getvalue())) |
|
181 |
|
182 |
|
183 class TestDecompressor_decompressobj(unittest.TestCase): |
|
184 def test_simple(self): |
|
185 data = zstd.ZstdCompressor(level=1).compress(b'foobar') |
|
186 |
|
187 dctx = zstd.ZstdDecompressor() |
|
188 dobj = dctx.decompressobj() |
|
189 self.assertEqual(dobj.decompress(data), b'foobar') |
|
190 |
|
191 def test_reuse(self): |
|
192 data = zstd.ZstdCompressor(level=1).compress(b'foobar') |
|
193 |
|
194 dctx = zstd.ZstdDecompressor() |
|
195 dobj = dctx.decompressobj() |
|
196 dobj.decompress(data) |
|
197 |
|
198 with self.assertRaisesRegexp(zstd.ZstdError, 'cannot use a decompressobj'): |
|
199 dobj.decompress(data) |
|
200 |
|
201 |
|
202 def decompress_via_writer(data): |
|
203 buffer = io.BytesIO() |
|
204 dctx = zstd.ZstdDecompressor() |
|
205 with dctx.write_to(buffer) as decompressor: |
|
206 decompressor.write(data) |
|
207 return buffer.getvalue() |
|
208 |
|
209 |
|
210 class TestDecompressor_write_to(unittest.TestCase): |
|
211 def test_empty_roundtrip(self): |
|
212 cctx = zstd.ZstdCompressor() |
|
213 empty = cctx.compress(b'') |
|
214 self.assertEqual(decompress_via_writer(empty), b'') |
|
215 |
|
216 def test_large_roundtrip(self): |
|
217 chunks = [] |
|
218 for i in range(255): |
|
219 chunks.append(struct.Struct('>B').pack(i) * 16384) |
|
220 orig = b''.join(chunks) |
|
221 cctx = zstd.ZstdCompressor() |
|
222 compressed = cctx.compress(orig) |
|
223 |
|
224 self.assertEqual(decompress_via_writer(compressed), orig) |
|
225 |
|
226 def test_multiple_calls(self): |
|
227 chunks = [] |
|
228 for i in range(255): |
|
229 for j in range(255): |
|
230 chunks.append(struct.Struct('>B').pack(j) * i) |
|
231 |
|
232 orig = b''.join(chunks) |
|
233 cctx = zstd.ZstdCompressor() |
|
234 compressed = cctx.compress(orig) |
|
235 |
|
236 buffer = io.BytesIO() |
|
237 dctx = zstd.ZstdDecompressor() |
|
238 with dctx.write_to(buffer) as decompressor: |
|
239 pos = 0 |
|
240 while pos < len(compressed): |
|
241 pos2 = pos + 8192 |
|
242 decompressor.write(compressed[pos:pos2]) |
|
243 pos += 8192 |
|
244 self.assertEqual(buffer.getvalue(), orig) |
|
245 |
|
246 def test_dictionary(self): |
|
247 samples = [] |
|
248 for i in range(128): |
|
249 samples.append(b'foo' * 64) |
|
250 samples.append(b'bar' * 64) |
|
251 samples.append(b'foobar' * 64) |
|
252 |
|
253 d = zstd.train_dictionary(8192, samples) |
|
254 |
|
255 orig = b'foobar' * 16384 |
|
256 buffer = io.BytesIO() |
|
257 cctx = zstd.ZstdCompressor(dict_data=d) |
|
258 with cctx.write_to(buffer) as compressor: |
|
259 compressor.write(orig) |
|
260 |
|
261 compressed = buffer.getvalue() |
|
262 buffer = io.BytesIO() |
|
263 |
|
264 dctx = zstd.ZstdDecompressor(dict_data=d) |
|
265 with dctx.write_to(buffer) as decompressor: |
|
266 decompressor.write(compressed) |
|
267 |
|
268 self.assertEqual(buffer.getvalue(), orig) |
|
269 |
|
270 def test_memory_size(self): |
|
271 dctx = zstd.ZstdDecompressor() |
|
272 buffer = io.BytesIO() |
|
273 with dctx.write_to(buffer) as decompressor: |
|
274 size = decompressor.memory_size() |
|
275 |
|
276 self.assertGreater(size, 100000) |
|
277 |
|
278 def test_write_size(self): |
|
279 source = zstd.ZstdCompressor().compress(b'foobarfoobar') |
|
280 dest = OpCountingBytesIO() |
|
281 dctx = zstd.ZstdDecompressor() |
|
282 with dctx.write_to(dest, write_size=1) as decompressor: |
|
283 s = struct.Struct('>B') |
|
284 for c in source: |
|
285 if not isinstance(c, str): |
|
286 c = s.pack(c) |
|
287 decompressor.write(c) |
|
288 |
|
289 |
|
290 self.assertEqual(dest.getvalue(), b'foobarfoobar') |
|
291 self.assertEqual(dest._write_count, len(dest.getvalue())) |
|
292 |
|
293 |
|
294 class TestDecompressor_read_from(unittest.TestCase): |
|
295 def test_type_validation(self): |
|
296 dctx = zstd.ZstdDecompressor() |
|
297 |
|
298 # Object with read() works. |
|
299 dctx.read_from(io.BytesIO()) |
|
300 |
|
301 # Buffer protocol works. |
|
302 dctx.read_from(b'foobar') |
|
303 |
|
304 with self.assertRaisesRegexp(ValueError, 'must pass an object with a read'): |
|
305 dctx.read_from(True) |
|
306 |
|
307 def test_empty_input(self): |
|
308 dctx = zstd.ZstdDecompressor() |
|
309 |
|
310 source = io.BytesIO() |
|
311 it = dctx.read_from(source) |
|
312 # TODO this is arguably wrong. Should get an error about missing frame foo. |
|
313 with self.assertRaises(StopIteration): |
|
314 next(it) |
|
315 |
|
316 it = dctx.read_from(b'') |
|
317 with self.assertRaises(StopIteration): |
|
318 next(it) |
|
319 |
|
320 def test_invalid_input(self): |
|
321 dctx = zstd.ZstdDecompressor() |
|
322 |
|
323 source = io.BytesIO(b'foobar') |
|
324 it = dctx.read_from(source) |
|
325 with self.assertRaisesRegexp(zstd.ZstdError, 'Unknown frame descriptor'): |
|
326 next(it) |
|
327 |
|
328 it = dctx.read_from(b'foobar') |
|
329 with self.assertRaisesRegexp(zstd.ZstdError, 'Unknown frame descriptor'): |
|
330 next(it) |
|
331 |
|
332 def test_empty_roundtrip(self): |
|
333 cctx = zstd.ZstdCompressor(level=1, write_content_size=False) |
|
334 empty = cctx.compress(b'') |
|
335 |
|
336 source = io.BytesIO(empty) |
|
337 source.seek(0) |
|
338 |
|
339 dctx = zstd.ZstdDecompressor() |
|
340 it = dctx.read_from(source) |
|
341 |
|
342 # No chunks should be emitted since there is no data. |
|
343 with self.assertRaises(StopIteration): |
|
344 next(it) |
|
345 |
|
346 # Again for good measure. |
|
347 with self.assertRaises(StopIteration): |
|
348 next(it) |
|
349 |
|
350 def test_skip_bytes_too_large(self): |
|
351 dctx = zstd.ZstdDecompressor() |
|
352 |
|
353 with self.assertRaisesRegexp(ValueError, 'skip_bytes must be smaller than read_size'): |
|
354 dctx.read_from(b'', skip_bytes=1, read_size=1) |
|
355 |
|
356 with self.assertRaisesRegexp(ValueError, 'skip_bytes larger than first input chunk'): |
|
357 b''.join(dctx.read_from(b'foobar', skip_bytes=10)) |
|
358 |
|
359 def test_skip_bytes(self): |
|
360 cctx = zstd.ZstdCompressor(write_content_size=False) |
|
361 compressed = cctx.compress(b'foobar') |
|
362 |
|
363 dctx = zstd.ZstdDecompressor() |
|
364 output = b''.join(dctx.read_from(b'hdr' + compressed, skip_bytes=3)) |
|
365 self.assertEqual(output, b'foobar') |
|
366 |
|
367 def test_large_output(self): |
|
368 source = io.BytesIO() |
|
369 source.write(b'f' * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE) |
|
370 source.write(b'o') |
|
371 source.seek(0) |
|
372 |
|
373 cctx = zstd.ZstdCompressor(level=1) |
|
374 compressed = io.BytesIO(cctx.compress(source.getvalue())) |
|
375 compressed.seek(0) |
|
376 |
|
377 dctx = zstd.ZstdDecompressor() |
|
378 it = dctx.read_from(compressed) |
|
379 |
|
380 chunks = [] |
|
381 chunks.append(next(it)) |
|
382 chunks.append(next(it)) |
|
383 |
|
384 with self.assertRaises(StopIteration): |
|
385 next(it) |
|
386 |
|
387 decompressed = b''.join(chunks) |
|
388 self.assertEqual(decompressed, source.getvalue()) |
|
389 |
|
390 # And again with buffer protocol. |
|
391 it = dctx.read_from(compressed.getvalue()) |
|
392 chunks = [] |
|
393 chunks.append(next(it)) |
|
394 chunks.append(next(it)) |
|
395 |
|
396 with self.assertRaises(StopIteration): |
|
397 next(it) |
|
398 |
|
399 decompressed = b''.join(chunks) |
|
400 self.assertEqual(decompressed, source.getvalue()) |
|
401 |
|
402 def test_large_input(self): |
|
403 bytes = list(struct.Struct('>B').pack(i) for i in range(256)) |
|
404 compressed = io.BytesIO() |
|
405 input_size = 0 |
|
406 cctx = zstd.ZstdCompressor(level=1) |
|
407 with cctx.write_to(compressed) as compressor: |
|
408 while True: |
|
409 compressor.write(random.choice(bytes)) |
|
410 input_size += 1 |
|
411 |
|
412 have_compressed = len(compressed.getvalue()) > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE |
|
413 have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2 |
|
414 if have_compressed and have_raw: |
|
415 break |
|
416 |
|
417 compressed.seek(0) |
|
418 self.assertGreater(len(compressed.getvalue()), |
|
419 zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE) |
|
420 |
|
421 dctx = zstd.ZstdDecompressor() |
|
422 it = dctx.read_from(compressed) |
|
423 |
|
424 chunks = [] |
|
425 chunks.append(next(it)) |
|
426 chunks.append(next(it)) |
|
427 chunks.append(next(it)) |
|
428 |
|
429 with self.assertRaises(StopIteration): |
|
430 next(it) |
|
431 |
|
432 decompressed = b''.join(chunks) |
|
433 self.assertEqual(len(decompressed), input_size) |
|
434 |
|
435 # And again with buffer protocol. |
|
436 it = dctx.read_from(compressed.getvalue()) |
|
437 |
|
438 chunks = [] |
|
439 chunks.append(next(it)) |
|
440 chunks.append(next(it)) |
|
441 chunks.append(next(it)) |
|
442 |
|
443 with self.assertRaises(StopIteration): |
|
444 next(it) |
|
445 |
|
446 decompressed = b''.join(chunks) |
|
447 self.assertEqual(len(decompressed), input_size) |
|
448 |
|
449 def test_interesting(self): |
|
450 # Found this edge case via fuzzing. |
|
451 cctx = zstd.ZstdCompressor(level=1) |
|
452 |
|
453 source = io.BytesIO() |
|
454 |
|
455 compressed = io.BytesIO() |
|
456 with cctx.write_to(compressed) as compressor: |
|
457 for i in range(256): |
|
458 chunk = b'\0' * 1024 |
|
459 compressor.write(chunk) |
|
460 source.write(chunk) |
|
461 |
|
462 dctx = zstd.ZstdDecompressor() |
|
463 |
|
464 simple = dctx.decompress(compressed.getvalue(), |
|
465 max_output_size=len(source.getvalue())) |
|
466 self.assertEqual(simple, source.getvalue()) |
|
467 |
|
468 compressed.seek(0) |
|
469 streamed = b''.join(dctx.read_from(compressed)) |
|
470 self.assertEqual(streamed, source.getvalue()) |
|
471 |
|
472 def test_read_write_size(self): |
|
473 source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b'foobarfoobar')) |
|
474 dctx = zstd.ZstdDecompressor() |
|
475 for chunk in dctx.read_from(source, read_size=1, write_size=1): |
|
476 self.assertEqual(len(chunk), 1) |
|
477 |
|
478 self.assertEqual(source._read_count, len(source.getvalue())) |