diff -r 1ce7a55b09d1 -r b1fb341d8a61 contrib/python-zstandard/c-ext/decompressor.c --- a/contrib/python-zstandard/c-ext/decompressor.c Sun Apr 08 01:08:43 2018 +0200 +++ b/contrib/python-zstandard/c-ext/decompressor.c Mon Apr 09 10:13:29 2018 -0700 @@ -12,54 +12,40 @@ extern PyObject* ZstdError; /** - * Ensure the ZSTD_DStream on a ZstdDecompressor is initialized and reset. - * - * This should be called before starting a decompression operation with a - * ZSTD_DStream on a ZstdDecompressor. - */ -int init_dstream(ZstdDecompressor* decompressor) { - void* dictData = NULL; - size_t dictSize = 0; + * Ensure the ZSTD_DCtx on a decompressor is initiated and ready for a new operation. + */ +int ensure_dctx(ZstdDecompressor* decompressor, int loadDict) { size_t zresult; - /* Simple case of dstream already exists. Just reset it. */ - if (decompressor->dstream) { - zresult = ZSTD_resetDStream(decompressor->dstream); + ZSTD_DCtx_reset(decompressor->dctx); + + if (decompressor->maxWindowSize) { + zresult = ZSTD_DCtx_setMaxWindowSize(decompressor->dctx, decompressor->maxWindowSize); if (ZSTD_isError(zresult)) { - PyErr_Format(ZstdError, "could not reset DStream: %s", + PyErr_Format(ZstdError, "unable to set max window size: %s", ZSTD_getErrorName(zresult)); - return -1; + return 1; } - - return 0; } - decompressor->dstream = ZSTD_createDStream(); - if (!decompressor->dstream) { - PyErr_SetString(ZstdError, "could not create DStream"); - return -1; - } - - if (decompressor->dict) { - dictData = decompressor->dict->dictData; - dictSize = decompressor->dict->dictSize; + zresult = ZSTD_DCtx_setFormat(decompressor->dctx, decompressor->format); + if (ZSTD_isError(zresult)) { + PyErr_Format(ZstdError, "unable to set decoding format: %s", + ZSTD_getErrorName(zresult)); + return 1; } - if (dictData) { - zresult = ZSTD_initDStream_usingDict(decompressor->dstream, dictData, dictSize); - } - else { - zresult = ZSTD_initDStream(decompressor->dstream); - } + if (loadDict && decompressor->dict) { + if (ensure_ddict(decompressor->dict)) { + return 1; + } - if (ZSTD_isError(zresult)) { - /* Don't leave a reference to an invalid object. */ - ZSTD_freeDStream(decompressor->dstream); - decompressor->dstream = NULL; - - PyErr_Format(ZstdError, "could not initialize DStream: %s", - ZSTD_getErrorName(zresult)); - return -1; + zresult = ZSTD_DCtx_refDDict(decompressor->dctx, decompressor->dict->ddict); + if (ZSTD_isError(zresult)) { + PyErr_Format(ZstdError, "unable to reference prepared dictionary: %s", + ZSTD_getErrorName(zresult)); + return 1; + } } return 0; @@ -76,36 +62,46 @@ static int Decompressor_init(ZstdDecompressor* self, PyObject* args, PyObject* kwargs) { static char* kwlist[] = { "dict_data", + "max_window_size", + "format", NULL }; ZstdCompressionDict* dict = NULL; + size_t maxWindowSize = 0; + ZSTD_format_e format = ZSTD_f_zstd1; self->dctx = NULL; self->dict = NULL; - self->ddict = NULL; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!:ZstdDecompressor", kwlist, - &ZstdCompressionDictType, &dict)) { + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!II:ZstdDecompressor", kwlist, + &ZstdCompressionDictType, &dict, &maxWindowSize, &format)) { return -1; } - /* TODO lazily initialize the reference ZSTD_DCtx on first use since - not instances of ZstdDecompressor will use a ZSTD_DCtx. */ self->dctx = ZSTD_createDCtx(); if (!self->dctx) { PyErr_NoMemory(); goto except; } + self->maxWindowSize = maxWindowSize; + self->format = format; + if (dict) { self->dict = dict; Py_INCREF(dict); } + if (ensure_dctx(self, 1)) { + goto except; + } + return 0; except: + Py_CLEAR(self->dict); + if (self->dctx) { ZSTD_freeDCtx(self->dctx); self->dctx = NULL; @@ -117,16 +113,6 @@ static void Decompressor_dealloc(ZstdDecompressor* self) { Py_CLEAR(self->dict); - if (self->ddict) { - ZSTD_freeDDict(self->ddict); - self->ddict = NULL; - } - - if (self->dstream) { - ZSTD_freeDStream(self->dstream); - self->dstream = NULL; - } - if (self->dctx) { ZSTD_freeDCtx(self->dctx); self->dctx = NULL; @@ -135,6 +121,20 @@ PyObject_Del(self); } +PyDoc_STRVAR(Decompressor_memory_size__doc__, +"memory_size() -- Size of decompression context, in bytes\n" +); + +static PyObject* Decompressor_memory_size(ZstdDecompressor* self) { + if (self->dctx) { + return PyLong_FromSize_t(ZSTD_sizeof_DCtx(self->dctx)); + } + else { + PyErr_SetString(ZstdError, "no decompressor context found; this should never happen"); + return NULL; + } +} + PyDoc_STRVAR(Decompressor_copy_stream__doc__, "copy_stream(ifh, ofh[, read_size=default, write_size=default]) -- decompress data between streams\n" "\n" @@ -166,7 +166,7 @@ Py_ssize_t totalWrite = 0; char* readBuffer; Py_ssize_t readSize; - PyObject* readResult; + PyObject* readResult = NULL; PyObject* res = NULL; size_t zresult = 0; PyObject* writeResult; @@ -191,7 +191,7 @@ /* Prevent free on uninitialized memory in finally. */ output.dst = NULL; - if (0 != init_dstream(self)) { + if (ensure_dctx(self, 1)) { res = NULL; goto finally; } @@ -229,7 +229,7 @@ while (input.pos < input.size) { Py_BEGIN_ALLOW_THREADS - zresult = ZSTD_decompressStream(self->dstream, &output, &input); + zresult = ZSTD_decompress_generic(self->dctx, &output, &input); Py_END_ALLOW_THREADS if (ZSTD_isError(zresult)) { @@ -252,6 +252,8 @@ output.pos = 0; } } + + Py_CLEAR(readResult); } /* Source stream is exhausted. Finish up. */ @@ -267,6 +269,8 @@ PyMem_Free(output.dst); } + Py_XDECREF(readResult); + return res; } @@ -300,98 +304,114 @@ NULL }; - const char* source; - Py_ssize_t sourceSize; + Py_buffer source; Py_ssize_t maxOutputSize = 0; unsigned long long decompressedSize; size_t destCapacity; PyObject* result = NULL; - void* dictData = NULL; - size_t dictSize = 0; size_t zresult; + ZSTD_outBuffer outBuffer; + ZSTD_inBuffer inBuffer; #if PY_MAJOR_VERSION >= 3 - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "y#|n:decompress", + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "y*|n:decompress", #else - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#|n:decompress", + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s*|n:decompress", #endif - kwlist, &source, &sourceSize, &maxOutputSize)) { + kwlist, &source, &maxOutputSize)) { return NULL; } - if (self->dict) { - dictData = self->dict->dictData; - dictSize = self->dict->dictSize; + if (!PyBuffer_IsContiguous(&source, 'C') || source.ndim > 1) { + PyErr_SetString(PyExc_ValueError, + "data buffer should be contiguous and have at most one dimension"); + goto finally; } - if (dictData && !self->ddict) { - Py_BEGIN_ALLOW_THREADS - self->ddict = ZSTD_createDDict_byReference(dictData, dictSize); - Py_END_ALLOW_THREADS - - if (!self->ddict) { - PyErr_SetString(ZstdError, "could not create decompression dict"); - return NULL; - } + if (ensure_dctx(self, 1)) { + goto finally; } - decompressedSize = ZSTD_getDecompressedSize(source, sourceSize); - /* 0 returned if content size not in the zstd frame header */ - if (0 == decompressedSize) { + decompressedSize = ZSTD_getFrameContentSize(source.buf, source.len); + + if (ZSTD_CONTENTSIZE_ERROR == decompressedSize) { + PyErr_SetString(ZstdError, "error determining content size from frame header"); + goto finally; + } + /* Special case of empty frame. */ + else if (0 == decompressedSize) { + result = PyBytes_FromStringAndSize("", 0); + goto finally; + } + /* Missing content size in frame header. */ + if (ZSTD_CONTENTSIZE_UNKNOWN == decompressedSize) { if (0 == maxOutputSize) { - PyErr_SetString(ZstdError, "input data invalid or missing content size " - "in frame header"); - return NULL; + PyErr_SetString(ZstdError, "could not determine content size in frame header"); + goto finally; } - else { - result = PyBytes_FromStringAndSize(NULL, maxOutputSize); - destCapacity = maxOutputSize; + + result = PyBytes_FromStringAndSize(NULL, maxOutputSize); + destCapacity = maxOutputSize; + decompressedSize = 0; + } + /* Size is recorded in frame header. */ + else { + assert(SIZE_MAX >= PY_SSIZE_T_MAX); + if (decompressedSize > PY_SSIZE_T_MAX) { + PyErr_SetString(ZstdError, "frame is too large to decompress on this platform"); + goto finally; } - } - else { - result = PyBytes_FromStringAndSize(NULL, decompressedSize); - destCapacity = decompressedSize; + + result = PyBytes_FromStringAndSize(NULL, (Py_ssize_t)decompressedSize); + destCapacity = (size_t)decompressedSize; } if (!result) { - return NULL; + goto finally; } + outBuffer.dst = PyBytes_AsString(result); + outBuffer.size = destCapacity; + outBuffer.pos = 0; + + inBuffer.src = source.buf; + inBuffer.size = source.len; + inBuffer.pos = 0; + Py_BEGIN_ALLOW_THREADS - if (self->ddict) { - zresult = ZSTD_decompress_usingDDict(self->dctx, - PyBytes_AsString(result), destCapacity, - source, sourceSize, self->ddict); - } - else { - zresult = ZSTD_decompressDCtx(self->dctx, - PyBytes_AsString(result), destCapacity, source, sourceSize); - } + zresult = ZSTD_decompress_generic(self->dctx, &outBuffer, &inBuffer); Py_END_ALLOW_THREADS if (ZSTD_isError(zresult)) { PyErr_Format(ZstdError, "decompression error: %s", ZSTD_getErrorName(zresult)); - Py_DECREF(result); - return NULL; + Py_CLEAR(result); + goto finally; } - else if (decompressedSize && zresult != decompressedSize) { + else if (zresult) { + PyErr_Format(ZstdError, "decompression error: did not decompress full frame"); + Py_CLEAR(result); + goto finally; + } + else if (decompressedSize && outBuffer.pos != decompressedSize) { PyErr_Format(ZstdError, "decompression error: decompressed %zu bytes; expected %llu", zresult, decompressedSize); - Py_DECREF(result); - return NULL; + Py_CLEAR(result); + goto finally; } - else if (zresult < destCapacity) { - if (_PyBytes_Resize(&result, zresult)) { - Py_DECREF(result); - return NULL; + else if (outBuffer.pos < destCapacity) { + if (safe_pybytes_resize(&result, outBuffer.pos)) { + Py_CLEAR(result); + goto finally; } } +finally: + PyBuffer_Release(&source); return result; } PyDoc_STRVAR(Decompressor_decompressobj__doc__, -"decompressobj()\n" +"decompressobj([write_size=default])\n" "\n" "Incrementally feed data into a decompressor.\n" "\n" @@ -400,25 +420,43 @@ "callers can swap in the zstd decompressor while using the same API.\n" ); -static ZstdDecompressionObj* Decompressor_decompressobj(ZstdDecompressor* self) { - ZstdDecompressionObj* result = (ZstdDecompressionObj*)PyObject_CallObject((PyObject*)&ZstdDecompressionObjType, NULL); +static ZstdDecompressionObj* Decompressor_decompressobj(ZstdDecompressor* self, PyObject* args, PyObject* kwargs) { + static char* kwlist[] = { + "write_size", + NULL + }; + + ZstdDecompressionObj* result = NULL; + size_t outSize = ZSTD_DStreamOutSize(); + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|k:decompressobj", kwlist, &outSize)) { + return NULL; + } + + if (!outSize) { + PyErr_SetString(PyExc_ValueError, "write_size must be positive"); + return NULL; + } + + result = (ZstdDecompressionObj*)PyObject_CallObject((PyObject*)&ZstdDecompressionObjType, NULL); if (!result) { return NULL; } - if (0 != init_dstream(self)) { + if (ensure_dctx(self, 1)) { Py_DECREF(result); return NULL; } result->decompressor = self; Py_INCREF(result->decompressor); + result->outSize = outSize; return result; } -PyDoc_STRVAR(Decompressor_read_from__doc__, -"read_from(reader[, read_size=default, write_size=default, skip_bytes=0])\n" +PyDoc_STRVAR(Decompressor_read_to_iter__doc__, +"read_to_iter(reader[, read_size=default, write_size=default, skip_bytes=0])\n" "Read compressed data and return an iterator\n" "\n" "Returns an iterator of decompressed data chunks produced from reading from\n" @@ -437,7 +475,7 @@ "the source.\n" ); -static ZstdDecompressorIterator* Decompressor_read_from(ZstdDecompressor* self, PyObject* args, PyObject* kwargs) { +static ZstdDecompressorIterator* Decompressor_read_to_iter(ZstdDecompressor* self, PyObject* args, PyObject* kwargs) { static char* kwlist[] = { "reader", "read_size", @@ -452,7 +490,7 @@ ZstdDecompressorIterator* result; size_t skipBytes = 0; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|kkk:read_from", kwlist, + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|kkk:read_to_iter", kwlist, &reader, &inSize, &outSize, &skipBytes)) { return NULL; } @@ -474,14 +512,7 @@ } else if (1 == PyObject_CheckBuffer(reader)) { /* Object claims it is a buffer. Try to get a handle to it. */ - result->buffer = PyMem_Malloc(sizeof(Py_buffer)); - if (!result->buffer) { - goto except; - } - - memset(result->buffer, 0, sizeof(Py_buffer)); - - if (0 != PyObject_GetBuffer(reader, result->buffer, PyBUF_CONTIG_RO)) { + if (0 != PyObject_GetBuffer(reader, &result->buffer, PyBUF_CONTIG_RO)) { goto except; } } @@ -498,7 +529,7 @@ result->outSize = outSize; result->skipBytes = skipBytes; - if (0 != init_dstream(self)) { + if (ensure_dctx(self, 1)) { goto except; } @@ -511,13 +542,6 @@ goto finally; except: - Py_CLEAR(result->reader); - - if (result->buffer) { - PyBuffer_Release(result->buffer); - Py_CLEAR(result->buffer); - } - Py_CLEAR(result); finally: @@ -525,7 +549,62 @@ return result; } -PyDoc_STRVAR(Decompressor_write_to__doc__, +PyDoc_STRVAR(Decompressor_stream_reader__doc__, +"stream_reader(source, [read_size=default])\n" +"\n" +"Obtain an object that behaves like an I/O stream that can be used for\n" +"reading decompressed output from an object.\n" +"\n" +"The source object can be any object with a ``read(size)`` method or that\n" +"conforms to the buffer protocol.\n" +); + +static ZstdDecompressionReader* Decompressor_stream_reader(ZstdDecompressor* self, PyObject* args, PyObject* kwargs) { + static char* kwlist[] = { + "source", + "read_size", + NULL + }; + + PyObject* source; + size_t readSize = ZSTD_DStreamInSize(); + ZstdDecompressionReader* result; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|k:stream_reader", kwlist, + &source, &readSize)) { + return NULL; + } + + result = (ZstdDecompressionReader*)PyObject_CallObject((PyObject*)&ZstdDecompressionReaderType, NULL); + if (NULL == result) { + return NULL; + } + + if (PyObject_HasAttrString(source, "read")) { + result->reader = source; + Py_INCREF(source); + result->readSize = readSize; + } + else if (1 == PyObject_CheckBuffer(source)) { + if (0 != PyObject_GetBuffer(source, &result->buffer, PyBUF_CONTIG_RO)) { + Py_CLEAR(result); + return NULL; + } + } + else { + PyErr_SetString(PyExc_TypeError, + "must pass an object with a read() method or that conforms to the buffer protocol"); + Py_CLEAR(result); + return NULL; + } + + result->decompressor = self; + Py_INCREF(self); + + return result; +} + +PyDoc_STRVAR(Decompressor_stream_writer__doc__, "Create a context manager to write decompressed data to an object.\n" "\n" "The passed object must have a ``write()`` method.\n" @@ -538,7 +617,7 @@ "streaming decompressor.\n" ); -static ZstdDecompressionWriter* Decompressor_write_to(ZstdDecompressor* self, PyObject* args, PyObject* kwargs) { +static ZstdDecompressionWriter* Decompressor_stream_writer(ZstdDecompressor* self, PyObject* args, PyObject* kwargs) { static char* kwlist[] = { "writer", "write_size", @@ -549,7 +628,7 @@ size_t outSize = ZSTD_DStreamOutSize(); ZstdDecompressionWriter* result; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|k:write_to", kwlist, + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|k:stream_writer", kwlist, &writer, &outSize)) { return NULL; } @@ -579,7 +658,7 @@ "Decompress a series of chunks using the content dictionary chaining technique\n" ); -static PyObject* Decompressor_decompress_content_dict_chain(PyObject* self, PyObject* args, PyObject* kwargs) { +static PyObject* Decompressor_decompress_content_dict_chain(ZstdDecompressor* self, PyObject* args, PyObject* kwargs) { static char* kwlist[] = { "frames", NULL @@ -592,9 +671,8 @@ PyObject* chunk; char* chunkData; Py_ssize_t chunkSize; - ZSTD_DCtx* dctx = NULL; size_t zresult; - ZSTD_frameParams frameParams; + ZSTD_frameHeader frameHeader; void* buffer1 = NULL; size_t buffer1Size = 0; size_t buffer1ContentSize = 0; @@ -603,6 +681,8 @@ size_t buffer2ContentSize = 0; void* destBuffer = NULL; PyObject* result = NULL; + ZSTD_outBuffer outBuffer; + ZSTD_inBuffer inBuffer; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!:decompress_content_dict_chain", kwlist, &PyList_Type, &chunks)) { @@ -624,7 +704,7 @@ /* We require that all chunks be zstd frames and that they have content size set. */ PyBytes_AsStringAndSize(chunk, &chunkData, &chunkSize); - zresult = ZSTD_getFrameParams(&frameParams, (void*)chunkData, chunkSize); + zresult = ZSTD_getFrameHeader(&frameHeader, (void*)chunkData, chunkSize); if (ZSTD_isError(zresult)) { PyErr_SetString(PyExc_ValueError, "chunk 0 is not a valid zstd frame"); return NULL; @@ -634,32 +714,56 @@ return NULL; } - if (0 == frameParams.frameContentSize) { + if (ZSTD_CONTENTSIZE_UNKNOWN == frameHeader.frameContentSize) { PyErr_SetString(PyExc_ValueError, "chunk 0 missing content size in frame"); return NULL; } - dctx = ZSTD_createDCtx(); - if (!dctx) { - PyErr_NoMemory(); + assert(ZSTD_CONTENTSIZE_ERROR != frameHeader.frameContentSize); + + /* We check against PY_SSIZE_T_MAX here because we ultimately cast the + * result to a Python object and it's length can be no greater than + * Py_ssize_t. In theory, we could have an intermediate frame that is + * larger. But a) why would this API be used for frames that large b) + * it isn't worth the complexity to support. */ + assert(SIZE_MAX >= PY_SSIZE_T_MAX); + if (frameHeader.frameContentSize > PY_SSIZE_T_MAX) { + PyErr_SetString(PyExc_ValueError, + "chunk 0 is too large to decompress on this platform"); + return NULL; + } + + if (ensure_dctx(self, 0)) { goto finally; } - buffer1Size = frameParams.frameContentSize; + buffer1Size = (size_t)frameHeader.frameContentSize; buffer1 = PyMem_Malloc(buffer1Size); if (!buffer1) { goto finally; } + outBuffer.dst = buffer1; + outBuffer.size = buffer1Size; + outBuffer.pos = 0; + + inBuffer.src = chunkData; + inBuffer.size = chunkSize; + inBuffer.pos = 0; + Py_BEGIN_ALLOW_THREADS - zresult = ZSTD_decompressDCtx(dctx, buffer1, buffer1Size, chunkData, chunkSize); + zresult = ZSTD_decompress_generic(self->dctx, &outBuffer, &inBuffer); Py_END_ALLOW_THREADS if (ZSTD_isError(zresult)) { PyErr_Format(ZstdError, "could not decompress chunk 0: %s", ZSTD_getErrorName(zresult)); goto finally; } + else if (zresult) { + PyErr_Format(ZstdError, "chunk 0 did not decompress full frame"); + goto finally; + } - buffer1ContentSize = zresult; + buffer1ContentSize = outBuffer.pos; /* Special case of a simple chain. */ if (1 == chunksLen) { @@ -668,7 +772,7 @@ } /* This should ideally look at next chunk. But this is slightly simpler. */ - buffer2Size = frameParams.frameContentSize; + buffer2Size = (size_t)frameHeader.frameContentSize; buffer2 = PyMem_Malloc(buffer2Size); if (!buffer2) { goto finally; @@ -688,7 +792,7 @@ } PyBytes_AsStringAndSize(chunk, &chunkData, &chunkSize); - zresult = ZSTD_getFrameParams(&frameParams, (void*)chunkData, chunkSize); + zresult = ZSTD_getFrameHeader(&frameHeader, (void*)chunkData, chunkSize); if (ZSTD_isError(zresult)) { PyErr_Format(PyExc_ValueError, "chunk %zd is not a valid zstd frame", chunkIndex); goto finally; @@ -698,18 +802,30 @@ goto finally; } - if (0 == frameParams.frameContentSize) { + if (ZSTD_CONTENTSIZE_UNKNOWN == frameHeader.frameContentSize) { PyErr_Format(PyExc_ValueError, "chunk %zd missing content size in frame", chunkIndex); goto finally; } + assert(ZSTD_CONTENTSIZE_ERROR != frameHeader.frameContentSize); + + if (frameHeader.frameContentSize > PY_SSIZE_T_MAX) { + PyErr_Format(PyExc_ValueError, + "chunk %zd is too large to decompress on this platform", chunkIndex); + goto finally; + } + + inBuffer.src = chunkData; + inBuffer.size = chunkSize; + inBuffer.pos = 0; + parity = chunkIndex % 2; /* This could definitely be abstracted to reduce code duplication. */ if (parity) { /* Resize destination buffer to hold larger content. */ - if (buffer2Size < frameParams.frameContentSize) { - buffer2Size = frameParams.frameContentSize; + if (buffer2Size < frameHeader.frameContentSize) { + buffer2Size = (size_t)frameHeader.frameContentSize; destBuffer = PyMem_Realloc(buffer2, buffer2Size); if (!destBuffer) { goto finally; @@ -718,19 +834,38 @@ } Py_BEGIN_ALLOW_THREADS - zresult = ZSTD_decompress_usingDict(dctx, buffer2, buffer2Size, - chunkData, chunkSize, buffer1, buffer1ContentSize); + zresult = ZSTD_DCtx_refPrefix_advanced(self->dctx, + buffer1, buffer1ContentSize, ZSTD_dct_rawContent); + Py_END_ALLOW_THREADS + if (ZSTD_isError(zresult)) { + PyErr_Format(ZstdError, + "failed to load prefix dictionary at chunk %zd", chunkIndex); + goto finally; + } + + outBuffer.dst = buffer2; + outBuffer.size = buffer2Size; + outBuffer.pos = 0; + + Py_BEGIN_ALLOW_THREADS + zresult = ZSTD_decompress_generic(self->dctx, &outBuffer, &inBuffer); Py_END_ALLOW_THREADS if (ZSTD_isError(zresult)) { PyErr_Format(ZstdError, "could not decompress chunk %zd: %s", chunkIndex, ZSTD_getErrorName(zresult)); goto finally; } - buffer2ContentSize = zresult; + else if (zresult) { + PyErr_Format(ZstdError, "chunk %zd did not decompress full frame", + chunkIndex); + goto finally; + } + + buffer2ContentSize = outBuffer.pos; } else { - if (buffer1Size < frameParams.frameContentSize) { - buffer1Size = frameParams.frameContentSize; + if (buffer1Size < frameHeader.frameContentSize) { + buffer1Size = (size_t)frameHeader.frameContentSize; destBuffer = PyMem_Realloc(buffer1, buffer1Size); if (!destBuffer) { goto finally; @@ -739,15 +874,34 @@ } Py_BEGIN_ALLOW_THREADS - zresult = ZSTD_decompress_usingDict(dctx, buffer1, buffer1Size, - chunkData, chunkSize, buffer2, buffer2ContentSize); + zresult = ZSTD_DCtx_refPrefix_advanced(self->dctx, + buffer2, buffer2ContentSize, ZSTD_dct_rawContent); + Py_END_ALLOW_THREADS + if (ZSTD_isError(zresult)) { + PyErr_Format(ZstdError, + "failed to load prefix dictionary at chunk %zd", chunkIndex); + goto finally; + } + + outBuffer.dst = buffer1; + outBuffer.size = buffer1Size; + outBuffer.pos = 0; + + Py_BEGIN_ALLOW_THREADS + zresult = ZSTD_decompress_generic(self->dctx, &outBuffer, &inBuffer); Py_END_ALLOW_THREADS if (ZSTD_isError(zresult)) { PyErr_Format(ZstdError, "could not decompress chunk %zd: %s", chunkIndex, ZSTD_getErrorName(zresult)); goto finally; } - buffer1ContentSize = zresult; + else if (zresult) { + PyErr_Format(ZstdError, "chunk %zd did not decompress full frame", + chunkIndex); + goto finally; + } + + buffer1ContentSize = outBuffer.pos; } } @@ -762,17 +916,13 @@ PyMem_Free(buffer1); } - if (dctx) { - ZSTD_freeDCtx(dctx); - } - return result; } typedef struct { void* sourceData; size_t sourceSize; - unsigned long long destSize; + size_t destSize; } FramePointer; typedef struct { @@ -806,7 +956,6 @@ /* Compression state and settings. */ ZSTD_DCtx* dctx; - ZSTD_DDict* ddict; int requireOutputSizes; /* Output storage. */ @@ -838,6 +987,14 @@ assert(0 == state->destCount); assert(state->endOffset - state->startOffset >= 0); + /* We could get here due to the way work is allocated. Ideally we wouldn't + get here. But that would require a bit of a refactor in the caller. */ + if (state->totalSourceSize > SIZE_MAX) { + state->error = WorkerError_memory; + state->errorOffset = 0; + return; + } + /* * We need to allocate a buffer to hold decompressed data. How we do this * depends on what we know about the output. The following scenarios are @@ -853,14 +1010,34 @@ /* Resolve ouput segments. */ for (frameIndex = state->startOffset; frameIndex <= state->endOffset; frameIndex++) { FramePointer* fp = &framePointers[frameIndex]; + unsigned long long decompressedSize; if (0 == fp->destSize) { - fp->destSize = ZSTD_getDecompressedSize(fp->sourceData, fp->sourceSize); - if (0 == fp->destSize && state->requireOutputSizes) { + decompressedSize = ZSTD_getFrameContentSize(fp->sourceData, fp->sourceSize); + + if (ZSTD_CONTENTSIZE_ERROR == decompressedSize) { state->error = WorkerError_unknownSize; state->errorOffset = frameIndex; return; } + else if (ZSTD_CONTENTSIZE_UNKNOWN == decompressedSize) { + if (state->requireOutputSizes) { + state->error = WorkerError_unknownSize; + state->errorOffset = frameIndex; + return; + } + + /* This will fail the assert for .destSize > 0 below. */ + decompressedSize = 0; + } + + if (decompressedSize > SIZE_MAX) { + state->error = WorkerError_memory; + state->errorOffset = frameIndex; + return; + } + + fp->destSize = (size_t)decompressedSize; } totalOutputSize += fp->destSize; @@ -878,7 +1055,7 @@ assert(framePointers[state->startOffset].destSize > 0); /* For now. */ - allocationSize = roundpow2(state->totalSourceSize); + allocationSize = roundpow2((size_t)state->totalSourceSize); if (framePointers[state->startOffset].destSize > allocationSize) { allocationSize = roundpow2(framePointers[state->startOffset].destSize); @@ -902,6 +1079,8 @@ destBuffer->segmentsSize = remainingItems; for (frameIndex = state->startOffset; frameIndex <= state->endOffset; frameIndex++) { + ZSTD_outBuffer outBuffer; + ZSTD_inBuffer inBuffer; const void* source = framePointers[frameIndex].sourceData; const size_t sourceSize = framePointers[frameIndex].sourceSize; void* dest; @@ -956,7 +1135,7 @@ /* Don't take any chances will non-NULL pointers. */ memset(destBuffer, 0, sizeof(DestBuffer)); - allocationSize = roundpow2(state->totalSourceSize); + allocationSize = roundpow2((size_t)state->totalSourceSize); if (decompressedSize > allocationSize) { allocationSize = roundpow2(decompressedSize); @@ -985,31 +1164,31 @@ dest = (char*)destBuffer->dest + destOffset; - if (state->ddict) { - zresult = ZSTD_decompress_usingDDict(state->dctx, dest, decompressedSize, - source, sourceSize, state->ddict); - } - else { - zresult = ZSTD_decompressDCtx(state->dctx, dest, decompressedSize, - source, sourceSize); - } + outBuffer.dst = dest; + outBuffer.size = decompressedSize; + outBuffer.pos = 0; + inBuffer.src = source; + inBuffer.size = sourceSize; + inBuffer.pos = 0; + + zresult = ZSTD_decompress_generic(state->dctx, &outBuffer, &inBuffer); if (ZSTD_isError(zresult)) { state->error = WorkerError_zstd; state->zresult = zresult; state->errorOffset = frameIndex; return; } - else if (zresult != decompressedSize) { + else if (zresult || outBuffer.pos != decompressedSize) { state->error = WorkerError_sizeMismatch; - state->zresult = zresult; + state->zresult = outBuffer.pos; state->errorOffset = frameIndex; return; } destBuffer->segments[localOffset].offset = destOffset; - destBuffer->segments[localOffset].length = decompressedSize; - destOffset += zresult; + destBuffer->segments[localOffset].length = outBuffer.pos; + destOffset += outBuffer.pos; localOffset++; remainingItems--; } @@ -1027,9 +1206,7 @@ } ZstdBufferWithSegmentsCollection* decompress_from_framesources(ZstdDecompressor* decompressor, FrameSources* frames, - unsigned int threadCount) { - void* dictData = NULL; - size_t dictSize = 0; + Py_ssize_t threadCount) { Py_ssize_t i = 0; int errored = 0; Py_ssize_t segmentsCount; @@ -1039,7 +1216,7 @@ ZstdBufferWithSegmentsCollection* result = NULL; FramePointer* framePointers = frames->frames; unsigned long long workerBytes = 0; - int currentThread = 0; + Py_ssize_t currentThread = 0; Py_ssize_t workerStartOffset = 0; POOL_ctx* pool = NULL; WorkerState* workerStates = NULL; @@ -1049,24 +1226,14 @@ assert(threadCount >= 1); /* More threads than inputs makes no sense under any conditions. */ - threadCount = frames->framesSize < threadCount ? (unsigned int)frames->framesSize + threadCount = frames->framesSize < threadCount ? frames->framesSize : threadCount; /* TODO lower thread count if input size is too small and threads would just add overhead. */ if (decompressor->dict) { - dictData = decompressor->dict->dictData; - dictSize = decompressor->dict->dictSize; - } - - if (dictData && !decompressor->ddict) { - Py_BEGIN_ALLOW_THREADS - decompressor->ddict = ZSTD_createDDict_byReference(dictData, dictSize); - Py_END_ALLOW_THREADS - - if (!decompressor->ddict) { - PyErr_SetString(ZstdError, "could not create decompression dict"); + if (ensure_ddict(decompressor->dict)) { return NULL; } } @@ -1091,7 +1258,14 @@ bytesPerWorker = frames->compressedSize / threadCount; + if (bytesPerWorker > SIZE_MAX) { + PyErr_SetString(ZstdError, "too much data per worker for this platform"); + goto finally; + } + for (i = 0; i < threadCount; i++) { + size_t zresult; + workerStates[i].dctx = ZSTD_createDCtx(); if (NULL == workerStates[i].dctx) { PyErr_NoMemory(); @@ -1100,7 +1274,15 @@ ZSTD_copyDCtx(workerStates[i].dctx, decompressor->dctx); - workerStates[i].ddict = decompressor->ddict; + if (decompressor->dict) { + zresult = ZSTD_DCtx_refDDict(workerStates[i].dctx, decompressor->dict->ddict); + if (zresult) { + PyErr_Format(ZstdError, "unable to reference prepared dictionary: %s", + ZSTD_getErrorName(zresult)); + goto finally; + } + } + workerStates[i].framePointers = framePointers; workerStates[i].requireOutputSizes = 1; } @@ -1178,7 +1360,7 @@ break; case WorkerError_sizeMismatch: - PyErr_Format(ZstdError, "error decompressing item %zd: decompressed %zu bytes; expected %llu", + PyErr_Format(ZstdError, "error decompressing item %zd: decompressed %zu bytes; expected %zu", workerStates[i].errorOffset, workerStates[i].zresult, framePointers[workerStates[i].errorOffset].destSize); errored = 1; @@ -1388,9 +1570,21 @@ decompressedSize = frameSizesP[i]; } + if (sourceSize > SIZE_MAX) { + PyErr_Format(PyExc_ValueError, + "item %zd is too large for this platform", i); + goto finally; + } + + if (decompressedSize > SIZE_MAX) { + PyErr_Format(PyExc_ValueError, + "decompressed size of item %zd is too large for this platform", i); + goto finally; + } + framePointers[i].sourceData = sourceData; - framePointers[i].sourceSize = sourceSize; - framePointers[i].destSize = decompressedSize; + framePointers[i].sourceSize = (size_t)sourceSize; + framePointers[i].destSize = (size_t)decompressedSize; } } else if (PyObject_TypeCheck(frames, &ZstdBufferWithSegmentsCollectionType)) { @@ -1419,17 +1613,33 @@ buffer = collection->buffers[i]; for (segmentIndex = 0; segmentIndex < buffer->segmentCount; segmentIndex++) { + unsigned long long decompressedSize = frameSizesP ? frameSizesP[offset] : 0; + if (buffer->segments[segmentIndex].offset + buffer->segments[segmentIndex].length > buffer->dataSize) { PyErr_Format(PyExc_ValueError, "item %zd has offset outside memory area", offset); goto finally; } + if (buffer->segments[segmentIndex].length > SIZE_MAX) { + PyErr_Format(PyExc_ValueError, + "item %zd in buffer %zd is too large for this platform", + segmentIndex, i); + goto finally; + } + + if (decompressedSize > SIZE_MAX) { + PyErr_Format(PyExc_ValueError, + "decompressed size of item %zd in buffer %zd is too large for this platform", + segmentIndex, i); + goto finally; + } + totalInputSize += buffer->segments[segmentIndex].length; framePointers[offset].sourceData = (char*)buffer->data + buffer->segments[segmentIndex].offset; - framePointers[offset].sourceSize = buffer->segments[segmentIndex].length; - framePointers[offset].destSize = frameSizesP ? frameSizesP[offset] : 0; + framePointers[offset].sourceSize = (size_t)buffer->segments[segmentIndex].length; + framePointers[offset].destSize = (size_t)decompressedSize; offset++; } @@ -1450,11 +1660,6 @@ goto finally; } - /* - * It is not clear whether Py_buffer.buf is still valid after - * PyBuffer_Release. So, we hold a reference to all Py_buffer instances - * for the duration of the operation. - */ frameBuffers = PyMem_Malloc(frameCount * sizeof(Py_buffer)); if (NULL == frameBuffers) { PyErr_NoMemory(); @@ -1465,6 +1670,8 @@ /* Do a pass to assemble info about our input buffers and output sizes. */ for (i = 0; i < frameCount; i++) { + unsigned long long decompressedSize = frameSizesP ? frameSizesP[i] : 0; + if (0 != PyObject_GetBuffer(PyList_GET_ITEM(frames, i), &frameBuffers[i], PyBUF_CONTIG_RO)) { PyErr_Clear(); @@ -1472,11 +1679,17 @@ goto finally; } + if (decompressedSize > SIZE_MAX) { + PyErr_Format(PyExc_ValueError, + "decompressed size of item %zd is too large for this platform", i); + goto finally; + } + totalInputSize += frameBuffers[i].len; framePointers[i].sourceData = frameBuffers[i].buf; framePointers[i].sourceSize = frameBuffers[i].len; - framePointers[i].destSize = frameSizesP ? frameSizesP[i] : 0; + framePointers[i].destSize = (size_t)decompressedSize; } } else { @@ -1514,16 +1727,26 @@ Decompressor_copy_stream__doc__ }, { "decompress", (PyCFunction)Decompressor_decompress, METH_VARARGS | METH_KEYWORDS, Decompressor_decompress__doc__ }, - { "decompressobj", (PyCFunction)Decompressor_decompressobj, METH_NOARGS, + { "decompressobj", (PyCFunction)Decompressor_decompressobj, METH_VARARGS | METH_KEYWORDS, Decompressor_decompressobj__doc__ }, - { "read_from", (PyCFunction)Decompressor_read_from, METH_VARARGS | METH_KEYWORDS, - Decompressor_read_from__doc__ }, - { "write_to", (PyCFunction)Decompressor_write_to, METH_VARARGS | METH_KEYWORDS, - Decompressor_write_to__doc__ }, + { "read_to_iter", (PyCFunction)Decompressor_read_to_iter, METH_VARARGS | METH_KEYWORDS, + Decompressor_read_to_iter__doc__ }, + /* TODO Remove deprecated API */ + { "read_from", (PyCFunction)Decompressor_read_to_iter, METH_VARARGS | METH_KEYWORDS, + Decompressor_read_to_iter__doc__ }, + { "stream_reader", (PyCFunction)Decompressor_stream_reader, + METH_VARARGS | METH_KEYWORDS, Decompressor_stream_reader__doc__ }, + { "stream_writer", (PyCFunction)Decompressor_stream_writer, METH_VARARGS | METH_KEYWORDS, + Decompressor_stream_writer__doc__ }, + /* TODO remove deprecated API */ + { "write_to", (PyCFunction)Decompressor_stream_writer, METH_VARARGS | METH_KEYWORDS, + Decompressor_stream_writer__doc__ }, { "decompress_content_dict_chain", (PyCFunction)Decompressor_decompress_content_dict_chain, METH_VARARGS | METH_KEYWORDS, Decompressor_decompress_content_dict_chain__doc__ }, { "multi_decompress_to_buffer", (PyCFunction)Decompressor_multi_decompress_to_buffer, METH_VARARGS | METH_KEYWORDS, Decompressor_multi_decompress_to_buffer__doc__ }, + { "memory_size", (PyCFunction)Decompressor_memory_size, METH_NOARGS, + Decompressor_memory_size__doc__ }, { NULL, NULL } };