diff --git a/mercurial/thirdparty/sha1dc/cext.c b/mercurial/thirdparty/sha1dc/cext.c --- a/mercurial/thirdparty/sha1dc/cext.c +++ b/mercurial/thirdparty/sha1dc/cext.c @@ -25,8 +25,8 @@ static int pysha1ctx_init(pysha1ctx *self, PyObject *args) { - const char *data = NULL; - Py_ssize_t len; + Py_buffer data; + data.obj = NULL; SHA1DCInit(&(self->ctx)); /* We don't want "safe" sha1s, wherein sha1dc can give you a @@ -34,11 +34,19 @@ collision. We just want to detect collisions. */ SHA1DCSetSafeHash(&(self->ctx), 0); - if (!PyArg_ParseTuple(args, PY23("|s#", "|y#"), &data, &len)) { + if (!PyArg_ParseTuple(args, PY23("|s*", "|y*"), &data)) { return -1; } - if (data) { - SHA1DCUpdate(&(self->ctx), data, len); + if (data.obj) { + if (!PyBuffer_IsContiguous(&data, 'C') || data.ndim > 1) { + PyErr_SetString(PyExc_BufferError, + "buffer must be contiguous and single dimension"); + PyBuffer_Release(&data); + return -1; + } + + SHA1DCUpdate(&(self->ctx), data.buf, data.len); + PyBuffer_Release(&data); } return 0; } @@ -50,12 +58,18 @@ static PyObject *pysha1ctx_update(pysha1ctx *self, PyObject *args) { - const char *data; - Py_ssize_t len; - if (!PyArg_ParseTuple(args, PY23("s#", "y#"), &data, &len)) { + Py_buffer data; + if (!PyArg_ParseTuple(args, PY23("s*", "y*"), &data)) { return NULL; } - SHA1DCUpdate(&(self->ctx), data, len); + if (!PyBuffer_IsContiguous(&data, 'C') || data.ndim > 1) { + PyErr_SetString(PyExc_BufferError, + "buffer must be contiguous and single dimension"); + PyBuffer_Release(&data); + return NULL; + } + SHA1DCUpdate(&(self->ctx), data.buf, data.len); + PyBuffer_Release(&data); Py_RETURN_NONE; } diff --git a/tests/test-hashutil.py b/tests/test-hashutil.py --- a/tests/test-hashutil.py +++ b/tests/test-hashutil.py @@ -45,6 +45,26 @@ h.digest(), ) + def test_bytes_like_types(self): + h = self.hasher() + h.update(bytearray(b'foo')) + h.update(memoryview(b'baz')) + self.assertEqual( + '21eb6533733a5e4763acacd1d45a60c2e0e404e1', h.hexdigest() + ) + + h = self.hasher(bytearray(b'foo')) + h.update(b'baz') + self.assertEqual( + '21eb6533733a5e4763acacd1d45a60c2e0e404e1', h.hexdigest() + ) + + h = self.hasher(memoryview(b'foo')) + h.update(b'baz') + self.assertEqual( + '21eb6533733a5e4763acacd1d45a60c2e0e404e1', h.hexdigest() + ) + class hashlibtests(unittest.TestCase, hashertestsbase): hasher = hashlib.sha1