diff --git a/rust/hg-core/src/ancestors.rs b/rust/hg-core/src/ancestors.rs --- a/rust/hg-core/src/ancestors.rs +++ b/rust/hg-core/src/ancestors.rs @@ -10,6 +10,7 @@ use super::{Graph, GraphError, Revision, NULL_REVISION}; use std::cmp::max; use std::collections::{BinaryHeap, HashSet}; +use crate::dagops; /// Iterator over the ancestors of a given list of revisions /// This is a generic type, defined and implemented for any Graph, so that @@ -229,6 +230,19 @@ &self.bases } + /// Computes the relative heads of current bases. + /// + /// The object is still usable after this. + pub fn bases_heads(&self) -> Result, GraphError> { + dagops::heads(&self.graph, self.bases.iter()) + } + + /// Consumes the object and returns the relative heads of its bases. + pub fn into_bases_heads(mut self) -> Result, GraphError> { + dagops::retain_heads(&self.graph, &mut self.bases)?; + Ok(self.bases) + } + pub fn add_bases( &mut self, new_bases: impl IntoIterator, @@ -556,8 +570,8 @@ } #[test] - /// Test constructor, add/get bases - fn test_missing_bases() { + /// Test constructor, add/get bases and heads + fn test_missing_bases() -> Result<(), GraphError> { let mut missing_ancestors = MissingAncestors::new(SampleGraph, [5, 3, 1, 3].iter().cloned()); let mut as_vec: Vec = @@ -569,6 +583,11 @@ as_vec = missing_ancestors.get_bases().iter().cloned().collect(); as_vec.sort(); assert_eq!(as_vec, [1, 3, 5, 7, 8]); + + as_vec = missing_ancestors.bases_heads()?.iter().cloned().collect(); + as_vec.sort(); + assert_eq!(as_vec, [3, 5, 7, 8]); + Ok(()) } fn assert_missing_remove( diff --git a/rust/hg-cpython/src/ancestors.rs b/rust/hg-cpython/src/ancestors.rs --- a/rust/hg-cpython/src/ancestors.rs +++ b/rust/hg-cpython/src/ancestors.rs @@ -166,6 +166,11 @@ py_set(py, self.inner(py).borrow().get_bases()) } + def basesheads(&self) -> PyResult { + let inner = self.inner(py).borrow(); + py_set(py, &inner.bases_heads().map_err(|e| GraphError::pynew(py, e))?) + } + def removeancestorsfrom(&self, revs: PyObject) -> PyResult { let mut inner = self.inner(py).borrow_mut(); // this is very lame: we convert to a Rust set, update it in place diff --git a/tests/test-rust-ancestor.py b/tests/test-rust-ancestor.py --- a/tests/test-rust-ancestor.py +++ b/tests/test-rust-ancestor.py @@ -114,6 +114,7 @@ missanc.addbases({2}) self.assertEqual(missanc.bases(), {1, 2}) self.assertEqual(missanc.missingancestors([3]), [3]) + self.assertEqual(missanc.basesheads(), {2}) def testmissingancestorsremove(self): idx = self.parseindex()