diff --git a/rust/treedirstate/src/dirstate.rs b/rust/treedirstate/src/dirstate.rs --- a/rust/treedirstate/src/dirstate.rs +++ b/rust/treedirstate/src/dirstate.rs @@ -186,6 +186,18 @@ self.tracked.get(self.store.store_view(), name) } + pub fn get_tracked_filtered_key( + &mut self, + name: KeyRef, + filter: &mut F, + ) -> Result> + where + F: FnMut(KeyRef) -> Result, + { + self.tracked + .get_filtered_key(self.store.store_view(), name, filter) + } + /// Visit all tracked files with a visitor. pub fn visit_tracked(&mut self, visitor: &mut F) -> Result<()> where diff --git a/rust/treedirstate/src/python.rs b/rust/treedirstate/src/python.rs --- a/rust/treedirstate/src/python.rs +++ b/rust/treedirstate/src/python.rs @@ -5,11 +5,12 @@ use cpython::exc; use dirstate::Dirstate; use errors::ErrorKind; +use errors; use filestate::FileState; use std::cell::RefCell; use std::path::PathBuf; use store::BlockId; -use tree::KeyRef; +use tree::{Key, KeyRef}; py_module_initializer!( rusttreedirstate, @@ -347,4 +348,25 @@ Ok(py.None()) } + def getcasefoldedtracked( + &self, + filename: PyBytes, + casefolder: PyObject + ) -> PyResult> { + let mut dirstate = self.dirstate(py).borrow_mut(); + let mut filter = |filename: KeyRef| -> errors::Result { + let unfolded = PyBytes::new(py, filename); + let folded = casefolder.call(py, (unfolded,), None) + .map_err(|e| callback_error(py, e))? + .extract::(py) + .map_err(|e| callback_error(py, e))?; + Ok(folded.data(py).to_vec()) + }; + + dirstate + .get_tracked_filtered_key(filename.data(py), &mut filter) + .map(|o| o.map(|k| PyBytes::new(py, &k).into_object())) + .map_err(|e| PyErr::new::(py, e.description())) + } + }); diff --git a/rust/treedirstate/src/tree.rs b/rust/treedirstate/src/tree.rs --- a/rust/treedirstate/src/tree.rs +++ b/rust/treedirstate/src/tree.rs @@ -45,6 +45,11 @@ /// then the ID must not be None, and the entries are yet to be loaded from the back-end /// store. entries: Option>, + + /// A map from keys that have been filtered through a case-folding filter function to the + /// original key. This is used for case-folded look-ups. Filtered values are cached, so + /// only a single filter function can be used with a tree. + filtered_keys: Option>, } /// The root of the tree. The count of files in the tree is maintained for fast size @@ -117,6 +122,7 @@ Node { id: None, entries: Some(NodeEntryMap::new()), + filtered_keys: None, } } @@ -126,6 +132,7 @@ Node { id: Some(id), entries: None, + filtered_keys: None, } } @@ -441,6 +448,7 @@ }; if let Some((new_key, new_entry)) = new_entry { self.load_entries(store)?.insert(new_key, new_entry); + self.filtered_keys = None; } self.id = None; Ok(file_added) @@ -465,6 +473,7 @@ }; if let Some(entry) = remove_entry { self.load_entries(store)?.remove(entry); + self.filtered_keys = None; self.id = None; } if file_removed { @@ -472,6 +481,58 @@ } Ok((file_removed, self.load_entries(store)?.is_empty())) } + + /// Performs a key lookup using filtered keys. + /// + /// Applies the filter function to each key in the node, then returns the real key that + /// matches the name provided. The name may contain a path, in which case the subdirectories + /// of this node are also queried. + /// + /// Returns a reversed vector of key references for each path element, or None if no file + /// matches the requested name after filtering. + fn get_filtered_key<'a, F>( + &'a mut self, + store: &StoreView, + name: KeyRef, + filter: &mut F, + ) -> Result>>> + where + F: FnMut(KeyRef) -> Result, + { + let (elem, path) = split_key(name); + if self.filtered_keys.is_none() { + let new_map = { + let entries = self.load_entries(store)?; + let mut new_map = VecMap::with_capacity(entries.len()); + for (k, _v) in entries.iter() { + new_map.insert(filter(k)?, k.to_vec()); + } + new_map + }; + self.filtered_keys = Some(new_map); + } + if let Some(path) = path { + if let Some(mapped_elem) = self.filtered_keys.as_ref().unwrap().get(elem) { + if let Some(&mut NodeEntry::Directory(ref mut node)) = + self.entries.as_mut().unwrap().get_mut(mapped_elem) + { + if let Some(mut mapped_path) = node.get_filtered_key(store, path, filter)? { + mapped_path.push(mapped_elem); + return Ok(Some(mapped_path)); + } + } + } + Ok(None) + } else { + Ok( + self.filtered_keys + .as_ref() + .unwrap() + .get(elem) + .map(|e| vec![e.as_ref()]), + ) + } + } } impl Tree { @@ -572,6 +633,23 @@ } Ok(removed) } + + pub fn get_filtered_key( + &mut self, + store: &StoreView, + name: KeyRef, + filter: &mut F, + ) -> Result> + where + F: FnMut(KeyRef) -> Result, + { + Ok(self.root.get_filtered_key(store, name, filter)?.map( + |mut path| { + path.reverse(); + path.concat() + }, + )) + } } #[cfg(test)] @@ -579,8 +657,9 @@ use store::NullStore; use store::tests::MapStore; - use tree::{KeyRef, Tree}; + use tree::{Key, KeyRef, Tree}; use filestate::FileState; + use errors::*; // Test files in order. Note lexicographic ordering of file9 and file10. static TEST_FILES: [(&[u8], u32, i32, i32); 16] = [ @@ -794,4 +873,34 @@ ] ); } + + #[test] + fn filtered_keys() { + let ms = MapStore::new(); + let mut t = Tree::new(); + populate(&mut t, &ms); + + // Define a mapping function that upper-cases 'A' characters: + fn map_upper_a(k: KeyRef) -> Result { + Ok( + k.iter() + .map(|c| if *c == b'a' { b'A' } else { *c }) + .collect(), + ) + } + + // Look-up with normalized name should give non-normalized version. + assert_eq!( + t.get_filtered_key(&ms, b"dirA/subdirA/file1", &mut map_upper_a) + .expect("should succeed"), + Some(b"dirA/subdira/file1".to_vec()) + ); + + // Look-up with non-normalized name should match nothing. + assert_eq!( + t.get_filtered_key(&ms, b"dirA/subdira/file1", &mut map_upper_a) + .expect("should succeed"), + None + ); + } } diff --git a/treedirstate/__init__.py b/treedirstate/__init__.py --- a/treedirstate/__init__.py +++ b/treedirstate/__init__.py @@ -7,12 +7,14 @@ from __future__ import absolute_import from mercurial import ( dirstate, + encoding, error, extensions, localrepo, node, pycompat, registrar, + scmutil, txnutil, util, ) @@ -146,6 +148,9 @@ return (self._rmap.gettracked(filename, None) or self._rmap.getremoved(filename, default)) + def getcasefoldedtracked(self, filename, foldfunc): + return self._rmap.getcasefoldedtracked(filename, foldfunc) + def __getitem__(self, filename): item = (self._rmap.gettracked(filename, None) or self._rmap.getremoved(filename, None)) @@ -452,6 +457,37 @@ ds._mapcls = treedirstatemap return ds +class casecollisionauditor(object): + def __init__(self, ui, abort, dirstate): + self._ui = ui + self._abort = abort + self._dirstate = dirstate + # The purpose of _newfiles is so that we don't complain about + # case collisions if someone were to call this object with the + # same filename twice. + self._newfiles = set() + self._newfilesfolded = set() + + def __call__(self, f): + if f in self._newfiles: + return + fl = encoding.lower(f) + if (f not in self._dirstate and + (fl in self._newfilesfolded or + self._dirstate._map.getcasefoldedtracked(fl, encoding.lower))): + msg = _('possible case-folding collision for %s') % f + if self._abort: + raise error.Abort(msg) + self._ui.warn(_("warning: %s\n") % msg) + self._newfiles.add(f) + self._newfilesfolded.add(fl) + +def wrapcca(orig, ui, abort, dirstate): + if util.safehasattr(dirstate._map, 'getcasefoldedtracked'): + return casecollisionauditor(ui, abort, dirstate) + else: + return orig(ui, abort, dirstate) + def wrapnewreporequirements(orig, repo): reqs = orig(repo) if useinnewrepos: @@ -474,6 +510,7 @@ localrepo.localrepository.featuresetupfuncs.add(featuresetup) extensions.wrapfilecache(localrepo.localrepository, 'dirstate', wrapdirstate) + extensions.wrapfunction(scmutil, 'casecollisionauditor', wrapcca) def reposetup(ui, repo): ui.log('treedirstate_enabled', '',