contrib/python-zstandard/tests/test_compressor.py
changeset 43999 de7838053207
parent 42941 69de49c4e39c
child 44232 5e84a96d865b
equal deleted inserted replaced
43998:873d0fecb9a3 43999:de7838053207
    11 
    11 
    12 from .common import (
    12 from .common import (
    13     make_cffi,
    13     make_cffi,
    14     NonClosingBytesIO,
    14     NonClosingBytesIO,
    15     OpCountingBytesIO,
    15     OpCountingBytesIO,
       
    16     TestCase,
    16 )
    17 )
    17 
    18 
    18 
    19 
    19 if sys.version_info[0] >= 3:
    20 if sys.version_info[0] >= 3:
    20     next = lambda it: it.__next__()
    21     next = lambda it: it.__next__()
    21 else:
    22 else:
    22     next = lambda it: it.next()
    23     next = lambda it: it.next()
    23 
    24 
    24 
    25 
    25 def multithreaded_chunk_size(level, source_size=0):
    26 def multithreaded_chunk_size(level, source_size=0):
    26     params = zstd.ZstdCompressionParameters.from_level(level,
    27     params = zstd.ZstdCompressionParameters.from_level(level, source_size=source_size)
    27                                                        source_size=source_size)
       
    28 
    28 
    29     return 1 << (params.window_log + 2)
    29     return 1 << (params.window_log + 2)
    30 
    30 
    31 
    31 
    32 @make_cffi
    32 @make_cffi
    33 class TestCompressor(unittest.TestCase):
    33 class TestCompressor(TestCase):
    34     def test_level_bounds(self):
    34     def test_level_bounds(self):
    35         with self.assertRaises(ValueError):
    35         with self.assertRaises(ValueError):
    36             zstd.ZstdCompressor(level=23)
    36             zstd.ZstdCompressor(level=23)
    37 
    37 
    38     def test_memory_size(self):
    38     def test_memory_size(self):
    39         cctx = zstd.ZstdCompressor(level=1)
    39         cctx = zstd.ZstdCompressor(level=1)
    40         self.assertGreater(cctx.memory_size(), 100)
    40         self.assertGreater(cctx.memory_size(), 100)
    41 
    41 
    42 
    42 
    43 @make_cffi
    43 @make_cffi
    44 class TestCompressor_compress(unittest.TestCase):
    44 class TestCompressor_compress(TestCase):
    45     def test_compress_empty(self):
    45     def test_compress_empty(self):
    46         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
    46         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
    47         result = cctx.compress(b'')
    47         result = cctx.compress(b"")
    48         self.assertEqual(result, b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00')
    48         self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
    49         params = zstd.get_frame_parameters(result)
    49         params = zstd.get_frame_parameters(result)
    50         self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
    50         self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
    51         self.assertEqual(params.window_size, 524288)
    51         self.assertEqual(params.window_size, 524288)
    52         self.assertEqual(params.dict_id, 0)
    52         self.assertEqual(params.dict_id, 0)
    53         self.assertFalse(params.has_checksum, 0)
    53         self.assertFalse(params.has_checksum, 0)
    54 
    54 
    55         cctx = zstd.ZstdCompressor()
    55         cctx = zstd.ZstdCompressor()
    56         result = cctx.compress(b'')
    56         result = cctx.compress(b"")
    57         self.assertEqual(result, b'\x28\xb5\x2f\xfd\x20\x00\x01\x00\x00')
    57         self.assertEqual(result, b"\x28\xb5\x2f\xfd\x20\x00\x01\x00\x00")
    58         params = zstd.get_frame_parameters(result)
    58         params = zstd.get_frame_parameters(result)
    59         self.assertEqual(params.content_size, 0)
    59         self.assertEqual(params.content_size, 0)
    60 
    60 
    61     def test_input_types(self):
    61     def test_input_types(self):
    62         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
    62         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
    63         expected = b'\x28\xb5\x2f\xfd\x00\x00\x19\x00\x00\x66\x6f\x6f'
    63         expected = b"\x28\xb5\x2f\xfd\x00\x00\x19\x00\x00\x66\x6f\x6f"
    64 
    64 
    65         mutable_array = bytearray(3)
    65         mutable_array = bytearray(3)
    66         mutable_array[:] = b'foo'
    66         mutable_array[:] = b"foo"
    67 
    67 
    68         sources = [
    68         sources = [
    69             memoryview(b'foo'),
    69             memoryview(b"foo"),
    70             bytearray(b'foo'),
    70             bytearray(b"foo"),
    71             mutable_array,
    71             mutable_array,
    72         ]
    72         ]
    73 
    73 
    74         for source in sources:
    74         for source in sources:
    75             self.assertEqual(cctx.compress(source), expected)
    75             self.assertEqual(cctx.compress(source), expected)
    76 
    76 
    77     def test_compress_large(self):
    77     def test_compress_large(self):
    78         chunks = []
    78         chunks = []
    79         for i in range(255):
    79         for i in range(255):
    80             chunks.append(struct.Struct('>B').pack(i) * 16384)
    80             chunks.append(struct.Struct(">B").pack(i) * 16384)
    81 
    81 
    82         cctx = zstd.ZstdCompressor(level=3, write_content_size=False)
    82         cctx = zstd.ZstdCompressor(level=3, write_content_size=False)
    83         result = cctx.compress(b''.join(chunks))
    83         result = cctx.compress(b"".join(chunks))
    84         self.assertEqual(len(result), 999)
    84         self.assertEqual(len(result), 999)
    85         self.assertEqual(result[0:4], b'\x28\xb5\x2f\xfd')
    85         self.assertEqual(result[0:4], b"\x28\xb5\x2f\xfd")
    86 
    86 
    87         # This matches the test for read_to_iter() below.
    87         # This matches the test for read_to_iter() below.
    88         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
    88         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
    89         result = cctx.compress(b'f' * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE + b'o')
    89         result = cctx.compress(b"f" * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE + b"o")
    90         self.assertEqual(result, b'\x28\xb5\x2f\xfd\x00\x40\x54\x00\x00'
    90         self.assertEqual(
    91                                  b'\x10\x66\x66\x01\x00\xfb\xff\x39\xc0'
    91             result,
    92                                  b'\x02\x09\x00\x00\x6f')
    92             b"\x28\xb5\x2f\xfd\x00\x40\x54\x00\x00"
       
    93             b"\x10\x66\x66\x01\x00\xfb\xff\x39\xc0"
       
    94             b"\x02\x09\x00\x00\x6f",
       
    95         )
    93 
    96 
    94     def test_negative_level(self):
    97     def test_negative_level(self):
    95         cctx = zstd.ZstdCompressor(level=-4)
    98         cctx = zstd.ZstdCompressor(level=-4)
    96         result = cctx.compress(b'foo' * 256)
    99         result = cctx.compress(b"foo" * 256)
    97 
   100 
    98     def test_no_magic(self):
   101     def test_no_magic(self):
       
   102         params = zstd.ZstdCompressionParameters.from_level(1, format=zstd.FORMAT_ZSTD1)
       
   103         cctx = zstd.ZstdCompressor(compression_params=params)
       
   104         magic = cctx.compress(b"foobar")
       
   105 
    99         params = zstd.ZstdCompressionParameters.from_level(
   106         params = zstd.ZstdCompressionParameters.from_level(
   100             1, format=zstd.FORMAT_ZSTD1)
   107             1, format=zstd.FORMAT_ZSTD1_MAGICLESS
       
   108         )
   101         cctx = zstd.ZstdCompressor(compression_params=params)
   109         cctx = zstd.ZstdCompressor(compression_params=params)
   102         magic = cctx.compress(b'foobar')
   110         no_magic = cctx.compress(b"foobar")
   103 
   111 
   104         params = zstd.ZstdCompressionParameters.from_level(
   112         self.assertEqual(magic[0:4], b"\x28\xb5\x2f\xfd")
   105             1, format=zstd.FORMAT_ZSTD1_MAGICLESS)
       
   106         cctx = zstd.ZstdCompressor(compression_params=params)
       
   107         no_magic = cctx.compress(b'foobar')
       
   108 
       
   109         self.assertEqual(magic[0:4], b'\x28\xb5\x2f\xfd')
       
   110         self.assertEqual(magic[4:], no_magic)
   113         self.assertEqual(magic[4:], no_magic)
   111 
   114 
   112     def test_write_checksum(self):
   115     def test_write_checksum(self):
   113         cctx = zstd.ZstdCompressor(level=1)
   116         cctx = zstd.ZstdCompressor(level=1)
   114         no_checksum = cctx.compress(b'foobar')
   117         no_checksum = cctx.compress(b"foobar")
   115         cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
   118         cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
   116         with_checksum = cctx.compress(b'foobar')
   119         with_checksum = cctx.compress(b"foobar")
   117 
   120 
   118         self.assertEqual(len(with_checksum), len(no_checksum) + 4)
   121         self.assertEqual(len(with_checksum), len(no_checksum) + 4)
   119 
   122 
   120         no_params = zstd.get_frame_parameters(no_checksum)
   123         no_params = zstd.get_frame_parameters(no_checksum)
   121         with_params = zstd.get_frame_parameters(with_checksum)
   124         with_params = zstd.get_frame_parameters(with_checksum)
   123         self.assertFalse(no_params.has_checksum)
   126         self.assertFalse(no_params.has_checksum)
   124         self.assertTrue(with_params.has_checksum)
   127         self.assertTrue(with_params.has_checksum)
   125 
   128 
   126     def test_write_content_size(self):
   129     def test_write_content_size(self):
   127         cctx = zstd.ZstdCompressor(level=1)
   130         cctx = zstd.ZstdCompressor(level=1)
   128         with_size = cctx.compress(b'foobar' * 256)
   131         with_size = cctx.compress(b"foobar" * 256)
   129         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
   132         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
   130         no_size = cctx.compress(b'foobar' * 256)
   133         no_size = cctx.compress(b"foobar" * 256)
   131 
   134 
   132         self.assertEqual(len(with_size), len(no_size) + 1)
   135         self.assertEqual(len(with_size), len(no_size) + 1)
   133 
   136 
   134         no_params = zstd.get_frame_parameters(no_size)
   137         no_params = zstd.get_frame_parameters(no_size)
   135         with_params = zstd.get_frame_parameters(with_size)
   138         with_params = zstd.get_frame_parameters(with_size)
   137         self.assertEqual(with_params.content_size, 1536)
   140         self.assertEqual(with_params.content_size, 1536)
   138 
   141 
   139     def test_no_dict_id(self):
   142     def test_no_dict_id(self):
   140         samples = []
   143         samples = []
   141         for i in range(128):
   144         for i in range(128):
   142             samples.append(b'foo' * 64)
   145             samples.append(b"foo" * 64)
   143             samples.append(b'bar' * 64)
   146             samples.append(b"bar" * 64)
   144             samples.append(b'foobar' * 64)
   147             samples.append(b"foobar" * 64)
   145 
   148 
   146         d = zstd.train_dictionary(1024, samples)
   149         d = zstd.train_dictionary(1024, samples)
   147 
   150 
   148         cctx = zstd.ZstdCompressor(level=1, dict_data=d)
   151         cctx = zstd.ZstdCompressor(level=1, dict_data=d)
   149         with_dict_id = cctx.compress(b'foobarfoobar')
   152         with_dict_id = cctx.compress(b"foobarfoobar")
   150 
   153 
   151         cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False)
   154         cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False)
   152         no_dict_id = cctx.compress(b'foobarfoobar')
   155         no_dict_id = cctx.compress(b"foobarfoobar")
   153 
   156 
   154         self.assertEqual(len(with_dict_id), len(no_dict_id) + 4)
   157         self.assertEqual(len(with_dict_id), len(no_dict_id) + 4)
   155 
   158 
   156         no_params = zstd.get_frame_parameters(no_dict_id)
   159         no_params = zstd.get_frame_parameters(no_dict_id)
   157         with_params = zstd.get_frame_parameters(with_dict_id)
   160         with_params = zstd.get_frame_parameters(with_dict_id)
   159         self.assertEqual(with_params.dict_id, 1880053135)
   162         self.assertEqual(with_params.dict_id, 1880053135)
   160 
   163 
   161     def test_compress_dict_multiple(self):
   164     def test_compress_dict_multiple(self):
   162         samples = []
   165         samples = []
   163         for i in range(128):
   166         for i in range(128):
   164             samples.append(b'foo' * 64)
   167             samples.append(b"foo" * 64)
   165             samples.append(b'bar' * 64)
   168             samples.append(b"bar" * 64)
   166             samples.append(b'foobar' * 64)
   169             samples.append(b"foobar" * 64)
   167 
   170 
   168         d = zstd.train_dictionary(8192, samples)
   171         d = zstd.train_dictionary(8192, samples)
   169 
   172 
   170         cctx = zstd.ZstdCompressor(level=1, dict_data=d)
   173         cctx = zstd.ZstdCompressor(level=1, dict_data=d)
   171 
   174 
   172         for i in range(32):
   175         for i in range(32):
   173             cctx.compress(b'foo bar foobar foo bar foobar')
   176             cctx.compress(b"foo bar foobar foo bar foobar")
   174 
   177 
   175     def test_dict_precompute(self):
   178     def test_dict_precompute(self):
   176         samples = []
   179         samples = []
   177         for i in range(128):
   180         for i in range(128):
   178             samples.append(b'foo' * 64)
   181             samples.append(b"foo" * 64)
   179             samples.append(b'bar' * 64)
   182             samples.append(b"bar" * 64)
   180             samples.append(b'foobar' * 64)
   183             samples.append(b"foobar" * 64)
   181 
   184 
   182         d = zstd.train_dictionary(8192, samples)
   185         d = zstd.train_dictionary(8192, samples)
   183         d.precompute_compress(level=1)
   186         d.precompute_compress(level=1)
   184 
   187 
   185         cctx = zstd.ZstdCompressor(level=1, dict_data=d)
   188         cctx = zstd.ZstdCompressor(level=1, dict_data=d)
   186 
   189 
   187         for i in range(32):
   190         for i in range(32):
   188             cctx.compress(b'foo bar foobar foo bar foobar')
   191             cctx.compress(b"foo bar foobar foo bar foobar")
   189 
   192 
   190     def test_multithreaded(self):
   193     def test_multithreaded(self):
   191         chunk_size = multithreaded_chunk_size(1)
   194         chunk_size = multithreaded_chunk_size(1)
   192         source = b''.join([b'x' * chunk_size, b'y' * chunk_size])
   195         source = b"".join([b"x" * chunk_size, b"y" * chunk_size])
   193 
   196 
   194         cctx = zstd.ZstdCompressor(level=1, threads=2)
   197         cctx = zstd.ZstdCompressor(level=1, threads=2)
   195         compressed = cctx.compress(source)
   198         compressed = cctx.compress(source)
   196 
   199 
   197         params = zstd.get_frame_parameters(compressed)
   200         params = zstd.get_frame_parameters(compressed)
   203         self.assertEqual(dctx.decompress(compressed), source)
   206         self.assertEqual(dctx.decompress(compressed), source)
   204 
   207 
   205     def test_multithreaded_dict(self):
   208     def test_multithreaded_dict(self):
   206         samples = []
   209         samples = []
   207         for i in range(128):
   210         for i in range(128):
   208             samples.append(b'foo' * 64)
   211             samples.append(b"foo" * 64)
   209             samples.append(b'bar' * 64)
   212             samples.append(b"bar" * 64)
   210             samples.append(b'foobar' * 64)
   213             samples.append(b"foobar" * 64)
   211 
   214 
   212         d = zstd.train_dictionary(1024, samples)
   215         d = zstd.train_dictionary(1024, samples)
   213 
   216 
   214         cctx = zstd.ZstdCompressor(dict_data=d, threads=2)
   217         cctx = zstd.ZstdCompressor(dict_data=d, threads=2)
   215 
   218 
   216         result = cctx.compress(b'foo')
   219         result = cctx.compress(b"foo")
   217         params = zstd.get_frame_parameters(result);
   220         params = zstd.get_frame_parameters(result)
   218         self.assertEqual(params.content_size, 3);
   221         self.assertEqual(params.content_size, 3)
   219         self.assertEqual(params.dict_id, d.dict_id())
   222         self.assertEqual(params.dict_id, d.dict_id())
   220 
   223 
   221         self.assertEqual(result,
   224         self.assertEqual(
   222                          b'\x28\xb5\x2f\xfd\x23\x8f\x55\x0f\x70\x03\x19\x00\x00'
   225             result,
   223                          b'\x66\x6f\x6f')
   226             b"\x28\xb5\x2f\xfd\x23\x8f\x55\x0f\x70\x03\x19\x00\x00" b"\x66\x6f\x6f",
       
   227         )
   224 
   228 
   225     def test_multithreaded_compression_params(self):
   229     def test_multithreaded_compression_params(self):
   226         params = zstd.ZstdCompressionParameters.from_level(0, threads=2)
   230         params = zstd.ZstdCompressionParameters.from_level(0, threads=2)
   227         cctx = zstd.ZstdCompressor(compression_params=params)
   231         cctx = zstd.ZstdCompressor(compression_params=params)
   228 
   232 
   229         result = cctx.compress(b'foo')
   233         result = cctx.compress(b"foo")
   230         params = zstd.get_frame_parameters(result);
   234         params = zstd.get_frame_parameters(result)
   231         self.assertEqual(params.content_size, 3);
   235         self.assertEqual(params.content_size, 3)
   232 
   236 
   233         self.assertEqual(result,
   237         self.assertEqual(result, b"\x28\xb5\x2f\xfd\x20\x03\x19\x00\x00\x66\x6f\x6f")
   234                          b'\x28\xb5\x2f\xfd\x20\x03\x19\x00\x00\x66\x6f\x6f')
       
   235 
   238 
   236 
   239 
   237 @make_cffi
   240 @make_cffi
   238 class TestCompressor_compressobj(unittest.TestCase):
   241 class TestCompressor_compressobj(TestCase):
   239     def test_compressobj_empty(self):
   242     def test_compressobj_empty(self):
   240         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
   243         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
   241         cobj = cctx.compressobj()
   244         cobj = cctx.compressobj()
   242         self.assertEqual(cobj.compress(b''), b'')
   245         self.assertEqual(cobj.compress(b""), b"")
   243         self.assertEqual(cobj.flush(),
   246         self.assertEqual(cobj.flush(), b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
   244                          b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00')
       
   245 
   247 
   246     def test_input_types(self):
   248     def test_input_types(self):
   247         expected = b'\x28\xb5\x2f\xfd\x00\x48\x19\x00\x00\x66\x6f\x6f'
   249         expected = b"\x28\xb5\x2f\xfd\x00\x48\x19\x00\x00\x66\x6f\x6f"
   248         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
   250         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
   249 
   251 
   250         mutable_array = bytearray(3)
   252         mutable_array = bytearray(3)
   251         mutable_array[:] = b'foo'
   253         mutable_array[:] = b"foo"
   252 
   254 
   253         sources = [
   255         sources = [
   254             memoryview(b'foo'),
   256             memoryview(b"foo"),
   255             bytearray(b'foo'),
   257             bytearray(b"foo"),
   256             mutable_array,
   258             mutable_array,
   257         ]
   259         ]
   258 
   260 
   259         for source in sources:
   261         for source in sources:
   260             cobj = cctx.compressobj()
   262             cobj = cctx.compressobj()
   261             self.assertEqual(cobj.compress(source), b'')
   263             self.assertEqual(cobj.compress(source), b"")
   262             self.assertEqual(cobj.flush(), expected)
   264             self.assertEqual(cobj.flush(), expected)
   263 
   265 
   264     def test_compressobj_large(self):
   266     def test_compressobj_large(self):
   265         chunks = []
   267         chunks = []
   266         for i in range(255):
   268         for i in range(255):
   267             chunks.append(struct.Struct('>B').pack(i) * 16384)
   269             chunks.append(struct.Struct(">B").pack(i) * 16384)
   268 
   270 
   269         cctx = zstd.ZstdCompressor(level=3)
   271         cctx = zstd.ZstdCompressor(level=3)
   270         cobj = cctx.compressobj()
   272         cobj = cctx.compressobj()
   271 
   273 
   272         result = cobj.compress(b''.join(chunks)) + cobj.flush()
   274         result = cobj.compress(b"".join(chunks)) + cobj.flush()
   273         self.assertEqual(len(result), 999)
   275         self.assertEqual(len(result), 999)
   274         self.assertEqual(result[0:4], b'\x28\xb5\x2f\xfd')
   276         self.assertEqual(result[0:4], b"\x28\xb5\x2f\xfd")
   275 
   277 
   276         params = zstd.get_frame_parameters(result)
   278         params = zstd.get_frame_parameters(result)
   277         self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
   279         self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
   278         self.assertEqual(params.window_size, 2097152)
   280         self.assertEqual(params.window_size, 2097152)
   279         self.assertEqual(params.dict_id, 0)
   281         self.assertEqual(params.dict_id, 0)
   280         self.assertFalse(params.has_checksum)
   282         self.assertFalse(params.has_checksum)
   281 
   283 
   282     def test_write_checksum(self):
   284     def test_write_checksum(self):
   283         cctx = zstd.ZstdCompressor(level=1)
   285         cctx = zstd.ZstdCompressor(level=1)
   284         cobj = cctx.compressobj()
   286         cobj = cctx.compressobj()
   285         no_checksum = cobj.compress(b'foobar') + cobj.flush()
   287         no_checksum = cobj.compress(b"foobar") + cobj.flush()
   286         cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
   288         cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
   287         cobj = cctx.compressobj()
   289         cobj = cctx.compressobj()
   288         with_checksum = cobj.compress(b'foobar') + cobj.flush()
   290         with_checksum = cobj.compress(b"foobar") + cobj.flush()
   289 
   291 
   290         no_params = zstd.get_frame_parameters(no_checksum)
   292         no_params = zstd.get_frame_parameters(no_checksum)
   291         with_params = zstd.get_frame_parameters(with_checksum)
   293         with_params = zstd.get_frame_parameters(with_checksum)
   292         self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
   294         self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
   293         self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
   295         self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
   298 
   300 
   299         self.assertEqual(len(with_checksum), len(no_checksum) + 4)
   301         self.assertEqual(len(with_checksum), len(no_checksum) + 4)
   300 
   302 
   301     def test_write_content_size(self):
   303     def test_write_content_size(self):
   302         cctx = zstd.ZstdCompressor(level=1)
   304         cctx = zstd.ZstdCompressor(level=1)
   303         cobj = cctx.compressobj(size=len(b'foobar' * 256))
   305         cobj = cctx.compressobj(size=len(b"foobar" * 256))
   304         with_size = cobj.compress(b'foobar' * 256) + cobj.flush()
   306         with_size = cobj.compress(b"foobar" * 256) + cobj.flush()
   305         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
   307         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
   306         cobj = cctx.compressobj(size=len(b'foobar' * 256))
   308         cobj = cctx.compressobj(size=len(b"foobar" * 256))
   307         no_size = cobj.compress(b'foobar' * 256) + cobj.flush()
   309         no_size = cobj.compress(b"foobar" * 256) + cobj.flush()
   308 
   310 
   309         no_params = zstd.get_frame_parameters(no_size)
   311         no_params = zstd.get_frame_parameters(no_size)
   310         with_params = zstd.get_frame_parameters(with_size)
   312         with_params = zstd.get_frame_parameters(with_size)
   311         self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
   313         self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
   312         self.assertEqual(with_params.content_size, 1536)
   314         self.assertEqual(with_params.content_size, 1536)
   319 
   321 
   320     def test_compress_after_finished(self):
   322     def test_compress_after_finished(self):
   321         cctx = zstd.ZstdCompressor()
   323         cctx = zstd.ZstdCompressor()
   322         cobj = cctx.compressobj()
   324         cobj = cctx.compressobj()
   323 
   325 
   324         cobj.compress(b'foo')
   326         cobj.compress(b"foo")
   325         cobj.flush()
   327         cobj.flush()
   326 
   328 
   327         with self.assertRaisesRegexp(zstd.ZstdError, r'cannot call compress\(\) after compressor'):
   329         with self.assertRaisesRegex(
   328             cobj.compress(b'foo')
   330             zstd.ZstdError, r"cannot call compress\(\) after compressor"
   329 
   331         ):
   330         with self.assertRaisesRegexp(zstd.ZstdError, 'compressor object already finished'):
   332             cobj.compress(b"foo")
       
   333 
       
   334         with self.assertRaisesRegex(
       
   335             zstd.ZstdError, "compressor object already finished"
       
   336         ):
   331             cobj.flush()
   337             cobj.flush()
   332 
   338 
   333     def test_flush_block_repeated(self):
   339     def test_flush_block_repeated(self):
   334         cctx = zstd.ZstdCompressor(level=1)
   340         cctx = zstd.ZstdCompressor(level=1)
   335         cobj = cctx.compressobj()
   341         cobj = cctx.compressobj()
   336 
   342 
   337         self.assertEqual(cobj.compress(b'foo'), b'')
   343         self.assertEqual(cobj.compress(b"foo"), b"")
   338         self.assertEqual(cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK),
   344         self.assertEqual(
   339                          b'\x28\xb5\x2f\xfd\x00\x48\x18\x00\x00foo')
   345             cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK),
   340         self.assertEqual(cobj.compress(b'bar'), b'')
   346             b"\x28\xb5\x2f\xfd\x00\x48\x18\x00\x00foo",
       
   347         )
       
   348         self.assertEqual(cobj.compress(b"bar"), b"")
   341         # 3 byte header plus content.
   349         # 3 byte header plus content.
   342         self.assertEqual(cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK),
   350         self.assertEqual(cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b"\x18\x00\x00bar")
   343                          b'\x18\x00\x00bar')
   351         self.assertEqual(cobj.flush(), b"\x01\x00\x00")
   344         self.assertEqual(cobj.flush(), b'\x01\x00\x00')
       
   345 
   352 
   346     def test_flush_empty_block(self):
   353     def test_flush_empty_block(self):
   347         cctx = zstd.ZstdCompressor(write_checksum=True)
   354         cctx = zstd.ZstdCompressor(write_checksum=True)
   348         cobj = cctx.compressobj()
   355         cobj = cctx.compressobj()
   349 
   356 
   350         cobj.compress(b'foobar')
   357         cobj.compress(b"foobar")
   351         cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK)
   358         cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK)
   352         # No-op if no block is active (this is internal to zstd).
   359         # No-op if no block is active (this is internal to zstd).
   353         self.assertEqual(cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b'')
   360         self.assertEqual(cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b"")
   354 
   361 
   355         trailing = cobj.flush()
   362         trailing = cobj.flush()
   356         # 3 bytes block header + 4 bytes frame checksum
   363         # 3 bytes block header + 4 bytes frame checksum
   357         self.assertEqual(len(trailing), 7)
   364         self.assertEqual(len(trailing), 7)
   358         header = trailing[0:3]
   365         header = trailing[0:3]
   359         self.assertEqual(header, b'\x01\x00\x00')
   366         self.assertEqual(header, b"\x01\x00\x00")
   360 
   367 
   361     def test_multithreaded(self):
   368     def test_multithreaded(self):
   362         source = io.BytesIO()
   369         source = io.BytesIO()
   363         source.write(b'a' * 1048576)
   370         source.write(b"a" * 1048576)
   364         source.write(b'b' * 1048576)
   371         source.write(b"b" * 1048576)
   365         source.write(b'c' * 1048576)
   372         source.write(b"c" * 1048576)
   366         source.seek(0)
   373         source.seek(0)
   367 
   374 
   368         cctx = zstd.ZstdCompressor(level=1, threads=2)
   375         cctx = zstd.ZstdCompressor(level=1, threads=2)
   369         cobj = cctx.compressobj()
   376         cobj = cctx.compressobj()
   370 
   377 
   376 
   383 
   377             chunks.append(cobj.compress(d))
   384             chunks.append(cobj.compress(d))
   378 
   385 
   379         chunks.append(cobj.flush())
   386         chunks.append(cobj.flush())
   380 
   387 
   381         compressed = b''.join(chunks)
   388         compressed = b"".join(chunks)
   382 
   389 
   383         self.assertEqual(len(compressed), 295)
   390         self.assertEqual(len(compressed), 119)
   384 
   391 
   385     def test_frame_progression(self):
   392     def test_frame_progression(self):
   386         cctx = zstd.ZstdCompressor()
   393         cctx = zstd.ZstdCompressor()
   387 
   394 
   388         self.assertEqual(cctx.frame_progression(), (0, 0, 0))
   395         self.assertEqual(cctx.frame_progression(), (0, 0, 0))
   389 
   396 
   390         cobj = cctx.compressobj()
   397         cobj = cctx.compressobj()
   391 
   398 
   392         cobj.compress(b'foobar')
   399         cobj.compress(b"foobar")
   393         self.assertEqual(cctx.frame_progression(), (6, 0, 0))
   400         self.assertEqual(cctx.frame_progression(), (6, 0, 0))
   394 
   401 
   395         cobj.flush()
   402         cobj.flush()
   396         self.assertEqual(cctx.frame_progression(), (6, 6, 15))
   403         self.assertEqual(cctx.frame_progression(), (6, 6, 15))
   397 
   404 
   398     def test_bad_size(self):
   405     def test_bad_size(self):
   399         cctx = zstd.ZstdCompressor()
   406         cctx = zstd.ZstdCompressor()
   400 
   407 
   401         cobj = cctx.compressobj(size=2)
   408         cobj = cctx.compressobj(size=2)
   402         with self.assertRaisesRegexp(zstd.ZstdError, 'Src size is incorrect'):
   409         with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
   403             cobj.compress(b'foo')
   410             cobj.compress(b"foo")
   404 
   411 
   405         # Try another operation on this instance.
   412         # Try another operation on this instance.
   406         with self.assertRaisesRegexp(zstd.ZstdError, 'Src size is incorrect'):
   413         with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
   407             cobj.compress(b'aa')
   414             cobj.compress(b"aa")
   408 
   415 
   409         # Try another operation on the compressor.
   416         # Try another operation on the compressor.
   410         cctx.compressobj(size=4)
   417         cctx.compressobj(size=4)
   411         cctx.compress(b'foobar')
   418         cctx.compress(b"foobar")
   412 
   419 
   413 
   420 
   414 @make_cffi
   421 @make_cffi
   415 class TestCompressor_copy_stream(unittest.TestCase):
   422 class TestCompressor_copy_stream(TestCase):
   416     def test_no_read(self):
   423     def test_no_read(self):
   417         source = object()
   424         source = object()
   418         dest = io.BytesIO()
   425         dest = io.BytesIO()
   419 
   426 
   420         cctx = zstd.ZstdCompressor()
   427         cctx = zstd.ZstdCompressor()
   436         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
   443         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
   437         r, w = cctx.copy_stream(source, dest)
   444         r, w = cctx.copy_stream(source, dest)
   438         self.assertEqual(int(r), 0)
   445         self.assertEqual(int(r), 0)
   439         self.assertEqual(w, 9)
   446         self.assertEqual(w, 9)
   440 
   447 
   441         self.assertEqual(dest.getvalue(),
   448         self.assertEqual(dest.getvalue(), b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
   442                          b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00')
       
   443 
   449 
   444     def test_large_data(self):
   450     def test_large_data(self):
   445         source = io.BytesIO()
   451         source = io.BytesIO()
   446         for i in range(255):
   452         for i in range(255):
   447             source.write(struct.Struct('>B').pack(i) * 16384)
   453             source.write(struct.Struct(">B").pack(i) * 16384)
   448         source.seek(0)
   454         source.seek(0)
   449 
   455 
   450         dest = io.BytesIO()
   456         dest = io.BytesIO()
   451         cctx = zstd.ZstdCompressor()
   457         cctx = zstd.ZstdCompressor()
   452         r, w = cctx.copy_stream(source, dest)
   458         r, w = cctx.copy_stream(source, dest)
   459         self.assertEqual(params.window_size, 2097152)
   465         self.assertEqual(params.window_size, 2097152)
   460         self.assertEqual(params.dict_id, 0)
   466         self.assertEqual(params.dict_id, 0)
   461         self.assertFalse(params.has_checksum)
   467         self.assertFalse(params.has_checksum)
   462 
   468 
   463     def test_write_checksum(self):
   469     def test_write_checksum(self):
   464         source = io.BytesIO(b'foobar')
   470         source = io.BytesIO(b"foobar")
   465         no_checksum = io.BytesIO()
   471         no_checksum = io.BytesIO()
   466 
   472 
   467         cctx = zstd.ZstdCompressor(level=1)
   473         cctx = zstd.ZstdCompressor(level=1)
   468         cctx.copy_stream(source, no_checksum)
   474         cctx.copy_stream(source, no_checksum)
   469 
   475 
   470         source.seek(0)
   476         source.seek(0)
   471         with_checksum = io.BytesIO()
   477         with_checksum = io.BytesIO()
   472         cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
   478         cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
   473         cctx.copy_stream(source, with_checksum)
   479         cctx.copy_stream(source, with_checksum)
   474 
   480 
   475         self.assertEqual(len(with_checksum.getvalue()),
   481         self.assertEqual(len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4)
   476                          len(no_checksum.getvalue()) + 4)
       
   477 
   482 
   478         no_params = zstd.get_frame_parameters(no_checksum.getvalue())
   483         no_params = zstd.get_frame_parameters(no_checksum.getvalue())
   479         with_params = zstd.get_frame_parameters(with_checksum.getvalue())
   484         with_params = zstd.get_frame_parameters(with_checksum.getvalue())
   480         self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
   485         self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
   481         self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
   486         self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
   483         self.assertEqual(with_params.dict_id, 0)
   488         self.assertEqual(with_params.dict_id, 0)
   484         self.assertFalse(no_params.has_checksum)
   489         self.assertFalse(no_params.has_checksum)
   485         self.assertTrue(with_params.has_checksum)
   490         self.assertTrue(with_params.has_checksum)
   486 
   491 
   487     def test_write_content_size(self):
   492     def test_write_content_size(self):
   488         source = io.BytesIO(b'foobar' * 256)
   493         source = io.BytesIO(b"foobar" * 256)
   489         no_size = io.BytesIO()
   494         no_size = io.BytesIO()
   490 
   495 
   491         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
   496         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
   492         cctx.copy_stream(source, no_size)
   497         cctx.copy_stream(source, no_size)
   493 
   498 
   495         with_size = io.BytesIO()
   500         with_size = io.BytesIO()
   496         cctx = zstd.ZstdCompressor(level=1)
   501         cctx = zstd.ZstdCompressor(level=1)
   497         cctx.copy_stream(source, with_size)
   502         cctx.copy_stream(source, with_size)
   498 
   503 
   499         # Source content size is unknown, so no content size written.
   504         # Source content size is unknown, so no content size written.
   500         self.assertEqual(len(with_size.getvalue()),
   505         self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()))
   501                          len(no_size.getvalue()))
       
   502 
   506 
   503         source.seek(0)
   507         source.seek(0)
   504         with_size = io.BytesIO()
   508         with_size = io.BytesIO()
   505         cctx.copy_stream(source, with_size, size=len(source.getvalue()))
   509         cctx.copy_stream(source, with_size, size=len(source.getvalue()))
   506 
   510 
   507         # We specified source size, so content size header is present.
   511         # We specified source size, so content size header is present.
   508         self.assertEqual(len(with_size.getvalue()),
   512         self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()) + 1)
   509                          len(no_size.getvalue()) + 1)
       
   510 
   513 
   511         no_params = zstd.get_frame_parameters(no_size.getvalue())
   514         no_params = zstd.get_frame_parameters(no_size.getvalue())
   512         with_params = zstd.get_frame_parameters(with_size.getvalue())
   515         with_params = zstd.get_frame_parameters(with_size.getvalue())
   513         self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
   516         self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
   514         self.assertEqual(with_params.content_size, 1536)
   517         self.assertEqual(with_params.content_size, 1536)
   516         self.assertEqual(with_params.dict_id, 0)
   519         self.assertEqual(with_params.dict_id, 0)
   517         self.assertFalse(no_params.has_checksum)
   520         self.assertFalse(no_params.has_checksum)
   518         self.assertFalse(with_params.has_checksum)
   521         self.assertFalse(with_params.has_checksum)
   519 
   522 
   520     def test_read_write_size(self):
   523     def test_read_write_size(self):
   521         source = OpCountingBytesIO(b'foobarfoobar')
   524         source = OpCountingBytesIO(b"foobarfoobar")
   522         dest = OpCountingBytesIO()
   525         dest = OpCountingBytesIO()
   523         cctx = zstd.ZstdCompressor()
   526         cctx = zstd.ZstdCompressor()
   524         r, w = cctx.copy_stream(source, dest, read_size=1, write_size=1)
   527         r, w = cctx.copy_stream(source, dest, read_size=1, write_size=1)
   525 
   528 
   526         self.assertEqual(r, len(source.getvalue()))
   529         self.assertEqual(r, len(source.getvalue()))
   528         self.assertEqual(source._read_count, len(source.getvalue()) + 1)
   531         self.assertEqual(source._read_count, len(source.getvalue()) + 1)
   529         self.assertEqual(dest._write_count, len(dest.getvalue()))
   532         self.assertEqual(dest._write_count, len(dest.getvalue()))
   530 
   533 
   531     def test_multithreaded(self):
   534     def test_multithreaded(self):
   532         source = io.BytesIO()
   535         source = io.BytesIO()
   533         source.write(b'a' * 1048576)
   536         source.write(b"a" * 1048576)
   534         source.write(b'b' * 1048576)
   537         source.write(b"b" * 1048576)
   535         source.write(b'c' * 1048576)
   538         source.write(b"c" * 1048576)
   536         source.seek(0)
   539         source.seek(0)
   537 
   540 
   538         dest = io.BytesIO()
   541         dest = io.BytesIO()
   539         cctx = zstd.ZstdCompressor(threads=2, write_content_size=False)
   542         cctx = zstd.ZstdCompressor(threads=2, write_content_size=False)
   540         r, w = cctx.copy_stream(source, dest)
   543         r, w = cctx.copy_stream(source, dest)
   541         self.assertEqual(r, 3145728)
   544         self.assertEqual(r, 3145728)
   542         self.assertEqual(w, 295)
   545         self.assertEqual(w, 111)
   543 
   546 
   544         params = zstd.get_frame_parameters(dest.getvalue())
   547         params = zstd.get_frame_parameters(dest.getvalue())
   545         self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
   548         self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
   546         self.assertEqual(params.dict_id, 0)
   549         self.assertEqual(params.dict_id, 0)
   547         self.assertFalse(params.has_checksum)
   550         self.assertFalse(params.has_checksum)
   557         self.assertEqual(params.dict_id, 0)
   560         self.assertEqual(params.dict_id, 0)
   558         self.assertTrue(params.has_checksum)
   561         self.assertTrue(params.has_checksum)
   559 
   562 
   560     def test_bad_size(self):
   563     def test_bad_size(self):
   561         source = io.BytesIO()
   564         source = io.BytesIO()
   562         source.write(b'a' * 32768)
   565         source.write(b"a" * 32768)
   563         source.write(b'b' * 32768)
   566         source.write(b"b" * 32768)
   564         source.seek(0)
   567         source.seek(0)
   565 
   568 
   566         dest = io.BytesIO()
   569         dest = io.BytesIO()
   567 
   570 
   568         cctx = zstd.ZstdCompressor()
   571         cctx = zstd.ZstdCompressor()
   569 
   572 
   570         with self.assertRaisesRegexp(zstd.ZstdError, 'Src size is incorrect'):
   573         with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
   571             cctx.copy_stream(source, dest, size=42)
   574             cctx.copy_stream(source, dest, size=42)
   572 
   575 
   573         # Try another operation on this compressor.
   576         # Try another operation on this compressor.
   574         source.seek(0)
   577         source.seek(0)
   575         dest = io.BytesIO()
   578         dest = io.BytesIO()
   576         cctx.copy_stream(source, dest)
   579         cctx.copy_stream(source, dest)
   577 
   580 
   578 
   581 
   579 @make_cffi
   582 @make_cffi
   580 class TestCompressor_stream_reader(unittest.TestCase):
   583 class TestCompressor_stream_reader(TestCase):
   581     def test_context_manager(self):
   584     def test_context_manager(self):
   582         cctx = zstd.ZstdCompressor()
   585         cctx = zstd.ZstdCompressor()
   583 
   586 
   584         with cctx.stream_reader(b'foo') as reader:
   587         with cctx.stream_reader(b"foo") as reader:
   585             with self.assertRaisesRegexp(ValueError, 'cannot __enter__ multiple times'):
   588             with self.assertRaisesRegex(ValueError, "cannot __enter__ multiple times"):
   586                 with reader as reader2:
   589                 with reader as reader2:
   587                     pass
   590                     pass
   588 
   591 
   589     def test_no_context_manager(self):
   592     def test_no_context_manager(self):
   590         cctx = zstd.ZstdCompressor()
   593         cctx = zstd.ZstdCompressor()
   591 
   594 
   592         reader = cctx.stream_reader(b'foo')
   595         reader = cctx.stream_reader(b"foo")
   593         reader.read(4)
   596         reader.read(4)
   594         self.assertFalse(reader.closed)
   597         self.assertFalse(reader.closed)
   595 
   598 
   596         reader.close()
   599         reader.close()
   597         self.assertTrue(reader.closed)
   600         self.assertTrue(reader.closed)
   598         with self.assertRaisesRegexp(ValueError, 'stream is closed'):
   601         with self.assertRaisesRegex(ValueError, "stream is closed"):
   599             reader.read(1)
   602             reader.read(1)
   600 
   603 
   601     def test_not_implemented(self):
   604     def test_not_implemented(self):
   602         cctx = zstd.ZstdCompressor()
   605         cctx = zstd.ZstdCompressor()
   603 
   606 
   604         with cctx.stream_reader(b'foo' * 60) as reader:
   607         with cctx.stream_reader(b"foo" * 60) as reader:
   605             with self.assertRaises(io.UnsupportedOperation):
   608             with self.assertRaises(io.UnsupportedOperation):
   606                 reader.readline()
   609                 reader.readline()
   607 
   610 
   608             with self.assertRaises(io.UnsupportedOperation):
   611             with self.assertRaises(io.UnsupportedOperation):
   609                 reader.readlines()
   612                 reader.readlines()
   616 
   619 
   617             with self.assertRaises(OSError):
   620             with self.assertRaises(OSError):
   618                 reader.writelines([])
   621                 reader.writelines([])
   619 
   622 
   620             with self.assertRaises(OSError):
   623             with self.assertRaises(OSError):
   621                 reader.write(b'foo')
   624                 reader.write(b"foo")
   622 
   625 
   623     def test_constant_methods(self):
   626     def test_constant_methods(self):
   624         cctx = zstd.ZstdCompressor()
   627         cctx = zstd.ZstdCompressor()
   625 
   628 
   626         with cctx.stream_reader(b'boo') as reader:
   629         with cctx.stream_reader(b"boo") as reader:
   627             self.assertTrue(reader.readable())
   630             self.assertTrue(reader.readable())
   628             self.assertFalse(reader.writable())
   631             self.assertFalse(reader.writable())
   629             self.assertFalse(reader.seekable())
   632             self.assertFalse(reader.seekable())
   630             self.assertFalse(reader.isatty())
   633             self.assertFalse(reader.isatty())
   631             self.assertFalse(reader.closed)
   634             self.assertFalse(reader.closed)
   635         self.assertTrue(reader.closed)
   638         self.assertTrue(reader.closed)
   636 
   639 
   637     def test_read_closed(self):
   640     def test_read_closed(self):
   638         cctx = zstd.ZstdCompressor()
   641         cctx = zstd.ZstdCompressor()
   639 
   642 
   640         with cctx.stream_reader(b'foo' * 60) as reader:
   643         with cctx.stream_reader(b"foo" * 60) as reader:
   641             reader.close()
   644             reader.close()
   642             self.assertTrue(reader.closed)
   645             self.assertTrue(reader.closed)
   643             with self.assertRaisesRegexp(ValueError, 'stream is closed'):
   646             with self.assertRaisesRegex(ValueError, "stream is closed"):
   644                 reader.read(10)
   647                 reader.read(10)
   645 
   648 
   646     def test_read_sizes(self):
   649     def test_read_sizes(self):
   647         cctx = zstd.ZstdCompressor()
   650         cctx = zstd.ZstdCompressor()
   648         foo = cctx.compress(b'foo')
   651         foo = cctx.compress(b"foo")
   649 
   652 
   650         with cctx.stream_reader(b'foo') as reader:
   653         with cctx.stream_reader(b"foo") as reader:
   651             with self.assertRaisesRegexp(ValueError, 'cannot read negative amounts less than -1'):
   654             with self.assertRaisesRegex(
       
   655                 ValueError, "cannot read negative amounts less than -1"
       
   656             ):
   652                 reader.read(-2)
   657                 reader.read(-2)
   653 
   658 
   654             self.assertEqual(reader.read(0), b'')
   659             self.assertEqual(reader.read(0), b"")
   655             self.assertEqual(reader.read(), foo)
   660             self.assertEqual(reader.read(), foo)
   656 
   661 
   657     def test_read_buffer(self):
   662     def test_read_buffer(self):
   658         cctx = zstd.ZstdCompressor()
   663         cctx = zstd.ZstdCompressor()
   659 
   664 
   660         source = b''.join([b'foo' * 60, b'bar' * 60, b'baz' * 60])
   665         source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60])
   661         frame = cctx.compress(source)
   666         frame = cctx.compress(source)
   662 
   667 
   663         with cctx.stream_reader(source) as reader:
   668         with cctx.stream_reader(source) as reader:
   664             self.assertEqual(reader.tell(), 0)
   669             self.assertEqual(reader.tell(), 0)
   665 
   670 
   666             # We should get entire frame in one read.
   671             # We should get entire frame in one read.
   667             result = reader.read(8192)
   672             result = reader.read(8192)
   668             self.assertEqual(result, frame)
   673             self.assertEqual(result, frame)
   669             self.assertEqual(reader.tell(), len(result))
   674             self.assertEqual(reader.tell(), len(result))
   670             self.assertEqual(reader.read(), b'')
   675             self.assertEqual(reader.read(), b"")
   671             self.assertEqual(reader.tell(), len(result))
   676             self.assertEqual(reader.tell(), len(result))
   672 
   677 
   673     def test_read_buffer_small_chunks(self):
   678     def test_read_buffer_small_chunks(self):
   674         cctx = zstd.ZstdCompressor()
   679         cctx = zstd.ZstdCompressor()
   675 
   680 
   676         source = b'foo' * 60
   681         source = b"foo" * 60
   677         chunks = []
   682         chunks = []
   678 
   683 
   679         with cctx.stream_reader(source) as reader:
   684         with cctx.stream_reader(source) as reader:
   680             self.assertEqual(reader.tell(), 0)
   685             self.assertEqual(reader.tell(), 0)
   681 
   686 
   685                     break
   690                     break
   686 
   691 
   687                 chunks.append(chunk)
   692                 chunks.append(chunk)
   688                 self.assertEqual(reader.tell(), sum(map(len, chunks)))
   693                 self.assertEqual(reader.tell(), sum(map(len, chunks)))
   689 
   694 
   690         self.assertEqual(b''.join(chunks), cctx.compress(source))
   695         self.assertEqual(b"".join(chunks), cctx.compress(source))
   691 
   696 
   692     def test_read_stream(self):
   697     def test_read_stream(self):
   693         cctx = zstd.ZstdCompressor()
   698         cctx = zstd.ZstdCompressor()
   694 
   699 
   695         source = b''.join([b'foo' * 60, b'bar' * 60, b'baz' * 60])
   700         source = b"".join([b"foo" * 60, b"bar" * 60, b"baz" * 60])
   696         frame = cctx.compress(source)
   701         frame = cctx.compress(source)
   697 
   702 
   698         with cctx.stream_reader(io.BytesIO(source), size=len(source)) as reader:
   703         with cctx.stream_reader(io.BytesIO(source), size=len(source)) as reader:
   699             self.assertEqual(reader.tell(), 0)
   704             self.assertEqual(reader.tell(), 0)
   700 
   705 
   701             chunk = reader.read(8192)
   706             chunk = reader.read(8192)
   702             self.assertEqual(chunk, frame)
   707             self.assertEqual(chunk, frame)
   703             self.assertEqual(reader.tell(), len(chunk))
   708             self.assertEqual(reader.tell(), len(chunk))
   704             self.assertEqual(reader.read(), b'')
   709             self.assertEqual(reader.read(), b"")
   705             self.assertEqual(reader.tell(), len(chunk))
   710             self.assertEqual(reader.tell(), len(chunk))
   706 
   711 
   707     def test_read_stream_small_chunks(self):
   712     def test_read_stream_small_chunks(self):
   708         cctx = zstd.ZstdCompressor()
   713         cctx = zstd.ZstdCompressor()
   709 
   714 
   710         source = b'foo' * 60
   715         source = b"foo" * 60
   711         chunks = []
   716         chunks = []
   712 
   717 
   713         with cctx.stream_reader(io.BytesIO(source), size=len(source)) as reader:
   718         with cctx.stream_reader(io.BytesIO(source), size=len(source)) as reader:
   714             self.assertEqual(reader.tell(), 0)
   719             self.assertEqual(reader.tell(), 0)
   715 
   720 
   719                     break
   724                     break
   720 
   725 
   721                 chunks.append(chunk)
   726                 chunks.append(chunk)
   722                 self.assertEqual(reader.tell(), sum(map(len, chunks)))
   727                 self.assertEqual(reader.tell(), sum(map(len, chunks)))
   723 
   728 
   724         self.assertEqual(b''.join(chunks), cctx.compress(source))
   729         self.assertEqual(b"".join(chunks), cctx.compress(source))
   725 
   730 
   726     def test_read_after_exit(self):
   731     def test_read_after_exit(self):
   727         cctx = zstd.ZstdCompressor()
   732         cctx = zstd.ZstdCompressor()
   728 
   733 
   729         with cctx.stream_reader(b'foo' * 60) as reader:
   734         with cctx.stream_reader(b"foo" * 60) as reader:
   730             while reader.read(8192):
   735             while reader.read(8192):
   731                 pass
   736                 pass
   732 
   737 
   733         with self.assertRaisesRegexp(ValueError, 'stream is closed'):
   738         with self.assertRaisesRegex(ValueError, "stream is closed"):
   734             reader.read(10)
   739             reader.read(10)
   735 
   740 
   736     def test_bad_size(self):
   741     def test_bad_size(self):
   737         cctx = zstd.ZstdCompressor()
   742         cctx = zstd.ZstdCompressor()
   738 
   743 
   739         source = io.BytesIO(b'foobar')
   744         source = io.BytesIO(b"foobar")
   740 
   745 
   741         with cctx.stream_reader(source, size=2) as reader:
   746         with cctx.stream_reader(source, size=2) as reader:
   742             with self.assertRaisesRegexp(zstd.ZstdError, 'Src size is incorrect'):
   747             with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
   743                 reader.read(10)
   748                 reader.read(10)
   744 
   749 
   745         # Try another compression operation.
   750         # Try another compression operation.
   746         with cctx.stream_reader(source, size=42):
   751         with cctx.stream_reader(source, size=42):
   747             pass
   752             pass
   748 
   753 
   749     def test_readall(self):
   754     def test_readall(self):
   750         cctx = zstd.ZstdCompressor()
   755         cctx = zstd.ZstdCompressor()
   751         frame = cctx.compress(b'foo' * 1024)
   756         frame = cctx.compress(b"foo" * 1024)
   752 
   757 
   753         reader = cctx.stream_reader(b'foo' * 1024)
   758         reader = cctx.stream_reader(b"foo" * 1024)
   754         self.assertEqual(reader.readall(), frame)
   759         self.assertEqual(reader.readall(), frame)
   755 
   760 
   756     def test_readinto(self):
   761     def test_readinto(self):
   757         cctx = zstd.ZstdCompressor()
   762         cctx = zstd.ZstdCompressor()
   758         foo = cctx.compress(b'foo')
   763         foo = cctx.compress(b"foo")
   759 
   764 
   760         reader = cctx.stream_reader(b'foo')
   765         reader = cctx.stream_reader(b"foo")
   761         with self.assertRaises(Exception):
   766         with self.assertRaises(Exception):
   762             reader.readinto(b'foobar')
   767             reader.readinto(b"foobar")
   763 
   768 
   764         # readinto() with sufficiently large destination.
   769         # readinto() with sufficiently large destination.
   765         b = bytearray(1024)
   770         b = bytearray(1024)
   766         reader = cctx.stream_reader(b'foo')
   771         reader = cctx.stream_reader(b"foo")
   767         self.assertEqual(reader.readinto(b), len(foo))
   772         self.assertEqual(reader.readinto(b), len(foo))
   768         self.assertEqual(b[0:len(foo)], foo)
   773         self.assertEqual(b[0 : len(foo)], foo)
   769         self.assertEqual(reader.readinto(b), 0)
   774         self.assertEqual(reader.readinto(b), 0)
   770         self.assertEqual(b[0:len(foo)], foo)
   775         self.assertEqual(b[0 : len(foo)], foo)
   771 
   776 
   772         # readinto() with small reads.
   777         # readinto() with small reads.
   773         b = bytearray(1024)
   778         b = bytearray(1024)
   774         reader = cctx.stream_reader(b'foo', read_size=1)
   779         reader = cctx.stream_reader(b"foo", read_size=1)
   775         self.assertEqual(reader.readinto(b), len(foo))
   780         self.assertEqual(reader.readinto(b), len(foo))
   776         self.assertEqual(b[0:len(foo)], foo)
   781         self.assertEqual(b[0 : len(foo)], foo)
   777 
   782 
   778         # Too small destination buffer.
   783         # Too small destination buffer.
   779         b = bytearray(2)
   784         b = bytearray(2)
   780         reader = cctx.stream_reader(b'foo')
   785         reader = cctx.stream_reader(b"foo")
   781         self.assertEqual(reader.readinto(b), 2)
   786         self.assertEqual(reader.readinto(b), 2)
   782         self.assertEqual(b[:], foo[0:2])
   787         self.assertEqual(b[:], foo[0:2])
   783         self.assertEqual(reader.readinto(b), 2)
   788         self.assertEqual(reader.readinto(b), 2)
   784         self.assertEqual(b[:], foo[2:4])
   789         self.assertEqual(b[:], foo[2:4])
   785         self.assertEqual(reader.readinto(b), 2)
   790         self.assertEqual(reader.readinto(b), 2)
   786         self.assertEqual(b[:], foo[4:6])
   791         self.assertEqual(b[:], foo[4:6])
   787 
   792 
   788     def test_readinto1(self):
   793     def test_readinto1(self):
   789         cctx = zstd.ZstdCompressor()
   794         cctx = zstd.ZstdCompressor()
   790         foo = b''.join(cctx.read_to_iter(io.BytesIO(b'foo')))
   795         foo = b"".join(cctx.read_to_iter(io.BytesIO(b"foo")))
   791 
   796 
   792         reader = cctx.stream_reader(b'foo')
   797         reader = cctx.stream_reader(b"foo")
   793         with self.assertRaises(Exception):
   798         with self.assertRaises(Exception):
   794             reader.readinto1(b'foobar')
   799             reader.readinto1(b"foobar")
   795 
   800 
   796         b = bytearray(1024)
   801         b = bytearray(1024)
   797         source = OpCountingBytesIO(b'foo')
   802         source = OpCountingBytesIO(b"foo")
   798         reader = cctx.stream_reader(source)
   803         reader = cctx.stream_reader(source)
   799         self.assertEqual(reader.readinto1(b), len(foo))
   804         self.assertEqual(reader.readinto1(b), len(foo))
   800         self.assertEqual(b[0:len(foo)], foo)
   805         self.assertEqual(b[0 : len(foo)], foo)
   801         self.assertEqual(source._read_count, 2)
   806         self.assertEqual(source._read_count, 2)
   802 
   807 
   803         # readinto1() with small reads.
   808         # readinto1() with small reads.
   804         b = bytearray(1024)
   809         b = bytearray(1024)
   805         source = OpCountingBytesIO(b'foo')
   810         source = OpCountingBytesIO(b"foo")
   806         reader = cctx.stream_reader(source, read_size=1)
   811         reader = cctx.stream_reader(source, read_size=1)
   807         self.assertEqual(reader.readinto1(b), len(foo))
   812         self.assertEqual(reader.readinto1(b), len(foo))
   808         self.assertEqual(b[0:len(foo)], foo)
   813         self.assertEqual(b[0 : len(foo)], foo)
   809         self.assertEqual(source._read_count, 4)
   814         self.assertEqual(source._read_count, 4)
   810 
   815 
   811     def test_read1(self):
   816     def test_read1(self):
   812         cctx = zstd.ZstdCompressor()
   817         cctx = zstd.ZstdCompressor()
   813         foo = b''.join(cctx.read_to_iter(io.BytesIO(b'foo')))
   818         foo = b"".join(cctx.read_to_iter(io.BytesIO(b"foo")))
   814 
   819 
   815         b = OpCountingBytesIO(b'foo')
   820         b = OpCountingBytesIO(b"foo")
   816         reader = cctx.stream_reader(b)
   821         reader = cctx.stream_reader(b)
   817 
   822 
   818         self.assertEqual(reader.read1(), foo)
   823         self.assertEqual(reader.read1(), foo)
   819         self.assertEqual(b._read_count, 2)
   824         self.assertEqual(b._read_count, 2)
   820 
   825 
   821         b = OpCountingBytesIO(b'foo')
   826         b = OpCountingBytesIO(b"foo")
   822         reader = cctx.stream_reader(b)
   827         reader = cctx.stream_reader(b)
   823 
   828 
   824         self.assertEqual(reader.read1(0), b'')
   829         self.assertEqual(reader.read1(0), b"")
   825         self.assertEqual(reader.read1(2), foo[0:2])
   830         self.assertEqual(reader.read1(2), foo[0:2])
   826         self.assertEqual(b._read_count, 2)
   831         self.assertEqual(b._read_count, 2)
   827         self.assertEqual(reader.read1(2), foo[2:4])
   832         self.assertEqual(reader.read1(2), foo[2:4])
   828         self.assertEqual(reader.read1(1024), foo[4:])
   833         self.assertEqual(reader.read1(1024), foo[4:])
   829 
   834 
   830 
   835 
   831 @make_cffi
   836 @make_cffi
   832 class TestCompressor_stream_writer(unittest.TestCase):
   837 class TestCompressor_stream_writer(TestCase):
   833     def test_io_api(self):
   838     def test_io_api(self):
   834         buffer = io.BytesIO()
   839         buffer = io.BytesIO()
   835         cctx = zstd.ZstdCompressor()
   840         cctx = zstd.ZstdCompressor()
   836         writer = cctx.stream_writer(buffer)
   841         writer = cctx.stream_writer(buffer)
   837 
   842 
   897             writer.fileno()
   902             writer.fileno()
   898 
   903 
   899         self.assertFalse(writer.closed)
   904         self.assertFalse(writer.closed)
   900 
   905 
   901     def test_fileno_file(self):
   906     def test_fileno_file(self):
   902         with tempfile.TemporaryFile('wb') as tf:
   907         with tempfile.TemporaryFile("wb") as tf:
   903             cctx = zstd.ZstdCompressor()
   908             cctx = zstd.ZstdCompressor()
   904             writer = cctx.stream_writer(tf)
   909             writer = cctx.stream_writer(tf)
   905 
   910 
   906             self.assertEqual(writer.fileno(), tf.fileno())
   911             self.assertEqual(writer.fileno(), tf.fileno())
   907 
   912 
   908     def test_close(self):
   913     def test_close(self):
   909         buffer = NonClosingBytesIO()
   914         buffer = NonClosingBytesIO()
   910         cctx = zstd.ZstdCompressor(level=1)
   915         cctx = zstd.ZstdCompressor(level=1)
   911         writer = cctx.stream_writer(buffer)
   916         writer = cctx.stream_writer(buffer)
   912 
   917 
   913         writer.write(b'foo' * 1024)
   918         writer.write(b"foo" * 1024)
   914         self.assertFalse(writer.closed)
   919         self.assertFalse(writer.closed)
   915         self.assertFalse(buffer.closed)
   920         self.assertFalse(buffer.closed)
   916         writer.close()
   921         writer.close()
   917         self.assertTrue(writer.closed)
   922         self.assertTrue(writer.closed)
   918         self.assertTrue(buffer.closed)
   923         self.assertTrue(buffer.closed)
   919 
   924 
   920         with self.assertRaisesRegexp(ValueError, 'stream is closed'):
   925         with self.assertRaisesRegex(ValueError, "stream is closed"):
   921             writer.write(b'foo')
   926             writer.write(b"foo")
   922 
   927 
   923         with self.assertRaisesRegexp(ValueError, 'stream is closed'):
   928         with self.assertRaisesRegex(ValueError, "stream is closed"):
   924             writer.flush()
   929             writer.flush()
   925 
   930 
   926         with self.assertRaisesRegexp(ValueError, 'stream is closed'):
   931         with self.assertRaisesRegex(ValueError, "stream is closed"):
   927             with writer:
   932             with writer:
   928                 pass
   933                 pass
   929 
   934 
   930         self.assertEqual(buffer.getvalue(),
   935         self.assertEqual(
   931                          b'\x28\xb5\x2f\xfd\x00\x48\x55\x00\x00\x18\x66\x6f'
   936             buffer.getvalue(),
   932                          b'\x6f\x01\x00\xfa\xd3\x77\x43')
   937             b"\x28\xb5\x2f\xfd\x00\x48\x55\x00\x00\x18\x66\x6f"
       
   938             b"\x6f\x01\x00\xfa\xd3\x77\x43",
       
   939         )
   933 
   940 
   934         # Context manager exit should close stream.
   941         # Context manager exit should close stream.
   935         buffer = io.BytesIO()
   942         buffer = io.BytesIO()
   936         writer = cctx.stream_writer(buffer)
   943         writer = cctx.stream_writer(buffer)
   937 
   944 
   938         with writer:
   945         with writer:
   939             writer.write(b'foo')
   946             writer.write(b"foo")
   940 
   947 
   941         self.assertTrue(writer.closed)
   948         self.assertTrue(writer.closed)
   942 
   949 
   943     def test_empty(self):
   950     def test_empty(self):
   944         buffer = NonClosingBytesIO()
   951         buffer = NonClosingBytesIO()
   945         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
   952         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
   946         with cctx.stream_writer(buffer) as compressor:
   953         with cctx.stream_writer(buffer) as compressor:
   947             compressor.write(b'')
   954             compressor.write(b"")
   948 
   955 
   949         result = buffer.getvalue()
   956         result = buffer.getvalue()
   950         self.assertEqual(result, b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00')
   957         self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
   951 
   958 
   952         params = zstd.get_frame_parameters(result)
   959         params = zstd.get_frame_parameters(result)
   953         self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
   960         self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
   954         self.assertEqual(params.window_size, 524288)
   961         self.assertEqual(params.window_size, 524288)
   955         self.assertEqual(params.dict_id, 0)
   962         self.assertEqual(params.dict_id, 0)
   956         self.assertFalse(params.has_checksum)
   963         self.assertFalse(params.has_checksum)
   957 
   964 
   958         # Test without context manager.
   965         # Test without context manager.
   959         buffer = io.BytesIO()
   966         buffer = io.BytesIO()
   960         compressor = cctx.stream_writer(buffer)
   967         compressor = cctx.stream_writer(buffer)
   961         self.assertEqual(compressor.write(b''), 0)
   968         self.assertEqual(compressor.write(b""), 0)
   962         self.assertEqual(buffer.getvalue(), b'')
   969         self.assertEqual(buffer.getvalue(), b"")
   963         self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 9)
   970         self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 9)
   964         result = buffer.getvalue()
   971         result = buffer.getvalue()
   965         self.assertEqual(result, b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00')
   972         self.assertEqual(result, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
   966 
   973 
   967         params = zstd.get_frame_parameters(result)
   974         params = zstd.get_frame_parameters(result)
   968         self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
   975         self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
   969         self.assertEqual(params.window_size, 524288)
   976         self.assertEqual(params.window_size, 524288)
   970         self.assertEqual(params.dict_id, 0)
   977         self.assertEqual(params.dict_id, 0)
   971         self.assertFalse(params.has_checksum)
   978         self.assertFalse(params.has_checksum)
   972 
   979 
   973         # Test write_return_read=True
   980         # Test write_return_read=True
   974         compressor = cctx.stream_writer(buffer, write_return_read=True)
   981         compressor = cctx.stream_writer(buffer, write_return_read=True)
   975         self.assertEqual(compressor.write(b''), 0)
   982         self.assertEqual(compressor.write(b""), 0)
   976 
   983 
   977     def test_input_types(self):
   984     def test_input_types(self):
   978         expected = b'\x28\xb5\x2f\xfd\x00\x48\x19\x00\x00\x66\x6f\x6f'
   985         expected = b"\x28\xb5\x2f\xfd\x00\x48\x19\x00\x00\x66\x6f\x6f"
   979         cctx = zstd.ZstdCompressor(level=1)
   986         cctx = zstd.ZstdCompressor(level=1)
   980 
   987 
   981         mutable_array = bytearray(3)
   988         mutable_array = bytearray(3)
   982         mutable_array[:] = b'foo'
   989         mutable_array[:] = b"foo"
   983 
   990 
   984         sources = [
   991         sources = [
   985             memoryview(b'foo'),
   992             memoryview(b"foo"),
   986             bytearray(b'foo'),
   993             bytearray(b"foo"),
   987             mutable_array,
   994             mutable_array,
   988         ]
   995         ]
   989 
   996 
   990         for source in sources:
   997         for source in sources:
   991             buffer = NonClosingBytesIO()
   998             buffer = NonClosingBytesIO()
   999 
  1006 
  1000     def test_multiple_compress(self):
  1007     def test_multiple_compress(self):
  1001         buffer = NonClosingBytesIO()
  1008         buffer = NonClosingBytesIO()
  1002         cctx = zstd.ZstdCompressor(level=5)
  1009         cctx = zstd.ZstdCompressor(level=5)
  1003         with cctx.stream_writer(buffer) as compressor:
  1010         with cctx.stream_writer(buffer) as compressor:
  1004             self.assertEqual(compressor.write(b'foo'), 0)
  1011             self.assertEqual(compressor.write(b"foo"), 0)
  1005             self.assertEqual(compressor.write(b'bar'), 0)
  1012             self.assertEqual(compressor.write(b"bar"), 0)
  1006             self.assertEqual(compressor.write(b'x' * 8192), 0)
  1013             self.assertEqual(compressor.write(b"x" * 8192), 0)
  1007 
  1014 
  1008         result = buffer.getvalue()
  1015         result = buffer.getvalue()
  1009         self.assertEqual(result,
  1016         self.assertEqual(
  1010                          b'\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x38\x66\x6f'
  1017             result,
  1011                          b'\x6f\x62\x61\x72\x78\x01\x00\xfc\xdf\x03\x23')
  1018             b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x38\x66\x6f"
       
  1019             b"\x6f\x62\x61\x72\x78\x01\x00\xfc\xdf\x03\x23",
       
  1020         )
  1012 
  1021 
  1013         # Test without context manager.
  1022         # Test without context manager.
  1014         buffer = io.BytesIO()
  1023         buffer = io.BytesIO()
  1015         compressor = cctx.stream_writer(buffer)
  1024         compressor = cctx.stream_writer(buffer)
  1016         self.assertEqual(compressor.write(b'foo'), 0)
  1025         self.assertEqual(compressor.write(b"foo"), 0)
  1017         self.assertEqual(compressor.write(b'bar'), 0)
  1026         self.assertEqual(compressor.write(b"bar"), 0)
  1018         self.assertEqual(compressor.write(b'x' * 8192), 0)
  1027         self.assertEqual(compressor.write(b"x" * 8192), 0)
  1019         self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 23)
  1028         self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 23)
  1020         result = buffer.getvalue()
  1029         result = buffer.getvalue()
  1021         self.assertEqual(result,
  1030         self.assertEqual(
  1022                          b'\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x38\x66\x6f'
  1031             result,
  1023                          b'\x6f\x62\x61\x72\x78\x01\x00\xfc\xdf\x03\x23')
  1032             b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x38\x66\x6f"
       
  1033             b"\x6f\x62\x61\x72\x78\x01\x00\xfc\xdf\x03\x23",
       
  1034         )
  1024 
  1035 
  1025         # Test with write_return_read=True.
  1036         # Test with write_return_read=True.
  1026         compressor = cctx.stream_writer(buffer, write_return_read=True)
  1037         compressor = cctx.stream_writer(buffer, write_return_read=True)
  1027         self.assertEqual(compressor.write(b'foo'), 3)
  1038         self.assertEqual(compressor.write(b"foo"), 3)
  1028         self.assertEqual(compressor.write(b'barbiz'), 6)
  1039         self.assertEqual(compressor.write(b"barbiz"), 6)
  1029         self.assertEqual(compressor.write(b'x' * 8192), 8192)
  1040         self.assertEqual(compressor.write(b"x" * 8192), 8192)
  1030 
  1041 
  1031     def test_dictionary(self):
  1042     def test_dictionary(self):
  1032         samples = []
  1043         samples = []
  1033         for i in range(128):
  1044         for i in range(128):
  1034             samples.append(b'foo' * 64)
  1045             samples.append(b"foo" * 64)
  1035             samples.append(b'bar' * 64)
  1046             samples.append(b"bar" * 64)
  1036             samples.append(b'foobar' * 64)
  1047             samples.append(b"foobar" * 64)
  1037 
  1048 
  1038         d = zstd.train_dictionary(8192, samples)
  1049         d = zstd.train_dictionary(8192, samples)
  1039 
  1050 
  1040         h = hashlib.sha1(d.as_bytes()).hexdigest()
  1051         h = hashlib.sha1(d.as_bytes()).hexdigest()
  1041         self.assertEqual(h, '7a2e59a876db958f74257141045af8f912e00d4e')
  1052         self.assertEqual(h, "7a2e59a876db958f74257141045af8f912e00d4e")
  1042 
  1053 
  1043         buffer = NonClosingBytesIO()
  1054         buffer = NonClosingBytesIO()
  1044         cctx = zstd.ZstdCompressor(level=9, dict_data=d)
  1055         cctx = zstd.ZstdCompressor(level=9, dict_data=d)
  1045         with cctx.stream_writer(buffer) as compressor:
  1056         with cctx.stream_writer(buffer) as compressor:
  1046             self.assertEqual(compressor.write(b'foo'), 0)
  1057             self.assertEqual(compressor.write(b"foo"), 0)
  1047             self.assertEqual(compressor.write(b'bar'), 0)
  1058             self.assertEqual(compressor.write(b"bar"), 0)
  1048             self.assertEqual(compressor.write(b'foo' * 16384), 0)
  1059             self.assertEqual(compressor.write(b"foo" * 16384), 0)
  1049 
  1060 
  1050         compressed = buffer.getvalue()
  1061         compressed = buffer.getvalue()
  1051 
  1062 
  1052         params = zstd.get_frame_parameters(compressed)
  1063         params = zstd.get_frame_parameters(compressed)
  1053         self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
  1064         self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
  1054         self.assertEqual(params.window_size, 2097152)
  1065         self.assertEqual(params.window_size, 2097152)
  1055         self.assertEqual(params.dict_id, d.dict_id())
  1066         self.assertEqual(params.dict_id, d.dict_id())
  1056         self.assertFalse(params.has_checksum)
  1067         self.assertFalse(params.has_checksum)
  1057 
  1068 
  1058         h = hashlib.sha1(compressed).hexdigest()
  1069         h = hashlib.sha1(compressed).hexdigest()
  1059         self.assertEqual(h, '0a7c05635061f58039727cdbe76388c6f4cfef06')
  1070         self.assertEqual(h, "0a7c05635061f58039727cdbe76388c6f4cfef06")
  1060 
  1071 
  1061         source = b'foo' + b'bar' + (b'foo' * 16384)
  1072         source = b"foo" + b"bar" + (b"foo" * 16384)
  1062 
  1073 
  1063         dctx = zstd.ZstdDecompressor(dict_data=d)
  1074         dctx = zstd.ZstdDecompressor(dict_data=d)
  1064 
  1075 
  1065         self.assertEqual(dctx.decompress(compressed, max_output_size=len(source)),
  1076         self.assertEqual(
  1066                          source)
  1077             dctx.decompress(compressed, max_output_size=len(source)), source
       
  1078         )
  1067 
  1079 
  1068     def test_compression_params(self):
  1080     def test_compression_params(self):
  1069         params = zstd.ZstdCompressionParameters(
  1081         params = zstd.ZstdCompressionParameters(
  1070             window_log=20,
  1082             window_log=20,
  1071             chain_log=6,
  1083             chain_log=6,
  1072             hash_log=12,
  1084             hash_log=12,
  1073             min_match=5,
  1085             min_match=5,
  1074             search_log=4,
  1086             search_log=4,
  1075             target_length=10,
  1087             target_length=10,
  1076             strategy=zstd.STRATEGY_FAST)
  1088             strategy=zstd.STRATEGY_FAST,
       
  1089         )
  1077 
  1090 
  1078         buffer = NonClosingBytesIO()
  1091         buffer = NonClosingBytesIO()
  1079         cctx = zstd.ZstdCompressor(compression_params=params)
  1092         cctx = zstd.ZstdCompressor(compression_params=params)
  1080         with cctx.stream_writer(buffer) as compressor:
  1093         with cctx.stream_writer(buffer) as compressor:
  1081             self.assertEqual(compressor.write(b'foo'), 0)
  1094             self.assertEqual(compressor.write(b"foo"), 0)
  1082             self.assertEqual(compressor.write(b'bar'), 0)
  1095             self.assertEqual(compressor.write(b"bar"), 0)
  1083             self.assertEqual(compressor.write(b'foobar' * 16384), 0)
  1096             self.assertEqual(compressor.write(b"foobar" * 16384), 0)
  1084 
  1097 
  1085         compressed = buffer.getvalue()
  1098         compressed = buffer.getvalue()
  1086 
  1099 
  1087         params = zstd.get_frame_parameters(compressed)
  1100         params = zstd.get_frame_parameters(compressed)
  1088         self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
  1101         self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
  1089         self.assertEqual(params.window_size, 1048576)
  1102         self.assertEqual(params.window_size, 1048576)
  1090         self.assertEqual(params.dict_id, 0)
  1103         self.assertEqual(params.dict_id, 0)
  1091         self.assertFalse(params.has_checksum)
  1104         self.assertFalse(params.has_checksum)
  1092 
  1105 
  1093         h = hashlib.sha1(compressed).hexdigest()
  1106         h = hashlib.sha1(compressed).hexdigest()
  1094         self.assertEqual(h, 'dd4bb7d37c1a0235b38a2f6b462814376843ef0b')
  1107         self.assertEqual(h, "dd4bb7d37c1a0235b38a2f6b462814376843ef0b")
  1095 
  1108 
  1096     def test_write_checksum(self):
  1109     def test_write_checksum(self):
  1097         no_checksum = NonClosingBytesIO()
  1110         no_checksum = NonClosingBytesIO()
  1098         cctx = zstd.ZstdCompressor(level=1)
  1111         cctx = zstd.ZstdCompressor(level=1)
  1099         with cctx.stream_writer(no_checksum) as compressor:
  1112         with cctx.stream_writer(no_checksum) as compressor:
  1100             self.assertEqual(compressor.write(b'foobar'), 0)
  1113             self.assertEqual(compressor.write(b"foobar"), 0)
  1101 
  1114 
  1102         with_checksum = NonClosingBytesIO()
  1115         with_checksum = NonClosingBytesIO()
  1103         cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
  1116         cctx = zstd.ZstdCompressor(level=1, write_checksum=True)
  1104         with cctx.stream_writer(with_checksum) as compressor:
  1117         with cctx.stream_writer(with_checksum) as compressor:
  1105             self.assertEqual(compressor.write(b'foobar'), 0)
  1118             self.assertEqual(compressor.write(b"foobar"), 0)
  1106 
  1119 
  1107         no_params = zstd.get_frame_parameters(no_checksum.getvalue())
  1120         no_params = zstd.get_frame_parameters(no_checksum.getvalue())
  1108         with_params = zstd.get_frame_parameters(with_checksum.getvalue())
  1121         with_params = zstd.get_frame_parameters(with_checksum.getvalue())
  1109         self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
  1122         self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
  1110         self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
  1123         self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
  1111         self.assertEqual(no_params.dict_id, 0)
  1124         self.assertEqual(no_params.dict_id, 0)
  1112         self.assertEqual(with_params.dict_id, 0)
  1125         self.assertEqual(with_params.dict_id, 0)
  1113         self.assertFalse(no_params.has_checksum)
  1126         self.assertFalse(no_params.has_checksum)
  1114         self.assertTrue(with_params.has_checksum)
  1127         self.assertTrue(with_params.has_checksum)
  1115 
  1128 
  1116         self.assertEqual(len(with_checksum.getvalue()),
  1129         self.assertEqual(len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4)
  1117                          len(no_checksum.getvalue()) + 4)
       
  1118 
  1130 
  1119     def test_write_content_size(self):
  1131     def test_write_content_size(self):
  1120         no_size = NonClosingBytesIO()
  1132         no_size = NonClosingBytesIO()
  1121         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
  1133         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
  1122         with cctx.stream_writer(no_size) as compressor:
  1134         with cctx.stream_writer(no_size) as compressor:
  1123             self.assertEqual(compressor.write(b'foobar' * 256), 0)
  1135             self.assertEqual(compressor.write(b"foobar" * 256), 0)
  1124 
  1136 
  1125         with_size = NonClosingBytesIO()
  1137         with_size = NonClosingBytesIO()
  1126         cctx = zstd.ZstdCompressor(level=1)
  1138         cctx = zstd.ZstdCompressor(level=1)
  1127         with cctx.stream_writer(with_size) as compressor:
  1139         with cctx.stream_writer(with_size) as compressor:
  1128             self.assertEqual(compressor.write(b'foobar' * 256), 0)
  1140             self.assertEqual(compressor.write(b"foobar" * 256), 0)
  1129 
  1141 
  1130         # Source size is not known in streaming mode, so header not
  1142         # Source size is not known in streaming mode, so header not
  1131         # written.
  1143         # written.
  1132         self.assertEqual(len(with_size.getvalue()),
  1144         self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()))
  1133                          len(no_size.getvalue()))
       
  1134 
  1145 
  1135         # Declaring size will write the header.
  1146         # Declaring size will write the header.
  1136         with_size = NonClosingBytesIO()
  1147         with_size = NonClosingBytesIO()
  1137         with cctx.stream_writer(with_size, size=len(b'foobar' * 256)) as compressor:
  1148         with cctx.stream_writer(with_size, size=len(b"foobar" * 256)) as compressor:
  1138             self.assertEqual(compressor.write(b'foobar' * 256), 0)
  1149             self.assertEqual(compressor.write(b"foobar" * 256), 0)
  1139 
  1150 
  1140         no_params = zstd.get_frame_parameters(no_size.getvalue())
  1151         no_params = zstd.get_frame_parameters(no_size.getvalue())
  1141         with_params = zstd.get_frame_parameters(with_size.getvalue())
  1152         with_params = zstd.get_frame_parameters(with_size.getvalue())
  1142         self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
  1153         self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
  1143         self.assertEqual(with_params.content_size, 1536)
  1154         self.assertEqual(with_params.content_size, 1536)
  1144         self.assertEqual(no_params.dict_id, 0)
  1155         self.assertEqual(no_params.dict_id, 0)
  1145         self.assertEqual(with_params.dict_id, 0)
  1156         self.assertEqual(with_params.dict_id, 0)
  1146         self.assertFalse(no_params.has_checksum)
  1157         self.assertFalse(no_params.has_checksum)
  1147         self.assertFalse(with_params.has_checksum)
  1158         self.assertFalse(with_params.has_checksum)
  1148 
  1159 
  1149         self.assertEqual(len(with_size.getvalue()),
  1160         self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue()) + 1)
  1150                          len(no_size.getvalue()) + 1)
       
  1151 
  1161 
  1152     def test_no_dict_id(self):
  1162     def test_no_dict_id(self):
  1153         samples = []
  1163         samples = []
  1154         for i in range(128):
  1164         for i in range(128):
  1155             samples.append(b'foo' * 64)
  1165             samples.append(b"foo" * 64)
  1156             samples.append(b'bar' * 64)
  1166             samples.append(b"bar" * 64)
  1157             samples.append(b'foobar' * 64)
  1167             samples.append(b"foobar" * 64)
  1158 
  1168 
  1159         d = zstd.train_dictionary(1024, samples)
  1169         d = zstd.train_dictionary(1024, samples)
  1160 
  1170 
  1161         with_dict_id = NonClosingBytesIO()
  1171         with_dict_id = NonClosingBytesIO()
  1162         cctx = zstd.ZstdCompressor(level=1, dict_data=d)
  1172         cctx = zstd.ZstdCompressor(level=1, dict_data=d)
  1163         with cctx.stream_writer(with_dict_id) as compressor:
  1173         with cctx.stream_writer(with_dict_id) as compressor:
  1164             self.assertEqual(compressor.write(b'foobarfoobar'), 0)
  1174             self.assertEqual(compressor.write(b"foobarfoobar"), 0)
  1165 
  1175 
  1166         self.assertEqual(with_dict_id.getvalue()[4:5], b'\x03')
  1176         self.assertEqual(with_dict_id.getvalue()[4:5], b"\x03")
  1167 
  1177 
  1168         cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False)
  1178         cctx = zstd.ZstdCompressor(level=1, dict_data=d, write_dict_id=False)
  1169         no_dict_id = NonClosingBytesIO()
  1179         no_dict_id = NonClosingBytesIO()
  1170         with cctx.stream_writer(no_dict_id) as compressor:
  1180         with cctx.stream_writer(no_dict_id) as compressor:
  1171             self.assertEqual(compressor.write(b'foobarfoobar'), 0)
  1181             self.assertEqual(compressor.write(b"foobarfoobar"), 0)
  1172 
  1182 
  1173         self.assertEqual(no_dict_id.getvalue()[4:5], b'\x00')
  1183         self.assertEqual(no_dict_id.getvalue()[4:5], b"\x00")
  1174 
  1184 
  1175         no_params = zstd.get_frame_parameters(no_dict_id.getvalue())
  1185         no_params = zstd.get_frame_parameters(no_dict_id.getvalue())
  1176         with_params = zstd.get_frame_parameters(with_dict_id.getvalue())
  1186         with_params = zstd.get_frame_parameters(with_dict_id.getvalue())
  1177         self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
  1187         self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
  1178         self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
  1188         self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN)
  1179         self.assertEqual(no_params.dict_id, 0)
  1189         self.assertEqual(no_params.dict_id, 0)
  1180         self.assertEqual(with_params.dict_id, d.dict_id())
  1190         self.assertEqual(with_params.dict_id, d.dict_id())
  1181         self.assertFalse(no_params.has_checksum)
  1191         self.assertFalse(no_params.has_checksum)
  1182         self.assertFalse(with_params.has_checksum)
  1192         self.assertFalse(with_params.has_checksum)
  1183 
  1193 
  1184         self.assertEqual(len(with_dict_id.getvalue()),
  1194         self.assertEqual(len(with_dict_id.getvalue()), len(no_dict_id.getvalue()) + 4)
  1185                          len(no_dict_id.getvalue()) + 4)
       
  1186 
  1195 
  1187     def test_memory_size(self):
  1196     def test_memory_size(self):
  1188         cctx = zstd.ZstdCompressor(level=3)
  1197         cctx = zstd.ZstdCompressor(level=3)
  1189         buffer = io.BytesIO()
  1198         buffer = io.BytesIO()
  1190         with cctx.stream_writer(buffer) as compressor:
  1199         with cctx.stream_writer(buffer) as compressor:
  1191             compressor.write(b'foo')
  1200             compressor.write(b"foo")
  1192             size = compressor.memory_size()
  1201             size = compressor.memory_size()
  1193 
  1202 
  1194         self.assertGreater(size, 100000)
  1203         self.assertGreater(size, 100000)
  1195 
  1204 
  1196     def test_write_size(self):
  1205     def test_write_size(self):
  1197         cctx = zstd.ZstdCompressor(level=3)
  1206         cctx = zstd.ZstdCompressor(level=3)
  1198         dest = OpCountingBytesIO()
  1207         dest = OpCountingBytesIO()
  1199         with cctx.stream_writer(dest, write_size=1) as compressor:
  1208         with cctx.stream_writer(dest, write_size=1) as compressor:
  1200             self.assertEqual(compressor.write(b'foo'), 0)
  1209             self.assertEqual(compressor.write(b"foo"), 0)
  1201             self.assertEqual(compressor.write(b'bar'), 0)
  1210             self.assertEqual(compressor.write(b"bar"), 0)
  1202             self.assertEqual(compressor.write(b'foobar'), 0)
  1211             self.assertEqual(compressor.write(b"foobar"), 0)
  1203 
  1212 
  1204         self.assertEqual(len(dest.getvalue()), dest._write_count)
  1213         self.assertEqual(len(dest.getvalue()), dest._write_count)
  1205 
  1214 
  1206     def test_flush_repeated(self):
  1215     def test_flush_repeated(self):
  1207         cctx = zstd.ZstdCompressor(level=3)
  1216         cctx = zstd.ZstdCompressor(level=3)
  1208         dest = OpCountingBytesIO()
  1217         dest = OpCountingBytesIO()
  1209         with cctx.stream_writer(dest) as compressor:
  1218         with cctx.stream_writer(dest) as compressor:
  1210             self.assertEqual(compressor.write(b'foo'), 0)
  1219             self.assertEqual(compressor.write(b"foo"), 0)
  1211             self.assertEqual(dest._write_count, 0)
  1220             self.assertEqual(dest._write_count, 0)
  1212             self.assertEqual(compressor.flush(), 12)
  1221             self.assertEqual(compressor.flush(), 12)
  1213             self.assertEqual(dest._write_count, 1)
  1222             self.assertEqual(dest._write_count, 1)
  1214             self.assertEqual(compressor.write(b'bar'), 0)
  1223             self.assertEqual(compressor.write(b"bar"), 0)
  1215             self.assertEqual(dest._write_count, 1)
  1224             self.assertEqual(dest._write_count, 1)
  1216             self.assertEqual(compressor.flush(), 6)
  1225             self.assertEqual(compressor.flush(), 6)
  1217             self.assertEqual(dest._write_count, 2)
  1226             self.assertEqual(dest._write_count, 2)
  1218             self.assertEqual(compressor.write(b'baz'), 0)
  1227             self.assertEqual(compressor.write(b"baz"), 0)
  1219 
  1228 
  1220         self.assertEqual(dest._write_count, 3)
  1229         self.assertEqual(dest._write_count, 3)
  1221 
  1230 
  1222     def test_flush_empty_block(self):
  1231     def test_flush_empty_block(self):
  1223         cctx = zstd.ZstdCompressor(level=3, write_checksum=True)
  1232         cctx = zstd.ZstdCompressor(level=3, write_checksum=True)
  1224         dest = OpCountingBytesIO()
  1233         dest = OpCountingBytesIO()
  1225         with cctx.stream_writer(dest) as compressor:
  1234         with cctx.stream_writer(dest) as compressor:
  1226             self.assertEqual(compressor.write(b'foobar' * 8192), 0)
  1235             self.assertEqual(compressor.write(b"foobar" * 8192), 0)
  1227             count = dest._write_count
  1236             count = dest._write_count
  1228             offset = dest.tell()
  1237             offset = dest.tell()
  1229             self.assertEqual(compressor.flush(), 23)
  1238             self.assertEqual(compressor.flush(), 23)
  1230             self.assertGreater(dest._write_count, count)
  1239             self.assertGreater(dest._write_count, count)
  1231             self.assertGreater(dest.tell(), offset)
  1240             self.assertGreater(dest.tell(), offset)
  1236         trailing = dest.getvalue()[offset:]
  1245         trailing = dest.getvalue()[offset:]
  1237         # 3 bytes block header + 4 bytes frame checksum
  1246         # 3 bytes block header + 4 bytes frame checksum
  1238         self.assertEqual(len(trailing), 7)
  1247         self.assertEqual(len(trailing), 7)
  1239 
  1248 
  1240         header = trailing[0:3]
  1249         header = trailing[0:3]
  1241         self.assertEqual(header, b'\x01\x00\x00')
  1250         self.assertEqual(header, b"\x01\x00\x00")
  1242 
  1251 
  1243     def test_flush_frame(self):
  1252     def test_flush_frame(self):
  1244         cctx = zstd.ZstdCompressor(level=3)
  1253         cctx = zstd.ZstdCompressor(level=3)
  1245         dest = OpCountingBytesIO()
  1254         dest = OpCountingBytesIO()
  1246 
  1255 
  1247         with cctx.stream_writer(dest) as compressor:
  1256         with cctx.stream_writer(dest) as compressor:
  1248             self.assertEqual(compressor.write(b'foobar' * 8192), 0)
  1257             self.assertEqual(compressor.write(b"foobar" * 8192), 0)
  1249             self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 23)
  1258             self.assertEqual(compressor.flush(zstd.FLUSH_FRAME), 23)
  1250             compressor.write(b'biz' * 16384)
  1259             compressor.write(b"biz" * 16384)
  1251 
  1260 
  1252         self.assertEqual(dest.getvalue(),
  1261         self.assertEqual(
  1253                          # Frame 1.
  1262             dest.getvalue(),
  1254                          b'\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x30\x66\x6f\x6f'
  1263             # Frame 1.
  1255                          b'\x62\x61\x72\x01\x00\xf7\xbf\xe8\xa5\x08'
  1264             b"\x28\xb5\x2f\xfd\x00\x58\x75\x00\x00\x30\x66\x6f\x6f"
  1256                          # Frame 2.
  1265             b"\x62\x61\x72\x01\x00\xf7\xbf\xe8\xa5\x08"
  1257                          b'\x28\xb5\x2f\xfd\x00\x58\x5d\x00\x00\x18\x62\x69\x7a'
  1266             # Frame 2.
  1258                          b'\x01\x00\xfa\x3f\x75\x37\x04')
  1267             b"\x28\xb5\x2f\xfd\x00\x58\x5d\x00\x00\x18\x62\x69\x7a"
       
  1268             b"\x01\x00\xfa\x3f\x75\x37\x04",
       
  1269         )
  1259 
  1270 
  1260     def test_bad_flush_mode(self):
  1271     def test_bad_flush_mode(self):
  1261         cctx = zstd.ZstdCompressor()
  1272         cctx = zstd.ZstdCompressor()
  1262         dest = io.BytesIO()
  1273         dest = io.BytesIO()
  1263         with cctx.stream_writer(dest) as compressor:
  1274         with cctx.stream_writer(dest) as compressor:
  1264             with self.assertRaisesRegexp(ValueError, 'unknown flush_mode: 42'):
  1275             with self.assertRaisesRegex(ValueError, "unknown flush_mode: 42"):
  1265                 compressor.flush(flush_mode=42)
  1276                 compressor.flush(flush_mode=42)
  1266 
  1277 
  1267     def test_multithreaded(self):
  1278     def test_multithreaded(self):
  1268         dest = NonClosingBytesIO()
  1279         dest = NonClosingBytesIO()
  1269         cctx = zstd.ZstdCompressor(threads=2)
  1280         cctx = zstd.ZstdCompressor(threads=2)
  1270         with cctx.stream_writer(dest) as compressor:
  1281         with cctx.stream_writer(dest) as compressor:
  1271             compressor.write(b'a' * 1048576)
  1282             compressor.write(b"a" * 1048576)
  1272             compressor.write(b'b' * 1048576)
  1283             compressor.write(b"b" * 1048576)
  1273             compressor.write(b'c' * 1048576)
  1284             compressor.write(b"c" * 1048576)
  1274 
  1285 
  1275         self.assertEqual(len(dest.getvalue()), 295)
  1286         self.assertEqual(len(dest.getvalue()), 111)
  1276 
  1287 
  1277     def test_tell(self):
  1288     def test_tell(self):
  1278         dest = io.BytesIO()
  1289         dest = io.BytesIO()
  1279         cctx = zstd.ZstdCompressor()
  1290         cctx = zstd.ZstdCompressor()
  1280         with cctx.stream_writer(dest) as compressor:
  1291         with cctx.stream_writer(dest) as compressor:
  1281             self.assertEqual(compressor.tell(), 0)
  1292             self.assertEqual(compressor.tell(), 0)
  1282 
  1293 
  1283             for i in range(256):
  1294             for i in range(256):
  1284                 compressor.write(b'foo' * (i + 1))
  1295                 compressor.write(b"foo" * (i + 1))
  1285                 self.assertEqual(compressor.tell(), dest.tell())
  1296                 self.assertEqual(compressor.tell(), dest.tell())
  1286 
  1297 
  1287     def test_bad_size(self):
  1298     def test_bad_size(self):
  1288         cctx = zstd.ZstdCompressor()
  1299         cctx = zstd.ZstdCompressor()
  1289 
  1300 
  1290         dest = io.BytesIO()
  1301         dest = io.BytesIO()
  1291 
  1302 
  1292         with self.assertRaisesRegexp(zstd.ZstdError, 'Src size is incorrect'):
  1303         with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
  1293             with cctx.stream_writer(dest, size=2) as compressor:
  1304             with cctx.stream_writer(dest, size=2) as compressor:
  1294                 compressor.write(b'foo')
  1305                 compressor.write(b"foo")
  1295 
  1306 
  1296         # Test another operation.
  1307         # Test another operation.
  1297         with cctx.stream_writer(dest, size=42):
  1308         with cctx.stream_writer(dest, size=42):
  1298             pass
  1309             pass
  1299 
  1310 
  1300     def test_tarfile_compat(self):
  1311     def test_tarfile_compat(self):
  1301         dest = NonClosingBytesIO()
  1312         dest = NonClosingBytesIO()
  1302         cctx = zstd.ZstdCompressor()
  1313         cctx = zstd.ZstdCompressor()
  1303         with cctx.stream_writer(dest) as compressor:
  1314         with cctx.stream_writer(dest) as compressor:
  1304             with tarfile.open('tf', mode='w|', fileobj=compressor) as tf:
  1315             with tarfile.open("tf", mode="w|", fileobj=compressor) as tf:
  1305                 tf.add(__file__, 'test_compressor.py')
  1316                 tf.add(__file__, "test_compressor.py")
  1306 
  1317 
  1307         dest = io.BytesIO(dest.getvalue())
  1318         dest = io.BytesIO(dest.getvalue())
  1308 
  1319 
  1309         dctx = zstd.ZstdDecompressor()
  1320         dctx = zstd.ZstdDecompressor()
  1310         with dctx.stream_reader(dest) as reader:
  1321         with dctx.stream_reader(dest) as reader:
  1311             with tarfile.open(mode='r|', fileobj=reader) as tf:
  1322             with tarfile.open(mode="r|", fileobj=reader) as tf:
  1312                 for member in tf:
  1323                 for member in tf:
  1313                     self.assertEqual(member.name, 'test_compressor.py')
  1324                     self.assertEqual(member.name, "test_compressor.py")
  1314 
  1325 
  1315 
  1326 
  1316 @make_cffi
  1327 @make_cffi
  1317 class TestCompressor_read_to_iter(unittest.TestCase):
  1328 class TestCompressor_read_to_iter(TestCase):
  1318     def test_type_validation(self):
  1329     def test_type_validation(self):
  1319         cctx = zstd.ZstdCompressor()
  1330         cctx = zstd.ZstdCompressor()
  1320 
  1331 
  1321         # Object with read() works.
  1332         # Object with read() works.
  1322         for chunk in cctx.read_to_iter(io.BytesIO()):
  1333         for chunk in cctx.read_to_iter(io.BytesIO()):
  1323             pass
  1334             pass
  1324 
  1335 
  1325         # Buffer protocol works.
  1336         # Buffer protocol works.
  1326         for chunk in cctx.read_to_iter(b'foobar'):
  1337         for chunk in cctx.read_to_iter(b"foobar"):
  1327             pass
  1338             pass
  1328 
  1339 
  1329         with self.assertRaisesRegexp(ValueError, 'must pass an object with a read'):
  1340         with self.assertRaisesRegex(ValueError, "must pass an object with a read"):
  1330             for chunk in cctx.read_to_iter(True):
  1341             for chunk in cctx.read_to_iter(True):
  1331                 pass
  1342                 pass
  1332 
  1343 
  1333     def test_read_empty(self):
  1344     def test_read_empty(self):
  1334         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
  1345         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
  1335 
  1346 
  1336         source = io.BytesIO()
  1347         source = io.BytesIO()
  1337         it = cctx.read_to_iter(source)
  1348         it = cctx.read_to_iter(source)
  1338         chunks = list(it)
  1349         chunks = list(it)
  1339         self.assertEqual(len(chunks), 1)
  1350         self.assertEqual(len(chunks), 1)
  1340         compressed = b''.join(chunks)
  1351         compressed = b"".join(chunks)
  1341         self.assertEqual(compressed, b'\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00')
  1352         self.assertEqual(compressed, b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00")
  1342 
  1353 
  1343         # And again with the buffer protocol.
  1354         # And again with the buffer protocol.
  1344         it = cctx.read_to_iter(b'')
  1355         it = cctx.read_to_iter(b"")
  1345         chunks = list(it)
  1356         chunks = list(it)
  1346         self.assertEqual(len(chunks), 1)
  1357         self.assertEqual(len(chunks), 1)
  1347         compressed2 = b''.join(chunks)
  1358         compressed2 = b"".join(chunks)
  1348         self.assertEqual(compressed2, compressed)
  1359         self.assertEqual(compressed2, compressed)
  1349 
  1360 
  1350     def test_read_large(self):
  1361     def test_read_large(self):
  1351         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
  1362         cctx = zstd.ZstdCompressor(level=1, write_content_size=False)
  1352 
  1363 
  1353         source = io.BytesIO()
  1364         source = io.BytesIO()
  1354         source.write(b'f' * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE)
  1365         source.write(b"f" * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE)
  1355         source.write(b'o')
  1366         source.write(b"o")
  1356         source.seek(0)
  1367         source.seek(0)
  1357 
  1368 
  1358         # Creating an iterator should not perform any compression until
  1369         # Creating an iterator should not perform any compression until
  1359         # first read.
  1370         # first read.
  1360         it = cctx.read_to_iter(source, size=len(source.getvalue()))
  1371         it = cctx.read_to_iter(source, size=len(source.getvalue()))
  1378         # And again for good measure.
  1389         # And again for good measure.
  1379         with self.assertRaises(StopIteration):
  1390         with self.assertRaises(StopIteration):
  1380             next(it)
  1391             next(it)
  1381 
  1392 
  1382         # We should get the same output as the one-shot compression mechanism.
  1393         # We should get the same output as the one-shot compression mechanism.
  1383         self.assertEqual(b''.join(chunks), cctx.compress(source.getvalue()))
  1394         self.assertEqual(b"".join(chunks), cctx.compress(source.getvalue()))
  1384 
  1395 
  1385         params = zstd.get_frame_parameters(b''.join(chunks))
  1396         params = zstd.get_frame_parameters(b"".join(chunks))
  1386         self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
  1397         self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
  1387         self.assertEqual(params.window_size, 262144)
  1398         self.assertEqual(params.window_size, 262144)
  1388         self.assertEqual(params.dict_id, 0)
  1399         self.assertEqual(params.dict_id, 0)
  1389         self.assertFalse(params.has_checksum)
  1400         self.assertFalse(params.has_checksum)
  1390 
  1401 
  1391         # Now check the buffer protocol.
  1402         # Now check the buffer protocol.
  1392         it = cctx.read_to_iter(source.getvalue())
  1403         it = cctx.read_to_iter(source.getvalue())
  1393         chunks = list(it)
  1404         chunks = list(it)
  1394         self.assertEqual(len(chunks), 2)
  1405         self.assertEqual(len(chunks), 2)
  1395 
  1406 
  1396         params = zstd.get_frame_parameters(b''.join(chunks))
  1407         params = zstd.get_frame_parameters(b"".join(chunks))
  1397         self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
  1408         self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN)
  1398         #self.assertEqual(params.window_size, 262144)
  1409         # self.assertEqual(params.window_size, 262144)
  1399         self.assertEqual(params.dict_id, 0)
  1410         self.assertEqual(params.dict_id, 0)
  1400         self.assertFalse(params.has_checksum)
  1411         self.assertFalse(params.has_checksum)
  1401 
  1412 
  1402         self.assertEqual(b''.join(chunks), cctx.compress(source.getvalue()))
  1413         self.assertEqual(b"".join(chunks), cctx.compress(source.getvalue()))
  1403 
  1414 
  1404     def test_read_write_size(self):
  1415     def test_read_write_size(self):
  1405         source = OpCountingBytesIO(b'foobarfoobar')
  1416         source = OpCountingBytesIO(b"foobarfoobar")
  1406         cctx = zstd.ZstdCompressor(level=3)
  1417         cctx = zstd.ZstdCompressor(level=3)
  1407         for chunk in cctx.read_to_iter(source, read_size=1, write_size=1):
  1418         for chunk in cctx.read_to_iter(source, read_size=1, write_size=1):
  1408             self.assertEqual(len(chunk), 1)
  1419             self.assertEqual(len(chunk), 1)
  1409 
  1420 
  1410         self.assertEqual(source._read_count, len(source.getvalue()) + 1)
  1421         self.assertEqual(source._read_count, len(source.getvalue()) + 1)
  1411 
  1422 
  1412     def test_multithreaded(self):
  1423     def test_multithreaded(self):
  1413         source = io.BytesIO()
  1424         source = io.BytesIO()
  1414         source.write(b'a' * 1048576)
  1425         source.write(b"a" * 1048576)
  1415         source.write(b'b' * 1048576)
  1426         source.write(b"b" * 1048576)
  1416         source.write(b'c' * 1048576)
  1427         source.write(b"c" * 1048576)
  1417         source.seek(0)
  1428         source.seek(0)
  1418 
  1429 
  1419         cctx = zstd.ZstdCompressor(threads=2)
  1430         cctx = zstd.ZstdCompressor(threads=2)
  1420 
  1431 
  1421         compressed = b''.join(cctx.read_to_iter(source))
  1432         compressed = b"".join(cctx.read_to_iter(source))
  1422         self.assertEqual(len(compressed), 295)
  1433         self.assertEqual(len(compressed), 111)
  1423 
  1434 
  1424     def test_bad_size(self):
  1435     def test_bad_size(self):
  1425         cctx = zstd.ZstdCompressor()
  1436         cctx = zstd.ZstdCompressor()
  1426 
  1437 
  1427         source = io.BytesIO(b'a' * 42)
  1438         source = io.BytesIO(b"a" * 42)
  1428 
  1439 
  1429         with self.assertRaisesRegexp(zstd.ZstdError, 'Src size is incorrect'):
  1440         with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"):
  1430             b''.join(cctx.read_to_iter(source, size=2))
  1441             b"".join(cctx.read_to_iter(source, size=2))
  1431 
  1442 
  1432         # Test another operation on errored compressor.
  1443         # Test another operation on errored compressor.
  1433         b''.join(cctx.read_to_iter(source))
  1444         b"".join(cctx.read_to_iter(source))
  1434 
  1445 
  1435 
  1446 
  1436 @make_cffi
  1447 @make_cffi
  1437 class TestCompressor_chunker(unittest.TestCase):
  1448 class TestCompressor_chunker(TestCase):
  1438     def test_empty(self):
  1449     def test_empty(self):
  1439         cctx = zstd.ZstdCompressor(write_content_size=False)
  1450         cctx = zstd.ZstdCompressor(write_content_size=False)
  1440         chunker = cctx.chunker()
  1451         chunker = cctx.chunker()
  1441 
  1452 
  1442         it = chunker.compress(b'')
  1453         it = chunker.compress(b"")
  1443 
  1454 
  1444         with self.assertRaises(StopIteration):
  1455         with self.assertRaises(StopIteration):
  1445             next(it)
  1456             next(it)
  1446 
  1457 
  1447         it = chunker.finish()
  1458         it = chunker.finish()
  1448 
  1459 
  1449         self.assertEqual(next(it), b'\x28\xb5\x2f\xfd\x00\x58\x01\x00\x00')
  1460         self.assertEqual(next(it), b"\x28\xb5\x2f\xfd\x00\x58\x01\x00\x00")
  1450 
  1461 
  1451         with self.assertRaises(StopIteration):
  1462         with self.assertRaises(StopIteration):
  1452             next(it)
  1463             next(it)
  1453 
  1464 
  1454     def test_simple_input(self):
  1465     def test_simple_input(self):
  1455         cctx = zstd.ZstdCompressor()
  1466         cctx = zstd.ZstdCompressor()
  1456         chunker = cctx.chunker()
  1467         chunker = cctx.chunker()
  1457 
  1468 
  1458         it = chunker.compress(b'foobar')
  1469         it = chunker.compress(b"foobar")
  1459 
  1470 
  1460         with self.assertRaises(StopIteration):
  1471         with self.assertRaises(StopIteration):
  1461             next(it)
  1472             next(it)
  1462 
  1473 
  1463         it = chunker.compress(b'baz' * 30)
  1474         it = chunker.compress(b"baz" * 30)
  1464 
  1475 
  1465         with self.assertRaises(StopIteration):
  1476         with self.assertRaises(StopIteration):
  1466             next(it)
  1477             next(it)
  1467 
  1478 
  1468         it = chunker.finish()
  1479         it = chunker.finish()
  1469 
  1480 
  1470         self.assertEqual(next(it),
  1481         self.assertEqual(
  1471                          b'\x28\xb5\x2f\xfd\x00\x58\x7d\x00\x00\x48\x66\x6f'
  1482             next(it),
  1472                          b'\x6f\x62\x61\x72\x62\x61\x7a\x01\x00\xe4\xe4\x8e')
  1483             b"\x28\xb5\x2f\xfd\x00\x58\x7d\x00\x00\x48\x66\x6f"
       
  1484             b"\x6f\x62\x61\x72\x62\x61\x7a\x01\x00\xe4\xe4\x8e",
       
  1485         )
  1473 
  1486 
  1474         with self.assertRaises(StopIteration):
  1487         with self.assertRaises(StopIteration):
  1475             next(it)
  1488             next(it)
  1476 
  1489 
  1477     def test_input_size(self):
  1490     def test_input_size(self):
  1478         cctx = zstd.ZstdCompressor()
  1491         cctx = zstd.ZstdCompressor()
  1479         chunker = cctx.chunker(size=1024)
  1492         chunker = cctx.chunker(size=1024)
  1480 
  1493 
  1481         it = chunker.compress(b'x' * 1000)
  1494         it = chunker.compress(b"x" * 1000)
  1482 
  1495 
  1483         with self.assertRaises(StopIteration):
  1496         with self.assertRaises(StopIteration):
  1484             next(it)
  1497             next(it)
  1485 
  1498 
  1486         it = chunker.compress(b'y' * 24)
  1499         it = chunker.compress(b"y" * 24)
  1487 
  1500 
  1488         with self.assertRaises(StopIteration):
  1501         with self.assertRaises(StopIteration):
  1489             next(it)
  1502             next(it)
  1490 
  1503 
  1491         chunks = list(chunker.finish())
  1504         chunks = list(chunker.finish())
  1492 
  1505 
  1493         self.assertEqual(chunks, [
  1506         self.assertEqual(
  1494             b'\x28\xb5\x2f\xfd\x60\x00\x03\x65\x00\x00\x18\x78\x78\x79\x02\x00'
  1507             chunks,
  1495             b'\xa0\x16\xe3\x2b\x80\x05'
  1508             [
  1496         ])
  1509                 b"\x28\xb5\x2f\xfd\x60\x00\x03\x65\x00\x00\x18\x78\x78\x79\x02\x00"
       
  1510                 b"\xa0\x16\xe3\x2b\x80\x05"
       
  1511             ],
       
  1512         )
  1497 
  1513 
  1498         dctx = zstd.ZstdDecompressor()
  1514         dctx = zstd.ZstdDecompressor()
  1499 
  1515 
  1500         self.assertEqual(dctx.decompress(b''.join(chunks)),
  1516         self.assertEqual(dctx.decompress(b"".join(chunks)), (b"x" * 1000) + (b"y" * 24))
  1501                          (b'x' * 1000) + (b'y' * 24))
       
  1502 
  1517 
  1503     def test_small_chunk_size(self):
  1518     def test_small_chunk_size(self):
  1504         cctx = zstd.ZstdCompressor()
  1519         cctx = zstd.ZstdCompressor()
  1505         chunker = cctx.chunker(chunk_size=1)
  1520         chunker = cctx.chunker(chunk_size=1)
  1506 
  1521 
  1507         chunks = list(chunker.compress(b'foo' * 1024))
  1522         chunks = list(chunker.compress(b"foo" * 1024))
  1508         self.assertEqual(chunks, [])
  1523         self.assertEqual(chunks, [])
  1509 
  1524 
  1510         chunks = list(chunker.finish())
  1525         chunks = list(chunker.finish())
  1511         self.assertTrue(all(len(chunk) == 1 for chunk in chunks))
  1526         self.assertTrue(all(len(chunk) == 1 for chunk in chunks))
  1512 
  1527 
  1513         self.assertEqual(
  1528         self.assertEqual(
  1514             b''.join(chunks),
  1529             b"".join(chunks),
  1515             b'\x28\xb5\x2f\xfd\x00\x58\x55\x00\x00\x18\x66\x6f\x6f\x01\x00'
  1530             b"\x28\xb5\x2f\xfd\x00\x58\x55\x00\x00\x18\x66\x6f\x6f\x01\x00"
  1516             b'\xfa\xd3\x77\x43')
  1531             b"\xfa\xd3\x77\x43",
       
  1532         )
  1517 
  1533 
  1518         dctx = zstd.ZstdDecompressor()
  1534         dctx = zstd.ZstdDecompressor()
  1519         self.assertEqual(dctx.decompress(b''.join(chunks),
  1535         self.assertEqual(
  1520                                          max_output_size=10000),
  1536             dctx.decompress(b"".join(chunks), max_output_size=10000), b"foo" * 1024
  1521                          b'foo' * 1024)
  1537         )
  1522 
  1538 
  1523     def test_input_types(self):
  1539     def test_input_types(self):
  1524         cctx = zstd.ZstdCompressor()
  1540         cctx = zstd.ZstdCompressor()
  1525 
  1541 
  1526         mutable_array = bytearray(3)
  1542         mutable_array = bytearray(3)
  1527         mutable_array[:] = b'foo'
  1543         mutable_array[:] = b"foo"
  1528 
  1544 
  1529         sources = [
  1545         sources = [
  1530             memoryview(b'foo'),
  1546             memoryview(b"foo"),
  1531             bytearray(b'foo'),
  1547             bytearray(b"foo"),
  1532             mutable_array,
  1548             mutable_array,
  1533         ]
  1549         ]
  1534 
  1550 
  1535         for source in sources:
  1551         for source in sources:
  1536             chunker = cctx.chunker()
  1552             chunker = cctx.chunker()
  1537 
  1553 
  1538             self.assertEqual(list(chunker.compress(source)), [])
  1554             self.assertEqual(list(chunker.compress(source)), [])
  1539             self.assertEqual(list(chunker.finish()), [
  1555             self.assertEqual(
  1540                 b'\x28\xb5\x2f\xfd\x00\x58\x19\x00\x00\x66\x6f\x6f'
  1556                 list(chunker.finish()),
  1541             ])
  1557                 [b"\x28\xb5\x2f\xfd\x00\x58\x19\x00\x00\x66\x6f\x6f"],
       
  1558             )
  1542 
  1559 
  1543     def test_flush(self):
  1560     def test_flush(self):
  1544         cctx = zstd.ZstdCompressor()
  1561         cctx = zstd.ZstdCompressor()
  1545         chunker = cctx.chunker()
  1562         chunker = cctx.chunker()
  1546 
  1563 
  1547         self.assertEqual(list(chunker.compress(b'foo' * 1024)), [])
  1564         self.assertEqual(list(chunker.compress(b"foo" * 1024)), [])
  1548         self.assertEqual(list(chunker.compress(b'bar' * 1024)), [])
  1565         self.assertEqual(list(chunker.compress(b"bar" * 1024)), [])
  1549 
  1566 
  1550         chunks1 = list(chunker.flush())
  1567         chunks1 = list(chunker.flush())
  1551 
  1568 
  1552         self.assertEqual(chunks1, [
  1569         self.assertEqual(
  1553             b'\x28\xb5\x2f\xfd\x00\x58\x8c\x00\x00\x30\x66\x6f\x6f\x62\x61\x72'
  1570             chunks1,
  1554             b'\x02\x00\xfa\x03\xfe\xd0\x9f\xbe\x1b\x02'
  1571             [
  1555         ])
  1572                 b"\x28\xb5\x2f\xfd\x00\x58\x8c\x00\x00\x30\x66\x6f\x6f\x62\x61\x72"
       
  1573                 b"\x02\x00\xfa\x03\xfe\xd0\x9f\xbe\x1b\x02"
       
  1574             ],
       
  1575         )
  1556 
  1576 
  1557         self.assertEqual(list(chunker.flush()), [])
  1577         self.assertEqual(list(chunker.flush()), [])
  1558         self.assertEqual(list(chunker.flush()), [])
  1578         self.assertEqual(list(chunker.flush()), [])
  1559 
  1579 
  1560         self.assertEqual(list(chunker.compress(b'baz' * 1024)), [])
  1580         self.assertEqual(list(chunker.compress(b"baz" * 1024)), [])
  1561 
  1581 
  1562         chunks2 = list(chunker.flush())
  1582         chunks2 = list(chunker.flush())
  1563         self.assertEqual(len(chunks2), 1)
  1583         self.assertEqual(len(chunks2), 1)
  1564 
  1584 
  1565         chunks3 = list(chunker.finish())
  1585         chunks3 = list(chunker.finish())
  1566         self.assertEqual(len(chunks2), 1)
  1586         self.assertEqual(len(chunks2), 1)
  1567 
  1587 
  1568         dctx = zstd.ZstdDecompressor()
  1588         dctx = zstd.ZstdDecompressor()
  1569 
  1589 
  1570         self.assertEqual(dctx.decompress(b''.join(chunks1 + chunks2 + chunks3),
  1590         self.assertEqual(
  1571                                          max_output_size=10000),
  1591             dctx.decompress(
  1572                          (b'foo' * 1024) + (b'bar' * 1024) + (b'baz' * 1024))
  1592                 b"".join(chunks1 + chunks2 + chunks3), max_output_size=10000
       
  1593             ),
       
  1594             (b"foo" * 1024) + (b"bar" * 1024) + (b"baz" * 1024),
       
  1595         )
  1573 
  1596 
  1574     def test_compress_after_finish(self):
  1597     def test_compress_after_finish(self):
  1575         cctx = zstd.ZstdCompressor()
  1598         cctx = zstd.ZstdCompressor()
  1576         chunker = cctx.chunker()
  1599         chunker = cctx.chunker()
  1577 
  1600 
  1578         list(chunker.compress(b'foo'))
  1601         list(chunker.compress(b"foo"))
  1579         list(chunker.finish())
  1602         list(chunker.finish())
  1580 
  1603 
  1581         with self.assertRaisesRegexp(
  1604         with self.assertRaisesRegex(
  1582                 zstd.ZstdError,
  1605             zstd.ZstdError, r"cannot call compress\(\) after compression finished"
  1583                 r'cannot call compress\(\) after compression finished'):
  1606         ):
  1584             list(chunker.compress(b'foo'))
  1607             list(chunker.compress(b"foo"))
  1585 
  1608 
  1586     def test_flush_after_finish(self):
  1609     def test_flush_after_finish(self):
  1587         cctx = zstd.ZstdCompressor()
  1610         cctx = zstd.ZstdCompressor()
  1588         chunker = cctx.chunker()
  1611         chunker = cctx.chunker()
  1589 
  1612 
  1590         list(chunker.compress(b'foo'))
  1613         list(chunker.compress(b"foo"))
  1591         list(chunker.finish())
  1614         list(chunker.finish())
  1592 
  1615 
  1593         with self.assertRaisesRegexp(
  1616         with self.assertRaisesRegex(
  1594                 zstd.ZstdError,
  1617             zstd.ZstdError, r"cannot call flush\(\) after compression finished"
  1595                 r'cannot call flush\(\) after compression finished'):
  1618         ):
  1596             list(chunker.flush())
  1619             list(chunker.flush())
  1597 
  1620 
  1598     def test_finish_after_finish(self):
  1621     def test_finish_after_finish(self):
  1599         cctx = zstd.ZstdCompressor()
  1622         cctx = zstd.ZstdCompressor()
  1600         chunker = cctx.chunker()
  1623         chunker = cctx.chunker()
  1601 
  1624 
  1602         list(chunker.compress(b'foo'))
  1625         list(chunker.compress(b"foo"))
  1603         list(chunker.finish())
  1626         list(chunker.finish())
  1604 
  1627 
  1605         with self.assertRaisesRegexp(
  1628         with self.assertRaisesRegex(
  1606                 zstd.ZstdError,
  1629             zstd.ZstdError, r"cannot call finish\(\) after compression finished"
  1607                 r'cannot call finish\(\) after compression finished'):
  1630         ):
  1608             list(chunker.finish())
  1631             list(chunker.finish())
  1609 
  1632 
  1610 
  1633 
  1611 class TestCompressor_multi_compress_to_buffer(unittest.TestCase):
  1634 class TestCompressor_multi_compress_to_buffer(TestCase):
  1612     def test_invalid_inputs(self):
  1635     def test_invalid_inputs(self):
  1613         cctx = zstd.ZstdCompressor()
  1636         cctx = zstd.ZstdCompressor()
  1614 
  1637 
  1615         if not hasattr(cctx, 'multi_compress_to_buffer'):
  1638         if not hasattr(cctx, "multi_compress_to_buffer"):
  1616             self.skipTest('multi_compress_to_buffer not available')
  1639             self.skipTest("multi_compress_to_buffer not available")
  1617 
  1640 
  1618         with self.assertRaises(TypeError):
  1641         with self.assertRaises(TypeError):
  1619             cctx.multi_compress_to_buffer(True)
  1642             cctx.multi_compress_to_buffer(True)
  1620 
  1643 
  1621         with self.assertRaises(TypeError):
  1644         with self.assertRaises(TypeError):
  1622             cctx.multi_compress_to_buffer((1, 2))
  1645             cctx.multi_compress_to_buffer((1, 2))
  1623 
  1646 
  1624         with self.assertRaisesRegexp(TypeError, 'item 0 not a bytes like object'):
  1647         with self.assertRaisesRegex(TypeError, "item 0 not a bytes like object"):
  1625             cctx.multi_compress_to_buffer([u'foo'])
  1648             cctx.multi_compress_to_buffer([u"foo"])
  1626 
  1649 
  1627     def test_empty_input(self):
  1650     def test_empty_input(self):
  1628         cctx = zstd.ZstdCompressor()
  1651         cctx = zstd.ZstdCompressor()
  1629 
  1652 
  1630         if not hasattr(cctx, 'multi_compress_to_buffer'):
  1653         if not hasattr(cctx, "multi_compress_to_buffer"):
  1631             self.skipTest('multi_compress_to_buffer not available')
  1654             self.skipTest("multi_compress_to_buffer not available")
  1632 
  1655 
  1633         with self.assertRaisesRegexp(ValueError, 'no source elements found'):
  1656         with self.assertRaisesRegex(ValueError, "no source elements found"):
  1634             cctx.multi_compress_to_buffer([])
  1657             cctx.multi_compress_to_buffer([])
  1635 
  1658 
  1636         with self.assertRaisesRegexp(ValueError, 'source elements are empty'):
  1659         with self.assertRaisesRegex(ValueError, "source elements are empty"):
  1637             cctx.multi_compress_to_buffer([b'', b'', b''])
  1660             cctx.multi_compress_to_buffer([b"", b"", b""])
  1638 
  1661 
  1639     def test_list_input(self):
  1662     def test_list_input(self):
  1640         cctx = zstd.ZstdCompressor(write_checksum=True)
  1663         cctx = zstd.ZstdCompressor(write_checksum=True)
  1641 
  1664 
  1642         if not hasattr(cctx, 'multi_compress_to_buffer'):
  1665         if not hasattr(cctx, "multi_compress_to_buffer"):
  1643             self.skipTest('multi_compress_to_buffer not available')
  1666             self.skipTest("multi_compress_to_buffer not available")
  1644 
  1667 
  1645         original = [b'foo' * 12, b'bar' * 6]
  1668         original = [b"foo" * 12, b"bar" * 6]
  1646         frames = [cctx.compress(c) for c in original]
  1669         frames = [cctx.compress(c) for c in original]
  1647         b = cctx.multi_compress_to_buffer(original)
  1670         b = cctx.multi_compress_to_buffer(original)
  1648 
  1671 
  1649         self.assertIsInstance(b, zstd.BufferWithSegmentsCollection)
  1672         self.assertIsInstance(b, zstd.BufferWithSegmentsCollection)
  1650 
  1673 
  1655         self.assertEqual(b[1].tobytes(), frames[1])
  1678         self.assertEqual(b[1].tobytes(), frames[1])
  1656 
  1679 
  1657     def test_buffer_with_segments_input(self):
  1680     def test_buffer_with_segments_input(self):
  1658         cctx = zstd.ZstdCompressor(write_checksum=True)
  1681         cctx = zstd.ZstdCompressor(write_checksum=True)
  1659 
  1682 
  1660         if not hasattr(cctx, 'multi_compress_to_buffer'):
  1683         if not hasattr(cctx, "multi_compress_to_buffer"):
  1661             self.skipTest('multi_compress_to_buffer not available')
  1684             self.skipTest("multi_compress_to_buffer not available")
  1662 
  1685 
  1663         original = [b'foo' * 4, b'bar' * 6]
  1686         original = [b"foo" * 4, b"bar" * 6]
  1664         frames = [cctx.compress(c) for c in original]
  1687         frames = [cctx.compress(c) for c in original]
  1665 
  1688 
  1666         offsets = struct.pack('=QQQQ', 0, len(original[0]),
  1689         offsets = struct.pack(
  1667                                        len(original[0]), len(original[1]))
  1690             "=QQQQ", 0, len(original[0]), len(original[0]), len(original[1])
  1668         segments = zstd.BufferWithSegments(b''.join(original), offsets)
  1691         )
       
  1692         segments = zstd.BufferWithSegments(b"".join(original), offsets)
  1669 
  1693 
  1670         result = cctx.multi_compress_to_buffer(segments)
  1694         result = cctx.multi_compress_to_buffer(segments)
  1671 
  1695 
  1672         self.assertEqual(len(result), 2)
  1696         self.assertEqual(len(result), 2)
  1673         self.assertEqual(result.size(), 47)
  1697         self.assertEqual(result.size(), 47)
  1676         self.assertEqual(result[1].tobytes(), frames[1])
  1700         self.assertEqual(result[1].tobytes(), frames[1])
  1677 
  1701 
  1678     def test_buffer_with_segments_collection_input(self):
  1702     def test_buffer_with_segments_collection_input(self):
  1679         cctx = zstd.ZstdCompressor(write_checksum=True)
  1703         cctx = zstd.ZstdCompressor(write_checksum=True)
  1680 
  1704 
  1681         if not hasattr(cctx, 'multi_compress_to_buffer'):
  1705         if not hasattr(cctx, "multi_compress_to_buffer"):
  1682             self.skipTest('multi_compress_to_buffer not available')
  1706             self.skipTest("multi_compress_to_buffer not available")
  1683 
  1707 
  1684         original = [
  1708         original = [
  1685             b'foo1',
  1709             b"foo1",
  1686             b'foo2' * 2,
  1710             b"foo2" * 2,
  1687             b'foo3' * 3,
  1711             b"foo3" * 3,
  1688             b'foo4' * 4,
  1712             b"foo4" * 4,
  1689             b'foo5' * 5,
  1713             b"foo5" * 5,
  1690         ]
  1714         ]
  1691 
  1715 
  1692         frames = [cctx.compress(c) for c in original]
  1716         frames = [cctx.compress(c) for c in original]
  1693 
  1717 
  1694         b = b''.join([original[0], original[1]])
  1718         b = b"".join([original[0], original[1]])
  1695         b1 = zstd.BufferWithSegments(b, struct.pack('=QQQQ',
  1719         b1 = zstd.BufferWithSegments(
  1696                                                     0, len(original[0]),
  1720             b,
  1697                                                     len(original[0]), len(original[1])))
  1721             struct.pack(
  1698         b = b''.join([original[2], original[3], original[4]])
  1722                 "=QQQQ", 0, len(original[0]), len(original[0]), len(original[1])
  1699         b2 = zstd.BufferWithSegments(b, struct.pack('=QQQQQQ',
  1723             ),
  1700                                                     0, len(original[2]),
  1724         )
  1701                                                     len(original[2]), len(original[3]),
  1725         b = b"".join([original[2], original[3], original[4]])
  1702                                                     len(original[2]) + len(original[3]), len(original[4])))
  1726         b2 = zstd.BufferWithSegments(
       
  1727             b,
       
  1728             struct.pack(
       
  1729                 "=QQQQQQ",
       
  1730                 0,
       
  1731                 len(original[2]),
       
  1732                 len(original[2]),
       
  1733                 len(original[3]),
       
  1734                 len(original[2]) + len(original[3]),
       
  1735                 len(original[4]),
       
  1736             ),
       
  1737         )
  1703 
  1738 
  1704         c = zstd.BufferWithSegmentsCollection(b1, b2)
  1739         c = zstd.BufferWithSegmentsCollection(b1, b2)
  1705 
  1740 
  1706         result = cctx.multi_compress_to_buffer(c)
  1741         result = cctx.multi_compress_to_buffer(c)
  1707 
  1742 
  1712 
  1747 
  1713     def test_multiple_threads(self):
  1748     def test_multiple_threads(self):
  1714         # threads argument will cause multi-threaded ZSTD APIs to be used, which will
  1749         # threads argument will cause multi-threaded ZSTD APIs to be used, which will
  1715         # make output different.
  1750         # make output different.
  1716         refcctx = zstd.ZstdCompressor(write_checksum=True)
  1751         refcctx = zstd.ZstdCompressor(write_checksum=True)
  1717         reference = [refcctx.compress(b'x' * 64), refcctx.compress(b'y' * 64)]
  1752         reference = [refcctx.compress(b"x" * 64), refcctx.compress(b"y" * 64)]
  1718 
  1753 
  1719         cctx = zstd.ZstdCompressor(write_checksum=True)
  1754         cctx = zstd.ZstdCompressor(write_checksum=True)
  1720 
  1755 
  1721         if not hasattr(cctx, 'multi_compress_to_buffer'):
  1756         if not hasattr(cctx, "multi_compress_to_buffer"):
  1722             self.skipTest('multi_compress_to_buffer not available')
  1757             self.skipTest("multi_compress_to_buffer not available")
  1723 
  1758 
  1724         frames = []
  1759         frames = []
  1725         frames.extend(b'x' * 64 for i in range(256))
  1760         frames.extend(b"x" * 64 for i in range(256))
  1726         frames.extend(b'y' * 64 for i in range(256))
  1761         frames.extend(b"y" * 64 for i in range(256))
  1727 
  1762 
  1728         result = cctx.multi_compress_to_buffer(frames, threads=-1)
  1763         result = cctx.multi_compress_to_buffer(frames, threads=-1)
  1729 
  1764 
  1730         self.assertEqual(len(result), 512)
  1765         self.assertEqual(len(result), 512)
  1731         for i in range(512):
  1766         for i in range(512):