contrib/python-zstandard/zstd_cffi.py
changeset 31799 e0dc40530c5a
parent 30924 c32454d69b85
child 37495 b1fb341d8a61
equal deleted inserted replaced
31798:2b130e26c3a4 31799:e0dc40530c5a
     6 
     6 
     7 """Python interface to the Zstandard (zstd) compression library."""
     7 """Python interface to the Zstandard (zstd) compression library."""
     8 
     8 
     9 from __future__ import absolute_import, unicode_literals
     9 from __future__ import absolute_import, unicode_literals
    10 
    10 
       
    11 import os
    11 import sys
    12 import sys
    12 
    13 
    13 from _zstd_cffi import (
    14 from _zstd_cffi import (
    14     ffi,
    15     ffi,
    15     lib,
    16     lib,
    60 
    61 
    61 COMPRESSOBJ_FLUSH_FINISH = 0
    62 COMPRESSOBJ_FLUSH_FINISH = 0
    62 COMPRESSOBJ_FLUSH_BLOCK = 1
    63 COMPRESSOBJ_FLUSH_BLOCK = 1
    63 
    64 
    64 
    65 
       
    66 def _cpu_count():
       
    67     # os.cpu_count() was introducd in Python 3.4.
       
    68     try:
       
    69         return os.cpu_count() or 0
       
    70     except AttributeError:
       
    71         pass
       
    72 
       
    73     # Linux.
       
    74     try:
       
    75         if sys.version_info[0] == 2:
       
    76             return os.sysconf(b'SC_NPROCESSORS_ONLN')
       
    77         else:
       
    78             return os.sysconf(u'SC_NPROCESSORS_ONLN')
       
    79     except (AttributeError, ValueError):
       
    80         pass
       
    81 
       
    82     # TODO implement on other platforms.
       
    83     return 0
       
    84 
       
    85 
    65 class ZstdError(Exception):
    86 class ZstdError(Exception):
    66     pass
    87     pass
    67 
    88 
    68 
    89 
    69 class CompressionParameters(object):
    90 class CompressionParameters(object):
    95         self.hash_log = hash_log
   116         self.hash_log = hash_log
    96         self.search_log = search_log
   117         self.search_log = search_log
    97         self.search_length = search_length
   118         self.search_length = search_length
    98         self.target_length = target_length
   119         self.target_length = target_length
    99         self.strategy = strategy
   120         self.strategy = strategy
       
   121 
       
   122         zresult = lib.ZSTD_checkCParams(self.as_compression_parameters())
       
   123         if lib.ZSTD_isError(zresult):
       
   124             raise ValueError('invalid compression parameters: %s',
       
   125                              ffi.string(lib.ZSTD_getErrorName(zresult)))
       
   126 
       
   127     def estimated_compression_context_size(self):
       
   128         return lib.ZSTD_estimateCCtxSize(self.as_compression_parameters())
   100 
   129 
   101     def as_compression_parameters(self):
   130     def as_compression_parameters(self):
   102         p = ffi.new('ZSTD_compressionParameters *')[0]
   131         p = ffi.new('ZSTD_compressionParameters *')[0]
   103         p.windowLog = self.window_log
   132         p.windowLog = self.window_log
   104         p.chainLog = self.chain_log
   133         p.chainLog = self.chain_log
   138         self._compressor = compressor
   167         self._compressor = compressor
   139         self._writer = writer
   168         self._writer = writer
   140         self._source_size = source_size
   169         self._source_size = source_size
   141         self._write_size = write_size
   170         self._write_size = write_size
   142         self._entered = False
   171         self._entered = False
       
   172         self._mtcctx = compressor._cctx if compressor._multithreaded else None
   143 
   173 
   144     def __enter__(self):
   174     def __enter__(self):
   145         if self._entered:
   175         if self._entered:
   146             raise ZstdError('cannot __enter__ multiple times')
   176             raise ZstdError('cannot __enter__ multiple times')
   147 
   177 
   148         self._cstream = self._compressor._get_cstream(self._source_size)
   178         if self._mtcctx:
       
   179             self._compressor._init_mtcstream(self._source_size)
       
   180         else:
       
   181             self._compressor._ensure_cstream(self._source_size)
   149         self._entered = True
   182         self._entered = True
   150         return self
   183         return self
   151 
   184 
   152     def __exit__(self, exc_type, exc_value, exc_tb):
   185     def __exit__(self, exc_type, exc_value, exc_tb):
   153         self._entered = False
   186         self._entered = False
   158             out_buffer.dst = dst_buffer
   191             out_buffer.dst = dst_buffer
   159             out_buffer.size = self._write_size
   192             out_buffer.size = self._write_size
   160             out_buffer.pos = 0
   193             out_buffer.pos = 0
   161 
   194 
   162             while True:
   195             while True:
   163                 zresult = lib.ZSTD_endStream(self._cstream, out_buffer)
   196                 if self._mtcctx:
       
   197                     zresult = lib.ZSTDMT_endStream(self._mtcctx, out_buffer)
       
   198                 else:
       
   199                     zresult = lib.ZSTD_endStream(self._compressor._cstream, out_buffer)
   164                 if lib.ZSTD_isError(zresult):
   200                 if lib.ZSTD_isError(zresult):
   165                     raise ZstdError('error ending compression stream: %s' %
   201                     raise ZstdError('error ending compression stream: %s' %
   166                                     ffi.string(lib.ZSTD_getErrorName(zresult)))
   202                                     ffi.string(lib.ZSTD_getErrorName(zresult)))
   167 
   203 
   168                 if out_buffer.pos:
   204                 if out_buffer.pos:
   170                     out_buffer.pos = 0
   206                     out_buffer.pos = 0
   171 
   207 
   172                 if zresult == 0:
   208                 if zresult == 0:
   173                     break
   209                     break
   174 
   210 
   175         self._cstream = None
       
   176         self._compressor = None
   211         self._compressor = None
   177 
   212 
   178         return False
   213         return False
   179 
   214 
   180     def memory_size(self):
   215     def memory_size(self):
   181         if not self._entered:
   216         if not self._entered:
   182             raise ZstdError('cannot determine size of an inactive compressor; '
   217             raise ZstdError('cannot determine size of an inactive compressor; '
   183                             'call when a context manager is active')
   218                             'call when a context manager is active')
   184 
   219 
   185         return lib.ZSTD_sizeof_CStream(self._cstream)
   220         return lib.ZSTD_sizeof_CStream(self._compressor._cstream)
   186 
   221 
   187     def write(self, data):
   222     def write(self, data):
   188         if not self._entered:
   223         if not self._entered:
   189             raise ZstdError('write() must be called from an active context '
   224             raise ZstdError('write() must be called from an active context '
   190                             'manager')
   225                             'manager')
   203         out_buffer.dst = dst_buffer
   238         out_buffer.dst = dst_buffer
   204         out_buffer.size = self._write_size
   239         out_buffer.size = self._write_size
   205         out_buffer.pos = 0
   240         out_buffer.pos = 0
   206 
   241 
   207         while in_buffer.pos < in_buffer.size:
   242         while in_buffer.pos < in_buffer.size:
   208             zresult = lib.ZSTD_compressStream(self._cstream, out_buffer, in_buffer)
   243             if self._mtcctx:
       
   244                 zresult = lib.ZSTDMT_compressStream(self._mtcctx, out_buffer,
       
   245                                                     in_buffer)
       
   246             else:
       
   247                 zresult = lib.ZSTD_compressStream(self._compressor._cstream, out_buffer,
       
   248                                                   in_buffer)
   209             if lib.ZSTD_isError(zresult):
   249             if lib.ZSTD_isError(zresult):
   210                 raise ZstdError('zstd compress error: %s' %
   250                 raise ZstdError('zstd compress error: %s' %
   211                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
   251                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
   212 
   252 
   213             if out_buffer.pos:
   253             if out_buffer.pos:
   228         out_buffer.dst = dst_buffer
   268         out_buffer.dst = dst_buffer
   229         out_buffer.size = self._write_size
   269         out_buffer.size = self._write_size
   230         out_buffer.pos = 0
   270         out_buffer.pos = 0
   231 
   271 
   232         while True:
   272         while True:
   233             zresult = lib.ZSTD_flushStream(self._cstream, out_buffer)
   273             if self._mtcctx:
       
   274                 zresult = lib.ZSTDMT_flushStream(self._mtcctx, out_buffer)
       
   275             else:
       
   276                 zresult = lib.ZSTD_flushStream(self._compressor._cstream, out_buffer)
   234             if lib.ZSTD_isError(zresult):
   277             if lib.ZSTD_isError(zresult):
   235                 raise ZstdError('zstd compress error: %s' %
   278                 raise ZstdError('zstd compress error: %s' %
   236                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
   279                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
   237 
   280 
   238             if not out_buffer.pos:
   281             if not out_buffer.pos:
   257         source.pos = 0
   300         source.pos = 0
   258 
   301 
   259         chunks = []
   302         chunks = []
   260 
   303 
   261         while source.pos < len(data):
   304         while source.pos < len(data):
   262             zresult = lib.ZSTD_compressStream(self._cstream, self._out, source)
   305             if self._mtcctx:
       
   306                 zresult = lib.ZSTDMT_compressStream(self._mtcctx,
       
   307                                                     self._out, source)
       
   308             else:
       
   309                 zresult = lib.ZSTD_compressStream(self._compressor._cstream, self._out,
       
   310                                                   source)
   263             if lib.ZSTD_isError(zresult):
   311             if lib.ZSTD_isError(zresult):
   264                 raise ZstdError('zstd compress error: %s' %
   312                 raise ZstdError('zstd compress error: %s' %
   265                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
   313                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
   266 
   314 
   267             if self._out.pos:
   315             if self._out.pos:
   278             raise ZstdError('compressor object already finished')
   326             raise ZstdError('compressor object already finished')
   279 
   327 
   280         assert self._out.pos == 0
   328         assert self._out.pos == 0
   281 
   329 
   282         if flush_mode == COMPRESSOBJ_FLUSH_BLOCK:
   330         if flush_mode == COMPRESSOBJ_FLUSH_BLOCK:
   283             zresult = lib.ZSTD_flushStream(self._cstream, self._out)
   331             if self._mtcctx:
       
   332                 zresult = lib.ZSTDMT_flushStream(self._mtcctx, self._out)
       
   333             else:
       
   334                 zresult = lib.ZSTD_flushStream(self._compressor._cstream, self._out)
   284             if lib.ZSTD_isError(zresult):
   335             if lib.ZSTD_isError(zresult):
   285                 raise ZstdError('zstd compress error: %s' %
   336                 raise ZstdError('zstd compress error: %s' %
   286                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
   337                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
   287 
   338 
   288             # Output buffer is guaranteed to hold full block.
   339             # Output buffer is guaranteed to hold full block.
   299         self._finished = True
   350         self._finished = True
   300 
   351 
   301         chunks = []
   352         chunks = []
   302 
   353 
   303         while True:
   354         while True:
   304             zresult = lib.ZSTD_endStream(self._cstream, self._out)
   355             if self._mtcctx:
       
   356                 zresult = lib.ZSTDMT_endStream(self._mtcctx, self._out)
       
   357             else:
       
   358                 zresult = lib.ZSTD_endStream(self._compressor._cstream, self._out)
   305             if lib.ZSTD_isError(zresult):
   359             if lib.ZSTD_isError(zresult):
   306                 raise ZstdError('error ending compression stream: %s' %
   360                 raise ZstdError('error ending compression stream: %s' %
   307                                 ffi.string(lib.ZSTD_getErroName(zresult)))
   361                                 ffi.string(lib.ZSTD_getErroName(zresult)))
   308 
   362 
   309             if self._out.pos:
   363             if self._out.pos:
   311                 self._out.pos = 0
   365                 self._out.pos = 0
   312 
   366 
   313             if not zresult:
   367             if not zresult:
   314                 break
   368                 break
   315 
   369 
   316         # GC compression stream immediately.
       
   317         self._cstream = None
       
   318 
       
   319         return b''.join(chunks)
   370         return b''.join(chunks)
   320 
   371 
   321 
   372 
   322 class ZstdCompressor(object):
   373 class ZstdCompressor(object):
   323     def __init__(self, level=3, dict_data=None, compression_params=None,
   374     def __init__(self, level=3, dict_data=None, compression_params=None,
   324                  write_checksum=False, write_content_size=False,
   375                  write_checksum=False, write_content_size=False,
   325                  write_dict_id=True):
   376                  write_dict_id=True, threads=0):
   326         if level < 1:
   377         if level < 1:
   327             raise ValueError('level must be greater than 0')
   378             raise ValueError('level must be greater than 0')
   328         elif level > lib.ZSTD_maxCLevel():
   379         elif level > lib.ZSTD_maxCLevel():
   329             raise ValueError('level must be less than %d' % lib.ZSTD_maxCLevel())
   380             raise ValueError('level must be less than %d' % lib.ZSTD_maxCLevel())
       
   381 
       
   382         if threads < 0:
       
   383             threads = _cpu_count()
   330 
   384 
   331         self._compression_level = level
   385         self._compression_level = level
   332         self._dict_data = dict_data
   386         self._dict_data = dict_data
   333         self._cparams = compression_params
   387         self._cparams = compression_params
   334         self._fparams = ffi.new('ZSTD_frameParameters *')[0]
   388         self._fparams = ffi.new('ZSTD_frameParameters *')[0]
   335         self._fparams.checksumFlag = write_checksum
   389         self._fparams.checksumFlag = write_checksum
   336         self._fparams.contentSizeFlag = write_content_size
   390         self._fparams.contentSizeFlag = write_content_size
   337         self._fparams.noDictIDFlag = not write_dict_id
   391         self._fparams.noDictIDFlag = not write_dict_id
   338 
   392 
   339         cctx = lib.ZSTD_createCCtx()
   393         if threads:
   340         if cctx == ffi.NULL:
   394             cctx = lib.ZSTDMT_createCCtx(threads)
   341             raise MemoryError()
   395             if cctx == ffi.NULL:
   342 
   396                 raise MemoryError()
   343         self._cctx = ffi.gc(cctx, lib.ZSTD_freeCCtx)
   397 
       
   398             self._cctx = ffi.gc(cctx, lib.ZSTDMT_freeCCtx)
       
   399             self._multithreaded = True
       
   400         else:
       
   401             cctx = lib.ZSTD_createCCtx()
       
   402             if cctx == ffi.NULL:
       
   403                 raise MemoryError()
       
   404 
       
   405             self._cctx = ffi.gc(cctx, lib.ZSTD_freeCCtx)
       
   406             self._multithreaded = False
       
   407 
       
   408         self._cstream = None
   344 
   409 
   345     def compress(self, data, allow_empty=False):
   410     def compress(self, data, allow_empty=False):
   346         if len(data) == 0 and self._fparams.contentSizeFlag and not allow_empty:
   411         if len(data) == 0 and self._fparams.contentSizeFlag and not allow_empty:
   347             raise ValueError('cannot write empty inputs when writing content sizes')
   412             raise ValueError('cannot write empty inputs when writing content sizes')
   348 
   413 
       
   414         if self._multithreaded and self._dict_data:
       
   415             raise ZstdError('compress() cannot be used with both dictionaries and multi-threaded compression')
       
   416 
       
   417         if self._multithreaded and self._cparams:
       
   418             raise ZstdError('compress() cannot be used with both compression parameters and multi-threaded compression')
       
   419 
   349         # TODO use a CDict for performance.
   420         # TODO use a CDict for performance.
   350         dict_data = ffi.NULL
   421         dict_data = ffi.NULL
   351         dict_size = 0
   422         dict_size = 0
   352 
   423 
   353         if self._dict_data:
   424         if self._dict_data:
   363         params.fParams = self._fparams
   434         params.fParams = self._fparams
   364 
   435 
   365         dest_size = lib.ZSTD_compressBound(len(data))
   436         dest_size = lib.ZSTD_compressBound(len(data))
   366         out = new_nonzero('char[]', dest_size)
   437         out = new_nonzero('char[]', dest_size)
   367 
   438 
   368         zresult = lib.ZSTD_compress_advanced(self._cctx,
   439         if self._multithreaded:
   369                                              ffi.addressof(out), dest_size,
   440             zresult = lib.ZSTDMT_compressCCtx(self._cctx,
   370                                              data, len(data),
   441                                               ffi.addressof(out), dest_size,
   371                                              dict_data, dict_size,
   442                                               data, len(data),
   372                                              params)
   443                                               self._compression_level)
       
   444         else:
       
   445             zresult = lib.ZSTD_compress_advanced(self._cctx,
       
   446                                                  ffi.addressof(out), dest_size,
       
   447                                                  data, len(data),
       
   448                                                  dict_data, dict_size,
       
   449                                                  params)
   373 
   450 
   374         if lib.ZSTD_isError(zresult):
   451         if lib.ZSTD_isError(zresult):
   375             raise ZstdError('cannot compress: %s' %
   452             raise ZstdError('cannot compress: %s' %
   376                             ffi.string(lib.ZSTD_getErrorName(zresult)))
   453                             ffi.string(lib.ZSTD_getErrorName(zresult)))
   377 
   454 
   378         return ffi.buffer(out, zresult)[:]
   455         return ffi.buffer(out, zresult)[:]
   379 
   456 
   380     def compressobj(self, size=0):
   457     def compressobj(self, size=0):
   381         cstream = self._get_cstream(size)
   458         if self._multithreaded:
       
   459             self._init_mtcstream(size)
       
   460         else:
       
   461             self._ensure_cstream(size)
       
   462 
   382         cobj = ZstdCompressionObj()
   463         cobj = ZstdCompressionObj()
   383         cobj._cstream = cstream
       
   384         cobj._out = ffi.new('ZSTD_outBuffer *')
   464         cobj._out = ffi.new('ZSTD_outBuffer *')
   385         cobj._dst_buffer = ffi.new('char[]', COMPRESSION_RECOMMENDED_OUTPUT_SIZE)
   465         cobj._dst_buffer = ffi.new('char[]', COMPRESSION_RECOMMENDED_OUTPUT_SIZE)
   386         cobj._out.dst = cobj._dst_buffer
   466         cobj._out.dst = cobj._dst_buffer
   387         cobj._out.size = COMPRESSION_RECOMMENDED_OUTPUT_SIZE
   467         cobj._out.size = COMPRESSION_RECOMMENDED_OUTPUT_SIZE
   388         cobj._out.pos = 0
   468         cobj._out.pos = 0
   389         cobj._compressor = self
   469         cobj._compressor = self
   390         cobj._finished = False
   470         cobj._finished = False
   391 
   471 
       
   472         if self._multithreaded:
       
   473             cobj._mtcctx = self._cctx
       
   474         else:
       
   475             cobj._mtcctx = None
       
   476 
   392         return cobj
   477         return cobj
   393 
   478 
   394     def copy_stream(self, ifh, ofh, size=0,
   479     def copy_stream(self, ifh, ofh, size=0,
   395                     read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE,
   480                     read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE,
   396                     write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE):
   481                     write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE):
   398         if not hasattr(ifh, 'read'):
   483         if not hasattr(ifh, 'read'):
   399             raise ValueError('first argument must have a read() method')
   484             raise ValueError('first argument must have a read() method')
   400         if not hasattr(ofh, 'write'):
   485         if not hasattr(ofh, 'write'):
   401             raise ValueError('second argument must have a write() method')
   486             raise ValueError('second argument must have a write() method')
   402 
   487 
   403         cstream = self._get_cstream(size)
   488         mt = self._multithreaded
       
   489         if mt:
       
   490             self._init_mtcstream(size)
       
   491         else:
       
   492             self._ensure_cstream(size)
   404 
   493 
   405         in_buffer = ffi.new('ZSTD_inBuffer *')
   494         in_buffer = ffi.new('ZSTD_inBuffer *')
   406         out_buffer = ffi.new('ZSTD_outBuffer *')
   495         out_buffer = ffi.new('ZSTD_outBuffer *')
   407 
   496 
   408         dst_buffer = ffi.new('char[]', write_size)
   497         dst_buffer = ffi.new('char[]', write_size)
   422             in_buffer.src = data_buffer
   511             in_buffer.src = data_buffer
   423             in_buffer.size = len(data_buffer)
   512             in_buffer.size = len(data_buffer)
   424             in_buffer.pos = 0
   513             in_buffer.pos = 0
   425 
   514 
   426             while in_buffer.pos < in_buffer.size:
   515             while in_buffer.pos < in_buffer.size:
   427                 zresult = lib.ZSTD_compressStream(cstream, out_buffer, in_buffer)
   516                 if mt:
       
   517                     zresult = lib.ZSTDMT_compressStream(self._cctx, out_buffer, in_buffer)
       
   518                 else:
       
   519                     zresult = lib.ZSTD_compressStream(self._cstream,
       
   520                                                       out_buffer, in_buffer)
   428                 if lib.ZSTD_isError(zresult):
   521                 if lib.ZSTD_isError(zresult):
   429                     raise ZstdError('zstd compress error: %s' %
   522                     raise ZstdError('zstd compress error: %s' %
   430                                     ffi.string(lib.ZSTD_getErrorName(zresult)))
   523                                     ffi.string(lib.ZSTD_getErrorName(zresult)))
   431 
   524 
   432                 if out_buffer.pos:
   525                 if out_buffer.pos:
   434                     total_write += out_buffer.pos
   527                     total_write += out_buffer.pos
   435                     out_buffer.pos = 0
   528                     out_buffer.pos = 0
   436 
   529 
   437         # We've finished reading. Flush the compressor.
   530         # We've finished reading. Flush the compressor.
   438         while True:
   531         while True:
   439             zresult = lib.ZSTD_endStream(cstream, out_buffer)
   532             if mt:
       
   533                 zresult = lib.ZSTDMT_endStream(self._cctx, out_buffer)
       
   534             else:
       
   535                 zresult = lib.ZSTD_endStream(self._cstream, out_buffer)
   440             if lib.ZSTD_isError(zresult):
   536             if lib.ZSTD_isError(zresult):
   441                 raise ZstdError('error ending compression stream: %s' %
   537                 raise ZstdError('error ending compression stream: %s' %
   442                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
   538                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
   443 
   539 
   444             if out_buffer.pos:
   540             if out_buffer.pos:
   470             size = len(reader)
   566             size = len(reader)
   471         else:
   567         else:
   472             raise ValueError('must pass an object with a read() method or '
   568             raise ValueError('must pass an object with a read() method or '
   473                              'conforms to buffer protocol')
   569                              'conforms to buffer protocol')
   474 
   570 
   475         cstream = self._get_cstream(size)
   571         if self._multithreaded:
       
   572             self._init_mtcstream(size)
       
   573         else:
       
   574             self._ensure_cstream(size)
   476 
   575 
   477         in_buffer = ffi.new('ZSTD_inBuffer *')
   576         in_buffer = ffi.new('ZSTD_inBuffer *')
   478         out_buffer = ffi.new('ZSTD_outBuffer *')
   577         out_buffer = ffi.new('ZSTD_outBuffer *')
   479 
   578 
   480         in_buffer.src = ffi.NULL
   579         in_buffer.src = ffi.NULL
   510             in_buffer.src = read_buffer
   609             in_buffer.src = read_buffer
   511             in_buffer.size = len(read_buffer)
   610             in_buffer.size = len(read_buffer)
   512             in_buffer.pos = 0
   611             in_buffer.pos = 0
   513 
   612 
   514             while in_buffer.pos < in_buffer.size:
   613             while in_buffer.pos < in_buffer.size:
   515                 zresult = lib.ZSTD_compressStream(cstream, out_buffer, in_buffer)
   614                 if self._multithreaded:
       
   615                     zresult = lib.ZSTDMT_compressStream(self._cctx, out_buffer, in_buffer)
       
   616                 else:
       
   617                     zresult = lib.ZSTD_compressStream(self._cstream, out_buffer, in_buffer)
   516                 if lib.ZSTD_isError(zresult):
   618                 if lib.ZSTD_isError(zresult):
   517                     raise ZstdError('zstd compress error: %s' %
   619                     raise ZstdError('zstd compress error: %s' %
   518                                     ffi.string(lib.ZSTD_getErrorName(zresult)))
   620                                     ffi.string(lib.ZSTD_getErrorName(zresult)))
   519 
   621 
   520                 if out_buffer.pos:
   622                 if out_buffer.pos:
   529 
   631 
   530         # If we get here, input is exhausted. End the stream and emit what
   632         # If we get here, input is exhausted. End the stream and emit what
   531         # remains.
   633         # remains.
   532         while True:
   634         while True:
   533             assert out_buffer.pos == 0
   635             assert out_buffer.pos == 0
   534             zresult = lib.ZSTD_endStream(cstream, out_buffer)
   636             if self._multithreaded:
       
   637                 zresult = lib.ZSTDMT_endStream(self._cctx, out_buffer)
       
   638             else:
       
   639                 zresult = lib.ZSTD_endStream(self._cstream, out_buffer)
   535             if lib.ZSTD_isError(zresult):
   640             if lib.ZSTD_isError(zresult):
   536                 raise ZstdError('error ending compression stream: %s' %
   641                 raise ZstdError('error ending compression stream: %s' %
   537                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
   642                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
   538 
   643 
   539             if out_buffer.pos:
   644             if out_buffer.pos:
   542                 yield data
   647                 yield data
   543 
   648 
   544             if zresult == 0:
   649             if zresult == 0:
   545                 break
   650                 break
   546 
   651 
   547     def _get_cstream(self, size):
   652     def _ensure_cstream(self, size):
       
   653         if self._cstream:
       
   654             zresult = lib.ZSTD_resetCStream(self._cstream, size)
       
   655             if lib.ZSTD_isError(zresult):
       
   656                 raise ZstdError('could not reset CStream: %s' %
       
   657                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
       
   658 
       
   659             return
       
   660 
   548         cstream = lib.ZSTD_createCStream()
   661         cstream = lib.ZSTD_createCStream()
   549         if cstream == ffi.NULL:
   662         if cstream == ffi.NULL:
   550             raise MemoryError()
   663             raise MemoryError()
   551 
   664 
   552         cstream = ffi.gc(cstream, lib.ZSTD_freeCStream)
   665         cstream = ffi.gc(cstream, lib.ZSTD_freeCStream)
   569                                                 zparams, size)
   682                                                 zparams, size)
   570         if lib.ZSTD_isError(zresult):
   683         if lib.ZSTD_isError(zresult):
   571             raise Exception('cannot init CStream: %s' %
   684             raise Exception('cannot init CStream: %s' %
   572                             ffi.string(lib.ZSTD_getErrorName(zresult)))
   685                             ffi.string(lib.ZSTD_getErrorName(zresult)))
   573 
   686 
   574         return cstream
   687         self._cstream = cstream
       
   688 
       
   689     def _init_mtcstream(self, size):
       
   690         assert self._multithreaded
       
   691 
       
   692         dict_data = ffi.NULL
       
   693         dict_size = 0
       
   694         if self._dict_data:
       
   695             dict_data = self._dict_data.as_bytes()
       
   696             dict_size = len(self._dict_data)
       
   697 
       
   698         zparams = ffi.new('ZSTD_parameters *')[0]
       
   699         if self._cparams:
       
   700             zparams.cParams = self._cparams.as_compression_parameters()
       
   701         else:
       
   702             zparams.cParams = lib.ZSTD_getCParams(self._compression_level,
       
   703                                                   size, dict_size)
       
   704 
       
   705         zparams.fParams = self._fparams
       
   706 
       
   707         zresult = lib.ZSTDMT_initCStream_advanced(self._cctx, dict_data, dict_size,
       
   708                                                   zparams, size)
       
   709 
       
   710         if lib.ZSTD_isError(zresult):
       
   711             raise ZstdError('cannot init CStream: %s' %
       
   712                             ffi.string(lib.ZSTD_getErrorName(zresult)))
   575 
   713 
   576 
   714 
   577 class FrameParameters(object):
   715 class FrameParameters(object):
   578     def __init__(self, fparams):
   716     def __init__(self, fparams):
   579         self.content_size = fparams.frameContentSize
   717         self.content_size = fparams.frameContentSize
   599 
   737 
   600     return FrameParameters(params[0])
   738     return FrameParameters(params[0])
   601 
   739 
   602 
   740 
   603 class ZstdCompressionDict(object):
   741 class ZstdCompressionDict(object):
   604     def __init__(self, data):
   742     def __init__(self, data, k=0, d=0):
   605         assert isinstance(data, bytes_type)
   743         assert isinstance(data, bytes_type)
   606         self._data = data
   744         self._data = data
       
   745         self.k = k
       
   746         self.d = d
   607 
   747 
   608     def __len__(self):
   748     def __len__(self):
   609         return len(self._data)
   749         return len(self._data)
   610 
   750 
   611     def dict_id(self):
   751     def dict_id(self):
   613 
   753 
   614     def as_bytes(self):
   754     def as_bytes(self):
   615         return self._data
   755         return self._data
   616 
   756 
   617 
   757 
   618 def train_dictionary(dict_size, samples, parameters=None):
   758 def train_dictionary(dict_size, samples, selectivity=0, level=0,
       
   759                      notifications=0, dict_id=0):
   619     if not isinstance(samples, list):
   760     if not isinstance(samples, list):
   620         raise TypeError('samples must be a list')
   761         raise TypeError('samples must be a list')
   621 
   762 
   622     total_size = sum(map(len, samples))
   763     total_size = sum(map(len, samples))
   623 
   764 
   634         offset += l
   775         offset += l
   635         sample_sizes[i] = l
   776         sample_sizes[i] = l
   636 
   777 
   637     dict_data = new_nonzero('char[]', dict_size)
   778     dict_data = new_nonzero('char[]', dict_size)
   638 
   779 
   639     zresult = lib.ZDICT_trainFromBuffer(ffi.addressof(dict_data), dict_size,
   780     dparams = ffi.new('ZDICT_params_t *')[0]
   640                                         ffi.addressof(samples_buffer),
   781     dparams.selectivityLevel = selectivity
   641                                         ffi.addressof(sample_sizes, 0),
   782     dparams.compressionLevel = level
   642                                         len(samples))
   783     dparams.notificationLevel = notifications
       
   784     dparams.dictID = dict_id
       
   785 
       
   786     zresult = lib.ZDICT_trainFromBuffer_advanced(
       
   787         ffi.addressof(dict_data), dict_size,
       
   788         ffi.addressof(samples_buffer),
       
   789         ffi.addressof(sample_sizes, 0), len(samples),
       
   790         dparams)
       
   791 
   643     if lib.ZDICT_isError(zresult):
   792     if lib.ZDICT_isError(zresult):
   644         raise ZstdError('Cannot train dict: %s' %
   793         raise ZstdError('Cannot train dict: %s' %
   645                         ffi.string(lib.ZDICT_getErrorName(zresult)))
   794                         ffi.string(lib.ZDICT_getErrorName(zresult)))
   646 
   795 
   647     return ZstdCompressionDict(ffi.buffer(dict_data, zresult)[:])
   796     return ZstdCompressionDict(ffi.buffer(dict_data, zresult)[:])
   648 
   797 
   649 
   798 
       
   799 def train_cover_dictionary(dict_size, samples, k=0, d=0,
       
   800                            notifications=0, dict_id=0, level=0, optimize=False,
       
   801                            steps=0, threads=0):
       
   802     if not isinstance(samples, list):
       
   803         raise TypeError('samples must be a list')
       
   804 
       
   805     if threads < 0:
       
   806         threads = _cpu_count()
       
   807 
       
   808     total_size = sum(map(len, samples))
       
   809 
       
   810     samples_buffer = new_nonzero('char[]', total_size)
       
   811     sample_sizes = new_nonzero('size_t[]', len(samples))
       
   812 
       
   813     offset = 0
       
   814     for i, sample in enumerate(samples):
       
   815         if not isinstance(sample, bytes_type):
       
   816             raise ValueError('samples must be bytes')
       
   817 
       
   818         l = len(sample)
       
   819         ffi.memmove(samples_buffer + offset, sample, l)
       
   820         offset += l
       
   821         sample_sizes[i] = l
       
   822 
       
   823     dict_data = new_nonzero('char[]', dict_size)
       
   824 
       
   825     dparams = ffi.new('COVER_params_t *')[0]
       
   826     dparams.k = k
       
   827     dparams.d = d
       
   828     dparams.steps = steps
       
   829     dparams.nbThreads = threads
       
   830     dparams.notificationLevel = notifications
       
   831     dparams.dictID = dict_id
       
   832     dparams.compressionLevel = level
       
   833 
       
   834     if optimize:
       
   835         zresult = lib.COVER_optimizeTrainFromBuffer(
       
   836             ffi.addressof(dict_data), dict_size,
       
   837             ffi.addressof(samples_buffer),
       
   838             ffi.addressof(sample_sizes, 0), len(samples),
       
   839             ffi.addressof(dparams))
       
   840     else:
       
   841         zresult = lib.COVER_trainFromBuffer(
       
   842             ffi.addressof(dict_data), dict_size,
       
   843             ffi.addressof(samples_buffer),
       
   844             ffi.addressof(sample_sizes, 0), len(samples),
       
   845             dparams)
       
   846 
       
   847     if lib.ZDICT_isError(zresult):
       
   848         raise ZstdError('cannot train dict: %s' %
       
   849                         ffi.string(lib.ZDICT_getErrorName(zresult)))
       
   850 
       
   851     return ZstdCompressionDict(ffi.buffer(dict_data, zresult)[:],
       
   852                                k=dparams.k, d=dparams.d)
       
   853 
       
   854 
   650 class ZstdDecompressionObj(object):
   855 class ZstdDecompressionObj(object):
   651     def __init__(self, decompressor):
   856     def __init__(self, decompressor):
   652         self._decompressor = decompressor
   857         self._decompressor = decompressor
   653         self._dstream = self._decompressor._get_dstream()
       
   654         self._finished = False
   858         self._finished = False
   655 
   859 
   656     def decompress(self, data):
   860     def decompress(self, data):
   657         if self._finished:
   861         if self._finished:
   658             raise ZstdError('cannot use a decompressobj multiple times')
   862             raise ZstdError('cannot use a decompressobj multiple times')
       
   863 
       
   864         assert(self._decompressor._dstream)
   659 
   865 
   660         in_buffer = ffi.new('ZSTD_inBuffer *')
   866         in_buffer = ffi.new('ZSTD_inBuffer *')
   661         out_buffer = ffi.new('ZSTD_outBuffer *')
   867         out_buffer = ffi.new('ZSTD_outBuffer *')
   662 
   868 
   663         data_buffer = ffi.from_buffer(data)
   869         data_buffer = ffi.from_buffer(data)
   671         out_buffer.pos = 0
   877         out_buffer.pos = 0
   672 
   878 
   673         chunks = []
   879         chunks = []
   674 
   880 
   675         while in_buffer.pos < in_buffer.size:
   881         while in_buffer.pos < in_buffer.size:
   676             zresult = lib.ZSTD_decompressStream(self._dstream, out_buffer, in_buffer)
   882             zresult = lib.ZSTD_decompressStream(self._decompressor._dstream,
       
   883                                                 out_buffer, in_buffer)
   677             if lib.ZSTD_isError(zresult):
   884             if lib.ZSTD_isError(zresult):
   678                 raise ZstdError('zstd decompressor error: %s' %
   885                 raise ZstdError('zstd decompressor error: %s' %
   679                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
   886                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
   680 
   887 
   681             if zresult == 0:
   888             if zresult == 0:
   682                 self._finished = True
   889                 self._finished = True
   683                 self._dstream = None
       
   684                 self._decompressor = None
   890                 self._decompressor = None
   685 
   891 
   686             if out_buffer.pos:
   892             if out_buffer.pos:
   687                 chunks.append(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
   893                 chunks.append(ffi.buffer(out_buffer.dst, out_buffer.pos)[:])
   688                 out_buffer.pos = 0
   894                 out_buffer.pos = 0
   693 class ZstdDecompressionWriter(object):
   899 class ZstdDecompressionWriter(object):
   694     def __init__(self, decompressor, writer, write_size):
   900     def __init__(self, decompressor, writer, write_size):
   695         self._decompressor = decompressor
   901         self._decompressor = decompressor
   696         self._writer = writer
   902         self._writer = writer
   697         self._write_size = write_size
   903         self._write_size = write_size
   698         self._dstream = None
       
   699         self._entered = False
   904         self._entered = False
   700 
   905 
   701     def __enter__(self):
   906     def __enter__(self):
   702         if self._entered:
   907         if self._entered:
   703             raise ZstdError('cannot __enter__ multiple times')
   908             raise ZstdError('cannot __enter__ multiple times')
   704 
   909 
   705         self._dstream = self._decompressor._get_dstream()
   910         self._decompressor._ensure_dstream()
   706         self._entered = True
   911         self._entered = True
   707 
   912 
   708         return self
   913         return self
   709 
   914 
   710     def __exit__(self, exc_type, exc_value, exc_tb):
   915     def __exit__(self, exc_type, exc_value, exc_tb):
   711         self._entered = False
   916         self._entered = False
   712         self._dstream = None
       
   713 
   917 
   714     def memory_size(self):
   918     def memory_size(self):
   715         if not self._dstream:
   919         if not self._decompressor._dstream:
   716             raise ZstdError('cannot determine size of inactive decompressor '
   920             raise ZstdError('cannot determine size of inactive decompressor '
   717                             'call when context manager is active')
   921                             'call when context manager is active')
   718 
   922 
   719         return lib.ZSTD_sizeof_DStream(self._dstream)
   923         return lib.ZSTD_sizeof_DStream(self._decompressor._dstream)
   720 
   924 
   721     def write(self, data):
   925     def write(self, data):
   722         if not self._entered:
   926         if not self._entered:
   723             raise ZstdError('write must be called from an active context manager')
   927             raise ZstdError('write must be called from an active context manager')
   724 
   928 
   735         dst_buffer = ffi.new('char[]', self._write_size)
   939         dst_buffer = ffi.new('char[]', self._write_size)
   736         out_buffer.dst = dst_buffer
   940         out_buffer.dst = dst_buffer
   737         out_buffer.size = len(dst_buffer)
   941         out_buffer.size = len(dst_buffer)
   738         out_buffer.pos = 0
   942         out_buffer.pos = 0
   739 
   943 
       
   944         dstream = self._decompressor._dstream
       
   945 
   740         while in_buffer.pos < in_buffer.size:
   946         while in_buffer.pos < in_buffer.size:
   741             zresult = lib.ZSTD_decompressStream(self._dstream, out_buffer, in_buffer)
   947             zresult = lib.ZSTD_decompressStream(dstream, out_buffer, in_buffer)
   742             if lib.ZSTD_isError(zresult):
   948             if lib.ZSTD_isError(zresult):
   743                 raise ZstdError('zstd decompress error: %s' %
   949                 raise ZstdError('zstd decompress error: %s' %
   744                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
   950                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
   745 
   951 
   746             if out_buffer.pos:
   952             if out_buffer.pos:
   758         dctx = lib.ZSTD_createDCtx()
   964         dctx = lib.ZSTD_createDCtx()
   759         if dctx == ffi.NULL:
   965         if dctx == ffi.NULL:
   760             raise MemoryError()
   966             raise MemoryError()
   761 
   967 
   762         self._refdctx = ffi.gc(dctx, lib.ZSTD_freeDCtx)
   968         self._refdctx = ffi.gc(dctx, lib.ZSTD_freeDCtx)
       
   969         self._dstream = None
   763 
   970 
   764     @property
   971     @property
   765     def _ddict(self):
   972     def _ddict(self):
   766         if self._dict_data:
   973         if self._dict_data:
   767             dict_data = self._dict_data.as_bytes()
   974             dict_data = self._dict_data.as_bytes()
   814                             (zresult, output_size))
  1021                             (zresult, output_size))
   815 
  1022 
   816         return ffi.buffer(result_buffer, zresult)[:]
  1023         return ffi.buffer(result_buffer, zresult)[:]
   817 
  1024 
   818     def decompressobj(self):
  1025     def decompressobj(self):
       
  1026         self._ensure_dstream()
   819         return ZstdDecompressionObj(self)
  1027         return ZstdDecompressionObj(self)
   820 
  1028 
   821     def read_from(self, reader, read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE,
  1029     def read_from(self, reader, read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE,
   822                   write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE,
  1030                   write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE,
   823                   skip_bytes=0):
  1031                   skip_bytes=0):
   841                 if skip_bytes > size:
  1049                 if skip_bytes > size:
   842                     raise ValueError('skip_bytes larger than first input chunk')
  1050                     raise ValueError('skip_bytes larger than first input chunk')
   843 
  1051 
   844                 buffer_offset = skip_bytes
  1052                 buffer_offset = skip_bytes
   845 
  1053 
   846         dstream = self._get_dstream()
  1054         self._ensure_dstream()
   847 
  1055 
   848         in_buffer = ffi.new('ZSTD_inBuffer *')
  1056         in_buffer = ffi.new('ZSTD_inBuffer *')
   849         out_buffer = ffi.new('ZSTD_outBuffer *')
  1057         out_buffer = ffi.new('ZSTD_outBuffer *')
   850 
  1058 
   851         dst_buffer = ffi.new('char[]', write_size)
  1059         dst_buffer = ffi.new('char[]', write_size)
   876             in_buffer.pos = 0
  1084             in_buffer.pos = 0
   877 
  1085 
   878             while in_buffer.pos < in_buffer.size:
  1086             while in_buffer.pos < in_buffer.size:
   879                 assert out_buffer.pos == 0
  1087                 assert out_buffer.pos == 0
   880 
  1088 
   881                 zresult = lib.ZSTD_decompressStream(dstream, out_buffer, in_buffer)
  1089                 zresult = lib.ZSTD_decompressStream(self._dstream, out_buffer, in_buffer)
   882                 if lib.ZSTD_isError(zresult):
  1090                 if lib.ZSTD_isError(zresult):
   883                     raise ZstdError('zstd decompress error: %s' %
  1091                     raise ZstdError('zstd decompress error: %s' %
   884                                     ffi.string(lib.ZSTD_getErrorName(zresult)))
  1092                                     ffi.string(lib.ZSTD_getErrorName(zresult)))
   885 
  1093 
   886                 if out_buffer.pos:
  1094                 if out_buffer.pos:
   908         if not hasattr(ifh, 'read'):
  1116         if not hasattr(ifh, 'read'):
   909             raise ValueError('first argument must have a read() method')
  1117             raise ValueError('first argument must have a read() method')
   910         if not hasattr(ofh, 'write'):
  1118         if not hasattr(ofh, 'write'):
   911             raise ValueError('second argument must have a write() method')
  1119             raise ValueError('second argument must have a write() method')
   912 
  1120 
   913         dstream = self._get_dstream()
  1121         self._ensure_dstream()
   914 
  1122 
   915         in_buffer = ffi.new('ZSTD_inBuffer *')
  1123         in_buffer = ffi.new('ZSTD_inBuffer *')
   916         out_buffer = ffi.new('ZSTD_outBuffer *')
  1124         out_buffer = ffi.new('ZSTD_outBuffer *')
   917 
  1125 
   918         dst_buffer = ffi.new('char[]', write_size)
  1126         dst_buffer = ffi.new('char[]', write_size)
   934             in_buffer.size = len(data_buffer)
  1142             in_buffer.size = len(data_buffer)
   935             in_buffer.pos = 0
  1143             in_buffer.pos = 0
   936 
  1144 
   937             # Flush all read data to output.
  1145             # Flush all read data to output.
   938             while in_buffer.pos < in_buffer.size:
  1146             while in_buffer.pos < in_buffer.size:
   939                 zresult = lib.ZSTD_decompressStream(dstream, out_buffer, in_buffer)
  1147                 zresult = lib.ZSTD_decompressStream(self._dstream, out_buffer, in_buffer)
   940                 if lib.ZSTD_isError(zresult):
  1148                 if lib.ZSTD_isError(zresult):
   941                     raise ZstdError('zstd decompressor error: %s' %
  1149                     raise ZstdError('zstd decompressor error: %s' %
   942                                     ffi.string(lib.ZSTD_getErrorName(zresult)))
  1150                                     ffi.string(lib.ZSTD_getErrorName(zresult)))
   943 
  1151 
   944                 if out_buffer.pos:
  1152                 if out_buffer.pos:
  1019             last_buffer = dest_buffer
  1227             last_buffer = dest_buffer
  1020             i += 1
  1228             i += 1
  1021 
  1229 
  1022         return ffi.buffer(last_buffer, len(last_buffer))[:]
  1230         return ffi.buffer(last_buffer, len(last_buffer))[:]
  1023 
  1231 
  1024     def _get_dstream(self):
  1232     def _ensure_dstream(self):
  1025         dstream = lib.ZSTD_createDStream()
  1233         if self._dstream:
  1026         if dstream == ffi.NULL:
  1234             zresult = lib.ZSTD_resetDStream(self._dstream)
       
  1235             if lib.ZSTD_isError(zresult):
       
  1236                 raise ZstdError('could not reset DStream: %s' %
       
  1237                                 ffi.string(lib.ZSTD_getErrorName(zresult)))
       
  1238 
       
  1239             return
       
  1240 
       
  1241         self._dstream = lib.ZSTD_createDStream()
       
  1242         if self._dstream == ffi.NULL:
  1027             raise MemoryError()
  1243             raise MemoryError()
  1028 
  1244 
  1029         dstream = ffi.gc(dstream, lib.ZSTD_freeDStream)
  1245         self._dstream = ffi.gc(self._dstream, lib.ZSTD_freeDStream)
  1030 
  1246 
  1031         if self._dict_data:
  1247         if self._dict_data:
  1032             zresult = lib.ZSTD_initDStream_usingDict(dstream,
  1248             zresult = lib.ZSTD_initDStream_usingDict(self._dstream,
  1033                                                      self._dict_data.as_bytes(),
  1249                                                      self._dict_data.as_bytes(),
  1034                                                      len(self._dict_data))
  1250                                                      len(self._dict_data))
  1035         else:
  1251         else:
  1036             zresult = lib.ZSTD_initDStream(dstream)
  1252             zresult = lib.ZSTD_initDStream(self._dstream)
  1037 
  1253 
  1038         if lib.ZSTD_isError(zresult):
  1254         if lib.ZSTD_isError(zresult):
       
  1255             self._dstream = None
  1039             raise ZstdError('could not initialize DStream: %s' %
  1256             raise ZstdError('could not initialize DStream: %s' %
  1040                             ffi.string(lib.ZSTD_getErrorName(zresult)))
  1257                             ffi.string(lib.ZSTD_getErrorName(zresult)))
  1041 
       
  1042         return dstream