view contrib/python-zstandard/tests/common.py @ 37774:d6970628b95f

fix: use templater to substitute values in command string bytes.format() isn't supported on Python 3. Luckily, our template syntax is similar so we can reuse it. We need a hack to disable \-escapes as '\' is a directory separator on Windows.
author Yuya Nishihara <yuya@tcha.org>
date Sat, 14 Apr 2018 00:30:39 +0900
parents b1fb341d8a61
children 675775c33ab6
line wrap: on
line source

import imp
import inspect
import io
import os
import types

try:
    import hypothesis
except ImportError:
    hypothesis = None


def make_cffi(cls):
    """Decorator to add CFFI versions of each test method."""

    # The module containing this class definition should
    # `import zstandard as zstd`. Otherwise things may blow up.
    mod = inspect.getmodule(cls)
    if not hasattr(mod, 'zstd'):
        raise Exception('test module does not contain "zstd" symbol')

    if not hasattr(mod.zstd, 'backend'):
        raise Exception('zstd symbol does not have "backend" attribute; did '
                        'you `import zstandard as zstd`?')

    # If `import zstandard` already chose the cffi backend, there is nothing
    # for us to do: we only add the cffi variation if the default backend
    # is the C extension.
    if mod.zstd.backend == 'cffi':
        return cls

    old_env = dict(os.environ)
    os.environ['PYTHON_ZSTANDARD_IMPORT_POLICY'] = 'cffi'
    try:
        try:
            mod_info = imp.find_module('zstandard')
            mod = imp.load_module('zstandard_cffi', *mod_info)
        except ImportError:
            return cls
    finally:
        os.environ.clear()
        os.environ.update(old_env)

    if mod.backend != 'cffi':
        raise Exception('got the zstandard %s backend instead of cffi' % mod.backend)

    # If CFFI version is available, dynamically construct test methods
    # that use it.

    for attr in dir(cls):
        fn = getattr(cls, attr)
        if not inspect.ismethod(fn) and not inspect.isfunction(fn):
            continue

        if not fn.__name__.startswith('test_'):
            continue

        name = '%s_cffi' % fn.__name__

        # Replace the "zstd" symbol with the CFFI module instance. Then copy
        # the function object and install it in a new attribute.
        if isinstance(fn, types.FunctionType):
            globs = dict(fn.__globals__)
            globs['zstd'] = mod
            new_fn = types.FunctionType(fn.__code__, globs, name,
                                        fn.__defaults__, fn.__closure__)
            new_method = new_fn
        else:
            globs = dict(fn.__func__.func_globals)
            globs['zstd'] = mod
            new_fn = types.FunctionType(fn.__func__.func_code, globs, name,
                                        fn.__func__.func_defaults,
                                        fn.__func__.func_closure)
            new_method = types.UnboundMethodType(new_fn, fn.im_self,
                                                 fn.im_class)

        setattr(cls, name, new_method)

    return cls


class OpCountingBytesIO(io.BytesIO):
    def __init__(self, *args, **kwargs):
        self._read_count = 0
        self._write_count = 0
        return super(OpCountingBytesIO, self).__init__(*args, **kwargs)

    def read(self, *args):
        self._read_count += 1
        return super(OpCountingBytesIO, self).read(*args)

    def write(self, data):
        self._write_count += 1
        return super(OpCountingBytesIO, self).write(data)


_source_files = []


def random_input_data():
    """Obtain the raw content of source files.

    This is used for generating "random" data to feed into fuzzing, since it is
    faster than random content generation.
    """
    if _source_files:
        return _source_files

    for root, dirs, files in os.walk(os.path.dirname(__file__)):
        dirs[:] = list(sorted(dirs))
        for f in sorted(files):
            try:
                with open(os.path.join(root, f), 'rb') as fh:
                    data = fh.read()
                    if data:
                        _source_files.append(data)
            except OSError:
                pass

    return _source_files


def generate_samples():
    inputs = [
        b'foo',
        b'bar',
        b'abcdef',
        b'sometext',
        b'baz',
    ]

    samples = []

    for i in range(128):
        samples.append(inputs[i % 5])
        samples.append(inputs[i % 5] * (i + 3))
        samples.append(inputs[-(i % 5)] * (i + 2))

    return samples


if hypothesis:
    default_settings = hypothesis.settings()
    hypothesis.settings.register_profile('default', default_settings)

    ci_settings = hypothesis.settings(max_examples=2500,
                                      max_iterations=2500)
    hypothesis.settings.register_profile('ci', ci_settings)

    hypothesis.settings.load_profile(
        os.environ.get('HYPOTHESIS_PROFILE', 'default'))