blob: 203a0891c33af7cf39b378e89c1948a1edcbdb28 [file] [log] [blame] [edit]
// Part of the Crubit project, under the Apache License v2.0 with LLVM
// Exceptions. See /LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//! Memoization of functions, in the style of Salsa.
//!
//! See the documentation for `query_group!` for more details.
//!
//! # Why not Salsa?
//!
//! Salsa is not compatible with rustc: types like `Ty<'tcx>`, which are not
//! `'static`, cannot be used as function parameters or return values in Salsa.
//! See https://github.com/salsa-rs/salsa/issues/424.
//!
//! Since Salsa is off the table, this module offers a very simplified
//! alternative, which implements _only_ memoization, as that is all we need.
//!
//! # Differences with Salsa
//!
//! * Supports non-`'static` types.
//! * Syntactic differences when initially setting up the trait and database.
//! * Immutable input and no support for recomputation given mutated inputs.
//! * Correspondingly, no requirement that the *return* types implement `Eq` or
//! `Hash`.
//! * Supports `#[break_cycles_with = <default-value>]`, which generates a
//! function that returns <default-value> if a cycle is detected.
//!
//! There are more substantial differences with Salsa 2022 - this was written
//! based on Salsa 0.16. We don't need to match exactly the API, but the
//! differences are kept relatively small so as to support eventually either
//! going back to Salsa, or evolving towards something closer to what
//! Salsa implements.
/// `query_group!` defines a collection of memoized functions, and the shared
/// inputs that all of those functions can access.
///
/// These functions are defined as a trait, as well as a concrete type that
/// stores the inputs and memoized values, and implements that trait.
///
/// The structure of an invocation is always as follows:
///
/// ```
/// query_group! {
/// trait QueryGroupName {
/// // First, all shared inputs are specified, in order.
/// //
/// // These are available to all of the memoized functions, by calling `db.some_input()`, etc.,
/// // in order to access the input value. In other words, they form a sort of immutable global
/// // state, useful to pass immutable configuration values to _all_ of the memoized functions.
/// //
/// // The inputs are immutable, and must be `Clone`. Because they are immutable and global,
/// // they do not need to be `Eq` or `Hash`, and they are not compared when memoizing
/// // functions.
/// //
/// // As a rule of thumb, anything which is a constant for the lifetime of an execution, but
/// // is not an actual compile time constant, should be an input.
///
/// #[input]
/// /// Doc comment goes after `#[input]`.
/// fn some_input(&self) -> InputType;
/// //...
///
/// // After all of the inputs, the actual memoized functions are specified.
/// //
/// // Memoized functions are pure functions, which take any number of (`Clone+Eq+Hash`)
/// // parameters, and return a `Clone` return type.
/// //
/// // Each of these must be implemented in the _same module_ as a top-level function, with
/// // the same function signature except taking `&dyn QueryGroupName` instead of `&self`.
/// //
/// // The functions can call methods on the `&dyn QueryGroupName` to access both the shared
/// // inputs, as well to call other memoized functions.
/// //
/// // Inputs to the computation which change from call to call should be specified as function
/// // parameters. Inputs which never change can be either `#[input]` functions, or actual
/// // program globals.
/// //
/// // When called on the `&dyn QueryGroupName` or directly on the concrete type, the functions
/// // will be memoized, and the return value will be cached, automatically.
/// //
/// // Some functions may need to gracefully handle cycles, in which case they should be
/// // annotated with `#[break_cycles_with = <default_value>]`. This will generate a function
/// // that returns <default_value> if a cycle is detected, but will _not_ cache the result.
/// // All `#[break_cycles_with = ..]` functions must appear before all
/// // non-`#[break_cycles_with = ..]` functions.
///
/// #[break_cycles_with = ReturnType::default()]
/// /// Doc comment goes after `#[break_cycles_with = ..]`.
/// fn may_be_cyclic(&self, arg: ArgType) -> ReturnType;
///
/// /// Doc comment goes here.
/// fn some_function(&self, arg: ArgType) -> ReturnType;
/// }
/// // The concrete type for the storage of inputs and memoized values.
/// // `query_group!` will add the required fields to this struct.
/// struct Database;
/// }
///
/// // The non-memoized implementation of the memoized functions
/// fn may_by_cyclic(db: &dyn QueryGroupName, arg: ArgType) -> ReturnType {
/// // ...
/// }
///
/// fn some_function(db: &dyn QueryGroupName, arg: ArgType) -> ReturnType {
/// // ...
/// }
/// ```
///
/// A new instance of `Database` can be created by using `Database::new(input,
/// values, here)`. It will implement the trait defined above: the `#[input]`
/// functions will return the corresponding values passed in to `Database::new`,
/// and the memoized functions will call the corresponding top-level functions.
///
/// For example, above, one could run `let db =
/// Database::new(InputType::default())`.
///
/// Now, if you call `db.some_function(...)`, it will either return a cached
/// value (if one is present), or else execute `some_function(&db as &dyn
/// QueryGroupName, ...)`, and cache and return the result. A direct call to
/// `some_function` (instead of `db.some_function`) is not directly memoized.
///
/// Because the results are saved, it's very important that the functions are
/// pure and have no side-effects. In particular, internal mutability on the
/// inputs and arguments is best avoided.
///
/// Important notes:
///
/// * In the trait definition, all `#[input]` functions must be declared before
/// all (non-`#[input]`) memoized functions. The order of the input functions
/// is the order of their parameters in `new()`.
/// * Every (non-`#[input]`) trait method _must_ have a matching top-level
/// function, with `&self` replaced by `&dyn QueryGroupName`.
/// * Since all trait methods are memoized, their arguments must be `Clone`,
/// `Eq`, and `Hash`.
///
/// # Non-`'static` types
///
/// The trait and struct can accept a lifetime parameter, but the syntax is a
/// little trickier:
///
/// ```
/// query_group! {
/// trait QueryGroupName<'a> {
/// #[input]
/// fn some_input(&self) -> &'a InputType;
/// fn some_function(&self, arg: &'a ArgType) -> &'a ReturnType;
/// }
/// struct Database;
/// }
///
/// fn<'a> some_function(db: &dyn QueryGroupName<'a>, arg: &'a ArgType) -> &'a ReturnType {
/// // ...
/// }
/// ```
///
/// Note that under the hood, `struct Database` will be replaced by something
/// like:
///
/// ```
/// struct Database<'a> {...}
/// ```
///
/// And so you may need to specify the lifetime in some uses.
#[macro_export]
macro_rules! query_group {
(
$trait_vis:vis trait $trait:ident $(<$($type_param:tt),*>)?{
$(
// TODO(jeanpierreda): Ideally would allow putting the doc-comment first,
// but this causes parsing ambiguity.
#[input]
$(#[doc = $input_doc:literal])*
fn $input_function:ident(&self $(,)?) -> $input_type:ty;
)*
$(
// TODO(jeanpierreda): Ideally would allow putting the doc-comment first,
// but this causes parsing ambiguity.
#[break_cycles_with = $break_cycles_default_value:expr]
$(#[doc = $break_cycles_doc:literal])*
fn $break_cycles_function:ident(
&self
$(
, $break_cycles_arg:ident : $break_cycles_arg_type:ty
)*
$(,)?
) -> $break_cycles_return_type:ty;
)*
$(
$(#[doc = $function_doc:literal])*
fn $function:ident(
&self
$(
, $arg:ident : $arg_type:ty
)*
$(,)?
) -> $return_type:ty;
)*
$(
#[provided]
$(#[doc = $provided_doc:literal])*
fn $provided_function:ident(&$provided_self:ident $(, $provided_arg:ident : $provided_arg_type:ty)* $(,)?) -> $provided_type:ty { $($provided_body:tt)* }
)*
}
$struct_vis:vis struct $database_struct:ident;
) => {
// First, yes, generate the trait.
$trait_vis trait $trait $(<$($type_param),*>)?{
$(
$(#[doc = $input_doc])*
fn $input_function(&self) -> $input_type { unimplemented!(concat!("input function '", stringify!($input_function), "'")) }
)*
$(
$(#[doc = $break_cycles_doc])*
fn $break_cycles_function(
&self,
$(
$break_cycles_arg : $break_cycles_arg_type
),*
) -> $break_cycles_return_type { _ = ($($break_cycles_arg),*); unimplemented!(concat!("break cycles function '", stringify!($break_cycles_function), "'")) }
)*
$(
$(#[doc = $function_doc])*
fn $function(
&self,
$(
$arg : $arg_type
),*
) -> $return_type { _ = ($($arg),*); unimplemented!(concat!("function '", stringify!($function), "'")) }
)*
$(
$(#[doc = $provided_doc])*
fn $provided_function(&$provided_self $(, $provided_arg: $provided_arg_type)*) -> $provided_type { _ = ($($provided_arg),*); unimplemented!(concat!("provided function '", stringify!($provided_function), "'")) }
)*
}
// Avoid having to repeat the `$type_param` metavariable below, since its repetition
// should not be zipped with other metavariable repetitions.
// This also avoids having to repeat out `&dyn $trait $(<$($type_param),*>)?` all over the
// place.
macro_rules! dyn_trait {
() => { &dyn $trait $(<$($type_param),*>)? };
}
// Now we can generate a database struct that contains the lookup tables.
$struct_vis struct $database_struct $(<$($type_param),*>)? {
__unwinding_cycles: ::core::cell::Cell<u32>,
$(
$input_function: $input_type,
)*
$(
$break_cycles_function: $crate::internal::FnAndTable<
fn(dyn_trait!(), $($break_cycles_arg_type),*) -> $break_cycles_return_type,
($($break_cycles_arg_type,)*),
$break_cycles_return_type
>,
)*
$(
$function: $crate::internal::FnAndTable<
fn(dyn_trait!(), $($arg_type),*) -> $return_type,
($($arg_type,)*),
$return_type
>,
)*
}
// ...and an implementation of the trait.
impl $(<$($type_param),*>)? $trait $(<$($type_param),*>)? for $database_struct $(<$($type_param),*>)? {
$(
fn $input_function(&self) -> $input_type {
// Have to be very careful to clone whatever the top level value is.
// In particular, if it's a reference `&T`, clone the _reference_ to get another `&T`,
// not the referent to get a `T`, like if we just cloned `self.$input_function`.
(&self.$input_function).clone()
}
)*
$(
fn $break_cycles_function(
&self,
$(
$break_cycles_arg : $break_cycles_arg_type
),*
) -> $break_cycles_return_type {
self.$break_cycles_function.table.internal_memoized_call(
($(
$break_cycles_arg,
)*),
|($($break_cycles_arg,)*)| {
// Force the use of &dyn $trait, so that we don't rule out separate compilation later.
(self.$break_cycles_function.fn_ptr)(self as &dyn $trait, $($break_cycles_arg),*)
},
&self.__unwinding_cycles,
).unwrap_or($break_cycles_default_value)
}
)*
$(
fn $function(
&self,
$(
$arg : $arg_type
),*
) -> $return_type {
self.$function.table.internal_memoized_call(
($(
$arg,
)*),
|($($arg,)*)| {
// Force the use of &dyn $trait, so that we don't rule out separate compilation later.
(self.$function.fn_ptr)(self as &dyn $trait, $($arg),*)
},
&self.__unwinding_cycles,
).unwrap_or_else(
|| panic!("Cycle detected: '{}' depends on its own return value", stringify!($function)),
)
}
)*
$(
fn $provided_function(
&$provided_self,
$(
$provided_arg : $provided_arg_type
),*
) -> $provided_type {
$( $provided_body )*
}
)*
}
// and the new() function for initialization.
impl $(<$($type_param),*>)? $database_struct $(<$($type_param),*>)? {
$struct_vis fn new(
$($input_function: $input_type,)*
$($break_cycles_function: fn(dyn_trait!(), $($break_cycles_arg_type),*) -> $break_cycles_return_type,)*
$($function: fn(dyn_trait!(), $($arg_type),*) -> $return_type,)*
) -> Self {
Self {
__unwinding_cycles: ::core::cell::Cell::new(0),
$(
$input_function,
)*
$(
$break_cycles_function: $crate::internal::FnAndTable {
fn_ptr: $break_cycles_function,
table: Default::default(),
},
)*
$(
$function: $crate::internal::FnAndTable {
fn_ptr: $function,
table: Default::default(),
},
)*
}
}
}
}
}
#[doc(hidden)]
pub mod internal {
use std::cell::{Cell, RefCell};
use std::collections::HashMap;
use std::hash::Hash;
#[derive(Copy, Clone, PartialEq, Eq)]
enum FoundCycle {
No,
Yes,
}
pub struct FnAndTable<F, Args, Return>
where
Args: Clone + Eq + Hash,
Return: Clone,
{
pub fn_ptr: F,
pub table: MemoizationTable<Args, Return>,
}
pub struct MemoizationTable<Args, Return>
where
Args: Clone + Eq + Hash,
Return: Clone,
{
memoized: RefCell<HashMap<Args, Return>>,
active: RefCell<HashMap<Args, FoundCycle>>,
}
// Separate `impl` instead of `#[derive(Default)]` because the `derive` would
// needlessly require that `Args` and `Return` also implement `Default`.
impl<Args, Return> Default for MemoizationTable<Args, Return>
where
Args: Clone + Eq + Hash,
Return: Clone,
{
fn default() -> Self {
Self { memoized: RefCell::new(HashMap::new()), active: RefCell::new(HashMap::new()) }
}
}
impl<Args, Return> MemoizationTable<Args, Return>
where
Args: Clone + Eq + Hash,
Return: Clone,
{
pub fn internal_memoized_call<F>(
&self,
args: Args,
f: F,
unwinding_cycles: &Cell<u32>,
) -> Option<Return>
where
F: FnOnce(Args) -> Return,
{
if let Some(return_value) = self.memoized.borrow().get(&args) {
return Some(return_value.clone());
}
if let Some(found_cycle) = self.active.borrow_mut().get_mut(&args) {
// We're in a cycle.
if *found_cycle == FoundCycle::No {
// Only increase the count if we haven't hit this cycle before.
unwinding_cycles.set(unwinding_cycles.get() + 1);
}
*found_cycle = FoundCycle::Yes;
return None;
}
self.active.borrow_mut().insert(args.clone(), FoundCycle::No);
let return_value = f(args.clone());
let found_cycle = self.active.borrow_mut().remove(&args).expect(
"Internal error: currently-active cycle detection args not found. \
Most likely this is because of a buggy Clone/Eq/Hash impl.",
);
if found_cycle == FoundCycle::Yes {
// We did hit outselves in a cycle but now we've broken out of it.
// If we hit ourselves multiple times, we were careful to only increment this
// count once.
unwinding_cycles.set(unwinding_cycles.get() - 1);
}
if unwinding_cycles.get() == 0 {
// No cycles, we can safely cache the result knowing that we haven't depended on
// any cycle default values.
self.memoized.borrow_mut().insert(args, return_value.clone());
}
Some(return_value)
}
}
}
#[cfg(test)]
pub mod tests {
use googletest::prelude::*;
use std::cell::{Cell, RefCell};
use std::rc::Rc;
#[gtest]
fn test_basic_memoization() {
crate::query_group! {
pub trait Add10 {
#[input]
/// Tracker for how many times this function is called so we can check
/// that memoization is indeed happening. This is just for testing;
/// memoized functions in non-test code shouldn't have side effects,
/// and inputs in non-test code shouldn't have internal mutability.
fn call_counter(&self) -> Rc<Cell<i32>>;
fn add10(&self, arg: i32) -> i32;
}
pub struct Database;
}
fn add10(db: &dyn Add10, arg: i32) -> i32 {
db.call_counter().set(db.call_counter().get() + 1);
arg + 10
}
let db = Database::new(Rc::new(Cell::new(0)), add10);
assert_eq!(db.add10(100), 110);
assert_eq!(db.call_counter().get(), 1);
assert_eq!(db.add10(100), 110);
assert_eq!(db.call_counter().get(), 1);
assert_eq!(db.add10(200), 210);
assert_eq!(db.call_counter().get(), 2);
}
/// The raison d'etre of this module: memoization with an attached lifetime.
///
/// This test is similar to test_basic_memoization, except that it accepts
/// and returns references.
#[gtest]
fn test_nonstatic_memoization() {
crate::query_group! {
pub trait Add10<'a> {
#[input]
fn call_counter(&self) -> &'a Cell<i32>;
fn add10(&self, arg: &'a i32) -> i32;
// non-input function with 'a in return type
fn identity(&self, arg: &'a i32) -> &'a i32;
}
pub struct Database;
}
fn add10<'a>(db: &dyn Add10<'a>, arg: &'a i32) -> i32 {
db.call_counter().set(db.call_counter().get() + 1);
*arg + 10
}
fn identity<'a>(db: &dyn Add10<'a>, arg: &'a i32) -> &'a i32 {
db.call_counter().set(db.call_counter().get() + 1);
arg
}
let count = Cell::new(0);
let db = Database::new(&count, add10, identity);
assert_eq!(db.add10(&100), 110);
assert_eq!(count.get(), 1);
assert_eq!(db.add10(&100), 110);
assert_eq!(count.get(), 1);
assert_eq!(db.add10(&200), 210);
assert_eq!(count.get(), 2);
assert_eq!(db.identity(&100), &100);
assert_eq!(count.get(), 3);
assert_eq!(db.identity(&100), &100);
assert_eq!(count.get(), 3);
}
#[gtest]
#[should_panic(expected = "Cycle detected: 'add10' depends on its own return value")]
fn test_cycle() {
crate::query_group! {
pub trait Add10 {
fn add10(&self, arg: i32) -> i32;
}
pub struct Database;
}
fn add10(db: &dyn Add10, arg: i32) -> i32 {
db.add10(arg) // infinite recursion!
}
let db = Database::new(add10);
db.add10(1);
}
#[gtest]
fn test_break_cycles_with_option() {
crate::query_group! {
pub trait Add10 {
#[break_cycles_with = None]
fn add10(&self, arg: i32) -> Option<i32>;
}
pub struct Database;
}
fn add10(db: &dyn Add10, arg: i32) -> Option<i32> {
db.add10(arg)
}
let db = Database::new(add10);
assert_eq!(db.add10(1), None);
}
#[gtest]
fn test_break_cycles_with_sentinel() {
crate::query_group! {
pub trait Add10 {
#[break_cycles_with = -1]
fn add10(&self, arg: i32) -> i32;
}
pub struct Database;
}
fn add10(db: &dyn Add10, arg: i32) -> i32 {
db.add10(arg)
}
let db = Database::new(add10);
assert_eq!(db.add10(1), -1);
}
#[gtest]
fn test_calls_in_cycle_are_not_memoized() {
crate::query_group! {
pub trait Table {
#[input]
fn logging(&self) -> Rc<RefCell<Vec<String>>>;
#[input]
fn records(&self) -> &'static [Record];
#[break_cycles_with = false]
fn is_unsafe(&self, name: &'static str) -> bool;
fn record(&self, name: &'static str) -> Record;
}
pub struct Database;
}
#[derive(Clone)]
struct Record {
name: &'static str,
is_unsafe: bool,
fields: &'static [&'static str],
}
// Returns whether or not a record is unsafe, checking recursively.
fn is_unsafe(db: &dyn Table, name: &'static str) -> bool {
let record = db.record(name);
let outcome =
record.is_unsafe || record.fields.iter().any(|&field| db.is_unsafe(field));
db.logging().borrow_mut().push(format!("is_unsafe({name}) = {outcome}"));
outcome
}
// Helper function so we can refer to records by name instead of by index.
fn record(db: &dyn Table, name: &'static str) -> Record {
db.records()
.iter()
.find(|record| record.name == name)
.expect("Record not found")
.clone()
}
let logging = Rc::default();
let db = Database::new(
Rc::clone(&logging),
&[
Record { name: "A", is_unsafe: false, fields: &["B", "Unsafe"] },
Record { name: "B", is_unsafe: false, fields: &["A"] },
Record { name: "Unsafe", is_unsafe: true, fields: &[] },
],
is_unsafe,
record,
);
// When checking if A is unsafe, it will first ask B, which will try to ask A
// again, defaulting to false. So B says "I guess I'm safe", but _doesn't_
// memoize that result. A will then see that it has Unsafe which is unsafe, so A
// will memoize itself as unsafe. But when we go to ask B if it's unsafe now, it
// will have correctly _not_ memoized that it's safe, and so it will ask
// A again, which will again say "I am unsafe", and so B will correctly memoize
// that it's unsafe.
assert!(db.is_unsafe("A"));
assert!(db.is_unsafe("B"));
assert_eq!(
logging.borrow().clone(),
vec![
"is_unsafe(B) = false".to_string(), // this is the cycle-default value
"is_unsafe(Unsafe) = true".to_string(),
"is_unsafe(A) = true".to_string(),
"is_unsafe(B) = true".to_string(), // as we can see, the default wasn't memoized
]
);
}
#[gtest]
fn test_finite_recursion() {
crate::query_group! {
pub trait Add10 {
#[input]
fn call_counter(&self) -> Rc<Cell<i32>>;
fn add10(&self, arg: i32) -> i32;
}
pub struct Database;
}
fn add10(db: &dyn Add10, arg: i32) -> i32 {
db.call_counter().set(db.call_counter().get() + 1);
if (arg % 10) != 0 {
db.add10(arg - 1) + 1 // Some recursion, but not infinite!
} else {
arg + 10
}
}
let db = Database::new(Rc::new(Cell::new(0)), add10);
assert_eq!(db.add10(100), 110);
assert_eq!(db.call_counter().get(), 1);
assert_eq!(db.add10(100), 110);
assert_eq!(db.call_counter().get(), 1);
assert_eq!(db.add10(205), 215);
assert_eq!(db.call_counter().get(), 7);
assert_eq!(db.add10(205), 215);
assert_eq!(db.call_counter().get(), 7);
}
/// As an edge case (which perhaps isn't optimized well[^1]), you can even
/// memoize a function which accepts no additional arguments, as a way
/// of running a fixed computation at most once.
///
/// [^1]: Since it has no inputs at all, we could store it as a bare `Option<ReturnType>`, but
/// instead we use `HashMap<(), ReturnType>`. Well, good enough.
#[gtest]
fn test_argless() {
crate::query_group! {
pub trait Argless {
#[input]
fn call_counter(&self) -> Rc<Cell<i32>>;
fn argless_function(&self) -> Rc<i32>;
}
pub struct Database;
}
fn argless_function(db: &dyn Argless) -> Rc<i32> {
db.call_counter().set(db.call_counter().get() + 1);
Rc::new(0)
}
let db = Database::new(Rc::new(Cell::new(0)), argless_function);
assert_eq!(db.call_counter().get(), 0);
let argless_return = db.argless_function();
assert_eq!(db.call_counter().get(), 1);
let argless_return_2 = db.argless_function();
assert_eq!(db.call_counter().get(), 1);
assert!(Rc::ptr_eq(&argless_return, &argless_return_2));
}
#[gtest]
fn test_provided_fn() {
crate::query_group! {
pub trait Db {
#[input]
fn input_fn(&self) -> i32;
#[provided]
fn provided_fn(&self, x: i32) -> i32 { self.input_fn() + x }
}
pub struct Database;
}
let db = Database::new(42);
let result = db.provided_fn(10);
expect_eq!(result, 52);
}
}