contrib/python-zstandard/zstandard/cffi.py
changeset 42070 675775c33ab6
parent 40121 73fef626dae3
child 42937 69de49c4e39c
equal deleted inserted replaced
42069:668eff08387f 42070:675775c33ab6
       
     1 # Copyright (c) 2016-present, Gregory Szorc
       
     2 # All rights reserved.
       
     3 #
       
     4 # This software may be modified and distributed under the terms
       
     5 # of the BSD license. See the LICENSE file for details.
       
     6 
       
     7 """Python interface to the Zstandard (zstd) compression library."""
       
     8 
       
     9 from __future__ import absolute_import, unicode_literals
       
    10 
       
    11 # This should match what the C extension exports.
       
    12 __all__ = [
       
    13     #'BufferSegment',
       
    14     #'BufferSegments',
       
    15     #'BufferWithSegments',
       
    16     #'BufferWithSegmentsCollection',
       
    17     'CompressionParameters',
       
    18     'ZstdCompressionDict',
       
    19     'ZstdCompressionParameters',
       
    20     'ZstdCompressor',
       
    21     'ZstdError',
       
    22     'ZstdDecompressor',
       
    23     'FrameParameters',
       
    24     'estimate_decompression_context_size',
       
    25     'frame_content_size',
       
    26     'frame_header_size',
       
    27     'get_frame_parameters',
       
    28     'train_dictionary',
       
    29 
       
    30     # Constants.
       
    31     'FLUSH_BLOCK',
       
    32     'FLUSH_FRAME',
       
    33     'COMPRESSOBJ_FLUSH_FINISH',
       
    34     'COMPRESSOBJ_FLUSH_BLOCK',
       
    35     'ZSTD_VERSION',
       
    36     'FRAME_HEADER',
       
    37     'CONTENTSIZE_UNKNOWN',
       
    38     'CONTENTSIZE_ERROR',
       
    39     'MAX_COMPRESSION_LEVEL',
       
    40     'COMPRESSION_RECOMMENDED_INPUT_SIZE',
       
    41     'COMPRESSION_RECOMMENDED_OUTPUT_SIZE',
       
    42     'DECOMPRESSION_RECOMMENDED_INPUT_SIZE',
       
    43     'DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE',
       
    44     'MAGIC_NUMBER',
       
    45     'BLOCKSIZELOG_MAX',
       
    46     'BLOCKSIZE_MAX',
       
    47     'WINDOWLOG_MIN',
       
    48     'WINDOWLOG_MAX',
       
    49     'CHAINLOG_MIN',
       
    50     'CHAINLOG_MAX',
       
    51     'HASHLOG_MIN',
       
    52     'HASHLOG_MAX',
       
    53     'HASHLOG3_MAX',
       
    54     'MINMATCH_MIN',
       
    55     'MINMATCH_MAX',
       
    56     'SEARCHLOG_MIN',
       
    57     'SEARCHLOG_MAX',
       
    58     'SEARCHLENGTH_MIN',
       
    59     'SEARCHLENGTH_MAX',
       
    60     'TARGETLENGTH_MIN',
       
    61     'TARGETLENGTH_MAX',
       
    62     'LDM_MINMATCH_MIN',
       
    63     'LDM_MINMATCH_MAX',
       
    64     'LDM_BUCKETSIZELOG_MAX',
       
    65     'STRATEGY_FAST',
       
    66     'STRATEGY_DFAST',
       
    67     'STRATEGY_GREEDY',
       
    68     'STRATEGY_LAZY',
       
    69     'STRATEGY_LAZY2',
       
    70     'STRATEGY_BTLAZY2',
       
    71     'STRATEGY_BTOPT',
       
    72     'STRATEGY_BTULTRA',
       
    73     'STRATEGY_BTULTRA2',
       
    74     'DICT_TYPE_AUTO',
       
    75     'DICT_TYPE_RAWCONTENT',
       
    76     'DICT_TYPE_FULLDICT',
       
    77     'FORMAT_ZSTD1',
       
    78     'FORMAT_ZSTD1_MAGICLESS',
       
    79 ]
       
    80 
       
    81 import io
       
    82 import os
       
    83 import sys
       
    84 
       
    85 from _zstd_cffi import (
       
    86     ffi,
       
    87     lib,
       
    88 )
       
    89 
       
    90 if sys.version_info[0] == 2:
       
    91     bytes_type = str
       
    92     int_type = long
       
    93 else:
       
    94     bytes_type = bytes
       
    95     int_type = int
       
    96 
       
    97 
       
    98 COMPRESSION_RECOMMENDED_INPUT_SIZE = lib.ZSTD_CStreamInSize()
       
    99 COMPRESSION_RECOMMENDED_OUTPUT_SIZE = lib.ZSTD_CStreamOutSize()
       
   100 DECOMPRESSION_RECOMMENDED_INPUT_SIZE = lib.ZSTD_DStreamInSize()
       
   101 DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE = lib.ZSTD_DStreamOutSize()
       
   102 
       
   103 new_nonzero = ffi.new_allocator(should_clear_after_alloc=False)
       
   104 
       
   105 
       
   106 MAX_COMPRESSION_LEVEL = lib.ZSTD_maxCLevel()
       
   107 MAGIC_NUMBER = lib.ZSTD_MAGICNUMBER
       
   108 FRAME_HEADER = b'\x28\xb5\x2f\xfd'
       
   109 CONTENTSIZE_UNKNOWN = lib.ZSTD_CONTENTSIZE_UNKNOWN
       
   110 CONTENTSIZE_ERROR = lib.ZSTD_CONTENTSIZE_ERROR
       
   111 ZSTD_VERSION = (lib.ZSTD_VERSION_MAJOR, lib.ZSTD_VERSION_MINOR, lib.ZSTD_VERSION_RELEASE)
       
   112 
       
   113 BLOCKSIZELOG_MAX = lib.ZSTD_BLOCKSIZELOG_MAX
       
   114 BLOCKSIZE_MAX = lib.ZSTD_BLOCKSIZE_MAX
       
   115 WINDOWLOG_MIN = lib.ZSTD_WINDOWLOG_MIN
       
   116 WINDOWLOG_MAX = lib.ZSTD_WINDOWLOG_MAX
       
   117 CHAINLOG_MIN = lib.ZSTD_CHAINLOG_MIN
       
   118 CHAINLOG_MAX = lib.ZSTD_CHAINLOG_MAX
       
   119 HASHLOG_MIN = lib.ZSTD_HASHLOG_MIN
       
   120 HASHLOG_MAX = lib.ZSTD_HASHLOG_MAX
       
   121 HASHLOG3_MAX = lib.ZSTD_HASHLOG3_MAX
       
   122 MINMATCH_MIN = lib.ZSTD_MINMATCH_MIN
       
   123 MINMATCH_MAX = lib.ZSTD_MINMATCH_MAX
       
   124 SEARCHLOG_MIN = lib.ZSTD_SEARCHLOG_MIN
       
   125 SEARCHLOG_MAX = lib.ZSTD_SEARCHLOG_MAX
       
   126 SEARCHLENGTH_MIN = lib.ZSTD_MINMATCH_MIN
       
   127 SEARCHLENGTH_MAX = lib.ZSTD_MINMATCH_MAX
       
   128 TARGETLENGTH_MIN = lib.ZSTD_TARGETLENGTH_MIN
       
   129 TARGETLENGTH_MAX = lib.ZSTD_TARGETLENGTH_MAX
       
   130 LDM_MINMATCH_MIN = lib.ZSTD_LDM_MINMATCH_MIN
       
   131 LDM_MINMATCH_MAX = lib.ZSTD_LDM_MINMATCH_MAX
       
   132 LDM_BUCKETSIZELOG_MAX = lib.ZSTD_LDM_BUCKETSIZELOG_MAX
       
   133 
       
   134 STRATEGY_FAST = lib.ZSTD_fast
       
   135 STRATEGY_DFAST = lib.ZSTD_dfast
       
   136 STRATEGY_GREEDY = lib.ZSTD_greedy
       
   137 STRATEGY_LAZY = lib.ZSTD_lazy
       
   138 STRATEGY_LAZY2 = lib.ZSTD_lazy2
       
   139 STRATEGY_BTLAZY2 = lib.ZSTD_btlazy2
       
   140 STRATEGY_BTOPT = lib.ZSTD_btopt
       
   141 STRATEGY_BTULTRA = lib.ZSTD_btultra
       
   142 STRATEGY_BTULTRA2 = lib.ZSTD_btultra2
       
   143 
       
   144 DICT_TYPE_AUTO = lib.ZSTD_dct_auto
       
   145 DICT_TYPE_RAWCONTENT = lib.ZSTD_dct_rawContent
       
   146 DICT_TYPE_FULLDICT = lib.ZSTD_dct_fullDict
       
   147 
       
   148 FORMAT_ZSTD1 = lib.ZSTD_f_zstd1
       
   149 FORMAT_ZSTD1_MAGICLESS = lib.ZSTD_f_zstd1_magicless
       
   150 
       
   151 FLUSH_BLOCK = 0
       
   152 FLUSH_FRAME = 1
       
   153 
       
   154 COMPRESSOBJ_FLUSH_FINISH = 0
       
   155 COMPRESSOBJ_FLUSH_BLOCK = 1
       
   156 
       
   157 
       
   158 def _cpu_count():
       
   159     # os.cpu_count() was introducd in Python 3.4.
       
   160     try:
       
   161         return os.cpu_count() or 0
       
   162     except AttributeError:
       
   163         pass
       
   164 
       
   165     # Linux.
       
   166     try:
       
   167         if sys.version_info[0] == 2:
       
   168             return os.sysconf(b'SC_NPROCESSORS_ONLN')
       
   169         else:
       
   170             return os.sysconf(u'SC_NPROCESSORS_ONLN')
       
   171     except (AttributeError, ValueError):
       
   172         pass
       
   173 
       
   174     # TODO implement on other platforms.
       
   175     return 0
       
   176 
       
   177 
       
   178 class ZstdError(Exception):
       
   179     pass
       
   180 
       
   181 
       
   182 def _zstd_error(zresult):
       
   183     # Resolves to bytes on Python 2 and 3. We use the string for formatting
       
   184     # into error messages, which will be literal unicode. So convert it to
       
   185     # unicode.
       
   186     return ffi.string(lib.ZSTD_getErrorName(zresult)).decode('utf-8')
       
   187 
       
   188 def _make_cctx_params(params):
       
   189     res = lib.ZSTD_createCCtxParams()
       
   190     if res == ffi.NULL:
       
   191         raise MemoryError()
       
   192 
       
   193     res = ffi.gc(res, lib.ZSTD_freeCCtxParams)
       
   194 
       
   195     attrs = [
       
   196         (lib.ZSTD_c_format, params.format),
       
   197         (lib.ZSTD_c_compressionLevel, params.compression_level),
       
   198         (lib.ZSTD_c_windowLog, params.window_log),
       
   199         (lib.ZSTD_c_hashLog, params.hash_log),
       
   200         (lib.ZSTD_c_chainLog, params.chain_log),
       
   201         (lib.ZSTD_c_searchLog, params.search_log),
       
   202         (lib.ZSTD_c_minMatch, params.min_match),
       
   203         (lib.ZSTD_c_targetLength, params.target_length),
       
   204         (lib.ZSTD_c_strategy, params.compression_strategy),
       
   205         (lib.ZSTD_c_contentSizeFlag, params.write_content_size),
       
   206         (lib.ZSTD_c_checksumFlag, params.write_checksum),
       
   207         (lib.ZSTD_c_dictIDFlag, params.write_dict_id),
       
   208         (lib.ZSTD_c_nbWorkers, params.threads),
       
   209         (lib.ZSTD_c_jobSize, params.job_size),
       
   210         (lib.ZSTD_c_overlapLog, params.overlap_log),
       
   211         (lib.ZSTD_c_forceMaxWindow, params.force_max_window),
       
   212         (lib.ZSTD_c_enableLongDistanceMatching, params.enable_ldm),
       
   213         (lib.ZSTD_c_ldmHashLog, params.ldm_hash_log),
       
   214         (lib.ZSTD_c_ldmMinMatch, params.ldm_min_match),
       
   215         (lib.ZSTD_c_ldmBucketSizeLog, params.ldm_bucket_size_log),
       
   216         (lib.ZSTD_c_ldmHashRateLog, params.ldm_hash_rate_log),
       
   217     ]
       
   218 
       
   219     for param, value in attrs:
       
   220         _set_compression_parameter(res, param, value)
       
   221 
       
   222     return res
       
   223 
       
   224 class ZstdCompressionParameters(object):
       
   225     @staticmethod
       
   226     def from_level(level, source_size=0, dict_size=0, **kwargs):
       
   227         params = lib.ZSTD_getCParams(level, source_size, dict_size)
       
   228 
       
   229         args = {
       
   230             'window_log': 'windowLog',
       
   231             'chain_log': 'chainLog',
       
   232             'hash_log': 'hashLog',
       
   233             'search_log': 'searchLog',
       
   234             'min_match': 'minMatch',
       
   235             'target_length': 'targetLength',
       
   236             'compression_strategy': 'strategy',
       
   237         }
       
   238 
       
   239         for arg, attr in args.items():
       
   240             if arg not in kwargs:
       
   241                 kwargs[arg] = getattr(params, attr)
       
   242 
       
   243         return ZstdCompressionParameters(**kwargs)
       
   244 
       
   245     def __init__(self, format=0, compression_level=0, window_log=0, hash_log=0,
       
   246                  chain_log=0, search_log=0, min_match=0, target_length=0,
       
   247                  strategy=-1, compression_strategy=-1,
       
   248                  write_content_size=1, write_checksum=0,
       
   249                  write_dict_id=0, job_size=0, overlap_log=-1,
       
   250                  overlap_size_log=-1, force_max_window=0, enable_ldm=0,
       
   251                  ldm_hash_log=0, ldm_min_match=0, ldm_bucket_size_log=0,
       
   252                  ldm_hash_rate_log=-1, ldm_hash_every_log=-1, threads=0):
       
   253 
       
   254         params = lib.ZSTD_createCCtxParams()
       
   255         if params == ffi.NULL:
       
   256             raise MemoryError()
       
   257 
       
   258         params = ffi.gc(params, lib.ZSTD_freeCCtxParams)
       
   259 
       
   260         self._params = params
       
   261 
       
   262         if threads < 0:
       
   263             threads = _cpu_count()
       
   264 
       
   265         # We need to set ZSTD_c_nbWorkers before ZSTD_c_jobSize and ZSTD_c_overlapLog
       
   266         # because setting ZSTD_c_nbWorkers resets the other parameters.
       
   267         _set_compression_parameter(params, lib.ZSTD_c_nbWorkers, threads)
       
   268 
       
   269         _set_compression_parameter(params, lib.ZSTD_c_format, format)
       
   270         _set_compression_parameter(params, lib.ZSTD_c_compressionLevel, compression_level)
       
   271         _set_compression_parameter(params, lib.ZSTD_c_windowLog, window_log)
       
   272         _set_compression_parameter(params, lib.ZSTD_c_hashLog, hash_log)
       
   273         _set_compression_parameter(params, lib.ZSTD_c_chainLog, chain_log)
       
   274         _set_compression_parameter(params, lib.ZSTD_c_searchLog, search_log)
       
   275         _set_compression_parameter(params, lib.ZSTD_c_minMatch, min_match)
       
   276         _set_compression_parameter(params, lib.ZSTD_c_targetLength, target_length)
       
   277 
       
   278         if strategy != -1 and compression_strategy != -1:
       
   279             raise ValueError('cannot specify both compression_strategy and strategy')
       
   280 
       
   281         if compression_strategy != -1:
       
   282             strategy = compression_strategy
       
   283         elif strategy == -1:
       
   284             strategy = 0
       
   285 
       
   286         _set_compression_parameter(params, lib.ZSTD_c_strategy, strategy)
       
   287         _set_compression_parameter(params, lib.ZSTD_c_contentSizeFlag, write_content_size)
       
   288         _set_compression_parameter(params, lib.ZSTD_c_checksumFlag, write_checksum)
       
   289         _set_compression_parameter(params, lib.ZSTD_c_dictIDFlag, write_dict_id)
       
   290         _set_compression_parameter(params, lib.ZSTD_c_jobSize, job_size)
       
   291 
       
   292         if overlap_log != -1 and overlap_size_log != -1:
       
   293             raise ValueError('cannot specify both overlap_log and overlap_size_log')
       
   294 
       
   295         if overlap_size_log != -1:
       
   296             overlap_log = overlap_size_log
       
   297         elif overlap_log == -1:
       
   298             overlap_log = 0
       
   299 
       
   300         _set_compression_parameter(params, lib.ZSTD_c_overlapLog, overlap_log)
       
   301         _set_compression_parameter(params, lib.ZSTD_c_forceMaxWindow, force_max_window)
       
   302         _set_compression_parameter(params, lib.ZSTD_c_enableLongDistanceMatching, enable_ldm)
       
   303         _set_compression_parameter(params, lib.ZSTD_c_ldmHashLog, ldm_hash_log)
       
   304         _set_compression_parameter(params, lib.ZSTD_c_ldmMinMatch, ldm_min_match)
       
   305         _set_compression_parameter(params, lib.ZSTD_c_ldmBucketSizeLog, ldm_bucket_size_log)
       
   306 
       
   307         if ldm_hash_rate_log != -1 and ldm_hash_every_log != -1:
       
   308             raise ValueError('cannot specify both ldm_hash_rate_log and ldm_hash_every_log')
       
   309 
       
   310         if ldm_hash_every_log != -1:
       
   311             ldm_hash_rate_log = ldm_hash_every_log
       
   312         elif ldm_hash_rate_log == -1:
       
   313             ldm_hash_rate_log = 0
       
   314 
       
   315         _set_compression_parameter(params, lib.ZSTD_c_ldmHashRateLog, ldm_hash_rate_log)
       
   316 
       
   317     @property
       
   318     def format(self):
       
   319         return _get_compression_parameter(self._params, lib.ZSTD_c_format)
       
   320 
       
   321     @property
       
   322     def compression_level(self):
       
   323         return _get_compression_parameter(self._params, lib.ZSTD_c_compressionLevel)
       
   324 
       
   325     @property
       
   326     def window_log(self):
       
   327         return _get_compression_parameter(self._params, lib.ZSTD_c_windowLog)
       
   328 
       
   329     @property
       
   330     def hash_log(self):
       
   331         return _get_compression_parameter(self._params, lib.ZSTD_c_hashLog)
       
   332 
       
   333     @property
       
   334     def chain_log(self):
       
   335         return _get_compression_parameter(self._params, lib.ZSTD_c_chainLog)
       
   336 
       
   337     @property
       
   338     def search_log(self):
       
   339         return _get_compression_parameter(self._params, lib.ZSTD_c_searchLog)
       
   340 
       
   341     @property
       
   342     def min_match(self):
       
   343         return _get_compression_parameter(self._params, lib.ZSTD_c_minMatch)
       
   344 
       
   345     @property
       
   346     def target_length(self):
       
   347         return _get_compression_parameter(self._params, lib.ZSTD_c_targetLength)
       
   348 
       
   349     @property
       
   350     def compression_strategy(self):
       
   351         return _get_compression_parameter(self._params, lib.ZSTD_c_strategy)
       
   352 
       
   353     @property
       
   354     def write_content_size(self):
       
   355         return _get_compression_parameter(self._params, lib.ZSTD_c_contentSizeFlag)
       
   356 
       
   357     @property
       
   358     def write_checksum(self):
       
   359         return _get_compression_parameter(self._params, lib.ZSTD_c_checksumFlag)
       
   360 
       
   361     @property
       
   362     def write_dict_id(self):
       
   363         return _get_compression_parameter(self._params, lib.ZSTD_c_dictIDFlag)
       
   364 
       
   365     @property
       
   366     def job_size(self):
       
   367         return _get_compression_parameter(self._params, lib.ZSTD_c_jobSize)
       
   368 
       
   369     @property
       
   370     def overlap_log(self):
       
   371         return _get_compression_parameter(self._params, lib.ZSTD_c_overlapLog)
       
   372 
       
   373     @property
       
   374     def overlap_size_log(self):
       
   375         return self.overlap_log
       
   376 
       
   377     @property
       
   378     def force_max_window(self):
       
   379         return _get_compression_parameter(self._params, lib.ZSTD_c_forceMaxWindow)
       
   380 
       
   381     @property
       
   382     def enable_ldm(self):
       
   383         return _get_compression_parameter(self._params, lib.ZSTD_c_enableLongDistanceMatching)
       
   384 
       
   385     @property
       
   386     def ldm_hash_log(self):
       
   387         return _get_compression_parameter(self._params, lib.ZSTD_c_ldmHashLog)
       
   388 
       
   389     @property
       
   390     def ldm_min_match(self):
       
   391         return _get_compression_parameter(self._params, lib.ZSTD_c_ldmMinMatch)
       
   392 
       
   393     @property
       
   394     def ldm_bucket_size_log(self):
       
   395         return _get_compression_parameter(self._params, lib.ZSTD_c_ldmBucketSizeLog)
       
   396 
       
   397     @property
       
   398     def ldm_hash_rate_log(self):
       
   399         return _get_compression_parameter(self._params, lib.ZSTD_c_ldmHashRateLog)
       
   400 
       
   401     @property
       
   402     def ldm_hash_every_log(self):
       
   403         return self.ldm_hash_rate_log
       
   404 
       
   405     @property
       
   406     def threads(self):
       
   407         return _get_compression_parameter(self._params, lib.ZSTD_c_nbWorkers)
       
   408 
       
   409     def estimated_compression_context_size(self):
       
   410         return lib.ZSTD_estimateCCtxSize_usingCCtxParams(self._params)
       
   411 
       
   412 CompressionParameters = ZstdCompressionParameters
       
   413 
       
   414 def estimate_decompression_context_size():
       
   415     return lib.ZSTD_estimateDCtxSize()
       
   416 
       
   417 
       
   418 def _set_compression_parameter(params, param, value):
       
   419     zresult = lib.ZSTD_CCtxParam_setParameter(params, param, value)
       
   420     if lib.ZSTD_isError(zresult):
       
   421         raise ZstdError('unable to set compression context parameter: %s' %
       
   422                         _zstd_error(zresult))
       
   423 
       
   424 
       
   425 def _get_compression_parameter(params, param):
       
   426     result = ffi.new('int *')
       
   427 
       
   428     zresult = lib.ZSTD_CCtxParam_getParameter(params, param, result)
       
   429     if lib.ZSTD_isError(zresult):
       
   430         raise ZstdError('unable to get compression context parameter: %s' %
       
   431                         _zstd_error(zresult))
       
   432 
       
   433     return result[0]
       
   434 
       
   435 
       
   436 class ZstdCompressionWriter(object):
       
   437     def __init__(self, compressor, writer, source_size, write_size,
       
   438                  write_return_read):
       
   439         self._compressor = compressor
       
   440         self._writer = writer
       
   441         self._write_size = write_size
       
   442         self._write_return_read = bool(write_return_read)
       
   443         self._entered = False
       
   444         self._closed = False
       
   445         self._bytes_compressed = 0
       
   446 
       
   447         self._dst_buffer = ffi.new('char[]', write_size)
       
   448         self._out_buffer = ffi.new('ZSTD_outBuffer *')
       
   449         self._out_buffer.dst = self._dst_buffer
       
   450         self._out_buffer.size = len(self._dst_buffer)
       
   451         self._out_buffer.pos = 0
       
   452 
       
   453         zresult = lib.ZSTD_CCtx_setPledgedSrcSize(compressor._cctx,
       
   454                                                   source_size)
       
   455         if lib.ZSTD_isError(zresult):
       
   456             raise ZstdError('error setting source size: %s' %
       
   457                             _zstd_error(zresult))
       
   458 
       
   459     def __enter__(self):
       
   460         if self._closed:
       
   461             raise ValueError('stream is closed')
       
   462 
       
   463         if self._entered:
       
   464             raise ZstdError('cannot __enter__ multiple times')
       
   465 
       
   466         self._entered = True
       
   467         return self
       
   468 
       
   469     def __exit__(self, exc_type, exc_value, exc_tb):
       
   470         self._entered = False
       
   471 
       
   472         if not exc_type and not exc_value and not exc_tb:
       
   473             self.close()
       
   474 
       
   475         self._compressor = None
       
   476 
       
   477         return False
       
   478 
       
   479     def memory_size(self):
       
   480         return lib.ZSTD_sizeof_CCtx(self._compressor._cctx)
       
   481 
       
   482     def fileno(self):
       
   483         f = getattr(self._writer, 'fileno', None)
       
   484         if f:
       
   485             return f()
       
   486         else:
       
   487             raise OSError('fileno not available on underlying writer')
       
   488 
       
   489     def close(self):
       
   490         if self._closed:
       
   491             return
       
   492 
       
   493         try:
       
   494             self.flush(FLUSH_FRAME)
       
   495         finally:
       
   496             self._closed = True
       
   497 
       
   498         # Call close() on underlying stream as well.
       
   499         f = getattr(self._writer, 'close', None)
       
   500         if f:
       
   501             f()
       
   502 
       
   503     @property
       
   504     def closed(self):
       
   505         return self._closed
       
   506 
       
   507     def isatty(self):
       
   508         return False
       
   509 
       
   510     def readable(self):
       
   511         return False
       
   512 
       
   513     def readline(self, size=-1):
       
   514         raise io.UnsupportedOperation()
       
   515 
       
   516     def readlines(self, hint=-1):
       
   517         raise io.UnsupportedOperation()
       
   518 
       
   519     def seek(self, offset, whence=None):
       
   520         raise io.UnsupportedOperation()
       
   521 
       
   522     def seekable(self):
       
   523         return False
       
   524 
       
   525     def truncate(self, size=None):
       
   526         raise io.UnsupportedOperation()
       
   527 
       
   528     def writable(self):
       
   529         return True
       
   530 
       
   531     def writelines(self, lines):
       
   532         raise NotImplementedError('writelines() is not yet implemented')
       
   533 
       
   534     def read(self, size=-1):
       
   535         raise io.UnsupportedOperation()
       
   536 
       
   537     def readall(self):
       
   538         raise io.UnsupportedOperation()
       
   539 
       
   540     def readinto(self, b):
       
   541         raise io.UnsupportedOperation()
       
   542 
       
   543     def write(self, data):
       
   544         if self._closed:
       
   545             raise ValueError('stream is closed')
       
   546 
       
   547         total_write = 0
       
   548 
       
   549         data_buffer = ffi.from_buffer(data)
       
   550 
       
   551         in_buffer = ffi.new('ZSTD_inBuffer *')
       
   552         in_buffer.src = data_buffer
       
   553         in_buffer.size = len(data_buffer)
       
   554         in_buffer.pos = 0
       
   555 
       
   556         out_buffer = self._out_buffer
       
   557         out_buffer.pos = 0
       
   558 
       
   559         while in_buffer.pos < in_buffer.size:
       
   560             zresult = lib.ZSTD_compressStream2(self._compressor._cctx,
       
   561                                                out_buffer, in_buffer,
       
   562                                                lib.ZSTD_e_continue)
       
   563             if lib.ZSTD_isError(zresult):
       
   564                 raise ZstdError('zstd compress error: %s' %
       
   565                                 _zstd_error(zresult))
       
   566 
       
   567             if out_buffer.pos:
       
   568                 self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
       
   569                 total_write += out_buffer.pos
       
   570                 self._bytes_compressed += out_buffer.pos
       
   571                 out_buffer.pos = 0
       
   572 
       
   573         if self._write_return_read:
       
   574             return in_buffer.pos
       
   575         else:
       
   576             return total_write
       
   577 
       
   578     def flush(self, flush_mode=FLUSH_BLOCK):
       
   579         if flush_mode == FLUSH_BLOCK:
       
   580             flush = lib.ZSTD_e_flush
       
   581         elif flush_mode == FLUSH_FRAME:
       
   582             flush = lib.ZSTD_e_end
       
   583         else:
       
   584             raise ValueError('unknown flush_mode: %r' % flush_mode)
       
   585 
       
   586         if self._closed:
       
   587             raise ValueError('stream is closed')
       
   588 
       
   589         total_write = 0
       
   590 
       
   591         out_buffer = self._out_buffer
       
   592         out_buffer.pos = 0
       
   593 
       
   594         in_buffer = ffi.new('ZSTD_inBuffer *')
       
   595         in_buffer.src = ffi.NULL
       
   596         in_buffer.size = 0
       
   597         in_buffer.pos = 0
       
   598 
       
   599         while True:
       
   600             zresult = lib.ZSTD_compressStream2(self._compressor._cctx,
       
   601                                                out_buffer, in_buffer,
       
   602                                                flush)
       
   603             if lib.ZSTD_isError(zresult):
       
   604                 raise ZstdError('zstd compress error: %s' %
       
   605                                 _zstd_error(zresult))
       
   606 
       
   607             if out_buffer.pos:
       
   608                 self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
       
   609                 total_write += out_buffer.pos
       
   610                 self._bytes_compressed += out_buffer.pos
       
   611                 out_buffer.pos = 0
       
   612 
       
   613             if not zresult:
       
   614                 break
       
   615 
       
   616         return total_write
       
   617 
       
   618     def tell(self):
       
   619         return self._bytes_compressed
       
   620 
       
   621 
       
   622 class ZstdCompressionObj(object):
       
   623     def compress(self, data):
       
   624         if self._finished:
       
   625             raise ZstdError('cannot call compress() after compressor finished')
       
   626 
       
   627         data_buffer = ffi.from_buffer(data)
       
   628         source = ffi.new('ZSTD_inBuffer *')
       
   629         source.src = data_buffer
       
   630         source.size = len(data_buffer)
       
   631         source.pos = 0
       
   632 
       
   633         chunks = []
       
   634 
       
   635         while source.pos < len(data):
       
   636             zresult = lib.ZSTD_compressStream2(self._compressor._cctx,
       
   637                                                self._out,
       
   638                                                source,
       
   639                                                lib.ZSTD_e_continue)
       
   640             if lib.ZSTD_isError(zresult):
       
   641                 raise ZstdError('zstd compress error: %s' %
       
   642                                 _zstd_error(zresult))
       
   643 
       
   644             if self._out.pos:
       
   645                 chunks.append(ffi.buffer(self._out.dst, self._out.pos)[:])
       
   646                 self._out.pos = 0
       
   647 
       
   648         return b''.join(chunks)
       
   649 
       
   650     def flush(self, flush_mode=COMPRESSOBJ_FLUSH_FINISH):
       
   651         if flush_mode not in (COMPRESSOBJ_FLUSH_FINISH, COMPRESSOBJ_FLUSH_BLOCK):
       
   652             raise ValueError('flush mode not recognized')
       
   653 
       
   654         if self._finished:
       
   655             raise ZstdError('compressor object already finished')
       
   656 
       
   657         if flush_mode == COMPRESSOBJ_FLUSH_BLOCK:
       
   658             z_flush_mode = lib.ZSTD_e_flush
       
   659         elif flush_mode == COMPRESSOBJ_FLUSH_FINISH:
       
   660             z_flush_mode = lib.ZSTD_e_end
       
   661             self._finished = True
       
   662         else:
       
   663             raise ZstdError('unhandled flush mode')
       
   664 
       
   665         assert self._out.pos == 0
       
   666 
       
   667         in_buffer = ffi.new('ZSTD_inBuffer *')
       
   668         in_buffer.src = ffi.NULL
       
   669         in_buffer.size = 0
       
   670         in_buffer.pos = 0
       
   671 
       
   672         chunks = []
       
   673 
       
   674         while True:
       
   675             zresult = lib.ZSTD_compressStream2(self._compressor._cctx,
       
   676                                                self._out,
       
   677                                                in_buffer,
       
   678                                                z_flush_mode)
       
   679             if lib.ZSTD_isError(zresult):
       
   680                 raise ZstdError('error ending compression stream: %s' %
       
   681                                 _zstd_error(zresult))
       
   682 
       
   683             if self._out.pos:
       
   684                 chunks.append(ffi.buffer(self._out.dst, self._out.pos)[:])
       
   685                 self._out.pos = 0
       
   686 
       
   687             if not zresult:
       
   688                 break
       
   689 
       
   690         return b''.join(chunks)
       
   691 
       
   692 
       
   693 class ZstdCompressionChunker(object):
       
   694     def __init__(self, compressor, chunk_size):
       
   695         self._compressor = compressor
       
   696         self._out = ffi.new('ZSTD_outBuffer *')
       
   697         self._dst_buffer = ffi.new('char[]', chunk_size)
       
   698         self._out.dst = self._dst_buffer
       
   699         self._out.size = chunk_size
       
   700         self._out.pos = 0
       
   701 
       
   702         self._in = ffi.new('ZSTD_inBuffer *')
       
   703         self._in.src = ffi.NULL
       
   704         self._in.size = 0
       
   705         self._in.pos = 0
       
   706         self._finished = False
       
   707 
       
   708     def compress(self, data):
       
   709         if self._finished:
       
   710             raise ZstdError('cannot call compress() after compression finished')
       
   711 
       
   712         if self._in.src != ffi.NULL:
       
   713             raise ZstdError('cannot perform operation before consuming output '
       
   714                             'from previous operation')
       
   715 
       
   716         data_buffer = ffi.from_buffer(data)
       
   717 
       
   718         if not len(data_buffer):
       
   719             return
       
   720 
       
   721         self._in.src = data_buffer
       
   722         self._in.size = len(data_buffer)
       
   723         self._in.pos = 0
       
   724 
       
   725         while self._in.pos < self._in.size:
       
   726             zresult = lib.ZSTD_compressStream2(self._compressor._cctx,
       
   727                                                self._out,
       
   728                                                self._in,
       
   729                                                lib.ZSTD_e_continue)
       
   730 
       
   731             if self._in.pos == self._in.size:
       
   732                 self._in.src = ffi.NULL
       
   733                 self._in.size = 0
       
   734                 self._in.pos = 0
       
   735 
       
   736             if lib.ZSTD_isError(zresult):
       
   737                 raise ZstdError('zstd compress error: %s' %
       
   738                                 _zstd_error(zresult))
       
   739 
       
   740             if self._out.pos == self._out.size:
       
   741                 yield ffi.buffer(self._out.dst, self._out.pos)[:]
       
   742                 self._out.pos = 0
       
   743 
       
   744     def flush(self):
       
   745         if self._finished:
       
   746             raise ZstdError('cannot call flush() after compression finished')
       
   747 
       
   748         if self._in.src != ffi.NULL:
       
   749             raise ZstdError('cannot call flush() before consuming output from '
       
   750                             'previous operation')
       
   751 
       
   752         while True:
       
   753             zresult = lib.ZSTD_compressStream2(self._compressor._cctx,
       
   754                                                self._out, self._in,
       
   755                                                lib.ZSTD_e_flush)
       
   756             if lib.ZSTD_isError(zresult):
       
   757                 raise ZstdError('zstd compress error: %s' % _zstd_error(zresult))
       
   758 
       
   759             if self._out.pos:
       
   760                 yield ffi.buffer(self._out.dst, self._out.pos)[:]
       
   761                 self._out.pos = 0
       
   762 
       
   763             if not zresult:
       
   764                 return
       
   765 
       
   766     def finish(self):
       
   767         if self._finished:
       
   768             raise ZstdError('cannot call finish() after compression finished')
       
   769 
       
   770         if self._in.src != ffi.NULL:
       
   771             raise ZstdError('cannot call finish() before consuming output from '
       
   772                             'previous operation')
       
   773 
       
   774         while True:
       
   775             zresult = lib.ZSTD_compressStream2(self._compressor._cctx,
       
   776                                                self._out, self._in,
       
   777                                                lib.ZSTD_e_end)
       
   778             if lib.ZSTD_isError(zresult):
       
   779                 raise ZstdError('zstd compress error: %s' % _zstd_error(zresult))
       
   780 
       
   781             if self._out.pos:
       
   782                 yield ffi.buffer(self._out.dst, self._out.pos)[:]
       
   783                 self._out.pos = 0
       
   784 
       
   785             if not zresult:
       
   786                 self._finished = True
       
   787                 return
       
   788 
       
   789 
       
   790 class ZstdCompressionReader(object):
       
   791     def __init__(self, compressor, source, read_size):
       
   792         self._compressor = compressor
       
   793         self._source = source
       
   794         self._read_size = read_size
       
   795         self._entered = False
       
   796         self._closed = False
       
   797         self._bytes_compressed = 0
       
   798         self._finished_input = False
       
   799         self._finished_output = False
       
   800 
       
   801         self._in_buffer = ffi.new('ZSTD_inBuffer *')
       
   802         # Holds a ref so backing bytes in self._in_buffer stay alive.
       
   803         self._source_buffer = None
       
   804 
       
   805     def __enter__(self):
       
   806         if self._entered:
       
   807             raise ValueError('cannot __enter__ multiple times')
       
   808 
       
   809         self._entered = True
       
   810         return self
       
   811 
       
   812     def __exit__(self, exc_type, exc_value, exc_tb):
       
   813         self._entered = False
       
   814         self._closed = True
       
   815         self._source = None
       
   816         self._compressor = None
       
   817 
       
   818         return False
       
   819 
       
   820     def readable(self):
       
   821         return True
       
   822 
       
   823     def writable(self):
       
   824         return False
       
   825 
       
   826     def seekable(self):
       
   827         return False
       
   828 
       
   829     def readline(self):
       
   830         raise io.UnsupportedOperation()
       
   831 
       
   832     def readlines(self):
       
   833         raise io.UnsupportedOperation()
       
   834 
       
   835     def write(self, data):
       
   836         raise OSError('stream is not writable')
       
   837 
       
   838     def writelines(self, ignored):
       
   839         raise OSError('stream is not writable')
       
   840 
       
   841     def isatty(self):
       
   842         return False
       
   843 
       
   844     def flush(self):
       
   845         return None
       
   846 
       
   847     def close(self):
       
   848         self._closed = True
       
   849         return None
       
   850 
       
   851     @property
       
   852     def closed(self):
       
   853         return self._closed
       
   854 
       
   855     def tell(self):
       
   856         return self._bytes_compressed
       
   857 
       
   858     def readall(self):
       
   859         chunks = []
       
   860 
       
   861         while True:
       
   862             chunk = self.read(1048576)
       
   863             if not chunk:
       
   864                 break
       
   865 
       
   866             chunks.append(chunk)
       
   867 
       
   868         return b''.join(chunks)
       
   869 
       
   870     def __iter__(self):
       
   871         raise io.UnsupportedOperation()
       
   872 
       
   873     def __next__(self):
       
   874         raise io.UnsupportedOperation()
       
   875 
       
   876     next = __next__
       
   877 
       
   878     def _read_input(self):
       
   879         if self._finished_input:
       
   880             return
       
   881 
       
   882         if hasattr(self._source, 'read'):
       
   883             data = self._source.read(self._read_size)
       
   884 
       
   885             if not data:
       
   886                 self._finished_input = True
       
   887                 return
       
   888 
       
   889             self._source_buffer = ffi.from_buffer(data)
       
   890             self._in_buffer.src = self._source_buffer
       
   891             self._in_buffer.size = len(self._source_buffer)
       
   892             self._in_buffer.pos = 0
       
   893         else:
       
   894             self._source_buffer = ffi.from_buffer(self._source)
       
   895             self._in_buffer.src = self._source_buffer
       
   896             self._in_buffer.size = len(self._source_buffer)
       
   897             self._in_buffer.pos = 0
       
   898 
       
   899     def _compress_into_buffer(self, out_buffer):
       
   900         if self._in_buffer.pos >= self._in_buffer.size:
       
   901             return
       
   902 
       
   903         old_pos = out_buffer.pos
       
   904 
       
   905         zresult = lib.ZSTD_compressStream2(self._compressor._cctx,
       
   906                                            out_buffer, self._in_buffer,
       
   907                                            lib.ZSTD_e_continue)
       
   908 
       
   909         self._bytes_compressed += out_buffer.pos - old_pos
       
   910 
       
   911         if self._in_buffer.pos == self._in_buffer.size:
       
   912             self._in_buffer.src = ffi.NULL
       
   913             self._in_buffer.pos = 0
       
   914             self._in_buffer.size = 0
       
   915             self._source_buffer = None
       
   916 
       
   917             if not hasattr(self._source, 'read'):
       
   918                 self._finished_input = True
       
   919 
       
   920         if lib.ZSTD_isError(zresult):
       
   921             raise ZstdError('zstd compress error: %s',
       
   922                             _zstd_error(zresult))
       
   923 
       
   924         return out_buffer.pos and out_buffer.pos == out_buffer.size
       
   925 
       
   926     def read(self, size=-1):
       
   927         if self._closed:
       
   928             raise ValueError('stream is closed')
       
   929 
       
   930         if size < -1:
       
   931             raise ValueError('cannot read negative amounts less than -1')
       
   932 
       
   933         if size == -1:
       
   934             return self.readall()
       
   935 
       
   936         if self._finished_output or size == 0:
       
   937             return b''
       
   938 
       
   939         # Need a dedicated ref to dest buffer otherwise it gets collected.
       
   940         dst_buffer = ffi.new('char[]', size)
       
   941         out_buffer = ffi.new('ZSTD_outBuffer *')
       
   942         out_buffer.dst = dst_buffer
       
   943         out_buffer.size = size
       
   944         out_buffer.pos = 0
       
   945 
       
   946         if self._compress_into_buffer(out_buffer):
       
   947             return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
       
   948 
       
   949         while not self._finished_input:
       
   950             self._read_input()
       
   951 
       
   952             if self._compress_into_buffer(out_buffer):
       
   953                 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
       
   954 
       
   955         # EOF
       
   956         old_pos = out_buffer.pos
       
   957 
       
   958         zresult = lib.ZSTD_compressStream2(self._compressor._cctx,
       
   959                                            out_buffer, self._in_buffer,
       
   960                                            lib.ZSTD_e_end)
       
   961 
       
   962         self._bytes_compressed += out_buffer.pos - old_pos
       
   963 
       
   964         if lib.ZSTD_isError(zresult):
       
   965             raise ZstdError('error ending compression stream: %s',
       
   966                             _zstd_error(zresult))
       
   967 
       
   968         if zresult == 0:
       
   969             self._finished_output = True
       
   970 
       
   971         return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
       
   972 
       
   973     def read1(self, size=-1):
       
   974         if self._closed:
       
   975             raise ValueError('stream is closed')
       
   976 
       
   977         if size < -1:
       
   978             raise ValueError('cannot read negative amounts less than -1')
       
   979 
       
   980         if self._finished_output or size == 0:
       
   981             return b''
       
   982 
       
   983         # -1 returns arbitrary number of bytes.
       
   984         if size == -1:
       
   985             size = COMPRESSION_RECOMMENDED_OUTPUT_SIZE
       
   986 
       
   987         dst_buffer = ffi.new('char[]', size)
       
   988         out_buffer = ffi.new('ZSTD_outBuffer *')
       
   989         out_buffer.dst = dst_buffer
       
   990         out_buffer.size = size
       
   991         out_buffer.pos = 0
       
   992 
       
   993         # read1() dictates that we can perform at most 1 call to the
       
   994         # underlying stream to get input. However, we can't satisfy this
       
   995         # restriction with compression because not all input generates output.
       
   996         # It is possible to perform a block flush in order to ensure output.
       
   997         # But this may not be desirable behavior. So we allow multiple read()
       
   998         # to the underlying stream. But unlike read(), we stop once we have
       
   999         # any output.
       
  1000 
       
  1001         self._compress_into_buffer(out_buffer)
       
  1002         if out_buffer.pos:
       
  1003             return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
       
  1004 
       
  1005         while not self._finished_input:
       
  1006             self._read_input()
       
  1007 
       
  1008             # If we've filled the output buffer, return immediately.
       
  1009             if self._compress_into_buffer(out_buffer):
       
  1010                 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
       
  1011 
       
  1012             # If we've populated the output buffer and we're not at EOF,
       
  1013             # also return, as we've satisfied the read1() limits.
       
  1014             if out_buffer.pos and not self._finished_input:
       
  1015                 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
       
  1016 
       
  1017             # Else if we're at EOS and we have room left in the buffer,
       
  1018             # fall through to below and try to add more data to the output.
       
  1019 
       
  1020         # EOF.
       
  1021         old_pos = out_buffer.pos
       
  1022 
       
  1023         zresult = lib.ZSTD_compressStream2(self._compressor._cctx,
       
  1024                                            out_buffer, self._in_buffer,
       
  1025                                            lib.ZSTD_e_end)
       
  1026 
       
  1027         self._bytes_compressed += out_buffer.pos - old_pos
       
  1028 
       
  1029         if lib.ZSTD_isError(zresult):
       
  1030             raise ZstdError('error ending compression stream: %s' %
       
  1031                             _zstd_error(zresult))
       
  1032 
       
  1033         if zresult == 0:
       
  1034             self._finished_output = True
       
  1035 
       
  1036         return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
       
  1037 
       
  1038     def readinto(self, b):
       
  1039         if self._closed:
       
  1040             raise ValueError('stream is closed')
       
  1041 
       
  1042         if self._finished_output:
       
  1043             return 0
       
  1044 
       
  1045         # TODO use writable=True once we require CFFI >= 1.12.
       
  1046         dest_buffer = ffi.from_buffer(b)
       
  1047         ffi.memmove(b, b'', 0)
       
  1048         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  1049         out_buffer.dst = dest_buffer
       
  1050         out_buffer.size = len(dest_buffer)
       
  1051         out_buffer.pos = 0
       
  1052 
       
  1053         if self._compress_into_buffer(out_buffer):
       
  1054             return out_buffer.pos
       
  1055 
       
  1056         while not self._finished_input:
       
  1057             self._read_input()
       
  1058             if self._compress_into_buffer(out_buffer):
       
  1059                 return out_buffer.pos
       
  1060 
       
  1061         # EOF.
       
  1062         old_pos = out_buffer.pos
       
  1063         zresult = lib.ZSTD_compressStream2(self._compressor._cctx,
       
  1064                                            out_buffer, self._in_buffer,
       
  1065                                            lib.ZSTD_e_end)
       
  1066 
       
  1067         self._bytes_compressed += out_buffer.pos - old_pos
       
  1068 
       
  1069         if lib.ZSTD_isError(zresult):
       
  1070             raise ZstdError('error ending compression stream: %s',
       
  1071                             _zstd_error(zresult))
       
  1072 
       
  1073         if zresult == 0:
       
  1074             self._finished_output = True
       
  1075 
       
  1076         return out_buffer.pos
       
  1077 
       
  1078     def readinto1(self, b):
       
  1079         if self._closed:
       
  1080             raise ValueError('stream is closed')
       
  1081 
       
  1082         if self._finished_output:
       
  1083             return 0
       
  1084 
       
  1085         # TODO use writable=True once we require CFFI >= 1.12.
       
  1086         dest_buffer = ffi.from_buffer(b)
       
  1087         ffi.memmove(b, b'', 0)
       
  1088 
       
  1089         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  1090         out_buffer.dst = dest_buffer
       
  1091         out_buffer.size = len(dest_buffer)
       
  1092         out_buffer.pos = 0
       
  1093 
       
  1094         self._compress_into_buffer(out_buffer)
       
  1095         if out_buffer.pos:
       
  1096             return out_buffer.pos
       
  1097 
       
  1098         while not self._finished_input:
       
  1099             self._read_input()
       
  1100 
       
  1101             if self._compress_into_buffer(out_buffer):
       
  1102                 return out_buffer.pos
       
  1103 
       
  1104             if out_buffer.pos and not self._finished_input:
       
  1105                 return out_buffer.pos
       
  1106 
       
  1107         # EOF.
       
  1108         old_pos = out_buffer.pos
       
  1109 
       
  1110         zresult = lib.ZSTD_compressStream2(self._compressor._cctx,
       
  1111                                            out_buffer, self._in_buffer,
       
  1112                                            lib.ZSTD_e_end)
       
  1113 
       
  1114         self._bytes_compressed += out_buffer.pos - old_pos
       
  1115 
       
  1116         if lib.ZSTD_isError(zresult):
       
  1117             raise ZstdError('error ending compression stream: %s' %
       
  1118                             _zstd_error(zresult))
       
  1119 
       
  1120         if zresult == 0:
       
  1121             self._finished_output = True
       
  1122 
       
  1123         return out_buffer.pos
       
  1124 
       
  1125 
       
  1126 class ZstdCompressor(object):
       
  1127     def __init__(self, level=3, dict_data=None, compression_params=None,
       
  1128                  write_checksum=None, write_content_size=None,
       
  1129                  write_dict_id=None, threads=0):
       
  1130         if level > lib.ZSTD_maxCLevel():
       
  1131             raise ValueError('level must be less than %d' % lib.ZSTD_maxCLevel())
       
  1132 
       
  1133         if threads < 0:
       
  1134             threads = _cpu_count()
       
  1135 
       
  1136         if compression_params and write_checksum is not None:
       
  1137             raise ValueError('cannot define compression_params and '
       
  1138                              'write_checksum')
       
  1139 
       
  1140         if compression_params and write_content_size is not None:
       
  1141             raise ValueError('cannot define compression_params and '
       
  1142                              'write_content_size')
       
  1143 
       
  1144         if compression_params and write_dict_id is not None:
       
  1145             raise ValueError('cannot define compression_params and '
       
  1146                              'write_dict_id')
       
  1147 
       
  1148         if compression_params and threads:
       
  1149             raise ValueError('cannot define compression_params and threads')
       
  1150 
       
  1151         if compression_params:
       
  1152             self._params = _make_cctx_params(compression_params)
       
  1153         else:
       
  1154             if write_dict_id is None:
       
  1155                 write_dict_id = True
       
  1156 
       
  1157             params = lib.ZSTD_createCCtxParams()
       
  1158             if params == ffi.NULL:
       
  1159                 raise MemoryError()
       
  1160 
       
  1161             self._params = ffi.gc(params, lib.ZSTD_freeCCtxParams)
       
  1162 
       
  1163             _set_compression_parameter(self._params,
       
  1164                                        lib.ZSTD_c_compressionLevel,
       
  1165                                        level)
       
  1166 
       
  1167             _set_compression_parameter(
       
  1168                 self._params,
       
  1169                 lib.ZSTD_c_contentSizeFlag,
       
  1170                 write_content_size if write_content_size is not None else 1)
       
  1171 
       
  1172             _set_compression_parameter(self._params,
       
  1173                                        lib.ZSTD_c_checksumFlag,
       
  1174                                        1 if write_checksum else 0)
       
  1175 
       
  1176             _set_compression_parameter(self._params,
       
  1177                                        lib.ZSTD_c_dictIDFlag,
       
  1178                                        1 if write_dict_id else 0)
       
  1179 
       
  1180             if threads:
       
  1181                 _set_compression_parameter(self._params,
       
  1182                                            lib.ZSTD_c_nbWorkers,
       
  1183                                            threads)
       
  1184 
       
  1185         cctx = lib.ZSTD_createCCtx()
       
  1186         if cctx == ffi.NULL:
       
  1187             raise MemoryError()
       
  1188 
       
  1189         self._cctx = cctx
       
  1190         self._dict_data = dict_data
       
  1191 
       
  1192         # We defer setting up garbage collection until after calling
       
  1193         # _setup_cctx() to ensure the memory size estimate is more accurate.
       
  1194         try:
       
  1195             self._setup_cctx()
       
  1196         finally:
       
  1197             self._cctx = ffi.gc(cctx, lib.ZSTD_freeCCtx,
       
  1198                                 size=lib.ZSTD_sizeof_CCtx(cctx))
       
  1199 
       
  1200     def _setup_cctx(self):
       
  1201         zresult = lib.ZSTD_CCtx_setParametersUsingCCtxParams(self._cctx,
       
  1202                                                              self._params)
       
  1203         if lib.ZSTD_isError(zresult):
       
  1204             raise ZstdError('could not set compression parameters: %s' %
       
  1205                             _zstd_error(zresult))
       
  1206 
       
  1207         dict_data = self._dict_data
       
  1208 
       
  1209         if dict_data:
       
  1210             if dict_data._cdict:
       
  1211                 zresult = lib.ZSTD_CCtx_refCDict(self._cctx, dict_data._cdict)
       
  1212             else:
       
  1213                 zresult = lib.ZSTD_CCtx_loadDictionary_advanced(
       
  1214                     self._cctx, dict_data.as_bytes(), len(dict_data),
       
  1215                     lib.ZSTD_dlm_byRef, dict_data._dict_type)
       
  1216 
       
  1217             if lib.ZSTD_isError(zresult):
       
  1218                 raise ZstdError('could not load compression dictionary: %s' %
       
  1219                                 _zstd_error(zresult))
       
  1220 
       
  1221     def memory_size(self):
       
  1222         return lib.ZSTD_sizeof_CCtx(self._cctx)
       
  1223 
       
  1224     def compress(self, data):
       
  1225         lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
       
  1226 
       
  1227         data_buffer = ffi.from_buffer(data)
       
  1228 
       
  1229         dest_size = lib.ZSTD_compressBound(len(data_buffer))
       
  1230         out = new_nonzero('char[]', dest_size)
       
  1231 
       
  1232         zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, len(data_buffer))
       
  1233         if lib.ZSTD_isError(zresult):
       
  1234             raise ZstdError('error setting source size: %s' %
       
  1235                             _zstd_error(zresult))
       
  1236 
       
  1237         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  1238         in_buffer = ffi.new('ZSTD_inBuffer *')
       
  1239 
       
  1240         out_buffer.dst = out
       
  1241         out_buffer.size = dest_size
       
  1242         out_buffer.pos = 0
       
  1243 
       
  1244         in_buffer.src = data_buffer
       
  1245         in_buffer.size = len(data_buffer)
       
  1246         in_buffer.pos = 0
       
  1247 
       
  1248         zresult = lib.ZSTD_compressStream2(self._cctx,
       
  1249                                            out_buffer,
       
  1250                                            in_buffer,
       
  1251                                            lib.ZSTD_e_end)
       
  1252 
       
  1253         if lib.ZSTD_isError(zresult):
       
  1254             raise ZstdError('cannot compress: %s' %
       
  1255                             _zstd_error(zresult))
       
  1256         elif zresult:
       
  1257             raise ZstdError('unexpected partial frame flush')
       
  1258 
       
  1259         return ffi.buffer(out, out_buffer.pos)[:]
       
  1260 
       
  1261     def compressobj(self, size=-1):
       
  1262         lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
       
  1263 
       
  1264         if size < 0:
       
  1265             size = lib.ZSTD_CONTENTSIZE_UNKNOWN
       
  1266 
       
  1267         zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
       
  1268         if lib.ZSTD_isError(zresult):
       
  1269             raise ZstdError('error setting source size: %s' %
       
  1270                             _zstd_error(zresult))
       
  1271 
       
  1272         cobj = ZstdCompressionObj()
       
  1273         cobj._out = ffi.new('ZSTD_outBuffer *')
       
  1274         cobj._dst_buffer = ffi.new('char[]', COMPRESSION_RECOMMENDED_OUTPUT_SIZE)
       
  1275         cobj._out.dst = cobj._dst_buffer
       
  1276         cobj._out.size = COMPRESSION_RECOMMENDED_OUTPUT_SIZE
       
  1277         cobj._out.pos = 0
       
  1278         cobj._compressor = self
       
  1279         cobj._finished = False
       
  1280 
       
  1281         return cobj
       
  1282 
       
  1283     def chunker(self, size=-1, chunk_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE):
       
  1284         lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
       
  1285 
       
  1286         if size < 0:
       
  1287             size = lib.ZSTD_CONTENTSIZE_UNKNOWN
       
  1288 
       
  1289         zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
       
  1290         if lib.ZSTD_isError(zresult):
       
  1291             raise ZstdError('error setting source size: %s' %
       
  1292                             _zstd_error(zresult))
       
  1293 
       
  1294         return ZstdCompressionChunker(self, chunk_size=chunk_size)
       
  1295 
       
  1296     def copy_stream(self, ifh, ofh, size=-1,
       
  1297                     read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE,
       
  1298                     write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE):
       
  1299 
       
  1300         if not hasattr(ifh, 'read'):
       
  1301             raise ValueError('first argument must have a read() method')
       
  1302         if not hasattr(ofh, 'write'):
       
  1303             raise ValueError('second argument must have a write() method')
       
  1304 
       
  1305         lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
       
  1306 
       
  1307         if size < 0:
       
  1308             size = lib.ZSTD_CONTENTSIZE_UNKNOWN
       
  1309 
       
  1310         zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
       
  1311         if lib.ZSTD_isError(zresult):
       
  1312             raise ZstdError('error setting source size: %s' %
       
  1313                             _zstd_error(zresult))
       
  1314 
       
  1315         in_buffer = ffi.new('ZSTD_inBuffer *')
       
  1316         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  1317 
       
  1318         dst_buffer = ffi.new('char[]', write_size)
       
  1319         out_buffer.dst = dst_buffer
       
  1320         out_buffer.size = write_size
       
  1321         out_buffer.pos = 0
       
  1322 
       
  1323         total_read, total_write = 0, 0
       
  1324 
       
  1325         while True:
       
  1326             data = ifh.read(read_size)
       
  1327             if not data:
       
  1328                 break
       
  1329 
       
  1330             data_buffer = ffi.from_buffer(data)
       
  1331             total_read += len(data_buffer)
       
  1332             in_buffer.src = data_buffer
       
  1333             in_buffer.size = len(data_buffer)
       
  1334             in_buffer.pos = 0
       
  1335 
       
  1336             while in_buffer.pos < in_buffer.size:
       
  1337                 zresult = lib.ZSTD_compressStream2(self._cctx,
       
  1338                                                    out_buffer,
       
  1339                                                    in_buffer,
       
  1340                                                    lib.ZSTD_e_continue)
       
  1341                 if lib.ZSTD_isError(zresult):
       
  1342                     raise ZstdError('zstd compress error: %s' %
       
  1343                                     _zstd_error(zresult))
       
  1344 
       
  1345                 if out_buffer.pos:
       
  1346                     ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
       
  1347                     total_write += out_buffer.pos
       
  1348                     out_buffer.pos = 0
       
  1349 
       
  1350         # We've finished reading. Flush the compressor.
       
  1351         while True:
       
  1352             zresult = lib.ZSTD_compressStream2(self._cctx,
       
  1353                                                out_buffer,
       
  1354                                                in_buffer,
       
  1355                                                lib.ZSTD_e_end)
       
  1356             if lib.ZSTD_isError(zresult):
       
  1357                 raise ZstdError('error ending compression stream: %s' %
       
  1358                                 _zstd_error(zresult))
       
  1359 
       
  1360             if out_buffer.pos:
       
  1361                 ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
       
  1362                 total_write += out_buffer.pos
       
  1363                 out_buffer.pos = 0
       
  1364 
       
  1365             if zresult == 0:
       
  1366                 break
       
  1367 
       
  1368         return total_read, total_write
       
  1369 
       
  1370     def stream_reader(self, source, size=-1,
       
  1371                       read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE):
       
  1372         lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
       
  1373 
       
  1374         try:
       
  1375             size = len(source)
       
  1376         except Exception:
       
  1377             pass
       
  1378 
       
  1379         if size < 0:
       
  1380             size = lib.ZSTD_CONTENTSIZE_UNKNOWN
       
  1381 
       
  1382         zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
       
  1383         if lib.ZSTD_isError(zresult):
       
  1384             raise ZstdError('error setting source size: %s' %
       
  1385                             _zstd_error(zresult))
       
  1386 
       
  1387         return ZstdCompressionReader(self, source, read_size)
       
  1388 
       
  1389     def stream_writer(self, writer, size=-1,
       
  1390                  write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE,
       
  1391                  write_return_read=False):
       
  1392 
       
  1393         if not hasattr(writer, 'write'):
       
  1394             raise ValueError('must pass an object with a write() method')
       
  1395 
       
  1396         lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
       
  1397 
       
  1398         if size < 0:
       
  1399             size = lib.ZSTD_CONTENTSIZE_UNKNOWN
       
  1400 
       
  1401         return ZstdCompressionWriter(self, writer, size, write_size,
       
  1402                                      write_return_read)
       
  1403 
       
  1404     write_to = stream_writer
       
  1405 
       
  1406     def read_to_iter(self, reader, size=-1,
       
  1407                      read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE,
       
  1408                      write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE):
       
  1409         if hasattr(reader, 'read'):
       
  1410             have_read = True
       
  1411         elif hasattr(reader, '__getitem__'):
       
  1412             have_read = False
       
  1413             buffer_offset = 0
       
  1414             size = len(reader)
       
  1415         else:
       
  1416             raise ValueError('must pass an object with a read() method or '
       
  1417                              'conforms to buffer protocol')
       
  1418 
       
  1419         lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only)
       
  1420 
       
  1421         if size < 0:
       
  1422             size = lib.ZSTD_CONTENTSIZE_UNKNOWN
       
  1423 
       
  1424         zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size)
       
  1425         if lib.ZSTD_isError(zresult):
       
  1426             raise ZstdError('error setting source size: %s' %
       
  1427                             _zstd_error(zresult))
       
  1428 
       
  1429         in_buffer = ffi.new('ZSTD_inBuffer *')
       
  1430         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  1431 
       
  1432         in_buffer.src = ffi.NULL
       
  1433         in_buffer.size = 0
       
  1434         in_buffer.pos = 0
       
  1435 
       
  1436         dst_buffer = ffi.new('char[]', write_size)
       
  1437         out_buffer.dst = dst_buffer
       
  1438         out_buffer.size = write_size
       
  1439         out_buffer.pos = 0
       
  1440 
       
  1441         while True:
       
  1442             # We should never have output data sitting around after a previous
       
  1443             # iteration.
       
  1444             assert out_buffer.pos == 0
       
  1445 
       
  1446             # Collect input data.
       
  1447             if have_read:
       
  1448                 read_result = reader.read(read_size)
       
  1449             else:
       
  1450                 remaining = len(reader) - buffer_offset
       
  1451                 slice_size = min(remaining, read_size)
       
  1452                 read_result = reader[buffer_offset:buffer_offset + slice_size]
       
  1453                 buffer_offset += slice_size
       
  1454 
       
  1455             # No new input data. Break out of the read loop.
       
  1456             if not read_result:
       
  1457                 break
       
  1458 
       
  1459             # Feed all read data into the compressor and emit output until
       
  1460             # exhausted.
       
  1461             read_buffer = ffi.from_buffer(read_result)
       
  1462             in_buffer.src = read_buffer
       
  1463             in_buffer.size = len(read_buffer)
       
  1464             in_buffer.pos = 0
       
  1465 
       
  1466             while in_buffer.pos < in_buffer.size:
       
  1467                 zresult = lib.ZSTD_compressStream2(self._cctx, out_buffer, in_buffer,
       
  1468                                                    lib.ZSTD_e_continue)
       
  1469                 if lib.ZSTD_isError(zresult):
       
  1470                     raise ZstdError('zstd compress error: %s' %
       
  1471                                     _zstd_error(zresult))
       
  1472 
       
  1473                 if out_buffer.pos:
       
  1474                     data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
       
  1475                     out_buffer.pos = 0
       
  1476                     yield data
       
  1477 
       
  1478             assert out_buffer.pos == 0
       
  1479 
       
  1480             # And repeat the loop to collect more data.
       
  1481             continue
       
  1482 
       
  1483         # If we get here, input is exhausted. End the stream and emit what
       
  1484         # remains.
       
  1485         while True:
       
  1486             assert out_buffer.pos == 0
       
  1487             zresult = lib.ZSTD_compressStream2(self._cctx,
       
  1488                                                out_buffer,
       
  1489                                                in_buffer,
       
  1490                                                lib.ZSTD_e_end)
       
  1491             if lib.ZSTD_isError(zresult):
       
  1492                 raise ZstdError('error ending compression stream: %s' %
       
  1493                                 _zstd_error(zresult))
       
  1494 
       
  1495             if out_buffer.pos:
       
  1496                 data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
       
  1497                 out_buffer.pos = 0
       
  1498                 yield data
       
  1499 
       
  1500             if zresult == 0:
       
  1501                 break
       
  1502 
       
  1503     read_from = read_to_iter
       
  1504 
       
  1505     def frame_progression(self):
       
  1506         progression = lib.ZSTD_getFrameProgression(self._cctx)
       
  1507 
       
  1508         return progression.ingested, progression.consumed, progression.produced
       
  1509 
       
  1510 
       
  1511 class FrameParameters(object):
       
  1512     def __init__(self, fparams):
       
  1513         self.content_size = fparams.frameContentSize
       
  1514         self.window_size = fparams.windowSize
       
  1515         self.dict_id = fparams.dictID
       
  1516         self.has_checksum = bool(fparams.checksumFlag)
       
  1517 
       
  1518 
       
  1519 def frame_content_size(data):
       
  1520     data_buffer = ffi.from_buffer(data)
       
  1521 
       
  1522     size = lib.ZSTD_getFrameContentSize(data_buffer, len(data_buffer))
       
  1523 
       
  1524     if size == lib.ZSTD_CONTENTSIZE_ERROR:
       
  1525         raise ZstdError('error when determining content size')
       
  1526     elif size == lib.ZSTD_CONTENTSIZE_UNKNOWN:
       
  1527         return -1
       
  1528     else:
       
  1529         return size
       
  1530 
       
  1531 
       
  1532 def frame_header_size(data):
       
  1533     data_buffer = ffi.from_buffer(data)
       
  1534 
       
  1535     zresult = lib.ZSTD_frameHeaderSize(data_buffer, len(data_buffer))
       
  1536     if lib.ZSTD_isError(zresult):
       
  1537         raise ZstdError('could not determine frame header size: %s' %
       
  1538                         _zstd_error(zresult))
       
  1539 
       
  1540     return zresult
       
  1541 
       
  1542 
       
  1543 def get_frame_parameters(data):
       
  1544     params = ffi.new('ZSTD_frameHeader *')
       
  1545 
       
  1546     data_buffer = ffi.from_buffer(data)
       
  1547     zresult = lib.ZSTD_getFrameHeader(params, data_buffer, len(data_buffer))
       
  1548     if lib.ZSTD_isError(zresult):
       
  1549         raise ZstdError('cannot get frame parameters: %s' %
       
  1550                         _zstd_error(zresult))
       
  1551 
       
  1552     if zresult:
       
  1553         raise ZstdError('not enough data for frame parameters; need %d bytes' %
       
  1554                         zresult)
       
  1555 
       
  1556     return FrameParameters(params[0])
       
  1557 
       
  1558 
       
  1559 class ZstdCompressionDict(object):
       
  1560     def __init__(self, data, dict_type=DICT_TYPE_AUTO, k=0, d=0):
       
  1561         assert isinstance(data, bytes_type)
       
  1562         self._data = data
       
  1563         self.k = k
       
  1564         self.d = d
       
  1565 
       
  1566         if dict_type not in (DICT_TYPE_AUTO, DICT_TYPE_RAWCONTENT,
       
  1567                              DICT_TYPE_FULLDICT):
       
  1568             raise ValueError('invalid dictionary load mode: %d; must use '
       
  1569                              'DICT_TYPE_* constants')
       
  1570 
       
  1571         self._dict_type = dict_type
       
  1572         self._cdict = None
       
  1573 
       
  1574     def __len__(self):
       
  1575         return len(self._data)
       
  1576 
       
  1577     def dict_id(self):
       
  1578         return int_type(lib.ZDICT_getDictID(self._data, len(self._data)))
       
  1579 
       
  1580     def as_bytes(self):
       
  1581         return self._data
       
  1582 
       
  1583     def precompute_compress(self, level=0, compression_params=None):
       
  1584         if level and compression_params:
       
  1585             raise ValueError('must only specify one of level or '
       
  1586                              'compression_params')
       
  1587 
       
  1588         if not level and not compression_params:
       
  1589             raise ValueError('must specify one of level or compression_params')
       
  1590 
       
  1591         if level:
       
  1592             cparams = lib.ZSTD_getCParams(level, 0, len(self._data))
       
  1593         else:
       
  1594             cparams = ffi.new('ZSTD_compressionParameters')
       
  1595             cparams.chainLog = compression_params.chain_log
       
  1596             cparams.hashLog = compression_params.hash_log
       
  1597             cparams.minMatch = compression_params.min_match
       
  1598             cparams.searchLog = compression_params.search_log
       
  1599             cparams.strategy = compression_params.compression_strategy
       
  1600             cparams.targetLength = compression_params.target_length
       
  1601             cparams.windowLog = compression_params.window_log
       
  1602 
       
  1603         cdict = lib.ZSTD_createCDict_advanced(self._data, len(self._data),
       
  1604                                               lib.ZSTD_dlm_byRef,
       
  1605                                               self._dict_type,
       
  1606                                               cparams,
       
  1607                                               lib.ZSTD_defaultCMem)
       
  1608         if cdict == ffi.NULL:
       
  1609             raise ZstdError('unable to precompute dictionary')
       
  1610 
       
  1611         self._cdict = ffi.gc(cdict, lib.ZSTD_freeCDict,
       
  1612                              size=lib.ZSTD_sizeof_CDict(cdict))
       
  1613 
       
  1614     @property
       
  1615     def _ddict(self):
       
  1616         ddict = lib.ZSTD_createDDict_advanced(self._data, len(self._data),
       
  1617                                               lib.ZSTD_dlm_byRef,
       
  1618                                               self._dict_type,
       
  1619                                               lib.ZSTD_defaultCMem)
       
  1620 
       
  1621         if ddict == ffi.NULL:
       
  1622             raise ZstdError('could not create decompression dict')
       
  1623 
       
  1624         ddict = ffi.gc(ddict, lib.ZSTD_freeDDict,
       
  1625                        size=lib.ZSTD_sizeof_DDict(ddict))
       
  1626         self.__dict__['_ddict'] = ddict
       
  1627 
       
  1628         return ddict
       
  1629 
       
  1630 def train_dictionary(dict_size, samples, k=0, d=0, notifications=0, dict_id=0,
       
  1631                      level=0, steps=0, threads=0):
       
  1632     if not isinstance(samples, list):
       
  1633         raise TypeError('samples must be a list')
       
  1634 
       
  1635     if threads < 0:
       
  1636         threads = _cpu_count()
       
  1637 
       
  1638     total_size = sum(map(len, samples))
       
  1639 
       
  1640     samples_buffer = new_nonzero('char[]', total_size)
       
  1641     sample_sizes = new_nonzero('size_t[]', len(samples))
       
  1642 
       
  1643     offset = 0
       
  1644     for i, sample in enumerate(samples):
       
  1645         if not isinstance(sample, bytes_type):
       
  1646             raise ValueError('samples must be bytes')
       
  1647 
       
  1648         l = len(sample)
       
  1649         ffi.memmove(samples_buffer + offset, sample, l)
       
  1650         offset += l
       
  1651         sample_sizes[i] = l
       
  1652 
       
  1653     dict_data = new_nonzero('char[]', dict_size)
       
  1654 
       
  1655     dparams = ffi.new('ZDICT_cover_params_t *')[0]
       
  1656     dparams.k = k
       
  1657     dparams.d = d
       
  1658     dparams.steps = steps
       
  1659     dparams.nbThreads = threads
       
  1660     dparams.zParams.notificationLevel = notifications
       
  1661     dparams.zParams.dictID = dict_id
       
  1662     dparams.zParams.compressionLevel = level
       
  1663 
       
  1664     if (not dparams.k and not dparams.d and not dparams.steps
       
  1665         and not dparams.nbThreads and not dparams.zParams.notificationLevel
       
  1666         and not dparams.zParams.dictID
       
  1667         and not dparams.zParams.compressionLevel):
       
  1668         zresult = lib.ZDICT_trainFromBuffer(
       
  1669             ffi.addressof(dict_data), dict_size,
       
  1670             ffi.addressof(samples_buffer),
       
  1671             ffi.addressof(sample_sizes, 0), len(samples))
       
  1672     elif dparams.steps or dparams.nbThreads:
       
  1673         zresult = lib.ZDICT_optimizeTrainFromBuffer_cover(
       
  1674             ffi.addressof(dict_data), dict_size,
       
  1675             ffi.addressof(samples_buffer),
       
  1676             ffi.addressof(sample_sizes, 0), len(samples),
       
  1677             ffi.addressof(dparams))
       
  1678     else:
       
  1679         zresult = lib.ZDICT_trainFromBuffer_cover(
       
  1680             ffi.addressof(dict_data), dict_size,
       
  1681             ffi.addressof(samples_buffer),
       
  1682             ffi.addressof(sample_sizes, 0), len(samples),
       
  1683             dparams)
       
  1684 
       
  1685     if lib.ZDICT_isError(zresult):
       
  1686         msg = ffi.string(lib.ZDICT_getErrorName(zresult)).decode('utf-8')
       
  1687         raise ZstdError('cannot train dict: %s' % msg)
       
  1688 
       
  1689     return ZstdCompressionDict(ffi.buffer(dict_data, zresult)[:],
       
  1690                                dict_type=DICT_TYPE_FULLDICT,
       
  1691                                k=dparams.k, d=dparams.d)
       
  1692 
       
  1693 
       
  1694 class ZstdDecompressionObj(object):
       
  1695     def __init__(self, decompressor, write_size):
       
  1696         self._decompressor = decompressor
       
  1697         self._write_size = write_size
       
  1698         self._finished = False
       
  1699 
       
  1700     def decompress(self, data):
       
  1701         if self._finished:
       
  1702             raise ZstdError('cannot use a decompressobj multiple times')
       
  1703 
       
  1704         in_buffer = ffi.new('ZSTD_inBuffer *')
       
  1705         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  1706 
       
  1707         data_buffer = ffi.from_buffer(data)
       
  1708 
       
  1709         if len(data_buffer) == 0:
       
  1710             return b''
       
  1711 
       
  1712         in_buffer.src = data_buffer
       
  1713         in_buffer.size = len(data_buffer)
       
  1714         in_buffer.pos = 0
       
  1715 
       
  1716         dst_buffer = ffi.new('char[]', self._write_size)
       
  1717         out_buffer.dst = dst_buffer
       
  1718         out_buffer.size = len(dst_buffer)
       
  1719         out_buffer.pos = 0
       
  1720 
       
  1721         chunks = []
       
  1722 
       
  1723         while True:
       
  1724             zresult = lib.ZSTD_decompressStream(self._decompressor._dctx,
       
  1725                                                 out_buffer, in_buffer)
       
  1726             if lib.ZSTD_isError(zresult):
       
  1727                 raise ZstdError('zstd decompressor error: %s' %
       
  1728                                 _zstd_error(zresult))
       
  1729 
       
  1730             if zresult == 0:
       
  1731                 self._finished = True
       
  1732                 self._decompressor = None
       
  1733 
       
  1734             if out_buffer.pos:
       
  1735                 chunks.append(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
       
  1736 
       
  1737             if (zresult == 0 or
       
  1738                     (in_buffer.pos == in_buffer.size and out_buffer.pos == 0)):
       
  1739                 break
       
  1740 
       
  1741             out_buffer.pos = 0
       
  1742 
       
  1743         return b''.join(chunks)
       
  1744 
       
  1745     def flush(self, length=0):
       
  1746         pass
       
  1747 
       
  1748 
       
  1749 class ZstdDecompressionReader(object):
       
  1750     def __init__(self, decompressor, source, read_size, read_across_frames):
       
  1751         self._decompressor = decompressor
       
  1752         self._source = source
       
  1753         self._read_size = read_size
       
  1754         self._read_across_frames = bool(read_across_frames)
       
  1755         self._entered = False
       
  1756         self._closed = False
       
  1757         self._bytes_decompressed = 0
       
  1758         self._finished_input = False
       
  1759         self._finished_output = False
       
  1760         self._in_buffer = ffi.new('ZSTD_inBuffer *')
       
  1761         # Holds a ref to self._in_buffer.src.
       
  1762         self._source_buffer = None
       
  1763 
       
  1764     def __enter__(self):
       
  1765         if self._entered:
       
  1766             raise ValueError('cannot __enter__ multiple times')
       
  1767 
       
  1768         self._entered = True
       
  1769         return self
       
  1770 
       
  1771     def __exit__(self, exc_type, exc_value, exc_tb):
       
  1772         self._entered = False
       
  1773         self._closed = True
       
  1774         self._source = None
       
  1775         self._decompressor = None
       
  1776 
       
  1777         return False
       
  1778 
       
  1779     def readable(self):
       
  1780         return True
       
  1781 
       
  1782     def writable(self):
       
  1783         return False
       
  1784 
       
  1785     def seekable(self):
       
  1786         return True
       
  1787 
       
  1788     def readline(self):
       
  1789         raise io.UnsupportedOperation()
       
  1790 
       
  1791     def readlines(self):
       
  1792         raise io.UnsupportedOperation()
       
  1793 
       
  1794     def write(self, data):
       
  1795         raise io.UnsupportedOperation()
       
  1796 
       
  1797     def writelines(self, lines):
       
  1798         raise io.UnsupportedOperation()
       
  1799 
       
  1800     def isatty(self):
       
  1801         return False
       
  1802 
       
  1803     def flush(self):
       
  1804         return None
       
  1805 
       
  1806     def close(self):
       
  1807         self._closed = True
       
  1808         return None
       
  1809 
       
  1810     @property
       
  1811     def closed(self):
       
  1812         return self._closed
       
  1813 
       
  1814     def tell(self):
       
  1815         return self._bytes_decompressed
       
  1816 
       
  1817     def readall(self):
       
  1818         chunks = []
       
  1819 
       
  1820         while True:
       
  1821             chunk = self.read(1048576)
       
  1822             if not chunk:
       
  1823                 break
       
  1824 
       
  1825             chunks.append(chunk)
       
  1826 
       
  1827         return b''.join(chunks)
       
  1828 
       
  1829     def __iter__(self):
       
  1830         raise io.UnsupportedOperation()
       
  1831 
       
  1832     def __next__(self):
       
  1833         raise io.UnsupportedOperation()
       
  1834 
       
  1835     next = __next__
       
  1836 
       
  1837     def _read_input(self):
       
  1838         # We have data left over in the input buffer. Use it.
       
  1839         if self._in_buffer.pos < self._in_buffer.size:
       
  1840             return
       
  1841 
       
  1842         # All input data exhausted. Nothing to do.
       
  1843         if self._finished_input:
       
  1844             return
       
  1845 
       
  1846         # Else populate the input buffer from our source.
       
  1847         if hasattr(self._source, 'read'):
       
  1848             data = self._source.read(self._read_size)
       
  1849 
       
  1850             if not data:
       
  1851                 self._finished_input = True
       
  1852                 return
       
  1853 
       
  1854             self._source_buffer = ffi.from_buffer(data)
       
  1855             self._in_buffer.src = self._source_buffer
       
  1856             self._in_buffer.size = len(self._source_buffer)
       
  1857             self._in_buffer.pos = 0
       
  1858         else:
       
  1859             self._source_buffer = ffi.from_buffer(self._source)
       
  1860             self._in_buffer.src = self._source_buffer
       
  1861             self._in_buffer.size = len(self._source_buffer)
       
  1862             self._in_buffer.pos = 0
       
  1863 
       
  1864     def _decompress_into_buffer(self, out_buffer):
       
  1865         """Decompress available input into an output buffer.
       
  1866 
       
  1867         Returns True if data in output buffer should be emitted.
       
  1868         """
       
  1869         zresult = lib.ZSTD_decompressStream(self._decompressor._dctx,
       
  1870                                             out_buffer, self._in_buffer)
       
  1871 
       
  1872         if self._in_buffer.pos == self._in_buffer.size:
       
  1873             self._in_buffer.src = ffi.NULL
       
  1874             self._in_buffer.pos = 0
       
  1875             self._in_buffer.size = 0
       
  1876             self._source_buffer = None
       
  1877 
       
  1878             if not hasattr(self._source, 'read'):
       
  1879                 self._finished_input = True
       
  1880 
       
  1881         if lib.ZSTD_isError(zresult):
       
  1882             raise ZstdError('zstd decompress error: %s' %
       
  1883                             _zstd_error(zresult))
       
  1884 
       
  1885         # Emit data if there is data AND either:
       
  1886         # a) output buffer is full (read amount is satisfied)
       
  1887         # b) we're at end of a frame and not in frame spanning mode
       
  1888         return (out_buffer.pos and
       
  1889                 (out_buffer.pos == out_buffer.size or
       
  1890                  zresult == 0 and not self._read_across_frames))
       
  1891 
       
  1892     def read(self, size=-1):
       
  1893         if self._closed:
       
  1894             raise ValueError('stream is closed')
       
  1895 
       
  1896         if size < -1:
       
  1897             raise ValueError('cannot read negative amounts less than -1')
       
  1898 
       
  1899         if size == -1:
       
  1900             # This is recursive. But it gets the job done.
       
  1901             return self.readall()
       
  1902 
       
  1903         if self._finished_output or size == 0:
       
  1904             return b''
       
  1905 
       
  1906         # We /could/ call into readinto() here. But that introduces more
       
  1907         # overhead.
       
  1908         dst_buffer = ffi.new('char[]', size)
       
  1909         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  1910         out_buffer.dst = dst_buffer
       
  1911         out_buffer.size = size
       
  1912         out_buffer.pos = 0
       
  1913 
       
  1914         self._read_input()
       
  1915         if self._decompress_into_buffer(out_buffer):
       
  1916             self._bytes_decompressed += out_buffer.pos
       
  1917             return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
       
  1918 
       
  1919         while not self._finished_input:
       
  1920             self._read_input()
       
  1921             if self._decompress_into_buffer(out_buffer):
       
  1922                 self._bytes_decompressed += out_buffer.pos
       
  1923                 return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
       
  1924 
       
  1925         self._bytes_decompressed += out_buffer.pos
       
  1926         return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
       
  1927 
       
  1928     def readinto(self, b):
       
  1929         if self._closed:
       
  1930             raise ValueError('stream is closed')
       
  1931 
       
  1932         if self._finished_output:
       
  1933             return 0
       
  1934 
       
  1935         # TODO use writable=True once we require CFFI >= 1.12.
       
  1936         dest_buffer = ffi.from_buffer(b)
       
  1937         ffi.memmove(b, b'', 0)
       
  1938         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  1939         out_buffer.dst = dest_buffer
       
  1940         out_buffer.size = len(dest_buffer)
       
  1941         out_buffer.pos = 0
       
  1942 
       
  1943         self._read_input()
       
  1944         if self._decompress_into_buffer(out_buffer):
       
  1945             self._bytes_decompressed += out_buffer.pos
       
  1946             return out_buffer.pos
       
  1947 
       
  1948         while not self._finished_input:
       
  1949             self._read_input()
       
  1950             if self._decompress_into_buffer(out_buffer):
       
  1951                 self._bytes_decompressed += out_buffer.pos
       
  1952                 return out_buffer.pos
       
  1953 
       
  1954         self._bytes_decompressed += out_buffer.pos
       
  1955         return out_buffer.pos
       
  1956 
       
  1957     def read1(self, size=-1):
       
  1958         if self._closed:
       
  1959             raise ValueError('stream is closed')
       
  1960 
       
  1961         if size < -1:
       
  1962             raise ValueError('cannot read negative amounts less than -1')
       
  1963 
       
  1964         if self._finished_output or size == 0:
       
  1965             return b''
       
  1966 
       
  1967         # -1 returns arbitrary number of bytes.
       
  1968         if size == -1:
       
  1969             size = DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE
       
  1970 
       
  1971         dst_buffer = ffi.new('char[]', size)
       
  1972         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  1973         out_buffer.dst = dst_buffer
       
  1974         out_buffer.size = size
       
  1975         out_buffer.pos = 0
       
  1976 
       
  1977         # read1() dictates that we can perform at most 1 call to underlying
       
  1978         # stream to get input. However, we can't satisfy this restriction with
       
  1979         # decompression because not all input generates output. So we allow
       
  1980         # multiple read(). But unlike read(), we stop once we have any output.
       
  1981         while not self._finished_input:
       
  1982             self._read_input()
       
  1983             self._decompress_into_buffer(out_buffer)
       
  1984 
       
  1985             if out_buffer.pos:
       
  1986                 break
       
  1987 
       
  1988         self._bytes_decompressed += out_buffer.pos
       
  1989         return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
       
  1990 
       
  1991     def readinto1(self, b):
       
  1992         if self._closed:
       
  1993             raise ValueError('stream is closed')
       
  1994 
       
  1995         if self._finished_output:
       
  1996             return 0
       
  1997 
       
  1998         # TODO use writable=True once we require CFFI >= 1.12.
       
  1999         dest_buffer = ffi.from_buffer(b)
       
  2000         ffi.memmove(b, b'', 0)
       
  2001 
       
  2002         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  2003         out_buffer.dst = dest_buffer
       
  2004         out_buffer.size = len(dest_buffer)
       
  2005         out_buffer.pos = 0
       
  2006 
       
  2007         while not self._finished_input and not self._finished_output:
       
  2008             self._read_input()
       
  2009             self._decompress_into_buffer(out_buffer)
       
  2010 
       
  2011             if out_buffer.pos:
       
  2012                 break
       
  2013 
       
  2014         self._bytes_decompressed += out_buffer.pos
       
  2015         return out_buffer.pos
       
  2016 
       
  2017     def seek(self, pos, whence=os.SEEK_SET):
       
  2018         if self._closed:
       
  2019             raise ValueError('stream is closed')
       
  2020 
       
  2021         read_amount = 0
       
  2022 
       
  2023         if whence == os.SEEK_SET:
       
  2024             if pos < 0:
       
  2025                 raise ValueError('cannot seek to negative position with SEEK_SET')
       
  2026 
       
  2027             if pos < self._bytes_decompressed:
       
  2028                 raise ValueError('cannot seek zstd decompression stream '
       
  2029                                  'backwards')
       
  2030 
       
  2031             read_amount = pos - self._bytes_decompressed
       
  2032 
       
  2033         elif whence == os.SEEK_CUR:
       
  2034             if pos < 0:
       
  2035                 raise ValueError('cannot seek zstd decompression stream '
       
  2036                                  'backwards')
       
  2037 
       
  2038             read_amount = pos
       
  2039         elif whence == os.SEEK_END:
       
  2040             raise ValueError('zstd decompression streams cannot be seeked '
       
  2041                              'with SEEK_END')
       
  2042 
       
  2043         while read_amount:
       
  2044             result = self.read(min(read_amount,
       
  2045                                    DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE))
       
  2046 
       
  2047             if not result:
       
  2048                 break
       
  2049 
       
  2050             read_amount -= len(result)
       
  2051 
       
  2052         return self._bytes_decompressed
       
  2053 
       
  2054 class ZstdDecompressionWriter(object):
       
  2055     def __init__(self, decompressor, writer, write_size, write_return_read):
       
  2056         decompressor._ensure_dctx()
       
  2057 
       
  2058         self._decompressor = decompressor
       
  2059         self._writer = writer
       
  2060         self._write_size = write_size
       
  2061         self._write_return_read = bool(write_return_read)
       
  2062         self._entered = False
       
  2063         self._closed = False
       
  2064 
       
  2065     def __enter__(self):
       
  2066         if self._closed:
       
  2067             raise ValueError('stream is closed')
       
  2068 
       
  2069         if self._entered:
       
  2070             raise ZstdError('cannot __enter__ multiple times')
       
  2071 
       
  2072         self._entered = True
       
  2073 
       
  2074         return self
       
  2075 
       
  2076     def __exit__(self, exc_type, exc_value, exc_tb):
       
  2077         self._entered = False
       
  2078         self.close()
       
  2079 
       
  2080     def memory_size(self):
       
  2081         return lib.ZSTD_sizeof_DCtx(self._decompressor._dctx)
       
  2082 
       
  2083     def close(self):
       
  2084         if self._closed:
       
  2085             return
       
  2086 
       
  2087         try:
       
  2088             self.flush()
       
  2089         finally:
       
  2090             self._closed = True
       
  2091 
       
  2092         f = getattr(self._writer, 'close', None)
       
  2093         if f:
       
  2094             f()
       
  2095 
       
  2096     @property
       
  2097     def closed(self):
       
  2098         return self._closed
       
  2099 
       
  2100     def fileno(self):
       
  2101         f = getattr(self._writer, 'fileno', None)
       
  2102         if f:
       
  2103             return f()
       
  2104         else:
       
  2105             raise OSError('fileno not available on underlying writer')
       
  2106 
       
  2107     def flush(self):
       
  2108         if self._closed:
       
  2109             raise ValueError('stream is closed')
       
  2110 
       
  2111         f = getattr(self._writer, 'flush', None)
       
  2112         if f:
       
  2113             return f()
       
  2114 
       
  2115     def isatty(self):
       
  2116         return False
       
  2117 
       
  2118     def readable(self):
       
  2119         return False
       
  2120 
       
  2121     def readline(self, size=-1):
       
  2122         raise io.UnsupportedOperation()
       
  2123 
       
  2124     def readlines(self, hint=-1):
       
  2125         raise io.UnsupportedOperation()
       
  2126 
       
  2127     def seek(self, offset, whence=None):
       
  2128         raise io.UnsupportedOperation()
       
  2129 
       
  2130     def seekable(self):
       
  2131         return False
       
  2132 
       
  2133     def tell(self):
       
  2134         raise io.UnsupportedOperation()
       
  2135 
       
  2136     def truncate(self, size=None):
       
  2137         raise io.UnsupportedOperation()
       
  2138 
       
  2139     def writable(self):
       
  2140         return True
       
  2141 
       
  2142     def writelines(self, lines):
       
  2143         raise io.UnsupportedOperation()
       
  2144 
       
  2145     def read(self, size=-1):
       
  2146         raise io.UnsupportedOperation()
       
  2147 
       
  2148     def readall(self):
       
  2149         raise io.UnsupportedOperation()
       
  2150 
       
  2151     def readinto(self, b):
       
  2152         raise io.UnsupportedOperation()
       
  2153 
       
  2154     def write(self, data):
       
  2155         if self._closed:
       
  2156             raise ValueError('stream is closed')
       
  2157 
       
  2158         total_write = 0
       
  2159 
       
  2160         in_buffer = ffi.new('ZSTD_inBuffer *')
       
  2161         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  2162 
       
  2163         data_buffer = ffi.from_buffer(data)
       
  2164         in_buffer.src = data_buffer
       
  2165         in_buffer.size = len(data_buffer)
       
  2166         in_buffer.pos = 0
       
  2167 
       
  2168         dst_buffer = ffi.new('char[]', self._write_size)
       
  2169         out_buffer.dst = dst_buffer
       
  2170         out_buffer.size = len(dst_buffer)
       
  2171         out_buffer.pos = 0
       
  2172 
       
  2173         dctx = self._decompressor._dctx
       
  2174 
       
  2175         while in_buffer.pos < in_buffer.size:
       
  2176             zresult = lib.ZSTD_decompressStream(dctx, out_buffer, in_buffer)
       
  2177             if lib.ZSTD_isError(zresult):
       
  2178                 raise ZstdError('zstd decompress error: %s' %
       
  2179                                 _zstd_error(zresult))
       
  2180 
       
  2181             if out_buffer.pos:
       
  2182                 self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
       
  2183                 total_write += out_buffer.pos
       
  2184                 out_buffer.pos = 0
       
  2185 
       
  2186         if self._write_return_read:
       
  2187             return in_buffer.pos
       
  2188         else:
       
  2189             return total_write
       
  2190 
       
  2191 
       
  2192 class ZstdDecompressor(object):
       
  2193     def __init__(self, dict_data=None, max_window_size=0, format=FORMAT_ZSTD1):
       
  2194         self._dict_data = dict_data
       
  2195         self._max_window_size = max_window_size
       
  2196         self._format = format
       
  2197 
       
  2198         dctx = lib.ZSTD_createDCtx()
       
  2199         if dctx == ffi.NULL:
       
  2200             raise MemoryError()
       
  2201 
       
  2202         self._dctx = dctx
       
  2203 
       
  2204         # Defer setting up garbage collection until full state is loaded so
       
  2205         # the memory size is more accurate.
       
  2206         try:
       
  2207             self._ensure_dctx()
       
  2208         finally:
       
  2209             self._dctx = ffi.gc(dctx, lib.ZSTD_freeDCtx,
       
  2210                                 size=lib.ZSTD_sizeof_DCtx(dctx))
       
  2211 
       
  2212     def memory_size(self):
       
  2213         return lib.ZSTD_sizeof_DCtx(self._dctx)
       
  2214 
       
  2215     def decompress(self, data, max_output_size=0):
       
  2216         self._ensure_dctx()
       
  2217 
       
  2218         data_buffer = ffi.from_buffer(data)
       
  2219 
       
  2220         output_size = lib.ZSTD_getFrameContentSize(data_buffer, len(data_buffer))
       
  2221 
       
  2222         if output_size == lib.ZSTD_CONTENTSIZE_ERROR:
       
  2223             raise ZstdError('error determining content size from frame header')
       
  2224         elif output_size == 0:
       
  2225             return b''
       
  2226         elif output_size == lib.ZSTD_CONTENTSIZE_UNKNOWN:
       
  2227             if not max_output_size:
       
  2228                 raise ZstdError('could not determine content size in frame header')
       
  2229 
       
  2230             result_buffer = ffi.new('char[]', max_output_size)
       
  2231             result_size = max_output_size
       
  2232             output_size = 0
       
  2233         else:
       
  2234             result_buffer = ffi.new('char[]', output_size)
       
  2235             result_size = output_size
       
  2236 
       
  2237         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  2238         out_buffer.dst = result_buffer
       
  2239         out_buffer.size = result_size
       
  2240         out_buffer.pos = 0
       
  2241 
       
  2242         in_buffer = ffi.new('ZSTD_inBuffer *')
       
  2243         in_buffer.src = data_buffer
       
  2244         in_buffer.size = len(data_buffer)
       
  2245         in_buffer.pos = 0
       
  2246 
       
  2247         zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer)
       
  2248         if lib.ZSTD_isError(zresult):
       
  2249             raise ZstdError('decompression error: %s' %
       
  2250                             _zstd_error(zresult))
       
  2251         elif zresult:
       
  2252             raise ZstdError('decompression error: did not decompress full frame')
       
  2253         elif output_size and out_buffer.pos != output_size:
       
  2254             raise ZstdError('decompression error: decompressed %d bytes; expected %d' %
       
  2255                             (zresult, output_size))
       
  2256 
       
  2257         return ffi.buffer(result_buffer, out_buffer.pos)[:]
       
  2258 
       
  2259     def stream_reader(self, source, read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE,
       
  2260                       read_across_frames=False):
       
  2261         self._ensure_dctx()
       
  2262         return ZstdDecompressionReader(self, source, read_size, read_across_frames)
       
  2263 
       
  2264     def decompressobj(self, write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE):
       
  2265         if write_size < 1:
       
  2266             raise ValueError('write_size must be positive')
       
  2267 
       
  2268         self._ensure_dctx()
       
  2269         return ZstdDecompressionObj(self, write_size=write_size)
       
  2270 
       
  2271     def read_to_iter(self, reader, read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE,
       
  2272                      write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE,
       
  2273                      skip_bytes=0):
       
  2274         if skip_bytes >= read_size:
       
  2275             raise ValueError('skip_bytes must be smaller than read_size')
       
  2276 
       
  2277         if hasattr(reader, 'read'):
       
  2278             have_read = True
       
  2279         elif hasattr(reader, '__getitem__'):
       
  2280             have_read = False
       
  2281             buffer_offset = 0
       
  2282             size = len(reader)
       
  2283         else:
       
  2284             raise ValueError('must pass an object with a read() method or '
       
  2285                              'conforms to buffer protocol')
       
  2286 
       
  2287         if skip_bytes:
       
  2288             if have_read:
       
  2289                 reader.read(skip_bytes)
       
  2290             else:
       
  2291                 if skip_bytes > size:
       
  2292                     raise ValueError('skip_bytes larger than first input chunk')
       
  2293 
       
  2294                 buffer_offset = skip_bytes
       
  2295 
       
  2296         self._ensure_dctx()
       
  2297 
       
  2298         in_buffer = ffi.new('ZSTD_inBuffer *')
       
  2299         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  2300 
       
  2301         dst_buffer = ffi.new('char[]', write_size)
       
  2302         out_buffer.dst = dst_buffer
       
  2303         out_buffer.size = len(dst_buffer)
       
  2304         out_buffer.pos = 0
       
  2305 
       
  2306         while True:
       
  2307             assert out_buffer.pos == 0
       
  2308 
       
  2309             if have_read:
       
  2310                 read_result = reader.read(read_size)
       
  2311             else:
       
  2312                 remaining = size - buffer_offset
       
  2313                 slice_size = min(remaining, read_size)
       
  2314                 read_result = reader[buffer_offset:buffer_offset + slice_size]
       
  2315                 buffer_offset += slice_size
       
  2316 
       
  2317             # No new input. Break out of read loop.
       
  2318             if not read_result:
       
  2319                 break
       
  2320 
       
  2321             # Feed all read data into decompressor and emit output until
       
  2322             # exhausted.
       
  2323             read_buffer = ffi.from_buffer(read_result)
       
  2324             in_buffer.src = read_buffer
       
  2325             in_buffer.size = len(read_buffer)
       
  2326             in_buffer.pos = 0
       
  2327 
       
  2328             while in_buffer.pos < in_buffer.size:
       
  2329                 assert out_buffer.pos == 0
       
  2330 
       
  2331                 zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer)
       
  2332                 if lib.ZSTD_isError(zresult):
       
  2333                     raise ZstdError('zstd decompress error: %s' %
       
  2334                                     _zstd_error(zresult))
       
  2335 
       
  2336                 if out_buffer.pos:
       
  2337                     data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:]
       
  2338                     out_buffer.pos = 0
       
  2339                     yield data
       
  2340 
       
  2341                 if zresult == 0:
       
  2342                     return
       
  2343 
       
  2344             # Repeat loop to collect more input data.
       
  2345             continue
       
  2346 
       
  2347         # If we get here, input is exhausted.
       
  2348 
       
  2349     read_from = read_to_iter
       
  2350 
       
  2351     def stream_writer(self, writer, write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE,
       
  2352                       write_return_read=False):
       
  2353         if not hasattr(writer, 'write'):
       
  2354             raise ValueError('must pass an object with a write() method')
       
  2355 
       
  2356         return ZstdDecompressionWriter(self, writer, write_size,
       
  2357                                        write_return_read)
       
  2358 
       
  2359     write_to = stream_writer
       
  2360 
       
  2361     def copy_stream(self, ifh, ofh,
       
  2362                     read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE,
       
  2363                     write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE):
       
  2364         if not hasattr(ifh, 'read'):
       
  2365             raise ValueError('first argument must have a read() method')
       
  2366         if not hasattr(ofh, 'write'):
       
  2367             raise ValueError('second argument must have a write() method')
       
  2368 
       
  2369         self._ensure_dctx()
       
  2370 
       
  2371         in_buffer = ffi.new('ZSTD_inBuffer *')
       
  2372         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  2373 
       
  2374         dst_buffer = ffi.new('char[]', write_size)
       
  2375         out_buffer.dst = dst_buffer
       
  2376         out_buffer.size = write_size
       
  2377         out_buffer.pos = 0
       
  2378 
       
  2379         total_read, total_write = 0, 0
       
  2380 
       
  2381         # Read all available input.
       
  2382         while True:
       
  2383             data = ifh.read(read_size)
       
  2384             if not data:
       
  2385                 break
       
  2386 
       
  2387             data_buffer = ffi.from_buffer(data)
       
  2388             total_read += len(data_buffer)
       
  2389             in_buffer.src = data_buffer
       
  2390             in_buffer.size = len(data_buffer)
       
  2391             in_buffer.pos = 0
       
  2392 
       
  2393             # Flush all read data to output.
       
  2394             while in_buffer.pos < in_buffer.size:
       
  2395                 zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer)
       
  2396                 if lib.ZSTD_isError(zresult):
       
  2397                     raise ZstdError('zstd decompressor error: %s' %
       
  2398                                     _zstd_error(zresult))
       
  2399 
       
  2400                 if out_buffer.pos:
       
  2401                     ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos))
       
  2402                     total_write += out_buffer.pos
       
  2403                     out_buffer.pos = 0
       
  2404 
       
  2405             # Continue loop to keep reading.
       
  2406 
       
  2407         return total_read, total_write
       
  2408 
       
  2409     def decompress_content_dict_chain(self, frames):
       
  2410         if not isinstance(frames, list):
       
  2411             raise TypeError('argument must be a list')
       
  2412 
       
  2413         if not frames:
       
  2414             raise ValueError('empty input chain')
       
  2415 
       
  2416         # First chunk should not be using a dictionary. We handle it specially.
       
  2417         chunk = frames[0]
       
  2418         if not isinstance(chunk, bytes_type):
       
  2419             raise ValueError('chunk 0 must be bytes')
       
  2420 
       
  2421         # All chunks should be zstd frames and should have content size set.
       
  2422         chunk_buffer = ffi.from_buffer(chunk)
       
  2423         params = ffi.new('ZSTD_frameHeader *')
       
  2424         zresult = lib.ZSTD_getFrameHeader(params, chunk_buffer, len(chunk_buffer))
       
  2425         if lib.ZSTD_isError(zresult):
       
  2426             raise ValueError('chunk 0 is not a valid zstd frame')
       
  2427         elif zresult:
       
  2428             raise ValueError('chunk 0 is too small to contain a zstd frame')
       
  2429 
       
  2430         if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN:
       
  2431             raise ValueError('chunk 0 missing content size in frame')
       
  2432 
       
  2433         self._ensure_dctx(load_dict=False)
       
  2434 
       
  2435         last_buffer = ffi.new('char[]', params.frameContentSize)
       
  2436 
       
  2437         out_buffer = ffi.new('ZSTD_outBuffer *')
       
  2438         out_buffer.dst = last_buffer
       
  2439         out_buffer.size = len(last_buffer)
       
  2440         out_buffer.pos = 0
       
  2441 
       
  2442         in_buffer = ffi.new('ZSTD_inBuffer *')
       
  2443         in_buffer.src = chunk_buffer
       
  2444         in_buffer.size = len(chunk_buffer)
       
  2445         in_buffer.pos = 0
       
  2446 
       
  2447         zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer)
       
  2448         if lib.ZSTD_isError(zresult):
       
  2449             raise ZstdError('could not decompress chunk 0: %s' %
       
  2450                             _zstd_error(zresult))
       
  2451         elif zresult:
       
  2452             raise ZstdError('chunk 0 did not decompress full frame')
       
  2453 
       
  2454         # Special case of chain length of 1
       
  2455         if len(frames) == 1:
       
  2456             return ffi.buffer(last_buffer, len(last_buffer))[:]
       
  2457 
       
  2458         i = 1
       
  2459         while i < len(frames):
       
  2460             chunk = frames[i]
       
  2461             if not isinstance(chunk, bytes_type):
       
  2462                 raise ValueError('chunk %d must be bytes' % i)
       
  2463 
       
  2464             chunk_buffer = ffi.from_buffer(chunk)
       
  2465             zresult = lib.ZSTD_getFrameHeader(params, chunk_buffer, len(chunk_buffer))
       
  2466             if lib.ZSTD_isError(zresult):
       
  2467                 raise ValueError('chunk %d is not a valid zstd frame' % i)
       
  2468             elif zresult:
       
  2469                 raise ValueError('chunk %d is too small to contain a zstd frame' % i)
       
  2470 
       
  2471             if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN:
       
  2472                 raise ValueError('chunk %d missing content size in frame' % i)
       
  2473 
       
  2474             dest_buffer = ffi.new('char[]', params.frameContentSize)
       
  2475 
       
  2476             out_buffer.dst = dest_buffer
       
  2477             out_buffer.size = len(dest_buffer)
       
  2478             out_buffer.pos = 0
       
  2479 
       
  2480             in_buffer.src = chunk_buffer
       
  2481             in_buffer.size = len(chunk_buffer)
       
  2482             in_buffer.pos = 0
       
  2483 
       
  2484             zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer)
       
  2485             if lib.ZSTD_isError(zresult):
       
  2486                 raise ZstdError('could not decompress chunk %d: %s' %
       
  2487                                 _zstd_error(zresult))
       
  2488             elif zresult:
       
  2489                 raise ZstdError('chunk %d did not decompress full frame' % i)
       
  2490 
       
  2491             last_buffer = dest_buffer
       
  2492             i += 1
       
  2493 
       
  2494         return ffi.buffer(last_buffer, len(last_buffer))[:]
       
  2495 
       
  2496     def _ensure_dctx(self, load_dict=True):
       
  2497         lib.ZSTD_DCtx_reset(self._dctx, lib.ZSTD_reset_session_only)
       
  2498 
       
  2499         if self._max_window_size:
       
  2500             zresult = lib.ZSTD_DCtx_setMaxWindowSize(self._dctx,
       
  2501                                                      self._max_window_size)
       
  2502             if lib.ZSTD_isError(zresult):
       
  2503                 raise ZstdError('unable to set max window size: %s' %
       
  2504                                 _zstd_error(zresult))
       
  2505 
       
  2506         zresult = lib.ZSTD_DCtx_setFormat(self._dctx, self._format)
       
  2507         if lib.ZSTD_isError(zresult):
       
  2508             raise ZstdError('unable to set decoding format: %s' %
       
  2509                             _zstd_error(zresult))
       
  2510 
       
  2511         if self._dict_data and load_dict:
       
  2512             zresult = lib.ZSTD_DCtx_refDDict(self._dctx, self._dict_data._ddict)
       
  2513             if lib.ZSTD_isError(zresult):
       
  2514                 raise ZstdError('unable to reference prepared dictionary: %s' %
       
  2515                                 _zstd_error(zresult))