diff --git a/rust/indexes/Cargo.toml b/rust/indexes/Cargo.toml --- a/rust/indexes/Cargo.toml +++ b/rust/indexes/Cargo.toml @@ -12,3 +12,8 @@ [dependencies] radixbuf = { path = "../radixbuf" } error-chain = "0.11" + +[dependencies.cpython] +git = "https://github.com/dgrunwald/rust-cpython" +default-features = false +features = ["extension-module-2-7"] diff --git a/rust/indexes/src/lib.rs b/rust/indexes/src/lib.rs --- a/rust/indexes/src/lib.rs +++ b/rust/indexes/src/lib.rs @@ -11,3 +11,7 @@ pub mod errors; pub mod nodemap; + +#[allow(non_camel_case_types)] +#[allow(dead_code)] +pub mod pyext; diff --git a/rust/indexes/src/pyext.rs b/rust/indexes/src/pyext.rs new file mode 100644 --- /dev/null +++ b/rust/indexes/src/pyext.rs @@ -0,0 +1,59 @@ +// Copyright 2017 Facebook, Inc. +// +// This software may be used and distributed according to the terms of the +// GNU General Public License version 2 or any later version. + +use cpython::{PyResult, PyBytes, PyObject}; +use cpython::buffer::PyBuffer; +use nodemap::NodeRevMap; +use std::mem::size_of; +use std::slice; + +py_module_initializer!(indexes, initindexes, PyInit_indexes, |py, m| { + try!(m.add_class::(py)); + Ok(()) +}); + +py_class!(class nodemap |py| { + data nodemap: NodeRevMap<'static>; + + // Keep their references so Python won't free them. + // This will cause dead code warning. + data changelog_slice: PyBuffer; + data index_slice: PyBuffer; + + def __new__(_cls, changelog: &PyObject, index: &PyObject) -> PyResult { + let changelog_buf = PyBuffer::get(py, changelog)?; + let index_buf = PyBuffer::get(py, index)?; + let nm = NodeRevMap::new(pybuf_to_slice(&changelog_buf), pybuf_to_slice(&index_buf))?; + nodemap::create_instance(py, nm, changelog_buf, index_buf) + } + + def __getitem__(&self, key: PyBytes) -> PyResult> { + Ok(self.nodemap(py).node_to_rev(key.data(py))?) + } + + def __contains__(&self, key: PyBytes) -> PyResult { + Ok(self.nodemap(py).node_to_rev(key.data(py))?.is_some()) + } + + def partialmatch(&self, hex: PyBytes) -> PyResult> { + Ok(self.nodemap(py).hex_prefix_to_node(hex.data(py))?.map(|b| PyBytes::new(py, b))) + } + + def build(&self) -> PyResult { + let buf = self.nodemap(py).build_incrementally()?; + let slice = unsafe { slice::from_raw_parts(buf.as_ptr() as *const u8, buf.len() * 4) }; + Ok(PyBytes::new(py, slice)) + } + + def lag(&self) -> PyResult { + Ok(self.nodemap(py).lag()) + } +}); + +#[inline] +fn pybuf_to_slice(buf: &PyBuffer) -> &'static [T] { + let len = buf.len_bytes() / size_of::(); + unsafe { slice::from_raw_parts(buf.buf_ptr() as *const T, len) } +}