rust-cpython: add generation counter to leaked reference
authorYuya Nishihara <yuya@tcha.org>
Sat, 05 Oct 2019 08:27:57 -0400
changeset 43476 0836efe4967b
parent 43475 945d4dba5e78
child 43477 ed50f2c31a4c
rust-cpython: add generation counter to leaked reference This counter increments on borrow_mut() to invalidate existing leaked references. This is modeled after the iterator invalidation in Python. The other checks will be adjusted by the subsequent patches.
rust/hg-cpython/src/ref_sharing.rs
--- a/rust/hg-cpython/src/ref_sharing.rs	Sat Oct 12 19:26:23 2019 +0900
+++ b/rust/hg-cpython/src/ref_sharing.rs	Sat Oct 05 08:27:57 2019 -0400
@@ -23,15 +23,33 @@
 //! Macros for use in the `hg-cpython` bridge library.
 
 use crate::exceptions::AlreadyBorrowed;
-use cpython::{PyClone, PyObject, PyResult, Python};
+use cpython::{exc, PyClone, PyErr, PyObject, PyResult, Python};
 use std::cell::{Cell, Ref, RefCell, RefMut};
 use std::ops::{Deref, DerefMut};
+use std::sync::atomic::{AtomicUsize, Ordering};
 
 /// Manages the shared state between Python and Rust
+///
+/// `PySharedState` is owned by `PySharedRefCell`, and is shared across its
+/// derived references. The consistency of these references are guaranteed
+/// as follows:
+///
+/// - The immutability of `py_class!` object fields. Any mutation of
+///   `PySharedRefCell` is allowed only through its `borrow_mut()`.
+/// - The `py: Python<'_>` token, which makes sure that any data access is
+///   synchronized by the GIL.
+/// - The `generation` counter, which increments on `borrow_mut()`. `PyLeaked`
+///   reference is valid only if the `current_generation()` equals to the
+///   `generation` at the time of `leak_immutable()`.
 #[derive(Debug, Default)]
 struct PySharedState {
     leak_count: Cell<usize>,
     mutably_borrowed: Cell<bool>,
+    // The counter variable could be Cell<usize> since any operation on
+    // PySharedState is synchronized by the GIL, but being "atomic" makes
+    // PySharedState inherently Sync. The ordering requirement doesn't
+    // matter thanks to the GIL.
+    generation: AtomicUsize,
 }
 
 // &PySharedState can be Send because any access to inner cells is
@@ -54,6 +72,10 @@
         match self.leak_count.get() {
             0 => {
                 self.mutably_borrowed.replace(true);
+                // Note that this wraps around to the same value if mutably
+                // borrowed more than usize::MAX times, which wouldn't happen
+                // in practice.
+                self.generation.fetch_add(1, Ordering::Relaxed);
                 Ok(PyRefMut::new(py, pyrefmut, self))
             }
             // TODO
@@ -118,6 +140,10 @@
             self.leak_count.replace(count - 1);
         }
     }
+
+    fn current_generation(&self, _py: Python) -> usize {
+        self.generation.load(Ordering::Relaxed)
+    }
 }
 
 /// `RefCell` wrapper to be safely used in conjunction with `PySharedState`.
@@ -308,14 +334,20 @@
 }
 
 /// Manage immutable references to `PyObject` leaked into Python iterators.
+///
+/// This reference will be invalidated once the original value is mutably
+/// borrowed.
 pub struct PyLeaked<T> {
     inner: PyObject,
     data: Option<T>,
     py_shared_state: &'static PySharedState,
+    /// Generation counter of data `T` captured when PyLeaked is created.
+    generation: usize,
 }
 
 // DO NOT implement Deref for PyLeaked<T>! Dereferencing PyLeaked
-// without taking Python GIL wouldn't be safe.
+// without taking Python GIL wouldn't be safe. Also, the underling reference
+// is invalid if generation != py_shared_state.generation.
 
 impl<T> PyLeaked<T> {
     /// # Safety
@@ -331,14 +363,18 @@
             inner: inner.clone_ref(py),
             data: Some(data),
             py_shared_state,
+            generation: py_shared_state.current_generation(py),
         }
     }
 
     /// Immutably borrows the wrapped value.
+    ///
+    /// Borrowing fails if the underlying reference has been invalidated.
     pub fn try_borrow<'a>(
         &'a self,
         py: Python<'a>,
     ) -> PyResult<PyLeakedRef<'a, T>> {
