I made this change upstream and it will make it into the next
release of python-zstandard. I figured I'd send it Mercurial's
way because it will allow us to drop this directory from the black
exclusion list.
- skip-blame blackening
| mharbison72 | |
| pulkit |
| hg-reviewers |
I made this change upstream and it will make it into the next
release of python-zstandard. I figured I'd send it Mercurial's
way because it will allow us to drop this directory from the black
exclusion list.
| No Linters Available |
| No Unit Test Coverage |
"contrib/examples/fix.hgrc" also skips "contrib/python-zstandard/**". It seems like the intent of this is to allow local edits to python-zstandard (or we wouldn't start checking the format), so maybe it should be covered by fix too?
I'm -0 on auto-formatting third-party code at all, and that includes zstandard stuff. Why do we care if it's in the ignorelist?
| Path | Packages | |||
|---|---|---|---|---|
| M | black.toml (1 line) | |||
| M | contrib/python-zstandard/make_cffi.py (7 lines) | |||
| M | contrib/python-zstandard/setup.py (4 lines) | |||
| M | contrib/python-zstandard/setup_zstd.py (8 lines) | |||
| M | contrib/python-zstandard/tests/common.py (12 lines) | |||
| M | contrib/python-zstandard/tests/test_buffer_util.py (15 lines) | |||
| M | contrib/python-zstandard/tests/test_compressor.py (69 lines) | |||
| M | contrib/python-zstandard/tests/test_compressor_fuzzing.py (94 lines) | |||
| M | contrib/python-zstandard/tests/test_data_structures.py (28 lines) | |||
| M | contrib/python-zstandard/tests/test_data_structures_fuzzing.py (22 lines) | |||
| M | contrib/python-zstandard/tests/test_decompressor.py (90 lines) | |||
| M | contrib/python-zstandard/tests/test_decompressor_fuzzing.py (35 lines) | |||
| M | contrib/python-zstandard/tests/test_train_dictionary.py (20 lines) | |||
| M | contrib/python-zstandard/zstandard/cffi.py (310 lines) | |||
| M | tests/test-check-format.t (2 lines) |
| Commit | Parents | Author | Summary | Date |
|---|---|---|---|---|
| 7983a6688fe2 | 98349eddceef | Gregory Szorc | Jan 18 2020, 12:53 AM |
| [tool.black] | [tool.black] | ||||
| line-length = 80 | line-length = 80 | ||||
| exclude = ''' | exclude = ''' | ||||
| build/ | build/ | ||||
| | wheelhouse/ | | wheelhouse/ | ||||
| | dist/ | | dist/ | ||||
| | packages/ | | packages/ | ||||
| | \.hg/ | | \.hg/ | ||||
| | \.mypy_cache/ | | \.mypy_cache/ | ||||
| | \.venv/ | | \.venv/ | ||||
| | mercurial/thirdparty/ | | mercurial/thirdparty/ | ||||
| | contrib/python-zstandard/ | |||||
| ''' | ''' | ||||
| skip-string-normalization = true | skip-string-normalization = true | ||||
| quiet = true | quiet = true | ||||
| "dictBuilder/fastcover.c", | "dictBuilder/fastcover.c", | ||||
| "dictBuilder/divsufsort.c", | "dictBuilder/divsufsort.c", | ||||
| "dictBuilder/zdict.c", | "dictBuilder/zdict.c", | ||||
| ) | ) | ||||
| ] | ] | ||||
| # Headers whose preprocessed output will be fed into cdef(). | # Headers whose preprocessed output will be fed into cdef(). | ||||
| HEADERS = [ | HEADERS = [ | ||||
| os.path.join(HERE, "zstd", *p) for p in (("zstd.h",), ("dictBuilder", "zdict.h"),) | os.path.join(HERE, "zstd", *p) | ||||
| for p in (("zstd.h",), ("dictBuilder", "zdict.h"),) | |||||
| ] | ] | ||||
| INCLUDE_DIRS = [ | INCLUDE_DIRS = [ | ||||
| os.path.join(HERE, d) | os.path.join(HERE, d) | ||||
| for d in ( | for d in ( | ||||
| "zstd", | "zstd", | ||||
| "zstd/common", | "zstd/common", | ||||
| "zstd/compress", | "zstd/compress", | ||||
| fd, input_file = tempfile.mkstemp(suffix=".h") | fd, input_file = tempfile.mkstemp(suffix=".h") | ||||
| os.write(fd, b"".join(lines)) | os.write(fd, b"".join(lines)) | ||||
| os.close(fd) | os.close(fd) | ||||
| try: | try: | ||||
| env = dict(os.environ) | env = dict(os.environ) | ||||
| if getattr(compiler, "_paths", None): | if getattr(compiler, "_paths", None): | ||||
| env["PATH"] = compiler._paths | env["PATH"] = compiler._paths | ||||
| process = subprocess.Popen(args + [input_file], stdout=subprocess.PIPE, env=env) | process = subprocess.Popen( | ||||
| args + [input_file], stdout=subprocess.PIPE, env=env | |||||
| ) | |||||
| output = process.communicate()[0] | output = process.communicate()[0] | ||||
| ret = process.poll() | ret = process.poll() | ||||
| if ret: | if ret: | ||||
| raise Exception("preprocessor exited with error") | raise Exception("preprocessor exited with error") | ||||
| return output | return output | ||||
| finally: | finally: | ||||
| os.unlink(input_file) | os.unlink(input_file) | ||||
| for line in fh: | for line in fh: | ||||
| if not line.startswith("#define PYTHON_ZSTANDARD_VERSION"): | if not line.startswith("#define PYTHON_ZSTANDARD_VERSION"): | ||||
| continue | continue | ||||
| version = line.split()[2][1:-1] | version = line.split()[2][1:-1] | ||||
| break | break | ||||
| if not version: | if not version: | ||||
| raise Exception("could not resolve package version; " "this should never happen") | raise Exception( | ||||
| "could not resolve package version; " "this should never happen" | |||||
| ) | |||||
| setup( | setup( | ||||
| name="zstandard", | name="zstandard", | ||||
| version=version, | version=version, | ||||
| description="Zstandard bindings for Python", | description="Zstandard bindings for Python", | ||||
| long_description=open("README.rst", "r").read(), | long_description=open("README.rst", "r").read(), | ||||
| url="https://github.com/indygreg/python-zstandard", | url="https://github.com/indygreg/python-zstandard", | ||||
| author="Gregory Szorc", | author="Gregory Szorc", | ||||
| """ | """ | ||||
| actual_root = os.path.abspath(os.path.dirname(__file__)) | actual_root = os.path.abspath(os.path.dirname(__file__)) | ||||
| root = root or actual_root | root = root or actual_root | ||||
| sources = set([os.path.join(actual_root, p) for p in ext_sources]) | sources = set([os.path.join(actual_root, p) for p in ext_sources]) | ||||
| if not system_zstd: | if not system_zstd: | ||||
| sources.update([os.path.join(actual_root, p) for p in zstd_sources]) | sources.update([os.path.join(actual_root, p) for p in zstd_sources]) | ||||
| if support_legacy: | if support_legacy: | ||||
| sources.update([os.path.join(actual_root, p) for p in zstd_sources_legacy]) | sources.update( | ||||
| [os.path.join(actual_root, p) for p in zstd_sources_legacy] | |||||
| ) | |||||
| sources = list(sources) | sources = list(sources) | ||||
| include_dirs = set([os.path.join(actual_root, d) for d in ext_includes]) | include_dirs = set([os.path.join(actual_root, d) for d in ext_includes]) | ||||
| if not system_zstd: | if not system_zstd: | ||||
| include_dirs.update([os.path.join(actual_root, d) for d in zstd_includes]) | include_dirs.update( | ||||
| [os.path.join(actual_root, d) for d in zstd_includes] | |||||
| ) | |||||
| if support_legacy: | if support_legacy: | ||||
| include_dirs.update( | include_dirs.update( | ||||
| [os.path.join(actual_root, d) for d in zstd_includes_legacy] | [os.path.join(actual_root, d) for d in zstd_includes_legacy] | ||||
| ) | ) | ||||
| include_dirs = list(include_dirs) | include_dirs = list(include_dirs) | ||||
| depends = [os.path.join(actual_root, p) for p in zstd_depends] | depends = [os.path.join(actual_root, p) for p in zstd_depends] | ||||
| mod = imp.load_module("zstandard_cffi", *mod_info) | mod = imp.load_module("zstandard_cffi", *mod_info) | ||||
| except ImportError: | except ImportError: | ||||
| return cls | return cls | ||||
| finally: | finally: | ||||
| os.environ.clear() | os.environ.clear() | ||||
| os.environ.update(old_env) | os.environ.update(old_env) | ||||
| if mod.backend != "cffi": | if mod.backend != "cffi": | ||||
| raise Exception("got the zstandard %s backend instead of cffi" % mod.backend) | raise Exception( | ||||
| "got the zstandard %s backend instead of cffi" % mod.backend | |||||
| ) | |||||
| # If CFFI version is available, dynamically construct test methods | # If CFFI version is available, dynamically construct test methods | ||||
| # that use it. | # that use it. | ||||
| for attr in dir(cls): | for attr in dir(cls): | ||||
| fn = getattr(cls, attr) | fn = getattr(cls, attr) | ||||
| if not inspect.ismethod(fn) and not inspect.isfunction(fn): | if not inspect.ismethod(fn) and not inspect.isfunction(fn): | ||||
| continue | continue | ||||
| globs["zstd"] = mod | globs["zstd"] = mod | ||||
| new_fn = types.FunctionType( | new_fn = types.FunctionType( | ||||
| fn.__func__.func_code, | fn.__func__.func_code, | ||||
| globs, | globs, | ||||
| name, | name, | ||||
| fn.__func__.func_defaults, | fn.__func__.func_defaults, | ||||
| fn.__func__.func_closure, | fn.__func__.func_closure, | ||||
| ) | ) | ||||
| new_method = types.UnboundMethodType(new_fn, fn.im_self, fn.im_class) | new_method = types.UnboundMethodType( | ||||
| new_fn, fn.im_self, fn.im_class | |||||
| ) | |||||
| setattr(cls, name, new_method) | setattr(cls, name, new_method) | ||||
| return cls | return cls | ||||
| class NonClosingBytesIO(io.BytesIO): | class NonClosingBytesIO(io.BytesIO): | ||||
| """BytesIO that saves the underlying buffer on close(). | """BytesIO that saves the underlying buffer on close(). | ||||
| hypothesis.settings.register_profile("default", default_settings) | hypothesis.settings.register_profile("default", default_settings) | ||||
| ci_settings = hypothesis.settings(deadline=20000, max_examples=1000) | ci_settings = hypothesis.settings(deadline=20000, max_examples=1000) | ||||
| hypothesis.settings.register_profile("ci", ci_settings) | hypothesis.settings.register_profile("ci", ci_settings) | ||||
| expensive_settings = hypothesis.settings(deadline=None, max_examples=10000) | expensive_settings = hypothesis.settings(deadline=None, max_examples=10000) | ||||
| hypothesis.settings.register_profile("expensive", expensive_settings) | hypothesis.settings.register_profile("expensive", expensive_settings) | ||||
| hypothesis.settings.load_profile(os.environ.get("HYPOTHESIS_PROFILE", "default")) | hypothesis.settings.load_profile( | ||||
| os.environ.get("HYPOTHESIS_PROFILE", "default") | |||||
| ) | |||||
| self.assertEqual(b[0].offset, 0) | self.assertEqual(b[0].offset, 0) | ||||
| self.assertEqual(b[0].tobytes(), b"foo") | self.assertEqual(b[0].tobytes(), b"foo") | ||||
| def test_multiple(self): | def test_multiple(self): | ||||
| if not hasattr(zstd, "BufferWithSegments"): | if not hasattr(zstd, "BufferWithSegments"): | ||||
| self.skipTest("BufferWithSegments not available") | self.skipTest("BufferWithSegments not available") | ||||
| b = zstd.BufferWithSegments( | b = zstd.BufferWithSegments( | ||||
| b"foofooxfooxy", b"".join([ss.pack(0, 3), ss.pack(3, 4), ss.pack(7, 5)]) | b"foofooxfooxy", | ||||
| b"".join([ss.pack(0, 3), ss.pack(3, 4), ss.pack(7, 5)]), | |||||
| ) | ) | ||||
| self.assertEqual(len(b), 3) | self.assertEqual(len(b), 3) | ||||
| self.assertEqual(b.size, 12) | self.assertEqual(b.size, 12) | ||||
| self.assertEqual(b.tobytes(), b"foofooxfooxy") | self.assertEqual(b.tobytes(), b"foofooxfooxy") | ||||
| self.assertEqual(b[0].tobytes(), b"foo") | self.assertEqual(b[0].tobytes(), b"foo") | ||||
| self.assertEqual(b[1].tobytes(), b"foox") | self.assertEqual(b[1].tobytes(), b"foox") | ||||
| self.assertEqual(b[2].tobytes(), b"fooxy") | self.assertEqual(b[2].tobytes(), b"fooxy") | ||||
| class TestBufferWithSegmentsCollection(TestCase): | class TestBufferWithSegmentsCollection(TestCase): | ||||
| def test_empty_constructor(self): | def test_empty_constructor(self): | ||||
| if not hasattr(zstd, "BufferWithSegmentsCollection"): | if not hasattr(zstd, "BufferWithSegmentsCollection"): | ||||
| self.skipTest("BufferWithSegmentsCollection not available") | self.skipTest("BufferWithSegmentsCollection not available") | ||||
| with self.assertRaisesRegex(ValueError, "must pass at least 1 argument"): | with self.assertRaisesRegex( | ||||
| ValueError, "must pass at least 1 argument" | |||||
| ): | |||||
| zstd.BufferWithSegmentsCollection() | zstd.BufferWithSegmentsCollection() | ||||
| def test_argument_validation(self): | def test_argument_validation(self): | ||||
| if not hasattr(zstd, "BufferWithSegmentsCollection"): | if not hasattr(zstd, "BufferWithSegmentsCollection"): | ||||
| self.skipTest("BufferWithSegmentsCollection not available") | self.skipTest("BufferWithSegmentsCollection not available") | ||||
| with self.assertRaisesRegex(TypeError, "arguments must be BufferWithSegments"): | with self.assertRaisesRegex( | ||||
| TypeError, "arguments must be BufferWithSegments" | |||||
| ): | |||||
| zstd.BufferWithSegmentsCollection(None) | zstd.BufferWithSegmentsCollection(None) | ||||
| with self.assertRaisesRegex(TypeError, "arguments must be BufferWithSegments"): | with self.assertRaisesRegex( | ||||
| TypeError, "arguments must be BufferWithSegments" | |||||
| ): | |||||
| zstd.BufferWithSegmentsCollection( | zstd.BufferWithSegmentsCollection( | ||||
| zstd.BufferWithSegments(b"foo", ss.pack(0, 3)), None | zstd.BufferWithSegments(b"foo", ss.pack(0, 3)), None | ||||
| ) | ) | ||||
| with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
| ValueError, "ZstdBufferWithSegments cannot be empty" | ValueError, "ZstdBufferWithSegments cannot be empty" | ||||
| ): | ): | ||||
| zstd.BufferWithSegmentsCollection(zstd.BufferWithSegments(b"", b"")) | zstd.BufferWithSegmentsCollection(zstd.BufferWithSegments(b"", b"")) | ||||
| if sys.version_info[0] >= 3: | if sys.version_info[0] >= 3: | ||||
| next = lambda it: it.__next__() | next = lambda it: it.__next__() | ||||
| else: | else: | ||||
| next = lambda it: it.next() | next = lambda it: it.next() | ||||
| def multithreaded_chunk_size(level, source_size=0): | def multithreaded_chunk_size(level, source_size=0): | ||||
| params = zstd.ZstdCompressionParameters.from_level(level, source_size=source_size) | params = zstd.ZstdCompressionParameters.from_level( | ||||
| level, source_size=source_size | |||||
| ) | |||||
| return 1 << (params.window_log + 2) | return 1 << (params.window_log + 2) | ||||
| @make_cffi | @make_cffi | ||||
| class TestCompressor(TestCase): | class TestCompressor(TestCase): | ||||
| def test_level_bounds(self): | def test_level_bounds(self): | ||||
| with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
| cctx = zstd.ZstdCompressor(level=3, write_content_size=False) | cctx = zstd.ZstdCompressor(level=3, write_content_size=False) | ||||
| result = cctx.compress(b"".join(chunks)) | result = cctx.compress(b"".join(chunks)) | ||||
| self.assertEqual(len(result), 999) | self.assertEqual(len(result), 999) | ||||
| self.assertEqual(result[0:4], b"\x28\xb5\x2f\xfd") | self.assertEqual(result[0:4], b"\x28\xb5\x2f\xfd") | ||||
| # This matches the test for read_to_iter() below. | # This matches the test for read_to_iter() below. | ||||
| cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | ||||
| result = cctx.compress(b"f" * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE + b"o") | result = cctx.compress( | ||||
| b"f" * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE + b"o" | |||||
| ) | |||||
| self.assertEqual( | self.assertEqual( | ||||
| result, | result, | ||||
| b"\x28\xb5\x2f\xfd\x00\x40\x54\x00\x00" | b"\x28\xb5\x2f\xfd\x00\x40\x54\x00\x00" | ||||
| b"\x10\x66\x66\x01\x00\xfb\xff\x39\xc0" | b"\x10\x66\x66\x01\x00\xfb\xff\x39\xc0" | ||||
| b"\x02\x09\x00\x00\x6f", | b"\x02\x09\x00\x00\x6f", | ||||
| ) | ) | ||||
| def test_negative_level(self): | def test_negative_level(self): | ||||
| cctx = zstd.ZstdCompressor(level=-4) | cctx = zstd.ZstdCompressor(level=-4) | ||||
| result = cctx.compress(b"foo" * 256) | result = cctx.compress(b"foo" * 256) | ||||
| def test_no_magic(self): | def test_no_magic(self): | ||||
| params = zstd.ZstdCompressionParameters.from_level(1, format=zstd.FORMAT_ZSTD1) | params = zstd.ZstdCompressionParameters.from_level( | ||||
| 1, format=zstd.FORMAT_ZSTD1 | |||||
| ) | |||||
| cctx = zstd.ZstdCompressor(compression_params=params) | cctx = zstd.ZstdCompressor(compression_params=params) | ||||
| magic = cctx.compress(b"foobar") | magic = cctx.compress(b"foobar") | ||||
| params = zstd.ZstdCompressionParameters.from_level( | params = zstd.ZstdCompressionParameters.from_level( | ||||
| 1, format=zstd.FORMAT_ZSTD1_MAGICLESS | 1, format=zstd.FORMAT_ZSTD1_MAGICLESS | ||||
| ) | ) | ||||
| cctx = zstd.ZstdCompressor(compression_params=params) | cctx = zstd.ZstdCompressor(compression_params=params) | ||||
| no_magic = cctx.compress(b"foobar") | no_magic = cctx.compress(b"foobar") | ||||
| result = cctx.compress(b"foo") | result = cctx.compress(b"foo") | ||||
| params = zstd.get_frame_parameters(result) | params = zstd.get_frame_parameters(result) | ||||
| self.assertEqual(params.content_size, 3) | self.assertEqual(params.content_size, 3) | ||||
| self.assertEqual(params.dict_id, d.dict_id()) | self.assertEqual(params.dict_id, d.dict_id()) | ||||
| self.assertEqual( | self.assertEqual( | ||||
| result, | result, | ||||
| b"\x28\xb5\x2f\xfd\x23\x8f\x55\x0f\x70\x03\x19\x00\x00" b"\x66\x6f\x6f", | b"\x28\xb5\x2f\xfd\x23\x8f\x55\x0f\x70\x03\x19\x00\x00" | ||||
| b"\x66\x6f\x6f", | |||||
| ) | ) | ||||
| def test_multithreaded_compression_params(self): | def test_multithreaded_compression_params(self): | ||||
| params = zstd.ZstdCompressionParameters.from_level(0, threads=2) | params = zstd.ZstdCompressionParameters.from_level(0, threads=2) | ||||
| cctx = zstd.ZstdCompressor(compression_params=params) | cctx = zstd.ZstdCompressor(compression_params=params) | ||||
| result = cctx.compress(b"foo") | result = cctx.compress(b"foo") | ||||
| params = zstd.get_frame_parameters(result) | params = zstd.get_frame_parameters(result) | ||||
| self.assertEqual(params.content_size, 3) | self.assertEqual(params.content_size, 3) | ||||
| self.assertEqual(result, b"\x28\xb5\x2f\xfd\x20\x03\x19\x00\x00\x66\x6f\x6f") | self.assertEqual( | ||||
| result, b"\x28\xb5\x2f\xfd\x20\x03\x19\x00\x00\x66\x6f\x6f" | |||||
| ) | |||||
| @make_cffi | @make_cffi | ||||
| class TestCompressor_compressobj(TestCase): | class TestCompressor_compressobj(TestCase): | ||||
| def test_compressobj_empty(self): | def test_compressobj_empty(self): | ||||
| cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | ||||
| cobj = cctx.compressobj() | cobj = cctx.compressobj() | ||||
| self.assertEqual(cobj.compress(b""), b"") | self.assertEqual(cobj.compress(b""), b"") | ||||
| self.assertEqual(cobj.compress(b"foo"), b"") | self.assertEqual(cobj.compress(b"foo"), b"") | ||||
| self.assertEqual( | self.assertEqual( | ||||
| cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), | cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), | ||||
| b"\x28\xb5\x2f\xfd\x00\x48\x18\x00\x00foo", | b"\x28\xb5\x2f\xfd\x00\x48\x18\x00\x00foo", | ||||
| ) | ) | ||||
| self.assertEqual(cobj.compress(b"bar"), b"") | self.assertEqual(cobj.compress(b"bar"), b"") | ||||
| # 3 byte header plus content. | # 3 byte header plus content. | ||||
| self.assertEqual(cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b"\x18\x00\x00bar") | self.assertEqual( | ||||
| cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b"\x18\x00\x00bar" | |||||
| ) | |||||
| self.assertEqual(cobj.flush(), b"\x01\x00\x00") | self.assertEqual(cobj.flush(), b"\x01\x00\x00") | ||||
| def test_flush_empty_block(self): | def test_flush_empty_block(self): | ||||
| cctx = zstd.ZstdCompressor(write_checksum=True) | cctx = zstd.ZstdCompressor(write_checksum=True) | ||||
| cobj = cctx.compressobj() | cobj = cctx.compressobj() | ||||
| cobj.compress(b"foobar") | cobj.compress(b"foobar") | ||||
| cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK) | cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK) | ||||
| source = io.BytesIO() | source = io.BytesIO() | ||||
| dest = io.BytesIO() | dest = io.BytesIO() | ||||
| cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | ||||
| r, w = cctx.copy_stream(source, dest) | r, w = cctx.copy_stream(source, dest) | ||||
| self.assertEqual(int(r), 0) | self.assertEqual(int(r), 0) | ||||
| self.assertEqual(w, 9) | self.assertEqual(w, 9) | ||||
| self.assertEqual(dest.getvalue(), b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00") | self.assertEqual( | ||||
| dest.getvalue(), b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00" | |||||
| ) | |||||
| def test_large_data(self): | def test_large_data(self): | ||||
| source = io.BytesIO() | source = io.BytesIO() | ||||
| for i in range(255): | for i in range(255): | ||||
| source.write(struct.Struct(">B").pack(i) * 16384) | source.write(struct.Struct(">B").pack(i) * 16384) | ||||
| source.seek(0) | source.seek(0) | ||||
| dest = io.BytesIO() | dest = io.BytesIO() | ||||
| cctx = zstd.ZstdCompressor(level=1) | cctx = zstd.ZstdCompressor(level=1) | ||||
| cctx.copy_stream(source, no_checksum) | cctx.copy_stream(source, no_checksum) | ||||
| source.seek(0) | source.seek(0) | ||||
| with_checksum = io.BytesIO() | with_checksum = io.BytesIO() | ||||
| cctx = zstd.ZstdCompressor(level=1, write_checksum=True) | cctx = zstd.ZstdCompressor(level=1, write_checksum=True) | ||||
| cctx.copy_stream(source, with_checksum) | cctx.copy_stream(source, with_checksum) | ||||
| self.assertEqual(len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4) | self.assertEqual( | ||||
| len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4 | |||||
| ) | |||||
| no_params = zstd.get_frame_parameters(no_checksum.getvalue()) | no_params = zstd.get_frame_parameters(no_checksum.getvalue()) | ||||
| with_params = zstd.get_frame_parameters(with_checksum.getvalue()) | with_params = zstd.get_frame_parameters(with_checksum.getvalue()) | ||||
| self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | ||||
| self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | ||||
| self.assertEqual(no_params.dict_id, 0) | self.assertEqual(no_params.dict_id, 0) | ||||
| self.assertEqual(with_params.dict_id, 0) | self.assertEqual(with_params.dict_id, 0) | ||||
| self.assertFalse(no_params.has_checksum) | self.assertFalse(no_params.has_checksum) | ||||
| @make_cffi | @make_cffi | ||||
| class TestCompressor_stream_reader(TestCase): | class TestCompressor_stream_reader(TestCase): | ||||
| def test_context_manager(self): | def test_context_manager(self): | ||||
| cctx = zstd.ZstdCompressor() | cctx = zstd.ZstdCompressor() | ||||
| with cctx.stream_reader(b"foo") as reader: | with cctx.stream_reader(b"foo") as reader: | ||||
| with self.assertRaisesRegex(ValueError, "cannot __enter__ multiple times"): | with self.assertRaisesRegex( | ||||
| ValueError, "cannot __enter__ multiple times" | |||||
| ): | |||||
| with reader as reader2: | with reader as reader2: | ||||
| pass | pass | ||||
| def test_no_context_manager(self): | def test_no_context_manager(self): | ||||
| cctx = zstd.ZstdCompressor() | cctx = zstd.ZstdCompressor() | ||||
| reader = cctx.stream_reader(b"foo") | reader = cctx.stream_reader(b"foo") | ||||
| reader.read(4) | reader.read(4) | ||||
| reader.read(10) | reader.read(10) | ||||
| def test_bad_size(self): | def test_bad_size(self): | ||||
| cctx = zstd.ZstdCompressor() | cctx = zstd.ZstdCompressor() | ||||
| source = io.BytesIO(b"foobar") | source = io.BytesIO(b"foobar") | ||||
| with cctx.stream_reader(source, size=2) as reader: | with cctx.stream_reader(source, size=2) as reader: | ||||
| with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"): | with self.assertRaisesRegex( | ||||
| zstd.ZstdError, "Src size is incorrect" | |||||
| ): | |||||
| reader.read(10) | reader.read(10) | ||||
| # Try another compression operation. | # Try another compression operation. | ||||
| with cctx.stream_reader(source, size=42): | with cctx.stream_reader(source, size=42): | ||||
| pass | pass | ||||
| def test_readall(self): | def test_readall(self): | ||||
| cctx = zstd.ZstdCompressor() | cctx = zstd.ZstdCompressor() | ||||
| with_params = zstd.get_frame_parameters(with_checksum.getvalue()) | with_params = zstd.get_frame_parameters(with_checksum.getvalue()) | ||||
| self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | ||||
| self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | ||||
| self.assertEqual(no_params.dict_id, 0) | self.assertEqual(no_params.dict_id, 0) | ||||
| self.assertEqual(with_params.dict_id, 0) | self.assertEqual(with_params.dict_id, 0) | ||||
| self.assertFalse(no_params.has_checksum) | self.assertFalse(no_params.has_checksum) | ||||
| self.assertTrue(with_params.has_checksum) | self.assertTrue(with_params.has_checksum) | ||||
| self.assertEqual(len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4) | self.assertEqual( | ||||
| len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4 | |||||
| ) | |||||
| def test_write_content_size(self): | def test_write_content_size(self): | ||||
| no_size = NonClosingBytesIO() | no_size = NonClosingBytesIO() | ||||
| cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | ||||
| with cctx.stream_writer(no_size) as compressor: | with cctx.stream_writer(no_size) as compressor: | ||||
| self.assertEqual(compressor.write(b"foobar" * 256), 0) | self.assertEqual(compressor.write(b"foobar" * 256), 0) | ||||
| with_size = NonClosingBytesIO() | with_size = NonClosingBytesIO() | ||||
| cctx = zstd.ZstdCompressor(level=1) | cctx = zstd.ZstdCompressor(level=1) | ||||
| with cctx.stream_writer(with_size) as compressor: | with cctx.stream_writer(with_size) as compressor: | ||||
| self.assertEqual(compressor.write(b"foobar" * 256), 0) | self.assertEqual(compressor.write(b"foobar" * 256), 0) | ||||
| # Source size is not known in streaming mode, so header not | # Source size is not known in streaming mode, so header not | ||||
| # written. | # written. | ||||
| self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue())) | self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue())) | ||||
| # Declaring size will write the header. | # Declaring size will write the header. | ||||
| with_size = NonClosingBytesIO() | with_size = NonClosingBytesIO() | ||||
| with cctx.stream_writer(with_size, size=len(b"foobar" * 256)) as compressor: | with cctx.stream_writer( | ||||
| with_size, size=len(b"foobar" * 256) | |||||
| ) as compressor: | |||||
| self.assertEqual(compressor.write(b"foobar" * 256), 0) | self.assertEqual(compressor.write(b"foobar" * 256), 0) | ||||
| no_params = zstd.get_frame_parameters(no_size.getvalue()) | no_params = zstd.get_frame_parameters(no_size.getvalue()) | ||||
| with_params = zstd.get_frame_parameters(with_size.getvalue()) | with_params = zstd.get_frame_parameters(with_size.getvalue()) | ||||
| self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | ||||
| self.assertEqual(with_params.content_size, 1536) | self.assertEqual(with_params.content_size, 1536) | ||||
| self.assertEqual(no_params.dict_id, 0) | self.assertEqual(no_params.dict_id, 0) | ||||
| self.assertEqual(with_params.dict_id, 0) | self.assertEqual(with_params.dict_id, 0) | ||||
| with_params = zstd.get_frame_parameters(with_dict_id.getvalue()) | with_params = zstd.get_frame_parameters(with_dict_id.getvalue()) | ||||
| self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | ||||
| self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | ||||
| self.assertEqual(no_params.dict_id, 0) | self.assertEqual(no_params.dict_id, 0) | ||||
| self.assertEqual(with_params.dict_id, d.dict_id()) | self.assertEqual(with_params.dict_id, d.dict_id()) | ||||
| self.assertFalse(no_params.has_checksum) | self.assertFalse(no_params.has_checksum) | ||||
| self.assertFalse(with_params.has_checksum) | self.assertFalse(with_params.has_checksum) | ||||
| self.assertEqual(len(with_dict_id.getvalue()), len(no_dict_id.getvalue()) + 4) | self.assertEqual( | ||||
| len(with_dict_id.getvalue()), len(no_dict_id.getvalue()) + 4 | |||||
| ) | |||||
| def test_memory_size(self): | def test_memory_size(self): | ||||
| cctx = zstd.ZstdCompressor(level=3) | cctx = zstd.ZstdCompressor(level=3) | ||||
| buffer = io.BytesIO() | buffer = io.BytesIO() | ||||
| with cctx.stream_writer(buffer) as compressor: | with cctx.stream_writer(buffer) as compressor: | ||||
| compressor.write(b"foo") | compressor.write(b"foo") | ||||
| size = compressor.memory_size() | size = compressor.memory_size() | ||||
| # Object with read() works. | # Object with read() works. | ||||
| for chunk in cctx.read_to_iter(io.BytesIO()): | for chunk in cctx.read_to_iter(io.BytesIO()): | ||||
| pass | pass | ||||
| # Buffer protocol works. | # Buffer protocol works. | ||||
| for chunk in cctx.read_to_iter(b"foobar"): | for chunk in cctx.read_to_iter(b"foobar"): | ||||
| pass | pass | ||||
| with self.assertRaisesRegex(ValueError, "must pass an object with a read"): | with self.assertRaisesRegex( | ||||
| ValueError, "must pass an object with a read" | |||||
| ): | |||||
| for chunk in cctx.read_to_iter(True): | for chunk in cctx.read_to_iter(True): | ||||
| pass | pass | ||||
| def test_read_empty(self): | def test_read_empty(self): | ||||
| cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | ||||
| source = io.BytesIO() | source = io.BytesIO() | ||||
| it = cctx.read_to_iter(source) | it = cctx.read_to_iter(source) | ||||
| [ | [ | ||||
| b"\x28\xb5\x2f\xfd\x60\x00\x03\x65\x00\x00\x18\x78\x78\x79\x02\x00" | b"\x28\xb5\x2f\xfd\x60\x00\x03\x65\x00\x00\x18\x78\x78\x79\x02\x00" | ||||
| b"\xa0\x16\xe3\x2b\x80\x05" | b"\xa0\x16\xe3\x2b\x80\x05" | ||||
| ], | ], | ||||
| ) | ) | ||||
| dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
| self.assertEqual(dctx.decompress(b"".join(chunks)), (b"x" * 1000) + (b"y" * 24)) | self.assertEqual( | ||||
| dctx.decompress(b"".join(chunks)), (b"x" * 1000) + (b"y" * 24) | |||||
| ) | |||||
| def test_small_chunk_size(self): | def test_small_chunk_size(self): | ||||
| cctx = zstd.ZstdCompressor() | cctx = zstd.ZstdCompressor() | ||||
| chunker = cctx.chunker(chunk_size=1) | chunker = cctx.chunker(chunk_size=1) | ||||
| chunks = list(chunker.compress(b"foo" * 1024)) | chunks = list(chunker.compress(b"foo" * 1024)) | ||||
| self.assertEqual(chunks, []) | self.assertEqual(chunks, []) | ||||
| chunks = list(chunker.finish()) | chunks = list(chunker.finish()) | ||||
| self.assertTrue(all(len(chunk) == 1 for chunk in chunks)) | self.assertTrue(all(len(chunk) == 1 for chunk in chunks)) | ||||
| self.assertEqual( | self.assertEqual( | ||||
| b"".join(chunks), | b"".join(chunks), | ||||
| b"\x28\xb5\x2f\xfd\x00\x58\x55\x00\x00\x18\x66\x6f\x6f\x01\x00" | b"\x28\xb5\x2f\xfd\x00\x58\x55\x00\x00\x18\x66\x6f\x6f\x01\x00" | ||||
| b"\xfa\xd3\x77\x43", | b"\xfa\xd3\x77\x43", | ||||
| ) | ) | ||||
| dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
| self.assertEqual( | self.assertEqual( | ||||
| dctx.decompress(b"".join(chunks), max_output_size=10000), b"foo" * 1024 | dctx.decompress(b"".join(chunks), max_output_size=10000), | ||||
| b"foo" * 1024, | |||||
| ) | ) | ||||
| def test_input_types(self): | def test_input_types(self): | ||||
| cctx = zstd.ZstdCompressor() | cctx = zstd.ZstdCompressor() | ||||
| mutable_array = bytearray(3) | mutable_array = bytearray(3) | ||||
| mutable_array[:] = b"foo" | mutable_array[:] = b"foo" | ||||
| def test_compress_after_finish(self): | def test_compress_after_finish(self): | ||||
| cctx = zstd.ZstdCompressor() | cctx = zstd.ZstdCompressor() | ||||
| chunker = cctx.chunker() | chunker = cctx.chunker() | ||||
| list(chunker.compress(b"foo")) | list(chunker.compress(b"foo")) | ||||
| list(chunker.finish()) | list(chunker.finish()) | ||||
| with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
| zstd.ZstdError, r"cannot call compress\(\) after compression finished" | zstd.ZstdError, | ||||
| r"cannot call compress\(\) after compression finished", | |||||
| ): | ): | ||||
| list(chunker.compress(b"foo")) | list(chunker.compress(b"foo")) | ||||
| def test_flush_after_finish(self): | def test_flush_after_finish(self): | ||||
| cctx = zstd.ZstdCompressor() | cctx = zstd.ZstdCompressor() | ||||
| chunker = cctx.chunker() | chunker = cctx.chunker() | ||||
| list(chunker.compress(b"foo")) | list(chunker.compress(b"foo")) | ||||
| self.skipTest("multi_compress_to_buffer not available") | self.skipTest("multi_compress_to_buffer not available") | ||||
| with self.assertRaises(TypeError): | with self.assertRaises(TypeError): | ||||
| cctx.multi_compress_to_buffer(True) | cctx.multi_compress_to_buffer(True) | ||||
| with self.assertRaises(TypeError): | with self.assertRaises(TypeError): | ||||
| cctx.multi_compress_to_buffer((1, 2)) | cctx.multi_compress_to_buffer((1, 2)) | ||||
| with self.assertRaisesRegex(TypeError, "item 0 not a bytes like object"): | with self.assertRaisesRegex( | ||||
| TypeError, "item 0 not a bytes like object" | |||||
| ): | |||||
| cctx.multi_compress_to_buffer([u"foo"]) | cctx.multi_compress_to_buffer([u"foo"]) | ||||
| def test_empty_input(self): | def test_empty_input(self): | ||||
| cctx = zstd.ZstdCompressor() | cctx = zstd.ZstdCompressor() | ||||
| if not hasattr(cctx, "multi_compress_to_buffer"): | if not hasattr(cctx, "multi_compress_to_buffer"): | ||||
| self.skipTest("multi_compress_to_buffer not available") | self.skipTest("multi_compress_to_buffer not available") | ||||
| class TestCompressor_stream_reader_fuzzing(TestCase): | class TestCompressor_stream_reader_fuzzing(TestCase): | ||||
| @hypothesis.settings( | @hypothesis.settings( | ||||
| suppress_health_check=[hypothesis.HealthCheck.large_base_example] | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | ||||
| ) | ) | ||||
| @hypothesis.given( | @hypothesis.given( | ||||
| original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
| level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
| source_read_size=strategies.integers(1, 16384), | source_read_size=strategies.integers(1, 16384), | ||||
| read_size=strategies.integers(-1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE), | read_size=strategies.integers( | ||||
| -1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||||
| ), | |||||
| ) | ) | ||||
| def test_stream_source_read(self, original, level, source_read_size, read_size): | def test_stream_source_read( | ||||
| self, original, level, source_read_size, read_size | |||||
| ): | |||||
| if read_size == 0: | if read_size == 0: | ||||
| read_size = -1 | read_size = -1 | ||||
| refctx = zstd.ZstdCompressor(level=level) | refctx = zstd.ZstdCompressor(level=level) | ||||
| ref_frame = refctx.compress(original) | ref_frame = refctx.compress(original) | ||||
| cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
| with cctx.stream_reader( | with cctx.stream_reader( | ||||
| @hypothesis.settings( | @hypothesis.settings( | ||||
| suppress_health_check=[hypothesis.HealthCheck.large_base_example] | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | ||||
| ) | ) | ||||
| @hypothesis.given( | @hypothesis.given( | ||||
| original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
| level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
| source_read_size=strategies.integers(1, 16384), | source_read_size=strategies.integers(1, 16384), | ||||
| read_size=strategies.integers(-1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE), | read_size=strategies.integers( | ||||
| -1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||||
| ), | |||||
| ) | ) | ||||
| def test_buffer_source_read(self, original, level, source_read_size, read_size): | def test_buffer_source_read( | ||||
| self, original, level, source_read_size, read_size | |||||
| ): | |||||
| if read_size == 0: | if read_size == 0: | ||||
| read_size = -1 | read_size = -1 | ||||
| refctx = zstd.ZstdCompressor(level=level) | refctx = zstd.ZstdCompressor(level=level) | ||||
| ref_frame = refctx.compress(original) | ref_frame = refctx.compress(original) | ||||
| cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
| with cctx.stream_reader( | with cctx.stream_reader( | ||||
| @hypothesis.settings( | @hypothesis.settings( | ||||
| suppress_health_check=[hypothesis.HealthCheck.large_base_example] | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | ||||
| ) | ) | ||||
| @hypothesis.given( | @hypothesis.given( | ||||
| original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
| level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
| source_read_size=strategies.integers(1, 16384), | source_read_size=strategies.integers(1, 16384), | ||||
| read_size=strategies.integers(1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE), | read_size=strategies.integers( | ||||
| 1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||||
| ), | |||||
| ) | ) | ||||
| def test_stream_source_readinto(self, original, level, source_read_size, read_size): | def test_stream_source_readinto( | ||||
| self, original, level, source_read_size, read_size | |||||
| ): | |||||
| refctx = zstd.ZstdCompressor(level=level) | refctx = zstd.ZstdCompressor(level=level) | ||||
| ref_frame = refctx.compress(original) | ref_frame = refctx.compress(original) | ||||
| cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
| with cctx.stream_reader( | with cctx.stream_reader( | ||||
| io.BytesIO(original), size=len(original), read_size=source_read_size | io.BytesIO(original), size=len(original), read_size=source_read_size | ||||
| ) as reader: | ) as reader: | ||||
| chunks = [] | chunks = [] | ||||
| @hypothesis.settings( | @hypothesis.settings( | ||||
| suppress_health_check=[hypothesis.HealthCheck.large_base_example] | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | ||||
| ) | ) | ||||
| @hypothesis.given( | @hypothesis.given( | ||||
| original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
| level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
| source_read_size=strategies.integers(1, 16384), | source_read_size=strategies.integers(1, 16384), | ||||
| read_size=strategies.integers(1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE), | read_size=strategies.integers( | ||||
| 1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||||
| ), | |||||
| ) | ) | ||||
| def test_buffer_source_readinto(self, original, level, source_read_size, read_size): | def test_buffer_source_readinto( | ||||
| self, original, level, source_read_size, read_size | |||||
| ): | |||||
| refctx = zstd.ZstdCompressor(level=level) | refctx = zstd.ZstdCompressor(level=level) | ||||
| ref_frame = refctx.compress(original) | ref_frame = refctx.compress(original) | ||||
| cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
| with cctx.stream_reader( | with cctx.stream_reader( | ||||
| original, size=len(original), read_size=source_read_size | original, size=len(original), read_size=source_read_size | ||||
| ) as reader: | ) as reader: | ||||
| @hypothesis.settings( | @hypothesis.settings( | ||||
| suppress_health_check=[hypothesis.HealthCheck.large_base_example] | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | ||||
| ) | ) | ||||
| @hypothesis.given( | @hypothesis.given( | ||||
| original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
| level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
| source_read_size=strategies.integers(1, 16384), | source_read_size=strategies.integers(1, 16384), | ||||
| read_size=strategies.integers(-1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE), | read_size=strategies.integers( | ||||
| -1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||||
| ), | |||||
| ) | ) | ||||
| def test_stream_source_read1(self, original, level, source_read_size, read_size): | def test_stream_source_read1( | ||||
| self, original, level, source_read_size, read_size | |||||
| ): | |||||
| if read_size == 0: | if read_size == 0: | ||||
| read_size = -1 | read_size = -1 | ||||
| refctx = zstd.ZstdCompressor(level=level) | refctx = zstd.ZstdCompressor(level=level) | ||||
| ref_frame = refctx.compress(original) | ref_frame = refctx.compress(original) | ||||
| cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
| with cctx.stream_reader( | with cctx.stream_reader( | ||||
| @hypothesis.settings( | @hypothesis.settings( | ||||
| suppress_health_check=[hypothesis.HealthCheck.large_base_example] | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | ||||
| ) | ) | ||||
| @hypothesis.given( | @hypothesis.given( | ||||
| original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
| level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
| source_read_size=strategies.integers(1, 16384), | source_read_size=strategies.integers(1, 16384), | ||||
| read_size=strategies.integers(-1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE), | read_size=strategies.integers( | ||||
| -1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||||
| ), | |||||
| ) | ) | ||||
| def test_buffer_source_read1(self, original, level, source_read_size, read_size): | def test_buffer_source_read1( | ||||
| self, original, level, source_read_size, read_size | |||||
| ): | |||||
| if read_size == 0: | if read_size == 0: | ||||
| read_size = -1 | read_size = -1 | ||||
| refctx = zstd.ZstdCompressor(level=level) | refctx = zstd.ZstdCompressor(level=level) | ||||
| ref_frame = refctx.compress(original) | ref_frame = refctx.compress(original) | ||||
| cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
| with cctx.stream_reader( | with cctx.stream_reader( | ||||
| @hypothesis.settings( | @hypothesis.settings( | ||||
| suppress_health_check=[hypothesis.HealthCheck.large_base_example] | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | ||||
| ) | ) | ||||
| @hypothesis.given( | @hypothesis.given( | ||||
| original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
| level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
| source_read_size=strategies.integers(1, 16384), | source_read_size=strategies.integers(1, 16384), | ||||
| read_size=strategies.integers(1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE), | read_size=strategies.integers( | ||||
| 1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||||
| ), | |||||
| ) | ) | ||||
| def test_stream_source_readinto1( | def test_stream_source_readinto1( | ||||
| self, original, level, source_read_size, read_size | self, original, level, source_read_size, read_size | ||||
| ): | ): | ||||
| if read_size == 0: | if read_size == 0: | ||||
| read_size = -1 | read_size = -1 | ||||
| refctx = zstd.ZstdCompressor(level=level) | refctx = zstd.ZstdCompressor(level=level) | ||||
| @hypothesis.settings( | @hypothesis.settings( | ||||
| suppress_health_check=[hypothesis.HealthCheck.large_base_example] | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | ||||
| ) | ) | ||||
| @hypothesis.given( | @hypothesis.given( | ||||
| original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
| level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
| source_read_size=strategies.integers(1, 16384), | source_read_size=strategies.integers(1, 16384), | ||||
| read_size=strategies.integers(1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE), | read_size=strategies.integers( | ||||
| 1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||||
| ), | |||||
| ) | ) | ||||
| def test_buffer_source_readinto1( | def test_buffer_source_readinto1( | ||||
| self, original, level, source_read_size, read_size | self, original, level, source_read_size, read_size | ||||
| ): | ): | ||||
| if read_size == 0: | if read_size == 0: | ||||
| read_size = -1 | read_size = -1 | ||||
| refctx = zstd.ZstdCompressor(level=level) | refctx = zstd.ZstdCompressor(level=level) | ||||
| @make_cffi | @make_cffi | ||||
| class TestCompressor_copy_stream_fuzzing(TestCase): | class TestCompressor_copy_stream_fuzzing(TestCase): | ||||
| @hypothesis.given( | @hypothesis.given( | ||||
| original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
| level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
| read_size=strategies.integers(min_value=1, max_value=1048576), | read_size=strategies.integers(min_value=1, max_value=1048576), | ||||
| write_size=strategies.integers(min_value=1, max_value=1048576), | write_size=strategies.integers(min_value=1, max_value=1048576), | ||||
| ) | ) | ||||
| def test_read_write_size_variance(self, original, level, read_size, write_size): | def test_read_write_size_variance( | ||||
| self, original, level, read_size, write_size | |||||
| ): | |||||
| refctx = zstd.ZstdCompressor(level=level) | refctx = zstd.ZstdCompressor(level=level) | ||||
| ref_frame = refctx.compress(original) | ref_frame = refctx.compress(original) | ||||
| cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
| source = io.BytesIO(original) | source = io.BytesIO(original) | ||||
| dest = io.BytesIO() | dest = io.BytesIO() | ||||
| cctx.copy_stream( | cctx.copy_stream( | ||||
| source, dest, size=len(original), read_size=read_size, write_size=write_size | source, | ||||
| dest, | |||||
| size=len(original), | |||||
| read_size=read_size, | |||||
| write_size=write_size, | |||||
| ) | ) | ||||
| self.assertEqual(dest.getvalue(), ref_frame) | self.assertEqual(dest.getvalue(), ref_frame) | ||||
| @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | ||||
| @make_cffi | @make_cffi | ||||
| class TestCompressor_compressobj_fuzzing(TestCase): | class TestCompressor_compressobj_fuzzing(TestCase): | ||||
| self.assertEqual(b"".join(decompressed_chunks), original[0:i]) | self.assertEqual(b"".join(decompressed_chunks), original[0:i]) | ||||
| chunk = cobj.flush(zstd.COMPRESSOBJ_FLUSH_FINISH) | chunk = cobj.flush(zstd.COMPRESSOBJ_FLUSH_FINISH) | ||||
| compressed_chunks.append(chunk) | compressed_chunks.append(chunk) | ||||
| decompressed_chunks.append(dobj.decompress(chunk)) | decompressed_chunks.append(dobj.decompress(chunk)) | ||||
| self.assertEqual( | self.assertEqual( | ||||
| dctx.decompress(b"".join(compressed_chunks), max_output_size=len(original)), | dctx.decompress( | ||||
| b"".join(compressed_chunks), max_output_size=len(original) | |||||
| ), | |||||
| original, | original, | ||||
| ) | ) | ||||
| self.assertEqual(b"".join(decompressed_chunks), original) | self.assertEqual(b"".join(decompressed_chunks), original) | ||||
| @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | ||||
| @make_cffi | @make_cffi | ||||
| class TestCompressor_read_to_iter_fuzzing(TestCase): | class TestCompressor_read_to_iter_fuzzing(TestCase): | ||||
| @hypothesis.given( | @hypothesis.given( | ||||
| original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
| level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
| read_size=strategies.integers(min_value=1, max_value=4096), | read_size=strategies.integers(min_value=1, max_value=4096), | ||||
| write_size=strategies.integers(min_value=1, max_value=4096), | write_size=strategies.integers(min_value=1, max_value=4096), | ||||
| ) | ) | ||||
| def test_read_write_size_variance(self, original, level, read_size, write_size): | def test_read_write_size_variance( | ||||
| self, original, level, read_size, write_size | |||||
| ): | |||||
| refcctx = zstd.ZstdCompressor(level=level) | refcctx = zstd.ZstdCompressor(level=level) | ||||
| ref_frame = refcctx.compress(original) | ref_frame = refcctx.compress(original) | ||||
| source = io.BytesIO(original) | source = io.BytesIO(original) | ||||
| cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
| chunks = list( | chunks = list( | ||||
| cctx.read_to_iter( | cctx.read_to_iter( | ||||
| source, size=len(original), read_size=read_size, write_size=write_size | source, | ||||
| size=len(original), | |||||
| read_size=read_size, | |||||
| write_size=write_size, | |||||
| ) | ) | ||||
| ) | ) | ||||
| self.assertEqual(b"".join(chunks), ref_frame) | self.assertEqual(b"".join(chunks), ref_frame) | ||||
| @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | ||||
| class TestCompressor_multi_compress_to_buffer_fuzzing(TestCase): | class TestCompressor_multi_compress_to_buffer_fuzzing(TestCase): | ||||
| @hypothesis.given( | @hypothesis.given( | ||||
| original=strategies.lists( | original=strategies.lists( | ||||
| strategies.sampled_from(random_input_data()), min_size=1, max_size=1024 | strategies.sampled_from(random_input_data()), | ||||
| min_size=1, | |||||
| max_size=1024, | |||||
| ), | ), | ||||
| threads=strategies.integers(min_value=1, max_value=8), | threads=strategies.integers(min_value=1, max_value=8), | ||||
| use_dict=strategies.booleans(), | use_dict=strategies.booleans(), | ||||
| ) | ) | ||||
| def test_data_equivalence(self, original, threads, use_dict): | def test_data_equivalence(self, original, threads, use_dict): | ||||
| kwargs = {} | kwargs = {} | ||||
| # Use a content dictionary because it is cheap to create. | # Use a content dictionary because it is cheap to create. | ||||
| chunks.extend(chunker.compress(source)) | chunks.extend(chunker.compress(source)) | ||||
| i += input_size | i += input_size | ||||
| chunks.extend(chunker.finish()) | chunks.extend(chunker.finish()) | ||||
| dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
| self.assertEqual( | self.assertEqual( | ||||
| dctx.decompress(b"".join(chunks), max_output_size=len(original)), original | dctx.decompress(b"".join(chunks), max_output_size=len(original)), | ||||
| original, | |||||
| ) | ) | ||||
| self.assertTrue(all(len(chunk) == chunk_size for chunk in chunks[:-1])) | self.assertTrue(all(len(chunk) == chunk_size for chunk in chunks[:-1])) | ||||
| @hypothesis.settings( | @hypothesis.settings( | ||||
| suppress_health_check=[ | suppress_health_check=[ | ||||
| hypothesis.HealthCheck.large_base_example, | hypothesis.HealthCheck.large_base_example, | ||||
| hypothesis.HealthCheck.too_slow, | hypothesis.HealthCheck.too_slow, | ||||
| ] | ] | ||||
| ) | ) | ||||
| @hypothesis.given( | @hypothesis.given( | ||||
| original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
| level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
| chunk_size=strategies.integers(min_value=1, max_value=32 * 1048576), | chunk_size=strategies.integers(min_value=1, max_value=32 * 1048576), | ||||
| input_sizes=strategies.data(), | input_sizes=strategies.data(), | ||||
| flushes=strategies.data(), | flushes=strategies.data(), | ||||
| ) | ) | ||||
| def test_flush_block(self, original, level, chunk_size, input_sizes, flushes): | def test_flush_block( | ||||
| self, original, level, chunk_size, input_sizes, flushes | |||||
| ): | |||||
| cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
| chunker = cctx.chunker(chunk_size=chunk_size) | chunker = cctx.chunker(chunk_size=chunk_size) | ||||
| dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
| dobj = dctx.decompressobj() | dobj = dctx.decompressobj() | ||||
| compressed_chunks = [] | compressed_chunks = [] | ||||
| decompressed_chunks = [] | decompressed_chunks = [] | ||||
| self.assertEqual(b"".join(decompressed_chunks), original[0:i]) | self.assertEqual(b"".join(decompressed_chunks), original[0:i]) | ||||
| chunks = list(chunker.finish()) | chunks = list(chunker.finish()) | ||||
| compressed_chunks.extend(chunks) | compressed_chunks.extend(chunks) | ||||
| decompressed_chunks.append(dobj.decompress(b"".join(chunks))) | decompressed_chunks.append(dobj.decompress(b"".join(chunks))) | ||||
| self.assertEqual( | self.assertEqual( | ||||
| dctx.decompress(b"".join(compressed_chunks), max_output_size=len(original)), | dctx.decompress( | ||||
| b"".join(compressed_chunks), max_output_size=len(original) | |||||
| ), | |||||
| original, | original, | ||||
| ) | ) | ||||
| self.assertEqual(b"".join(decompressed_chunks), original) | self.assertEqual(b"".join(decompressed_chunks), original) | ||||
| self.assertEqual(p.compression_strategy, 1) | self.assertEqual(p.compression_strategy, 1) | ||||
| p = zstd.ZstdCompressionParameters(compression_level=2) | p = zstd.ZstdCompressionParameters(compression_level=2) | ||||
| self.assertEqual(p.compression_level, 2) | self.assertEqual(p.compression_level, 2) | ||||
| p = zstd.ZstdCompressionParameters(threads=4) | p = zstd.ZstdCompressionParameters(threads=4) | ||||
| self.assertEqual(p.threads, 4) | self.assertEqual(p.threads, 4) | ||||
| p = zstd.ZstdCompressionParameters(threads=2, job_size=1048576, overlap_log=6) | p = zstd.ZstdCompressionParameters( | ||||
| threads=2, job_size=1048576, overlap_log=6 | |||||
| ) | |||||
| self.assertEqual(p.threads, 2) | self.assertEqual(p.threads, 2) | ||||
| self.assertEqual(p.job_size, 1048576) | self.assertEqual(p.job_size, 1048576) | ||||
| self.assertEqual(p.overlap_log, 6) | self.assertEqual(p.overlap_log, 6) | ||||
| self.assertEqual(p.overlap_size_log, 6) | self.assertEqual(p.overlap_size_log, 6) | ||||
| p = zstd.ZstdCompressionParameters(compression_level=-1) | p = zstd.ZstdCompressionParameters(compression_level=-1) | ||||
| self.assertEqual(p.compression_level, -1) | self.assertEqual(p.compression_level, -1) | ||||
| p = zstd.ZstdCompressionParameters(strategy=3) | p = zstd.ZstdCompressionParameters(strategy=3) | ||||
| self.assertEqual(p.compression_strategy, 3) | self.assertEqual(p.compression_strategy, 3) | ||||
| def test_ldm_hash_rate_log(self): | def test_ldm_hash_rate_log(self): | ||||
| with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
| ValueError, "cannot specify both ldm_hash_rate_log" | ValueError, "cannot specify both ldm_hash_rate_log" | ||||
| ): | ): | ||||
| zstd.ZstdCompressionParameters(ldm_hash_rate_log=8, ldm_hash_every_log=4) | zstd.ZstdCompressionParameters( | ||||
| ldm_hash_rate_log=8, ldm_hash_every_log=4 | |||||
| ) | |||||
| p = zstd.ZstdCompressionParameters(ldm_hash_rate_log=8) | p = zstd.ZstdCompressionParameters(ldm_hash_rate_log=8) | ||||
| self.assertEqual(p.ldm_hash_every_log, 8) | self.assertEqual(p.ldm_hash_every_log, 8) | ||||
| p = zstd.ZstdCompressionParameters(ldm_hash_every_log=16) | p = zstd.ZstdCompressionParameters(ldm_hash_every_log=16) | ||||
| self.assertEqual(p.ldm_hash_every_log, 16) | self.assertEqual(p.ldm_hash_every_log, 16) | ||||
| def test_overlap_log(self): | def test_overlap_log(self): | ||||
| with self.assertRaisesRegex(ValueError, "cannot specify both overlap_log"): | with self.assertRaisesRegex( | ||||
| ValueError, "cannot specify both overlap_log" | |||||
| ): | |||||
| zstd.ZstdCompressionParameters(overlap_log=1, overlap_size_log=9) | zstd.ZstdCompressionParameters(overlap_log=1, overlap_size_log=9) | ||||
| p = zstd.ZstdCompressionParameters(overlap_log=2) | p = zstd.ZstdCompressionParameters(overlap_log=2) | ||||
| self.assertEqual(p.overlap_log, 2) | self.assertEqual(p.overlap_log, 2) | ||||
| self.assertEqual(p.overlap_size_log, 2) | self.assertEqual(p.overlap_size_log, 2) | ||||
| p = zstd.ZstdCompressionParameters(overlap_size_log=4) | p = zstd.ZstdCompressionParameters(overlap_size_log=4) | ||||
| self.assertEqual(p.overlap_log, 4) | self.assertEqual(p.overlap_log, 4) | ||||
| if zstd.backend == "cffi": | if zstd.backend == "cffi": | ||||
| with self.assertRaises(TypeError): | with self.assertRaises(TypeError): | ||||
| zstd.get_frame_parameters(u"foobarbaz") | zstd.get_frame_parameters(u"foobarbaz") | ||||
| else: | else: | ||||
| with self.assertRaises(zstd.ZstdError): | with self.assertRaises(zstd.ZstdError): | ||||
| zstd.get_frame_parameters(u"foobarbaz") | zstd.get_frame_parameters(u"foobarbaz") | ||||
| def test_invalid_input_sizes(self): | def test_invalid_input_sizes(self): | ||||
| with self.assertRaisesRegex(zstd.ZstdError, "not enough data for frame"): | with self.assertRaisesRegex( | ||||
| zstd.ZstdError, "not enough data for frame" | |||||
| ): | |||||
| zstd.get_frame_parameters(b"") | zstd.get_frame_parameters(b"") | ||||
| with self.assertRaisesRegex(zstd.ZstdError, "not enough data for frame"): | with self.assertRaisesRegex( | ||||
| zstd.ZstdError, "not enough data for frame" | |||||
| ): | |||||
| zstd.get_frame_parameters(zstd.FRAME_HEADER) | zstd.get_frame_parameters(zstd.FRAME_HEADER) | ||||
| def test_invalid_frame(self): | def test_invalid_frame(self): | ||||
| with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"): | with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"): | ||||
| zstd.get_frame_parameters(b"foobarbaz") | zstd.get_frame_parameters(b"foobarbaz") | ||||
| def test_attributes(self): | def test_attributes(self): | ||||
| params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x00\x00") | params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x00\x00") | ||||
| # Lowest 3rd bit indicates if checksum is present. | # Lowest 3rd bit indicates if checksum is present. | ||||
| params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x04\x00") | params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x04\x00") | ||||
| self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | ||||
| self.assertEqual(params.window_size, 1024) | self.assertEqual(params.window_size, 1024) | ||||
| self.assertEqual(params.dict_id, 0) | self.assertEqual(params.dict_id, 0) | ||||
| self.assertTrue(params.has_checksum) | self.assertTrue(params.has_checksum) | ||||
| # Upper 2 bits indicate content size. | # Upper 2 bits indicate content size. | ||||
| params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x40\x00\xff\x00") | params = zstd.get_frame_parameters( | ||||
| zstd.FRAME_HEADER + b"\x40\x00\xff\x00" | |||||
| ) | |||||
| self.assertEqual(params.content_size, 511) | self.assertEqual(params.content_size, 511) | ||||
| self.assertEqual(params.window_size, 1024) | self.assertEqual(params.window_size, 1024) | ||||
| self.assertEqual(params.dict_id, 0) | self.assertEqual(params.dict_id, 0) | ||||
| self.assertFalse(params.has_checksum) | self.assertFalse(params.has_checksum) | ||||
| # Window descriptor is 2nd byte after frame header. | # Window descriptor is 2nd byte after frame header. | ||||
| params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x00\x40") | params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x00\x40") | ||||
| self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | ||||
| self.assertEqual(params.window_size, 262144) | self.assertEqual(params.window_size, 262144) | ||||
| self.assertEqual(params.dict_id, 0) | self.assertEqual(params.dict_id, 0) | ||||
| self.assertFalse(params.has_checksum) | self.assertFalse(params.has_checksum) | ||||
| # Set multiple things. | # Set multiple things. | ||||
| params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x45\x40\x0f\x10\x00") | params = zstd.get_frame_parameters( | ||||
| zstd.FRAME_HEADER + b"\x45\x40\x0f\x10\x00" | |||||
| ) | |||||
| self.assertEqual(params.content_size, 272) | self.assertEqual(params.content_size, 272) | ||||
| self.assertEqual(params.window_size, 262144) | self.assertEqual(params.window_size, 262144) | ||||
| self.assertEqual(params.dict_id, 15) | self.assertEqual(params.dict_id, 15) | ||||
| self.assertTrue(params.has_checksum) | self.assertTrue(params.has_checksum) | ||||
| def test_input_types(self): | def test_input_types(self): | ||||
| v = zstd.FRAME_HEADER + b"\x00\x00" | v = zstd.FRAME_HEADER + b"\x00\x00" | ||||
| s_windowlog = strategies.integers( | s_windowlog = strategies.integers( | ||||
| min_value=zstd.WINDOWLOG_MIN, max_value=zstd.WINDOWLOG_MAX | min_value=zstd.WINDOWLOG_MIN, max_value=zstd.WINDOWLOG_MAX | ||||
| ) | ) | ||||
| s_chainlog = strategies.integers( | s_chainlog = strategies.integers( | ||||
| min_value=zstd.CHAINLOG_MIN, max_value=zstd.CHAINLOG_MAX | min_value=zstd.CHAINLOG_MIN, max_value=zstd.CHAINLOG_MAX | ||||
| ) | ) | ||||
| s_hashlog = strategies.integers(min_value=zstd.HASHLOG_MIN, max_value=zstd.HASHLOG_MAX) | s_hashlog = strategies.integers( | ||||
| min_value=zstd.HASHLOG_MIN, max_value=zstd.HASHLOG_MAX | |||||
| ) | |||||
| s_searchlog = strategies.integers( | s_searchlog = strategies.integers( | ||||
| min_value=zstd.SEARCHLOG_MIN, max_value=zstd.SEARCHLOG_MAX | min_value=zstd.SEARCHLOG_MIN, max_value=zstd.SEARCHLOG_MAX | ||||
| ) | ) | ||||
| s_minmatch = strategies.integers( | s_minmatch = strategies.integers( | ||||
| min_value=zstd.MINMATCH_MIN, max_value=zstd.MINMATCH_MAX | min_value=zstd.MINMATCH_MIN, max_value=zstd.MINMATCH_MAX | ||||
| ) | ) | ||||
| s_targetlength = strategies.integers( | s_targetlength = strategies.integers( | ||||
| min_value=zstd.TARGETLENGTH_MIN, max_value=zstd.TARGETLENGTH_MAX | min_value=zstd.TARGETLENGTH_MIN, max_value=zstd.TARGETLENGTH_MAX | ||||
| s_chainlog, | s_chainlog, | ||||
| s_hashlog, | s_hashlog, | ||||
| s_searchlog, | s_searchlog, | ||||
| s_minmatch, | s_minmatch, | ||||
| s_targetlength, | s_targetlength, | ||||
| s_strategy, | s_strategy, | ||||
| ) | ) | ||||
| def test_valid_init( | def test_valid_init( | ||||
| self, windowlog, chainlog, hashlog, searchlog, minmatch, targetlength, strategy | self, | ||||
| windowlog, | |||||
| chainlog, | |||||
| hashlog, | |||||
| searchlog, | |||||
| minmatch, | |||||
| targetlength, | |||||
| strategy, | |||||
| ): | ): | ||||
| zstd.ZstdCompressionParameters( | zstd.ZstdCompressionParameters( | ||||
| window_log=windowlog, | window_log=windowlog, | ||||
| chain_log=chainlog, | chain_log=chainlog, | ||||
| hash_log=hashlog, | hash_log=hashlog, | ||||
| search_log=searchlog, | search_log=searchlog, | ||||
| min_match=minmatch, | min_match=minmatch, | ||||
| target_length=targetlength, | target_length=targetlength, | ||||
| strategy=strategy, | strategy=strategy, | ||||
| ) | ) | ||||
| @hypothesis.given( | @hypothesis.given( | ||||
| s_windowlog, | s_windowlog, | ||||
| s_chainlog, | s_chainlog, | ||||
| s_hashlog, | s_hashlog, | ||||
| s_searchlog, | s_searchlog, | ||||
| s_minmatch, | s_minmatch, | ||||
| s_targetlength, | s_targetlength, | ||||
| s_strategy, | s_strategy, | ||||
| ) | ) | ||||
| def test_estimated_compression_context_size( | def test_estimated_compression_context_size( | ||||
| self, windowlog, chainlog, hashlog, searchlog, minmatch, targetlength, strategy | self, | ||||
| windowlog, | |||||
| chainlog, | |||||
| hashlog, | |||||
| searchlog, | |||||
| minmatch, | |||||
| targetlength, | |||||
| strategy, | |||||
| ): | ): | ||||
| if minmatch == zstd.MINMATCH_MIN and strategy in ( | if minmatch == zstd.MINMATCH_MIN and strategy in ( | ||||
| zstd.STRATEGY_FAST, | zstd.STRATEGY_FAST, | ||||
| zstd.STRATEGY_GREEDY, | zstd.STRATEGY_GREEDY, | ||||
| ): | ): | ||||
| minmatch += 1 | minmatch += 1 | ||||
| elif minmatch == zstd.MINMATCH_MAX and strategy != zstd.STRATEGY_FAST: | elif minmatch == zstd.MINMATCH_MAX and strategy != zstd.STRATEGY_FAST: | ||||
| minmatch -= 1 | minmatch -= 1 | ||||
| # Input size - 1 fails | # Input size - 1 fails | ||||
| with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
| zstd.ZstdError, "decompression error: did not decompress full frame" | zstd.ZstdError, "decompression error: did not decompress full frame" | ||||
| ): | ): | ||||
| dctx.decompress(compressed, max_output_size=len(source) - 1) | dctx.decompress(compressed, max_output_size=len(source) - 1) | ||||
| # Input size + 1 works | # Input size + 1 works | ||||
| decompressed = dctx.decompress(compressed, max_output_size=len(source) + 1) | decompressed = dctx.decompress( | ||||
| compressed, max_output_size=len(source) + 1 | |||||
| ) | |||||
| self.assertEqual(decompressed, source) | self.assertEqual(decompressed, source) | ||||
| # A much larger buffer works. | # A much larger buffer works. | ||||
| decompressed = dctx.decompress(compressed, max_output_size=len(source) * 64) | decompressed = dctx.decompress( | ||||
| compressed, max_output_size=len(source) * 64 | |||||
| ) | |||||
| self.assertEqual(decompressed, source) | self.assertEqual(decompressed, source) | ||||
| def test_stupidly_large_output_buffer(self): | def test_stupidly_large_output_buffer(self): | ||||
| cctx = zstd.ZstdCompressor(write_content_size=False) | cctx = zstd.ZstdCompressor(write_content_size=False) | ||||
| compressed = cctx.compress(b"foobar" * 256) | compressed = cctx.compress(b"foobar" * 256) | ||||
| dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
| # Will get OverflowError on some Python distributions that can't | # Will get OverflowError on some Python distributions that can't | ||||
| # If we write a content size, the decompressor engages single pass | # If we write a content size, the decompressor engages single pass | ||||
| # mode and the window size doesn't come into play. | # mode and the window size doesn't come into play. | ||||
| cctx = zstd.ZstdCompressor(write_content_size=False) | cctx = zstd.ZstdCompressor(write_content_size=False) | ||||
| frame = cctx.compress(source) | frame = cctx.compress(source) | ||||
| dctx = zstd.ZstdDecompressor(max_window_size=2 ** zstd.WINDOWLOG_MIN) | dctx = zstd.ZstdDecompressor(max_window_size=2 ** zstd.WINDOWLOG_MIN) | ||||
| with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
| zstd.ZstdError, "decompression error: Frame requires too much memory" | zstd.ZstdError, | ||||
| "decompression error: Frame requires too much memory", | |||||
| ): | ): | ||||
| dctx.decompress(frame, max_output_size=len(source)) | dctx.decompress(frame, max_output_size=len(source)) | ||||
| @make_cffi | @make_cffi | ||||
| class TestDecompressor_copy_stream(TestCase): | class TestDecompressor_copy_stream(TestCase): | ||||
| def test_no_read(self): | def test_no_read(self): | ||||
| source = object() | source = object() | ||||
| dest = io.BytesIO() | dest = io.BytesIO() | ||||
| dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
| r, w = dctx.copy_stream(compressed, dest) | r, w = dctx.copy_stream(compressed, dest) | ||||
| self.assertEqual(r, len(compressed.getvalue())) | self.assertEqual(r, len(compressed.getvalue())) | ||||
| self.assertEqual(w, len(source.getvalue())) | self.assertEqual(w, len(source.getvalue())) | ||||
| def test_read_write_size(self): | def test_read_write_size(self): | ||||
| source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b"foobarfoobar")) | source = OpCountingBytesIO( | ||||
| zstd.ZstdCompressor().compress(b"foobarfoobar") | |||||
| ) | |||||
| dest = OpCountingBytesIO() | dest = OpCountingBytesIO() | ||||
| dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
| r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1) | r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1) | ||||
| self.assertEqual(r, len(source.getvalue())) | self.assertEqual(r, len(source.getvalue())) | ||||
| self.assertEqual(w, len(b"foobarfoobar")) | self.assertEqual(w, len(b"foobarfoobar")) | ||||
| self.assertEqual(source._read_count, len(source.getvalue()) + 1) | self.assertEqual(source._read_count, len(source.getvalue()) + 1) | ||||
| self.assertEqual(dest._write_count, len(dest.getvalue())) | self.assertEqual(dest._write_count, len(dest.getvalue())) | ||||
| @make_cffi | @make_cffi | ||||
| class TestDecompressor_stream_reader(TestCase): | class TestDecompressor_stream_reader(TestCase): | ||||
| def test_context_manager(self): | def test_context_manager(self): | ||||
| dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
| with dctx.stream_reader(b"foo") as reader: | with dctx.stream_reader(b"foo") as reader: | ||||
| with self.assertRaisesRegex(ValueError, "cannot __enter__ multiple times"): | with self.assertRaisesRegex( | ||||
| ValueError, "cannot __enter__ multiple times" | |||||
| ): | |||||
| with reader as reader2: | with reader as reader2: | ||||
| pass | pass | ||||
| def test_not_implemented(self): | def test_not_implemented(self): | ||||
| dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
| with dctx.stream_reader(b"foo") as reader: | with dctx.stream_reader(b"foo") as reader: | ||||
| with self.assertRaises(io.UnsupportedOperation): | with self.assertRaises(io.UnsupportedOperation): | ||||
| def test_illegal_seeks(self): | def test_illegal_seeks(self): | ||||
| cctx = zstd.ZstdCompressor() | cctx = zstd.ZstdCompressor() | ||||
| frame = cctx.compress(b"foo" * 60) | frame = cctx.compress(b"foo" * 60) | ||||
| dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
| with dctx.stream_reader(frame) as reader: | with dctx.stream_reader(frame) as reader: | ||||
| with self.assertRaisesRegex(ValueError, "cannot seek to negative position"): | with self.assertRaisesRegex( | ||||
| ValueError, "cannot seek to negative position" | |||||
| ): | |||||
| reader.seek(-1, os.SEEK_SET) | reader.seek(-1, os.SEEK_SET) | ||||
| reader.read(1) | reader.read(1) | ||||
| with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
| ValueError, "cannot seek zstd decompression stream backwards" | ValueError, "cannot seek zstd decompression stream backwards" | ||||
| ): | ): | ||||
| reader.seek(0, os.SEEK_SET) | reader.seek(0, os.SEEK_SET) | ||||
| with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
| ValueError, "cannot seek zstd decompression stream backwards" | ValueError, "cannot seek zstd decompression stream backwards" | ||||
| ): | ): | ||||
| reader.seek(-1, os.SEEK_CUR) | reader.seek(-1, os.SEEK_CUR) | ||||
| with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
| ValueError, "zstd decompression streams cannot be seeked with SEEK_END" | ValueError, | ||||
| "zstd decompression streams cannot be seeked with SEEK_END", | |||||
| ): | ): | ||||
| reader.seek(0, os.SEEK_END) | reader.seek(0, os.SEEK_END) | ||||
| reader.close() | reader.close() | ||||
| with self.assertRaisesRegex(ValueError, "stream is closed"): | with self.assertRaisesRegex(ValueError, "stream is closed"): | ||||
| reader.seek(4, os.SEEK_SET) | reader.seek(4, os.SEEK_SET) | ||||
| self.assertEqual(b._read_count, 1) | self.assertEqual(b._read_count, 1) | ||||
| self.assertEqual(reader.read1(1), b"o") | self.assertEqual(reader.read1(1), b"o") | ||||
| self.assertEqual(b._read_count, 1) | self.assertEqual(b._read_count, 1) | ||||
| self.assertEqual(reader.read1(1), b"") | self.assertEqual(reader.read1(1), b"") | ||||
| self.assertEqual(b._read_count, 2) | self.assertEqual(b._read_count, 2) | ||||
| def test_read_lines(self): | def test_read_lines(self): | ||||
| cctx = zstd.ZstdCompressor() | cctx = zstd.ZstdCompressor() | ||||
| source = b"\n".join(("line %d" % i).encode("ascii") for i in range(1024)) | source = b"\n".join( | ||||
| ("line %d" % i).encode("ascii") for i in range(1024) | |||||
| ) | |||||
| frame = cctx.compress(source) | frame = cctx.compress(source) | ||||
| dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
| reader = dctx.stream_reader(frame) | reader = dctx.stream_reader(frame) | ||||
| tr = io.TextIOWrapper(reader, encoding="utf-8") | tr = io.TextIOWrapper(reader, encoding="utf-8") | ||||
| lines = [] | lines = [] | ||||
| def test_reuse(self): | def test_reuse(self): | ||||
| data = zstd.ZstdCompressor(level=1).compress(b"foobar") | data = zstd.ZstdCompressor(level=1).compress(b"foobar") | ||||
| dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
| dobj = dctx.decompressobj() | dobj = dctx.decompressobj() | ||||
| dobj.decompress(data) | dobj.decompress(data) | ||||
| with self.assertRaisesRegex(zstd.ZstdError, "cannot use a decompressobj"): | with self.assertRaisesRegex( | ||||
| zstd.ZstdError, "cannot use a decompressobj" | |||||
| ): | |||||
| dobj.decompress(data) | dobj.decompress(data) | ||||
| self.assertIsNone(dobj.flush()) | self.assertIsNone(dobj.flush()) | ||||
| def test_bad_write_size(self): | def test_bad_write_size(self): | ||||
| dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
| with self.assertRaisesRegex(ValueError, "write_size must be positive"): | with self.assertRaisesRegex(ValueError, "write_size must be positive"): | ||||
| dctx.decompressobj(write_size=0) | dctx.decompressobj(write_size=0) | ||||
| dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
| # Object with read() works. | # Object with read() works. | ||||
| dctx.read_to_iter(io.BytesIO()) | dctx.read_to_iter(io.BytesIO()) | ||||
| # Buffer protocol works. | # Buffer protocol works. | ||||
| dctx.read_to_iter(b"foobar") | dctx.read_to_iter(b"foobar") | ||||
| with self.assertRaisesRegex(ValueError, "must pass an object with a read"): | with self.assertRaisesRegex( | ||||
| ValueError, "must pass an object with a read" | |||||
| ): | |||||
| b"".join(dctx.read_to_iter(True)) | b"".join(dctx.read_to_iter(True)) | ||||
| def test_empty_input(self): | def test_empty_input(self): | ||||
| dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
| source = io.BytesIO() | source = io.BytesIO() | ||||
| it = dctx.read_to_iter(source) | it = dctx.read_to_iter(source) | ||||
| # TODO this is arguably wrong. Should get an error about missing frame foo. | # TODO this is arguably wrong. Should get an error about missing frame foo. | ||||
| chunks.append(next(it)) | chunks.append(next(it)) | ||||
| with self.assertRaises(StopIteration): | with self.assertRaises(StopIteration): | ||||
| next(it) | next(it) | ||||
| decompressed = b"".join(chunks) | decompressed = b"".join(chunks) | ||||
| self.assertEqual(decompressed, source.getvalue()) | self.assertEqual(decompressed, source.getvalue()) | ||||
| @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | @unittest.skipUnless( | ||||
| "ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set" | |||||
| ) | |||||
| def test_large_input(self): | def test_large_input(self): | ||||
| bytes = list(struct.Struct(">B").pack(i) for i in range(256)) | bytes = list(struct.Struct(">B").pack(i) for i in range(256)) | ||||
| compressed = NonClosingBytesIO() | compressed = NonClosingBytesIO() | ||||
| input_size = 0 | input_size = 0 | ||||
| cctx = zstd.ZstdCompressor(level=1) | cctx = zstd.ZstdCompressor(level=1) | ||||
| with cctx.stream_writer(compressed) as compressor: | with cctx.stream_writer(compressed) as compressor: | ||||
| while True: | while True: | ||||
| compressor.write(random.choice(bytes)) | compressor.write(random.choice(bytes)) | ||||
| input_size += 1 | input_size += 1 | ||||
| have_compressed = ( | have_compressed = ( | ||||
| len(compressed.getvalue()) | len(compressed.getvalue()) | ||||
| > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE | > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE | ||||
| ) | ) | ||||
| have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2 | have_raw = ( | ||||
| input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2 | |||||
| ) | |||||
| if have_compressed and have_raw: | if have_compressed and have_raw: | ||||
| break | break | ||||
| compressed = io.BytesIO(compressed.getvalue()) | compressed = io.BytesIO(compressed.getvalue()) | ||||
| self.assertGreater( | self.assertGreater( | ||||
| len(compressed.getvalue()), zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE | len(compressed.getvalue()), | ||||
| zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE, | |||||
| ) | ) | ||||
| dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
| it = dctx.read_to_iter(compressed) | it = dctx.read_to_iter(compressed) | ||||
| chunks = [] | chunks = [] | ||||
| chunks.append(next(it)) | chunks.append(next(it)) | ||||
| chunks.append(next(it)) | chunks.append(next(it)) | ||||
| ) | ) | ||||
| self.assertEqual(simple, source.getvalue()) | self.assertEqual(simple, source.getvalue()) | ||||
| compressed = io.BytesIO(compressed.getvalue()) | compressed = io.BytesIO(compressed.getvalue()) | ||||
| streamed = b"".join(dctx.read_to_iter(compressed)) | streamed = b"".join(dctx.read_to_iter(compressed)) | ||||
| self.assertEqual(streamed, source.getvalue()) | self.assertEqual(streamed, source.getvalue()) | ||||
| def test_read_write_size(self): | def test_read_write_size(self): | ||||
| source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b"foobarfoobar")) | source = OpCountingBytesIO( | ||||
| zstd.ZstdCompressor().compress(b"foobarfoobar") | |||||
| ) | |||||
| dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
| for chunk in dctx.read_to_iter(source, read_size=1, write_size=1): | for chunk in dctx.read_to_iter(source, read_size=1, write_size=1): | ||||
| self.assertEqual(len(chunk), 1) | self.assertEqual(len(chunk), 1) | ||||
| self.assertEqual(source._read_count, len(source.getvalue())) | self.assertEqual(source._read_count, len(source.getvalue())) | ||||
| def test_magic_less(self): | def test_magic_less(self): | ||||
| params = zstd.CompressionParameters.from_level( | params = zstd.CompressionParameters.from_level( | ||||
| with self.assertRaisesRegex(ValueError, "chunk 0 must be bytes"): | with self.assertRaisesRegex(ValueError, "chunk 0 must be bytes"): | ||||
| dctx.decompress_content_dict_chain([True]) | dctx.decompress_content_dict_chain([True]) | ||||
| with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
| ValueError, "chunk 0 is too small to contain a zstd frame" | ValueError, "chunk 0 is too small to contain a zstd frame" | ||||
| ): | ): | ||||
| dctx.decompress_content_dict_chain([zstd.FRAME_HEADER]) | dctx.decompress_content_dict_chain([zstd.FRAME_HEADER]) | ||||
| with self.assertRaisesRegex(ValueError, "chunk 0 is not a valid zstd frame"): | with self.assertRaisesRegex( | ||||
| ValueError, "chunk 0 is not a valid zstd frame" | |||||
| ): | |||||
| dctx.decompress_content_dict_chain([b"foo" * 8]) | dctx.decompress_content_dict_chain([b"foo" * 8]) | ||||
| no_size = zstd.ZstdCompressor(write_content_size=False).compress(b"foo" * 64) | no_size = zstd.ZstdCompressor(write_content_size=False).compress( | ||||
| b"foo" * 64 | |||||
| ) | |||||
| with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
| ValueError, "chunk 0 missing content size in frame" | ValueError, "chunk 0 missing content size in frame" | ||||
| ): | ): | ||||
| dctx.decompress_content_dict_chain([no_size]) | dctx.decompress_content_dict_chain([no_size]) | ||||
| # Corrupt first frame. | # Corrupt first frame. | ||||
| frame = zstd.ZstdCompressor().compress(b"foo" * 64) | frame = zstd.ZstdCompressor().compress(b"foo" * 64) | ||||
| with self.assertRaisesRegex(ValueError, "chunk 1 must be bytes"): | with self.assertRaisesRegex(ValueError, "chunk 1 must be bytes"): | ||||
| dctx.decompress_content_dict_chain([initial, None]) | dctx.decompress_content_dict_chain([initial, None]) | ||||
| with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
| ValueError, "chunk 1 is too small to contain a zstd frame" | ValueError, "chunk 1 is too small to contain a zstd frame" | ||||
| ): | ): | ||||
| dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER]) | dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER]) | ||||
| with self.assertRaisesRegex(ValueError, "chunk 1 is not a valid zstd frame"): | with self.assertRaisesRegex( | ||||
| ValueError, "chunk 1 is not a valid zstd frame" | |||||
| ): | |||||
| dctx.decompress_content_dict_chain([initial, b"foo" * 8]) | dctx.decompress_content_dict_chain([initial, b"foo" * 8]) | ||||
| no_size = zstd.ZstdCompressor(write_content_size=False).compress(b"foo" * 64) | no_size = zstd.ZstdCompressor(write_content_size=False).compress( | ||||
| b"foo" * 64 | |||||
| ) | |||||
| with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
| ValueError, "chunk 1 missing content size in frame" | ValueError, "chunk 1 missing content size in frame" | ||||
| ): | ): | ||||
| dctx.decompress_content_dict_chain([initial, no_size]) | dctx.decompress_content_dict_chain([initial, no_size]) | ||||
| # Corrupt second frame. | # Corrupt second frame. | ||||
| cctx = zstd.ZstdCompressor(dict_data=zstd.ZstdCompressionDict(b"foo" * 64)) | cctx = zstd.ZstdCompressor( | ||||
| dict_data=zstd.ZstdCompressionDict(b"foo" * 64) | |||||
| ) | |||||
| frame = cctx.compress(b"bar" * 64) | frame = cctx.compress(b"bar" * 64) | ||||
| frame = frame[0:12] + frame[15:] | frame = frame[0:12] + frame[15:] | ||||
| with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
| zstd.ZstdError, "chunk 1 did not decompress full frame" | zstd.ZstdError, "chunk 1 did not decompress full frame" | ||||
| ): | ): | ||||
| dctx.decompress_content_dict_chain([initial, frame]) | dctx.decompress_content_dict_chain([initial, frame]) | ||||
| self.skipTest("multi_decompress_to_buffer not available") | self.skipTest("multi_decompress_to_buffer not available") | ||||
| with self.assertRaises(TypeError): | with self.assertRaises(TypeError): | ||||
| dctx.multi_decompress_to_buffer(True) | dctx.multi_decompress_to_buffer(True) | ||||
| with self.assertRaises(TypeError): | with self.assertRaises(TypeError): | ||||
| dctx.multi_decompress_to_buffer((1, 2)) | dctx.multi_decompress_to_buffer((1, 2)) | ||||
| with self.assertRaisesRegex(TypeError, "item 0 not a bytes like object"): | with self.assertRaisesRegex( | ||||
| TypeError, "item 0 not a bytes like object" | |||||
| ): | |||||
| dctx.multi_decompress_to_buffer([u"foo"]) | dctx.multi_decompress_to_buffer([u"foo"]) | ||||
| with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
| ValueError, "could not determine decompressed size of item 0" | ValueError, "could not determine decompressed size of item 0" | ||||
| ): | ): | ||||
| dctx.multi_decompress_to_buffer([b"foobarbaz"]) | dctx.multi_decompress_to_buffer([b"foobarbaz"]) | ||||
| def test_list_input(self): | def test_list_input(self): | ||||
| frames = [cctx.compress(d) for d in original] | frames = [cctx.compress(d) for d in original] | ||||
| sizes = struct.pack("=" + "Q" * len(original), *map(len, original)) | sizes = struct.pack("=" + "Q" * len(original), *map(len, original)) | ||||
| dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
| if not hasattr(dctx, "multi_decompress_to_buffer"): | if not hasattr(dctx, "multi_decompress_to_buffer"): | ||||
| self.skipTest("multi_decompress_to_buffer not available") | self.skipTest("multi_decompress_to_buffer not available") | ||||
| result = dctx.multi_decompress_to_buffer(frames, decompressed_sizes=sizes) | result = dctx.multi_decompress_to_buffer( | ||||
| frames, decompressed_sizes=sizes | |||||
| ) | |||||
| self.assertEqual(len(result), len(frames)) | self.assertEqual(len(result), len(frames)) | ||||
| self.assertEqual(result.size(), sum(map(len, original))) | self.assertEqual(result.size(), sum(map(len, original))) | ||||
| for i, data in enumerate(original): | for i, data in enumerate(original): | ||||
| self.assertEqual(result[i].tobytes(), data) | self.assertEqual(result[i].tobytes(), data) | ||||
| def test_buffer_with_segments_input(self): | def test_buffer_with_segments_input(self): | ||||
| self.assertEqual(len(decompressed), len(original)) | self.assertEqual(len(decompressed), len(original)) | ||||
| for i, data in enumerate(original): | for i, data in enumerate(original): | ||||
| self.assertEqual(data, decompressed[i].tobytes()) | self.assertEqual(data, decompressed[i].tobytes()) | ||||
| # And a manual mode. | # And a manual mode. | ||||
| b = b"".join([frames[0].tobytes(), frames[1].tobytes()]) | b = b"".join([frames[0].tobytes(), frames[1].tobytes()]) | ||||
| b1 = zstd.BufferWithSegments( | b1 = zstd.BufferWithSegments( | ||||
| b, struct.pack("=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1])) | b, | ||||
| struct.pack( | |||||
| "=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1]) | |||||
| ), | |||||
| ) | ) | ||||
| b = b"".join([frames[2].tobytes(), frames[3].tobytes(), frames[4].tobytes()]) | b = b"".join( | ||||
| [frames[2].tobytes(), frames[3].tobytes(), frames[4].tobytes()] | |||||
| ) | |||||
| b2 = zstd.BufferWithSegments( | b2 = zstd.BufferWithSegments( | ||||
| b, | b, | ||||
| struct.pack( | struct.pack( | ||||
| "=QQQQQQ", | "=QQQQQQ", | ||||
| 0, | 0, | ||||
| len(frames[2]), | len(frames[2]), | ||||
| len(frames[2]), | len(frames[2]), | ||||
| len(frames[3]), | len(frames[3]), | ||||
| suppress_health_check=[hypothesis.HealthCheck.large_base_example] | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | ||||
| ) | ) | ||||
| @hypothesis.given( | @hypothesis.given( | ||||
| original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
| level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
| streaming=strategies.booleans(), | streaming=strategies.booleans(), | ||||
| source_read_size=strategies.integers(1, 1048576), | source_read_size=strategies.integers(1, 1048576), | ||||
| ) | ) | ||||
| def test_stream_source_readall(self, original, level, streaming, source_read_size): | def test_stream_source_readall( | ||||
| self, original, level, streaming, source_read_size | |||||
| ): | |||||
| cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
| if streaming: | if streaming: | ||||
| source = io.BytesIO() | source = io.BytesIO() | ||||
| writer = cctx.stream_writer(source) | writer = cctx.stream_writer(source) | ||||
| writer.write(original) | writer.write(original) | ||||
| writer.flush(zstd.FLUSH_FRAME) | writer.flush(zstd.FLUSH_FRAME) | ||||
| source.seek(0) | source.seek(0) | ||||
| ] | ] | ||||
| ) | ) | ||||
| @hypothesis.given( | @hypothesis.given( | ||||
| original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
| level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
| write_size=strategies.integers(min_value=1, max_value=8192), | write_size=strategies.integers(min_value=1, max_value=8192), | ||||
| input_sizes=strategies.data(), | input_sizes=strategies.data(), | ||||
| ) | ) | ||||
| def test_write_size_variance(self, original, level, write_size, input_sizes): | def test_write_size_variance( | ||||
| self, original, level, write_size, input_sizes | |||||
| ): | |||||
| cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
| frame = cctx.compress(original) | frame = cctx.compress(original) | ||||
| dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
| source = io.BytesIO(frame) | source = io.BytesIO(frame) | ||||
| dest = NonClosingBytesIO() | dest = NonClosingBytesIO() | ||||
| with dctx.stream_writer(dest, write_size=write_size) as decompressor: | with dctx.stream_writer(dest, write_size=write_size) as decompressor: | ||||
| ] | ] | ||||
| ) | ) | ||||
| @hypothesis.given( | @hypothesis.given( | ||||
| original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
| level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
| read_size=strategies.integers(min_value=1, max_value=8192), | read_size=strategies.integers(min_value=1, max_value=8192), | ||||
| write_size=strategies.integers(min_value=1, max_value=8192), | write_size=strategies.integers(min_value=1, max_value=8192), | ||||
| ) | ) | ||||
| def test_read_write_size_variance(self, original, level, read_size, write_size): | def test_read_write_size_variance( | ||||
| self, original, level, read_size, write_size | |||||
| ): | |||||
| cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
| frame = cctx.compress(original) | frame = cctx.compress(original) | ||||
| source = io.BytesIO(frame) | source = io.BytesIO(frame) | ||||
| dest = io.BytesIO() | dest = io.BytesIO() | ||||
| dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
| dctx.copy_stream(source, dest, read_size=read_size, write_size=write_size) | dctx.copy_stream( | ||||
| source, dest, read_size=read_size, write_size=write_size | |||||
| ) | |||||
| self.assertEqual(dest.getvalue(), original) | self.assertEqual(dest.getvalue(), original) | ||||
| @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | ||||
| @make_cffi | @make_cffi | ||||
| class TestDecompressor_decompressobj_fuzzing(TestCase): | class TestDecompressor_decompressobj_fuzzing(TestCase): | ||||
| @hypothesis.settings( | @hypothesis.settings( | ||||
| hypothesis.HealthCheck.large_base_example, | hypothesis.HealthCheck.large_base_example, | ||||
| hypothesis.HealthCheck.too_slow, | hypothesis.HealthCheck.too_slow, | ||||
| ] | ] | ||||
| ) | ) | ||||
| @hypothesis.given( | @hypothesis.given( | ||||
| original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
| level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
| write_size=strategies.integers( | write_size=strategies.integers( | ||||
| min_value=1, max_value=4 * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE | min_value=1, | ||||
| max_value=4 * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE, | |||||
| ), | ), | ||||
| chunk_sizes=strategies.data(), | chunk_sizes=strategies.data(), | ||||
| ) | ) | ||||
| def test_random_output_sizes(self, original, level, write_size, chunk_sizes): | def test_random_output_sizes( | ||||
| self, original, level, write_size, chunk_sizes | |||||
| ): | |||||
| cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
| frame = cctx.compress(original) | frame = cctx.compress(original) | ||||
| source = io.BytesIO(frame) | source = io.BytesIO(frame) | ||||
| dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
| dobj = dctx.decompressobj(write_size=write_size) | dobj = dctx.decompressobj(write_size=write_size) | ||||
| @make_cffi | @make_cffi | ||||
| class TestDecompressor_read_to_iter_fuzzing(TestCase): | class TestDecompressor_read_to_iter_fuzzing(TestCase): | ||||
| @hypothesis.given( | @hypothesis.given( | ||||
| original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
| level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
| read_size=strategies.integers(min_value=1, max_value=4096), | read_size=strategies.integers(min_value=1, max_value=4096), | ||||
| write_size=strategies.integers(min_value=1, max_value=4096), | write_size=strategies.integers(min_value=1, max_value=4096), | ||||
| ) | ) | ||||
| def test_read_write_size_variance(self, original, level, read_size, write_size): | def test_read_write_size_variance( | ||||
| self, original, level, read_size, write_size | |||||
| ): | |||||
| cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
| frame = cctx.compress(original) | frame = cctx.compress(original) | ||||
| source = io.BytesIO(frame) | source = io.BytesIO(frame) | ||||
| dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
| chunks = list( | chunks = list( | ||||
| dctx.read_to_iter(source, read_size=read_size, write_size=write_size) | dctx.read_to_iter( | ||||
| source, read_size=read_size, write_size=write_size | |||||
| ) | |||||
| ) | ) | ||||
| self.assertEqual(b"".join(chunks), original) | self.assertEqual(b"".join(chunks), original) | ||||
| @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | ||||
| class TestDecompressor_multi_decompress_to_buffer_fuzzing(TestCase): | class TestDecompressor_multi_decompress_to_buffer_fuzzing(TestCase): | ||||
| @hypothesis.given( | @hypothesis.given( | ||||
| original=strategies.lists( | original=strategies.lists( | ||||
| strategies.sampled_from(random_input_data()), min_size=1, max_size=1024 | strategies.sampled_from(random_input_data()), | ||||
| min_size=1, | |||||
| max_size=1024, | |||||
| ), | ), | ||||
| threads=strategies.integers(min_value=1, max_value=8), | threads=strategies.integers(min_value=1, max_value=8), | ||||
| use_dict=strategies.booleans(), | use_dict=strategies.booleans(), | ||||
| ) | ) | ||||
| def test_data_equivalence(self, original, threads, use_dict): | def test_data_equivalence(self, original, threads, use_dict): | ||||
| kwargs = {} | kwargs = {} | ||||
| if use_dict: | if use_dict: | ||||
| kwargs["dict_data"] = zstd.ZstdCompressionDict(original[0]) | kwargs["dict_data"] = zstd.ZstdCompressionDict(original[0]) | ||||
| data = d.as_bytes() | data = d.as_bytes() | ||||
| self.assertEqual(data[0:4], b"\x37\xa4\x30\xec") | self.assertEqual(data[0:4], b"\x37\xa4\x30\xec") | ||||
| self.assertEqual(d.k, 64) | self.assertEqual(d.k, 64) | ||||
| self.assertEqual(d.d, 16) | self.assertEqual(d.d, 16) | ||||
| def test_set_dict_id(self): | def test_set_dict_id(self): | ||||
| d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16, dict_id=42) | d = zstd.train_dictionary( | ||||
| 8192, generate_samples(), k=64, d=16, dict_id=42 | |||||
| ) | |||||
| self.assertEqual(d.dict_id(), 42) | self.assertEqual(d.dict_id(), 42) | ||||
| def test_optimize(self): | def test_optimize(self): | ||||
| d = zstd.train_dictionary(8192, generate_samples(), threads=-1, steps=1, d=16) | d = zstd.train_dictionary( | ||||
| 8192, generate_samples(), threads=-1, steps=1, d=16 | |||||
| ) | |||||
| # This varies by platform. | # This varies by platform. | ||||
| self.assertIn(d.k, (50, 2000)) | self.assertIn(d.k, (50, 2000)) | ||||
| self.assertEqual(d.d, 16) | self.assertEqual(d.d, 16) | ||||
| @make_cffi | @make_cffi | ||||
| class TestCompressionDict(TestCase): | class TestCompressionDict(TestCase): | ||||
| def test_bad_mode(self): | def test_bad_mode(self): | ||||
| with self.assertRaisesRegex(ValueError, "invalid dictionary load mode"): | with self.assertRaisesRegex(ValueError, "invalid dictionary load mode"): | ||||
| zstd.ZstdCompressionDict(b"foo", dict_type=42) | zstd.ZstdCompressionDict(b"foo", dict_type=42) | ||||
| def test_bad_precompute_compress(self): | def test_bad_precompute_compress(self): | ||||
| d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16) | d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16) | ||||
| with self.assertRaisesRegex(ValueError, "must specify one of level or "): | with self.assertRaisesRegex( | ||||
| ValueError, "must specify one of level or " | |||||
| ): | |||||
| d.precompute_compress() | d.precompute_compress() | ||||
| with self.assertRaisesRegex(ValueError, "must only specify one of level or "): | with self.assertRaisesRegex( | ||||
| ValueError, "must only specify one of level or " | |||||
| ): | |||||
| d.precompute_compress( | d.precompute_compress( | ||||
| level=3, compression_params=zstd.CompressionParameters() | level=3, compression_params=zstd.CompressionParameters() | ||||
| ) | ) | ||||
| def test_precompute_compress_rawcontent(self): | def test_precompute_compress_rawcontent(self): | ||||
| d = zstd.ZstdCompressionDict( | d = zstd.ZstdCompressionDict( | ||||
| b"dictcontent" * 64, dict_type=zstd.DICT_TYPE_RAWCONTENT | b"dictcontent" * 64, dict_type=zstd.DICT_TYPE_RAWCONTENT | ||||
| ) | ) | ||||
| d.precompute_compress(level=1) | d.precompute_compress(level=1) | ||||
| d = zstd.ZstdCompressionDict( | d = zstd.ZstdCompressionDict( | ||||
| b"dictcontent" * 64, dict_type=zstd.DICT_TYPE_FULLDICT | b"dictcontent" * 64, dict_type=zstd.DICT_TYPE_FULLDICT | ||||
| ) | ) | ||||
| with self.assertRaisesRegex(zstd.ZstdError, "unable to precompute dictionary"): | with self.assertRaisesRegex( | ||||
| zstd.ZstdError, "unable to precompute dictionary" | |||||
| ): | |||||
| d.precompute_compress(level=1) | d.precompute_compress(level=1) | ||||
| _set_compression_parameter( | _set_compression_parameter( | ||||
| params, lib.ZSTD_c_compressionLevel, compression_level | params, lib.ZSTD_c_compressionLevel, compression_level | ||||
| ) | ) | ||||
| _set_compression_parameter(params, lib.ZSTD_c_windowLog, window_log) | _set_compression_parameter(params, lib.ZSTD_c_windowLog, window_log) | ||||
| _set_compression_parameter(params, lib.ZSTD_c_hashLog, hash_log) | _set_compression_parameter(params, lib.ZSTD_c_hashLog, hash_log) | ||||
| _set_compression_parameter(params, lib.ZSTD_c_chainLog, chain_log) | _set_compression_parameter(params, lib.ZSTD_c_chainLog, chain_log) | ||||
| _set_compression_parameter(params, lib.ZSTD_c_searchLog, search_log) | _set_compression_parameter(params, lib.ZSTD_c_searchLog, search_log) | ||||
| _set_compression_parameter(params, lib.ZSTD_c_minMatch, min_match) | _set_compression_parameter(params, lib.ZSTD_c_minMatch, min_match) | ||||
| _set_compression_parameter(params, lib.ZSTD_c_targetLength, target_length) | _set_compression_parameter( | ||||
| params, lib.ZSTD_c_targetLength, target_length | |||||
| ) | |||||
| if strategy != -1 and compression_strategy != -1: | if strategy != -1 and compression_strategy != -1: | ||||
| raise ValueError("cannot specify both compression_strategy and strategy") | raise ValueError( | ||||
| "cannot specify both compression_strategy and strategy" | |||||
| ) | |||||
| if compression_strategy != -1: | if compression_strategy != -1: | ||||
| strategy = compression_strategy | strategy = compression_strategy | ||||
| elif strategy == -1: | elif strategy == -1: | ||||
| strategy = 0 | strategy = 0 | ||||
| _set_compression_parameter(params, lib.ZSTD_c_strategy, strategy) | _set_compression_parameter(params, lib.ZSTD_c_strategy, strategy) | ||||
| _set_compression_parameter( | _set_compression_parameter( | ||||
| params, lib.ZSTD_c_contentSizeFlag, write_content_size | params, lib.ZSTD_c_contentSizeFlag, write_content_size | ||||
| ) | ) | ||||
| _set_compression_parameter(params, lib.ZSTD_c_checksumFlag, write_checksum) | _set_compression_parameter( | ||||
| params, lib.ZSTD_c_checksumFlag, write_checksum | |||||
| ) | |||||
| _set_compression_parameter(params, lib.ZSTD_c_dictIDFlag, write_dict_id) | _set_compression_parameter(params, lib.ZSTD_c_dictIDFlag, write_dict_id) | ||||
| _set_compression_parameter(params, lib.ZSTD_c_jobSize, job_size) | _set_compression_parameter(params, lib.ZSTD_c_jobSize, job_size) | ||||
| if overlap_log != -1 and overlap_size_log != -1: | if overlap_log != -1 and overlap_size_log != -1: | ||||
| raise ValueError("cannot specify both overlap_log and overlap_size_log") | raise ValueError( | ||||
| "cannot specify both overlap_log and overlap_size_log" | |||||
| ) | |||||
| if overlap_size_log != -1: | if overlap_size_log != -1: | ||||
| overlap_log = overlap_size_log | overlap_log = overlap_size_log | ||||
| elif overlap_log == -1: | elif overlap_log == -1: | ||||
| overlap_log = 0 | overlap_log = 0 | ||||
| _set_compression_parameter(params, lib.ZSTD_c_overlapLog, overlap_log) | _set_compression_parameter(params, lib.ZSTD_c_overlapLog, overlap_log) | ||||
| _set_compression_parameter(params, lib.ZSTD_c_forceMaxWindow, force_max_window) | _set_compression_parameter( | ||||
| params, lib.ZSTD_c_forceMaxWindow, force_max_window | |||||
| ) | |||||
| _set_compression_parameter( | _set_compression_parameter( | ||||
| params, lib.ZSTD_c_enableLongDistanceMatching, enable_ldm | params, lib.ZSTD_c_enableLongDistanceMatching, enable_ldm | ||||
| ) | ) | ||||
| _set_compression_parameter(params, lib.ZSTD_c_ldmHashLog, ldm_hash_log) | _set_compression_parameter(params, lib.ZSTD_c_ldmHashLog, ldm_hash_log) | ||||
| _set_compression_parameter(params, lib.ZSTD_c_ldmMinMatch, ldm_min_match) | _set_compression_parameter( | ||||
| params, lib.ZSTD_c_ldmMinMatch, ldm_min_match | |||||
| ) | |||||
| _set_compression_parameter( | _set_compression_parameter( | ||||
| params, lib.ZSTD_c_ldmBucketSizeLog, ldm_bucket_size_log | params, lib.ZSTD_c_ldmBucketSizeLog, ldm_bucket_size_log | ||||
| ) | ) | ||||
| if ldm_hash_rate_log != -1 and ldm_hash_every_log != -1: | if ldm_hash_rate_log != -1 and ldm_hash_every_log != -1: | ||||
| raise ValueError( | raise ValueError( | ||||
| "cannot specify both ldm_hash_rate_log and ldm_hash_every_log" | "cannot specify both ldm_hash_rate_log and ldm_hash_every_log" | ||||
| ) | ) | ||||
| if ldm_hash_every_log != -1: | if ldm_hash_every_log != -1: | ||||
| ldm_hash_rate_log = ldm_hash_every_log | ldm_hash_rate_log = ldm_hash_every_log | ||||
| elif ldm_hash_rate_log == -1: | elif ldm_hash_rate_log == -1: | ||||
| ldm_hash_rate_log = 0 | ldm_hash_rate_log = 0 | ||||
| _set_compression_parameter(params, lib.ZSTD_c_ldmHashRateLog, ldm_hash_rate_log) | _set_compression_parameter( | ||||
| params, lib.ZSTD_c_ldmHashRateLog, ldm_hash_rate_log | |||||
| ) | |||||
| @property | @property | ||||
| def format(self): | def format(self): | ||||
| return _get_compression_parameter(self._params, lib.ZSTD_c_format) | return _get_compression_parameter(self._params, lib.ZSTD_c_format) | ||||
| @property | @property | ||||
| def compression_level(self): | def compression_level(self): | ||||
| return _get_compression_parameter(self._params, lib.ZSTD_c_compressionLevel) | return _get_compression_parameter( | ||||
| self._params, lib.ZSTD_c_compressionLevel | |||||
| ) | |||||
| @property | @property | ||||
| def window_log(self): | def window_log(self): | ||||
| return _get_compression_parameter(self._params, lib.ZSTD_c_windowLog) | return _get_compression_parameter(self._params, lib.ZSTD_c_windowLog) | ||||
| @property | @property | ||||
| def hash_log(self): | def hash_log(self): | ||||
| return _get_compression_parameter(self._params, lib.ZSTD_c_hashLog) | return _get_compression_parameter(self._params, lib.ZSTD_c_hashLog) | ||||
| return _get_compression_parameter(self._params, lib.ZSTD_c_targetLength) | return _get_compression_parameter(self._params, lib.ZSTD_c_targetLength) | ||||
| @property | @property | ||||
| def compression_strategy(self): | def compression_strategy(self): | ||||
| return _get_compression_parameter(self._params, lib.ZSTD_c_strategy) | return _get_compression_parameter(self._params, lib.ZSTD_c_strategy) | ||||
| @property | @property | ||||
| def write_content_size(self): | def write_content_size(self): | ||||
| return _get_compression_parameter(self._params, lib.ZSTD_c_contentSizeFlag) | return _get_compression_parameter( | ||||
| self._params, lib.ZSTD_c_contentSizeFlag | |||||
| ) | |||||
| @property | @property | ||||
| def write_checksum(self): | def write_checksum(self): | ||||
| return _get_compression_parameter(self._params, lib.ZSTD_c_checksumFlag) | return _get_compression_parameter(self._params, lib.ZSTD_c_checksumFlag) | ||||
| @property | @property | ||||
| def write_dict_id(self): | def write_dict_id(self): | ||||
| return _get_compression_parameter(self._params, lib.ZSTD_c_dictIDFlag) | return _get_compression_parameter(self._params, lib.ZSTD_c_dictIDFlag) | ||||
| @property | @property | ||||
| def job_size(self): | def job_size(self): | ||||
| return _get_compression_parameter(self._params, lib.ZSTD_c_jobSize) | return _get_compression_parameter(self._params, lib.ZSTD_c_jobSize) | ||||
| @property | @property | ||||
| def overlap_log(self): | def overlap_log(self): | ||||
| return _get_compression_parameter(self._params, lib.ZSTD_c_overlapLog) | return _get_compression_parameter(self._params, lib.ZSTD_c_overlapLog) | ||||
| @property | @property | ||||
| def overlap_size_log(self): | def overlap_size_log(self): | ||||
| return self.overlap_log | return self.overlap_log | ||||
| @property | @property | ||||
| def force_max_window(self): | def force_max_window(self): | ||||
| return _get_compression_parameter(self._params, lib.ZSTD_c_forceMaxWindow) | return _get_compression_parameter( | ||||
| self._params, lib.ZSTD_c_forceMaxWindow | |||||
| ) | |||||
| @property | @property | ||||
| def enable_ldm(self): | def enable_ldm(self): | ||||
| return _get_compression_parameter( | return _get_compression_parameter( | ||||
| self._params, lib.ZSTD_c_enableLongDistanceMatching | self._params, lib.ZSTD_c_enableLongDistanceMatching | ||||
| ) | ) | ||||
| @property | @property | ||||
| def ldm_hash_log(self): | def ldm_hash_log(self): | ||||
| return _get_compression_parameter(self._params, lib.ZSTD_c_ldmHashLog) | return _get_compression_parameter(self._params, lib.ZSTD_c_ldmHashLog) | ||||
| @property | @property | ||||
| def ldm_min_match(self): | def ldm_min_match(self): | ||||
| return _get_compression_parameter(self._params, lib.ZSTD_c_ldmMinMatch) | return _get_compression_parameter(self._params, lib.ZSTD_c_ldmMinMatch) | ||||
| @property | @property | ||||
| def ldm_bucket_size_log(self): | def ldm_bucket_size_log(self): | ||||
| return _get_compression_parameter(self._params, lib.ZSTD_c_ldmBucketSizeLog) | return _get_compression_parameter( | ||||
| self._params, lib.ZSTD_c_ldmBucketSizeLog | |||||
| ) | |||||
| @property | @property | ||||
| def ldm_hash_rate_log(self): | def ldm_hash_rate_log(self): | ||||
| return _get_compression_parameter(self._params, lib.ZSTD_c_ldmHashRateLog) | return _get_compression_parameter( | ||||
| self._params, lib.ZSTD_c_ldmHashRateLog | |||||
| ) | |||||
| @property | @property | ||||
| def ldm_hash_every_log(self): | def ldm_hash_every_log(self): | ||||
| return self.ldm_hash_rate_log | return self.ldm_hash_rate_log | ||||
| @property | @property | ||||
| def threads(self): | def threads(self): | ||||
| return _get_compression_parameter(self._params, lib.ZSTD_c_nbWorkers) | return _get_compression_parameter(self._params, lib.ZSTD_c_nbWorkers) | ||||
| def estimated_compression_context_size(self): | def estimated_compression_context_size(self): | ||||
| return lib.ZSTD_estimateCCtxSize_usingCCtxParams(self._params) | return lib.ZSTD_estimateCCtxSize_usingCCtxParams(self._params) | ||||
| CompressionParameters = ZstdCompressionParameters | CompressionParameters = ZstdCompressionParameters | ||||
| def estimate_decompression_context_size(): | def estimate_decompression_context_size(): | ||||
| return lib.ZSTD_estimateDCtxSize() | return lib.ZSTD_estimateDCtxSize() | ||||
| def _set_compression_parameter(params, param, value): | def _set_compression_parameter(params, param, value): | ||||
| zresult = lib.ZSTD_CCtxParams_setParameter(params, param, value) | zresult = lib.ZSTD_CCtxParams_setParameter(params, param, value) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError( | raise ZstdError( | ||||
| "unable to set compression context parameter: %s" % _zstd_error(zresult) | "unable to set compression context parameter: %s" | ||||
| % _zstd_error(zresult) | |||||
| ) | ) | ||||
| def _get_compression_parameter(params, param): | def _get_compression_parameter(params, param): | ||||
| result = ffi.new("int *") | result = ffi.new("int *") | ||||
| zresult = lib.ZSTD_CCtxParams_getParameter(params, param, result) | zresult = lib.ZSTD_CCtxParams_getParameter(params, param, result) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError( | raise ZstdError( | ||||
| "unable to get compression context parameter: %s" % _zstd_error(zresult) | "unable to get compression context parameter: %s" | ||||
| % _zstd_error(zresult) | |||||
| ) | ) | ||||
| return result[0] | return result[0] | ||||
| class ZstdCompressionWriter(object): | class ZstdCompressionWriter(object): | ||||
| def __init__(self, compressor, writer, source_size, write_size, write_return_read): | def __init__( | ||||
| self, compressor, writer, source_size, write_size, write_return_read | |||||
| ): | |||||
| self._compressor = compressor | self._compressor = compressor | ||||
| self._writer = writer | self._writer = writer | ||||
| self._write_size = write_size | self._write_size = write_size | ||||
| self._write_return_read = bool(write_return_read) | self._write_return_read = bool(write_return_read) | ||||
| self._entered = False | self._entered = False | ||||
| self._closed = False | self._closed = False | ||||
| self._bytes_compressed = 0 | self._bytes_compressed = 0 | ||||
| self._dst_buffer = ffi.new("char[]", write_size) | self._dst_buffer = ffi.new("char[]", write_size) | ||||
| self._out_buffer = ffi.new("ZSTD_outBuffer *") | self._out_buffer = ffi.new("ZSTD_outBuffer *") | ||||
| self._out_buffer.dst = self._dst_buffer | self._out_buffer.dst = self._dst_buffer | ||||
| self._out_buffer.size = len(self._dst_buffer) | self._out_buffer.size = len(self._dst_buffer) | ||||
| self._out_buffer.pos = 0 | self._out_buffer.pos = 0 | ||||
| zresult = lib.ZSTD_CCtx_setPledgedSrcSize(compressor._cctx, source_size) | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(compressor._cctx, source_size) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("error setting source size: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
| "error setting source size: %s" % _zstd_error(zresult) | |||||
| ) | |||||
| def __enter__(self): | def __enter__(self): | ||||
| if self._closed: | if self._closed: | ||||
| raise ValueError("stream is closed") | raise ValueError("stream is closed") | ||||
| if self._entered: | if self._entered: | ||||
| raise ZstdError("cannot __enter__ multiple times") | raise ZstdError("cannot __enter__ multiple times") | ||||
| in_buffer.size = len(data_buffer) | in_buffer.size = len(data_buffer) | ||||
| in_buffer.pos = 0 | in_buffer.pos = 0 | ||||
| out_buffer = self._out_buffer | out_buffer = self._out_buffer | ||||
| out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
| while in_buffer.pos < in_buffer.size: | while in_buffer.pos < in_buffer.size: | ||||
| zresult = lib.ZSTD_compressStream2( | zresult = lib.ZSTD_compressStream2( | ||||
| self._compressor._cctx, out_buffer, in_buffer, lib.ZSTD_e_continue | self._compressor._cctx, | ||||
| out_buffer, | |||||
| in_buffer, | |||||
| lib.ZSTD_e_continue, | |||||
| ) | ) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
| "zstd compress error: %s" % _zstd_error(zresult) | |||||
| ) | |||||
| if out_buffer.pos: | if out_buffer.pos: | ||||
| self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:]) | self._writer.write( | ||||
| ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |||||
| ) | |||||
| total_write += out_buffer.pos | total_write += out_buffer.pos | ||||
| self._bytes_compressed += out_buffer.pos | self._bytes_compressed += out_buffer.pos | ||||
| out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
| if self._write_return_read: | if self._write_return_read: | ||||
| return in_buffer.pos | return in_buffer.pos | ||||
| else: | else: | ||||
| return total_write | return total_write | ||||
| in_buffer.size = 0 | in_buffer.size = 0 | ||||
| in_buffer.pos = 0 | in_buffer.pos = 0 | ||||
| while True: | while True: | ||||
| zresult = lib.ZSTD_compressStream2( | zresult = lib.ZSTD_compressStream2( | ||||
| self._compressor._cctx, out_buffer, in_buffer, flush | self._compressor._cctx, out_buffer, in_buffer, flush | ||||
| ) | ) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
| "zstd compress error: %s" % _zstd_error(zresult) | |||||
| ) | |||||
| if out_buffer.pos: | if out_buffer.pos: | ||||
| self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:]) | self._writer.write( | ||||
| ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |||||
| ) | |||||
| total_write += out_buffer.pos | total_write += out_buffer.pos | ||||
| self._bytes_compressed += out_buffer.pos | self._bytes_compressed += out_buffer.pos | ||||
| out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
| if not zresult: | if not zresult: | ||||
| break | break | ||||
| return total_write | return total_write | ||||
| chunks = [] | chunks = [] | ||||
| while source.pos < len(data): | while source.pos < len(data): | ||||
| zresult = lib.ZSTD_compressStream2( | zresult = lib.ZSTD_compressStream2( | ||||
| self._compressor._cctx, self._out, source, lib.ZSTD_e_continue | self._compressor._cctx, self._out, source, lib.ZSTD_e_continue | ||||
| ) | ) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
| "zstd compress error: %s" % _zstd_error(zresult) | |||||
| ) | |||||
| if self._out.pos: | if self._out.pos: | ||||
| chunks.append(ffi.buffer(self._out.dst, self._out.pos)[:]) | chunks.append(ffi.buffer(self._out.dst, self._out.pos)[:]) | ||||
| self._out.pos = 0 | self._out.pos = 0 | ||||
| return b"".join(chunks) | return b"".join(chunks) | ||||
| def flush(self, flush_mode=COMPRESSOBJ_FLUSH_FINISH): | def flush(self, flush_mode=COMPRESSOBJ_FLUSH_FINISH): | ||||
| if flush_mode not in (COMPRESSOBJ_FLUSH_FINISH, COMPRESSOBJ_FLUSH_BLOCK): | if flush_mode not in ( | ||||
| COMPRESSOBJ_FLUSH_FINISH, | |||||
| COMPRESSOBJ_FLUSH_BLOCK, | |||||
| ): | |||||
| raise ValueError("flush mode not recognized") | raise ValueError("flush mode not recognized") | ||||
| if self._finished: | if self._finished: | ||||
| raise ZstdError("compressor object already finished") | raise ZstdError("compressor object already finished") | ||||
| if flush_mode == COMPRESSOBJ_FLUSH_BLOCK: | if flush_mode == COMPRESSOBJ_FLUSH_BLOCK: | ||||
| z_flush_mode = lib.ZSTD_e_flush | z_flush_mode = lib.ZSTD_e_flush | ||||
| elif flush_mode == COMPRESSOBJ_FLUSH_FINISH: | elif flush_mode == COMPRESSOBJ_FLUSH_FINISH: | ||||
| ) | ) | ||||
| if self._in.pos == self._in.size: | if self._in.pos == self._in.size: | ||||
| self._in.src = ffi.NULL | self._in.src = ffi.NULL | ||||
| self._in.size = 0 | self._in.size = 0 | ||||
| self._in.pos = 0 | self._in.pos = 0 | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
| "zstd compress error: %s" % _zstd_error(zresult) | |||||
| ) | |||||
| if self._out.pos == self._out.size: | if self._out.pos == self._out.size: | ||||
| yield ffi.buffer(self._out.dst, self._out.pos)[:] | yield ffi.buffer(self._out.dst, self._out.pos)[:] | ||||
| self._out.pos = 0 | self._out.pos = 0 | ||||
| def flush(self): | def flush(self): | ||||
| if self._finished: | if self._finished: | ||||
| raise ZstdError("cannot call flush() after compression finished") | raise ZstdError("cannot call flush() after compression finished") | ||||
| if self._in.src != ffi.NULL: | if self._in.src != ffi.NULL: | ||||
| raise ZstdError( | raise ZstdError( | ||||
| "cannot call flush() before consuming output from " "previous operation" | "cannot call flush() before consuming output from " | ||||
| "previous operation" | |||||
| ) | ) | ||||
| while True: | while True: | ||||
| zresult = lib.ZSTD_compressStream2( | zresult = lib.ZSTD_compressStream2( | ||||
| self._compressor._cctx, self._out, self._in, lib.ZSTD_e_flush | self._compressor._cctx, self._out, self._in, lib.ZSTD_e_flush | ||||
| ) | ) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
| "zstd compress error: %s" % _zstd_error(zresult) | |||||
| ) | |||||
| if self._out.pos: | if self._out.pos: | ||||
| yield ffi.buffer(self._out.dst, self._out.pos)[:] | yield ffi.buffer(self._out.dst, self._out.pos)[:] | ||||
| self._out.pos = 0 | self._out.pos = 0 | ||||
| if not zresult: | if not zresult: | ||||
| return | return | ||||
| def finish(self): | def finish(self): | ||||
| if self._finished: | if self._finished: | ||||
| raise ZstdError("cannot call finish() after compression finished") | raise ZstdError("cannot call finish() after compression finished") | ||||
| if self._in.src != ffi.NULL: | if self._in.src != ffi.NULL: | ||||
| raise ZstdError( | raise ZstdError( | ||||
| "cannot call finish() before consuming output from " | "cannot call finish() before consuming output from " | ||||
| "previous operation" | "previous operation" | ||||
| ) | ) | ||||
| while True: | while True: | ||||
| zresult = lib.ZSTD_compressStream2( | zresult = lib.ZSTD_compressStream2( | ||||
| self._compressor._cctx, self._out, self._in, lib.ZSTD_e_end | self._compressor._cctx, self._out, self._in, lib.ZSTD_e_end | ||||
| ) | ) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
| "zstd compress error: %s" % _zstd_error(zresult) | |||||
| ) | |||||
| if self._out.pos: | if self._out.pos: | ||||
| yield ffi.buffer(self._out.dst, self._out.pos)[:] | yield ffi.buffer(self._out.dst, self._out.pos)[:] | ||||
| self._out.pos = 0 | self._out.pos = 0 | ||||
| if not zresult: | if not zresult: | ||||
| self._finished = True | self._finished = True | ||||
| return | return | ||||
| def _compress_into_buffer(self, out_buffer): | def _compress_into_buffer(self, out_buffer): | ||||
| if self._in_buffer.pos >= self._in_buffer.size: | if self._in_buffer.pos >= self._in_buffer.size: | ||||
| return | return | ||||
| old_pos = out_buffer.pos | old_pos = out_buffer.pos | ||||
| zresult = lib.ZSTD_compressStream2( | zresult = lib.ZSTD_compressStream2( | ||||
| self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_continue | self._compressor._cctx, | ||||
| out_buffer, | |||||
| self._in_buffer, | |||||
| lib.ZSTD_e_continue, | |||||
| ) | ) | ||||
| self._bytes_compressed += out_buffer.pos - old_pos | self._bytes_compressed += out_buffer.pos - old_pos | ||||
| if self._in_buffer.pos == self._in_buffer.size: | if self._in_buffer.pos == self._in_buffer.size: | ||||
| self._in_buffer.src = ffi.NULL | self._in_buffer.src = ffi.NULL | ||||
| self._in_buffer.pos = 0 | self._in_buffer.pos = 0 | ||||
| self._in_buffer.size = 0 | self._in_buffer.size = 0 | ||||
| zresult = lib.ZSTD_compressStream2( | zresult = lib.ZSTD_compressStream2( | ||||
| self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end | self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end | ||||
| ) | ) | ||||
| self._bytes_compressed += out_buffer.pos - old_pos | self._bytes_compressed += out_buffer.pos - old_pos | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("error ending compression stream: %s", _zstd_error(zresult)) | raise ZstdError( | ||||
| "error ending compression stream: %s", _zstd_error(zresult) | |||||
| ) | |||||
| if zresult == 0: | if zresult == 0: | ||||
| self._finished_output = True | self._finished_output = True | ||||
| return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | ||||
| def read1(self, size=-1): | def read1(self, size=-1): | ||||
| if self._closed: | if self._closed: | ||||
| old_pos = out_buffer.pos | old_pos = out_buffer.pos | ||||
| zresult = lib.ZSTD_compressStream2( | zresult = lib.ZSTD_compressStream2( | ||||
| self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end | self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end | ||||
| ) | ) | ||||
| self._bytes_compressed += out_buffer.pos - old_pos | self._bytes_compressed += out_buffer.pos - old_pos | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("error ending compression stream: %s", _zstd_error(zresult)) | raise ZstdError( | ||||
| "error ending compression stream: %s", _zstd_error(zresult) | |||||
| ) | |||||
| if zresult == 0: | if zresult == 0: | ||||
| self._finished_output = True | self._finished_output = True | ||||
| return out_buffer.pos | return out_buffer.pos | ||||
| def readinto1(self, b): | def readinto1(self, b): | ||||
| if self._closed: | if self._closed: | ||||
| dict_data=None, | dict_data=None, | ||||
| compression_params=None, | compression_params=None, | ||||
| write_checksum=None, | write_checksum=None, | ||||
| write_content_size=None, | write_content_size=None, | ||||
| write_dict_id=None, | write_dict_id=None, | ||||
| threads=0, | threads=0, | ||||
| ): | ): | ||||
| if level > lib.ZSTD_maxCLevel(): | if level > lib.ZSTD_maxCLevel(): | ||||
| raise ValueError("level must be less than %d" % lib.ZSTD_maxCLevel()) | raise ValueError( | ||||
| "level must be less than %d" % lib.ZSTD_maxCLevel() | |||||
| ) | |||||
| if threads < 0: | if threads < 0: | ||||
| threads = _cpu_count() | threads = _cpu_count() | ||||
| if compression_params and write_checksum is not None: | if compression_params and write_checksum is not None: | ||||
| raise ValueError("cannot define compression_params and " "write_checksum") | raise ValueError( | ||||
| "cannot define compression_params and " "write_checksum" | |||||
| ) | |||||
| if compression_params and write_content_size is not None: | if compression_params and write_content_size is not None: | ||||
| raise ValueError( | raise ValueError( | ||||
| "cannot define compression_params and " "write_content_size" | "cannot define compression_params and " "write_content_size" | ||||
| ) | ) | ||||
| if compression_params and write_dict_id is not None: | if compression_params and write_dict_id is not None: | ||||
| raise ValueError("cannot define compression_params and " "write_dict_id") | raise ValueError( | ||||
| "cannot define compression_params and " "write_dict_id" | |||||
| ) | |||||
| if compression_params and threads: | if compression_params and threads: | ||||
| raise ValueError("cannot define compression_params and threads") | raise ValueError("cannot define compression_params and threads") | ||||
| if compression_params: | if compression_params: | ||||
| self._params = _make_cctx_params(compression_params) | self._params = _make_cctx_params(compression_params) | ||||
| else: | else: | ||||
| if write_dict_id is None: | if write_dict_id is None: | ||||
| write_dict_id = True | write_dict_id = True | ||||
| params = lib.ZSTD_createCCtxParams() | params = lib.ZSTD_createCCtxParams() | ||||
| if params == ffi.NULL: | if params == ffi.NULL: | ||||
| raise MemoryError() | raise MemoryError() | ||||
| self._params = ffi.gc(params, lib.ZSTD_freeCCtxParams) | self._params = ffi.gc(params, lib.ZSTD_freeCCtxParams) | ||||
| _set_compression_parameter(self._params, lib.ZSTD_c_compressionLevel, level) | _set_compression_parameter( | ||||
| self._params, lib.ZSTD_c_compressionLevel, level | |||||
| ) | |||||
| _set_compression_parameter( | _set_compression_parameter( | ||||
| self._params, | self._params, | ||||
| lib.ZSTD_c_contentSizeFlag, | lib.ZSTD_c_contentSizeFlag, | ||||
| write_content_size if write_content_size is not None else 1, | write_content_size if write_content_size is not None else 1, | ||||
| ) | ) | ||||
| _set_compression_parameter( | _set_compression_parameter( | ||||
| self._params, lib.ZSTD_c_checksumFlag, 1 if write_checksum else 0 | self._params, | ||||
| lib.ZSTD_c_checksumFlag, | |||||
| 1 if write_checksum else 0, | |||||
| ) | ) | ||||
| _set_compression_parameter( | _set_compression_parameter( | ||||
| self._params, lib.ZSTD_c_dictIDFlag, 1 if write_dict_id else 0 | self._params, lib.ZSTD_c_dictIDFlag, 1 if write_dict_id else 0 | ||||
| ) | ) | ||||
| if threads: | if threads: | ||||
| _set_compression_parameter(self._params, lib.ZSTD_c_nbWorkers, threads) | _set_compression_parameter( | ||||
| self._params, lib.ZSTD_c_nbWorkers, threads | |||||
| ) | |||||
| cctx = lib.ZSTD_createCCtx() | cctx = lib.ZSTD_createCCtx() | ||||
| if cctx == ffi.NULL: | if cctx == ffi.NULL: | ||||
| raise MemoryError() | raise MemoryError() | ||||
| self._cctx = cctx | self._cctx = cctx | ||||
| self._dict_data = dict_data | self._dict_data = dict_data | ||||
| # We defer setting up garbage collection until after calling | # We defer setting up garbage collection until after calling | ||||
| # _setup_cctx() to ensure the memory size estimate is more accurate. | # _setup_cctx() to ensure the memory size estimate is more accurate. | ||||
| try: | try: | ||||
| self._setup_cctx() | self._setup_cctx() | ||||
| finally: | finally: | ||||
| self._cctx = ffi.gc( | self._cctx = ffi.gc( | ||||
| cctx, lib.ZSTD_freeCCtx, size=lib.ZSTD_sizeof_CCtx(cctx) | cctx, lib.ZSTD_freeCCtx, size=lib.ZSTD_sizeof_CCtx(cctx) | ||||
| ) | ) | ||||
| def _setup_cctx(self): | def _setup_cctx(self): | ||||
| zresult = lib.ZSTD_CCtx_setParametersUsingCCtxParams(self._cctx, self._params) | zresult = lib.ZSTD_CCtx_setParametersUsingCCtxParams( | ||||
| self._cctx, self._params | |||||
| ) | |||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError( | raise ZstdError( | ||||
| "could not set compression parameters: %s" % _zstd_error(zresult) | "could not set compression parameters: %s" | ||||
| % _zstd_error(zresult) | |||||
| ) | ) | ||||
| dict_data = self._dict_data | dict_data = self._dict_data | ||||
| if dict_data: | if dict_data: | ||||
| if dict_data._cdict: | if dict_data._cdict: | ||||
| zresult = lib.ZSTD_CCtx_refCDict(self._cctx, dict_data._cdict) | zresult = lib.ZSTD_CCtx_refCDict(self._cctx, dict_data._cdict) | ||||
| else: | else: | ||||
| zresult = lib.ZSTD_CCtx_loadDictionary_advanced( | zresult = lib.ZSTD_CCtx_loadDictionary_advanced( | ||||
| self._cctx, | self._cctx, | ||||
| dict_data.as_bytes(), | dict_data.as_bytes(), | ||||
| len(dict_data), | len(dict_data), | ||||
| lib.ZSTD_dlm_byRef, | lib.ZSTD_dlm_byRef, | ||||
| dict_data._dict_type, | dict_data._dict_type, | ||||
| ) | ) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError( | raise ZstdError( | ||||
| "could not load compression dictionary: %s" % _zstd_error(zresult) | "could not load compression dictionary: %s" | ||||
| % _zstd_error(zresult) | |||||
| ) | ) | ||||
| def memory_size(self): | def memory_size(self): | ||||
| return lib.ZSTD_sizeof_CCtx(self._cctx) | return lib.ZSTD_sizeof_CCtx(self._cctx) | ||||
| def compress(self, data): | def compress(self, data): | ||||
| lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | ||||
| data_buffer = ffi.from_buffer(data) | data_buffer = ffi.from_buffer(data) | ||||
| dest_size = lib.ZSTD_compressBound(len(data_buffer)) | dest_size = lib.ZSTD_compressBound(len(data_buffer)) | ||||
| out = new_nonzero("char[]", dest_size) | out = new_nonzero("char[]", dest_size) | ||||
| zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, len(data_buffer)) | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, len(data_buffer)) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("error setting source size: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
| "error setting source size: %s" % _zstd_error(zresult) | |||||
| ) | |||||
| out_buffer = ffi.new("ZSTD_outBuffer *") | out_buffer = ffi.new("ZSTD_outBuffer *") | ||||
| in_buffer = ffi.new("ZSTD_inBuffer *") | in_buffer = ffi.new("ZSTD_inBuffer *") | ||||
| out_buffer.dst = out | out_buffer.dst = out | ||||
| out_buffer.size = dest_size | out_buffer.size = dest_size | ||||
| out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
| def compressobj(self, size=-1): | def compressobj(self, size=-1): | ||||
| lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | ||||
| if size < 0: | if size < 0: | ||||
| size = lib.ZSTD_CONTENTSIZE_UNKNOWN | size = lib.ZSTD_CONTENTSIZE_UNKNOWN | ||||
| zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("error setting source size: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
| "error setting source size: %s" % _zstd_error(zresult) | |||||
| ) | |||||
| cobj = ZstdCompressionObj() | cobj = ZstdCompressionObj() | ||||
| cobj._out = ffi.new("ZSTD_outBuffer *") | cobj._out = ffi.new("ZSTD_outBuffer *") | ||||
| cobj._dst_buffer = ffi.new("char[]", COMPRESSION_RECOMMENDED_OUTPUT_SIZE) | cobj._dst_buffer = ffi.new( | ||||
| "char[]", COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||||
| ) | |||||
| cobj._out.dst = cobj._dst_buffer | cobj._out.dst = cobj._dst_buffer | ||||
| cobj._out.size = COMPRESSION_RECOMMENDED_OUTPUT_SIZE | cobj._out.size = COMPRESSION_RECOMMENDED_OUTPUT_SIZE | ||||
| cobj._out.pos = 0 | cobj._out.pos = 0 | ||||
| cobj._compressor = self | cobj._compressor = self | ||||
| cobj._finished = False | cobj._finished = False | ||||
| return cobj | return cobj | ||||
| def chunker(self, size=-1, chunk_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE): | def chunker(self, size=-1, chunk_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE): | ||||
| lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | ||||
| if size < 0: | if size < 0: | ||||
| size = lib.ZSTD_CONTENTSIZE_UNKNOWN | size = lib.ZSTD_CONTENTSIZE_UNKNOWN | ||||
| zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("error setting source size: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
| "error setting source size: %s" % _zstd_error(zresult) | |||||
| ) | |||||
| return ZstdCompressionChunker(self, chunk_size=chunk_size) | return ZstdCompressionChunker(self, chunk_size=chunk_size) | ||||
| def copy_stream( | def copy_stream( | ||||
| self, | self, | ||||
| ifh, | ifh, | ||||
| ofh, | ofh, | ||||
| size=-1, | size=-1, | ||||
| read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE, | read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE, | ||||
| write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE, | write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE, | ||||
| ): | ): | ||||
| if not hasattr(ifh, "read"): | if not hasattr(ifh, "read"): | ||||
| raise ValueError("first argument must have a read() method") | raise ValueError("first argument must have a read() method") | ||||
| if not hasattr(ofh, "write"): | if not hasattr(ofh, "write"): | ||||
| raise ValueError("second argument must have a write() method") | raise ValueError("second argument must have a write() method") | ||||
| lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | ||||
| if size < 0: | if size < 0: | ||||
| size = lib.ZSTD_CONTENTSIZE_UNKNOWN | size = lib.ZSTD_CONTENTSIZE_UNKNOWN | ||||
| zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("error setting source size: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
| "error setting source size: %s" % _zstd_error(zresult) | |||||
| ) | |||||
| in_buffer = ffi.new("ZSTD_inBuffer *") | in_buffer = ffi.new("ZSTD_inBuffer *") | ||||
| out_buffer = ffi.new("ZSTD_outBuffer *") | out_buffer = ffi.new("ZSTD_outBuffer *") | ||||
| dst_buffer = ffi.new("char[]", write_size) | dst_buffer = ffi.new("char[]", write_size) | ||||
| out_buffer.dst = dst_buffer | out_buffer.dst = dst_buffer | ||||
| out_buffer.size = write_size | out_buffer.size = write_size | ||||
| out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
| in_buffer.size = len(data_buffer) | in_buffer.size = len(data_buffer) | ||||
| in_buffer.pos = 0 | in_buffer.pos = 0 | ||||
| while in_buffer.pos < in_buffer.size: | while in_buffer.pos < in_buffer.size: | ||||
| zresult = lib.ZSTD_compressStream2( | zresult = lib.ZSTD_compressStream2( | ||||
| self._cctx, out_buffer, in_buffer, lib.ZSTD_e_continue | self._cctx, out_buffer, in_buffer, lib.ZSTD_e_continue | ||||
| ) | ) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
| "zstd compress error: %s" % _zstd_error(zresult) | |||||
| ) | |||||
| if out_buffer.pos: | if out_buffer.pos: | ||||
| ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos)) | ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos)) | ||||
| total_write += out_buffer.pos | total_write += out_buffer.pos | ||||
| out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
| # We've finished reading. Flush the compressor. | # We've finished reading. Flush the compressor. | ||||
| while True: | while True: | ||||
| except Exception: | except Exception: | ||||
| pass | pass | ||||
| if size < 0: | if size < 0: | ||||
| size = lib.ZSTD_CONTENTSIZE_UNKNOWN | size = lib.ZSTD_CONTENTSIZE_UNKNOWN | ||||
| zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("error setting source size: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
| "error setting source size: %s" % _zstd_error(zresult) | |||||
| ) | |||||
| return ZstdCompressionReader(self, source, read_size) | return ZstdCompressionReader(self, source, read_size) | ||||
| def stream_writer( | def stream_writer( | ||||
| self, | self, | ||||
| writer, | writer, | ||||
| size=-1, | size=-1, | ||||
| write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE, | write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE, | ||||
| write_return_read=False, | write_return_read=False, | ||||
| ): | ): | ||||
| if not hasattr(writer, "write"): | if not hasattr(writer, "write"): | ||||
| raise ValueError("must pass an object with a write() method") | raise ValueError("must pass an object with a write() method") | ||||
| lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | ||||
| if size < 0: | if size < 0: | ||||
| size = lib.ZSTD_CONTENTSIZE_UNKNOWN | size = lib.ZSTD_CONTENTSIZE_UNKNOWN | ||||
| return ZstdCompressionWriter(self, writer, size, write_size, write_return_read) | return ZstdCompressionWriter( | ||||
| self, writer, size, write_size, write_return_read | |||||
| ) | |||||
| write_to = stream_writer | write_to = stream_writer | ||||
| def read_to_iter( | def read_to_iter( | ||||
| self, | self, | ||||
| reader, | reader, | ||||
| size=-1, | size=-1, | ||||
| read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE, | read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE, | ||||
| lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | ||||
| if size < 0: | if size < 0: | ||||
| size = lib.ZSTD_CONTENTSIZE_UNKNOWN | size = lib.ZSTD_CONTENTSIZE_UNKNOWN | ||||
| zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("error setting source size: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
| "error setting source size: %s" % _zstd_error(zresult) | |||||
| ) | |||||
| in_buffer = ffi.new("ZSTD_inBuffer *") | in_buffer = ffi.new("ZSTD_inBuffer *") | ||||
| out_buffer = ffi.new("ZSTD_outBuffer *") | out_buffer = ffi.new("ZSTD_outBuffer *") | ||||
| in_buffer.src = ffi.NULL | in_buffer.src = ffi.NULL | ||||
| in_buffer.size = 0 | in_buffer.size = 0 | ||||
| in_buffer.pos = 0 | in_buffer.pos = 0 | ||||
| in_buffer.size = len(read_buffer) | in_buffer.size = len(read_buffer) | ||||
| in_buffer.pos = 0 | in_buffer.pos = 0 | ||||
| while in_buffer.pos < in_buffer.size: | while in_buffer.pos < in_buffer.size: | ||||
| zresult = lib.ZSTD_compressStream2( | zresult = lib.ZSTD_compressStream2( | ||||
| self._cctx, out_buffer, in_buffer, lib.ZSTD_e_continue | self._cctx, out_buffer, in_buffer, lib.ZSTD_e_continue | ||||
| ) | ) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
| "zstd compress error: %s" % _zstd_error(zresult) | |||||
| ) | |||||
| if out_buffer.pos: | if out_buffer.pos: | ||||
| data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | ||||
| out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
| yield data | yield data | ||||
| assert out_buffer.pos == 0 | assert out_buffer.pos == 0 | ||||
| def get_frame_parameters(data): | def get_frame_parameters(data): | ||||
| params = ffi.new("ZSTD_frameHeader *") | params = ffi.new("ZSTD_frameHeader *") | ||||
| data_buffer = ffi.from_buffer(data) | data_buffer = ffi.from_buffer(data) | ||||
| zresult = lib.ZSTD_getFrameHeader(params, data_buffer, len(data_buffer)) | zresult = lib.ZSTD_getFrameHeader(params, data_buffer, len(data_buffer)) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("cannot get frame parameters: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
| "cannot get frame parameters: %s" % _zstd_error(zresult) | |||||
| ) | |||||
| if zresult: | if zresult: | ||||
| raise ZstdError("not enough data for frame parameters; need %d bytes" % zresult) | raise ZstdError( | ||||
| "not enough data for frame parameters; need %d bytes" % zresult | |||||
| ) | |||||
| return FrameParameters(params[0]) | return FrameParameters(params[0]) | ||||
| class ZstdCompressionDict(object): | class ZstdCompressionDict(object): | ||||
| def __init__(self, data, dict_type=DICT_TYPE_AUTO, k=0, d=0): | def __init__(self, data, dict_type=DICT_TYPE_AUTO, k=0, d=0): | ||||
| assert isinstance(data, bytes_type) | assert isinstance(data, bytes_type) | ||||
| self._data = data | self._data = data | ||||
| self.k = k | self.k = k | ||||
| self.d = d | self.d = d | ||||
| if dict_type not in (DICT_TYPE_AUTO, DICT_TYPE_RAWCONTENT, DICT_TYPE_FULLDICT): | if dict_type not in ( | ||||
| DICT_TYPE_AUTO, | |||||
| DICT_TYPE_RAWCONTENT, | |||||
| DICT_TYPE_FULLDICT, | |||||
| ): | |||||
| raise ValueError( | raise ValueError( | ||||
| "invalid dictionary load mode: %d; must use " "DICT_TYPE_* constants" | "invalid dictionary load mode: %d; must use " | ||||
| "DICT_TYPE_* constants" | |||||
| ) | ) | ||||
| self._dict_type = dict_type | self._dict_type = dict_type | ||||
| self._cdict = None | self._cdict = None | ||||
| def __len__(self): | def __len__(self): | ||||
| return len(self._data) | return len(self._data) | ||||
| def dict_id(self): | def dict_id(self): | ||||
| return int_type(lib.ZDICT_getDictID(self._data, len(self._data))) | return int_type(lib.ZDICT_getDictID(self._data, len(self._data))) | ||||
| def as_bytes(self): | def as_bytes(self): | ||||
| return self._data | return self._data | ||||
| def precompute_compress(self, level=0, compression_params=None): | def precompute_compress(self, level=0, compression_params=None): | ||||
| if level and compression_params: | if level and compression_params: | ||||
| raise ValueError("must only specify one of level or " "compression_params") | raise ValueError( | ||||
| "must only specify one of level or " "compression_params" | |||||
| ) | |||||
| if not level and not compression_params: | if not level and not compression_params: | ||||
| raise ValueError("must specify one of level or compression_params") | raise ValueError("must specify one of level or compression_params") | ||||
| if level: | if level: | ||||
| cparams = lib.ZSTD_getCParams(level, 0, len(self._data)) | cparams = lib.ZSTD_getCParams(level, 0, len(self._data)) | ||||
| else: | else: | ||||
| cparams = ffi.new("ZSTD_compressionParameters") | cparams = ffi.new("ZSTD_compressionParameters") | ||||
| lib.ZSTD_dlm_byRef, | lib.ZSTD_dlm_byRef, | ||||
| self._dict_type, | self._dict_type, | ||||
| lib.ZSTD_defaultCMem, | lib.ZSTD_defaultCMem, | ||||
| ) | ) | ||||
| if ddict == ffi.NULL: | if ddict == ffi.NULL: | ||||
| raise ZstdError("could not create decompression dict") | raise ZstdError("could not create decompression dict") | ||||
| ddict = ffi.gc(ddict, lib.ZSTD_freeDDict, size=lib.ZSTD_sizeof_DDict(ddict)) | ddict = ffi.gc( | ||||
| ddict, lib.ZSTD_freeDDict, size=lib.ZSTD_sizeof_DDict(ddict) | |||||
| ) | |||||
| self.__dict__["_ddict"] = ddict | self.__dict__["_ddict"] = ddict | ||||
| return ddict | return ddict | ||||
| def train_dictionary( | def train_dictionary( | ||||
| dict_size, | dict_size, | ||||
| samples, | samples, | ||||
| chunks = [] | chunks = [] | ||||
| while True: | while True: | ||||
| zresult = lib.ZSTD_decompressStream( | zresult = lib.ZSTD_decompressStream( | ||||
| self._decompressor._dctx, out_buffer, in_buffer | self._decompressor._dctx, out_buffer, in_buffer | ||||
| ) | ) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("zstd decompressor error: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
| "zstd decompressor error: %s" % _zstd_error(zresult) | |||||
| ) | |||||
| if zresult == 0: | if zresult == 0: | ||||
| self._finished = True | self._finished = True | ||||
| self._decompressor = None | self._decompressor = None | ||||
| if out_buffer.pos: | if out_buffer.pos: | ||||
| chunks.append(ffi.buffer(out_buffer.dst, out_buffer.pos)[:]) | chunks.append(ffi.buffer(out_buffer.dst, out_buffer.pos)[:]) | ||||
| def seek(self, pos, whence=os.SEEK_SET): | def seek(self, pos, whence=os.SEEK_SET): | ||||
| if self._closed: | if self._closed: | ||||
| raise ValueError("stream is closed") | raise ValueError("stream is closed") | ||||
| read_amount = 0 | read_amount = 0 | ||||
| if whence == os.SEEK_SET: | if whence == os.SEEK_SET: | ||||
| if pos < 0: | if pos < 0: | ||||
| raise ValueError("cannot seek to negative position with SEEK_SET") | raise ValueError( | ||||
| "cannot seek to negative position with SEEK_SET" | |||||
| ) | |||||
| if pos < self._bytes_decompressed: | if pos < self._bytes_decompressed: | ||||
| raise ValueError("cannot seek zstd decompression stream " "backwards") | raise ValueError( | ||||
| "cannot seek zstd decompression stream " "backwards" | |||||
| ) | |||||
| read_amount = pos - self._bytes_decompressed | read_amount = pos - self._bytes_decompressed | ||||
| elif whence == os.SEEK_CUR: | elif whence == os.SEEK_CUR: | ||||
| if pos < 0: | if pos < 0: | ||||
| raise ValueError("cannot seek zstd decompression stream " "backwards") | raise ValueError( | ||||
| "cannot seek zstd decompression stream " "backwards" | |||||
| ) | |||||
| read_amount = pos | read_amount = pos | ||||
| elif whence == os.SEEK_END: | elif whence == os.SEEK_END: | ||||
| raise ValueError( | raise ValueError( | ||||
| "zstd decompression streams cannot be seeked " "with SEEK_END" | "zstd decompression streams cannot be seeked " "with SEEK_END" | ||||
| ) | ) | ||||
| while read_amount: | while read_amount: | ||||
| result = self.read(min(read_amount, DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE)) | result = self.read( | ||||
| min(read_amount, DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE) | |||||
| ) | |||||
| if not result: | if not result: | ||||
| break | break | ||||
| read_amount -= len(result) | read_amount -= len(result) | ||||
| return self._bytes_decompressed | return self._bytes_decompressed | ||||
| out_buffer.size = len(dst_buffer) | out_buffer.size = len(dst_buffer) | ||||
| out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
| dctx = self._decompressor._dctx | dctx = self._decompressor._dctx | ||||
| while in_buffer.pos < in_buffer.size: | while in_buffer.pos < in_buffer.size: | ||||
| zresult = lib.ZSTD_decompressStream(dctx, out_buffer, in_buffer) | zresult = lib.ZSTD_decompressStream(dctx, out_buffer, in_buffer) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("zstd decompress error: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
| "zstd decompress error: %s" % _zstd_error(zresult) | |||||
| ) | |||||
| if out_buffer.pos: | if out_buffer.pos: | ||||
| self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:]) | self._writer.write( | ||||
| ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |||||
| ) | |||||
| total_write += out_buffer.pos | total_write += out_buffer.pos | ||||
| out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
| if self._write_return_read: | if self._write_return_read: | ||||
| return in_buffer.pos | return in_buffer.pos | ||||
| else: | else: | ||||
| return total_write | return total_write | ||||
| def memory_size(self): | def memory_size(self): | ||||
| return lib.ZSTD_sizeof_DCtx(self._dctx) | return lib.ZSTD_sizeof_DCtx(self._dctx) | ||||
| def decompress(self, data, max_output_size=0): | def decompress(self, data, max_output_size=0): | ||||
| self._ensure_dctx() | self._ensure_dctx() | ||||
| data_buffer = ffi.from_buffer(data) | data_buffer = ffi.from_buffer(data) | ||||
| output_size = lib.ZSTD_getFrameContentSize(data_buffer, len(data_buffer)) | output_size = lib.ZSTD_getFrameContentSize( | ||||
| data_buffer, len(data_buffer) | |||||
| ) | |||||
| if output_size == lib.ZSTD_CONTENTSIZE_ERROR: | if output_size == lib.ZSTD_CONTENTSIZE_ERROR: | ||||
| raise ZstdError("error determining content size from frame header") | raise ZstdError("error determining content size from frame header") | ||||
| elif output_size == 0: | elif output_size == 0: | ||||
| return b"" | return b"" | ||||
| elif output_size == lib.ZSTD_CONTENTSIZE_UNKNOWN: | elif output_size == lib.ZSTD_CONTENTSIZE_UNKNOWN: | ||||
| if not max_output_size: | if not max_output_size: | ||||
| raise ZstdError("could not determine content size in frame header") | raise ZstdError( | ||||
| "could not determine content size in frame header" | |||||
| ) | |||||
| result_buffer = ffi.new("char[]", max_output_size) | result_buffer = ffi.new("char[]", max_output_size) | ||||
| result_size = max_output_size | result_size = max_output_size | ||||
| output_size = 0 | output_size = 0 | ||||
| else: | else: | ||||
| result_buffer = ffi.new("char[]", output_size) | result_buffer = ffi.new("char[]", output_size) | ||||
| result_size = output_size | result_size = output_size | ||||
| out_buffer = ffi.new("ZSTD_outBuffer *") | out_buffer = ffi.new("ZSTD_outBuffer *") | ||||
| out_buffer.dst = result_buffer | out_buffer.dst = result_buffer | ||||
| out_buffer.size = result_size | out_buffer.size = result_size | ||||
| out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
| in_buffer = ffi.new("ZSTD_inBuffer *") | in_buffer = ffi.new("ZSTD_inBuffer *") | ||||
| in_buffer.src = data_buffer | in_buffer.src = data_buffer | ||||
| in_buffer.size = len(data_buffer) | in_buffer.size = len(data_buffer) | ||||
| in_buffer.pos = 0 | in_buffer.pos = 0 | ||||
| zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer) | zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("decompression error: %s" % _zstd_error(zresult)) | raise ZstdError("decompression error: %s" % _zstd_error(zresult)) | ||||
| elif zresult: | elif zresult: | ||||
| raise ZstdError("decompression error: did not decompress full frame") | raise ZstdError( | ||||
| "decompression error: did not decompress full frame" | |||||
| ) | |||||
| elif output_size and out_buffer.pos != output_size: | elif output_size and out_buffer.pos != output_size: | ||||
| raise ZstdError( | raise ZstdError( | ||||
| "decompression error: decompressed %d bytes; expected %d" | "decompression error: decompressed %d bytes; expected %d" | ||||
| % (zresult, output_size) | % (zresult, output_size) | ||||
| ) | ) | ||||
| return ffi.buffer(result_buffer, out_buffer.pos)[:] | return ffi.buffer(result_buffer, out_buffer.pos)[:] | ||||
| def stream_reader( | def stream_reader( | ||||
| self, | self, | ||||
| source, | source, | ||||
| read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE, | read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE, | ||||
| read_across_frames=False, | read_across_frames=False, | ||||
| ): | ): | ||||
| self._ensure_dctx() | self._ensure_dctx() | ||||
| return ZstdDecompressionReader(self, source, read_size, read_across_frames) | return ZstdDecompressionReader( | ||||
| self, source, read_size, read_across_frames | |||||
| ) | |||||
| def decompressobj(self, write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE): | def decompressobj(self, write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE): | ||||
| if write_size < 1: | if write_size < 1: | ||||
| raise ValueError("write_size must be positive") | raise ValueError("write_size must be positive") | ||||
| self._ensure_dctx() | self._ensure_dctx() | ||||
| return ZstdDecompressionObj(self, write_size=write_size) | return ZstdDecompressionObj(self, write_size=write_size) | ||||
| read_buffer = ffi.from_buffer(read_result) | read_buffer = ffi.from_buffer(read_result) | ||||
| in_buffer.src = read_buffer | in_buffer.src = read_buffer | ||||
| in_buffer.size = len(read_buffer) | in_buffer.size = len(read_buffer) | ||||
| in_buffer.pos = 0 | in_buffer.pos = 0 | ||||
| while in_buffer.pos < in_buffer.size: | while in_buffer.pos < in_buffer.size: | ||||
| assert out_buffer.pos == 0 | assert out_buffer.pos == 0 | ||||
| zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer) | zresult = lib.ZSTD_decompressStream( | ||||
| self._dctx, out_buffer, in_buffer | |||||
| ) | |||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("zstd decompress error: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
| "zstd decompress error: %s" % _zstd_error(zresult) | |||||
| ) | |||||
| if out_buffer.pos: | if out_buffer.pos: | ||||
| data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | ||||
| out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
| yield data | yield data | ||||
| if zresult == 0: | if zresult == 0: | ||||
| return | return | ||||
| self, | self, | ||||
| writer, | writer, | ||||
| write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE, | write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE, | ||||
| write_return_read=False, | write_return_read=False, | ||||
| ): | ): | ||||
| if not hasattr(writer, "write"): | if not hasattr(writer, "write"): | ||||
| raise ValueError("must pass an object with a write() method") | raise ValueError("must pass an object with a write() method") | ||||
| return ZstdDecompressionWriter(self, writer, write_size, write_return_read) | return ZstdDecompressionWriter( | ||||
| self, writer, write_size, write_return_read | |||||
| ) | |||||
| write_to = stream_writer | write_to = stream_writer | ||||
| def copy_stream( | def copy_stream( | ||||
| self, | self, | ||||
| ifh, | ifh, | ||||
| ofh, | ofh, | ||||
| read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE, | read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE, | ||||
| data_buffer = ffi.from_buffer(data) | data_buffer = ffi.from_buffer(data) | ||||
| total_read += len(data_buffer) | total_read += len(data_buffer) | ||||
| in_buffer.src = data_buffer | in_buffer.src = data_buffer | ||||
| in_buffer.size = len(data_buffer) | in_buffer.size = len(data_buffer) | ||||
| in_buffer.pos = 0 | in_buffer.pos = 0 | ||||
| # Flush all read data to output. | # Flush all read data to output. | ||||
| while in_buffer.pos < in_buffer.size: | while in_buffer.pos < in_buffer.size: | ||||
| zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer) | zresult = lib.ZSTD_decompressStream( | ||||
| self._dctx, out_buffer, in_buffer | |||||
| ) | |||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError( | raise ZstdError( | ||||
| "zstd decompressor error: %s" % _zstd_error(zresult) | "zstd decompressor error: %s" % _zstd_error(zresult) | ||||
| ) | ) | ||||
| if out_buffer.pos: | if out_buffer.pos: | ||||
| ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos)) | ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos)) | ||||
| total_write += out_buffer.pos | total_write += out_buffer.pos | ||||
| # First chunk should not be using a dictionary. We handle it specially. | # First chunk should not be using a dictionary. We handle it specially. | ||||
| chunk = frames[0] | chunk = frames[0] | ||||
| if not isinstance(chunk, bytes_type): | if not isinstance(chunk, bytes_type): | ||||
| raise ValueError("chunk 0 must be bytes") | raise ValueError("chunk 0 must be bytes") | ||||
| # All chunks should be zstd frames and should have content size set. | # All chunks should be zstd frames and should have content size set. | ||||
| chunk_buffer = ffi.from_buffer(chunk) | chunk_buffer = ffi.from_buffer(chunk) | ||||
| params = ffi.new("ZSTD_frameHeader *") | params = ffi.new("ZSTD_frameHeader *") | ||||
| zresult = lib.ZSTD_getFrameHeader(params, chunk_buffer, len(chunk_buffer)) | zresult = lib.ZSTD_getFrameHeader( | ||||
| params, chunk_buffer, len(chunk_buffer) | |||||
| ) | |||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ValueError("chunk 0 is not a valid zstd frame") | raise ValueError("chunk 0 is not a valid zstd frame") | ||||
| elif zresult: | elif zresult: | ||||
| raise ValueError("chunk 0 is too small to contain a zstd frame") | raise ValueError("chunk 0 is too small to contain a zstd frame") | ||||
| if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN: | if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN: | ||||
| raise ValueError("chunk 0 missing content size in frame") | raise ValueError("chunk 0 missing content size in frame") | ||||
| self._ensure_dctx(load_dict=False) | self._ensure_dctx(load_dict=False) | ||||
| last_buffer = ffi.new("char[]", params.frameContentSize) | last_buffer = ffi.new("char[]", params.frameContentSize) | ||||
| out_buffer = ffi.new("ZSTD_outBuffer *") | out_buffer = ffi.new("ZSTD_outBuffer *") | ||||
| out_buffer.dst = last_buffer | out_buffer.dst = last_buffer | ||||
| out_buffer.size = len(last_buffer) | out_buffer.size = len(last_buffer) | ||||
| out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
| in_buffer = ffi.new("ZSTD_inBuffer *") | in_buffer = ffi.new("ZSTD_inBuffer *") | ||||
| in_buffer.src = chunk_buffer | in_buffer.src = chunk_buffer | ||||
| in_buffer.size = len(chunk_buffer) | in_buffer.size = len(chunk_buffer) | ||||
| in_buffer.pos = 0 | in_buffer.pos = 0 | ||||
| zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer) | zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("could not decompress chunk 0: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
| "could not decompress chunk 0: %s" % _zstd_error(zresult) | |||||
| ) | |||||
| elif zresult: | elif zresult: | ||||
| raise ZstdError("chunk 0 did not decompress full frame") | raise ZstdError("chunk 0 did not decompress full frame") | ||||
| # Special case of chain length of 1 | # Special case of chain length of 1 | ||||
| if len(frames) == 1: | if len(frames) == 1: | ||||
| return ffi.buffer(last_buffer, len(last_buffer))[:] | return ffi.buffer(last_buffer, len(last_buffer))[:] | ||||
| i = 1 | i = 1 | ||||
| while i < len(frames): | while i < len(frames): | ||||
| chunk = frames[i] | chunk = frames[i] | ||||
| if not isinstance(chunk, bytes_type): | if not isinstance(chunk, bytes_type): | ||||
| raise ValueError("chunk %d must be bytes" % i) | raise ValueError("chunk %d must be bytes" % i) | ||||
| chunk_buffer = ffi.from_buffer(chunk) | chunk_buffer = ffi.from_buffer(chunk) | ||||
| zresult = lib.ZSTD_getFrameHeader(params, chunk_buffer, len(chunk_buffer)) | zresult = lib.ZSTD_getFrameHeader( | ||||
| params, chunk_buffer, len(chunk_buffer) | |||||
| ) | |||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ValueError("chunk %d is not a valid zstd frame" % i) | raise ValueError("chunk %d is not a valid zstd frame" % i) | ||||
| elif zresult: | elif zresult: | ||||
| raise ValueError("chunk %d is too small to contain a zstd frame" % i) | raise ValueError( | ||||
| "chunk %d is too small to contain a zstd frame" % i | |||||
| ) | |||||
| if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN: | if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN: | ||||
| raise ValueError("chunk %d missing content size in frame" % i) | raise ValueError("chunk %d missing content size in frame" % i) | ||||
| dest_buffer = ffi.new("char[]", params.frameContentSize) | dest_buffer = ffi.new("char[]", params.frameContentSize) | ||||
| out_buffer.dst = dest_buffer | out_buffer.dst = dest_buffer | ||||
| out_buffer.size = len(dest_buffer) | out_buffer.size = len(dest_buffer) | ||||
| out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
| in_buffer.src = chunk_buffer | in_buffer.src = chunk_buffer | ||||
| in_buffer.size = len(chunk_buffer) | in_buffer.size = len(chunk_buffer) | ||||
| in_buffer.pos = 0 | in_buffer.pos = 0 | ||||
| zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer) | zresult = lib.ZSTD_decompressStream( | ||||
| self._dctx, out_buffer, in_buffer | |||||
| ) | |||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError( | raise ZstdError( | ||||
| "could not decompress chunk %d: %s" % _zstd_error(zresult) | "could not decompress chunk %d: %s" % _zstd_error(zresult) | ||||
| ) | ) | ||||
| elif zresult: | elif zresult: | ||||
| raise ZstdError("chunk %d did not decompress full frame" % i) | raise ZstdError("chunk %d did not decompress full frame" % i) | ||||
| last_buffer = dest_buffer | last_buffer = dest_buffer | ||||
| i += 1 | i += 1 | ||||
| return ffi.buffer(last_buffer, len(last_buffer))[:] | return ffi.buffer(last_buffer, len(last_buffer))[:] | ||||
| def _ensure_dctx(self, load_dict=True): | def _ensure_dctx(self, load_dict=True): | ||||
| lib.ZSTD_DCtx_reset(self._dctx, lib.ZSTD_reset_session_only) | lib.ZSTD_DCtx_reset(self._dctx, lib.ZSTD_reset_session_only) | ||||
| if self._max_window_size: | if self._max_window_size: | ||||
| zresult = lib.ZSTD_DCtx_setMaxWindowSize(self._dctx, self._max_window_size) | zresult = lib.ZSTD_DCtx_setMaxWindowSize( | ||||
| self._dctx, self._max_window_size | |||||
| ) | |||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError( | raise ZstdError( | ||||
| "unable to set max window size: %s" % _zstd_error(zresult) | "unable to set max window size: %s" % _zstd_error(zresult) | ||||
| ) | ) | ||||
| zresult = lib.ZSTD_DCtx_setFormat(self._dctx, self._format) | zresult = lib.ZSTD_DCtx_setFormat(self._dctx, self._format) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError("unable to set decoding format: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
| "unable to set decoding format: %s" % _zstd_error(zresult) | |||||
| ) | |||||
| if self._dict_data and load_dict: | if self._dict_data and load_dict: | ||||
| zresult = lib.ZSTD_DCtx_refDDict(self._dctx, self._dict_data._ddict) | zresult = lib.ZSTD_DCtx_refDDict(self._dctx, self._dict_data._ddict) | ||||
| if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
| raise ZstdError( | raise ZstdError( | ||||
| "unable to reference prepared dictionary: %s" % _zstd_error(zresult) | "unable to reference prepared dictionary: %s" | ||||
| % _zstd_error(zresult) | |||||
| ) | ) | ||||
| #require black | #require black | ||||
| $ cd $RUNTESTDIR/.. | $ cd $RUNTESTDIR/.. | ||||
| $ black --config=black.toml --check --diff `hg files 'set:(**.py + grep("^#!.*python")) - mercurial/thirdparty/** - "contrib/python-zstandard/**"'` | $ black --config=black.toml --check --diff `hg files 'set:(**.py + grep("^#!.*python")) - mercurial/thirdparty/**'` | ||||