blob: 1110c12fc0474ef54c66c1383b4e17e06e92eb97 [file] [log] [blame]
Devin Jeanpierreb368e682022-05-03 02:23:44 -07001// Part of the Crubit project, under the Apache License v2.0 with LLVM
2// Exceptions. See /LICENSE for license information.
3// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5#![cfg_attr(test, feature(negative_impls))]
6
7//! # Object-Oriented Programming Support (OOPS).
8//!
9//! ## Upcasting
10//!
11//! To cast a reference to its base class type, use `my_reference.upcast()`.
12//! For example:
13//!
14//! ```ignore
15//! let x : &mut Derived = ...;
16//! let y : Pin<&mut Base> = x.upcast();
17//! ```
18//!
19//! Because base classes are always `!Unpin`, mutable references to base must
20//! take the form of `Pin<&mut Base>`. See
21//! docs/unpin.md
22//!
23//! To implement upcasting, implement the `Inherits` trait.
24//!
25//! ## Downcasting
26//!
27//! TODO(b/216195042): dynamic downcasting
28//! TODO(b/216195042): static downcasting
29
30use std::pin::Pin;
31
32/// Upcast a reference or smart pointer. This operation cannot fail at runtime.
33///
34/// If `Derived` has a (public, unambiguous) base class `Base`, then:
35///
36/// ```ignore
37/// &Derived : Upcast<&Base>
38/// Pin<&mut Derived> : Upcast<Pin<&mut Base>>
39/// ```
40///
41/// In addition, if `Derived : Unpin`, then `&mut Derived : Upcast<Pin<&mut
42/// Base>>`.
43///
44/// (And, while it is not possible in Crubit bindings, if `Base` is also
45/// `Unpin`, then `&mut Derived : Upcast<&mut Base>`.)
46///
47/// For the purpose of `Upcast`, any type `T` is its own ("improper") base
48/// class.
49pub trait Upcast<Target> {
50 fn upcast(self) -> Target;
51}
52
53/// Upcast `&` -> `&`.
54impl<'a, Derived, Base> Upcast<&'a Base> for &'a Derived
55where
56 Derived: Inherits<Base>,
57{
58 fn upcast(self: &'a Derived) -> &'a Base {
59 unsafe { &*Derived::upcast_ptr(self as *const Derived) }
60 }
61}
62
63/// Upcast `Pin<&mut>` -> `Pin<&mut>.
64impl<'a, Derived, Base> Upcast<Pin<&'a mut Base>> for Pin<&'a mut Derived>
65where
66 Derived: Inherits<Base>,
67{
68 fn upcast(self: Pin<&'a mut Derived>) -> Pin<&'a mut Base> {
69 unsafe {
70 let inner = Pin::into_inner_unchecked(self) as *mut Derived;
71 Pin::new_unchecked(&mut *Derived::upcast_ptr_mut(inner))
72 }
73 }
74}
75
76/// Upcast `&mut` -> `Pin<&mut>.
77///
78/// Since all C++ base classes are `!Unpin`, this is the normal shape of a
79/// mutable reference upcast for an `Unpin` derived class.
80impl<'a, Derived, Base> Upcast<Pin<&'a mut Base>> for &'a mut Derived
81where
82 Pin<&'a mut Derived>: Upcast<Pin<&'a mut Base>>,
83 Derived: Unpin,
84{
85 fn upcast(self: &'a mut Derived) -> Pin<&'a mut Base> {
86 Pin::new(self).upcast()
87 }
88}
89
90/// Upcast `&mut` -> `&mut`.
91///
Dmitri Gribenko1dfdf0d2022-05-03 23:23:15 -070092/// This impl is never applicable to C++ types (a C++ base class is `!Unpin`),
93/// but could work for inheritance implemented in pure Rust.
Devin Jeanpierreb368e682022-05-03 02:23:44 -070094impl<'a, Derived, Base> Upcast<&'a mut Base> for &'a mut Derived
95where
96 Pin<&'a mut Derived>: Upcast<Pin<&'a mut Base>>,
97 Derived: Unpin,
98 Base: Unpin,
99{
100 fn upcast(self: &'a mut Derived) -> &'a mut Base {
101 Pin::into_inner(Pin::new(self).upcast())
102 }
103}
104
105/// Unsafely upcast a raw pointer. `Derived : Inherits<Base>` means that
106/// `Derived` can be upcast to `Base`.
107///
108/// To upcast in safe code, use the `Upcast` trait. `Inherits` is used for
109/// unsafe pointer upcasts, and to implement upcasting.
110///
111/// (Note that unlike `Upcast`, `Inherits` is not implemented on the pointers
112/// themselves -- this is solely for trait coherence reasons, as owning `T` does
113/// not currently grant ownership over `*const T` or `*mut T`.)
114///
115/// ## Safety
116///
117/// Implementations must uphold the safety contract of the unsafe functions in
118/// this trait.
119///
120/// TODO(jeanpierreda): Should this be split into two traits?
121/// We could have `Inherits` (with safe functions) and `InheritsVirtual` (with
122/// unsafe functions). For now, these are all merged into one trait, as it is
123/// not an immediately obvious benefit to make raw pointer upcasts a safe
124/// operation.
125pub unsafe trait Inherits<Base> {
Dmitri Gribenko1dfdf0d2022-05-03 23:23:15 -0700126 /// Upcast a `const` pointer.
Devin Jeanpierreb368e682022-05-03 02:23:44 -0700127 ///
128 /// ## Safety
129 ///
130 /// Casting follows the same safety and dereferencability rules as C++:
131 ///
132 /// If `derived` is a dereferencable pointer, then the upcasted pointer is a
133 /// dereferencable pointer with the same lifetime.
134 ///
135 /// If `derived` is null, this returns null.
136 ///
137 /// If `derived` is non-dereferencable, and `Base` is a non-virtual base
138 /// class, then the return value is non-dereferencable.
139 ///
140 /// Otherwise, if `derived` is non-dereferencable and `Base` is a virtual
141 /// base class, the behavior is undefined.
142 unsafe fn upcast_ptr(derived: *const Self) -> *const Base;
143
144 /// Upcast a `mut` pointer.
145 ///
146 /// ## Safety
147 ///
148 /// Casting follows the same safety and dereferencability rules as C++:
149 ///
150 /// If `derived` is a dereferencable pointer, then the upcasted pointer is a
151 /// dereferencable pointer with the same lifetime.
152 ///
153 /// If `derived` is null, this returns null.
154 ///
155 /// If `derived` is non-dereferencable, and `Base` is a non-virtual base
156 /// class, then the return value is non-dereferencable.
157 ///
158 /// Otherwise, if `derived` is non-dereferencable and `Base` is a virtual
159 /// base class, the behavior is undefined.
160 unsafe fn upcast_ptr_mut(derived: *mut Self) -> *mut Base {
161 Self::upcast_ptr(derived) as *mut _
162 }
163}
164
165/// All classes are their own improper base.
166unsafe impl<T> Inherits<T> for T {
167 unsafe fn upcast_ptr(derived: *const Self) -> *const Self {
168 derived
169 }
170}
171
172#[cfg(test)]
173mod test {
174 use super::*;
175
176 fn ptr_location<T: std::ops::Deref>(x: T) -> usize {
177 &*x as *const _ as *const u8 as usize
178 }
179
180 #[test]
181 fn test_unpin_upcast() {
182 #[derive(Default)]
183 struct Base(i32);
184
185 #[derive(Default)]
186 struct Derived {
187 _other_field: u32,
188 base: Base,
189 }
190
191 unsafe impl Inherits<Base> for Derived {
192 unsafe fn upcast_ptr(derived: *const Self) -> *const Base {
Devin Jeanpierre54388ae2022-05-03 03:32:00 -0700193 &(*derived).base
Devin Jeanpierreb368e682022-05-03 02:23:44 -0700194 }
195 }
196 let mut derived = Derived::default();
197 assert_eq!(ptr_location(&derived.base), ptr_location::<&Base>((&derived).upcast()));
198
199 let _: *const Base = unsafe { Derived::upcast_ptr(&derived) };
200 let _: *mut Base = unsafe { Derived::upcast_ptr_mut(&mut derived) };
201 let _: &mut Base = (&mut derived).upcast();
202 let _: Pin<&mut Base> = (&mut derived).upcast();
203 let _: Pin<&mut Base> = Pin::new(&mut derived).upcast();
204
205 // This write must not be UB:
206 {
207 let base: &mut Base = (&mut derived).upcast();
208 base.0 = 42;
209 }
210 assert_eq!(derived.base.0, 42);
211 }
212
213 #[test]
214 fn test_nonunpin_upcast() {
215 #[derive(Default)]
216 struct Base(i32);
217 impl !Unpin for Base {}
218
219 #[derive(Default)]
220 struct Derived {
221 _other_field: u32,
222 base: Base,
223 }
224 impl Unpin for Derived {}
225
226 unsafe impl Inherits<Base> for Derived {
227 unsafe fn upcast_ptr(derived: *const Self) -> *const Base {
Devin Jeanpierre54388ae2022-05-03 03:32:00 -0700228 &(*derived).base
Devin Jeanpierreb368e682022-05-03 02:23:44 -0700229 }
230 }
231 let mut derived = Derived::default();
232 assert_eq!(ptr_location(&derived.base), ptr_location::<&Base>((&derived).upcast()));
233
234 let _: *const Base = unsafe { Derived::upcast_ptr(&derived) };
235 let _: *mut Base = unsafe { Derived::upcast_ptr_mut(&mut derived) };
236 // let _: &mut Base = (&mut derived).upcast(); // does not compile
237 let _: Pin<&mut Base> = (&mut derived).upcast();
238 let _: Pin<&mut Base> = Pin::new(&mut derived).upcast();
239
240 // This write must not be UB:
241 unsafe {
242 let base: *mut Base = Derived::upcast_ptr_mut(&mut derived);
243
244 (&mut *base).0 = 42;
245 }
246 assert_eq!(derived.base.0, 42);
247 }
248}