view contrib/python-zstandard/zstd_cffi.py @ 30890:22a4f664c1a5 stable

misc: replace domain of mercurial-devel ML address by mercurial-scm.org This patch also adds new check-code.py pattern to detect invalid usage of "mercurial-devel@selenic.com".
author FUJIWARA Katsunori <foozy@lares.dti.ne.jp>
date Sat, 11 Feb 2017 00:23:55 +0900
parents b86a448a2965
children c32454d69b85
line wrap: on
line source

# Copyright (c) 2016-present, Gregory Szorc
# All rights reserved.
#
# This software may be modified and distributed under the terms
# of the BSD license. See the LICENSE file for details.

"""Python interface to the Zstandard (zstd) compression library."""

from __future__ import absolute_import, unicode_literals

import io

from _zstd_cffi import (
    ffi,
    lib,
)


_CSTREAM_IN_SIZE = lib.ZSTD_CStreamInSize()
_CSTREAM_OUT_SIZE = lib.ZSTD_CStreamOutSize()


class _ZstdCompressionWriter(object):
    def __init__(self, cstream, writer):
        self._cstream = cstream
        self._writer = writer

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, exc_tb):
        if not exc_type and not exc_value and not exc_tb:
            out_buffer = ffi.new('ZSTD_outBuffer *')
            out_buffer.dst = ffi.new('char[]', _CSTREAM_OUT_SIZE)
            out_buffer.size = _CSTREAM_OUT_SIZE
            out_buffer.pos = 0

            while True:
                res = lib.ZSTD_endStream(self._cstream, out_buffer)
                if lib.ZSTD_isError(res):
                    raise Exception('error ending compression stream: %s' % lib.ZSTD_getErrorName)

                if out_buffer.pos:
                    self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
                    out_buffer.pos = 0

                if res == 0:
                    break

        return False

    def write(self, data):
        out_buffer = ffi.new('ZSTD_outBuffer *')
        out_buffer.dst = ffi.new('char[]', _CSTREAM_OUT_SIZE)
        out_buffer.size = _CSTREAM_OUT_SIZE
        out_buffer.pos = 0

        # TODO can we reuse existing memory?
        in_buffer = ffi.new('ZSTD_inBuffer *')
        in_buffer.src = ffi.new('char[]', data)
        in_buffer.size = len(data)
        in_buffer.pos = 0
        while in_buffer.pos < in_buffer.size:
            res = lib.ZSTD_compressStream(self._cstream, out_buffer, in_buffer)
            if lib.ZSTD_isError(res):
                raise Exception('zstd compress error: %s' % lib.ZSTD_getErrorName(res))

            if out_buffer.pos:
                self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
                out_buffer.pos = 0


class ZstdCompressor(object):
    def __init__(self, level=3, dict_data=None, compression_params=None):
        if dict_data:
            raise Exception('dict_data not yet supported')
        if compression_params:
            raise Exception('compression_params not yet supported')

        self._compression_level = level

    def compress(self, data):
        # Just use the stream API for now.
        output = io.BytesIO()
        with self.write_to(output) as compressor:
            compressor.write(data)
        return output.getvalue()

    def copy_stream(self, ifh, ofh):
        cstream = self._get_cstream()

        in_buffer = ffi.new('ZSTD_inBuffer *')
        out_buffer = ffi.new('ZSTD_outBuffer *')

        out_buffer.dst = ffi.new('char[]', _CSTREAM_OUT_SIZE)
        out_buffer.size = _CSTREAM_OUT_SIZE
        out_buffer.pos = 0

        total_read, total_write = 0, 0

        while True:
            data = ifh.read(_CSTREAM_IN_SIZE)
            if not data:
                break

            total_read += len(data)

            in_buffer.src = ffi.new('char[]', data)
            in_buffer.size = len(data)
            in_buffer.pos = 0

            while in_buffer.pos < in_buffer.size:
                res = lib.ZSTD_compressStream(cstream, out_buffer, in_buffer)
                if lib.ZSTD_isError(res):
                    raise Exception('zstd compress error: %s' %
                                    lib.ZSTD_getErrorName(res))

                if out_buffer.pos:
                    ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
                    total_write = out_buffer.pos
                    out_buffer.pos = 0

        # We've finished reading. Flush the compressor.
        while True:
            res = lib.ZSTD_endStream(cstream, out_buffer)
            if lib.ZSTD_isError(res):
                raise Exception('error ending compression stream: %s' %
                                lib.ZSTD_getErrorName(res))

            if out_buffer.pos:
                ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
                total_write += out_buffer.pos
                out_buffer.pos = 0

            if res == 0:
                break

        return total_read, total_write

    def write_to(self, writer):
        return _ZstdCompressionWriter(self._get_cstream(), writer)

    def _get_cstream(self):
        cstream = lib.ZSTD_createCStream()
        cstream = ffi.gc(cstream, lib.ZSTD_freeCStream)

        res = lib.ZSTD_initCStream(cstream, self._compression_level)
        if lib.ZSTD_isError(res):
            raise Exception('cannot init CStream: %s' %
                            lib.ZSTD_getErrorName(res))

        return cstream