comparison contrib/python-zstandard/c-ext/decompressor.c @ 37495:b1fb341d8a61

zstandard: vendor python-zstandard 0.9.0 This was just released. It features a number of goodies. More info at https://gregoryszorc.com/blog/2018/04/09/release-of-python-zstandard-0.9/. The clang-format ignore list was updated to reflect the new source of files. The project contains a vendored copy of zstandard 1.3.4. The old version was 1.1.3. One of the changes between those versions is that zstandard is now dual licensed BSD + GPLv2 and the patent rights grant has been removed. Good riddance. The API should be backwards compatible. So no changes in core should be needed. However, there were a number of changes in the library that we'll want to adapt to. Those will be addressed in subsequent commits. Differential Revision: https://phab.mercurial-scm.org/D3198
author Gregory Szorc <gregory.szorc@gmail.com>
date Mon, 09 Apr 2018 10:13:29 -0700
parents e0dc40530c5a
children 73fef626dae3
comparison
equal deleted inserted replaced
37494:1ce7a55b09d1 37495:b1fb341d8a61
10 #include "pool.h" 10 #include "pool.h"
11 11
12 extern PyObject* ZstdError; 12 extern PyObject* ZstdError;
13 13
14 /** 14 /**
15 * Ensure the ZSTD_DStream on a ZstdDecompressor is initialized and reset. 15 * Ensure the ZSTD_DCtx on a decompressor is initiated and ready for a new operation.
16 * 16 */
17 * This should be called before starting a decompression operation with a 17 int ensure_dctx(ZstdDecompressor* decompressor, int loadDict) {
18 * ZSTD_DStream on a ZstdDecompressor.
19 */
20 int init_dstream(ZstdDecompressor* decompressor) {
21 void* dictData = NULL;
22 size_t dictSize = 0;
23 size_t zresult; 18 size_t zresult;
24 19
25 /* Simple case of dstream already exists. Just reset it. */ 20 ZSTD_DCtx_reset(decompressor->dctx);
26 if (decompressor->dstream) { 21
27 zresult = ZSTD_resetDStream(decompressor->dstream); 22 if (decompressor->maxWindowSize) {
23 zresult = ZSTD_DCtx_setMaxWindowSize(decompressor->dctx, decompressor->maxWindowSize);
28 if (ZSTD_isError(zresult)) { 24 if (ZSTD_isError(zresult)) {
29 PyErr_Format(ZstdError, "could not reset DStream: %s", 25 PyErr_Format(ZstdError, "unable to set max window size: %s",
30 ZSTD_getErrorName(zresult)); 26 ZSTD_getErrorName(zresult));
31 return -1; 27 return 1;
32 } 28 }
33 29 }
34 return 0; 30
35 } 31 zresult = ZSTD_DCtx_setFormat(decompressor->dctx, decompressor->format);
36
37 decompressor->dstream = ZSTD_createDStream();
38 if (!decompressor->dstream) {
39 PyErr_SetString(ZstdError, "could not create DStream");
40 return -1;
41 }
42
43 if (decompressor->dict) {
44 dictData = decompressor->dict->dictData;
45 dictSize = decompressor->dict->dictSize;
46 }
47
48 if (dictData) {
49 zresult = ZSTD_initDStream_usingDict(decompressor->dstream, dictData, dictSize);
50 }
51 else {
52 zresult = ZSTD_initDStream(decompressor->dstream);
53 }
54
55 if (ZSTD_isError(zresult)) { 32 if (ZSTD_isError(zresult)) {
56 /* Don't leave a reference to an invalid object. */ 33 PyErr_Format(ZstdError, "unable to set decoding format: %s",
57 ZSTD_freeDStream(decompressor->dstream);
58 decompressor->dstream = NULL;
59
60 PyErr_Format(ZstdError, "could not initialize DStream: %s",
61 ZSTD_getErrorName(zresult)); 34 ZSTD_getErrorName(zresult));
62 return -1; 35 return 1;
36 }
37
38 if (loadDict && decompressor->dict) {
39 if (ensure_ddict(decompressor->dict)) {
40 return 1;
41 }
42
43 zresult = ZSTD_DCtx_refDDict(decompressor->dctx, decompressor->dict->ddict);
44 if (ZSTD_isError(zresult)) {
45 PyErr_Format(ZstdError, "unable to reference prepared dictionary: %s",
46 ZSTD_getErrorName(zresult));
47 return 1;
48 }
63 } 49 }
64 50
65 return 0; 51 return 0;
66 } 52 }
67 53
74 ); 60 );
75 61
76 static int Decompressor_init(ZstdDecompressor* self, PyObject* args, PyObject* kwargs) { 62 static int Decompressor_init(ZstdDecompressor* self, PyObject* args, PyObject* kwargs) {
77 static char* kwlist[] = { 63 static char* kwlist[] = {
78 "dict_data", 64 "dict_data",
65 "max_window_size",
66 "format",
79 NULL 67 NULL
80 }; 68 };
81 69
82 ZstdCompressionDict* dict = NULL; 70 ZstdCompressionDict* dict = NULL;
71 size_t maxWindowSize = 0;
72 ZSTD_format_e format = ZSTD_f_zstd1;
83 73
84 self->dctx = NULL; 74 self->dctx = NULL;
85 self->dict = NULL; 75 self->dict = NULL;
86 self->ddict = NULL; 76
87 77 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!II:ZstdDecompressor", kwlist,
88 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!:ZstdDecompressor", kwlist, 78 &ZstdCompressionDictType, &dict, &maxWindowSize, &format)) {
89 &ZstdCompressionDictType, &dict)) {
90 return -1; 79 return -1;
91 } 80 }
92 81
93 /* TODO lazily initialize the reference ZSTD_DCtx on first use since
94 not instances of ZstdDecompressor will use a ZSTD_DCtx. */
95 self->dctx = ZSTD_createDCtx(); 82 self->dctx = ZSTD_createDCtx();
96 if (!self->dctx) { 83 if (!self->dctx) {
97 PyErr_NoMemory(); 84 PyErr_NoMemory();
98 goto except; 85 goto except;
99 } 86 }
100 87
88 self->maxWindowSize = maxWindowSize;
89 self->format = format;
90
101 if (dict) { 91 if (dict) {
102 self->dict = dict; 92 self->dict = dict;
103 Py_INCREF(dict); 93 Py_INCREF(dict);
104 } 94 }
105 95
96 if (ensure_dctx(self, 1)) {
97 goto except;
98 }
99
106 return 0; 100 return 0;
107 101
108 except: 102 except:
103 Py_CLEAR(self->dict);
104
109 if (self->dctx) { 105 if (self->dctx) {
110 ZSTD_freeDCtx(self->dctx); 106 ZSTD_freeDCtx(self->dctx);
111 self->dctx = NULL; 107 self->dctx = NULL;
112 } 108 }
113 109
115 } 111 }
116 112
117 static void Decompressor_dealloc(ZstdDecompressor* self) { 113 static void Decompressor_dealloc(ZstdDecompressor* self) {
118 Py_CLEAR(self->dict); 114 Py_CLEAR(self->dict);
119 115
120 if (self->ddict) {
121 ZSTD_freeDDict(self->ddict);
122 self->ddict = NULL;
123 }
124
125 if (self->dstream) {
126 ZSTD_freeDStream(self->dstream);
127 self->dstream = NULL;
128 }
129
130 if (self->dctx) { 116 if (self->dctx) {
131 ZSTD_freeDCtx(self->dctx); 117 ZSTD_freeDCtx(self->dctx);
132 self->dctx = NULL; 118 self->dctx = NULL;
133 } 119 }
134 120
135 PyObject_Del(self); 121 PyObject_Del(self);
122 }
123
124 PyDoc_STRVAR(Decompressor_memory_size__doc__,
125 "memory_size() -- Size of decompression context, in bytes\n"
126 );
127
128 static PyObject* Decompressor_memory_size(ZstdDecompressor* self) {
129 if (self->dctx) {
130 return PyLong_FromSize_t(ZSTD_sizeof_DCtx(self->dctx));
131 }
132 else {
133 PyErr_SetString(ZstdError, "no decompressor context found; this should never happen");
134 return NULL;
135 }
136 } 136 }
137 137
138 PyDoc_STRVAR(Decompressor_copy_stream__doc__, 138 PyDoc_STRVAR(Decompressor_copy_stream__doc__,
139 "copy_stream(ifh, ofh[, read_size=default, write_size=default]) -- decompress data between streams\n" 139 "copy_stream(ifh, ofh[, read_size=default, write_size=default]) -- decompress data between streams\n"
140 "\n" 140 "\n"
164 ZSTD_outBuffer output; 164 ZSTD_outBuffer output;
165 Py_ssize_t totalRead = 0; 165 Py_ssize_t totalRead = 0;
166 Py_ssize_t totalWrite = 0; 166 Py_ssize_t totalWrite = 0;
167 char* readBuffer; 167 char* readBuffer;
168 Py_ssize_t readSize; 168 Py_ssize_t readSize;
169 PyObject* readResult; 169 PyObject* readResult = NULL;
170 PyObject* res = NULL; 170 PyObject* res = NULL;
171 size_t zresult = 0; 171 size_t zresult = 0;
172 PyObject* writeResult; 172 PyObject* writeResult;
173 PyObject* totalReadPy; 173 PyObject* totalReadPy;
174 PyObject* totalWritePy; 174 PyObject* totalWritePy;
189 } 189 }
190 190
191 /* Prevent free on uninitialized memory in finally. */ 191 /* Prevent free on uninitialized memory in finally. */
192 output.dst = NULL; 192 output.dst = NULL;
193 193
194 if (0 != init_dstream(self)) { 194 if (ensure_dctx(self, 1)) {
195 res = NULL; 195 res = NULL;
196 goto finally; 196 goto finally;
197 } 197 }
198 198
199 output.dst = PyMem_Malloc(outSize); 199 output.dst = PyMem_Malloc(outSize);
227 input.size = readSize; 227 input.size = readSize;
228 input.pos = 0; 228 input.pos = 0;
229 229
230 while (input.pos < input.size) { 230 while (input.pos < input.size) {
231 Py_BEGIN_ALLOW_THREADS 231 Py_BEGIN_ALLOW_THREADS
232 zresult = ZSTD_decompressStream(self->dstream, &output, &input); 232 zresult = ZSTD_decompress_generic(self->dctx, &output, &input);
233 Py_END_ALLOW_THREADS 233 Py_END_ALLOW_THREADS
234 234
235 if (ZSTD_isError(zresult)) { 235 if (ZSTD_isError(zresult)) {
236 PyErr_Format(ZstdError, "zstd decompressor error: %s", 236 PyErr_Format(ZstdError, "zstd decompressor error: %s",
237 ZSTD_getErrorName(zresult)); 237 ZSTD_getErrorName(zresult));
250 Py_XDECREF(writeResult); 250 Py_XDECREF(writeResult);
251 totalWrite += output.pos; 251 totalWrite += output.pos;
252 output.pos = 0; 252 output.pos = 0;
253 } 253 }
254 } 254 }
255
256 Py_CLEAR(readResult);
255 } 257 }
256 258
257 /* Source stream is exhausted. Finish up. */ 259 /* Source stream is exhausted. Finish up. */
258 260
259 totalReadPy = PyLong_FromSsize_t(totalRead); 261 totalReadPy = PyLong_FromSsize_t(totalRead);
264 266
265 finally: 267 finally:
266 if (output.dst) { 268 if (output.dst) {
267 PyMem_Free(output.dst); 269 PyMem_Free(output.dst);
268 } 270 }
271
272 Py_XDECREF(readResult);
269 273
270 return res; 274 return res;
271 } 275 }
272 276
273 PyDoc_STRVAR(Decompressor_decompress__doc__, 277 PyDoc_STRVAR(Decompressor_decompress__doc__,
298 "data", 302 "data",
299 "max_output_size", 303 "max_output_size",
300 NULL 304 NULL
301 }; 305 };
302 306
303 const char* source; 307 Py_buffer source;
304 Py_ssize_t sourceSize;
305 Py_ssize_t maxOutputSize = 0; 308 Py_ssize_t maxOutputSize = 0;
306 unsigned long long decompressedSize; 309 unsigned long long decompressedSize;
307 size_t destCapacity; 310 size_t destCapacity;
308 PyObject* result = NULL; 311 PyObject* result = NULL;
309 void* dictData = NULL;
310 size_t dictSize = 0;
311 size_t zresult; 312 size_t zresult;
313 ZSTD_outBuffer outBuffer;
314 ZSTD_inBuffer inBuffer;
312 315
313 #if PY_MAJOR_VERSION >= 3 316 #if PY_MAJOR_VERSION >= 3
314 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "y#|n:decompress", 317 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "y*|n:decompress",
315 #else 318 #else
316 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#|n:decompress", 319 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s*|n:decompress",
317 #endif 320 #endif
318 kwlist, &source, &sourceSize, &maxOutputSize)) { 321 kwlist, &source, &maxOutputSize)) {
319 return NULL; 322 return NULL;
320 } 323 }
321 324
322 if (self->dict) { 325 if (!PyBuffer_IsContiguous(&source, 'C') || source.ndim > 1) {
323 dictData = self->dict->dictData; 326 PyErr_SetString(PyExc_ValueError,
324 dictSize = self->dict->dictSize; 327 "data buffer should be contiguous and have at most one dimension");
325 } 328 goto finally;
326 329 }
327 if (dictData && !self->ddict) { 330
328 Py_BEGIN_ALLOW_THREADS 331 if (ensure_dctx(self, 1)) {
329 self->ddict = ZSTD_createDDict_byReference(dictData, dictSize); 332 goto finally;
330 Py_END_ALLOW_THREADS 333 }
331 334
332 if (!self->ddict) { 335 decompressedSize = ZSTD_getFrameContentSize(source.buf, source.len);
333 PyErr_SetString(ZstdError, "could not create decompression dict"); 336
334 return NULL; 337 if (ZSTD_CONTENTSIZE_ERROR == decompressedSize) {
335 } 338 PyErr_SetString(ZstdError, "error determining content size from frame header");
336 } 339 goto finally;
337 340 }
338 decompressedSize = ZSTD_getDecompressedSize(source, sourceSize); 341 /* Special case of empty frame. */
339 /* 0 returned if content size not in the zstd frame header */ 342 else if (0 == decompressedSize) {
340 if (0 == decompressedSize) { 343 result = PyBytes_FromStringAndSize("", 0);
344 goto finally;
345 }
346 /* Missing content size in frame header. */
347 if (ZSTD_CONTENTSIZE_UNKNOWN == decompressedSize) {
341 if (0 == maxOutputSize) { 348 if (0 == maxOutputSize) {
342 PyErr_SetString(ZstdError, "input data invalid or missing content size " 349 PyErr_SetString(ZstdError, "could not determine content size in frame header");
343 "in frame header"); 350 goto finally;
344 return NULL; 351 }
345 } 352
346 else { 353 result = PyBytes_FromStringAndSize(NULL, maxOutputSize);
347 result = PyBytes_FromStringAndSize(NULL, maxOutputSize); 354 destCapacity = maxOutputSize;
348 destCapacity = maxOutputSize; 355 decompressedSize = 0;
349 } 356 }
350 } 357 /* Size is recorded in frame header. */
351 else { 358 else {
352 result = PyBytes_FromStringAndSize(NULL, decompressedSize); 359 assert(SIZE_MAX >= PY_SSIZE_T_MAX);
353 destCapacity = decompressedSize; 360 if (decompressedSize > PY_SSIZE_T_MAX) {
361 PyErr_SetString(ZstdError, "frame is too large to decompress on this platform");
362 goto finally;
363 }
364
365 result = PyBytes_FromStringAndSize(NULL, (Py_ssize_t)decompressedSize);
366 destCapacity = (size_t)decompressedSize;
354 } 367 }
355 368
356 if (!result) { 369 if (!result) {
357 return NULL; 370 goto finally;
358 } 371 }
372
373 outBuffer.dst = PyBytes_AsString(result);
374 outBuffer.size = destCapacity;
375 outBuffer.pos = 0;
376
377 inBuffer.src = source.buf;
378 inBuffer.size = source.len;
379 inBuffer.pos = 0;
359 380
360 Py_BEGIN_ALLOW_THREADS 381 Py_BEGIN_ALLOW_THREADS
361 if (self->ddict) { 382 zresult = ZSTD_decompress_generic(self->dctx, &outBuffer, &inBuffer);
362 zresult = ZSTD_decompress_usingDDict(self->dctx,
363 PyBytes_AsString(result), destCapacity,
364 source, sourceSize, self->ddict);
365 }
366 else {
367 zresult = ZSTD_decompressDCtx(self->dctx,
368 PyBytes_AsString(result), destCapacity, source, sourceSize);
369 }
370 Py_END_ALLOW_THREADS 383 Py_END_ALLOW_THREADS
371 384
372 if (ZSTD_isError(zresult)) { 385 if (ZSTD_isError(zresult)) {
373 PyErr_Format(ZstdError, "decompression error: %s", ZSTD_getErrorName(zresult)); 386 PyErr_Format(ZstdError, "decompression error: %s", ZSTD_getErrorName(zresult));
374 Py_DECREF(result); 387 Py_CLEAR(result);
375 return NULL; 388 goto finally;
376 } 389 }
377 else if (decompressedSize && zresult != decompressedSize) { 390 else if (zresult) {
391 PyErr_Format(ZstdError, "decompression error: did not decompress full frame");
392 Py_CLEAR(result);
393 goto finally;
394 }
395 else if (decompressedSize && outBuffer.pos != decompressedSize) {
378 PyErr_Format(ZstdError, "decompression error: decompressed %zu bytes; expected %llu", 396 PyErr_Format(ZstdError, "decompression error: decompressed %zu bytes; expected %llu",
379 zresult, decompressedSize); 397 zresult, decompressedSize);
380 Py_DECREF(result); 398 Py_CLEAR(result);
381 return NULL; 399 goto finally;
382 } 400 }
383 else if (zresult < destCapacity) { 401 else if (outBuffer.pos < destCapacity) {
384 if (_PyBytes_Resize(&result, zresult)) { 402 if (safe_pybytes_resize(&result, outBuffer.pos)) {
385 Py_DECREF(result); 403 Py_CLEAR(result);
386 return NULL; 404 goto finally;
387 } 405 }
388 } 406 }
389 407
408 finally:
409 PyBuffer_Release(&source);
390 return result; 410 return result;
391 } 411 }
392 412
393 PyDoc_STRVAR(Decompressor_decompressobj__doc__, 413 PyDoc_STRVAR(Decompressor_decompressobj__doc__,
394 "decompressobj()\n" 414 "decompressobj([write_size=default])\n"
395 "\n" 415 "\n"
396 "Incrementally feed data into a decompressor.\n" 416 "Incrementally feed data into a decompressor.\n"
397 "\n" 417 "\n"
398 "The returned object exposes a ``decompress(data)`` method. This makes it\n" 418 "The returned object exposes a ``decompress(data)`` method. This makes it\n"
399 "compatible with ``zlib.decompressobj`` and ``bz2.BZ2Decompressor`` so that\n" 419 "compatible with ``zlib.decompressobj`` and ``bz2.BZ2Decompressor`` so that\n"
400 "callers can swap in the zstd decompressor while using the same API.\n" 420 "callers can swap in the zstd decompressor while using the same API.\n"
401 ); 421 );
402 422
403 static ZstdDecompressionObj* Decompressor_decompressobj(ZstdDecompressor* self) { 423 static ZstdDecompressionObj* Decompressor_decompressobj(ZstdDecompressor* self, PyObject* args, PyObject* kwargs) {
404 ZstdDecompressionObj* result = (ZstdDecompressionObj*)PyObject_CallObject((PyObject*)&ZstdDecompressionObjType, NULL); 424 static char* kwlist[] = {
425 "write_size",
426 NULL
427 };
428
429 ZstdDecompressionObj* result = NULL;
430 size_t outSize = ZSTD_DStreamOutSize();
431
432 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|k:decompressobj", kwlist, &outSize)) {
433 return NULL;
434 }
435
436 if (!outSize) {
437 PyErr_SetString(PyExc_ValueError, "write_size must be positive");
438 return NULL;
439 }
440
441 result = (ZstdDecompressionObj*)PyObject_CallObject((PyObject*)&ZstdDecompressionObjType, NULL);
405 if (!result) { 442 if (!result) {
406 return NULL; 443 return NULL;
407 } 444 }
408 445
409 if (0 != init_dstream(self)) { 446 if (ensure_dctx(self, 1)) {
410 Py_DECREF(result); 447 Py_DECREF(result);
411 return NULL; 448 return NULL;
412 } 449 }
413 450
414 result->decompressor = self; 451 result->decompressor = self;
415 Py_INCREF(result->decompressor); 452 Py_INCREF(result->decompressor);
453 result->outSize = outSize;
416 454
417 return result; 455 return result;
418 } 456 }
419 457
420 PyDoc_STRVAR(Decompressor_read_from__doc__, 458 PyDoc_STRVAR(Decompressor_read_to_iter__doc__,
421 "read_from(reader[, read_size=default, write_size=default, skip_bytes=0])\n" 459 "read_to_iter(reader[, read_size=default, write_size=default, skip_bytes=0])\n"
422 "Read compressed data and return an iterator\n" 460 "Read compressed data and return an iterator\n"
423 "\n" 461 "\n"
424 "Returns an iterator of decompressed data chunks produced from reading from\n" 462 "Returns an iterator of decompressed data chunks produced from reading from\n"
425 "the ``reader``.\n" 463 "the ``reader``.\n"
426 "\n" 464 "\n"
435 "\n" 473 "\n"
436 "There is also support for skipping the first ``skip_bytes`` of data from\n" 474 "There is also support for skipping the first ``skip_bytes`` of data from\n"
437 "the source.\n" 475 "the source.\n"
438 ); 476 );
439 477
440 static ZstdDecompressorIterator* Decompressor_read_from(ZstdDecompressor* self, PyObject* args, PyObject* kwargs) { 478 static ZstdDecompressorIterator* Decompressor_read_to_iter(ZstdDecompressor* self, PyObject* args, PyObject* kwargs) {
441 static char* kwlist[] = { 479 static char* kwlist[] = {
442 "reader", 480 "reader",
443 "read_size", 481 "read_size",
444 "write_size", 482 "write_size",
445 "skip_bytes", 483 "skip_bytes",
450 size_t inSize = ZSTD_DStreamInSize(); 488 size_t inSize = ZSTD_DStreamInSize();
451 size_t outSize = ZSTD_DStreamOutSize(); 489 size_t outSize = ZSTD_DStreamOutSize();
452 ZstdDecompressorIterator* result; 490 ZstdDecompressorIterator* result;
453 size_t skipBytes = 0; 491 size_t skipBytes = 0;
454 492
455 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|kkk:read_from", kwlist, 493 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|kkk:read_to_iter", kwlist,
456 &reader, &inSize, &outSize, &skipBytes)) { 494 &reader, &inSize, &outSize, &skipBytes)) {
457 return NULL; 495 return NULL;
458 } 496 }
459 497
460 if (skipBytes >= inSize) { 498 if (skipBytes >= inSize) {
472 result->reader = reader; 510 result->reader = reader;
473 Py_INCREF(result->reader); 511 Py_INCREF(result->reader);
474 } 512 }
475 else if (1 == PyObject_CheckBuffer(reader)) { 513 else if (1 == PyObject_CheckBuffer(reader)) {
476 /* Object claims it is a buffer. Try to get a handle to it. */ 514 /* Object claims it is a buffer. Try to get a handle to it. */
477 result->buffer = PyMem_Malloc(sizeof(Py_buffer)); 515 if (0 != PyObject_GetBuffer(reader, &result->buffer, PyBUF_CONTIG_RO)) {
478 if (!result->buffer) {
479 goto except;
480 }
481
482 memset(result->buffer, 0, sizeof(Py_buffer));
483
484 if (0 != PyObject_GetBuffer(reader, result->buffer, PyBUF_CONTIG_RO)) {
485 goto except; 516 goto except;
486 } 517 }
487 } 518 }
488 else { 519 else {
489 PyErr_SetString(PyExc_ValueError, 520 PyErr_SetString(PyExc_ValueError,
496 527
497 result->inSize = inSize; 528 result->inSize = inSize;
498 result->outSize = outSize; 529 result->outSize = outSize;
499 result->skipBytes = skipBytes; 530 result->skipBytes = skipBytes;
500 531
501 if (0 != init_dstream(self)) { 532 if (ensure_dctx(self, 1)) {
502 goto except; 533 goto except;
503 } 534 }
504 535
505 result->input.src = PyMem_Malloc(inSize); 536 result->input.src = PyMem_Malloc(inSize);
506 if (!result->input.src) { 537 if (!result->input.src) {
509 } 540 }
510 541
511 goto finally; 542 goto finally;
512 543
513 except: 544 except:
514 Py_CLEAR(result->reader);
515
516 if (result->buffer) {
517 PyBuffer_Release(result->buffer);
518 Py_CLEAR(result->buffer);
519 }
520
521 Py_CLEAR(result); 545 Py_CLEAR(result);
522 546
523 finally: 547 finally:
524 548
525 return result; 549 return result;
526 } 550 }
527 551
528 PyDoc_STRVAR(Decompressor_write_to__doc__, 552 PyDoc_STRVAR(Decompressor_stream_reader__doc__,
553 "stream_reader(source, [read_size=default])\n"
554 "\n"
555 "Obtain an object that behaves like an I/O stream that can be used for\n"
556 "reading decompressed output from an object.\n"
557 "\n"
558 "The source object can be any object with a ``read(size)`` method or that\n"
559 "conforms to the buffer protocol.\n"
560 );
561
562 static ZstdDecompressionReader* Decompressor_stream_reader(ZstdDecompressor* self, PyObject* args, PyObject* kwargs) {
563 static char* kwlist[] = {
564 "source",
565 "read_size",
566 NULL
567 };
568
569 PyObject* source;
570 size_t readSize = ZSTD_DStreamInSize();
571 ZstdDecompressionReader* result;
572
573 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|k:stream_reader", kwlist,
574 &source, &readSize)) {
575 return NULL;
576 }
577
578 result = (ZstdDecompressionReader*)PyObject_CallObject((PyObject*)&ZstdDecompressionReaderType, NULL);
579 if (NULL == result) {
580 return NULL;
581 }
582
583 if (PyObject_HasAttrString(source, "read")) {
584 result->reader = source;
585 Py_INCREF(source);
586 result->readSize = readSize;
587 }
588 else if (1 == PyObject_CheckBuffer(source)) {
589 if (0 != PyObject_GetBuffer(source, &result->buffer, PyBUF_CONTIG_RO)) {
590 Py_CLEAR(result);
591 return NULL;
592 }
593 }
594 else {
595 PyErr_SetString(PyExc_TypeError,
596 "must pass an object with a read() method or that conforms to the buffer protocol");
597 Py_CLEAR(result);
598 return NULL;
599 }
600
601 result->decompressor = self;
602 Py_INCREF(self);
603
604 return result;
605 }
606
607 PyDoc_STRVAR(Decompressor_stream_writer__doc__,
529 "Create a context manager to write decompressed data to an object.\n" 608 "Create a context manager to write decompressed data to an object.\n"
530 "\n" 609 "\n"
531 "The passed object must have a ``write()`` method.\n" 610 "The passed object must have a ``write()`` method.\n"
532 "\n" 611 "\n"
533 "The caller feeds intput data to the object by calling ``write(data)``.\n" 612 "The caller feeds intput data to the object by calling ``write(data)``.\n"
536 "An optional ``write_size`` argument defines the size of chunks to\n" 615 "An optional ``write_size`` argument defines the size of chunks to\n"
537 "``write()`` to the writer. It defaults to the default output size for a zstd\n" 616 "``write()`` to the writer. It defaults to the default output size for a zstd\n"
538 "streaming decompressor.\n" 617 "streaming decompressor.\n"
539 ); 618 );
540 619
541 static ZstdDecompressionWriter* Decompressor_write_to(ZstdDecompressor* self, PyObject* args, PyObject* kwargs) { 620 static ZstdDecompressionWriter* Decompressor_stream_writer(ZstdDecompressor* self, PyObject* args, PyObject* kwargs) {
542 static char* kwlist[] = { 621 static char* kwlist[] = {
543 "writer", 622 "writer",
544 "write_size", 623 "write_size",
545 NULL 624 NULL
546 }; 625 };
547 626
548 PyObject* writer; 627 PyObject* writer;
549 size_t outSize = ZSTD_DStreamOutSize(); 628 size_t outSize = ZSTD_DStreamOutSize();
550 ZstdDecompressionWriter* result; 629 ZstdDecompressionWriter* result;
551 630
552 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|k:write_to", kwlist, 631 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|k:stream_writer", kwlist,
553 &writer, &outSize)) { 632 &writer, &outSize)) {
554 return NULL; 633 return NULL;
555 } 634 }
556 635
557 if (!PyObject_HasAttrString(writer, "write")) { 636 if (!PyObject_HasAttrString(writer, "write")) {
577 656
578 PyDoc_STRVAR(Decompressor_decompress_content_dict_chain__doc__, 657 PyDoc_STRVAR(Decompressor_decompress_content_dict_chain__doc__,
579 "Decompress a series of chunks using the content dictionary chaining technique\n" 658 "Decompress a series of chunks using the content dictionary chaining technique\n"
580 ); 659 );
581 660
582 static PyObject* Decompressor_decompress_content_dict_chain(PyObject* self, PyObject* args, PyObject* kwargs) { 661 static PyObject* Decompressor_decompress_content_dict_chain(ZstdDecompressor* self, PyObject* args, PyObject* kwargs) {
583 static char* kwlist[] = { 662 static char* kwlist[] = {
584 "frames", 663 "frames",
585 NULL 664 NULL
586 }; 665 };
587 666
590 Py_ssize_t chunkIndex; 669 Py_ssize_t chunkIndex;
591 char parity = 0; 670 char parity = 0;
592 PyObject* chunk; 671 PyObject* chunk;
593 char* chunkData; 672 char* chunkData;
594 Py_ssize_t chunkSize; 673 Py_ssize_t chunkSize;
595 ZSTD_DCtx* dctx = NULL;
596 size_t zresult; 674 size_t zresult;
597 ZSTD_frameParams frameParams; 675 ZSTD_frameHeader frameHeader;
598 void* buffer1 = NULL; 676 void* buffer1 = NULL;
599 size_t buffer1Size = 0; 677 size_t buffer1Size = 0;
600 size_t buffer1ContentSize = 0; 678 size_t buffer1ContentSize = 0;
601 void* buffer2 = NULL; 679 void* buffer2 = NULL;
602 size_t buffer2Size = 0; 680 size_t buffer2Size = 0;
603 size_t buffer2ContentSize = 0; 681 size_t buffer2ContentSize = 0;
604 void* destBuffer = NULL; 682 void* destBuffer = NULL;
605 PyObject* result = NULL; 683 PyObject* result = NULL;
684 ZSTD_outBuffer outBuffer;
685 ZSTD_inBuffer inBuffer;
606 686
607 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!:decompress_content_dict_chain", 687 if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!:decompress_content_dict_chain",
608 kwlist, &PyList_Type, &chunks)) { 688 kwlist, &PyList_Type, &chunks)) {
609 return NULL; 689 return NULL;
610 } 690 }
622 return NULL; 702 return NULL;
623 } 703 }
624 704
625 /* We require that all chunks be zstd frames and that they have content size set. */ 705 /* We require that all chunks be zstd frames and that they have content size set. */
626 PyBytes_AsStringAndSize(chunk, &chunkData, &chunkSize); 706 PyBytes_AsStringAndSize(chunk, &chunkData, &chunkSize);
627 zresult = ZSTD_getFrameParams(&frameParams, (void*)chunkData, chunkSize); 707 zresult = ZSTD_getFrameHeader(&frameHeader, (void*)chunkData, chunkSize);
628 if (ZSTD_isError(zresult)) { 708 if (ZSTD_isError(zresult)) {
629 PyErr_SetString(PyExc_ValueError, "chunk 0 is not a valid zstd frame"); 709 PyErr_SetString(PyExc_ValueError, "chunk 0 is not a valid zstd frame");
630 return NULL; 710 return NULL;
631 } 711 }
632 else if (zresult) { 712 else if (zresult) {
633 PyErr_SetString(PyExc_ValueError, "chunk 0 is too small to contain a zstd frame"); 713 PyErr_SetString(PyExc_ValueError, "chunk 0 is too small to contain a zstd frame");
634 return NULL; 714 return NULL;
635 } 715 }
636 716
637 if (0 == frameParams.frameContentSize) { 717 if (ZSTD_CONTENTSIZE_UNKNOWN == frameHeader.frameContentSize) {
638 PyErr_SetString(PyExc_ValueError, "chunk 0 missing content size in frame"); 718 PyErr_SetString(PyExc_ValueError, "chunk 0 missing content size in frame");
639 return NULL; 719 return NULL;
640 } 720 }
641 721
642 dctx = ZSTD_createDCtx(); 722 assert(ZSTD_CONTENTSIZE_ERROR != frameHeader.frameContentSize);
643 if (!dctx) { 723
644 PyErr_NoMemory(); 724 /* We check against PY_SSIZE_T_MAX here because we ultimately cast the
645 goto finally; 725 * result to a Python object and it's length can be no greater than
646 } 726 * Py_ssize_t. In theory, we could have an intermediate frame that is
647 727 * larger. But a) why would this API be used for frames that large b)
648 buffer1Size = frameParams.frameContentSize; 728 * it isn't worth the complexity to support. */
729 assert(SIZE_MAX >= PY_SSIZE_T_MAX);
730 if (frameHeader.frameContentSize > PY_SSIZE_T_MAX) {
731 PyErr_SetString(PyExc_ValueError,
732 "chunk 0 is too large to decompress on this platform");
733 return NULL;
734 }
735
736 if (ensure_dctx(self, 0)) {
737 goto finally;
738 }
739
740 buffer1Size = (size_t)frameHeader.frameContentSize;
649 buffer1 = PyMem_Malloc(buffer1Size); 741 buffer1 = PyMem_Malloc(buffer1Size);
650 if (!buffer1) { 742 if (!buffer1) {
651 goto finally; 743 goto finally;
652 } 744 }
653 745
746 outBuffer.dst = buffer1;
747 outBuffer.size = buffer1Size;
748 outBuffer.pos = 0;
749
750 inBuffer.src = chunkData;
751 inBuffer.size = chunkSize;
752 inBuffer.pos = 0;
753
654 Py_BEGIN_ALLOW_THREADS 754 Py_BEGIN_ALLOW_THREADS
655 zresult = ZSTD_decompressDCtx(dctx, buffer1, buffer1Size, chunkData, chunkSize); 755 zresult = ZSTD_decompress_generic(self->dctx, &outBuffer, &inBuffer);
656 Py_END_ALLOW_THREADS 756 Py_END_ALLOW_THREADS
657 if (ZSTD_isError(zresult)) { 757 if (ZSTD_isError(zresult)) {
658 PyErr_Format(ZstdError, "could not decompress chunk 0: %s", ZSTD_getErrorName(zresult)); 758 PyErr_Format(ZstdError, "could not decompress chunk 0: %s", ZSTD_getErrorName(zresult));
659 goto finally; 759 goto finally;
660 } 760 }
661 761 else if (zresult) {
662 buffer1ContentSize = zresult; 762 PyErr_Format(ZstdError, "chunk 0 did not decompress full frame");
763 goto finally;
764 }
765
766 buffer1ContentSize = outBuffer.pos;
663 767
664 /* Special case of a simple chain. */ 768 /* Special case of a simple chain. */
665 if (1 == chunksLen) { 769 if (1 == chunksLen) {
666 result = PyBytes_FromStringAndSize(buffer1, buffer1Size); 770 result = PyBytes_FromStringAndSize(buffer1, buffer1Size);
667 goto finally; 771 goto finally;
668 } 772 }
669 773
670 /* This should ideally look at next chunk. But this is slightly simpler. */ 774 /* This should ideally look at next chunk. But this is slightly simpler. */
671 buffer2Size = frameParams.frameContentSize; 775 buffer2Size = (size_t)frameHeader.frameContentSize;
672 buffer2 = PyMem_Malloc(buffer2Size); 776 buffer2 = PyMem_Malloc(buffer2Size);
673 if (!buffer2) { 777 if (!buffer2) {
674 goto finally; 778 goto finally;
675 } 779 }
676 780
686 PyErr_Format(PyExc_ValueError, "chunk %zd must be bytes", chunkIndex); 790 PyErr_Format(PyExc_ValueError, "chunk %zd must be bytes", chunkIndex);
687 goto finally; 791 goto finally;
688 } 792 }
689 793
690 PyBytes_AsStringAndSize(chunk, &chunkData, &chunkSize); 794 PyBytes_AsStringAndSize(chunk, &chunkData, &chunkSize);
691 zresult = ZSTD_getFrameParams(&frameParams, (void*)chunkData, chunkSize); 795 zresult = ZSTD_getFrameHeader(&frameHeader, (void*)chunkData, chunkSize);
692 if (ZSTD_isError(zresult)) { 796 if (ZSTD_isError(zresult)) {
693 PyErr_Format(PyExc_ValueError, "chunk %zd is not a valid zstd frame", chunkIndex); 797 PyErr_Format(PyExc_ValueError, "chunk %zd is not a valid zstd frame", chunkIndex);
694 goto finally; 798 goto finally;
695 } 799 }
696 else if (zresult) { 800 else if (zresult) {
697 PyErr_Format(PyExc_ValueError, "chunk %zd is too small to contain a zstd frame", chunkIndex); 801 PyErr_Format(PyExc_ValueError, "chunk %zd is too small to contain a zstd frame", chunkIndex);
698 goto finally; 802 goto finally;
699 } 803 }
700 804
701 if (0 == frameParams.frameContentSize) { 805 if (ZSTD_CONTENTSIZE_UNKNOWN == frameHeader.frameContentSize) {
702 PyErr_Format(PyExc_ValueError, "chunk %zd missing content size in frame", chunkIndex); 806 PyErr_Format(PyExc_ValueError, "chunk %zd missing content size in frame", chunkIndex);
703 goto finally; 807 goto finally;
704 } 808 }
809
810 assert(ZSTD_CONTENTSIZE_ERROR != frameHeader.frameContentSize);
811
812 if (frameHeader.frameContentSize > PY_SSIZE_T_MAX) {
813 PyErr_Format(PyExc_ValueError,
814 "chunk %zd is too large to decompress on this platform", chunkIndex);
815 goto finally;
816 }
817
818 inBuffer.src = chunkData;
819 inBuffer.size = chunkSize;
820 inBuffer.pos = 0;
705 821
706 parity = chunkIndex % 2; 822 parity = chunkIndex % 2;
707 823
708 /* This could definitely be abstracted to reduce code duplication. */ 824 /* This could definitely be abstracted to reduce code duplication. */
709 if (parity) { 825 if (parity) {
710 /* Resize destination buffer to hold larger content. */ 826 /* Resize destination buffer to hold larger content. */
711 if (buffer2Size < frameParams.frameContentSize) { 827 if (buffer2Size < frameHeader.frameContentSize) {
712 buffer2Size = frameParams.frameContentSize; 828 buffer2Size = (size_t)frameHeader.frameContentSize;
713 destBuffer = PyMem_Realloc(buffer2, buffer2Size); 829 destBuffer = PyMem_Realloc(buffer2, buffer2Size);
714 if (!destBuffer) { 830 if (!destBuffer) {
715 goto finally; 831 goto finally;
716 } 832 }
717 buffer2 = destBuffer; 833 buffer2 = destBuffer;
718 } 834 }
719 835
720 Py_BEGIN_ALLOW_THREADS 836 Py_BEGIN_ALLOW_THREADS
721 zresult = ZSTD_decompress_usingDict(dctx, buffer2, buffer2Size, 837 zresult = ZSTD_DCtx_refPrefix_advanced(self->dctx,
722 chunkData, chunkSize, buffer1, buffer1ContentSize); 838 buffer1, buffer1ContentSize, ZSTD_dct_rawContent);
839 Py_END_ALLOW_THREADS
840 if (ZSTD_isError(zresult)) {
841 PyErr_Format(ZstdError,
842 "failed to load prefix dictionary at chunk %zd", chunkIndex);
843 goto finally;
844 }
845
846 outBuffer.dst = buffer2;
847 outBuffer.size = buffer2Size;
848 outBuffer.pos = 0;
849
850 Py_BEGIN_ALLOW_THREADS
851 zresult = ZSTD_decompress_generic(self->dctx, &outBuffer, &inBuffer);
723 Py_END_ALLOW_THREADS 852 Py_END_ALLOW_THREADS
724 if (ZSTD_isError(zresult)) { 853 if (ZSTD_isError(zresult)) {
725 PyErr_Format(ZstdError, "could not decompress chunk %zd: %s", 854 PyErr_Format(ZstdError, "could not decompress chunk %zd: %s",
726 chunkIndex, ZSTD_getErrorName(zresult)); 855 chunkIndex, ZSTD_getErrorName(zresult));
727 goto finally; 856 goto finally;
728 } 857 }
729 buffer2ContentSize = zresult; 858 else if (zresult) {
859 PyErr_Format(ZstdError, "chunk %zd did not decompress full frame",
860 chunkIndex);
861 goto finally;
862 }
863
864 buffer2ContentSize = outBuffer.pos;
730 } 865 }
731 else { 866 else {
732 if (buffer1Size < frameParams.frameContentSize) { 867 if (buffer1Size < frameHeader.frameContentSize) {
733 buffer1Size = frameParams.frameContentSize; 868 buffer1Size = (size_t)frameHeader.frameContentSize;
734 destBuffer = PyMem_Realloc(buffer1, buffer1Size); 869 destBuffer = PyMem_Realloc(buffer1, buffer1Size);
735 if (!destBuffer) { 870 if (!destBuffer) {
736 goto finally; 871 goto finally;
737 } 872 }
738 buffer1 = destBuffer; 873 buffer1 = destBuffer;
739 } 874 }
740 875
741 Py_BEGIN_ALLOW_THREADS 876 Py_BEGIN_ALLOW_THREADS
742 zresult = ZSTD_decompress_usingDict(dctx, buffer1, buffer1Size, 877 zresult = ZSTD_DCtx_refPrefix_advanced(self->dctx,
743 chunkData, chunkSize, buffer2, buffer2ContentSize); 878 buffer2, buffer2ContentSize, ZSTD_dct_rawContent);
879 Py_END_ALLOW_THREADS
880 if (ZSTD_isError(zresult)) {
881 PyErr_Format(ZstdError,
882 "failed to load prefix dictionary at chunk %zd", chunkIndex);
883 goto finally;
884 }
885
886 outBuffer.dst = buffer1;
887 outBuffer.size = buffer1Size;
888 outBuffer.pos = 0;
889
890 Py_BEGIN_ALLOW_THREADS
891 zresult = ZSTD_decompress_generic(self->dctx, &outBuffer, &inBuffer);
744 Py_END_ALLOW_THREADS 892 Py_END_ALLOW_THREADS
745 if (ZSTD_isError(zresult)) { 893 if (ZSTD_isError(zresult)) {
746 PyErr_Format(ZstdError, "could not decompress chunk %zd: %s", 894 PyErr_Format(ZstdError, "could not decompress chunk %zd: %s",
747 chunkIndex, ZSTD_getErrorName(zresult)); 895 chunkIndex, ZSTD_getErrorName(zresult));
748 goto finally; 896 goto finally;
749 } 897 }
750 buffer1ContentSize = zresult; 898 else if (zresult) {
899 PyErr_Format(ZstdError, "chunk %zd did not decompress full frame",
900 chunkIndex);
901 goto finally;
902 }
903
904 buffer1ContentSize = outBuffer.pos;
751 } 905 }
752 } 906 }
753 907
754 result = PyBytes_FromStringAndSize(parity ? buffer2 : buffer1, 908 result = PyBytes_FromStringAndSize(parity ? buffer2 : buffer1,
755 parity ? buffer2ContentSize : buffer1ContentSize); 909 parity ? buffer2ContentSize : buffer1ContentSize);
760 } 914 }
761 if (buffer1) { 915 if (buffer1) {
762 PyMem_Free(buffer1); 916 PyMem_Free(buffer1);
763 } 917 }
764 918
765 if (dctx) {
766 ZSTD_freeDCtx(dctx);
767 }
768
769 return result; 919 return result;
770 } 920 }
771 921
772 typedef struct { 922 typedef struct {
773 void* sourceData; 923 void* sourceData;
774 size_t sourceSize; 924 size_t sourceSize;
775 unsigned long long destSize; 925 size_t destSize;
776 } FramePointer; 926 } FramePointer;
777 927
778 typedef struct { 928 typedef struct {
779 FramePointer* frames; 929 FramePointer* frames;
780 Py_ssize_t framesSize; 930 Py_ssize_t framesSize;
804 Py_ssize_t endOffset; 954 Py_ssize_t endOffset;
805 unsigned long long totalSourceSize; 955 unsigned long long totalSourceSize;
806 956
807 /* Compression state and settings. */ 957 /* Compression state and settings. */
808 ZSTD_DCtx* dctx; 958 ZSTD_DCtx* dctx;
809 ZSTD_DDict* ddict;
810 int requireOutputSizes; 959 int requireOutputSizes;
811 960
812 /* Output storage. */ 961 /* Output storage. */
813 DestBuffer* destBuffers; 962 DestBuffer* destBuffers;
814 Py_ssize_t destCount; 963 Py_ssize_t destCount;
836 985
837 assert(NULL == state->destBuffers); 986 assert(NULL == state->destBuffers);
838 assert(0 == state->destCount); 987 assert(0 == state->destCount);
839 assert(state->endOffset - state->startOffset >= 0); 988 assert(state->endOffset - state->startOffset >= 0);
840 989
990 /* We could get here due to the way work is allocated. Ideally we wouldn't
991 get here. But that would require a bit of a refactor in the caller. */
992 if (state->totalSourceSize > SIZE_MAX) {
993 state->error = WorkerError_memory;
994 state->errorOffset = 0;
995 return;
996 }
997
841 /* 998 /*
842 * We need to allocate a buffer to hold decompressed data. How we do this 999 * We need to allocate a buffer to hold decompressed data. How we do this
843 * depends on what we know about the output. The following scenarios are 1000 * depends on what we know about the output. The following scenarios are
844 * possible: 1001 * possible:
845 * 1002 *
851 */ 1008 */
852 1009
853 /* Resolve ouput segments. */ 1010 /* Resolve ouput segments. */
854 for (frameIndex = state->startOffset; frameIndex <= state->endOffset; frameIndex++) { 1011 for (frameIndex = state->startOffset; frameIndex <= state->endOffset; frameIndex++) {
855 FramePointer* fp = &framePointers[frameIndex]; 1012 FramePointer* fp = &framePointers[frameIndex];
1013 unsigned long long decompressedSize;
856 1014
857 if (0 == fp->destSize) { 1015 if (0 == fp->destSize) {
858 fp->destSize = ZSTD_getDecompressedSize(fp->sourceData, fp->sourceSize); 1016 decompressedSize = ZSTD_getFrameContentSize(fp->sourceData, fp->sourceSize);
859 if (0 == fp->destSize && state->requireOutputSizes) { 1017
1018 if (ZSTD_CONTENTSIZE_ERROR == decompressedSize) {
860 state->error = WorkerError_unknownSize; 1019 state->error = WorkerError_unknownSize;
861 state->errorOffset = frameIndex; 1020 state->errorOffset = frameIndex;
862 return; 1021 return;
863 } 1022 }
1023 else if (ZSTD_CONTENTSIZE_UNKNOWN == decompressedSize) {
1024 if (state->requireOutputSizes) {
1025 state->error = WorkerError_unknownSize;
1026 state->errorOffset = frameIndex;
1027 return;
1028 }
1029
1030 /* This will fail the assert for .destSize > 0 below. */
1031 decompressedSize = 0;
1032 }
1033
1034 if (decompressedSize > SIZE_MAX) {
1035 state->error = WorkerError_memory;
1036 state->errorOffset = frameIndex;
1037 return;
1038 }
1039
1040 fp->destSize = (size_t)decompressedSize;
864 } 1041 }
865 1042
866 totalOutputSize += fp->destSize; 1043 totalOutputSize += fp->destSize;
867 } 1044 }
868 1045
876 1053
877 destBuffer = &state->destBuffers[state->destCount - 1]; 1054 destBuffer = &state->destBuffers[state->destCount - 1];
878 1055
879 assert(framePointers[state->startOffset].destSize > 0); /* For now. */ 1056 assert(framePointers[state->startOffset].destSize > 0); /* For now. */
880 1057
881 allocationSize = roundpow2(state->totalSourceSize); 1058 allocationSize = roundpow2((size_t)state->totalSourceSize);
882 1059
883 if (framePointers[state->startOffset].destSize > allocationSize) { 1060 if (framePointers[state->startOffset].destSize > allocationSize) {
884 allocationSize = roundpow2(framePointers[state->startOffset].destSize); 1061 allocationSize = roundpow2(framePointers[state->startOffset].destSize);
885 } 1062 }
886 1063
900 } 1077 }
901 1078
902 destBuffer->segmentsSize = remainingItems; 1079 destBuffer->segmentsSize = remainingItems;
903 1080
904 for (frameIndex = state->startOffset; frameIndex <= state->endOffset; frameIndex++) { 1081 for (frameIndex = state->startOffset; frameIndex <= state->endOffset; frameIndex++) {
1082 ZSTD_outBuffer outBuffer;
1083 ZSTD_inBuffer inBuffer;
905 const void* source = framePointers[frameIndex].sourceData; 1084 const void* source = framePointers[frameIndex].sourceData;
906 const size_t sourceSize = framePointers[frameIndex].sourceSize; 1085 const size_t sourceSize = framePointers[frameIndex].sourceSize;
907 void* dest; 1086 void* dest;
908 const size_t decompressedSize = framePointers[frameIndex].destSize; 1087 const size_t decompressedSize = framePointers[frameIndex].destSize;
909 size_t destAvailable = destBuffer->destSize - destOffset; 1088 size_t destAvailable = destBuffer->destSize - destOffset;
954 destBuffer = &state->destBuffers[state->destCount - 1]; 1133 destBuffer = &state->destBuffers[state->destCount - 1];
955 1134
956 /* Don't take any chances will non-NULL pointers. */ 1135 /* Don't take any chances will non-NULL pointers. */
957 memset(destBuffer, 0, sizeof(DestBuffer)); 1136 memset(destBuffer, 0, sizeof(DestBuffer));
958 1137
959 allocationSize = roundpow2(state->totalSourceSize); 1138 allocationSize = roundpow2((size_t)state->totalSourceSize);
960 1139
961 if (decompressedSize > allocationSize) { 1140 if (decompressedSize > allocationSize) {
962 allocationSize = roundpow2(decompressedSize); 1141 allocationSize = roundpow2(decompressedSize);
963 } 1142 }
964 1143
983 currentBufferStartIndex = frameIndex; 1162 currentBufferStartIndex = frameIndex;
984 } 1163 }
985 1164
986 dest = (char*)destBuffer->dest + destOffset; 1165 dest = (char*)destBuffer->dest + destOffset;
987 1166
988 if (state->ddict) { 1167 outBuffer.dst = dest;
989 zresult = ZSTD_decompress_usingDDict(state->dctx, dest, decompressedSize, 1168 outBuffer.size = decompressedSize;
990 source, sourceSize, state->ddict); 1169 outBuffer.pos = 0;
991 } 1170
992 else { 1171 inBuffer.src = source;
993 zresult = ZSTD_decompressDCtx(state->dctx, dest, decompressedSize, 1172 inBuffer.size = sourceSize;
994 source, sourceSize); 1173 inBuffer.pos = 0;
995 } 1174
996 1175 zresult = ZSTD_decompress_generic(state->dctx, &outBuffer, &inBuffer);
997 if (ZSTD_isError(zresult)) { 1176 if (ZSTD_isError(zresult)) {
998 state->error = WorkerError_zstd; 1177 state->error = WorkerError_zstd;
999 state->zresult = zresult; 1178 state->zresult = zresult;
1000 state->errorOffset = frameIndex; 1179 state->errorOffset = frameIndex;
1001 return; 1180 return;
1002 } 1181 }
1003 else if (zresult != decompressedSize) { 1182 else if (zresult || outBuffer.pos != decompressedSize) {
1004 state->error = WorkerError_sizeMismatch; 1183 state->error = WorkerError_sizeMismatch;
1005 state->zresult = zresult; 1184 state->zresult = outBuffer.pos;
1006 state->errorOffset = frameIndex; 1185 state->errorOffset = frameIndex;
1007 return; 1186 return;
1008 } 1187 }
1009 1188
1010 destBuffer->segments[localOffset].offset = destOffset; 1189 destBuffer->segments[localOffset].offset = destOffset;
1011 destBuffer->segments[localOffset].length = decompressedSize; 1190 destBuffer->segments[localOffset].length = outBuffer.pos;
1012 destOffset += zresult; 1191 destOffset += outBuffer.pos;
1013 localOffset++; 1192 localOffset++;
1014 remainingItems--; 1193 remainingItems--;
1015 } 1194 }
1016 1195
1017 if (destBuffer->destSize > destOffset) { 1196 if (destBuffer->destSize > destOffset) {
1025 destBuffer->destSize = destOffset; 1204 destBuffer->destSize = destOffset;
1026 } 1205 }
1027 } 1206 }
1028 1207
1029 ZstdBufferWithSegmentsCollection* decompress_from_framesources(ZstdDecompressor* decompressor, FrameSources* frames, 1208 ZstdBufferWithSegmentsCollection* decompress_from_framesources(ZstdDecompressor* decompressor, FrameSources* frames,
1030 unsigned int threadCount) { 1209 Py_ssize_t threadCount) {
1031 void* dictData = NULL;
1032 size_t dictSize = 0;
1033 Py_ssize_t i = 0; 1210 Py_ssize_t i = 0;
1034 int errored = 0; 1211 int errored = 0;
1035 Py_ssize_t segmentsCount; 1212 Py_ssize_t segmentsCount;
1036 ZstdBufferWithSegments* bws = NULL; 1213 ZstdBufferWithSegments* bws = NULL;
1037 PyObject* resultArg = NULL; 1214 PyObject* resultArg = NULL;
1038 Py_ssize_t resultIndex; 1215 Py_ssize_t resultIndex;
1039 ZstdBufferWithSegmentsCollection* result = NULL; 1216 ZstdBufferWithSegmentsCollection* result = NULL;
1040 FramePointer* framePointers = frames->frames; 1217 FramePointer* framePointers = frames->frames;
1041 unsigned long long workerBytes = 0; 1218 unsigned long long workerBytes = 0;
1042 int currentThread = 0; 1219 Py_ssize_t currentThread = 0;
1043 Py_ssize_t workerStartOffset = 0; 1220 Py_ssize_t workerStartOffset = 0;
1044 POOL_ctx* pool = NULL; 1221 POOL_ctx* pool = NULL;
1045 WorkerState* workerStates = NULL; 1222 WorkerState* workerStates = NULL;
1046 unsigned long long bytesPerWorker; 1223 unsigned long long bytesPerWorker;
1047 1224
1048 /* Caller should normalize 0 and negative values to 1 or larger. */ 1225 /* Caller should normalize 0 and negative values to 1 or larger. */
1049 assert(threadCount >= 1); 1226 assert(threadCount >= 1);
1050 1227
1051 /* More threads than inputs makes no sense under any conditions. */ 1228 /* More threads than inputs makes no sense under any conditions. */
1052 threadCount = frames->framesSize < threadCount ? (unsigned int)frames->framesSize 1229 threadCount = frames->framesSize < threadCount ? frames->framesSize
1053 : threadCount; 1230 : threadCount;
1054 1231
1055 /* TODO lower thread count if input size is too small and threads would just 1232 /* TODO lower thread count if input size is too small and threads would just
1056 add overhead. */ 1233 add overhead. */
1057 1234
1058 if (decompressor->dict) { 1235 if (decompressor->dict) {
1059 dictData = decompressor->dict->dictData; 1236 if (ensure_ddict(decompressor->dict)) {
1060 dictSize = decompressor->dict->dictSize;
1061 }
1062
1063 if (dictData && !decompressor->ddict) {
1064 Py_BEGIN_ALLOW_THREADS
1065 decompressor->ddict = ZSTD_createDDict_byReference(dictData, dictSize);
1066 Py_END_ALLOW_THREADS
1067
1068 if (!decompressor->ddict) {
1069 PyErr_SetString(ZstdError, "could not create decompression dict");
1070 return NULL; 1237 return NULL;
1071 } 1238 }
1072 } 1239 }
1073 1240
1074 /* If threadCount==1, we don't start a thread pool. But we do leverage the 1241 /* If threadCount==1, we don't start a thread pool. But we do leverage the
1089 } 1256 }
1090 } 1257 }
1091 1258
1092 bytesPerWorker = frames->compressedSize / threadCount; 1259 bytesPerWorker = frames->compressedSize / threadCount;
1093 1260
1261 if (bytesPerWorker > SIZE_MAX) {
1262 PyErr_SetString(ZstdError, "too much data per worker for this platform");
1263 goto finally;
1264 }
1265
1094 for (i = 0; i < threadCount; i++) { 1266 for (i = 0; i < threadCount; i++) {
1267 size_t zresult;
1268
1095 workerStates[i].dctx = ZSTD_createDCtx(); 1269 workerStates[i].dctx = ZSTD_createDCtx();
1096 if (NULL == workerStates[i].dctx) { 1270 if (NULL == workerStates[i].dctx) {
1097 PyErr_NoMemory(); 1271 PyErr_NoMemory();
1098 goto finally; 1272 goto finally;
1099 } 1273 }
1100 1274
1101 ZSTD_copyDCtx(workerStates[i].dctx, decompressor->dctx); 1275 ZSTD_copyDCtx(workerStates[i].dctx, decompressor->dctx);
1102 1276
1103 workerStates[i].ddict = decompressor->ddict; 1277 if (decompressor->dict) {
1278 zresult = ZSTD_DCtx_refDDict(workerStates[i].dctx, decompressor->dict->ddict);
1279 if (zresult) {
1280 PyErr_Format(ZstdError, "unable to reference prepared dictionary: %s",
1281 ZSTD_getErrorName(zresult));
1282 goto finally;
1283 }
1284 }
1285
1104 workerStates[i].framePointers = framePointers; 1286 workerStates[i].framePointers = framePointers;
1105 workerStates[i].requireOutputSizes = 1; 1287 workerStates[i].requireOutputSizes = 1;
1106 } 1288 }
1107 1289
1108 Py_BEGIN_ALLOW_THREADS 1290 Py_BEGIN_ALLOW_THREADS
1176 PyErr_NoMemory(); 1358 PyErr_NoMemory();
1177 errored = 1; 1359 errored = 1;
1178 break; 1360 break;
1179 1361
1180 case WorkerError_sizeMismatch: 1362 case WorkerError_sizeMismatch:
1181 PyErr_Format(ZstdError, "error decompressing item %zd: decompressed %zu bytes; expected %llu", 1363 PyErr_Format(ZstdError, "error decompressing item %zd: decompressed %zu bytes; expected %zu",
1182 workerStates[i].errorOffset, workerStates[i].zresult, 1364 workerStates[i].errorOffset, workerStates[i].zresult,
1183 framePointers[workerStates[i].errorOffset].destSize); 1365 framePointers[workerStates[i].errorOffset].destSize);
1184 errored = 1; 1366 errored = 1;
1185 break; 1367 break;
1186 1368
1386 1568
1387 if (frameSizesP) { 1569 if (frameSizesP) {
1388 decompressedSize = frameSizesP[i]; 1570 decompressedSize = frameSizesP[i];
1389 } 1571 }
1390 1572
1573 if (sourceSize > SIZE_MAX) {
1574 PyErr_Format(PyExc_ValueError,
1575 "item %zd is too large for this platform", i);
1576 goto finally;
1577 }
1578
1579 if (decompressedSize > SIZE_MAX) {
1580 PyErr_Format(PyExc_ValueError,
1581 "decompressed size of item %zd is too large for this platform", i);
1582 goto finally;
1583 }
1584
1391 framePointers[i].sourceData = sourceData; 1585 framePointers[i].sourceData = sourceData;
1392 framePointers[i].sourceSize = sourceSize; 1586 framePointers[i].sourceSize = (size_t)sourceSize;
1393 framePointers[i].destSize = decompressedSize; 1587 framePointers[i].destSize = (size_t)decompressedSize;
1394 } 1588 }
1395 } 1589 }
1396 else if (PyObject_TypeCheck(frames, &ZstdBufferWithSegmentsCollectionType)) { 1590 else if (PyObject_TypeCheck(frames, &ZstdBufferWithSegmentsCollectionType)) {
1397 Py_ssize_t offset = 0; 1591 Py_ssize_t offset = 0;
1398 ZstdBufferWithSegments* buffer; 1592 ZstdBufferWithSegments* buffer;
1417 for (i = 0; i < collection->bufferCount; i++) { 1611 for (i = 0; i < collection->bufferCount; i++) {
1418 Py_ssize_t segmentIndex; 1612 Py_ssize_t segmentIndex;
1419 buffer = collection->buffers[i]; 1613 buffer = collection->buffers[i];
1420 1614
1421 for (segmentIndex = 0; segmentIndex < buffer->segmentCount; segmentIndex++) { 1615 for (segmentIndex = 0; segmentIndex < buffer->segmentCount; segmentIndex++) {
1616 unsigned long long decompressedSize = frameSizesP ? frameSizesP[offset] : 0;
1617
1422 if (buffer->segments[segmentIndex].offset + buffer->segments[segmentIndex].length > buffer->dataSize) { 1618 if (buffer->segments[segmentIndex].offset + buffer->segments[segmentIndex].length > buffer->dataSize) {
1423 PyErr_Format(PyExc_ValueError, "item %zd has offset outside memory area", 1619 PyErr_Format(PyExc_ValueError, "item %zd has offset outside memory area",
1424 offset); 1620 offset);
1425 goto finally; 1621 goto finally;
1426 } 1622 }
1427 1623
1624 if (buffer->segments[segmentIndex].length > SIZE_MAX) {
1625 PyErr_Format(PyExc_ValueError,
1626 "item %zd in buffer %zd is too large for this platform",
1627 segmentIndex, i);
1628 goto finally;
1629 }
1630
1631 if (decompressedSize > SIZE_MAX) {
1632 PyErr_Format(PyExc_ValueError,
1633 "decompressed size of item %zd in buffer %zd is too large for this platform",
1634 segmentIndex, i);
1635 goto finally;
1636 }
1637
1428 totalInputSize += buffer->segments[segmentIndex].length; 1638 totalInputSize += buffer->segments[segmentIndex].length;
1429 1639
1430 framePointers[offset].sourceData = (char*)buffer->data + buffer->segments[segmentIndex].offset; 1640 framePointers[offset].sourceData = (char*)buffer->data + buffer->segments[segmentIndex].offset;
1431 framePointers[offset].sourceSize = buffer->segments[segmentIndex].length; 1641 framePointers[offset].sourceSize = (size_t)buffer->segments[segmentIndex].length;
1432 framePointers[offset].destSize = frameSizesP ? frameSizesP[offset] : 0; 1642 framePointers[offset].destSize = (size_t)decompressedSize;
1433 1643
1434 offset++; 1644 offset++;
1435 } 1645 }
1436 } 1646 }
1437 } 1647 }
1448 if (!framePointers) { 1658 if (!framePointers) {
1449 PyErr_NoMemory(); 1659 PyErr_NoMemory();
1450 goto finally; 1660 goto finally;
1451 } 1661 }
1452 1662
1453 /*
1454 * It is not clear whether Py_buffer.buf is still valid after
1455 * PyBuffer_Release. So, we hold a reference to all Py_buffer instances
1456 * for the duration of the operation.
1457 */
1458 frameBuffers = PyMem_Malloc(frameCount * sizeof(Py_buffer)); 1663 frameBuffers = PyMem_Malloc(frameCount * sizeof(Py_buffer));
1459 if (NULL == frameBuffers) { 1664 if (NULL == frameBuffers) {
1460 PyErr_NoMemory(); 1665 PyErr_NoMemory();
1461 goto finally; 1666 goto finally;
1462 } 1667 }
1463 1668
1464 memset(frameBuffers, 0, frameCount * sizeof(Py_buffer)); 1669 memset(frameBuffers, 0, frameCount * sizeof(Py_buffer));
1465 1670
1466 /* Do a pass to assemble info about our input buffers and output sizes. */ 1671 /* Do a pass to assemble info about our input buffers and output sizes. */
1467 for (i = 0; i < frameCount; i++) { 1672 for (i = 0; i < frameCount; i++) {
1673 unsigned long long decompressedSize = frameSizesP ? frameSizesP[i] : 0;
1674
1468 if (0 != PyObject_GetBuffer(PyList_GET_ITEM(frames, i), 1675 if (0 != PyObject_GetBuffer(PyList_GET_ITEM(frames, i),
1469 &frameBuffers[i], PyBUF_CONTIG_RO)) { 1676 &frameBuffers[i], PyBUF_CONTIG_RO)) {
1470 PyErr_Clear(); 1677 PyErr_Clear();
1471 PyErr_Format(PyExc_TypeError, "item %zd not a bytes like object", i); 1678 PyErr_Format(PyExc_TypeError, "item %zd not a bytes like object", i);
1472 goto finally; 1679 goto finally;
1473 } 1680 }
1474 1681
1682 if (decompressedSize > SIZE_MAX) {
1683 PyErr_Format(PyExc_ValueError,
1684 "decompressed size of item %zd is too large for this platform", i);
1685 goto finally;
1686 }
1687
1475 totalInputSize += frameBuffers[i].len; 1688 totalInputSize += frameBuffers[i].len;
1476 1689
1477 framePointers[i].sourceData = frameBuffers[i].buf; 1690 framePointers[i].sourceData = frameBuffers[i].buf;
1478 framePointers[i].sourceSize = frameBuffers[i].len; 1691 framePointers[i].sourceSize = frameBuffers[i].len;
1479 framePointers[i].destSize = frameSizesP ? frameSizesP[i] : 0; 1692 framePointers[i].destSize = (size_t)decompressedSize;
1480 } 1693 }
1481 } 1694 }
1482 else { 1695 else {
1483 PyErr_SetString(PyExc_TypeError, "argument must be list or BufferWithSegments"); 1696 PyErr_SetString(PyExc_TypeError, "argument must be list or BufferWithSegments");
1484 goto finally; 1697 goto finally;
1512 static PyMethodDef Decompressor_methods[] = { 1725 static PyMethodDef Decompressor_methods[] = {
1513 { "copy_stream", (PyCFunction)Decompressor_copy_stream, METH_VARARGS | METH_KEYWORDS, 1726 { "copy_stream", (PyCFunction)Decompressor_copy_stream, METH_VARARGS | METH_KEYWORDS,
1514 Decompressor_copy_stream__doc__ }, 1727 Decompressor_copy_stream__doc__ },
1515 { "decompress", (PyCFunction)Decompressor_decompress, METH_VARARGS | METH_KEYWORDS, 1728 { "decompress", (PyCFunction)Decompressor_decompress, METH_VARARGS | METH_KEYWORDS,
1516 Decompressor_decompress__doc__ }, 1729 Decompressor_decompress__doc__ },
1517 { "decompressobj", (PyCFunction)Decompressor_decompressobj, METH_NOARGS, 1730 { "decompressobj", (PyCFunction)Decompressor_decompressobj, METH_VARARGS | METH_KEYWORDS,
1518 Decompressor_decompressobj__doc__ }, 1731 Decompressor_decompressobj__doc__ },
1519 { "read_from", (PyCFunction)Decompressor_read_from, METH_VARARGS | METH_KEYWORDS, 1732 { "read_to_iter", (PyCFunction)Decompressor_read_to_iter, METH_VARARGS | METH_KEYWORDS,
1520 Decompressor_read_from__doc__ }, 1733 Decompressor_read_to_iter__doc__ },
1521 { "write_to", (PyCFunction)Decompressor_write_to, METH_VARARGS | METH_KEYWORDS, 1734 /* TODO Remove deprecated API */
1522 Decompressor_write_to__doc__ }, 1735 { "read_from", (PyCFunction)Decompressor_read_to_iter, METH_VARARGS | METH_KEYWORDS,
1736 Decompressor_read_to_iter__doc__ },
1737 { "stream_reader", (PyCFunction)Decompressor_stream_reader,
1738 METH_VARARGS | METH_KEYWORDS, Decompressor_stream_reader__doc__ },
1739 { "stream_writer", (PyCFunction)Decompressor_stream_writer, METH_VARARGS | METH_KEYWORDS,
1740 Decompressor_stream_writer__doc__ },
1741 /* TODO remove deprecated API */
1742 { "write_to", (PyCFunction)Decompressor_stream_writer, METH_VARARGS | METH_KEYWORDS,
1743 Decompressor_stream_writer__doc__ },
1523 { "decompress_content_dict_chain", (PyCFunction)Decompressor_decompress_content_dict_chain, 1744 { "decompress_content_dict_chain", (PyCFunction)Decompressor_decompress_content_dict_chain,
1524 METH_VARARGS | METH_KEYWORDS, Decompressor_decompress_content_dict_chain__doc__ }, 1745 METH_VARARGS | METH_KEYWORDS, Decompressor_decompress_content_dict_chain__doc__ },
1525 { "multi_decompress_to_buffer", (PyCFunction)Decompressor_multi_decompress_to_buffer, 1746 { "multi_decompress_to_buffer", (PyCFunction)Decompressor_multi_decompress_to_buffer,
1526 METH_VARARGS | METH_KEYWORDS, Decompressor_multi_decompress_to_buffer__doc__ }, 1747 METH_VARARGS | METH_KEYWORDS, Decompressor_multi_decompress_to_buffer__doc__ },
1748 { "memory_size", (PyCFunction)Decompressor_memory_size, METH_NOARGS,
1749 Decompressor_memory_size__doc__ },
1527 { NULL, NULL } 1750 { NULL, NULL }
1528 }; 1751 };
1529 1752
1530 PyTypeObject ZstdDecompressorType = { 1753 PyTypeObject ZstdDecompressorType = {
1531 PyVarObject_HEAD_INIT(NULL, 0) 1754 PyVarObject_HEAD_INIT(NULL, 0)