Devin Jeanpierre | b368e68 | 2022-05-03 02:23:44 -0700 | [diff] [blame] | 1 | // Part of the Crubit project, under the Apache License v2.0 with LLVM |
| 2 | // Exceptions. See /LICENSE for license information. |
| 3 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 4 | |
| 5 | #![cfg_attr(test, feature(negative_impls))] |
| 6 | |
| 7 | //! # Object-Oriented Programming Support (OOPS). |
| 8 | //! |
| 9 | //! ## Upcasting |
| 10 | //! |
| 11 | //! To cast a reference to its base class type, use `my_reference.upcast()`. |
| 12 | //! For example: |
| 13 | //! |
| 14 | //! ```ignore |
| 15 | //! let x : &mut Derived = ...; |
| 16 | //! let y : Pin<&mut Base> = x.upcast(); |
| 17 | //! ``` |
| 18 | //! |
| 19 | //! Because base classes are always `!Unpin`, mutable references to base must |
| 20 | //! take the form of `Pin<&mut Base>`. See |
| 21 | //! docs/unpin.md |
| 22 | //! |
| 23 | //! To implement upcasting, implement the `Inherits` trait. |
| 24 | //! |
| 25 | //! ## Downcasting |
| 26 | //! |
| 27 | //! TODO(b/216195042): dynamic downcasting |
| 28 | //! TODO(b/216195042): static downcasting |
| 29 | |
| 30 | use std::pin::Pin; |
| 31 | |
| 32 | /// Upcast a reference or smart pointer. This operation cannot fail at runtime. |
| 33 | /// |
| 34 | /// If `Derived` has a (public, unambiguous) base class `Base`, then: |
| 35 | /// |
| 36 | /// ```ignore |
| 37 | /// &Derived : Upcast<&Base> |
| 38 | /// Pin<&mut Derived> : Upcast<Pin<&mut Base>> |
| 39 | /// ``` |
| 40 | /// |
| 41 | /// In addition, if `Derived : Unpin`, then `&mut Derived : Upcast<Pin<&mut |
| 42 | /// Base>>`. |
| 43 | /// |
| 44 | /// (And, while it is not possible in Crubit bindings, if `Base` is also |
| 45 | /// `Unpin`, then `&mut Derived : Upcast<&mut Base>`.) |
| 46 | /// |
| 47 | /// For the purpose of `Upcast`, any type `T` is its own ("improper") base |
| 48 | /// class. |
| 49 | pub trait Upcast<Target> { |
| 50 | fn upcast(self) -> Target; |
| 51 | } |
| 52 | |
| 53 | /// Upcast `&` -> `&`. |
| 54 | impl<'a, Derived, Base> Upcast<&'a Base> for &'a Derived |
| 55 | where |
| 56 | Derived: Inherits<Base>, |
| 57 | { |
| 58 | fn upcast(self: &'a Derived) -> &'a Base { |
| 59 | unsafe { &*Derived::upcast_ptr(self as *const Derived) } |
| 60 | } |
| 61 | } |
| 62 | |
| 63 | /// Upcast `Pin<&mut>` -> `Pin<&mut>. |
| 64 | impl<'a, Derived, Base> Upcast<Pin<&'a mut Base>> for Pin<&'a mut Derived> |
| 65 | where |
| 66 | Derived: Inherits<Base>, |
| 67 | { |
| 68 | fn upcast(self: Pin<&'a mut Derived>) -> Pin<&'a mut Base> { |
| 69 | unsafe { |
| 70 | let inner = Pin::into_inner_unchecked(self) as *mut Derived; |
| 71 | Pin::new_unchecked(&mut *Derived::upcast_ptr_mut(inner)) |
| 72 | } |
| 73 | } |
| 74 | } |
| 75 | |
| 76 | /// Upcast `&mut` -> `Pin<&mut>. |
| 77 | /// |
| 78 | /// Since all C++ base classes are `!Unpin`, this is the normal shape of a |
| 79 | /// mutable reference upcast for an `Unpin` derived class. |
| 80 | impl<'a, Derived, Base> Upcast<Pin<&'a mut Base>> for &'a mut Derived |
| 81 | where |
| 82 | Pin<&'a mut Derived>: Upcast<Pin<&'a mut Base>>, |
| 83 | Derived: Unpin, |
| 84 | { |
| 85 | fn upcast(self: &'a mut Derived) -> Pin<&'a mut Base> { |
| 86 | Pin::new(self).upcast() |
| 87 | } |
| 88 | } |
| 89 | |
| 90 | /// Upcast `&mut` -> `&mut`. |
| 91 | /// |
Dmitri Gribenko | 1dfdf0d | 2022-05-03 23:23:15 -0700 | [diff] [blame] | 92 | /// This impl is never applicable to C++ types (a C++ base class is `!Unpin`), |
| 93 | /// but could work for inheritance implemented in pure Rust. |
Devin Jeanpierre | b368e68 | 2022-05-03 02:23:44 -0700 | [diff] [blame] | 94 | impl<'a, Derived, Base> Upcast<&'a mut Base> for &'a mut Derived |
| 95 | where |
| 96 | Pin<&'a mut Derived>: Upcast<Pin<&'a mut Base>>, |
| 97 | Derived: Unpin, |
| 98 | Base: Unpin, |
| 99 | { |
| 100 | fn upcast(self: &'a mut Derived) -> &'a mut Base { |
| 101 | Pin::into_inner(Pin::new(self).upcast()) |
| 102 | } |
| 103 | } |
| 104 | |
| 105 | /// Unsafely upcast a raw pointer. `Derived : Inherits<Base>` means that |
| 106 | /// `Derived` can be upcast to `Base`. |
| 107 | /// |
| 108 | /// To upcast in safe code, use the `Upcast` trait. `Inherits` is used for |
| 109 | /// unsafe pointer upcasts, and to implement upcasting. |
| 110 | /// |
| 111 | /// (Note that unlike `Upcast`, `Inherits` is not implemented on the pointers |
| 112 | /// themselves -- this is solely for trait coherence reasons, as owning `T` does |
| 113 | /// not currently grant ownership over `*const T` or `*mut T`.) |
| 114 | /// |
| 115 | /// ## Safety |
| 116 | /// |
| 117 | /// Implementations must uphold the safety contract of the unsafe functions in |
| 118 | /// this trait. |
| 119 | /// |
| 120 | /// TODO(jeanpierreda): Should this be split into two traits? |
| 121 | /// We could have `Inherits` (with safe functions) and `InheritsVirtual` (with |
| 122 | /// unsafe functions). For now, these are all merged into one trait, as it is |
| 123 | /// not an immediately obvious benefit to make raw pointer upcasts a safe |
| 124 | /// operation. |
| 125 | pub unsafe trait Inherits<Base> { |
Dmitri Gribenko | 1dfdf0d | 2022-05-03 23:23:15 -0700 | [diff] [blame] | 126 | /// Upcast a `const` pointer. |
Devin Jeanpierre | b368e68 | 2022-05-03 02:23:44 -0700 | [diff] [blame] | 127 | /// |
| 128 | /// ## Safety |
| 129 | /// |
| 130 | /// Casting follows the same safety and dereferencability rules as C++: |
| 131 | /// |
| 132 | /// If `derived` is a dereferencable pointer, then the upcasted pointer is a |
| 133 | /// dereferencable pointer with the same lifetime. |
| 134 | /// |
| 135 | /// If `derived` is null, this returns null. |
| 136 | /// |
| 137 | /// If `derived` is non-dereferencable, and `Base` is a non-virtual base |
| 138 | /// class, then the return value is non-dereferencable. |
| 139 | /// |
| 140 | /// Otherwise, if `derived` is non-dereferencable and `Base` is a virtual |
| 141 | /// base class, the behavior is undefined. |
| 142 | unsafe fn upcast_ptr(derived: *const Self) -> *const Base; |
| 143 | |
| 144 | /// Upcast a `mut` pointer. |
| 145 | /// |
| 146 | /// ## Safety |
| 147 | /// |
| 148 | /// Casting follows the same safety and dereferencability rules as C++: |
| 149 | /// |
| 150 | /// If `derived` is a dereferencable pointer, then the upcasted pointer is a |
| 151 | /// dereferencable pointer with the same lifetime. |
| 152 | /// |
| 153 | /// If `derived` is null, this returns null. |
| 154 | /// |
| 155 | /// If `derived` is non-dereferencable, and `Base` is a non-virtual base |
| 156 | /// class, then the return value is non-dereferencable. |
| 157 | /// |
| 158 | /// Otherwise, if `derived` is non-dereferencable and `Base` is a virtual |
| 159 | /// base class, the behavior is undefined. |
| 160 | unsafe fn upcast_ptr_mut(derived: *mut Self) -> *mut Base { |
| 161 | Self::upcast_ptr(derived) as *mut _ |
| 162 | } |
| 163 | } |
| 164 | |
| 165 | /// All classes are their own improper base. |
| 166 | unsafe impl<T> Inherits<T> for T { |
| 167 | unsafe fn upcast_ptr(derived: *const Self) -> *const Self { |
| 168 | derived |
| 169 | } |
| 170 | } |
| 171 | |
| 172 | #[cfg(test)] |
| 173 | mod test { |
| 174 | use super::*; |
| 175 | |
| 176 | fn ptr_location<T: std::ops::Deref>(x: T) -> usize { |
| 177 | &*x as *const _ as *const u8 as usize |
| 178 | } |
| 179 | |
| 180 | #[test] |
| 181 | fn test_unpin_upcast() { |
| 182 | #[derive(Default)] |
| 183 | struct Base(i32); |
| 184 | |
| 185 | #[derive(Default)] |
| 186 | struct Derived { |
| 187 | _other_field: u32, |
| 188 | base: Base, |
| 189 | } |
| 190 | |
| 191 | unsafe impl Inherits<Base> for Derived { |
| 192 | unsafe fn upcast_ptr(derived: *const Self) -> *const Base { |
Devin Jeanpierre | 54388ae | 2022-05-03 03:32:00 -0700 | [diff] [blame] | 193 | &(*derived).base |
Devin Jeanpierre | b368e68 | 2022-05-03 02:23:44 -0700 | [diff] [blame] | 194 | } |
| 195 | } |
| 196 | let mut derived = Derived::default(); |
| 197 | assert_eq!(ptr_location(&derived.base), ptr_location::<&Base>((&derived).upcast())); |
| 198 | |
| 199 | let _: *const Base = unsafe { Derived::upcast_ptr(&derived) }; |
| 200 | let _: *mut Base = unsafe { Derived::upcast_ptr_mut(&mut derived) }; |
| 201 | let _: &mut Base = (&mut derived).upcast(); |
| 202 | let _: Pin<&mut Base> = (&mut derived).upcast(); |
| 203 | let _: Pin<&mut Base> = Pin::new(&mut derived).upcast(); |
| 204 | |
| 205 | // This write must not be UB: |
| 206 | { |
| 207 | let base: &mut Base = (&mut derived).upcast(); |
| 208 | base.0 = 42; |
| 209 | } |
| 210 | assert_eq!(derived.base.0, 42); |
| 211 | } |
| 212 | |
| 213 | #[test] |
| 214 | fn test_nonunpin_upcast() { |
| 215 | #[derive(Default)] |
| 216 | struct Base(i32); |
| 217 | impl !Unpin for Base {} |
| 218 | |
| 219 | #[derive(Default)] |
| 220 | struct Derived { |
| 221 | _other_field: u32, |
| 222 | base: Base, |
| 223 | } |
| 224 | impl Unpin for Derived {} |
| 225 | |
| 226 | unsafe impl Inherits<Base> for Derived { |
| 227 | unsafe fn upcast_ptr(derived: *const Self) -> *const Base { |
Devin Jeanpierre | 54388ae | 2022-05-03 03:32:00 -0700 | [diff] [blame] | 228 | &(*derived).base |
Devin Jeanpierre | b368e68 | 2022-05-03 02:23:44 -0700 | [diff] [blame] | 229 | } |
| 230 | } |
| 231 | let mut derived = Derived::default(); |
| 232 | assert_eq!(ptr_location(&derived.base), ptr_location::<&Base>((&derived).upcast())); |
| 233 | |
| 234 | let _: *const Base = unsafe { Derived::upcast_ptr(&derived) }; |
| 235 | let _: *mut Base = unsafe { Derived::upcast_ptr_mut(&mut derived) }; |
| 236 | // let _: &mut Base = (&mut derived).upcast(); // does not compile |
| 237 | let _: Pin<&mut Base> = (&mut derived).upcast(); |
| 238 | let _: Pin<&mut Base> = Pin::new(&mut derived).upcast(); |
| 239 | |
| 240 | // This write must not be UB: |
| 241 | unsafe { |
| 242 | let base: *mut Base = Derived::upcast_ptr_mut(&mut derived); |
| 243 | |
| 244 | (&mut *base).0 = 42; |
| 245 | } |
| 246 | assert_eq!(derived.base.0, 42); |
| 247 | } |
| 248 | } |