contrib/python-zstandard/tests/test_train_dictionary.py
changeset 43999 de7838053207
parent 42941 69de49c4e39c
child 44232 5e84a96d865b
equal deleted inserted replaced
43998:873d0fecb9a3 43999:de7838053207
     2 import sys
     2 import sys
     3 import unittest
     3 import unittest
     4 
     4 
     5 import zstandard as zstd
     5 import zstandard as zstd
     6 
     6 
     7 from . common import (
     7 from .common import (
     8     generate_samples,
     8     generate_samples,
     9     make_cffi,
     9     make_cffi,
    10     random_input_data,
    10     random_input_data,
       
    11     TestCase,
    11 )
    12 )
    12 
    13 
    13 if sys.version_info[0] >= 3:
    14 if sys.version_info[0] >= 3:
    14     int_type = int
    15     int_type = int
    15 else:
    16 else:
    16     int_type = long
    17     int_type = long
    17 
    18 
    18 
    19 
    19 @make_cffi
    20 @make_cffi
    20 class TestTrainDictionary(unittest.TestCase):
    21 class TestTrainDictionary(TestCase):
    21     def test_no_args(self):
    22     def test_no_args(self):
    22         with self.assertRaises(TypeError):
    23         with self.assertRaises(TypeError):
    23             zstd.train_dictionary()
    24             zstd.train_dictionary()
    24 
    25 
    25     def test_bad_args(self):
    26     def test_bad_args(self):
    26         with self.assertRaises(TypeError):
    27         with self.assertRaises(TypeError):
    27             zstd.train_dictionary(8192, u'foo')
    28             zstd.train_dictionary(8192, u"foo")
    28 
    29 
    29         with self.assertRaises(ValueError):
    30         with self.assertRaises(ValueError):
    30             zstd.train_dictionary(8192, [u'foo'])
    31             zstd.train_dictionary(8192, [u"foo"])
    31 
    32 
    32     def test_no_params(self):
    33     def test_no_params(self):
    33         d = zstd.train_dictionary(8192, random_input_data())
    34         d = zstd.train_dictionary(8192, random_input_data())
    34         self.assertIsInstance(d.dict_id(), int_type)
    35         self.assertIsInstance(d.dict_id(), int_type)
    35 
    36 
    36         # The dictionary ID may be different across platforms.
    37         # The dictionary ID may be different across platforms.
    37         expected = b'\x37\xa4\x30\xec' + struct.pack('<I', d.dict_id())
    38         expected = b"\x37\xa4\x30\xec" + struct.pack("<I", d.dict_id())
    38 
    39 
    39         data = d.as_bytes()
    40         data = d.as_bytes()
    40         self.assertEqual(data[0:8], expected)
    41         self.assertEqual(data[0:8], expected)
    41 
    42 
    42     def test_basic(self):
    43     def test_basic(self):
    43         d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16)
    44         d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16)
    44         self.assertIsInstance(d.dict_id(), int_type)
    45         self.assertIsInstance(d.dict_id(), int_type)
    45 
    46 
    46         data = d.as_bytes()
    47         data = d.as_bytes()
    47         self.assertEqual(data[0:4], b'\x37\xa4\x30\xec')
    48         self.assertEqual(data[0:4], b"\x37\xa4\x30\xec")
    48 
    49 
    49         self.assertEqual(d.k, 64)
    50         self.assertEqual(d.k, 64)
    50         self.assertEqual(d.d, 16)
    51         self.assertEqual(d.d, 16)
    51 
    52 
    52     def test_set_dict_id(self):
    53     def test_set_dict_id(self):
    53         d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16,
    54         d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16, dict_id=42)
    54                                   dict_id=42)
       
    55         self.assertEqual(d.dict_id(), 42)
    55         self.assertEqual(d.dict_id(), 42)
    56 
    56 
    57     def test_optimize(self):
    57     def test_optimize(self):
    58         d = zstd.train_dictionary(8192, generate_samples(), threads=-1, steps=1,
    58         d = zstd.train_dictionary(8192, generate_samples(), threads=-1, steps=1, d=16)
    59                                   d=16)
       
    60 
    59 
    61         # This varies by platform.
    60         # This varies by platform.
    62         self.assertIn(d.k, (50, 2000))
    61         self.assertIn(d.k, (50, 2000))
    63         self.assertEqual(d.d, 16)
    62         self.assertEqual(d.d, 16)
    64 
    63 
       
    64 
    65 @make_cffi
    65 @make_cffi
    66 class TestCompressionDict(unittest.TestCase):
    66 class TestCompressionDict(TestCase):
    67     def test_bad_mode(self):
    67     def test_bad_mode(self):
    68         with self.assertRaisesRegexp(ValueError, 'invalid dictionary load mode'):
    68         with self.assertRaisesRegex(ValueError, "invalid dictionary load mode"):
    69             zstd.ZstdCompressionDict(b'foo', dict_type=42)
    69             zstd.ZstdCompressionDict(b"foo", dict_type=42)
    70 
    70 
    71     def test_bad_precompute_compress(self):
    71     def test_bad_precompute_compress(self):
    72         d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16)
    72         d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16)
    73 
    73 
    74         with self.assertRaisesRegexp(ValueError, 'must specify one of level or '):
    74         with self.assertRaisesRegex(ValueError, "must specify one of level or "):
    75             d.precompute_compress()
    75             d.precompute_compress()
    76 
    76 
    77         with self.assertRaisesRegexp(ValueError, 'must only specify one of level or '):
    77         with self.assertRaisesRegex(ValueError, "must only specify one of level or "):
    78             d.precompute_compress(level=3,
    78             d.precompute_compress(
    79                                   compression_params=zstd.CompressionParameters())
    79                 level=3, compression_params=zstd.CompressionParameters()
       
    80             )
    80 
    81 
    81     def test_precompute_compress_rawcontent(self):
    82     def test_precompute_compress_rawcontent(self):
    82         d = zstd.ZstdCompressionDict(b'dictcontent' * 64,
    83         d = zstd.ZstdCompressionDict(
    83                                      dict_type=zstd.DICT_TYPE_RAWCONTENT)
    84             b"dictcontent" * 64, dict_type=zstd.DICT_TYPE_RAWCONTENT
       
    85         )
    84         d.precompute_compress(level=1)
    86         d.precompute_compress(level=1)
    85 
    87 
    86         d = zstd.ZstdCompressionDict(b'dictcontent' * 64,
    88         d = zstd.ZstdCompressionDict(
    87                                      dict_type=zstd.DICT_TYPE_FULLDICT)
    89             b"dictcontent" * 64, dict_type=zstd.DICT_TYPE_FULLDICT
    88         with self.assertRaisesRegexp(zstd.ZstdError, 'unable to precompute dictionary'):
    90         )
       
    91         with self.assertRaisesRegex(zstd.ZstdError, "unable to precompute dictionary"):
    89             d.precompute_compress(level=1)
    92             d.precompute_compress(level=1)