# HG changeset patch # User Yuya Nishihara # Date 1570875983 -32400 # Node ID 945d4dba5e78f38537606761a134fbfebbedc24a # Parent b9f79109021150a64af3053ab56ea81fa00689df rust-cpython: add stub wrapper that'll prevent leaked data from being mutated In order to allow mutation of PySharedRefCell value while PyLeaked reference exists, we need yet another "borrow" scope where mutation is prohibited. try_borrow<'a> and try_borrow_mut<'a> defines the "borrow" scope <'a>. The subsequent patches will implement leak counter based on this scope. PyLeakedRef and PyLeakedRefMut could be unified to PyLeakedRef<&T> and PyLeakedRef<&mut T> respectively, but I didn't do that since it seemed a bit weird that deref_mut() would return a mutable reference to an immutable reference. diff -r b9f791090211 -r 945d4dba5e78 rust/hg-cpython/src/ref_sharing.rs --- a/rust/hg-cpython/src/ref_sharing.rs Sat Oct 12 19:10:51 2019 +0900 +++ b/rust/hg-cpython/src/ref_sharing.rs Sat Oct 12 19:26:23 2019 +0900 @@ -25,6 +25,7 @@ use crate::exceptions::AlreadyBorrowed; use cpython::{PyClone, PyObject, PyResult, Python}; use std::cell::{Cell, Ref, RefCell, RefMut}; +use std::ops::{Deref, DerefMut}; /// Manages the shared state between Python and Rust #[derive(Debug, Default)] @@ -333,17 +334,29 @@ } } - /// Returns an immutable reference to the inner value. - pub fn get_ref<'a>(&'a self, _py: Python<'a>) -> &'a T { - self.data.as_ref().unwrap() + /// Immutably borrows the wrapped value. + pub fn try_borrow<'a>( + &'a self, + py: Python<'a>, + ) -> PyResult> { + Ok(PyLeakedRef { + _py: py, + data: self.data.as_ref().unwrap(), + }) } - /// Returns a mutable reference to the inner value. + /// Mutably borrows the wrapped value. /// /// Typically `T` is an iterator. If `T` is an immutable reference, /// `get_mut()` is useless since the inner value can't be mutated. - pub fn get_mut<'a>(&'a mut self, _py: Python<'a>) -> &'a mut T { - self.data.as_mut().unwrap() + pub fn try_borrow_mut<'a>( + &'a mut self, + py: Python<'a>, + ) -> PyResult> { + Ok(PyLeakedRefMut { + _py: py, + data: self.data.as_mut().unwrap(), + }) } /// Converts the inner value by the given function. @@ -389,6 +402,40 @@ } } +/// Immutably borrowed reference to a leaked value. +pub struct PyLeakedRef<'a, T> { + _py: Python<'a>, + data: &'a T, +} + +impl Deref for PyLeakedRef<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + self.data + } +} + +/// Mutably borrowed reference to a leaked value. +pub struct PyLeakedRefMut<'a, T> { + _py: Python<'a>, + data: &'a mut T, +} + +impl Deref for PyLeakedRefMut<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + self.data + } +} + +impl DerefMut for PyLeakedRefMut<'_, T> { + fn deref_mut(&mut self) -> &mut T { + self.data + } +} + /// Defines a `py_class!` that acts as a Python iterator over a Rust iterator. /// /// TODO: this is a bit awkward to use, and a better (more complicated) @@ -457,7 +504,8 @@ def __next__(&self) -> PyResult<$success_type> { let mut inner_opt = self.inner(py).borrow_mut(); if let Some(leaked) = inner_opt.as_mut() { - match leaked.get_mut(py).next() { + let mut iter = leaked.try_borrow_mut(py)?; + match iter.next() { None => { // replace Some(inner) by None, drop $leaked inner_opt.take(); @@ -512,6 +560,28 @@ } #[test] + fn test_leaked_borrow() { + let (gil, owner) = prepare_env(); + let py = gil.python(); + let leaked = owner.string_shared(py).leak_immutable().unwrap(); + let leaked_ref = leaked.try_borrow(py).unwrap(); + assert_eq!(*leaked_ref, "new"); + } + + #[test] + fn test_leaked_borrow_mut() { + let (gil, owner) = prepare_env(); + let py = gil.python(); + let leaked = owner.string_shared(py).leak_immutable().unwrap(); + let mut leaked_iter = unsafe { leaked.map(py, |s| s.chars()) }; + let mut leaked_ref = leaked_iter.try_borrow_mut(py).unwrap(); + assert_eq!(leaked_ref.next(), Some('n')); + assert_eq!(leaked_ref.next(), Some('e')); + assert_eq!(leaked_ref.next(), Some('w')); + assert_eq!(leaked_ref.next(), None); + } + + #[test] fn test_borrow_mut_while_leaked() { let (gil, owner) = prepare_env(); let py = gil.python();