Generate `operator<` bindings.
For now operator<=> is completely ignored, and even if both operator< and operator<=> are present we construct partial_cmp from lt (i.e. the operator< thunk).
Bindings are also not generated for operator< when the operands have different types, since that's a much more complicated case: implementing partial_cmp in Rust using lt would require transitivity of operator< and operator==, and we'd need to query salsa for them while avoiding cyclic queries.
PiperOrigin-RevId: 480291403
diff --git a/rs_bindings_from_cc/src_code_gen.rs b/rs_bindings_from_cc/src_code_gen.rs
index b223f93..3982cc9 100644
--- a/rs_bindings_from_cc/src_code_gen.rs
+++ b/rs_bindings_from_cc/src_code_gen.rs
@@ -92,6 +92,9 @@
) -> Result<Option<Rc<(RsSnippet, RsSnippet, Rc<FunctionId>)>>>;
fn overloaded_funcs(&self) -> Rc<HashSet<Rc<FunctionId>>>;
+
+ // TODO(b/236687702): convert the get_binding function into a query once
+ // ImplKind implements Eq.
}
#[salsa::database(BindingsGeneratorStorage)]
@@ -372,6 +375,10 @@
/// An Unpin constructor trait, e.g. From or Clone, with a list of parameter
/// types.
UnpinConstructor { name: TokenStream, params: Vec<RsTypeKind> },
+ /// The PartialEq trait.
+ PartialEq { params: Vec<RsTypeKind> },
+ /// The PartialOrd trait.
+ PartialOrd { params: Vec<RsTypeKind> },
/// Any other trait, e.g. Eq.
Other { name: TokenStream, params: Vec<RsTypeKind>, is_unsafe_fn: bool },
}
@@ -382,6 +389,8 @@
match self {
Self::CtorNew(params)
| Self::UnpinConstructor { params, .. }
+ | Self::PartialEq { params }
+ | Self::PartialOrd { params }
| Self::Other { params, .. } => params.iter(),
}
}
@@ -399,6 +408,14 @@
let params = format_generic_params(params);
quote! {#name #params}.to_tokens(tokens)
}
+ Self::PartialEq { params } => {
+ let params = format_generic_params(params);
+ quote! {PartialEq #params}.to_tokens(tokens)
+ }
+ Self::PartialOrd { params } => {
+ let params = format_generic_params(params);
+ quote! {PartialOrd #params}.to_tokens(tokens)
+ }
Self::CtorNew(arg_types) => {
let arg_types = format_tuple_except_singleton(arg_types);
quote! { ::ctor::CtorNew < #arg_types > }.to_tokens(tokens)
@@ -667,11 +684,7 @@
RsTypeKind::Record { record: lhs_record, .. } => {
func_name = make_rs_ident("eq");
impl_kind = ImplKind::new_trait(
- TraitName::Other {
- name: quote! {PartialEq},
- params: vec![(**rhs).clone()],
- is_unsafe_fn: false,
- },
+ TraitName::PartialEq { params: vec![(**rhs).clone()] },
lhs_record,
&ir,
/* format_first_param_as_self= */ true,
@@ -686,6 +699,62 @@
}
};
}
+ UnqualifiedIdentifier::Operator(op) if op.name == "<=>" => {
+ bail!("Three-way comparison operator not yet supported (b/219827738)");
+ }
+ UnqualifiedIdentifier::Operator(op) if op.name == "<" => {
+ assert_eq!(
+ param_types.len(),
+ 2,
+ "Unexpected number of parameters in operator<: {func:?}"
+ );
+ match (¶m_types[0], ¶m_types[1]) {
+ (
+ RsTypeKind::Reference { referent: lhs, mutability: Mutability::Const, .. },
+ RsTypeKind::Reference { referent: rhs, mutability: Mutability::Const, .. },
+ ) => match (&**lhs, &**rhs) {
+ (
+ RsTypeKind::Record { record: lhs_record, .. },
+ RsTypeKind::Record { record: rhs_record, .. },
+ ) => {
+ if lhs_record != rhs_record {
+ bail!("operator< where lhs and rhs are not the same type.");
+ }
+ // PartialOrd requires PartialEq, so we need to make sure operator== is
+ // implemented for this Record type.
+ match get_binding(
+ db,
+ UnqualifiedIdentifier::Operator(Operator { name: "==".to_string() }),
+ param_types,
+ ) {
+ Some((
+ _,
+ ImplKind::Trait { trait_name: TraitName::PartialEq { .. }, .. },
+ )) => {
+ func_name = make_rs_ident("lt");
+ impl_kind = ImplKind::new_trait(
+ TraitName::PartialOrd { params: vec![(**rhs).clone()] },
+ lhs_record,
+ &ir,
+ /* format_first_param_as_self= */
+ true,
+ )?;
+ }
+ _ => bail!("operator< where operator== is missing."),
+ }
+ }
+ (RsTypeKind::Record { .. }, _) => {
+ bail!("operator< where lhs and rhs are not the same type.");
+ }
+ _ => {
+ bail!("operator< where lhs doesn't refer to a record.",);
+ }
+ },
+ _ => {
+ bail!("operator< where operands are not const references.",);
+ }
+ };
+ }
UnqualifiedIdentifier::Operator(op) if op.name == "=" => {
assert_eq!(
param_types.len(),
@@ -966,6 +1035,35 @@
Ok(Some((func_name, impl_kind)))
}
+/// Returns the generated bindings for a function with the given name and param
+/// types. If none exists, returns None.
+fn get_binding(
+ db: &dyn BindingsGenerator,
+ expected_function_name: UnqualifiedIdentifier,
+ expected_param_types: &[RsTypeKind],
+) -> Option<(Ident, ImplKind)> {
+ return db
+ .ir()
+ // TODO(jeanpierreda): make this O(1) using a hash table lookup.
+ .functions()
+ .filter(|function| {
+ function.name == expected_function_name
+ && generate_func(db, (*function).clone()).ok().flatten().is_some()
+ })
+ .find_map(|function| {
+ let mut function_param_types = function
+ .params
+ .iter()
+ .map(|param| db.rs_type_kind(param.type_.rs_type.clone()))
+ .collect::<Result<Vec<_>>>()
+ .ok()?;
+ if !function_param_types.iter().eq(expected_param_types) {
+ return None;
+ }
+ api_func_shape(db, function, &mut function_param_types).ok().flatten()
+ });
+}
+
/// Mutates the provided parameters so that nontrivial by-value parameters are,
/// instead, materialized in the caller and passed by rvalue reference.
fn materialize_ctor_in_caller(func: &Func, params: &mut [RsTypeKind]) {
@@ -1216,6 +1314,22 @@
quote! {
type #name = #quoted_return_type;
}
+ } else if let TraitName::PartialOrd { params: _ } = trait_name {
+ quote! {
+ #[inline(always)]
+ fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
+ if self == other {
+ return Some(core::cmp::Ordering::Equal);
+ }
+ if self < other {
+ return Some(core::cmp::Ordering::Less);
+ }
+ if other < self {
+ return Some(core::cmp::Ordering::Greater);
+ }
+ None
+ }
+ }
} else {
quote! {}
};
@@ -6176,6 +6290,99 @@
}
#[test]
+ fn test_impl_lt_for_member_function() -> Result<()> {
+ let ir = ir_from_cc(
+ r#"#pragma clang lifetime_elision
+ struct SomeStruct final {
+ inline bool operator==(const SomeStruct& other) const {
+ return i == other.i;
+ }
+ inline bool operator<(const SomeStruct& other) const {
+ return i < other.i;
+ }
+ int i;
+ };"#,
+ )?;
+ let BindingsTokens { rs_api, rs_api_impl } = generate_bindings_tokens(ir)?;
+ assert_rs_matches!(
+ rs_api,
+ quote! {
+ impl PartialOrd<crate::SomeStruct> for SomeStruct {
+ #[inline(always)]
+ fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
+ if self == other {
+ return Some(core::cmp::Ordering::Equal);
+ }
+ if self < other {
+ return Some(core::cmp::Ordering::Less);
+ }
+ if other < self {
+ return Some(core::cmp::Ordering::Greater);
+ }
+ None
+ }
+ #[inline(always)]
+ fn lt<'a, 'b>(&'a self, other: &'b crate::SomeStruct) -> bool {
+ unsafe { crate::detail::__rust_thunk___ZNK10SomeStructltERKS_(self, other) }
+ }
+ }
+ }
+ );
+ assert_cc_matches!(
+ rs_api_impl,
+ quote! {
+ extern "C" bool __rust_thunk___ZNK10SomeStructltERKS_(
+ const struct SomeStruct* __this, const struct SomeStruct* other) {
+ return __this->operator<(*other);
+ }
+ }
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn test_impl_lt_for_free_function() -> Result<()> {
+ let ir = ir_from_cc(
+ r#"#pragma clang lifetime_elision
+ struct SomeStruct final {
+ inline bool operator==(const SomeStruct& other) const {
+ return i == other.i;
+ }
+ int i;
+ };
+ bool operator<(const SomeStruct& lhs, const SomeStruct& rhs) {
+ return lhs.i < rhs.i;
+ }"#,
+ )?;
+ let rs_api = generate_bindings_tokens(ir)?.rs_api;
+ assert_rs_matches!(
+ rs_api,
+ quote! {
+ impl PartialOrd<crate::SomeStruct> for SomeStruct {
+ #[inline(always)]
+ fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
+ if self == other {
+ return Some(core::cmp::Ordering::Equal);
+ }
+ if self < other {
+ return Some(core::cmp::Ordering::Less);
+ }
+ if other < self {
+ return Some(core::cmp::Ordering::Greater);
+ }
+ None
+ }
+ #[inline(always)]
+ fn lt<'a, 'b>(&'a self, rhs: &'b crate::SomeStruct) -> bool {
+ unsafe { crate::detail::__rust_thunk___ZltRK10SomeStructS1_(self, rhs) }
+ }
+ }
+ }
+ );
+ Ok(())
+ }
+
+ #[test]
fn test_assign() -> Result<()> {
let ir = ir_from_cc(
r#"
@@ -6280,6 +6487,78 @@
}
#[test]
+ fn test_impl_lt_different_operands() -> Result<()> {
+ let ir = ir_from_cc(
+ r#"#pragma clang lifetime_elision
+ struct SomeStruct1 final {
+ int i;
+ };
+ struct SomeStruct2 final {
+ inline bool operator==(const SomeStruct1& other) const {
+ return i == other.i;
+ }
+ inline bool operator<(const SomeStruct1& other) const {
+ return i < other.i;
+ };
+ int i;
+ };"#,
+ )?;
+ let rs_api = generate_bindings_tokens(ir)?.rs_api;
+ assert_rs_not_matches!(rs_api, quote! {impl PartialOrd});
+ Ok(())
+ }
+
+ #[test]
+ fn test_impl_lt_non_const_member_function() -> Result<()> {
+ let ir = ir_from_cc(
+ r#"#pragma clang lifetime_elision
+ struct SomeStruct final {
+ inline bool operator==(const SomeStruct& other) const {
+ return i == other.i;
+ }
+ int i;
+ bool operator<(const SomeStruct& other) /* no `const` here */;
+ };"#,
+ )?;
+ let rs_api = generate_bindings_tokens(ir)?.rs_api;
+ assert_rs_not_matches!(rs_api, quote! {impl PartialOrd});
+ Ok(())
+ }
+
+ #[test]
+ fn test_impl_lt_rhs_by_value() -> Result<()> {
+ let ir = ir_from_cc(
+ r#"#pragma clang lifetime_elision
+ struct SomeStruct final {
+ inline bool operator==(const SomeStruct& other) const {
+ return i == other.i;
+ }
+ int i;
+ bool operator<(SomeStruct other) const;
+ };"#,
+ )?;
+ let rs_api = generate_bindings_tokens(ir)?.rs_api;
+ assert_rs_not_matches!(rs_api, quote! {impl PartialOrd});
+ Ok(())
+ }
+
+ #[test]
+ fn test_impl_lt_missing_eq_impl() -> Result<()> {
+ let ir = ir_from_cc(
+ r#"#pragma clang lifetime_elision
+ struct SomeStruct final {
+ inline bool operator<(const SomeStruct& other) const {
+ return i < other.i;
+ }
+ int i;
+ };"#,
+ )?;
+ let rs_api = generate_bindings_tokens(ir)?.rs_api;
+ assert_rs_not_matches!(rs_api, quote! {impl PartialOrd});
+ Ok(())
+ }
+
+ #[test]
fn test_thunk_ident_function() -> Result<()> {
let ir = ir_from_cc("inline int foo() {}")?;
let func = retrieve_func(&ir, "foo");
diff --git a/rs_bindings_from_cc/test/struct/operators/operators.cc b/rs_bindings_from_cc/test/struct/operators/operators.cc
index 08c0e4f..ce9f767 100644
--- a/rs_bindings_from_cc/test/struct/operators/operators.cc
+++ b/rs_bindings_from_cc/test/struct/operators/operators.cc
@@ -9,10 +9,19 @@
return (i % 10) == (other.i % 10);
}
+bool OperandForOutOfLineDefinition::operator<(
+ const OperandForOutOfLineDefinition& other) const {
+ return (i % 10) < (other.i % 10);
+}
+
bool operator==(const OperandForFreeFunc& lhs, const OperandForFreeFunc& rhs) {
return (lhs.i % 10) == (rhs.i % 10);
}
+bool operator<(const OperandForFreeFunc& lhs, const OperandForFreeFunc& rhs) {
+ return (lhs.i % 10) < (rhs.i % 10);
+}
+
namespace test_namespace_bindings {
// bool operator==(const OperandForFreeFuncInDifferentNamespace& lhs,
diff --git a/rs_bindings_from_cc/test/struct/operators/operators.h b/rs_bindings_from_cc/test/struct/operators/operators.h
index 01c3278..4ceedba 100644
--- a/rs_bindings_from_cc/test/struct/operators/operators.h
+++ b/rs_bindings_from_cc/test/struct/operators/operators.h
@@ -26,6 +26,19 @@
return (i % 10) == (other.i % 10);
}
+ // Comparison with the same struct. Should generate:
+ // impl PartialOrd for TestStruct2
+ // `PartialOrd<TestStruct2>` also ok.
+ inline bool operator<(const TestStruct2& other) const {
+ return (i % 10) < (other.i % 10);
+ }
+
+ // Comparison with another struct. Shouldn't generate anything since the
+ // operands are not of the same type.
+ inline bool operator<(const TestStruct1& other) const {
+ return (i % 10) < (other.i % 10);
+ }
+
// Test that method names starting with "operator" are not confused with real
// operator names (e.g. accidentally treating "operator1" as an unrecognized /
// unsupported operator).
@@ -40,6 +53,11 @@
// Non-`inline` definition. Should generate:
// impl PartialEq for TestStructForOutOfLineDefinition
bool operator==(const OperandForOutOfLineDefinition& other) const;
+
+ // Non-`inline` definition. Should generate:
+ // impl PartialOrd for TestStructForOutOfLineDefinition
+ bool operator<(const OperandForOutOfLineDefinition& other) const;
+
int i;
};
@@ -53,6 +71,10 @@
// impl PartialEq for TestStructForFreeFunc.
bool operator==(const OperandForFreeFunc& lhs, const OperandForFreeFunc& rhs);
+// Non-member function. Should generate:
+// impl PartialOrd for TestStructForFreeFunc.
+bool operator<(const OperandForFreeFunc& lhs, const OperandForFreeFunc& rhs);
+
//////////////////////////////////////////////////////////////////////
struct OperandForFreeFuncInDifferentNamespace final {
diff --git a/rs_bindings_from_cc/test/struct/operators/operators_test.rs b/rs_bindings_from_cc/test/struct/operators/operators_test.rs
index 81c5941..5a1a3a9 100644
--- a/rs_bindings_from_cc/test/struct/operators/operators_test.rs
+++ b/rs_bindings_from_cc/test/struct/operators/operators_test.rs
@@ -28,6 +28,22 @@
assert_impl_all!(TestStruct2: PartialEq<TestStruct1>);
assert_not_impl_any!(TestStruct1: PartialEq<TestStruct2>);
}
+ #[test]
+ fn test_lt_member_func_same_operands() {
+ let s1 = TestStruct2 { i: 1001 };
+ let s2 = TestStruct2 { i: 2002 };
+ let s3 = TestStruct2 { i: 3000 };
+ assert!(s1 < s2);
+ assert!(s1 >= s3);
+ }
+
+ #[test]
+ fn test_lt_member_func_different_operands() {
+ // PartialOrd is only implemented if the operands of operator< are of the same
+ // type.
+ assert_not_impl_any!(TestStruct2: PartialOrd<TestStruct1>);
+ assert_not_impl_any!(TestStruct1: PartialOrd<TestStruct2>);
+ }
#[test]
fn test_non_operator_method_name() {
@@ -68,6 +84,22 @@
}
#[test]
+ fn test_lt_out_of_line_definition() {
+ let s1 = OperandForOutOfLineDefinition { i: 1001 };
+ let s2 = OperandForOutOfLineDefinition { i: 2002 };
+ let s3 = OperandForOutOfLineDefinition { i: 3000 };
+ assert!(s1 < s2);
+ assert!(s1 >= s3);
+ }
+
+ #[test]
+ fn test_lt_free_func() {
+ let s1 = OperandForFreeFunc { i: 1001 };
+ let s2 = OperandForFreeFunc { i: 2002 };
+ assert!(s1 < s2);
+ }
+
+ #[test]
fn test_many_operators_neg() {
let s = ManyOperators { i: 7 };
assert_eq!(-7, (-&s).i);