--- a/rust/hg-cpython/src/revlog.rs Mon Oct 30 21:25:28 2023 +0100
+++ b/rust/hg-cpython/src/revlog.rs Mon Oct 30 21:26:17 2023 +0100
@@ -965,6 +965,73 @@
}
}
+py_class!(pub class NodeTree |py| {
+ data nt: RefCell<CoreNodeTree>;
+ data index: RefCell<UnsafePyLeaked<PySharedIndex>>;
+
+ def __new__(_cls, index: PyObject) -> PyResult<NodeTree> {
+ let index = py_rust_index_to_graph(py, index)?;
+ let nt = CoreNodeTree::default(); // in-RAM, fully mutable
+ Self::create_instance(py, RefCell::new(nt), RefCell::new(index))
+ }
+
+ def insert(&self, rev: PyRevision) -> PyResult<PyObject> {
+ let leaked = self.index(py).borrow();
+ let index = &*unsafe { leaked.try_borrow(py)? };
+
+ let rev = UncheckedRevision(rev.0);
+ let rev = index
+ .check_revision(rev)
+ .ok_or_else(|| rev_not_in_index(py, rev))?;
+ if rev == NULL_REVISION {
+ return Err(rev_not_in_index(py, rev.into()))
+ }
+
+ let entry = index.inner.get_entry(rev).unwrap();
+ let mut nt = self.nt(py).borrow_mut();
+ nt.insert(index, entry.hash(), rev).map_err(|e| nodemap_error(py, e))?;
+
+ Ok(py.None())
+ }
+
+ /// Lookup by node hex prefix in the NodeTree, returning revision number.
+ ///
+ /// This is not part of the classical NodeTree API, but is good enough
+ /// for unit testing, as in `test-rust-revlog.py`.
+ def prefix_rev_lookup(
+ &self,
+ node_prefix: PyBytes
+ ) -> PyResult<Option<PyRevision>> {
+ let prefix = NodePrefix::from_hex(node_prefix.data(py))
+ .map_err(|_| PyErr::new::<ValueError, _>(
+ py,
+ format!("Invalid node or prefix {:?}",
+ node_prefix.as_object()))
+ )?;
+
+ let nt = self.nt(py).borrow();
+ let leaked = self.index(py).borrow();
+ let index = &*unsafe { leaked.try_borrow(py)? };
+
+ Ok(nt.find_bin(index, prefix)
+ .map_err(|e| nodemap_error(py, e))?
+ .map(|r| r.into())
+ )
+ }
+
+ def shortest(&self, node: PyBytes) -> PyResult<usize> {
+ let nt = self.nt(py).borrow();
+ let leaked = self.index(py).borrow();
+ let idx = &*unsafe { leaked.try_borrow(py)? };
+ match nt.unique_prefix_len_node(idx, &node_from_py_bytes(py, &node)?)
+ {
+ Ok(Some(l)) => Ok(l),
+ Ok(None) => Err(revlog_error(py)),
+ Err(e) => Err(nodemap_error(py, e)),
+ }
+ }
+});
+
fn revlog_error(py: Python) -> PyErr {
match py
.import("mercurial.error")
@@ -1033,6 +1100,7 @@
m.add(py, "__doc__", "RevLog - Rust implementations")?;
m.add_class::<MixedIndex>(py)?;
+ m.add_class::<NodeTree>(py)?;
let sys = PyModule::import(py, "sys")?;
let sys_modules: PyDict = sys.get(py, "modules")?.extract(py)?;