I made this change upstream and it will make it into the next
release of python-zstandard. I figured I'd send it Mercurial's
way because it will allow us to drop this directory from the black
exclusion list.
- skip-blame blackening
mharbison72 | |
pulkit |
hg-reviewers |
I made this change upstream and it will make it into the next
release of python-zstandard. I figured I'd send it Mercurial's
way because it will allow us to drop this directory from the black
exclusion list.
Automatic diff as part of commit; lint not applicable. |
Automatic diff as part of commit; unit tests not applicable. |
"contrib/examples/fix.hgrc" also skips "contrib/python-zstandard/**". It seems like the intent of this is to allow local edits to python-zstandard (or we wouldn't start checking the format), so maybe it should be covered by fix too?
I'm -0 on auto-formatting third-party code at all, and that includes zstandard stuff. Why do we care if it's in the ignorelist?
Path | Packages | |||
---|---|---|---|---|
M | black.toml (1 line) | |||
M | contrib/examples/fix.hgrc (2 lines) | |||
M | contrib/python-zstandard/make_cffi.py (7 lines) | |||
M | contrib/python-zstandard/setup.py (4 lines) | |||
M | contrib/python-zstandard/setup_zstd.py (8 lines) | |||
M | contrib/python-zstandard/tests/common.py (12 lines) | |||
M | contrib/python-zstandard/tests/test_buffer_util.py (15 lines) | |||
M | contrib/python-zstandard/tests/test_compressor.py (69 lines) | |||
M | contrib/python-zstandard/tests/test_compressor_fuzzing.py (94 lines) | |||
M | contrib/python-zstandard/tests/test_data_structures.py (28 lines) | |||
M | contrib/python-zstandard/tests/test_data_structures_fuzzing.py (22 lines) | |||
M | contrib/python-zstandard/tests/test_decompressor.py (90 lines) | |||
M | contrib/python-zstandard/tests/test_decompressor_fuzzing.py (35 lines) | |||
M | contrib/python-zstandard/tests/test_train_dictionary.py (20 lines) | |||
M | contrib/python-zstandard/zstandard/cffi.py (310 lines) | |||
M | tests/test-check-format.t (2 lines) |
[tool.black] | [tool.black] | ||||
line-length = 80 | line-length = 80 | ||||
exclude = ''' | exclude = ''' | ||||
build/ | build/ | ||||
| wheelhouse/ | | wheelhouse/ | ||||
| dist/ | | dist/ | ||||
| packages/ | | packages/ | ||||
| \.hg/ | | \.hg/ | ||||
| \.mypy_cache/ | | \.mypy_cache/ | ||||
| \.venv/ | | \.venv/ | ||||
| mercurial/thirdparty/ | | mercurial/thirdparty/ | ||||
| contrib/python-zstandard/ | |||||
''' | ''' | ||||
skip-string-normalization = true | skip-string-normalization = true | ||||
quiet = true | quiet = true |
[fix] | [fix] | ||||
clang-format:command = clang-format --style file | clang-format:command = clang-format --style file | ||||
clang-format:pattern = set:(**.c or **.cc or **.h) and not "include:contrib/clang-format-ignorelist" | clang-format:pattern = set:(**.c or **.cc or **.h) and not "include:contrib/clang-format-ignorelist" | ||||
rustfmt:command = rustfmt +nightly | rustfmt:command = rustfmt +nightly | ||||
rustfmt:pattern = set:**.rs | rustfmt:pattern = set:**.rs | ||||
black:command = black --config=black.toml - | black:command = black --config=black.toml - | ||||
black:pattern = set:**.py - mercurial/thirdparty/** - "contrib/python-zstandard/**" | black:pattern = set:**.py - mercurial/thirdparty/** | ||||
# Mercurial doesn't have any Go code, but if we did this is how we | # Mercurial doesn't have any Go code, but if we did this is how we | ||||
# would configure `hg fix` for Go: | # would configure `hg fix` for Go: | ||||
go:command = gofmt | go:command = gofmt | ||||
go:pattern = set:**.go | go:pattern = set:**.go |
"dictBuilder/fastcover.c", | "dictBuilder/fastcover.c", | ||||
"dictBuilder/divsufsort.c", | "dictBuilder/divsufsort.c", | ||||
"dictBuilder/zdict.c", | "dictBuilder/zdict.c", | ||||
) | ) | ||||
] | ] | ||||
# Headers whose preprocessed output will be fed into cdef(). | # Headers whose preprocessed output will be fed into cdef(). | ||||
HEADERS = [ | HEADERS = [ | ||||
os.path.join(HERE, "zstd", *p) for p in (("zstd.h",), ("dictBuilder", "zdict.h"),) | os.path.join(HERE, "zstd", *p) | ||||
for p in (("zstd.h",), ("dictBuilder", "zdict.h"),) | |||||
] | ] | ||||
INCLUDE_DIRS = [ | INCLUDE_DIRS = [ | ||||
os.path.join(HERE, d) | os.path.join(HERE, d) | ||||
for d in ( | for d in ( | ||||
"zstd", | "zstd", | ||||
"zstd/common", | "zstd/common", | ||||
"zstd/compress", | "zstd/compress", | ||||
fd, input_file = tempfile.mkstemp(suffix=".h") | fd, input_file = tempfile.mkstemp(suffix=".h") | ||||
os.write(fd, b"".join(lines)) | os.write(fd, b"".join(lines)) | ||||
os.close(fd) | os.close(fd) | ||||
try: | try: | ||||
env = dict(os.environ) | env = dict(os.environ) | ||||
if getattr(compiler, "_paths", None): | if getattr(compiler, "_paths", None): | ||||
env["PATH"] = compiler._paths | env["PATH"] = compiler._paths | ||||
process = subprocess.Popen(args + [input_file], stdout=subprocess.PIPE, env=env) | process = subprocess.Popen( | ||||
args + [input_file], stdout=subprocess.PIPE, env=env | |||||
) | |||||
output = process.communicate()[0] | output = process.communicate()[0] | ||||
ret = process.poll() | ret = process.poll() | ||||
if ret: | if ret: | ||||
raise Exception("preprocessor exited with error") | raise Exception("preprocessor exited with error") | ||||
return output | return output | ||||
finally: | finally: | ||||
os.unlink(input_file) | os.unlink(input_file) |
for line in fh: | for line in fh: | ||||
if not line.startswith("#define PYTHON_ZSTANDARD_VERSION"): | if not line.startswith("#define PYTHON_ZSTANDARD_VERSION"): | ||||
continue | continue | ||||
version = line.split()[2][1:-1] | version = line.split()[2][1:-1] | ||||
break | break | ||||
if not version: | if not version: | ||||
raise Exception("could not resolve package version; " "this should never happen") | raise Exception( | ||||
"could not resolve package version; " "this should never happen" | |||||
) | |||||
setup( | setup( | ||||
name="zstandard", | name="zstandard", | ||||
version=version, | version=version, | ||||
description="Zstandard bindings for Python", | description="Zstandard bindings for Python", | ||||
long_description=open("README.rst", "r").read(), | long_description=open("README.rst", "r").read(), | ||||
url="https://github.com/indygreg/python-zstandard", | url="https://github.com/indygreg/python-zstandard", | ||||
author="Gregory Szorc", | author="Gregory Szorc", |
""" | """ | ||||
actual_root = os.path.abspath(os.path.dirname(__file__)) | actual_root = os.path.abspath(os.path.dirname(__file__)) | ||||
root = root or actual_root | root = root or actual_root | ||||
sources = set([os.path.join(actual_root, p) for p in ext_sources]) | sources = set([os.path.join(actual_root, p) for p in ext_sources]) | ||||
if not system_zstd: | if not system_zstd: | ||||
sources.update([os.path.join(actual_root, p) for p in zstd_sources]) | sources.update([os.path.join(actual_root, p) for p in zstd_sources]) | ||||
if support_legacy: | if support_legacy: | ||||
sources.update([os.path.join(actual_root, p) for p in zstd_sources_legacy]) | sources.update( | ||||
[os.path.join(actual_root, p) for p in zstd_sources_legacy] | |||||
) | |||||
sources = list(sources) | sources = list(sources) | ||||
include_dirs = set([os.path.join(actual_root, d) for d in ext_includes]) | include_dirs = set([os.path.join(actual_root, d) for d in ext_includes]) | ||||
if not system_zstd: | if not system_zstd: | ||||
include_dirs.update([os.path.join(actual_root, d) for d in zstd_includes]) | include_dirs.update( | ||||
[os.path.join(actual_root, d) for d in zstd_includes] | |||||
) | |||||
if support_legacy: | if support_legacy: | ||||
include_dirs.update( | include_dirs.update( | ||||
[os.path.join(actual_root, d) for d in zstd_includes_legacy] | [os.path.join(actual_root, d) for d in zstd_includes_legacy] | ||||
) | ) | ||||
include_dirs = list(include_dirs) | include_dirs = list(include_dirs) | ||||
depends = [os.path.join(actual_root, p) for p in zstd_depends] | depends = [os.path.join(actual_root, p) for p in zstd_depends] | ||||
mod = imp.load_module("zstandard_cffi", *mod_info) | mod = imp.load_module("zstandard_cffi", *mod_info) | ||||
except ImportError: | except ImportError: | ||||
return cls | return cls | ||||
finally: | finally: | ||||
os.environ.clear() | os.environ.clear() | ||||
os.environ.update(old_env) | os.environ.update(old_env) | ||||
if mod.backend != "cffi": | if mod.backend != "cffi": | ||||
raise Exception("got the zstandard %s backend instead of cffi" % mod.backend) | raise Exception( | ||||
"got the zstandard %s backend instead of cffi" % mod.backend | |||||
) | |||||
# If CFFI version is available, dynamically construct test methods | # If CFFI version is available, dynamically construct test methods | ||||
# that use it. | # that use it. | ||||
for attr in dir(cls): | for attr in dir(cls): | ||||
fn = getattr(cls, attr) | fn = getattr(cls, attr) | ||||
if not inspect.ismethod(fn) and not inspect.isfunction(fn): | if not inspect.ismethod(fn) and not inspect.isfunction(fn): | ||||
continue | continue | ||||
globs["zstd"] = mod | globs["zstd"] = mod | ||||
new_fn = types.FunctionType( | new_fn = types.FunctionType( | ||||
fn.__func__.func_code, | fn.__func__.func_code, | ||||
globs, | globs, | ||||
name, | name, | ||||
fn.__func__.func_defaults, | fn.__func__.func_defaults, | ||||
fn.__func__.func_closure, | fn.__func__.func_closure, | ||||
) | ) | ||||
new_method = types.UnboundMethodType(new_fn, fn.im_self, fn.im_class) | new_method = types.UnboundMethodType( | ||||
new_fn, fn.im_self, fn.im_class | |||||
) | |||||
setattr(cls, name, new_method) | setattr(cls, name, new_method) | ||||
return cls | return cls | ||||
class NonClosingBytesIO(io.BytesIO): | class NonClosingBytesIO(io.BytesIO): | ||||
"""BytesIO that saves the underlying buffer on close(). | """BytesIO that saves the underlying buffer on close(). | ||||
hypothesis.settings.register_profile("default", default_settings) | hypothesis.settings.register_profile("default", default_settings) | ||||
ci_settings = hypothesis.settings(deadline=20000, max_examples=1000) | ci_settings = hypothesis.settings(deadline=20000, max_examples=1000) | ||||
hypothesis.settings.register_profile("ci", ci_settings) | hypothesis.settings.register_profile("ci", ci_settings) | ||||
expensive_settings = hypothesis.settings(deadline=None, max_examples=10000) | expensive_settings = hypothesis.settings(deadline=None, max_examples=10000) | ||||
hypothesis.settings.register_profile("expensive", expensive_settings) | hypothesis.settings.register_profile("expensive", expensive_settings) | ||||
hypothesis.settings.load_profile(os.environ.get("HYPOTHESIS_PROFILE", "default")) | hypothesis.settings.load_profile( | ||||
os.environ.get("HYPOTHESIS_PROFILE", "default") | |||||
) |
self.assertEqual(b[0].offset, 0) | self.assertEqual(b[0].offset, 0) | ||||
self.assertEqual(b[0].tobytes(), b"foo") | self.assertEqual(b[0].tobytes(), b"foo") | ||||
def test_multiple(self): | def test_multiple(self): | ||||
if not hasattr(zstd, "BufferWithSegments"): | if not hasattr(zstd, "BufferWithSegments"): | ||||
self.skipTest("BufferWithSegments not available") | self.skipTest("BufferWithSegments not available") | ||||
b = zstd.BufferWithSegments( | b = zstd.BufferWithSegments( | ||||
b"foofooxfooxy", b"".join([ss.pack(0, 3), ss.pack(3, 4), ss.pack(7, 5)]) | b"foofooxfooxy", | ||||
b"".join([ss.pack(0, 3), ss.pack(3, 4), ss.pack(7, 5)]), | |||||
) | ) | ||||
self.assertEqual(len(b), 3) | self.assertEqual(len(b), 3) | ||||
self.assertEqual(b.size, 12) | self.assertEqual(b.size, 12) | ||||
self.assertEqual(b.tobytes(), b"foofooxfooxy") | self.assertEqual(b.tobytes(), b"foofooxfooxy") | ||||
self.assertEqual(b[0].tobytes(), b"foo") | self.assertEqual(b[0].tobytes(), b"foo") | ||||
self.assertEqual(b[1].tobytes(), b"foox") | self.assertEqual(b[1].tobytes(), b"foox") | ||||
self.assertEqual(b[2].tobytes(), b"fooxy") | self.assertEqual(b[2].tobytes(), b"fooxy") | ||||
class TestBufferWithSegmentsCollection(TestCase): | class TestBufferWithSegmentsCollection(TestCase): | ||||
def test_empty_constructor(self): | def test_empty_constructor(self): | ||||
if not hasattr(zstd, "BufferWithSegmentsCollection"): | if not hasattr(zstd, "BufferWithSegmentsCollection"): | ||||
self.skipTest("BufferWithSegmentsCollection not available") | self.skipTest("BufferWithSegmentsCollection not available") | ||||
with self.assertRaisesRegex(ValueError, "must pass at least 1 argument"): | with self.assertRaisesRegex( | ||||
ValueError, "must pass at least 1 argument" | |||||
): | |||||
zstd.BufferWithSegmentsCollection() | zstd.BufferWithSegmentsCollection() | ||||
def test_argument_validation(self): | def test_argument_validation(self): | ||||
if not hasattr(zstd, "BufferWithSegmentsCollection"): | if not hasattr(zstd, "BufferWithSegmentsCollection"): | ||||
self.skipTest("BufferWithSegmentsCollection not available") | self.skipTest("BufferWithSegmentsCollection not available") | ||||
with self.assertRaisesRegex(TypeError, "arguments must be BufferWithSegments"): | with self.assertRaisesRegex( | ||||
TypeError, "arguments must be BufferWithSegments" | |||||
): | |||||
zstd.BufferWithSegmentsCollection(None) | zstd.BufferWithSegmentsCollection(None) | ||||
with self.assertRaisesRegex(TypeError, "arguments must be BufferWithSegments"): | with self.assertRaisesRegex( | ||||
TypeError, "arguments must be BufferWithSegments" | |||||
): | |||||
zstd.BufferWithSegmentsCollection( | zstd.BufferWithSegmentsCollection( | ||||
zstd.BufferWithSegments(b"foo", ss.pack(0, 3)), None | zstd.BufferWithSegments(b"foo", ss.pack(0, 3)), None | ||||
) | ) | ||||
with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
ValueError, "ZstdBufferWithSegments cannot be empty" | ValueError, "ZstdBufferWithSegments cannot be empty" | ||||
): | ): | ||||
zstd.BufferWithSegmentsCollection(zstd.BufferWithSegments(b"", b"")) | zstd.BufferWithSegmentsCollection(zstd.BufferWithSegments(b"", b"")) |
if sys.version_info[0] >= 3: | if sys.version_info[0] >= 3: | ||||
next = lambda it: it.__next__() | next = lambda it: it.__next__() | ||||
else: | else: | ||||
next = lambda it: it.next() | next = lambda it: it.next() | ||||
def multithreaded_chunk_size(level, source_size=0): | def multithreaded_chunk_size(level, source_size=0): | ||||
params = zstd.ZstdCompressionParameters.from_level(level, source_size=source_size) | params = zstd.ZstdCompressionParameters.from_level( | ||||
level, source_size=source_size | |||||
) | |||||
return 1 << (params.window_log + 2) | return 1 << (params.window_log + 2) | ||||
@make_cffi | @make_cffi | ||||
class TestCompressor(TestCase): | class TestCompressor(TestCase): | ||||
def test_level_bounds(self): | def test_level_bounds(self): | ||||
with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
cctx = zstd.ZstdCompressor(level=3, write_content_size=False) | cctx = zstd.ZstdCompressor(level=3, write_content_size=False) | ||||
result = cctx.compress(b"".join(chunks)) | result = cctx.compress(b"".join(chunks)) | ||||
self.assertEqual(len(result), 999) | self.assertEqual(len(result), 999) | ||||
self.assertEqual(result[0:4], b"\x28\xb5\x2f\xfd") | self.assertEqual(result[0:4], b"\x28\xb5\x2f\xfd") | ||||
# This matches the test for read_to_iter() below. | # This matches the test for read_to_iter() below. | ||||
cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | ||||
result = cctx.compress(b"f" * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE + b"o") | result = cctx.compress( | ||||
b"f" * zstd.COMPRESSION_RECOMMENDED_INPUT_SIZE + b"o" | |||||
) | |||||
self.assertEqual( | self.assertEqual( | ||||
result, | result, | ||||
b"\x28\xb5\x2f\xfd\x00\x40\x54\x00\x00" | b"\x28\xb5\x2f\xfd\x00\x40\x54\x00\x00" | ||||
b"\x10\x66\x66\x01\x00\xfb\xff\x39\xc0" | b"\x10\x66\x66\x01\x00\xfb\xff\x39\xc0" | ||||
b"\x02\x09\x00\x00\x6f", | b"\x02\x09\x00\x00\x6f", | ||||
) | ) | ||||
def test_negative_level(self): | def test_negative_level(self): | ||||
cctx = zstd.ZstdCompressor(level=-4) | cctx = zstd.ZstdCompressor(level=-4) | ||||
result = cctx.compress(b"foo" * 256) | result = cctx.compress(b"foo" * 256) | ||||
def test_no_magic(self): | def test_no_magic(self): | ||||
params = zstd.ZstdCompressionParameters.from_level(1, format=zstd.FORMAT_ZSTD1) | params = zstd.ZstdCompressionParameters.from_level( | ||||
1, format=zstd.FORMAT_ZSTD1 | |||||
) | |||||
cctx = zstd.ZstdCompressor(compression_params=params) | cctx = zstd.ZstdCompressor(compression_params=params) | ||||
magic = cctx.compress(b"foobar") | magic = cctx.compress(b"foobar") | ||||
params = zstd.ZstdCompressionParameters.from_level( | params = zstd.ZstdCompressionParameters.from_level( | ||||
1, format=zstd.FORMAT_ZSTD1_MAGICLESS | 1, format=zstd.FORMAT_ZSTD1_MAGICLESS | ||||
) | ) | ||||
cctx = zstd.ZstdCompressor(compression_params=params) | cctx = zstd.ZstdCompressor(compression_params=params) | ||||
no_magic = cctx.compress(b"foobar") | no_magic = cctx.compress(b"foobar") | ||||
result = cctx.compress(b"foo") | result = cctx.compress(b"foo") | ||||
params = zstd.get_frame_parameters(result) | params = zstd.get_frame_parameters(result) | ||||
self.assertEqual(params.content_size, 3) | self.assertEqual(params.content_size, 3) | ||||
self.assertEqual(params.dict_id, d.dict_id()) | self.assertEqual(params.dict_id, d.dict_id()) | ||||
self.assertEqual( | self.assertEqual( | ||||
result, | result, | ||||
b"\x28\xb5\x2f\xfd\x23\x8f\x55\x0f\x70\x03\x19\x00\x00" b"\x66\x6f\x6f", | b"\x28\xb5\x2f\xfd\x23\x8f\x55\x0f\x70\x03\x19\x00\x00" | ||||
b"\x66\x6f\x6f", | |||||
) | ) | ||||
def test_multithreaded_compression_params(self): | def test_multithreaded_compression_params(self): | ||||
params = zstd.ZstdCompressionParameters.from_level(0, threads=2) | params = zstd.ZstdCompressionParameters.from_level(0, threads=2) | ||||
cctx = zstd.ZstdCompressor(compression_params=params) | cctx = zstd.ZstdCompressor(compression_params=params) | ||||
result = cctx.compress(b"foo") | result = cctx.compress(b"foo") | ||||
params = zstd.get_frame_parameters(result) | params = zstd.get_frame_parameters(result) | ||||
self.assertEqual(params.content_size, 3) | self.assertEqual(params.content_size, 3) | ||||
self.assertEqual(result, b"\x28\xb5\x2f\xfd\x20\x03\x19\x00\x00\x66\x6f\x6f") | self.assertEqual( | ||||
result, b"\x28\xb5\x2f\xfd\x20\x03\x19\x00\x00\x66\x6f\x6f" | |||||
) | |||||
@make_cffi | @make_cffi | ||||
class TestCompressor_compressobj(TestCase): | class TestCompressor_compressobj(TestCase): | ||||
def test_compressobj_empty(self): | def test_compressobj_empty(self): | ||||
cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | ||||
cobj = cctx.compressobj() | cobj = cctx.compressobj() | ||||
self.assertEqual(cobj.compress(b""), b"") | self.assertEqual(cobj.compress(b""), b"") | ||||
self.assertEqual(cobj.compress(b"foo"), b"") | self.assertEqual(cobj.compress(b"foo"), b"") | ||||
self.assertEqual( | self.assertEqual( | ||||
cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), | cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), | ||||
b"\x28\xb5\x2f\xfd\x00\x48\x18\x00\x00foo", | b"\x28\xb5\x2f\xfd\x00\x48\x18\x00\x00foo", | ||||
) | ) | ||||
self.assertEqual(cobj.compress(b"bar"), b"") | self.assertEqual(cobj.compress(b"bar"), b"") | ||||
# 3 byte header plus content. | # 3 byte header plus content. | ||||
self.assertEqual(cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b"\x18\x00\x00bar") | self.assertEqual( | ||||
cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK), b"\x18\x00\x00bar" | |||||
) | |||||
self.assertEqual(cobj.flush(), b"\x01\x00\x00") | self.assertEqual(cobj.flush(), b"\x01\x00\x00") | ||||
def test_flush_empty_block(self): | def test_flush_empty_block(self): | ||||
cctx = zstd.ZstdCompressor(write_checksum=True) | cctx = zstd.ZstdCompressor(write_checksum=True) | ||||
cobj = cctx.compressobj() | cobj = cctx.compressobj() | ||||
cobj.compress(b"foobar") | cobj.compress(b"foobar") | ||||
cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK) | cobj.flush(zstd.COMPRESSOBJ_FLUSH_BLOCK) | ||||
source = io.BytesIO() | source = io.BytesIO() | ||||
dest = io.BytesIO() | dest = io.BytesIO() | ||||
cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | ||||
r, w = cctx.copy_stream(source, dest) | r, w = cctx.copy_stream(source, dest) | ||||
self.assertEqual(int(r), 0) | self.assertEqual(int(r), 0) | ||||
self.assertEqual(w, 9) | self.assertEqual(w, 9) | ||||
self.assertEqual(dest.getvalue(), b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00") | self.assertEqual( | ||||
dest.getvalue(), b"\x28\xb5\x2f\xfd\x00\x48\x01\x00\x00" | |||||
) | |||||
def test_large_data(self): | def test_large_data(self): | ||||
source = io.BytesIO() | source = io.BytesIO() | ||||
for i in range(255): | for i in range(255): | ||||
source.write(struct.Struct(">B").pack(i) * 16384) | source.write(struct.Struct(">B").pack(i) * 16384) | ||||
source.seek(0) | source.seek(0) | ||||
dest = io.BytesIO() | dest = io.BytesIO() | ||||
cctx = zstd.ZstdCompressor(level=1) | cctx = zstd.ZstdCompressor(level=1) | ||||
cctx.copy_stream(source, no_checksum) | cctx.copy_stream(source, no_checksum) | ||||
source.seek(0) | source.seek(0) | ||||
with_checksum = io.BytesIO() | with_checksum = io.BytesIO() | ||||
cctx = zstd.ZstdCompressor(level=1, write_checksum=True) | cctx = zstd.ZstdCompressor(level=1, write_checksum=True) | ||||
cctx.copy_stream(source, with_checksum) | cctx.copy_stream(source, with_checksum) | ||||
self.assertEqual(len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4) | self.assertEqual( | ||||
len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4 | |||||
) | |||||
no_params = zstd.get_frame_parameters(no_checksum.getvalue()) | no_params = zstd.get_frame_parameters(no_checksum.getvalue()) | ||||
with_params = zstd.get_frame_parameters(with_checksum.getvalue()) | with_params = zstd.get_frame_parameters(with_checksum.getvalue()) | ||||
self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | ||||
self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | ||||
self.assertEqual(no_params.dict_id, 0) | self.assertEqual(no_params.dict_id, 0) | ||||
self.assertEqual(with_params.dict_id, 0) | self.assertEqual(with_params.dict_id, 0) | ||||
self.assertFalse(no_params.has_checksum) | self.assertFalse(no_params.has_checksum) | ||||
@make_cffi | @make_cffi | ||||
class TestCompressor_stream_reader(TestCase): | class TestCompressor_stream_reader(TestCase): | ||||
def test_context_manager(self): | def test_context_manager(self): | ||||
cctx = zstd.ZstdCompressor() | cctx = zstd.ZstdCompressor() | ||||
with cctx.stream_reader(b"foo") as reader: | with cctx.stream_reader(b"foo") as reader: | ||||
with self.assertRaisesRegex(ValueError, "cannot __enter__ multiple times"): | with self.assertRaisesRegex( | ||||
ValueError, "cannot __enter__ multiple times" | |||||
): | |||||
with reader as reader2: | with reader as reader2: | ||||
pass | pass | ||||
def test_no_context_manager(self): | def test_no_context_manager(self): | ||||
cctx = zstd.ZstdCompressor() | cctx = zstd.ZstdCompressor() | ||||
reader = cctx.stream_reader(b"foo") | reader = cctx.stream_reader(b"foo") | ||||
reader.read(4) | reader.read(4) | ||||
reader.read(10) | reader.read(10) | ||||
def test_bad_size(self): | def test_bad_size(self): | ||||
cctx = zstd.ZstdCompressor() | cctx = zstd.ZstdCompressor() | ||||
source = io.BytesIO(b"foobar") | source = io.BytesIO(b"foobar") | ||||
with cctx.stream_reader(source, size=2) as reader: | with cctx.stream_reader(source, size=2) as reader: | ||||
with self.assertRaisesRegex(zstd.ZstdError, "Src size is incorrect"): | with self.assertRaisesRegex( | ||||
zstd.ZstdError, "Src size is incorrect" | |||||
): | |||||
reader.read(10) | reader.read(10) | ||||
# Try another compression operation. | # Try another compression operation. | ||||
with cctx.stream_reader(source, size=42): | with cctx.stream_reader(source, size=42): | ||||
pass | pass | ||||
def test_readall(self): | def test_readall(self): | ||||
cctx = zstd.ZstdCompressor() | cctx = zstd.ZstdCompressor() | ||||
with_params = zstd.get_frame_parameters(with_checksum.getvalue()) | with_params = zstd.get_frame_parameters(with_checksum.getvalue()) | ||||
self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | ||||
self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | ||||
self.assertEqual(no_params.dict_id, 0) | self.assertEqual(no_params.dict_id, 0) | ||||
self.assertEqual(with_params.dict_id, 0) | self.assertEqual(with_params.dict_id, 0) | ||||
self.assertFalse(no_params.has_checksum) | self.assertFalse(no_params.has_checksum) | ||||
self.assertTrue(with_params.has_checksum) | self.assertTrue(with_params.has_checksum) | ||||
self.assertEqual(len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4) | self.assertEqual( | ||||
len(with_checksum.getvalue()), len(no_checksum.getvalue()) + 4 | |||||
) | |||||
def test_write_content_size(self): | def test_write_content_size(self): | ||||
no_size = NonClosingBytesIO() | no_size = NonClosingBytesIO() | ||||
cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | ||||
with cctx.stream_writer(no_size) as compressor: | with cctx.stream_writer(no_size) as compressor: | ||||
self.assertEqual(compressor.write(b"foobar" * 256), 0) | self.assertEqual(compressor.write(b"foobar" * 256), 0) | ||||
with_size = NonClosingBytesIO() | with_size = NonClosingBytesIO() | ||||
cctx = zstd.ZstdCompressor(level=1) | cctx = zstd.ZstdCompressor(level=1) | ||||
with cctx.stream_writer(with_size) as compressor: | with cctx.stream_writer(with_size) as compressor: | ||||
self.assertEqual(compressor.write(b"foobar" * 256), 0) | self.assertEqual(compressor.write(b"foobar" * 256), 0) | ||||
# Source size is not known in streaming mode, so header not | # Source size is not known in streaming mode, so header not | ||||
# written. | # written. | ||||
self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue())) | self.assertEqual(len(with_size.getvalue()), len(no_size.getvalue())) | ||||
# Declaring size will write the header. | # Declaring size will write the header. | ||||
with_size = NonClosingBytesIO() | with_size = NonClosingBytesIO() | ||||
with cctx.stream_writer(with_size, size=len(b"foobar" * 256)) as compressor: | with cctx.stream_writer( | ||||
with_size, size=len(b"foobar" * 256) | |||||
) as compressor: | |||||
self.assertEqual(compressor.write(b"foobar" * 256), 0) | self.assertEqual(compressor.write(b"foobar" * 256), 0) | ||||
no_params = zstd.get_frame_parameters(no_size.getvalue()) | no_params = zstd.get_frame_parameters(no_size.getvalue()) | ||||
with_params = zstd.get_frame_parameters(with_size.getvalue()) | with_params = zstd.get_frame_parameters(with_size.getvalue()) | ||||
self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | ||||
self.assertEqual(with_params.content_size, 1536) | self.assertEqual(with_params.content_size, 1536) | ||||
self.assertEqual(no_params.dict_id, 0) | self.assertEqual(no_params.dict_id, 0) | ||||
self.assertEqual(with_params.dict_id, 0) | self.assertEqual(with_params.dict_id, 0) | ||||
with_params = zstd.get_frame_parameters(with_dict_id.getvalue()) | with_params = zstd.get_frame_parameters(with_dict_id.getvalue()) | ||||
self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | self.assertEqual(no_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | ||||
self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | self.assertEqual(with_params.content_size, zstd.CONTENTSIZE_UNKNOWN) | ||||
self.assertEqual(no_params.dict_id, 0) | self.assertEqual(no_params.dict_id, 0) | ||||
self.assertEqual(with_params.dict_id, d.dict_id()) | self.assertEqual(with_params.dict_id, d.dict_id()) | ||||
self.assertFalse(no_params.has_checksum) | self.assertFalse(no_params.has_checksum) | ||||
self.assertFalse(with_params.has_checksum) | self.assertFalse(with_params.has_checksum) | ||||
self.assertEqual(len(with_dict_id.getvalue()), len(no_dict_id.getvalue()) + 4) | self.assertEqual( | ||||
len(with_dict_id.getvalue()), len(no_dict_id.getvalue()) + 4 | |||||
) | |||||
def test_memory_size(self): | def test_memory_size(self): | ||||
cctx = zstd.ZstdCompressor(level=3) | cctx = zstd.ZstdCompressor(level=3) | ||||
buffer = io.BytesIO() | buffer = io.BytesIO() | ||||
with cctx.stream_writer(buffer) as compressor: | with cctx.stream_writer(buffer) as compressor: | ||||
compressor.write(b"foo") | compressor.write(b"foo") | ||||
size = compressor.memory_size() | size = compressor.memory_size() | ||||
# Object with read() works. | # Object with read() works. | ||||
for chunk in cctx.read_to_iter(io.BytesIO()): | for chunk in cctx.read_to_iter(io.BytesIO()): | ||||
pass | pass | ||||
# Buffer protocol works. | # Buffer protocol works. | ||||
for chunk in cctx.read_to_iter(b"foobar"): | for chunk in cctx.read_to_iter(b"foobar"): | ||||
pass | pass | ||||
with self.assertRaisesRegex(ValueError, "must pass an object with a read"): | with self.assertRaisesRegex( | ||||
ValueError, "must pass an object with a read" | |||||
): | |||||
for chunk in cctx.read_to_iter(True): | for chunk in cctx.read_to_iter(True): | ||||
pass | pass | ||||
def test_read_empty(self): | def test_read_empty(self): | ||||
cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | cctx = zstd.ZstdCompressor(level=1, write_content_size=False) | ||||
source = io.BytesIO() | source = io.BytesIO() | ||||
it = cctx.read_to_iter(source) | it = cctx.read_to_iter(source) | ||||
[ | [ | ||||
b"\x28\xb5\x2f\xfd\x60\x00\x03\x65\x00\x00\x18\x78\x78\x79\x02\x00" | b"\x28\xb5\x2f\xfd\x60\x00\x03\x65\x00\x00\x18\x78\x78\x79\x02\x00" | ||||
b"\xa0\x16\xe3\x2b\x80\x05" | b"\xa0\x16\xe3\x2b\x80\x05" | ||||
], | ], | ||||
) | ) | ||||
dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
self.assertEqual(dctx.decompress(b"".join(chunks)), (b"x" * 1000) + (b"y" * 24)) | self.assertEqual( | ||||
dctx.decompress(b"".join(chunks)), (b"x" * 1000) + (b"y" * 24) | |||||
) | |||||
def test_small_chunk_size(self): | def test_small_chunk_size(self): | ||||
cctx = zstd.ZstdCompressor() | cctx = zstd.ZstdCompressor() | ||||
chunker = cctx.chunker(chunk_size=1) | chunker = cctx.chunker(chunk_size=1) | ||||
chunks = list(chunker.compress(b"foo" * 1024)) | chunks = list(chunker.compress(b"foo" * 1024)) | ||||
self.assertEqual(chunks, []) | self.assertEqual(chunks, []) | ||||
chunks = list(chunker.finish()) | chunks = list(chunker.finish()) | ||||
self.assertTrue(all(len(chunk) == 1 for chunk in chunks)) | self.assertTrue(all(len(chunk) == 1 for chunk in chunks)) | ||||
self.assertEqual( | self.assertEqual( | ||||
b"".join(chunks), | b"".join(chunks), | ||||
b"\x28\xb5\x2f\xfd\x00\x58\x55\x00\x00\x18\x66\x6f\x6f\x01\x00" | b"\x28\xb5\x2f\xfd\x00\x58\x55\x00\x00\x18\x66\x6f\x6f\x01\x00" | ||||
b"\xfa\xd3\x77\x43", | b"\xfa\xd3\x77\x43", | ||||
) | ) | ||||
dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
self.assertEqual( | self.assertEqual( | ||||
dctx.decompress(b"".join(chunks), max_output_size=10000), b"foo" * 1024 | dctx.decompress(b"".join(chunks), max_output_size=10000), | ||||
b"foo" * 1024, | |||||
) | ) | ||||
def test_input_types(self): | def test_input_types(self): | ||||
cctx = zstd.ZstdCompressor() | cctx = zstd.ZstdCompressor() | ||||
mutable_array = bytearray(3) | mutable_array = bytearray(3) | ||||
mutable_array[:] = b"foo" | mutable_array[:] = b"foo" | ||||
def test_compress_after_finish(self): | def test_compress_after_finish(self): | ||||
cctx = zstd.ZstdCompressor() | cctx = zstd.ZstdCompressor() | ||||
chunker = cctx.chunker() | chunker = cctx.chunker() | ||||
list(chunker.compress(b"foo")) | list(chunker.compress(b"foo")) | ||||
list(chunker.finish()) | list(chunker.finish()) | ||||
with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
zstd.ZstdError, r"cannot call compress\(\) after compression finished" | zstd.ZstdError, | ||||
r"cannot call compress\(\) after compression finished", | |||||
): | ): | ||||
list(chunker.compress(b"foo")) | list(chunker.compress(b"foo")) | ||||
def test_flush_after_finish(self): | def test_flush_after_finish(self): | ||||
cctx = zstd.ZstdCompressor() | cctx = zstd.ZstdCompressor() | ||||
chunker = cctx.chunker() | chunker = cctx.chunker() | ||||
list(chunker.compress(b"foo")) | list(chunker.compress(b"foo")) | ||||
self.skipTest("multi_compress_to_buffer not available") | self.skipTest("multi_compress_to_buffer not available") | ||||
with self.assertRaises(TypeError): | with self.assertRaises(TypeError): | ||||
cctx.multi_compress_to_buffer(True) | cctx.multi_compress_to_buffer(True) | ||||
with self.assertRaises(TypeError): | with self.assertRaises(TypeError): | ||||
cctx.multi_compress_to_buffer((1, 2)) | cctx.multi_compress_to_buffer((1, 2)) | ||||
with self.assertRaisesRegex(TypeError, "item 0 not a bytes like object"): | with self.assertRaisesRegex( | ||||
TypeError, "item 0 not a bytes like object" | |||||
): | |||||
cctx.multi_compress_to_buffer([u"foo"]) | cctx.multi_compress_to_buffer([u"foo"]) | ||||
def test_empty_input(self): | def test_empty_input(self): | ||||
cctx = zstd.ZstdCompressor() | cctx = zstd.ZstdCompressor() | ||||
if not hasattr(cctx, "multi_compress_to_buffer"): | if not hasattr(cctx, "multi_compress_to_buffer"): | ||||
self.skipTest("multi_compress_to_buffer not available") | self.skipTest("multi_compress_to_buffer not available") | ||||
class TestCompressor_stream_reader_fuzzing(TestCase): | class TestCompressor_stream_reader_fuzzing(TestCase): | ||||
@hypothesis.settings( | @hypothesis.settings( | ||||
suppress_health_check=[hypothesis.HealthCheck.large_base_example] | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | ||||
) | ) | ||||
@hypothesis.given( | @hypothesis.given( | ||||
original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
source_read_size=strategies.integers(1, 16384), | source_read_size=strategies.integers(1, 16384), | ||||
read_size=strategies.integers(-1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE), | read_size=strategies.integers( | ||||
-1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||||
), | |||||
) | ) | ||||
def test_stream_source_read(self, original, level, source_read_size, read_size): | def test_stream_source_read( | ||||
self, original, level, source_read_size, read_size | |||||
): | |||||
if read_size == 0: | if read_size == 0: | ||||
read_size = -1 | read_size = -1 | ||||
refctx = zstd.ZstdCompressor(level=level) | refctx = zstd.ZstdCompressor(level=level) | ||||
ref_frame = refctx.compress(original) | ref_frame = refctx.compress(original) | ||||
cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
with cctx.stream_reader( | with cctx.stream_reader( | ||||
@hypothesis.settings( | @hypothesis.settings( | ||||
suppress_health_check=[hypothesis.HealthCheck.large_base_example] | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | ||||
) | ) | ||||
@hypothesis.given( | @hypothesis.given( | ||||
original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
source_read_size=strategies.integers(1, 16384), | source_read_size=strategies.integers(1, 16384), | ||||
read_size=strategies.integers(-1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE), | read_size=strategies.integers( | ||||
-1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||||
), | |||||
) | ) | ||||
def test_buffer_source_read(self, original, level, source_read_size, read_size): | def test_buffer_source_read( | ||||
self, original, level, source_read_size, read_size | |||||
): | |||||
if read_size == 0: | if read_size == 0: | ||||
read_size = -1 | read_size = -1 | ||||
refctx = zstd.ZstdCompressor(level=level) | refctx = zstd.ZstdCompressor(level=level) | ||||
ref_frame = refctx.compress(original) | ref_frame = refctx.compress(original) | ||||
cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
with cctx.stream_reader( | with cctx.stream_reader( | ||||
@hypothesis.settings( | @hypothesis.settings( | ||||
suppress_health_check=[hypothesis.HealthCheck.large_base_example] | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | ||||
) | ) | ||||
@hypothesis.given( | @hypothesis.given( | ||||
original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
source_read_size=strategies.integers(1, 16384), | source_read_size=strategies.integers(1, 16384), | ||||
read_size=strategies.integers(1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE), | read_size=strategies.integers( | ||||
1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||||
), | |||||
) | ) | ||||
def test_stream_source_readinto(self, original, level, source_read_size, read_size): | def test_stream_source_readinto( | ||||
self, original, level, source_read_size, read_size | |||||
): | |||||
refctx = zstd.ZstdCompressor(level=level) | refctx = zstd.ZstdCompressor(level=level) | ||||
ref_frame = refctx.compress(original) | ref_frame = refctx.compress(original) | ||||
cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
with cctx.stream_reader( | with cctx.stream_reader( | ||||
io.BytesIO(original), size=len(original), read_size=source_read_size | io.BytesIO(original), size=len(original), read_size=source_read_size | ||||
) as reader: | ) as reader: | ||||
chunks = [] | chunks = [] | ||||
@hypothesis.settings( | @hypothesis.settings( | ||||
suppress_health_check=[hypothesis.HealthCheck.large_base_example] | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | ||||
) | ) | ||||
@hypothesis.given( | @hypothesis.given( | ||||
original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
source_read_size=strategies.integers(1, 16384), | source_read_size=strategies.integers(1, 16384), | ||||
read_size=strategies.integers(1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE), | read_size=strategies.integers( | ||||
1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||||
), | |||||
) | ) | ||||
def test_buffer_source_readinto(self, original, level, source_read_size, read_size): | def test_buffer_source_readinto( | ||||
self, original, level, source_read_size, read_size | |||||
): | |||||
refctx = zstd.ZstdCompressor(level=level) | refctx = zstd.ZstdCompressor(level=level) | ||||
ref_frame = refctx.compress(original) | ref_frame = refctx.compress(original) | ||||
cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
with cctx.stream_reader( | with cctx.stream_reader( | ||||
original, size=len(original), read_size=source_read_size | original, size=len(original), read_size=source_read_size | ||||
) as reader: | ) as reader: | ||||
@hypothesis.settings( | @hypothesis.settings( | ||||
suppress_health_check=[hypothesis.HealthCheck.large_base_example] | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | ||||
) | ) | ||||
@hypothesis.given( | @hypothesis.given( | ||||
original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
source_read_size=strategies.integers(1, 16384), | source_read_size=strategies.integers(1, 16384), | ||||
read_size=strategies.integers(-1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE), | read_size=strategies.integers( | ||||
-1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||||
), | |||||
) | ) | ||||
def test_stream_source_read1(self, original, level, source_read_size, read_size): | def test_stream_source_read1( | ||||
self, original, level, source_read_size, read_size | |||||
): | |||||
if read_size == 0: | if read_size == 0: | ||||
read_size = -1 | read_size = -1 | ||||
refctx = zstd.ZstdCompressor(level=level) | refctx = zstd.ZstdCompressor(level=level) | ||||
ref_frame = refctx.compress(original) | ref_frame = refctx.compress(original) | ||||
cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
with cctx.stream_reader( | with cctx.stream_reader( | ||||
@hypothesis.settings( | @hypothesis.settings( | ||||
suppress_health_check=[hypothesis.HealthCheck.large_base_example] | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | ||||
) | ) | ||||
@hypothesis.given( | @hypothesis.given( | ||||
original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
source_read_size=strategies.integers(1, 16384), | source_read_size=strategies.integers(1, 16384), | ||||
read_size=strategies.integers(-1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE), | read_size=strategies.integers( | ||||
-1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||||
), | |||||
) | ) | ||||
def test_buffer_source_read1(self, original, level, source_read_size, read_size): | def test_buffer_source_read1( | ||||
self, original, level, source_read_size, read_size | |||||
): | |||||
if read_size == 0: | if read_size == 0: | ||||
read_size = -1 | read_size = -1 | ||||
refctx = zstd.ZstdCompressor(level=level) | refctx = zstd.ZstdCompressor(level=level) | ||||
ref_frame = refctx.compress(original) | ref_frame = refctx.compress(original) | ||||
cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
with cctx.stream_reader( | with cctx.stream_reader( | ||||
@hypothesis.settings( | @hypothesis.settings( | ||||
suppress_health_check=[hypothesis.HealthCheck.large_base_example] | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | ||||
) | ) | ||||
@hypothesis.given( | @hypothesis.given( | ||||
original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
source_read_size=strategies.integers(1, 16384), | source_read_size=strategies.integers(1, 16384), | ||||
read_size=strategies.integers(1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE), | read_size=strategies.integers( | ||||
1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||||
), | |||||
) | ) | ||||
def test_stream_source_readinto1( | def test_stream_source_readinto1( | ||||
self, original, level, source_read_size, read_size | self, original, level, source_read_size, read_size | ||||
): | ): | ||||
if read_size == 0: | if read_size == 0: | ||||
read_size = -1 | read_size = -1 | ||||
refctx = zstd.ZstdCompressor(level=level) | refctx = zstd.ZstdCompressor(level=level) | ||||
@hypothesis.settings( | @hypothesis.settings( | ||||
suppress_health_check=[hypothesis.HealthCheck.large_base_example] | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | ||||
) | ) | ||||
@hypothesis.given( | @hypothesis.given( | ||||
original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
source_read_size=strategies.integers(1, 16384), | source_read_size=strategies.integers(1, 16384), | ||||
read_size=strategies.integers(1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE), | read_size=strategies.integers( | ||||
1, zstd.COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||||
), | |||||
) | ) | ||||
def test_buffer_source_readinto1( | def test_buffer_source_readinto1( | ||||
self, original, level, source_read_size, read_size | self, original, level, source_read_size, read_size | ||||
): | ): | ||||
if read_size == 0: | if read_size == 0: | ||||
read_size = -1 | read_size = -1 | ||||
refctx = zstd.ZstdCompressor(level=level) | refctx = zstd.ZstdCompressor(level=level) | ||||
@make_cffi | @make_cffi | ||||
class TestCompressor_copy_stream_fuzzing(TestCase): | class TestCompressor_copy_stream_fuzzing(TestCase): | ||||
@hypothesis.given( | @hypothesis.given( | ||||
original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
read_size=strategies.integers(min_value=1, max_value=1048576), | read_size=strategies.integers(min_value=1, max_value=1048576), | ||||
write_size=strategies.integers(min_value=1, max_value=1048576), | write_size=strategies.integers(min_value=1, max_value=1048576), | ||||
) | ) | ||||
def test_read_write_size_variance(self, original, level, read_size, write_size): | def test_read_write_size_variance( | ||||
self, original, level, read_size, write_size | |||||
): | |||||
refctx = zstd.ZstdCompressor(level=level) | refctx = zstd.ZstdCompressor(level=level) | ||||
ref_frame = refctx.compress(original) | ref_frame = refctx.compress(original) | ||||
cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
source = io.BytesIO(original) | source = io.BytesIO(original) | ||||
dest = io.BytesIO() | dest = io.BytesIO() | ||||
cctx.copy_stream( | cctx.copy_stream( | ||||
source, dest, size=len(original), read_size=read_size, write_size=write_size | source, | ||||
dest, | |||||
size=len(original), | |||||
read_size=read_size, | |||||
write_size=write_size, | |||||
) | ) | ||||
self.assertEqual(dest.getvalue(), ref_frame) | self.assertEqual(dest.getvalue(), ref_frame) | ||||
@unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | ||||
@make_cffi | @make_cffi | ||||
class TestCompressor_compressobj_fuzzing(TestCase): | class TestCompressor_compressobj_fuzzing(TestCase): | ||||
self.assertEqual(b"".join(decompressed_chunks), original[0:i]) | self.assertEqual(b"".join(decompressed_chunks), original[0:i]) | ||||
chunk = cobj.flush(zstd.COMPRESSOBJ_FLUSH_FINISH) | chunk = cobj.flush(zstd.COMPRESSOBJ_FLUSH_FINISH) | ||||
compressed_chunks.append(chunk) | compressed_chunks.append(chunk) | ||||
decompressed_chunks.append(dobj.decompress(chunk)) | decompressed_chunks.append(dobj.decompress(chunk)) | ||||
self.assertEqual( | self.assertEqual( | ||||
dctx.decompress(b"".join(compressed_chunks), max_output_size=len(original)), | dctx.decompress( | ||||
b"".join(compressed_chunks), max_output_size=len(original) | |||||
), | |||||
original, | original, | ||||
) | ) | ||||
self.assertEqual(b"".join(decompressed_chunks), original) | self.assertEqual(b"".join(decompressed_chunks), original) | ||||
@unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | ||||
@make_cffi | @make_cffi | ||||
class TestCompressor_read_to_iter_fuzzing(TestCase): | class TestCompressor_read_to_iter_fuzzing(TestCase): | ||||
@hypothesis.given( | @hypothesis.given( | ||||
original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
read_size=strategies.integers(min_value=1, max_value=4096), | read_size=strategies.integers(min_value=1, max_value=4096), | ||||
write_size=strategies.integers(min_value=1, max_value=4096), | write_size=strategies.integers(min_value=1, max_value=4096), | ||||
) | ) | ||||
def test_read_write_size_variance(self, original, level, read_size, write_size): | def test_read_write_size_variance( | ||||
self, original, level, read_size, write_size | |||||
): | |||||
refcctx = zstd.ZstdCompressor(level=level) | refcctx = zstd.ZstdCompressor(level=level) | ||||
ref_frame = refcctx.compress(original) | ref_frame = refcctx.compress(original) | ||||
source = io.BytesIO(original) | source = io.BytesIO(original) | ||||
cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
chunks = list( | chunks = list( | ||||
cctx.read_to_iter( | cctx.read_to_iter( | ||||
source, size=len(original), read_size=read_size, write_size=write_size | source, | ||||
size=len(original), | |||||
read_size=read_size, | |||||
write_size=write_size, | |||||
) | ) | ||||
) | ) | ||||
self.assertEqual(b"".join(chunks), ref_frame) | self.assertEqual(b"".join(chunks), ref_frame) | ||||
@unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | ||||
class TestCompressor_multi_compress_to_buffer_fuzzing(TestCase): | class TestCompressor_multi_compress_to_buffer_fuzzing(TestCase): | ||||
@hypothesis.given( | @hypothesis.given( | ||||
original=strategies.lists( | original=strategies.lists( | ||||
strategies.sampled_from(random_input_data()), min_size=1, max_size=1024 | strategies.sampled_from(random_input_data()), | ||||
min_size=1, | |||||
max_size=1024, | |||||
), | ), | ||||
threads=strategies.integers(min_value=1, max_value=8), | threads=strategies.integers(min_value=1, max_value=8), | ||||
use_dict=strategies.booleans(), | use_dict=strategies.booleans(), | ||||
) | ) | ||||
def test_data_equivalence(self, original, threads, use_dict): | def test_data_equivalence(self, original, threads, use_dict): | ||||
kwargs = {} | kwargs = {} | ||||
# Use a content dictionary because it is cheap to create. | # Use a content dictionary because it is cheap to create. | ||||
chunks.extend(chunker.compress(source)) | chunks.extend(chunker.compress(source)) | ||||
i += input_size | i += input_size | ||||
chunks.extend(chunker.finish()) | chunks.extend(chunker.finish()) | ||||
dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
self.assertEqual( | self.assertEqual( | ||||
dctx.decompress(b"".join(chunks), max_output_size=len(original)), original | dctx.decompress(b"".join(chunks), max_output_size=len(original)), | ||||
original, | |||||
) | ) | ||||
self.assertTrue(all(len(chunk) == chunk_size for chunk in chunks[:-1])) | self.assertTrue(all(len(chunk) == chunk_size for chunk in chunks[:-1])) | ||||
@hypothesis.settings( | @hypothesis.settings( | ||||
suppress_health_check=[ | suppress_health_check=[ | ||||
hypothesis.HealthCheck.large_base_example, | hypothesis.HealthCheck.large_base_example, | ||||
hypothesis.HealthCheck.too_slow, | hypothesis.HealthCheck.too_slow, | ||||
] | ] | ||||
) | ) | ||||
@hypothesis.given( | @hypothesis.given( | ||||
original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
chunk_size=strategies.integers(min_value=1, max_value=32 * 1048576), | chunk_size=strategies.integers(min_value=1, max_value=32 * 1048576), | ||||
input_sizes=strategies.data(), | input_sizes=strategies.data(), | ||||
flushes=strategies.data(), | flushes=strategies.data(), | ||||
) | ) | ||||
def test_flush_block(self, original, level, chunk_size, input_sizes, flushes): | def test_flush_block( | ||||
self, original, level, chunk_size, input_sizes, flushes | |||||
): | |||||
cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
chunker = cctx.chunker(chunk_size=chunk_size) | chunker = cctx.chunker(chunk_size=chunk_size) | ||||
dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
dobj = dctx.decompressobj() | dobj = dctx.decompressobj() | ||||
compressed_chunks = [] | compressed_chunks = [] | ||||
decompressed_chunks = [] | decompressed_chunks = [] | ||||
self.assertEqual(b"".join(decompressed_chunks), original[0:i]) | self.assertEqual(b"".join(decompressed_chunks), original[0:i]) | ||||
chunks = list(chunker.finish()) | chunks = list(chunker.finish()) | ||||
compressed_chunks.extend(chunks) | compressed_chunks.extend(chunks) | ||||
decompressed_chunks.append(dobj.decompress(b"".join(chunks))) | decompressed_chunks.append(dobj.decompress(b"".join(chunks))) | ||||
self.assertEqual( | self.assertEqual( | ||||
dctx.decompress(b"".join(compressed_chunks), max_output_size=len(original)), | dctx.decompress( | ||||
b"".join(compressed_chunks), max_output_size=len(original) | |||||
), | |||||
original, | original, | ||||
) | ) | ||||
self.assertEqual(b"".join(decompressed_chunks), original) | self.assertEqual(b"".join(decompressed_chunks), original) |
self.assertEqual(p.compression_strategy, 1) | self.assertEqual(p.compression_strategy, 1) | ||||
p = zstd.ZstdCompressionParameters(compression_level=2) | p = zstd.ZstdCompressionParameters(compression_level=2) | ||||
self.assertEqual(p.compression_level, 2) | self.assertEqual(p.compression_level, 2) | ||||
p = zstd.ZstdCompressionParameters(threads=4) | p = zstd.ZstdCompressionParameters(threads=4) | ||||
self.assertEqual(p.threads, 4) | self.assertEqual(p.threads, 4) | ||||
p = zstd.ZstdCompressionParameters(threads=2, job_size=1048576, overlap_log=6) | p = zstd.ZstdCompressionParameters( | ||||
threads=2, job_size=1048576, overlap_log=6 | |||||
) | |||||
self.assertEqual(p.threads, 2) | self.assertEqual(p.threads, 2) | ||||
self.assertEqual(p.job_size, 1048576) | self.assertEqual(p.job_size, 1048576) | ||||
self.assertEqual(p.overlap_log, 6) | self.assertEqual(p.overlap_log, 6) | ||||
self.assertEqual(p.overlap_size_log, 6) | self.assertEqual(p.overlap_size_log, 6) | ||||
p = zstd.ZstdCompressionParameters(compression_level=-1) | p = zstd.ZstdCompressionParameters(compression_level=-1) | ||||
self.assertEqual(p.compression_level, -1) | self.assertEqual(p.compression_level, -1) | ||||
p = zstd.ZstdCompressionParameters(strategy=3) | p = zstd.ZstdCompressionParameters(strategy=3) | ||||
self.assertEqual(p.compression_strategy, 3) | self.assertEqual(p.compression_strategy, 3) | ||||
def test_ldm_hash_rate_log(self): | def test_ldm_hash_rate_log(self): | ||||
with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
ValueError, "cannot specify both ldm_hash_rate_log" | ValueError, "cannot specify both ldm_hash_rate_log" | ||||
): | ): | ||||
zstd.ZstdCompressionParameters(ldm_hash_rate_log=8, ldm_hash_every_log=4) | zstd.ZstdCompressionParameters( | ||||
ldm_hash_rate_log=8, ldm_hash_every_log=4 | |||||
) | |||||
p = zstd.ZstdCompressionParameters(ldm_hash_rate_log=8) | p = zstd.ZstdCompressionParameters(ldm_hash_rate_log=8) | ||||
self.assertEqual(p.ldm_hash_every_log, 8) | self.assertEqual(p.ldm_hash_every_log, 8) | ||||
p = zstd.ZstdCompressionParameters(ldm_hash_every_log=16) | p = zstd.ZstdCompressionParameters(ldm_hash_every_log=16) | ||||
self.assertEqual(p.ldm_hash_every_log, 16) | self.assertEqual(p.ldm_hash_every_log, 16) | ||||
def test_overlap_log(self): | def test_overlap_log(self): | ||||
with self.assertRaisesRegex(ValueError, "cannot specify both overlap_log"): | with self.assertRaisesRegex( | ||||
ValueError, "cannot specify both overlap_log" | |||||
): | |||||
zstd.ZstdCompressionParameters(overlap_log=1, overlap_size_log=9) | zstd.ZstdCompressionParameters(overlap_log=1, overlap_size_log=9) | ||||
p = zstd.ZstdCompressionParameters(overlap_log=2) | p = zstd.ZstdCompressionParameters(overlap_log=2) | ||||
self.assertEqual(p.overlap_log, 2) | self.assertEqual(p.overlap_log, 2) | ||||
self.assertEqual(p.overlap_size_log, 2) | self.assertEqual(p.overlap_size_log, 2) | ||||
p = zstd.ZstdCompressionParameters(overlap_size_log=4) | p = zstd.ZstdCompressionParameters(overlap_size_log=4) | ||||
self.assertEqual(p.overlap_log, 4) | self.assertEqual(p.overlap_log, 4) | ||||
if zstd.backend == "cffi": | if zstd.backend == "cffi": | ||||
with self.assertRaises(TypeError): | with self.assertRaises(TypeError): | ||||
zstd.get_frame_parameters(u"foobarbaz") | zstd.get_frame_parameters(u"foobarbaz") | ||||
else: | else: | ||||
with self.assertRaises(zstd.ZstdError): | with self.assertRaises(zstd.ZstdError): | ||||
zstd.get_frame_parameters(u"foobarbaz") | zstd.get_frame_parameters(u"foobarbaz") | ||||
def test_invalid_input_sizes(self): | def test_invalid_input_sizes(self): | ||||
with self.assertRaisesRegex(zstd.ZstdError, "not enough data for frame"): | with self.assertRaisesRegex( | ||||
zstd.ZstdError, "not enough data for frame" | |||||
): | |||||
zstd.get_frame_parameters(b"") | zstd.get_frame_parameters(b"") | ||||
with self.assertRaisesRegex(zstd.ZstdError, "not enough data for frame"): | with self.assertRaisesRegex( | ||||
zstd.ZstdError, "not enough data for frame" | |||||
): | |||||
zstd.get_frame_parameters(zstd.FRAME_HEADER) | zstd.get_frame_parameters(zstd.FRAME_HEADER) | ||||
def test_invalid_frame(self): | def test_invalid_frame(self): | ||||
with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"): | with self.assertRaisesRegex(zstd.ZstdError, "Unknown frame descriptor"): | ||||
zstd.get_frame_parameters(b"foobarbaz") | zstd.get_frame_parameters(b"foobarbaz") | ||||
def test_attributes(self): | def test_attributes(self): | ||||
params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x00\x00") | params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x00\x00") | ||||
# Lowest 3rd bit indicates if checksum is present. | # Lowest 3rd bit indicates if checksum is present. | ||||
params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x04\x00") | params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x04\x00") | ||||
self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | ||||
self.assertEqual(params.window_size, 1024) | self.assertEqual(params.window_size, 1024) | ||||
self.assertEqual(params.dict_id, 0) | self.assertEqual(params.dict_id, 0) | ||||
self.assertTrue(params.has_checksum) | self.assertTrue(params.has_checksum) | ||||
# Upper 2 bits indicate content size. | # Upper 2 bits indicate content size. | ||||
params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x40\x00\xff\x00") | params = zstd.get_frame_parameters( | ||||
zstd.FRAME_HEADER + b"\x40\x00\xff\x00" | |||||
) | |||||
self.assertEqual(params.content_size, 511) | self.assertEqual(params.content_size, 511) | ||||
self.assertEqual(params.window_size, 1024) | self.assertEqual(params.window_size, 1024) | ||||
self.assertEqual(params.dict_id, 0) | self.assertEqual(params.dict_id, 0) | ||||
self.assertFalse(params.has_checksum) | self.assertFalse(params.has_checksum) | ||||
# Window descriptor is 2nd byte after frame header. | # Window descriptor is 2nd byte after frame header. | ||||
params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x00\x40") | params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x00\x40") | ||||
self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | self.assertEqual(params.content_size, zstd.CONTENTSIZE_UNKNOWN) | ||||
self.assertEqual(params.window_size, 262144) | self.assertEqual(params.window_size, 262144) | ||||
self.assertEqual(params.dict_id, 0) | self.assertEqual(params.dict_id, 0) | ||||
self.assertFalse(params.has_checksum) | self.assertFalse(params.has_checksum) | ||||
# Set multiple things. | # Set multiple things. | ||||
params = zstd.get_frame_parameters(zstd.FRAME_HEADER + b"\x45\x40\x0f\x10\x00") | params = zstd.get_frame_parameters( | ||||
zstd.FRAME_HEADER + b"\x45\x40\x0f\x10\x00" | |||||
) | |||||
self.assertEqual(params.content_size, 272) | self.assertEqual(params.content_size, 272) | ||||
self.assertEqual(params.window_size, 262144) | self.assertEqual(params.window_size, 262144) | ||||
self.assertEqual(params.dict_id, 15) | self.assertEqual(params.dict_id, 15) | ||||
self.assertTrue(params.has_checksum) | self.assertTrue(params.has_checksum) | ||||
def test_input_types(self): | def test_input_types(self): | ||||
v = zstd.FRAME_HEADER + b"\x00\x00" | v = zstd.FRAME_HEADER + b"\x00\x00" | ||||
s_windowlog = strategies.integers( | s_windowlog = strategies.integers( | ||||
min_value=zstd.WINDOWLOG_MIN, max_value=zstd.WINDOWLOG_MAX | min_value=zstd.WINDOWLOG_MIN, max_value=zstd.WINDOWLOG_MAX | ||||
) | ) | ||||
s_chainlog = strategies.integers( | s_chainlog = strategies.integers( | ||||
min_value=zstd.CHAINLOG_MIN, max_value=zstd.CHAINLOG_MAX | min_value=zstd.CHAINLOG_MIN, max_value=zstd.CHAINLOG_MAX | ||||
) | ) | ||||
s_hashlog = strategies.integers(min_value=zstd.HASHLOG_MIN, max_value=zstd.HASHLOG_MAX) | s_hashlog = strategies.integers( | ||||
min_value=zstd.HASHLOG_MIN, max_value=zstd.HASHLOG_MAX | |||||
) | |||||
s_searchlog = strategies.integers( | s_searchlog = strategies.integers( | ||||
min_value=zstd.SEARCHLOG_MIN, max_value=zstd.SEARCHLOG_MAX | min_value=zstd.SEARCHLOG_MIN, max_value=zstd.SEARCHLOG_MAX | ||||
) | ) | ||||
s_minmatch = strategies.integers( | s_minmatch = strategies.integers( | ||||
min_value=zstd.MINMATCH_MIN, max_value=zstd.MINMATCH_MAX | min_value=zstd.MINMATCH_MIN, max_value=zstd.MINMATCH_MAX | ||||
) | ) | ||||
s_targetlength = strategies.integers( | s_targetlength = strategies.integers( | ||||
min_value=zstd.TARGETLENGTH_MIN, max_value=zstd.TARGETLENGTH_MAX | min_value=zstd.TARGETLENGTH_MIN, max_value=zstd.TARGETLENGTH_MAX | ||||
s_chainlog, | s_chainlog, | ||||
s_hashlog, | s_hashlog, | ||||
s_searchlog, | s_searchlog, | ||||
s_minmatch, | s_minmatch, | ||||
s_targetlength, | s_targetlength, | ||||
s_strategy, | s_strategy, | ||||
) | ) | ||||
def test_valid_init( | def test_valid_init( | ||||
self, windowlog, chainlog, hashlog, searchlog, minmatch, targetlength, strategy | self, | ||||
windowlog, | |||||
chainlog, | |||||
hashlog, | |||||
searchlog, | |||||
minmatch, | |||||
targetlength, | |||||
strategy, | |||||
): | ): | ||||
zstd.ZstdCompressionParameters( | zstd.ZstdCompressionParameters( | ||||
window_log=windowlog, | window_log=windowlog, | ||||
chain_log=chainlog, | chain_log=chainlog, | ||||
hash_log=hashlog, | hash_log=hashlog, | ||||
search_log=searchlog, | search_log=searchlog, | ||||
min_match=minmatch, | min_match=minmatch, | ||||
target_length=targetlength, | target_length=targetlength, | ||||
strategy=strategy, | strategy=strategy, | ||||
) | ) | ||||
@hypothesis.given( | @hypothesis.given( | ||||
s_windowlog, | s_windowlog, | ||||
s_chainlog, | s_chainlog, | ||||
s_hashlog, | s_hashlog, | ||||
s_searchlog, | s_searchlog, | ||||
s_minmatch, | s_minmatch, | ||||
s_targetlength, | s_targetlength, | ||||
s_strategy, | s_strategy, | ||||
) | ) | ||||
def test_estimated_compression_context_size( | def test_estimated_compression_context_size( | ||||
self, windowlog, chainlog, hashlog, searchlog, minmatch, targetlength, strategy | self, | ||||
windowlog, | |||||
chainlog, | |||||
hashlog, | |||||
searchlog, | |||||
minmatch, | |||||
targetlength, | |||||
strategy, | |||||
): | ): | ||||
if minmatch == zstd.MINMATCH_MIN and strategy in ( | if minmatch == zstd.MINMATCH_MIN and strategy in ( | ||||
zstd.STRATEGY_FAST, | zstd.STRATEGY_FAST, | ||||
zstd.STRATEGY_GREEDY, | zstd.STRATEGY_GREEDY, | ||||
): | ): | ||||
minmatch += 1 | minmatch += 1 | ||||
elif minmatch == zstd.MINMATCH_MAX and strategy != zstd.STRATEGY_FAST: | elif minmatch == zstd.MINMATCH_MAX and strategy != zstd.STRATEGY_FAST: | ||||
minmatch -= 1 | minmatch -= 1 |
# Input size - 1 fails | # Input size - 1 fails | ||||
with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
zstd.ZstdError, "decompression error: did not decompress full frame" | zstd.ZstdError, "decompression error: did not decompress full frame" | ||||
): | ): | ||||
dctx.decompress(compressed, max_output_size=len(source) - 1) | dctx.decompress(compressed, max_output_size=len(source) - 1) | ||||
# Input size + 1 works | # Input size + 1 works | ||||
decompressed = dctx.decompress(compressed, max_output_size=len(source) + 1) | decompressed = dctx.decompress( | ||||
compressed, max_output_size=len(source) + 1 | |||||
) | |||||
self.assertEqual(decompressed, source) | self.assertEqual(decompressed, source) | ||||
# A much larger buffer works. | # A much larger buffer works. | ||||
decompressed = dctx.decompress(compressed, max_output_size=len(source) * 64) | decompressed = dctx.decompress( | ||||
compressed, max_output_size=len(source) * 64 | |||||
) | |||||
self.assertEqual(decompressed, source) | self.assertEqual(decompressed, source) | ||||
def test_stupidly_large_output_buffer(self): | def test_stupidly_large_output_buffer(self): | ||||
cctx = zstd.ZstdCompressor(write_content_size=False) | cctx = zstd.ZstdCompressor(write_content_size=False) | ||||
compressed = cctx.compress(b"foobar" * 256) | compressed = cctx.compress(b"foobar" * 256) | ||||
dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
# Will get OverflowError on some Python distributions that can't | # Will get OverflowError on some Python distributions that can't | ||||
# If we write a content size, the decompressor engages single pass | # If we write a content size, the decompressor engages single pass | ||||
# mode and the window size doesn't come into play. | # mode and the window size doesn't come into play. | ||||
cctx = zstd.ZstdCompressor(write_content_size=False) | cctx = zstd.ZstdCompressor(write_content_size=False) | ||||
frame = cctx.compress(source) | frame = cctx.compress(source) | ||||
dctx = zstd.ZstdDecompressor(max_window_size=2 ** zstd.WINDOWLOG_MIN) | dctx = zstd.ZstdDecompressor(max_window_size=2 ** zstd.WINDOWLOG_MIN) | ||||
with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
zstd.ZstdError, "decompression error: Frame requires too much memory" | zstd.ZstdError, | ||||
"decompression error: Frame requires too much memory", | |||||
): | ): | ||||
dctx.decompress(frame, max_output_size=len(source)) | dctx.decompress(frame, max_output_size=len(source)) | ||||
@make_cffi | @make_cffi | ||||
class TestDecompressor_copy_stream(TestCase): | class TestDecompressor_copy_stream(TestCase): | ||||
def test_no_read(self): | def test_no_read(self): | ||||
source = object() | source = object() | ||||
dest = io.BytesIO() | dest = io.BytesIO() | ||||
dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
r, w = dctx.copy_stream(compressed, dest) | r, w = dctx.copy_stream(compressed, dest) | ||||
self.assertEqual(r, len(compressed.getvalue())) | self.assertEqual(r, len(compressed.getvalue())) | ||||
self.assertEqual(w, len(source.getvalue())) | self.assertEqual(w, len(source.getvalue())) | ||||
def test_read_write_size(self): | def test_read_write_size(self): | ||||
source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b"foobarfoobar")) | source = OpCountingBytesIO( | ||||
zstd.ZstdCompressor().compress(b"foobarfoobar") | |||||
) | |||||
dest = OpCountingBytesIO() | dest = OpCountingBytesIO() | ||||
dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1) | r, w = dctx.copy_stream(source, dest, read_size=1, write_size=1) | ||||
self.assertEqual(r, len(source.getvalue())) | self.assertEqual(r, len(source.getvalue())) | ||||
self.assertEqual(w, len(b"foobarfoobar")) | self.assertEqual(w, len(b"foobarfoobar")) | ||||
self.assertEqual(source._read_count, len(source.getvalue()) + 1) | self.assertEqual(source._read_count, len(source.getvalue()) + 1) | ||||
self.assertEqual(dest._write_count, len(dest.getvalue())) | self.assertEqual(dest._write_count, len(dest.getvalue())) | ||||
@make_cffi | @make_cffi | ||||
class TestDecompressor_stream_reader(TestCase): | class TestDecompressor_stream_reader(TestCase): | ||||
def test_context_manager(self): | def test_context_manager(self): | ||||
dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
with dctx.stream_reader(b"foo") as reader: | with dctx.stream_reader(b"foo") as reader: | ||||
with self.assertRaisesRegex(ValueError, "cannot __enter__ multiple times"): | with self.assertRaisesRegex( | ||||
ValueError, "cannot __enter__ multiple times" | |||||
): | |||||
with reader as reader2: | with reader as reader2: | ||||
pass | pass | ||||
def test_not_implemented(self): | def test_not_implemented(self): | ||||
dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
with dctx.stream_reader(b"foo") as reader: | with dctx.stream_reader(b"foo") as reader: | ||||
with self.assertRaises(io.UnsupportedOperation): | with self.assertRaises(io.UnsupportedOperation): | ||||
def test_illegal_seeks(self): | def test_illegal_seeks(self): | ||||
cctx = zstd.ZstdCompressor() | cctx = zstd.ZstdCompressor() | ||||
frame = cctx.compress(b"foo" * 60) | frame = cctx.compress(b"foo" * 60) | ||||
dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
with dctx.stream_reader(frame) as reader: | with dctx.stream_reader(frame) as reader: | ||||
with self.assertRaisesRegex(ValueError, "cannot seek to negative position"): | with self.assertRaisesRegex( | ||||
ValueError, "cannot seek to negative position" | |||||
): | |||||
reader.seek(-1, os.SEEK_SET) | reader.seek(-1, os.SEEK_SET) | ||||
reader.read(1) | reader.read(1) | ||||
with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
ValueError, "cannot seek zstd decompression stream backwards" | ValueError, "cannot seek zstd decompression stream backwards" | ||||
): | ): | ||||
reader.seek(0, os.SEEK_SET) | reader.seek(0, os.SEEK_SET) | ||||
with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
ValueError, "cannot seek zstd decompression stream backwards" | ValueError, "cannot seek zstd decompression stream backwards" | ||||
): | ): | ||||
reader.seek(-1, os.SEEK_CUR) | reader.seek(-1, os.SEEK_CUR) | ||||
with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
ValueError, "zstd decompression streams cannot be seeked with SEEK_END" | ValueError, | ||||
"zstd decompression streams cannot be seeked with SEEK_END", | |||||
): | ): | ||||
reader.seek(0, os.SEEK_END) | reader.seek(0, os.SEEK_END) | ||||
reader.close() | reader.close() | ||||
with self.assertRaisesRegex(ValueError, "stream is closed"): | with self.assertRaisesRegex(ValueError, "stream is closed"): | ||||
reader.seek(4, os.SEEK_SET) | reader.seek(4, os.SEEK_SET) | ||||
self.assertEqual(b._read_count, 1) | self.assertEqual(b._read_count, 1) | ||||
self.assertEqual(reader.read1(1), b"o") | self.assertEqual(reader.read1(1), b"o") | ||||
self.assertEqual(b._read_count, 1) | self.assertEqual(b._read_count, 1) | ||||
self.assertEqual(reader.read1(1), b"") | self.assertEqual(reader.read1(1), b"") | ||||
self.assertEqual(b._read_count, 2) | self.assertEqual(b._read_count, 2) | ||||
def test_read_lines(self): | def test_read_lines(self): | ||||
cctx = zstd.ZstdCompressor() | cctx = zstd.ZstdCompressor() | ||||
source = b"\n".join(("line %d" % i).encode("ascii") for i in range(1024)) | source = b"\n".join( | ||||
("line %d" % i).encode("ascii") for i in range(1024) | |||||
) | |||||
frame = cctx.compress(source) | frame = cctx.compress(source) | ||||
dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
reader = dctx.stream_reader(frame) | reader = dctx.stream_reader(frame) | ||||
tr = io.TextIOWrapper(reader, encoding="utf-8") | tr = io.TextIOWrapper(reader, encoding="utf-8") | ||||
lines = [] | lines = [] | ||||
def test_reuse(self): | def test_reuse(self): | ||||
data = zstd.ZstdCompressor(level=1).compress(b"foobar") | data = zstd.ZstdCompressor(level=1).compress(b"foobar") | ||||
dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
dobj = dctx.decompressobj() | dobj = dctx.decompressobj() | ||||
dobj.decompress(data) | dobj.decompress(data) | ||||
with self.assertRaisesRegex(zstd.ZstdError, "cannot use a decompressobj"): | with self.assertRaisesRegex( | ||||
zstd.ZstdError, "cannot use a decompressobj" | |||||
): | |||||
dobj.decompress(data) | dobj.decompress(data) | ||||
self.assertIsNone(dobj.flush()) | self.assertIsNone(dobj.flush()) | ||||
def test_bad_write_size(self): | def test_bad_write_size(self): | ||||
dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
with self.assertRaisesRegex(ValueError, "write_size must be positive"): | with self.assertRaisesRegex(ValueError, "write_size must be positive"): | ||||
dctx.decompressobj(write_size=0) | dctx.decompressobj(write_size=0) | ||||
dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
# Object with read() works. | # Object with read() works. | ||||
dctx.read_to_iter(io.BytesIO()) | dctx.read_to_iter(io.BytesIO()) | ||||
# Buffer protocol works. | # Buffer protocol works. | ||||
dctx.read_to_iter(b"foobar") | dctx.read_to_iter(b"foobar") | ||||
with self.assertRaisesRegex(ValueError, "must pass an object with a read"): | with self.assertRaisesRegex( | ||||
ValueError, "must pass an object with a read" | |||||
): | |||||
b"".join(dctx.read_to_iter(True)) | b"".join(dctx.read_to_iter(True)) | ||||
def test_empty_input(self): | def test_empty_input(self): | ||||
dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
source = io.BytesIO() | source = io.BytesIO() | ||||
it = dctx.read_to_iter(source) | it = dctx.read_to_iter(source) | ||||
# TODO this is arguably wrong. Should get an error about missing frame foo. | # TODO this is arguably wrong. Should get an error about missing frame foo. | ||||
chunks.append(next(it)) | chunks.append(next(it)) | ||||
with self.assertRaises(StopIteration): | with self.assertRaises(StopIteration): | ||||
next(it) | next(it) | ||||
decompressed = b"".join(chunks) | decompressed = b"".join(chunks) | ||||
self.assertEqual(decompressed, source.getvalue()) | self.assertEqual(decompressed, source.getvalue()) | ||||
@unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | @unittest.skipUnless( | ||||
"ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set" | |||||
) | |||||
def test_large_input(self): | def test_large_input(self): | ||||
bytes = list(struct.Struct(">B").pack(i) for i in range(256)) | bytes = list(struct.Struct(">B").pack(i) for i in range(256)) | ||||
compressed = NonClosingBytesIO() | compressed = NonClosingBytesIO() | ||||
input_size = 0 | input_size = 0 | ||||
cctx = zstd.ZstdCompressor(level=1) | cctx = zstd.ZstdCompressor(level=1) | ||||
with cctx.stream_writer(compressed) as compressor: | with cctx.stream_writer(compressed) as compressor: | ||||
while True: | while True: | ||||
compressor.write(random.choice(bytes)) | compressor.write(random.choice(bytes)) | ||||
input_size += 1 | input_size += 1 | ||||
have_compressed = ( | have_compressed = ( | ||||
len(compressed.getvalue()) | len(compressed.getvalue()) | ||||
> zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE | > zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE | ||||
) | ) | ||||
have_raw = input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2 | have_raw = ( | ||||
input_size > zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE * 2 | |||||
) | |||||
if have_compressed and have_raw: | if have_compressed and have_raw: | ||||
break | break | ||||
compressed = io.BytesIO(compressed.getvalue()) | compressed = io.BytesIO(compressed.getvalue()) | ||||
self.assertGreater( | self.assertGreater( | ||||
len(compressed.getvalue()), zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE | len(compressed.getvalue()), | ||||
zstd.DECOMPRESSION_RECOMMENDED_INPUT_SIZE, | |||||
) | ) | ||||
dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
it = dctx.read_to_iter(compressed) | it = dctx.read_to_iter(compressed) | ||||
chunks = [] | chunks = [] | ||||
chunks.append(next(it)) | chunks.append(next(it)) | ||||
chunks.append(next(it)) | chunks.append(next(it)) | ||||
) | ) | ||||
self.assertEqual(simple, source.getvalue()) | self.assertEqual(simple, source.getvalue()) | ||||
compressed = io.BytesIO(compressed.getvalue()) | compressed = io.BytesIO(compressed.getvalue()) | ||||
streamed = b"".join(dctx.read_to_iter(compressed)) | streamed = b"".join(dctx.read_to_iter(compressed)) | ||||
self.assertEqual(streamed, source.getvalue()) | self.assertEqual(streamed, source.getvalue()) | ||||
def test_read_write_size(self): | def test_read_write_size(self): | ||||
source = OpCountingBytesIO(zstd.ZstdCompressor().compress(b"foobarfoobar")) | source = OpCountingBytesIO( | ||||
zstd.ZstdCompressor().compress(b"foobarfoobar") | |||||
) | |||||
dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
for chunk in dctx.read_to_iter(source, read_size=1, write_size=1): | for chunk in dctx.read_to_iter(source, read_size=1, write_size=1): | ||||
self.assertEqual(len(chunk), 1) | self.assertEqual(len(chunk), 1) | ||||
self.assertEqual(source._read_count, len(source.getvalue())) | self.assertEqual(source._read_count, len(source.getvalue())) | ||||
def test_magic_less(self): | def test_magic_less(self): | ||||
params = zstd.CompressionParameters.from_level( | params = zstd.CompressionParameters.from_level( | ||||
with self.assertRaisesRegex(ValueError, "chunk 0 must be bytes"): | with self.assertRaisesRegex(ValueError, "chunk 0 must be bytes"): | ||||
dctx.decompress_content_dict_chain([True]) | dctx.decompress_content_dict_chain([True]) | ||||
with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
ValueError, "chunk 0 is too small to contain a zstd frame" | ValueError, "chunk 0 is too small to contain a zstd frame" | ||||
): | ): | ||||
dctx.decompress_content_dict_chain([zstd.FRAME_HEADER]) | dctx.decompress_content_dict_chain([zstd.FRAME_HEADER]) | ||||
with self.assertRaisesRegex(ValueError, "chunk 0 is not a valid zstd frame"): | with self.assertRaisesRegex( | ||||
ValueError, "chunk 0 is not a valid zstd frame" | |||||
): | |||||
dctx.decompress_content_dict_chain([b"foo" * 8]) | dctx.decompress_content_dict_chain([b"foo" * 8]) | ||||
no_size = zstd.ZstdCompressor(write_content_size=False).compress(b"foo" * 64) | no_size = zstd.ZstdCompressor(write_content_size=False).compress( | ||||
b"foo" * 64 | |||||
) | |||||
with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
ValueError, "chunk 0 missing content size in frame" | ValueError, "chunk 0 missing content size in frame" | ||||
): | ): | ||||
dctx.decompress_content_dict_chain([no_size]) | dctx.decompress_content_dict_chain([no_size]) | ||||
# Corrupt first frame. | # Corrupt first frame. | ||||
frame = zstd.ZstdCompressor().compress(b"foo" * 64) | frame = zstd.ZstdCompressor().compress(b"foo" * 64) | ||||
with self.assertRaisesRegex(ValueError, "chunk 1 must be bytes"): | with self.assertRaisesRegex(ValueError, "chunk 1 must be bytes"): | ||||
dctx.decompress_content_dict_chain([initial, None]) | dctx.decompress_content_dict_chain([initial, None]) | ||||
with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
ValueError, "chunk 1 is too small to contain a zstd frame" | ValueError, "chunk 1 is too small to contain a zstd frame" | ||||
): | ): | ||||
dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER]) | dctx.decompress_content_dict_chain([initial, zstd.FRAME_HEADER]) | ||||
with self.assertRaisesRegex(ValueError, "chunk 1 is not a valid zstd frame"): | with self.assertRaisesRegex( | ||||
ValueError, "chunk 1 is not a valid zstd frame" | |||||
): | |||||
dctx.decompress_content_dict_chain([initial, b"foo" * 8]) | dctx.decompress_content_dict_chain([initial, b"foo" * 8]) | ||||
no_size = zstd.ZstdCompressor(write_content_size=False).compress(b"foo" * 64) | no_size = zstd.ZstdCompressor(write_content_size=False).compress( | ||||
b"foo" * 64 | |||||
) | |||||
with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
ValueError, "chunk 1 missing content size in frame" | ValueError, "chunk 1 missing content size in frame" | ||||
): | ): | ||||
dctx.decompress_content_dict_chain([initial, no_size]) | dctx.decompress_content_dict_chain([initial, no_size]) | ||||
# Corrupt second frame. | # Corrupt second frame. | ||||
cctx = zstd.ZstdCompressor(dict_data=zstd.ZstdCompressionDict(b"foo" * 64)) | cctx = zstd.ZstdCompressor( | ||||
dict_data=zstd.ZstdCompressionDict(b"foo" * 64) | |||||
) | |||||
frame = cctx.compress(b"bar" * 64) | frame = cctx.compress(b"bar" * 64) | ||||
frame = frame[0:12] + frame[15:] | frame = frame[0:12] + frame[15:] | ||||
with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
zstd.ZstdError, "chunk 1 did not decompress full frame" | zstd.ZstdError, "chunk 1 did not decompress full frame" | ||||
): | ): | ||||
dctx.decompress_content_dict_chain([initial, frame]) | dctx.decompress_content_dict_chain([initial, frame]) | ||||
self.skipTest("multi_decompress_to_buffer not available") | self.skipTest("multi_decompress_to_buffer not available") | ||||
with self.assertRaises(TypeError): | with self.assertRaises(TypeError): | ||||
dctx.multi_decompress_to_buffer(True) | dctx.multi_decompress_to_buffer(True) | ||||
with self.assertRaises(TypeError): | with self.assertRaises(TypeError): | ||||
dctx.multi_decompress_to_buffer((1, 2)) | dctx.multi_decompress_to_buffer((1, 2)) | ||||
with self.assertRaisesRegex(TypeError, "item 0 not a bytes like object"): | with self.assertRaisesRegex( | ||||
TypeError, "item 0 not a bytes like object" | |||||
): | |||||
dctx.multi_decompress_to_buffer([u"foo"]) | dctx.multi_decompress_to_buffer([u"foo"]) | ||||
with self.assertRaisesRegex( | with self.assertRaisesRegex( | ||||
ValueError, "could not determine decompressed size of item 0" | ValueError, "could not determine decompressed size of item 0" | ||||
): | ): | ||||
dctx.multi_decompress_to_buffer([b"foobarbaz"]) | dctx.multi_decompress_to_buffer([b"foobarbaz"]) | ||||
def test_list_input(self): | def test_list_input(self): | ||||
frames = [cctx.compress(d) for d in original] | frames = [cctx.compress(d) for d in original] | ||||
sizes = struct.pack("=" + "Q" * len(original), *map(len, original)) | sizes = struct.pack("=" + "Q" * len(original), *map(len, original)) | ||||
dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
if not hasattr(dctx, "multi_decompress_to_buffer"): | if not hasattr(dctx, "multi_decompress_to_buffer"): | ||||
self.skipTest("multi_decompress_to_buffer not available") | self.skipTest("multi_decompress_to_buffer not available") | ||||
result = dctx.multi_decompress_to_buffer(frames, decompressed_sizes=sizes) | result = dctx.multi_decompress_to_buffer( | ||||
frames, decompressed_sizes=sizes | |||||
) | |||||
self.assertEqual(len(result), len(frames)) | self.assertEqual(len(result), len(frames)) | ||||
self.assertEqual(result.size(), sum(map(len, original))) | self.assertEqual(result.size(), sum(map(len, original))) | ||||
for i, data in enumerate(original): | for i, data in enumerate(original): | ||||
self.assertEqual(result[i].tobytes(), data) | self.assertEqual(result[i].tobytes(), data) | ||||
def test_buffer_with_segments_input(self): | def test_buffer_with_segments_input(self): | ||||
self.assertEqual(len(decompressed), len(original)) | self.assertEqual(len(decompressed), len(original)) | ||||
for i, data in enumerate(original): | for i, data in enumerate(original): | ||||
self.assertEqual(data, decompressed[i].tobytes()) | self.assertEqual(data, decompressed[i].tobytes()) | ||||
# And a manual mode. | # And a manual mode. | ||||
b = b"".join([frames[0].tobytes(), frames[1].tobytes()]) | b = b"".join([frames[0].tobytes(), frames[1].tobytes()]) | ||||
b1 = zstd.BufferWithSegments( | b1 = zstd.BufferWithSegments( | ||||
b, struct.pack("=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1])) | b, | ||||
struct.pack( | |||||
"=QQQQ", 0, len(frames[0]), len(frames[0]), len(frames[1]) | |||||
), | |||||
) | ) | ||||
b = b"".join([frames[2].tobytes(), frames[3].tobytes(), frames[4].tobytes()]) | b = b"".join( | ||||
[frames[2].tobytes(), frames[3].tobytes(), frames[4].tobytes()] | |||||
) | |||||
b2 = zstd.BufferWithSegments( | b2 = zstd.BufferWithSegments( | ||||
b, | b, | ||||
struct.pack( | struct.pack( | ||||
"=QQQQQQ", | "=QQQQQQ", | ||||
0, | 0, | ||||
len(frames[2]), | len(frames[2]), | ||||
len(frames[2]), | len(frames[2]), | ||||
len(frames[3]), | len(frames[3]), |
suppress_health_check=[hypothesis.HealthCheck.large_base_example] | suppress_health_check=[hypothesis.HealthCheck.large_base_example] | ||||
) | ) | ||||
@hypothesis.given( | @hypothesis.given( | ||||
original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
streaming=strategies.booleans(), | streaming=strategies.booleans(), | ||||
source_read_size=strategies.integers(1, 1048576), | source_read_size=strategies.integers(1, 1048576), | ||||
) | ) | ||||
def test_stream_source_readall(self, original, level, streaming, source_read_size): | def test_stream_source_readall( | ||||
self, original, level, streaming, source_read_size | |||||
): | |||||
cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
if streaming: | if streaming: | ||||
source = io.BytesIO() | source = io.BytesIO() | ||||
writer = cctx.stream_writer(source) | writer = cctx.stream_writer(source) | ||||
writer.write(original) | writer.write(original) | ||||
writer.flush(zstd.FLUSH_FRAME) | writer.flush(zstd.FLUSH_FRAME) | ||||
source.seek(0) | source.seek(0) | ||||
] | ] | ||||
) | ) | ||||
@hypothesis.given( | @hypothesis.given( | ||||
original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
write_size=strategies.integers(min_value=1, max_value=8192), | write_size=strategies.integers(min_value=1, max_value=8192), | ||||
input_sizes=strategies.data(), | input_sizes=strategies.data(), | ||||
) | ) | ||||
def test_write_size_variance(self, original, level, write_size, input_sizes): | def test_write_size_variance( | ||||
self, original, level, write_size, input_sizes | |||||
): | |||||
cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
frame = cctx.compress(original) | frame = cctx.compress(original) | ||||
dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
source = io.BytesIO(frame) | source = io.BytesIO(frame) | ||||
dest = NonClosingBytesIO() | dest = NonClosingBytesIO() | ||||
with dctx.stream_writer(dest, write_size=write_size) as decompressor: | with dctx.stream_writer(dest, write_size=write_size) as decompressor: | ||||
] | ] | ||||
) | ) | ||||
@hypothesis.given( | @hypothesis.given( | ||||
original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
read_size=strategies.integers(min_value=1, max_value=8192), | read_size=strategies.integers(min_value=1, max_value=8192), | ||||
write_size=strategies.integers(min_value=1, max_value=8192), | write_size=strategies.integers(min_value=1, max_value=8192), | ||||
) | ) | ||||
def test_read_write_size_variance(self, original, level, read_size, write_size): | def test_read_write_size_variance( | ||||
self, original, level, read_size, write_size | |||||
): | |||||
cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
frame = cctx.compress(original) | frame = cctx.compress(original) | ||||
source = io.BytesIO(frame) | source = io.BytesIO(frame) | ||||
dest = io.BytesIO() | dest = io.BytesIO() | ||||
dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
dctx.copy_stream(source, dest, read_size=read_size, write_size=write_size) | dctx.copy_stream( | ||||
source, dest, read_size=read_size, write_size=write_size | |||||
) | |||||
self.assertEqual(dest.getvalue(), original) | self.assertEqual(dest.getvalue(), original) | ||||
@unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | ||||
@make_cffi | @make_cffi | ||||
class TestDecompressor_decompressobj_fuzzing(TestCase): | class TestDecompressor_decompressobj_fuzzing(TestCase): | ||||
@hypothesis.settings( | @hypothesis.settings( | ||||
hypothesis.HealthCheck.large_base_example, | hypothesis.HealthCheck.large_base_example, | ||||
hypothesis.HealthCheck.too_slow, | hypothesis.HealthCheck.too_slow, | ||||
] | ] | ||||
) | ) | ||||
@hypothesis.given( | @hypothesis.given( | ||||
original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
write_size=strategies.integers( | write_size=strategies.integers( | ||||
min_value=1, max_value=4 * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE | min_value=1, | ||||
max_value=4 * zstd.DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE, | |||||
), | ), | ||||
chunk_sizes=strategies.data(), | chunk_sizes=strategies.data(), | ||||
) | ) | ||||
def test_random_output_sizes(self, original, level, write_size, chunk_sizes): | def test_random_output_sizes( | ||||
self, original, level, write_size, chunk_sizes | |||||
): | |||||
cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
frame = cctx.compress(original) | frame = cctx.compress(original) | ||||
source = io.BytesIO(frame) | source = io.BytesIO(frame) | ||||
dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
dobj = dctx.decompressobj(write_size=write_size) | dobj = dctx.decompressobj(write_size=write_size) | ||||
@make_cffi | @make_cffi | ||||
class TestDecompressor_read_to_iter_fuzzing(TestCase): | class TestDecompressor_read_to_iter_fuzzing(TestCase): | ||||
@hypothesis.given( | @hypothesis.given( | ||||
original=strategies.sampled_from(random_input_data()), | original=strategies.sampled_from(random_input_data()), | ||||
level=strategies.integers(min_value=1, max_value=5), | level=strategies.integers(min_value=1, max_value=5), | ||||
read_size=strategies.integers(min_value=1, max_value=4096), | read_size=strategies.integers(min_value=1, max_value=4096), | ||||
write_size=strategies.integers(min_value=1, max_value=4096), | write_size=strategies.integers(min_value=1, max_value=4096), | ||||
) | ) | ||||
def test_read_write_size_variance(self, original, level, read_size, write_size): | def test_read_write_size_variance( | ||||
self, original, level, read_size, write_size | |||||
): | |||||
cctx = zstd.ZstdCompressor(level=level) | cctx = zstd.ZstdCompressor(level=level) | ||||
frame = cctx.compress(original) | frame = cctx.compress(original) | ||||
source = io.BytesIO(frame) | source = io.BytesIO(frame) | ||||
dctx = zstd.ZstdDecompressor() | dctx = zstd.ZstdDecompressor() | ||||
chunks = list( | chunks = list( | ||||
dctx.read_to_iter(source, read_size=read_size, write_size=write_size) | dctx.read_to_iter( | ||||
source, read_size=read_size, write_size=write_size | |||||
) | |||||
) | ) | ||||
self.assertEqual(b"".join(chunks), original) | self.assertEqual(b"".join(chunks), original) | ||||
@unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set") | ||||
class TestDecompressor_multi_decompress_to_buffer_fuzzing(TestCase): | class TestDecompressor_multi_decompress_to_buffer_fuzzing(TestCase): | ||||
@hypothesis.given( | @hypothesis.given( | ||||
original=strategies.lists( | original=strategies.lists( | ||||
strategies.sampled_from(random_input_data()), min_size=1, max_size=1024 | strategies.sampled_from(random_input_data()), | ||||
min_size=1, | |||||
max_size=1024, | |||||
), | ), | ||||
threads=strategies.integers(min_value=1, max_value=8), | threads=strategies.integers(min_value=1, max_value=8), | ||||
use_dict=strategies.booleans(), | use_dict=strategies.booleans(), | ||||
) | ) | ||||
def test_data_equivalence(self, original, threads, use_dict): | def test_data_equivalence(self, original, threads, use_dict): | ||||
kwargs = {} | kwargs = {} | ||||
if use_dict: | if use_dict: | ||||
kwargs["dict_data"] = zstd.ZstdCompressionDict(original[0]) | kwargs["dict_data"] = zstd.ZstdCompressionDict(original[0]) |
data = d.as_bytes() | data = d.as_bytes() | ||||
self.assertEqual(data[0:4], b"\x37\xa4\x30\xec") | self.assertEqual(data[0:4], b"\x37\xa4\x30\xec") | ||||
self.assertEqual(d.k, 64) | self.assertEqual(d.k, 64) | ||||
self.assertEqual(d.d, 16) | self.assertEqual(d.d, 16) | ||||
def test_set_dict_id(self): | def test_set_dict_id(self): | ||||
d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16, dict_id=42) | d = zstd.train_dictionary( | ||||
8192, generate_samples(), k=64, d=16, dict_id=42 | |||||
) | |||||
self.assertEqual(d.dict_id(), 42) | self.assertEqual(d.dict_id(), 42) | ||||
def test_optimize(self): | def test_optimize(self): | ||||
d = zstd.train_dictionary(8192, generate_samples(), threads=-1, steps=1, d=16) | d = zstd.train_dictionary( | ||||
8192, generate_samples(), threads=-1, steps=1, d=16 | |||||
) | |||||
# This varies by platform. | # This varies by platform. | ||||
self.assertIn(d.k, (50, 2000)) | self.assertIn(d.k, (50, 2000)) | ||||
self.assertEqual(d.d, 16) | self.assertEqual(d.d, 16) | ||||
@make_cffi | @make_cffi | ||||
class TestCompressionDict(TestCase): | class TestCompressionDict(TestCase): | ||||
def test_bad_mode(self): | def test_bad_mode(self): | ||||
with self.assertRaisesRegex(ValueError, "invalid dictionary load mode"): | with self.assertRaisesRegex(ValueError, "invalid dictionary load mode"): | ||||
zstd.ZstdCompressionDict(b"foo", dict_type=42) | zstd.ZstdCompressionDict(b"foo", dict_type=42) | ||||
def test_bad_precompute_compress(self): | def test_bad_precompute_compress(self): | ||||
d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16) | d = zstd.train_dictionary(8192, generate_samples(), k=64, d=16) | ||||
with self.assertRaisesRegex(ValueError, "must specify one of level or "): | with self.assertRaisesRegex( | ||||
ValueError, "must specify one of level or " | |||||
): | |||||
d.precompute_compress() | d.precompute_compress() | ||||
with self.assertRaisesRegex(ValueError, "must only specify one of level or "): | with self.assertRaisesRegex( | ||||
ValueError, "must only specify one of level or " | |||||
): | |||||
d.precompute_compress( | d.precompute_compress( | ||||
level=3, compression_params=zstd.CompressionParameters() | level=3, compression_params=zstd.CompressionParameters() | ||||
) | ) | ||||
def test_precompute_compress_rawcontent(self): | def test_precompute_compress_rawcontent(self): | ||||
d = zstd.ZstdCompressionDict( | d = zstd.ZstdCompressionDict( | ||||
b"dictcontent" * 64, dict_type=zstd.DICT_TYPE_RAWCONTENT | b"dictcontent" * 64, dict_type=zstd.DICT_TYPE_RAWCONTENT | ||||
) | ) | ||||
d.precompute_compress(level=1) | d.precompute_compress(level=1) | ||||
d = zstd.ZstdCompressionDict( | d = zstd.ZstdCompressionDict( | ||||
b"dictcontent" * 64, dict_type=zstd.DICT_TYPE_FULLDICT | b"dictcontent" * 64, dict_type=zstd.DICT_TYPE_FULLDICT | ||||
) | ) | ||||
with self.assertRaisesRegex(zstd.ZstdError, "unable to precompute dictionary"): | with self.assertRaisesRegex( | ||||
zstd.ZstdError, "unable to precompute dictionary" | |||||
): | |||||
d.precompute_compress(level=1) | d.precompute_compress(level=1) |
_set_compression_parameter( | _set_compression_parameter( | ||||
params, lib.ZSTD_c_compressionLevel, compression_level | params, lib.ZSTD_c_compressionLevel, compression_level | ||||
) | ) | ||||
_set_compression_parameter(params, lib.ZSTD_c_windowLog, window_log) | _set_compression_parameter(params, lib.ZSTD_c_windowLog, window_log) | ||||
_set_compression_parameter(params, lib.ZSTD_c_hashLog, hash_log) | _set_compression_parameter(params, lib.ZSTD_c_hashLog, hash_log) | ||||
_set_compression_parameter(params, lib.ZSTD_c_chainLog, chain_log) | _set_compression_parameter(params, lib.ZSTD_c_chainLog, chain_log) | ||||
_set_compression_parameter(params, lib.ZSTD_c_searchLog, search_log) | _set_compression_parameter(params, lib.ZSTD_c_searchLog, search_log) | ||||
_set_compression_parameter(params, lib.ZSTD_c_minMatch, min_match) | _set_compression_parameter(params, lib.ZSTD_c_minMatch, min_match) | ||||
_set_compression_parameter(params, lib.ZSTD_c_targetLength, target_length) | _set_compression_parameter( | ||||
params, lib.ZSTD_c_targetLength, target_length | |||||
) | |||||
if strategy != -1 and compression_strategy != -1: | if strategy != -1 and compression_strategy != -1: | ||||
raise ValueError("cannot specify both compression_strategy and strategy") | raise ValueError( | ||||
"cannot specify both compression_strategy and strategy" | |||||
) | |||||
if compression_strategy != -1: | if compression_strategy != -1: | ||||
strategy = compression_strategy | strategy = compression_strategy | ||||
elif strategy == -1: | elif strategy == -1: | ||||
strategy = 0 | strategy = 0 | ||||
_set_compression_parameter(params, lib.ZSTD_c_strategy, strategy) | _set_compression_parameter(params, lib.ZSTD_c_strategy, strategy) | ||||
_set_compression_parameter( | _set_compression_parameter( | ||||
params, lib.ZSTD_c_contentSizeFlag, write_content_size | params, lib.ZSTD_c_contentSizeFlag, write_content_size | ||||
) | ) | ||||
_set_compression_parameter(params, lib.ZSTD_c_checksumFlag, write_checksum) | _set_compression_parameter( | ||||
params, lib.ZSTD_c_checksumFlag, write_checksum | |||||
) | |||||
_set_compression_parameter(params, lib.ZSTD_c_dictIDFlag, write_dict_id) | _set_compression_parameter(params, lib.ZSTD_c_dictIDFlag, write_dict_id) | ||||
_set_compression_parameter(params, lib.ZSTD_c_jobSize, job_size) | _set_compression_parameter(params, lib.ZSTD_c_jobSize, job_size) | ||||
if overlap_log != -1 and overlap_size_log != -1: | if overlap_log != -1 and overlap_size_log != -1: | ||||
raise ValueError("cannot specify both overlap_log and overlap_size_log") | raise ValueError( | ||||
"cannot specify both overlap_log and overlap_size_log" | |||||
) | |||||
if overlap_size_log != -1: | if overlap_size_log != -1: | ||||
overlap_log = overlap_size_log | overlap_log = overlap_size_log | ||||
elif overlap_log == -1: | elif overlap_log == -1: | ||||
overlap_log = 0 | overlap_log = 0 | ||||
_set_compression_parameter(params, lib.ZSTD_c_overlapLog, overlap_log) | _set_compression_parameter(params, lib.ZSTD_c_overlapLog, overlap_log) | ||||
_set_compression_parameter(params, lib.ZSTD_c_forceMaxWindow, force_max_window) | _set_compression_parameter( | ||||
params, lib.ZSTD_c_forceMaxWindow, force_max_window | |||||
) | |||||
_set_compression_parameter( | _set_compression_parameter( | ||||
params, lib.ZSTD_c_enableLongDistanceMatching, enable_ldm | params, lib.ZSTD_c_enableLongDistanceMatching, enable_ldm | ||||
) | ) | ||||
_set_compression_parameter(params, lib.ZSTD_c_ldmHashLog, ldm_hash_log) | _set_compression_parameter(params, lib.ZSTD_c_ldmHashLog, ldm_hash_log) | ||||
_set_compression_parameter(params, lib.ZSTD_c_ldmMinMatch, ldm_min_match) | _set_compression_parameter( | ||||
params, lib.ZSTD_c_ldmMinMatch, ldm_min_match | |||||
) | |||||
_set_compression_parameter( | _set_compression_parameter( | ||||
params, lib.ZSTD_c_ldmBucketSizeLog, ldm_bucket_size_log | params, lib.ZSTD_c_ldmBucketSizeLog, ldm_bucket_size_log | ||||
) | ) | ||||
if ldm_hash_rate_log != -1 and ldm_hash_every_log != -1: | if ldm_hash_rate_log != -1 and ldm_hash_every_log != -1: | ||||
raise ValueError( | raise ValueError( | ||||
"cannot specify both ldm_hash_rate_log and ldm_hash_every_log" | "cannot specify both ldm_hash_rate_log and ldm_hash_every_log" | ||||
) | ) | ||||
if ldm_hash_every_log != -1: | if ldm_hash_every_log != -1: | ||||
ldm_hash_rate_log = ldm_hash_every_log | ldm_hash_rate_log = ldm_hash_every_log | ||||
elif ldm_hash_rate_log == -1: | elif ldm_hash_rate_log == -1: | ||||
ldm_hash_rate_log = 0 | ldm_hash_rate_log = 0 | ||||
_set_compression_parameter(params, lib.ZSTD_c_ldmHashRateLog, ldm_hash_rate_log) | _set_compression_parameter( | ||||
params, lib.ZSTD_c_ldmHashRateLog, ldm_hash_rate_log | |||||
) | |||||
@property | @property | ||||
def format(self): | def format(self): | ||||
return _get_compression_parameter(self._params, lib.ZSTD_c_format) | return _get_compression_parameter(self._params, lib.ZSTD_c_format) | ||||
@property | @property | ||||
def compression_level(self): | def compression_level(self): | ||||
return _get_compression_parameter(self._params, lib.ZSTD_c_compressionLevel) | return _get_compression_parameter( | ||||
self._params, lib.ZSTD_c_compressionLevel | |||||
) | |||||
@property | @property | ||||
def window_log(self): | def window_log(self): | ||||
return _get_compression_parameter(self._params, lib.ZSTD_c_windowLog) | return _get_compression_parameter(self._params, lib.ZSTD_c_windowLog) | ||||
@property | @property | ||||
def hash_log(self): | def hash_log(self): | ||||
return _get_compression_parameter(self._params, lib.ZSTD_c_hashLog) | return _get_compression_parameter(self._params, lib.ZSTD_c_hashLog) | ||||
return _get_compression_parameter(self._params, lib.ZSTD_c_targetLength) | return _get_compression_parameter(self._params, lib.ZSTD_c_targetLength) | ||||
@property | @property | ||||
def compression_strategy(self): | def compression_strategy(self): | ||||
return _get_compression_parameter(self._params, lib.ZSTD_c_strategy) | return _get_compression_parameter(self._params, lib.ZSTD_c_strategy) | ||||
@property | @property | ||||
def write_content_size(self): | def write_content_size(self): | ||||
return _get_compression_parameter(self._params, lib.ZSTD_c_contentSizeFlag) | return _get_compression_parameter( | ||||
self._params, lib.ZSTD_c_contentSizeFlag | |||||
) | |||||
@property | @property | ||||
def write_checksum(self): | def write_checksum(self): | ||||
return _get_compression_parameter(self._params, lib.ZSTD_c_checksumFlag) | return _get_compression_parameter(self._params, lib.ZSTD_c_checksumFlag) | ||||
@property | @property | ||||
def write_dict_id(self): | def write_dict_id(self): | ||||
return _get_compression_parameter(self._params, lib.ZSTD_c_dictIDFlag) | return _get_compression_parameter(self._params, lib.ZSTD_c_dictIDFlag) | ||||
@property | @property | ||||
def job_size(self): | def job_size(self): | ||||
return _get_compression_parameter(self._params, lib.ZSTD_c_jobSize) | return _get_compression_parameter(self._params, lib.ZSTD_c_jobSize) | ||||
@property | @property | ||||
def overlap_log(self): | def overlap_log(self): | ||||
return _get_compression_parameter(self._params, lib.ZSTD_c_overlapLog) | return _get_compression_parameter(self._params, lib.ZSTD_c_overlapLog) | ||||
@property | @property | ||||
def overlap_size_log(self): | def overlap_size_log(self): | ||||
return self.overlap_log | return self.overlap_log | ||||
@property | @property | ||||
def force_max_window(self): | def force_max_window(self): | ||||
return _get_compression_parameter(self._params, lib.ZSTD_c_forceMaxWindow) | return _get_compression_parameter( | ||||
self._params, lib.ZSTD_c_forceMaxWindow | |||||
) | |||||
@property | @property | ||||
def enable_ldm(self): | def enable_ldm(self): | ||||
return _get_compression_parameter( | return _get_compression_parameter( | ||||
self._params, lib.ZSTD_c_enableLongDistanceMatching | self._params, lib.ZSTD_c_enableLongDistanceMatching | ||||
) | ) | ||||
@property | @property | ||||
def ldm_hash_log(self): | def ldm_hash_log(self): | ||||
return _get_compression_parameter(self._params, lib.ZSTD_c_ldmHashLog) | return _get_compression_parameter(self._params, lib.ZSTD_c_ldmHashLog) | ||||
@property | @property | ||||
def ldm_min_match(self): | def ldm_min_match(self): | ||||
return _get_compression_parameter(self._params, lib.ZSTD_c_ldmMinMatch) | return _get_compression_parameter(self._params, lib.ZSTD_c_ldmMinMatch) | ||||
@property | @property | ||||
def ldm_bucket_size_log(self): | def ldm_bucket_size_log(self): | ||||
return _get_compression_parameter(self._params, lib.ZSTD_c_ldmBucketSizeLog) | return _get_compression_parameter( | ||||
self._params, lib.ZSTD_c_ldmBucketSizeLog | |||||
) | |||||
@property | @property | ||||
def ldm_hash_rate_log(self): | def ldm_hash_rate_log(self): | ||||
return _get_compression_parameter(self._params, lib.ZSTD_c_ldmHashRateLog) | return _get_compression_parameter( | ||||
self._params, lib.ZSTD_c_ldmHashRateLog | |||||
) | |||||
@property | @property | ||||
def ldm_hash_every_log(self): | def ldm_hash_every_log(self): | ||||
return self.ldm_hash_rate_log | return self.ldm_hash_rate_log | ||||
@property | @property | ||||
def threads(self): | def threads(self): | ||||
return _get_compression_parameter(self._params, lib.ZSTD_c_nbWorkers) | return _get_compression_parameter(self._params, lib.ZSTD_c_nbWorkers) | ||||
def estimated_compression_context_size(self): | def estimated_compression_context_size(self): | ||||
return lib.ZSTD_estimateCCtxSize_usingCCtxParams(self._params) | return lib.ZSTD_estimateCCtxSize_usingCCtxParams(self._params) | ||||
CompressionParameters = ZstdCompressionParameters | CompressionParameters = ZstdCompressionParameters | ||||
def estimate_decompression_context_size(): | def estimate_decompression_context_size(): | ||||
return lib.ZSTD_estimateDCtxSize() | return lib.ZSTD_estimateDCtxSize() | ||||
def _set_compression_parameter(params, param, value): | def _set_compression_parameter(params, param, value): | ||||
zresult = lib.ZSTD_CCtxParams_setParameter(params, param, value) | zresult = lib.ZSTD_CCtxParams_setParameter(params, param, value) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError( | raise ZstdError( | ||||
"unable to set compression context parameter: %s" % _zstd_error(zresult) | "unable to set compression context parameter: %s" | ||||
% _zstd_error(zresult) | |||||
) | ) | ||||
def _get_compression_parameter(params, param): | def _get_compression_parameter(params, param): | ||||
result = ffi.new("int *") | result = ffi.new("int *") | ||||
zresult = lib.ZSTD_CCtxParams_getParameter(params, param, result) | zresult = lib.ZSTD_CCtxParams_getParameter(params, param, result) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError( | raise ZstdError( | ||||
"unable to get compression context parameter: %s" % _zstd_error(zresult) | "unable to get compression context parameter: %s" | ||||
% _zstd_error(zresult) | |||||
) | ) | ||||
return result[0] | return result[0] | ||||
class ZstdCompressionWriter(object): | class ZstdCompressionWriter(object): | ||||
def __init__(self, compressor, writer, source_size, write_size, write_return_read): | def __init__( | ||||
self, compressor, writer, source_size, write_size, write_return_read | |||||
): | |||||
self._compressor = compressor | self._compressor = compressor | ||||
self._writer = writer | self._writer = writer | ||||
self._write_size = write_size | self._write_size = write_size | ||||
self._write_return_read = bool(write_return_read) | self._write_return_read = bool(write_return_read) | ||||
self._entered = False | self._entered = False | ||||
self._closed = False | self._closed = False | ||||
self._bytes_compressed = 0 | self._bytes_compressed = 0 | ||||
self._dst_buffer = ffi.new("char[]", write_size) | self._dst_buffer = ffi.new("char[]", write_size) | ||||
self._out_buffer = ffi.new("ZSTD_outBuffer *") | self._out_buffer = ffi.new("ZSTD_outBuffer *") | ||||
self._out_buffer.dst = self._dst_buffer | self._out_buffer.dst = self._dst_buffer | ||||
self._out_buffer.size = len(self._dst_buffer) | self._out_buffer.size = len(self._dst_buffer) | ||||
self._out_buffer.pos = 0 | self._out_buffer.pos = 0 | ||||
zresult = lib.ZSTD_CCtx_setPledgedSrcSize(compressor._cctx, source_size) | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(compressor._cctx, source_size) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("error setting source size: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
"error setting source size: %s" % _zstd_error(zresult) | |||||
) | |||||
def __enter__(self): | def __enter__(self): | ||||
if self._closed: | if self._closed: | ||||
raise ValueError("stream is closed") | raise ValueError("stream is closed") | ||||
if self._entered: | if self._entered: | ||||
raise ZstdError("cannot __enter__ multiple times") | raise ZstdError("cannot __enter__ multiple times") | ||||
in_buffer.size = len(data_buffer) | in_buffer.size = len(data_buffer) | ||||
in_buffer.pos = 0 | in_buffer.pos = 0 | ||||
out_buffer = self._out_buffer | out_buffer = self._out_buffer | ||||
out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
while in_buffer.pos < in_buffer.size: | while in_buffer.pos < in_buffer.size: | ||||
zresult = lib.ZSTD_compressStream2( | zresult = lib.ZSTD_compressStream2( | ||||
self._compressor._cctx, out_buffer, in_buffer, lib.ZSTD_e_continue | self._compressor._cctx, | ||||
out_buffer, | |||||
in_buffer, | |||||
lib.ZSTD_e_continue, | |||||
) | ) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
"zstd compress error: %s" % _zstd_error(zresult) | |||||
) | |||||
if out_buffer.pos: | if out_buffer.pos: | ||||
self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:]) | self._writer.write( | ||||
ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |||||
) | |||||
total_write += out_buffer.pos | total_write += out_buffer.pos | ||||
self._bytes_compressed += out_buffer.pos | self._bytes_compressed += out_buffer.pos | ||||
out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
if self._write_return_read: | if self._write_return_read: | ||||
return in_buffer.pos | return in_buffer.pos | ||||
else: | else: | ||||
return total_write | return total_write | ||||
in_buffer.size = 0 | in_buffer.size = 0 | ||||
in_buffer.pos = 0 | in_buffer.pos = 0 | ||||
while True: | while True: | ||||
zresult = lib.ZSTD_compressStream2( | zresult = lib.ZSTD_compressStream2( | ||||
self._compressor._cctx, out_buffer, in_buffer, flush | self._compressor._cctx, out_buffer, in_buffer, flush | ||||
) | ) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
"zstd compress error: %s" % _zstd_error(zresult) | |||||
) | |||||
if out_buffer.pos: | if out_buffer.pos: | ||||
self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:]) | self._writer.write( | ||||
ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |||||
) | |||||
total_write += out_buffer.pos | total_write += out_buffer.pos | ||||
self._bytes_compressed += out_buffer.pos | self._bytes_compressed += out_buffer.pos | ||||
out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
if not zresult: | if not zresult: | ||||
break | break | ||||
return total_write | return total_write | ||||
chunks = [] | chunks = [] | ||||
while source.pos < len(data): | while source.pos < len(data): | ||||
zresult = lib.ZSTD_compressStream2( | zresult = lib.ZSTD_compressStream2( | ||||
self._compressor._cctx, self._out, source, lib.ZSTD_e_continue | self._compressor._cctx, self._out, source, lib.ZSTD_e_continue | ||||
) | ) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
"zstd compress error: %s" % _zstd_error(zresult) | |||||
) | |||||
if self._out.pos: | if self._out.pos: | ||||
chunks.append(ffi.buffer(self._out.dst, self._out.pos)[:]) | chunks.append(ffi.buffer(self._out.dst, self._out.pos)[:]) | ||||
self._out.pos = 0 | self._out.pos = 0 | ||||
return b"".join(chunks) | return b"".join(chunks) | ||||
def flush(self, flush_mode=COMPRESSOBJ_FLUSH_FINISH): | def flush(self, flush_mode=COMPRESSOBJ_FLUSH_FINISH): | ||||
if flush_mode not in (COMPRESSOBJ_FLUSH_FINISH, COMPRESSOBJ_FLUSH_BLOCK): | if flush_mode not in ( | ||||
COMPRESSOBJ_FLUSH_FINISH, | |||||
COMPRESSOBJ_FLUSH_BLOCK, | |||||
): | |||||
raise ValueError("flush mode not recognized") | raise ValueError("flush mode not recognized") | ||||
if self._finished: | if self._finished: | ||||
raise ZstdError("compressor object already finished") | raise ZstdError("compressor object already finished") | ||||
if flush_mode == COMPRESSOBJ_FLUSH_BLOCK: | if flush_mode == COMPRESSOBJ_FLUSH_BLOCK: | ||||
z_flush_mode = lib.ZSTD_e_flush | z_flush_mode = lib.ZSTD_e_flush | ||||
elif flush_mode == COMPRESSOBJ_FLUSH_FINISH: | elif flush_mode == COMPRESSOBJ_FLUSH_FINISH: | ||||
) | ) | ||||
if self._in.pos == self._in.size: | if self._in.pos == self._in.size: | ||||
self._in.src = ffi.NULL | self._in.src = ffi.NULL | ||||
self._in.size = 0 | self._in.size = 0 | ||||
self._in.pos = 0 | self._in.pos = 0 | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
"zstd compress error: %s" % _zstd_error(zresult) | |||||
) | |||||
if self._out.pos == self._out.size: | if self._out.pos == self._out.size: | ||||
yield ffi.buffer(self._out.dst, self._out.pos)[:] | yield ffi.buffer(self._out.dst, self._out.pos)[:] | ||||
self._out.pos = 0 | self._out.pos = 0 | ||||
def flush(self): | def flush(self): | ||||
if self._finished: | if self._finished: | ||||
raise ZstdError("cannot call flush() after compression finished") | raise ZstdError("cannot call flush() after compression finished") | ||||
if self._in.src != ffi.NULL: | if self._in.src != ffi.NULL: | ||||
raise ZstdError( | raise ZstdError( | ||||
"cannot call flush() before consuming output from " "previous operation" | "cannot call flush() before consuming output from " | ||||
"previous operation" | |||||
) | ) | ||||
while True: | while True: | ||||
zresult = lib.ZSTD_compressStream2( | zresult = lib.ZSTD_compressStream2( | ||||
self._compressor._cctx, self._out, self._in, lib.ZSTD_e_flush | self._compressor._cctx, self._out, self._in, lib.ZSTD_e_flush | ||||
) | ) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
"zstd compress error: %s" % _zstd_error(zresult) | |||||
) | |||||
if self._out.pos: | if self._out.pos: | ||||
yield ffi.buffer(self._out.dst, self._out.pos)[:] | yield ffi.buffer(self._out.dst, self._out.pos)[:] | ||||
self._out.pos = 0 | self._out.pos = 0 | ||||
if not zresult: | if not zresult: | ||||
return | return | ||||
def finish(self): | def finish(self): | ||||
if self._finished: | if self._finished: | ||||
raise ZstdError("cannot call finish() after compression finished") | raise ZstdError("cannot call finish() after compression finished") | ||||
if self._in.src != ffi.NULL: | if self._in.src != ffi.NULL: | ||||
raise ZstdError( | raise ZstdError( | ||||
"cannot call finish() before consuming output from " | "cannot call finish() before consuming output from " | ||||
"previous operation" | "previous operation" | ||||
) | ) | ||||
while True: | while True: | ||||
zresult = lib.ZSTD_compressStream2( | zresult = lib.ZSTD_compressStream2( | ||||
self._compressor._cctx, self._out, self._in, lib.ZSTD_e_end | self._compressor._cctx, self._out, self._in, lib.ZSTD_e_end | ||||
) | ) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
"zstd compress error: %s" % _zstd_error(zresult) | |||||
) | |||||
if self._out.pos: | if self._out.pos: | ||||
yield ffi.buffer(self._out.dst, self._out.pos)[:] | yield ffi.buffer(self._out.dst, self._out.pos)[:] | ||||
self._out.pos = 0 | self._out.pos = 0 | ||||
if not zresult: | if not zresult: | ||||
self._finished = True | self._finished = True | ||||
return | return | ||||
def _compress_into_buffer(self, out_buffer): | def _compress_into_buffer(self, out_buffer): | ||||
if self._in_buffer.pos >= self._in_buffer.size: | if self._in_buffer.pos >= self._in_buffer.size: | ||||
return | return | ||||
old_pos = out_buffer.pos | old_pos = out_buffer.pos | ||||
zresult = lib.ZSTD_compressStream2( | zresult = lib.ZSTD_compressStream2( | ||||
self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_continue | self._compressor._cctx, | ||||
out_buffer, | |||||
self._in_buffer, | |||||
lib.ZSTD_e_continue, | |||||
) | ) | ||||
self._bytes_compressed += out_buffer.pos - old_pos | self._bytes_compressed += out_buffer.pos - old_pos | ||||
if self._in_buffer.pos == self._in_buffer.size: | if self._in_buffer.pos == self._in_buffer.size: | ||||
self._in_buffer.src = ffi.NULL | self._in_buffer.src = ffi.NULL | ||||
self._in_buffer.pos = 0 | self._in_buffer.pos = 0 | ||||
self._in_buffer.size = 0 | self._in_buffer.size = 0 | ||||
zresult = lib.ZSTD_compressStream2( | zresult = lib.ZSTD_compressStream2( | ||||
self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end | self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end | ||||
) | ) | ||||
self._bytes_compressed += out_buffer.pos - old_pos | self._bytes_compressed += out_buffer.pos - old_pos | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("error ending compression stream: %s", _zstd_error(zresult)) | raise ZstdError( | ||||
"error ending compression stream: %s", _zstd_error(zresult) | |||||
) | |||||
if zresult == 0: | if zresult == 0: | ||||
self._finished_output = True | self._finished_output = True | ||||
return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | ||||
def read1(self, size=-1): | def read1(self, size=-1): | ||||
if self._closed: | if self._closed: | ||||
old_pos = out_buffer.pos | old_pos = out_buffer.pos | ||||
zresult = lib.ZSTD_compressStream2( | zresult = lib.ZSTD_compressStream2( | ||||
self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end | self._compressor._cctx, out_buffer, self._in_buffer, lib.ZSTD_e_end | ||||
) | ) | ||||
self._bytes_compressed += out_buffer.pos - old_pos | self._bytes_compressed += out_buffer.pos - old_pos | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("error ending compression stream: %s", _zstd_error(zresult)) | raise ZstdError( | ||||
"error ending compression stream: %s", _zstd_error(zresult) | |||||
) | |||||
if zresult == 0: | if zresult == 0: | ||||
self._finished_output = True | self._finished_output = True | ||||
return out_buffer.pos | return out_buffer.pos | ||||
def readinto1(self, b): | def readinto1(self, b): | ||||
if self._closed: | if self._closed: | ||||
dict_data=None, | dict_data=None, | ||||
compression_params=None, | compression_params=None, | ||||
write_checksum=None, | write_checksum=None, | ||||
write_content_size=None, | write_content_size=None, | ||||
write_dict_id=None, | write_dict_id=None, | ||||
threads=0, | threads=0, | ||||
): | ): | ||||
if level > lib.ZSTD_maxCLevel(): | if level > lib.ZSTD_maxCLevel(): | ||||
raise ValueError("level must be less than %d" % lib.ZSTD_maxCLevel()) | raise ValueError( | ||||
"level must be less than %d" % lib.ZSTD_maxCLevel() | |||||
) | |||||
if threads < 0: | if threads < 0: | ||||
threads = _cpu_count() | threads = _cpu_count() | ||||
if compression_params and write_checksum is not None: | if compression_params and write_checksum is not None: | ||||
raise ValueError("cannot define compression_params and " "write_checksum") | raise ValueError( | ||||
"cannot define compression_params and " "write_checksum" | |||||
) | |||||
if compression_params and write_content_size is not None: | if compression_params and write_content_size is not None: | ||||
raise ValueError( | raise ValueError( | ||||
"cannot define compression_params and " "write_content_size" | "cannot define compression_params and " "write_content_size" | ||||
) | ) | ||||
if compression_params and write_dict_id is not None: | if compression_params and write_dict_id is not None: | ||||
raise ValueError("cannot define compression_params and " "write_dict_id") | raise ValueError( | ||||
"cannot define compression_params and " "write_dict_id" | |||||
) | |||||
if compression_params and threads: | if compression_params and threads: | ||||
raise ValueError("cannot define compression_params and threads") | raise ValueError("cannot define compression_params and threads") | ||||
if compression_params: | if compression_params: | ||||
self._params = _make_cctx_params(compression_params) | self._params = _make_cctx_params(compression_params) | ||||
else: | else: | ||||
if write_dict_id is None: | if write_dict_id is None: | ||||
write_dict_id = True | write_dict_id = True | ||||
params = lib.ZSTD_createCCtxParams() | params = lib.ZSTD_createCCtxParams() | ||||
if params == ffi.NULL: | if params == ffi.NULL: | ||||
raise MemoryError() | raise MemoryError() | ||||
self._params = ffi.gc(params, lib.ZSTD_freeCCtxParams) | self._params = ffi.gc(params, lib.ZSTD_freeCCtxParams) | ||||
_set_compression_parameter(self._params, lib.ZSTD_c_compressionLevel, level) | _set_compression_parameter( | ||||
self._params, lib.ZSTD_c_compressionLevel, level | |||||
) | |||||
_set_compression_parameter( | _set_compression_parameter( | ||||
self._params, | self._params, | ||||
lib.ZSTD_c_contentSizeFlag, | lib.ZSTD_c_contentSizeFlag, | ||||
write_content_size if write_content_size is not None else 1, | write_content_size if write_content_size is not None else 1, | ||||
) | ) | ||||
_set_compression_parameter( | _set_compression_parameter( | ||||
self._params, lib.ZSTD_c_checksumFlag, 1 if write_checksum else 0 | self._params, | ||||
lib.ZSTD_c_checksumFlag, | |||||
1 if write_checksum else 0, | |||||
) | ) | ||||
_set_compression_parameter( | _set_compression_parameter( | ||||
self._params, lib.ZSTD_c_dictIDFlag, 1 if write_dict_id else 0 | self._params, lib.ZSTD_c_dictIDFlag, 1 if write_dict_id else 0 | ||||
) | ) | ||||
if threads: | if threads: | ||||
_set_compression_parameter(self._params, lib.ZSTD_c_nbWorkers, threads) | _set_compression_parameter( | ||||
self._params, lib.ZSTD_c_nbWorkers, threads | |||||
) | |||||
cctx = lib.ZSTD_createCCtx() | cctx = lib.ZSTD_createCCtx() | ||||
if cctx == ffi.NULL: | if cctx == ffi.NULL: | ||||
raise MemoryError() | raise MemoryError() | ||||
self._cctx = cctx | self._cctx = cctx | ||||
self._dict_data = dict_data | self._dict_data = dict_data | ||||
# We defer setting up garbage collection until after calling | # We defer setting up garbage collection until after calling | ||||
# _setup_cctx() to ensure the memory size estimate is more accurate. | # _setup_cctx() to ensure the memory size estimate is more accurate. | ||||
try: | try: | ||||
self._setup_cctx() | self._setup_cctx() | ||||
finally: | finally: | ||||
self._cctx = ffi.gc( | self._cctx = ffi.gc( | ||||
cctx, lib.ZSTD_freeCCtx, size=lib.ZSTD_sizeof_CCtx(cctx) | cctx, lib.ZSTD_freeCCtx, size=lib.ZSTD_sizeof_CCtx(cctx) | ||||
) | ) | ||||
def _setup_cctx(self): | def _setup_cctx(self): | ||||
zresult = lib.ZSTD_CCtx_setParametersUsingCCtxParams(self._cctx, self._params) | zresult = lib.ZSTD_CCtx_setParametersUsingCCtxParams( | ||||
self._cctx, self._params | |||||
) | |||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError( | raise ZstdError( | ||||
"could not set compression parameters: %s" % _zstd_error(zresult) | "could not set compression parameters: %s" | ||||
% _zstd_error(zresult) | |||||
) | ) | ||||
dict_data = self._dict_data | dict_data = self._dict_data | ||||
if dict_data: | if dict_data: | ||||
if dict_data._cdict: | if dict_data._cdict: | ||||
zresult = lib.ZSTD_CCtx_refCDict(self._cctx, dict_data._cdict) | zresult = lib.ZSTD_CCtx_refCDict(self._cctx, dict_data._cdict) | ||||
else: | else: | ||||
zresult = lib.ZSTD_CCtx_loadDictionary_advanced( | zresult = lib.ZSTD_CCtx_loadDictionary_advanced( | ||||
self._cctx, | self._cctx, | ||||
dict_data.as_bytes(), | dict_data.as_bytes(), | ||||
len(dict_data), | len(dict_data), | ||||
lib.ZSTD_dlm_byRef, | lib.ZSTD_dlm_byRef, | ||||
dict_data._dict_type, | dict_data._dict_type, | ||||
) | ) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError( | raise ZstdError( | ||||
"could not load compression dictionary: %s" % _zstd_error(zresult) | "could not load compression dictionary: %s" | ||||
% _zstd_error(zresult) | |||||
) | ) | ||||
def memory_size(self): | def memory_size(self): | ||||
return lib.ZSTD_sizeof_CCtx(self._cctx) | return lib.ZSTD_sizeof_CCtx(self._cctx) | ||||
def compress(self, data): | def compress(self, data): | ||||
lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | ||||
data_buffer = ffi.from_buffer(data) | data_buffer = ffi.from_buffer(data) | ||||
dest_size = lib.ZSTD_compressBound(len(data_buffer)) | dest_size = lib.ZSTD_compressBound(len(data_buffer)) | ||||
out = new_nonzero("char[]", dest_size) | out = new_nonzero("char[]", dest_size) | ||||
zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, len(data_buffer)) | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, len(data_buffer)) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("error setting source size: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
"error setting source size: %s" % _zstd_error(zresult) | |||||
) | |||||
out_buffer = ffi.new("ZSTD_outBuffer *") | out_buffer = ffi.new("ZSTD_outBuffer *") | ||||
in_buffer = ffi.new("ZSTD_inBuffer *") | in_buffer = ffi.new("ZSTD_inBuffer *") | ||||
out_buffer.dst = out | out_buffer.dst = out | ||||
out_buffer.size = dest_size | out_buffer.size = dest_size | ||||
out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
def compressobj(self, size=-1): | def compressobj(self, size=-1): | ||||
lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | ||||
if size < 0: | if size < 0: | ||||
size = lib.ZSTD_CONTENTSIZE_UNKNOWN | size = lib.ZSTD_CONTENTSIZE_UNKNOWN | ||||
zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("error setting source size: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
"error setting source size: %s" % _zstd_error(zresult) | |||||
) | |||||
cobj = ZstdCompressionObj() | cobj = ZstdCompressionObj() | ||||
cobj._out = ffi.new("ZSTD_outBuffer *") | cobj._out = ffi.new("ZSTD_outBuffer *") | ||||
cobj._dst_buffer = ffi.new("char[]", COMPRESSION_RECOMMENDED_OUTPUT_SIZE) | cobj._dst_buffer = ffi.new( | ||||
"char[]", COMPRESSION_RECOMMENDED_OUTPUT_SIZE | |||||
) | |||||
cobj._out.dst = cobj._dst_buffer | cobj._out.dst = cobj._dst_buffer | ||||
cobj._out.size = COMPRESSION_RECOMMENDED_OUTPUT_SIZE | cobj._out.size = COMPRESSION_RECOMMENDED_OUTPUT_SIZE | ||||
cobj._out.pos = 0 | cobj._out.pos = 0 | ||||
cobj._compressor = self | cobj._compressor = self | ||||
cobj._finished = False | cobj._finished = False | ||||
return cobj | return cobj | ||||
def chunker(self, size=-1, chunk_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE): | def chunker(self, size=-1, chunk_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE): | ||||
lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | ||||
if size < 0: | if size < 0: | ||||
size = lib.ZSTD_CONTENTSIZE_UNKNOWN | size = lib.ZSTD_CONTENTSIZE_UNKNOWN | ||||
zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("error setting source size: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
"error setting source size: %s" % _zstd_error(zresult) | |||||
) | |||||
return ZstdCompressionChunker(self, chunk_size=chunk_size) | return ZstdCompressionChunker(self, chunk_size=chunk_size) | ||||
def copy_stream( | def copy_stream( | ||||
self, | self, | ||||
ifh, | ifh, | ||||
ofh, | ofh, | ||||
size=-1, | size=-1, | ||||
read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE, | read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE, | ||||
write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE, | write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE, | ||||
): | ): | ||||
if not hasattr(ifh, "read"): | if not hasattr(ifh, "read"): | ||||
raise ValueError("first argument must have a read() method") | raise ValueError("first argument must have a read() method") | ||||
if not hasattr(ofh, "write"): | if not hasattr(ofh, "write"): | ||||
raise ValueError("second argument must have a write() method") | raise ValueError("second argument must have a write() method") | ||||
lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | ||||
if size < 0: | if size < 0: | ||||
size = lib.ZSTD_CONTENTSIZE_UNKNOWN | size = lib.ZSTD_CONTENTSIZE_UNKNOWN | ||||
zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("error setting source size: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
"error setting source size: %s" % _zstd_error(zresult) | |||||
) | |||||
in_buffer = ffi.new("ZSTD_inBuffer *") | in_buffer = ffi.new("ZSTD_inBuffer *") | ||||
out_buffer = ffi.new("ZSTD_outBuffer *") | out_buffer = ffi.new("ZSTD_outBuffer *") | ||||
dst_buffer = ffi.new("char[]", write_size) | dst_buffer = ffi.new("char[]", write_size) | ||||
out_buffer.dst = dst_buffer | out_buffer.dst = dst_buffer | ||||
out_buffer.size = write_size | out_buffer.size = write_size | ||||
out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
in_buffer.size = len(data_buffer) | in_buffer.size = len(data_buffer) | ||||
in_buffer.pos = 0 | in_buffer.pos = 0 | ||||
while in_buffer.pos < in_buffer.size: | while in_buffer.pos < in_buffer.size: | ||||
zresult = lib.ZSTD_compressStream2( | zresult = lib.ZSTD_compressStream2( | ||||
self._cctx, out_buffer, in_buffer, lib.ZSTD_e_continue | self._cctx, out_buffer, in_buffer, lib.ZSTD_e_continue | ||||
) | ) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
"zstd compress error: %s" % _zstd_error(zresult) | |||||
) | |||||
if out_buffer.pos: | if out_buffer.pos: | ||||
ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos)) | ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos)) | ||||
total_write += out_buffer.pos | total_write += out_buffer.pos | ||||
out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
# We've finished reading. Flush the compressor. | # We've finished reading. Flush the compressor. | ||||
while True: | while True: | ||||
except Exception: | except Exception: | ||||
pass | pass | ||||
if size < 0: | if size < 0: | ||||
size = lib.ZSTD_CONTENTSIZE_UNKNOWN | size = lib.ZSTD_CONTENTSIZE_UNKNOWN | ||||
zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("error setting source size: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
"error setting source size: %s" % _zstd_error(zresult) | |||||
) | |||||
return ZstdCompressionReader(self, source, read_size) | return ZstdCompressionReader(self, source, read_size) | ||||
def stream_writer( | def stream_writer( | ||||
self, | self, | ||||
writer, | writer, | ||||
size=-1, | size=-1, | ||||
write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE, | write_size=COMPRESSION_RECOMMENDED_OUTPUT_SIZE, | ||||
write_return_read=False, | write_return_read=False, | ||||
): | ): | ||||
if not hasattr(writer, "write"): | if not hasattr(writer, "write"): | ||||
raise ValueError("must pass an object with a write() method") | raise ValueError("must pass an object with a write() method") | ||||
lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | ||||
if size < 0: | if size < 0: | ||||
size = lib.ZSTD_CONTENTSIZE_UNKNOWN | size = lib.ZSTD_CONTENTSIZE_UNKNOWN | ||||
return ZstdCompressionWriter(self, writer, size, write_size, write_return_read) | return ZstdCompressionWriter( | ||||
self, writer, size, write_size, write_return_read | |||||
) | |||||
write_to = stream_writer | write_to = stream_writer | ||||
def read_to_iter( | def read_to_iter( | ||||
self, | self, | ||||
reader, | reader, | ||||
size=-1, | size=-1, | ||||
read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE, | read_size=COMPRESSION_RECOMMENDED_INPUT_SIZE, | ||||
lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | lib.ZSTD_CCtx_reset(self._cctx, lib.ZSTD_reset_session_only) | ||||
if size < 0: | if size < 0: | ||||
size = lib.ZSTD_CONTENTSIZE_UNKNOWN | size = lib.ZSTD_CONTENTSIZE_UNKNOWN | ||||
zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | zresult = lib.ZSTD_CCtx_setPledgedSrcSize(self._cctx, size) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("error setting source size: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
"error setting source size: %s" % _zstd_error(zresult) | |||||
) | |||||
in_buffer = ffi.new("ZSTD_inBuffer *") | in_buffer = ffi.new("ZSTD_inBuffer *") | ||||
out_buffer = ffi.new("ZSTD_outBuffer *") | out_buffer = ffi.new("ZSTD_outBuffer *") | ||||
in_buffer.src = ffi.NULL | in_buffer.src = ffi.NULL | ||||
in_buffer.size = 0 | in_buffer.size = 0 | ||||
in_buffer.pos = 0 | in_buffer.pos = 0 | ||||
in_buffer.size = len(read_buffer) | in_buffer.size = len(read_buffer) | ||||
in_buffer.pos = 0 | in_buffer.pos = 0 | ||||
while in_buffer.pos < in_buffer.size: | while in_buffer.pos < in_buffer.size: | ||||
zresult = lib.ZSTD_compressStream2( | zresult = lib.ZSTD_compressStream2( | ||||
self._cctx, out_buffer, in_buffer, lib.ZSTD_e_continue | self._cctx, out_buffer, in_buffer, lib.ZSTD_e_continue | ||||
) | ) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("zstd compress error: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
"zstd compress error: %s" % _zstd_error(zresult) | |||||
) | |||||
if out_buffer.pos: | if out_buffer.pos: | ||||
data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | ||||
out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
yield data | yield data | ||||
assert out_buffer.pos == 0 | assert out_buffer.pos == 0 | ||||
def get_frame_parameters(data): | def get_frame_parameters(data): | ||||
params = ffi.new("ZSTD_frameHeader *") | params = ffi.new("ZSTD_frameHeader *") | ||||
data_buffer = ffi.from_buffer(data) | data_buffer = ffi.from_buffer(data) | ||||
zresult = lib.ZSTD_getFrameHeader(params, data_buffer, len(data_buffer)) | zresult = lib.ZSTD_getFrameHeader(params, data_buffer, len(data_buffer)) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("cannot get frame parameters: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
"cannot get frame parameters: %s" % _zstd_error(zresult) | |||||
) | |||||
if zresult: | if zresult: | ||||
raise ZstdError("not enough data for frame parameters; need %d bytes" % zresult) | raise ZstdError( | ||||
"not enough data for frame parameters; need %d bytes" % zresult | |||||
) | |||||
return FrameParameters(params[0]) | return FrameParameters(params[0]) | ||||
class ZstdCompressionDict(object): | class ZstdCompressionDict(object): | ||||
def __init__(self, data, dict_type=DICT_TYPE_AUTO, k=0, d=0): | def __init__(self, data, dict_type=DICT_TYPE_AUTO, k=0, d=0): | ||||
assert isinstance(data, bytes_type) | assert isinstance(data, bytes_type) | ||||
self._data = data | self._data = data | ||||
self.k = k | self.k = k | ||||
self.d = d | self.d = d | ||||
if dict_type not in (DICT_TYPE_AUTO, DICT_TYPE_RAWCONTENT, DICT_TYPE_FULLDICT): | if dict_type not in ( | ||||
DICT_TYPE_AUTO, | |||||
DICT_TYPE_RAWCONTENT, | |||||
DICT_TYPE_FULLDICT, | |||||
): | |||||
raise ValueError( | raise ValueError( | ||||
"invalid dictionary load mode: %d; must use " "DICT_TYPE_* constants" | "invalid dictionary load mode: %d; must use " | ||||
"DICT_TYPE_* constants" | |||||
) | ) | ||||
self._dict_type = dict_type | self._dict_type = dict_type | ||||
self._cdict = None | self._cdict = None | ||||
def __len__(self): | def __len__(self): | ||||
return len(self._data) | return len(self._data) | ||||
def dict_id(self): | def dict_id(self): | ||||
return int_type(lib.ZDICT_getDictID(self._data, len(self._data))) | return int_type(lib.ZDICT_getDictID(self._data, len(self._data))) | ||||
def as_bytes(self): | def as_bytes(self): | ||||
return self._data | return self._data | ||||
def precompute_compress(self, level=0, compression_params=None): | def precompute_compress(self, level=0, compression_params=None): | ||||
if level and compression_params: | if level and compression_params: | ||||
raise ValueError("must only specify one of level or " "compression_params") | raise ValueError( | ||||
"must only specify one of level or " "compression_params" | |||||
) | |||||
if not level and not compression_params: | if not level and not compression_params: | ||||
raise ValueError("must specify one of level or compression_params") | raise ValueError("must specify one of level or compression_params") | ||||
if level: | if level: | ||||
cparams = lib.ZSTD_getCParams(level, 0, len(self._data)) | cparams = lib.ZSTD_getCParams(level, 0, len(self._data)) | ||||
else: | else: | ||||
cparams = ffi.new("ZSTD_compressionParameters") | cparams = ffi.new("ZSTD_compressionParameters") | ||||
lib.ZSTD_dlm_byRef, | lib.ZSTD_dlm_byRef, | ||||
self._dict_type, | self._dict_type, | ||||
lib.ZSTD_defaultCMem, | lib.ZSTD_defaultCMem, | ||||
) | ) | ||||
if ddict == ffi.NULL: | if ddict == ffi.NULL: | ||||
raise ZstdError("could not create decompression dict") | raise ZstdError("could not create decompression dict") | ||||
ddict = ffi.gc(ddict, lib.ZSTD_freeDDict, size=lib.ZSTD_sizeof_DDict(ddict)) | ddict = ffi.gc( | ||||
ddict, lib.ZSTD_freeDDict, size=lib.ZSTD_sizeof_DDict(ddict) | |||||
) | |||||
self.__dict__["_ddict"] = ddict | self.__dict__["_ddict"] = ddict | ||||
return ddict | return ddict | ||||
def train_dictionary( | def train_dictionary( | ||||
dict_size, | dict_size, | ||||
samples, | samples, | ||||
chunks = [] | chunks = [] | ||||
while True: | while True: | ||||
zresult = lib.ZSTD_decompressStream( | zresult = lib.ZSTD_decompressStream( | ||||
self._decompressor._dctx, out_buffer, in_buffer | self._decompressor._dctx, out_buffer, in_buffer | ||||
) | ) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("zstd decompressor error: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
"zstd decompressor error: %s" % _zstd_error(zresult) | |||||
) | |||||
if zresult == 0: | if zresult == 0: | ||||
self._finished = True | self._finished = True | ||||
self._decompressor = None | self._decompressor = None | ||||
if out_buffer.pos: | if out_buffer.pos: | ||||
chunks.append(ffi.buffer(out_buffer.dst, out_buffer.pos)[:]) | chunks.append(ffi.buffer(out_buffer.dst, out_buffer.pos)[:]) | ||||
def seek(self, pos, whence=os.SEEK_SET): | def seek(self, pos, whence=os.SEEK_SET): | ||||
if self._closed: | if self._closed: | ||||
raise ValueError("stream is closed") | raise ValueError("stream is closed") | ||||
read_amount = 0 | read_amount = 0 | ||||
if whence == os.SEEK_SET: | if whence == os.SEEK_SET: | ||||
if pos < 0: | if pos < 0: | ||||
raise ValueError("cannot seek to negative position with SEEK_SET") | raise ValueError( | ||||
"cannot seek to negative position with SEEK_SET" | |||||
) | |||||
if pos < self._bytes_decompressed: | if pos < self._bytes_decompressed: | ||||
raise ValueError("cannot seek zstd decompression stream " "backwards") | raise ValueError( | ||||
"cannot seek zstd decompression stream " "backwards" | |||||
) | |||||
read_amount = pos - self._bytes_decompressed | read_amount = pos - self._bytes_decompressed | ||||
elif whence == os.SEEK_CUR: | elif whence == os.SEEK_CUR: | ||||
if pos < 0: | if pos < 0: | ||||
raise ValueError("cannot seek zstd decompression stream " "backwards") | raise ValueError( | ||||
"cannot seek zstd decompression stream " "backwards" | |||||
) | |||||
read_amount = pos | read_amount = pos | ||||
elif whence == os.SEEK_END: | elif whence == os.SEEK_END: | ||||
raise ValueError( | raise ValueError( | ||||
"zstd decompression streams cannot be seeked " "with SEEK_END" | "zstd decompression streams cannot be seeked " "with SEEK_END" | ||||
) | ) | ||||
while read_amount: | while read_amount: | ||||
result = self.read(min(read_amount, DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE)) | result = self.read( | ||||
min(read_amount, DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE) | |||||
) | |||||
if not result: | if not result: | ||||
break | break | ||||
read_amount -= len(result) | read_amount -= len(result) | ||||
return self._bytes_decompressed | return self._bytes_decompressed | ||||
out_buffer.size = len(dst_buffer) | out_buffer.size = len(dst_buffer) | ||||
out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
dctx = self._decompressor._dctx | dctx = self._decompressor._dctx | ||||
while in_buffer.pos < in_buffer.size: | while in_buffer.pos < in_buffer.size: | ||||
zresult = lib.ZSTD_decompressStream(dctx, out_buffer, in_buffer) | zresult = lib.ZSTD_decompressStream(dctx, out_buffer, in_buffer) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("zstd decompress error: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
"zstd decompress error: %s" % _zstd_error(zresult) | |||||
) | |||||
if out_buffer.pos: | if out_buffer.pos: | ||||
self._writer.write(ffi.buffer(out_buffer.dst, out_buffer.pos)[:]) | self._writer.write( | ||||
ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | |||||
) | |||||
total_write += out_buffer.pos | total_write += out_buffer.pos | ||||
out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
if self._write_return_read: | if self._write_return_read: | ||||
return in_buffer.pos | return in_buffer.pos | ||||
else: | else: | ||||
return total_write | return total_write | ||||
def memory_size(self): | def memory_size(self): | ||||
return lib.ZSTD_sizeof_DCtx(self._dctx) | return lib.ZSTD_sizeof_DCtx(self._dctx) | ||||
def decompress(self, data, max_output_size=0): | def decompress(self, data, max_output_size=0): | ||||
self._ensure_dctx() | self._ensure_dctx() | ||||
data_buffer = ffi.from_buffer(data) | data_buffer = ffi.from_buffer(data) | ||||
output_size = lib.ZSTD_getFrameContentSize(data_buffer, len(data_buffer)) | output_size = lib.ZSTD_getFrameContentSize( | ||||
data_buffer, len(data_buffer) | |||||
) | |||||
if output_size == lib.ZSTD_CONTENTSIZE_ERROR: | if output_size == lib.ZSTD_CONTENTSIZE_ERROR: | ||||
raise ZstdError("error determining content size from frame header") | raise ZstdError("error determining content size from frame header") | ||||
elif output_size == 0: | elif output_size == 0: | ||||
return b"" | return b"" | ||||
elif output_size == lib.ZSTD_CONTENTSIZE_UNKNOWN: | elif output_size == lib.ZSTD_CONTENTSIZE_UNKNOWN: | ||||
if not max_output_size: | if not max_output_size: | ||||
raise ZstdError("could not determine content size in frame header") | raise ZstdError( | ||||
"could not determine content size in frame header" | |||||
) | |||||
result_buffer = ffi.new("char[]", max_output_size) | result_buffer = ffi.new("char[]", max_output_size) | ||||
result_size = max_output_size | result_size = max_output_size | ||||
output_size = 0 | output_size = 0 | ||||
else: | else: | ||||
result_buffer = ffi.new("char[]", output_size) | result_buffer = ffi.new("char[]", output_size) | ||||
result_size = output_size | result_size = output_size | ||||
out_buffer = ffi.new("ZSTD_outBuffer *") | out_buffer = ffi.new("ZSTD_outBuffer *") | ||||
out_buffer.dst = result_buffer | out_buffer.dst = result_buffer | ||||
out_buffer.size = result_size | out_buffer.size = result_size | ||||
out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
in_buffer = ffi.new("ZSTD_inBuffer *") | in_buffer = ffi.new("ZSTD_inBuffer *") | ||||
in_buffer.src = data_buffer | in_buffer.src = data_buffer | ||||
in_buffer.size = len(data_buffer) | in_buffer.size = len(data_buffer) | ||||
in_buffer.pos = 0 | in_buffer.pos = 0 | ||||
zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer) | zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("decompression error: %s" % _zstd_error(zresult)) | raise ZstdError("decompression error: %s" % _zstd_error(zresult)) | ||||
elif zresult: | elif zresult: | ||||
raise ZstdError("decompression error: did not decompress full frame") | raise ZstdError( | ||||
"decompression error: did not decompress full frame" | |||||
) | |||||
elif output_size and out_buffer.pos != output_size: | elif output_size and out_buffer.pos != output_size: | ||||
raise ZstdError( | raise ZstdError( | ||||
"decompression error: decompressed %d bytes; expected %d" | "decompression error: decompressed %d bytes; expected %d" | ||||
% (zresult, output_size) | % (zresult, output_size) | ||||
) | ) | ||||
return ffi.buffer(result_buffer, out_buffer.pos)[:] | return ffi.buffer(result_buffer, out_buffer.pos)[:] | ||||
def stream_reader( | def stream_reader( | ||||
self, | self, | ||||
source, | source, | ||||
read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE, | read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE, | ||||
read_across_frames=False, | read_across_frames=False, | ||||
): | ): | ||||
self._ensure_dctx() | self._ensure_dctx() | ||||
return ZstdDecompressionReader(self, source, read_size, read_across_frames) | return ZstdDecompressionReader( | ||||
self, source, read_size, read_across_frames | |||||
) | |||||
def decompressobj(self, write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE): | def decompressobj(self, write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE): | ||||
if write_size < 1: | if write_size < 1: | ||||
raise ValueError("write_size must be positive") | raise ValueError("write_size must be positive") | ||||
self._ensure_dctx() | self._ensure_dctx() | ||||
return ZstdDecompressionObj(self, write_size=write_size) | return ZstdDecompressionObj(self, write_size=write_size) | ||||
read_buffer = ffi.from_buffer(read_result) | read_buffer = ffi.from_buffer(read_result) | ||||
in_buffer.src = read_buffer | in_buffer.src = read_buffer | ||||
in_buffer.size = len(read_buffer) | in_buffer.size = len(read_buffer) | ||||
in_buffer.pos = 0 | in_buffer.pos = 0 | ||||
while in_buffer.pos < in_buffer.size: | while in_buffer.pos < in_buffer.size: | ||||
assert out_buffer.pos == 0 | assert out_buffer.pos == 0 | ||||
zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer) | zresult = lib.ZSTD_decompressStream( | ||||
self._dctx, out_buffer, in_buffer | |||||
) | |||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("zstd decompress error: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
"zstd decompress error: %s" % _zstd_error(zresult) | |||||
) | |||||
if out_buffer.pos: | if out_buffer.pos: | ||||
data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | data = ffi.buffer(out_buffer.dst, out_buffer.pos)[:] | ||||
out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
yield data | yield data | ||||
if zresult == 0: | if zresult == 0: | ||||
return | return | ||||
self, | self, | ||||
writer, | writer, | ||||
write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE, | write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE, | ||||
write_return_read=False, | write_return_read=False, | ||||
): | ): | ||||
if not hasattr(writer, "write"): | if not hasattr(writer, "write"): | ||||
raise ValueError("must pass an object with a write() method") | raise ValueError("must pass an object with a write() method") | ||||
return ZstdDecompressionWriter(self, writer, write_size, write_return_read) | return ZstdDecompressionWriter( | ||||
self, writer, write_size, write_return_read | |||||
) | |||||
write_to = stream_writer | write_to = stream_writer | ||||
def copy_stream( | def copy_stream( | ||||
self, | self, | ||||
ifh, | ifh, | ||||
ofh, | ofh, | ||||
read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE, | read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE, | ||||
data_buffer = ffi.from_buffer(data) | data_buffer = ffi.from_buffer(data) | ||||
total_read += len(data_buffer) | total_read += len(data_buffer) | ||||
in_buffer.src = data_buffer | in_buffer.src = data_buffer | ||||
in_buffer.size = len(data_buffer) | in_buffer.size = len(data_buffer) | ||||
in_buffer.pos = 0 | in_buffer.pos = 0 | ||||
# Flush all read data to output. | # Flush all read data to output. | ||||
while in_buffer.pos < in_buffer.size: | while in_buffer.pos < in_buffer.size: | ||||
zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer) | zresult = lib.ZSTD_decompressStream( | ||||
self._dctx, out_buffer, in_buffer | |||||
) | |||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError( | raise ZstdError( | ||||
"zstd decompressor error: %s" % _zstd_error(zresult) | "zstd decompressor error: %s" % _zstd_error(zresult) | ||||
) | ) | ||||
if out_buffer.pos: | if out_buffer.pos: | ||||
ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos)) | ofh.write(ffi.buffer(out_buffer.dst, out_buffer.pos)) | ||||
total_write += out_buffer.pos | total_write += out_buffer.pos | ||||
# First chunk should not be using a dictionary. We handle it specially. | # First chunk should not be using a dictionary. We handle it specially. | ||||
chunk = frames[0] | chunk = frames[0] | ||||
if not isinstance(chunk, bytes_type): | if not isinstance(chunk, bytes_type): | ||||
raise ValueError("chunk 0 must be bytes") | raise ValueError("chunk 0 must be bytes") | ||||
# All chunks should be zstd frames and should have content size set. | # All chunks should be zstd frames and should have content size set. | ||||
chunk_buffer = ffi.from_buffer(chunk) | chunk_buffer = ffi.from_buffer(chunk) | ||||
params = ffi.new("ZSTD_frameHeader *") | params = ffi.new("ZSTD_frameHeader *") | ||||
zresult = lib.ZSTD_getFrameHeader(params, chunk_buffer, len(chunk_buffer)) | zresult = lib.ZSTD_getFrameHeader( | ||||
params, chunk_buffer, len(chunk_buffer) | |||||
) | |||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ValueError("chunk 0 is not a valid zstd frame") | raise ValueError("chunk 0 is not a valid zstd frame") | ||||
elif zresult: | elif zresult: | ||||
raise ValueError("chunk 0 is too small to contain a zstd frame") | raise ValueError("chunk 0 is too small to contain a zstd frame") | ||||
if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN: | if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN: | ||||
raise ValueError("chunk 0 missing content size in frame") | raise ValueError("chunk 0 missing content size in frame") | ||||
self._ensure_dctx(load_dict=False) | self._ensure_dctx(load_dict=False) | ||||
last_buffer = ffi.new("char[]", params.frameContentSize) | last_buffer = ffi.new("char[]", params.frameContentSize) | ||||
out_buffer = ffi.new("ZSTD_outBuffer *") | out_buffer = ffi.new("ZSTD_outBuffer *") | ||||
out_buffer.dst = last_buffer | out_buffer.dst = last_buffer | ||||
out_buffer.size = len(last_buffer) | out_buffer.size = len(last_buffer) | ||||
out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
in_buffer = ffi.new("ZSTD_inBuffer *") | in_buffer = ffi.new("ZSTD_inBuffer *") | ||||
in_buffer.src = chunk_buffer | in_buffer.src = chunk_buffer | ||||
in_buffer.size = len(chunk_buffer) | in_buffer.size = len(chunk_buffer) | ||||
in_buffer.pos = 0 | in_buffer.pos = 0 | ||||
zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer) | zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("could not decompress chunk 0: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
"could not decompress chunk 0: %s" % _zstd_error(zresult) | |||||
) | |||||
elif zresult: | elif zresult: | ||||
raise ZstdError("chunk 0 did not decompress full frame") | raise ZstdError("chunk 0 did not decompress full frame") | ||||
# Special case of chain length of 1 | # Special case of chain length of 1 | ||||
if len(frames) == 1: | if len(frames) == 1: | ||||
return ffi.buffer(last_buffer, len(last_buffer))[:] | return ffi.buffer(last_buffer, len(last_buffer))[:] | ||||
i = 1 | i = 1 | ||||
while i < len(frames): | while i < len(frames): | ||||
chunk = frames[i] | chunk = frames[i] | ||||
if not isinstance(chunk, bytes_type): | if not isinstance(chunk, bytes_type): | ||||
raise ValueError("chunk %d must be bytes" % i) | raise ValueError("chunk %d must be bytes" % i) | ||||
chunk_buffer = ffi.from_buffer(chunk) | chunk_buffer = ffi.from_buffer(chunk) | ||||
zresult = lib.ZSTD_getFrameHeader(params, chunk_buffer, len(chunk_buffer)) | zresult = lib.ZSTD_getFrameHeader( | ||||
params, chunk_buffer, len(chunk_buffer) | |||||
) | |||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ValueError("chunk %d is not a valid zstd frame" % i) | raise ValueError("chunk %d is not a valid zstd frame" % i) | ||||
elif zresult: | elif zresult: | ||||
raise ValueError("chunk %d is too small to contain a zstd frame" % i) | raise ValueError( | ||||
"chunk %d is too small to contain a zstd frame" % i | |||||
) | |||||
if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN: | if params.frameContentSize == lib.ZSTD_CONTENTSIZE_UNKNOWN: | ||||
raise ValueError("chunk %d missing content size in frame" % i) | raise ValueError("chunk %d missing content size in frame" % i) | ||||
dest_buffer = ffi.new("char[]", params.frameContentSize) | dest_buffer = ffi.new("char[]", params.frameContentSize) | ||||
out_buffer.dst = dest_buffer | out_buffer.dst = dest_buffer | ||||
out_buffer.size = len(dest_buffer) | out_buffer.size = len(dest_buffer) | ||||
out_buffer.pos = 0 | out_buffer.pos = 0 | ||||
in_buffer.src = chunk_buffer | in_buffer.src = chunk_buffer | ||||
in_buffer.size = len(chunk_buffer) | in_buffer.size = len(chunk_buffer) | ||||
in_buffer.pos = 0 | in_buffer.pos = 0 | ||||
zresult = lib.ZSTD_decompressStream(self._dctx, out_buffer, in_buffer) | zresult = lib.ZSTD_decompressStream( | ||||
self._dctx, out_buffer, in_buffer | |||||
) | |||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError( | raise ZstdError( | ||||
"could not decompress chunk %d: %s" % _zstd_error(zresult) | "could not decompress chunk %d: %s" % _zstd_error(zresult) | ||||
) | ) | ||||
elif zresult: | elif zresult: | ||||
raise ZstdError("chunk %d did not decompress full frame" % i) | raise ZstdError("chunk %d did not decompress full frame" % i) | ||||
last_buffer = dest_buffer | last_buffer = dest_buffer | ||||
i += 1 | i += 1 | ||||
return ffi.buffer(last_buffer, len(last_buffer))[:] | return ffi.buffer(last_buffer, len(last_buffer))[:] | ||||
def _ensure_dctx(self, load_dict=True): | def _ensure_dctx(self, load_dict=True): | ||||
lib.ZSTD_DCtx_reset(self._dctx, lib.ZSTD_reset_session_only) | lib.ZSTD_DCtx_reset(self._dctx, lib.ZSTD_reset_session_only) | ||||
if self._max_window_size: | if self._max_window_size: | ||||
zresult = lib.ZSTD_DCtx_setMaxWindowSize(self._dctx, self._max_window_size) | zresult = lib.ZSTD_DCtx_setMaxWindowSize( | ||||
self._dctx, self._max_window_size | |||||
) | |||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError( | raise ZstdError( | ||||
"unable to set max window size: %s" % _zstd_error(zresult) | "unable to set max window size: %s" % _zstd_error(zresult) | ||||
) | ) | ||||
zresult = lib.ZSTD_DCtx_setFormat(self._dctx, self._format) | zresult = lib.ZSTD_DCtx_setFormat(self._dctx, self._format) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError("unable to set decoding format: %s" % _zstd_error(zresult)) | raise ZstdError( | ||||
"unable to set decoding format: %s" % _zstd_error(zresult) | |||||
) | |||||
if self._dict_data and load_dict: | if self._dict_data and load_dict: | ||||
zresult = lib.ZSTD_DCtx_refDDict(self._dctx, self._dict_data._ddict) | zresult = lib.ZSTD_DCtx_refDDict(self._dctx, self._dict_data._ddict) | ||||
if lib.ZSTD_isError(zresult): | if lib.ZSTD_isError(zresult): | ||||
raise ZstdError( | raise ZstdError( | ||||
"unable to reference prepared dictionary: %s" % _zstd_error(zresult) | "unable to reference prepared dictionary: %s" | ||||
% _zstd_error(zresult) | |||||
) | ) |
#require black | #require black | ||||
$ cd $RUNTESTDIR/.. | $ cd $RUNTESTDIR/.. | ||||
$ black --config=black.toml --check --diff `hg files 'set:(**.py + grep("^#!.*python")) - mercurial/thirdparty/** - "contrib/python-zstandard/**"'` | $ black --config=black.toml --check --diff `hg files 'set:(**.py + grep("^#!.*python")) - mercurial/thirdparty/**'` | ||||