+        self.validate_generation(py)?;
         Ok(PyLeakedRef {
             _py: py,
             data: self.data.as_ref().unwrap(),
@@ -347,12 +383,15 @@
 
     /// Mutably borrows the wrapped value.
     ///
+    /// Borrowing fails if the underlying reference has been invalidated.
+    ///
     /// Typically `T` is an iterator. If `T` is an immutable reference,
     /// `get_mut()` is useless since the inner value can't be mutated.
     pub fn try_borrow_mut<'a>(
         &'a mut self,
         py: Python<'a>,
     ) -> PyResult<PyLeakedRefMut<'a, T>> {
+        self.validate_generation(py)?;
         Ok(PyLeakedRefMut {
             _py: py,
             data: self.data.as_mut().unwrap(),
@@ -364,6 +403,13 @@
     /// Typically `T` is a static reference to a container, and `U` is an
     /// iterator of that container.
     ///
+    /// # Panics
+    ///
+    /// Panics if the underlying reference has been invalidated.
+    ///
+    /// This is typically called immediately after the `PyLeaked` is obtained.
+    /// In which case, the reference must be valid and no panic would occur.
+    ///
     /// # Safety
     ///
     /// The lifetime of the object passed in to the function `f` is cheated.
@@ -375,6 +421,11 @@
         py: Python,
         f: impl FnOnce(T) -> U,
     ) -> PyLeaked<U> {
+        // Needs to test the generation value to make sure self.data reference
+        // is still intact.
+        self.validate_generation(py)
+            .expect("map() over invalidated leaked reference");
+
         // f() could make the self.data outlive. That's why map() is unsafe.
         // In order to make this function safe, maybe we'll need a way to
         // temporarily restrict the lifetime of self.data and translate the
@@ -384,6 +435,18 @@
             inner: self.inner.clone_ref(py),
             data: Some(new_data),
             py_shared_state: self.py_shared_state,
+            generation: self.generation,
+        }
+    }
+
+    fn validate_generation(&self, py: Python) -> PyResult<()> {
+        if self.py_shared_state.current_generation(py) == self.generation {
+            Ok(())
+        } else {
+            Err(PyErr::new::<exc::RuntimeError, _>(
+                py,
+                "Cannot access to leaked reference after mutation",
+            ))
         }
     }
 }
@@ -582,6 +645,41 @@
     }
 
     #[test]
+    fn test_leaked_borrow_after_mut() {
+        let (gil, owner) = prepare_env();
+        let py = gil.python();
+        let leaked = owner.string_shared(py).leak_immutable().unwrap();
+        owner.string(py).py_shared_state.leak_count.replace(0); // XXX cheat
+        owner.string_shared(py).borrow_mut().unwrap().clear();
+        owner.string(py).py_shared_state.leak_count.replace(1); // XXX cheat
+        assert!(leaked.try_borrow(py).is_err());
+    }
+
+    #[test]
+    fn test_leaked_borrow_mut_after_mut() {
+        let (gil, owner) = prepare_env();
+        let py = gil.python();
+        let leaked = owner.string_shared(py).leak_immutable().unwrap();
+        let mut leaked_iter = unsafe { leaked.map(py, |s| s.chars()) };
+        owner.string(py).py_shared_state.leak_count.replace(0); // XXX cheat
+        owner.string_shared(py).borrow_mut().unwrap().clear();
+        owner.string(py).py_shared_state.leak_count.replace(1); // XXX cheat
+        assert!(leaked_iter.try_borrow_mut(py).is_err());
+    }
+
+    #[test]
+    #[should_panic(expected = "map() over invalidated leaked reference")]
+    fn test_leaked_map_after_mut() {
+        let (gil, owner) = prepare_env();
+        let py = gil.python();
+        let leaked = owner.string_shared(py).leak_immutable().unwrap();
+        owner.string(py).py_shared_state.leak_count.replace(0); // XXX cheat
+        owner.string_shared(py).borrow_mut().unwrap().clear();
+        owner.string(py).py_shared_state.leak_count.replace(1); // XXX cheat
+        let _leaked_iter = unsafe { leaked.map(py, |s| s.chars()) };
+    }
+
+    #[test]
     fn test_borrow_mut_while_leaked() {
         let (gil, owner) = prepare_env();
         let py = gil.python();