Fix generating bindings for functions named `operator1`.
I realized that `generate_func` in `src_code_gen.rs` was incorrectly
looking just at the presence of the "operator" prefix to determine
if something is an operator VS a regular function.
PiperOrigin-RevId: 427337548
diff --git a/rs_bindings_from_cc/importer.cc b/rs_bindings_from_cc/importer.cc
index 35ad45f..455670f 100644
--- a/rs_bindings_from_cc/importer.cc
+++ b/rs_bindings_from_cc/importer.cc
@@ -1011,12 +1011,7 @@
// clang-format off
#define OVERLOADED_OPERATOR(name, spelling, ...) \
case clang::OO_##name: { \
- std::string name = "operator"; \
- if ('a' <= spelling[0] && spelling[0] <= 'z') { \
- absl::StrAppend(&name, " "); \
- } \
- absl::StrAppend(&name, spelling); \
- return {Identifier(std::move(name))}; \
+ return {Operator(spelling)}; \
}
#include "third_party/llvm/llvm-project/clang/include/clang/Basic/OperatorKinds.def"
#undef OVERLOADED_OPERATOR
diff --git a/rs_bindings_from_cc/ir.cc b/rs_bindings_from_cc/ir.cc
index 0b2833e..7b502b2 100644
--- a/rs_bindings_from_cc/ir.cc
+++ b/rs_bindings_from_cc/ir.cc
@@ -139,10 +139,9 @@
return result;
}
-nlohmann::json FuncParam::ToJson() const {
+nlohmann::json Operator::ToJson() const {
nlohmann::json result;
- result["type"] = type.ToJson();
- result["identifier"] = identifier.ToJson();
+ result["name"] = name_;
return result;
}
@@ -155,6 +154,26 @@
}
}
+nlohmann::json ToJson(const UnqualifiedIdentifier& unqualified_identifier) {
+ nlohmann::json result;
+ if (auto* id = std::get_if<Identifier>(&unqualified_identifier)) {
+ result["Identifier"] = id->ToJson();
+ } else if (auto* op = std::get_if<Operator>(&unqualified_identifier)) {
+ result["Operator"] = op->ToJson();
+ } else {
+ SpecialName special_name = std::get<SpecialName>(unqualified_identifier);
+ result[SpecialNameToString(special_name)] = nullptr;
+ }
+ return result;
+}
+
+nlohmann::json FuncParam::ToJson() const {
+ nlohmann::json result;
+ result["type"] = type.ToJson();
+ result["identifier"] = identifier.ToJson();
+ return result;
+}
+
std::ostream& operator<<(std::ostream& o, const SpecialName& special_name) {
return o << SpecialNameToString(special_name);
}
@@ -192,11 +211,7 @@
nlohmann::json Func::ToJson() const {
nlohmann::json func;
- if (auto* id = std::get_if<Identifier>(&name)) {
- func["name"]["Identifier"] = id->ToJson();
- } else {
- func["name"][SpecialNameToString(std::get<SpecialName>(name))] = nullptr;
- }
+ func["name"] = rs_bindings_from_cc::ToJson(name);
func["owning_target"] = owning_target.value();
if (doc_comment) {
func["doc_comment"] = *doc_comment;
diff --git a/rs_bindings_from_cc/ir.h b/rs_bindings_from_cc/ir.h
index 7e8598e..5cdf8b5 100644
--- a/rs_bindings_from_cc/ir.h
+++ b/rs_bindings_from_cc/ir.h
@@ -224,6 +224,27 @@
return o << std::setw(internal::kJsonIndent) << id.Ident();
}
+class Operator {
+ public:
+ explicit Operator(std::string name) : name_(std::move(name)) {
+ CHECK(!name_.empty()) << "Operator name cannot be empty.";
+ }
+
+ absl::string_view Name() const { return name_; }
+
+ nlohmann::json ToJson() const;
+
+ private:
+ std::string name_;
+};
+
+inline std::ostream& operator<<(std::ostream& stream, const Operator& op) {
+ char first_char = op.Name()[0];
+ const char* separator = ('a' <= first_char) && (first_char <= 'z') ? " " : "";
+ return stream << std::setw(internal::kJsonIndent) << "`operator" << separator
+ << op.Name() << "`";
+}
+
// A function parameter.
//
// Examples:
@@ -253,7 +274,8 @@
// Note that constructors are given a separate variant, so that we can treat
// them differently. After all, they are not invoked or defined like normal
// functions.
-using UnqualifiedIdentifier = std::variant<Identifier, SpecialName>;
+using UnqualifiedIdentifier = std::variant<Identifier, Operator, SpecialName>;
+nlohmann::json ToJson(const UnqualifiedIdentifier& unqualified_identifier);
struct MemberFuncMetadata {
enum ReferenceQualification : char {
diff --git a/rs_bindings_from_cc/ir.rs b/rs_bindings_from_cc/ir.rs
index 9967c08..e9302bc 100644
--- a/rs_bindings_from_cc/ir.rs
+++ b/rs_bindings_from_cc/ir.rs
@@ -134,6 +134,27 @@
}
}
+#[derive(PartialEq, Eq, Hash, Clone, Deserialize)]
+pub struct Operator {
+ pub name: String,
+}
+
+impl Operator {
+ pub fn cc_name(&self) -> String {
+ let separator = match self.name.chars().next() {
+ Some(c) if c.is_alphabetic() => " ",
+ _ => "",
+ };
+ format!("operator{separator}{name}", separator = separator, name = self.name)
+ }
+}
+
+impl fmt::Debug for Operator {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.write_str(&format!("\"{}\"", &self.cc_name()))
+ }
+}
+
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Deserialize)]
#[serde(transparent)]
pub struct DeclId(pub usize);
@@ -160,6 +181,7 @@
#[derive(PartialEq, Eq, Hash, Clone, Deserialize)]
pub enum UnqualifiedIdentifier {
Identifier(Identifier),
+ Operator(Operator),
Constructor,
Destructor,
}
@@ -177,6 +199,7 @@
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
UnqualifiedIdentifier::Identifier(identifier) => fmt::Debug::fmt(identifier, f),
+ UnqualifiedIdentifier::Operator(op) => fmt::Debug::fmt(op, f),
UnqualifiedIdentifier::Constructor => f.write_str("Constructor"),
UnqualifiedIdentifier::Destructor => f.write_str("Destructor"),
}
diff --git a/rs_bindings_from_cc/ir_from_cc_test.rs b/rs_bindings_from_cc/ir_from_cc_test.rs
index 26ffb5b..0505d00 100644
--- a/rs_bindings_from_cc/ir_from_cc_test.rs
+++ b/rs_bindings_from_cc/ir_from_cc_test.rs
@@ -1119,7 +1119,7 @@
};"#,
)
.unwrap();
- let function_names: HashSet<&str> = ir
+ let operator_names: HashSet<&str> = ir
.functions()
.filter(|f| {
// Only SomeStruct member functions (excluding stddef.h stuff).
@@ -1129,14 +1129,14 @@
.unwrap_or_default()
})
.flat_map(|f| match &f.name {
- UnqualifiedIdentifier::Identifier(id) => Some(id.identifier.as_ref()),
+ UnqualifiedIdentifier::Operator(op) => Some(op.name.as_ref()),
_ => None,
})
.collect();
- assert!(function_names.contains("operator="));
- assert!(function_names.contains("operator new"));
- assert!(function_names.contains("operator new[]"));
- assert!(function_names.contains("operator=="));
+ assert!(operator_names.contains("="));
+ assert!(operator_names.contains("new"));
+ assert!(operator_names.contains("new[]"));
+ assert!(operator_names.contains("=="));
}
#[test]
diff --git a/rs_bindings_from_cc/src_code_gen.rs b/rs_bindings_from_cc/src_code_gen.rs
index 77f5892..7297ae2 100644
--- a/rs_bindings_from_cc/src_code_gen.rs
+++ b/rs_bindings_from_cc/src_code_gen.rs
@@ -180,6 +180,7 @@
let func_name = match &func.name {
UnqualifiedIdentifier::Identifier(id) => id.identifier.clone(),
+ UnqualifiedIdentifier::Operator(op) => op.cc_name(),
UnqualifiedIdentifier::Destructor => {
format!("~{}", record.expect("destructor must be associated with a record"))
}
@@ -270,7 +271,7 @@
let func_name: syn::Ident;
let format_first_param_as_self: bool;
match &func.name {
- UnqualifiedIdentifier::Identifier(id) if id.identifier == "operator==" => {
+ UnqualifiedIdentifier::Operator(op) if op.name == "==" => {
if param_type_kinds.len() != 2 {
bail!("Unexpected number of parameters in operator==: {:?}", func);
}
@@ -294,7 +295,7 @@
_ => return make_unsupported_result("operator== where operands are not const references"),
};
}
- UnqualifiedIdentifier::Identifier(id) if id.identifier.starts_with("operator") => {
+ UnqualifiedIdentifier::Operator(_) => {
return make_unsupported_result("Bindings for this kind of operator are not supported");
}
UnqualifiedIdentifier::Identifier(id) => {
@@ -466,7 +467,7 @@
}
let func_body = match &func.name {
- UnqualifiedIdentifier::Identifier(_) => {
+ UnqualifiedIdentifier::Identifier(_) | UnqualifiedIdentifier::Operator(_) => {
let mut body = quote! { crate::detail::#thunk_ident( #( #thunk_args ),* ) };
// Only need to wrap everything in an `unsafe { ... }` block if
// the *whole* api function is safe.
@@ -1417,6 +1418,10 @@
let thunk_ident = thunk_ident(func);
let implementation_function = match &func.name {
+ UnqualifiedIdentifier::Operator(op) => {
+ let name = syn::parse_str::<TokenStream>(&op.name)?;
+ quote! { operator #name }
+ }
UnqualifiedIdentifier::Identifier(id) => {
let fn_ident = format_cc_ident(&id.identifier);
let static_method_metadata = func
@@ -1463,7 +1468,8 @@
None => false,
Some(meta) => match &func.name {
UnqualifiedIdentifier::Constructor | UnqualifiedIdentifier::Destructor => false,
- UnqualifiedIdentifier::Identifier(_) => meta.instance_method_metadata.is_some(),
+ UnqualifiedIdentifier::Identifier(_) | UnqualifiedIdentifier::Operator(_) =>
+ meta.instance_method_metadata.is_some(),
},
};
let (implementation_function, arg_expressions) = if !needs_this_deref {
@@ -1520,7 +1526,7 @@
mod tests {
use super::*;
use anyhow::anyhow;
- use ir_testing::{ir_from_cc, ir_from_cc_dependency, ir_func, ir_record};
+ use ir_testing::{ir_from_cc, ir_from_cc_dependency, ir_func, ir_record, retrieve_func};
use token_stream_matchers::{
assert_cc_matches, assert_cc_not_matches, assert_rs_matches, assert_rs_not_matches,
};
@@ -2846,13 +2852,7 @@
];
for (type_str, is_copy_expected) in tests.iter() {
let ir = ir_from_cc(&template.replace("PARAM_TYPE", type_str))?;
- let f = ir
- .functions()
- .find(|f| match &f.name {
- UnqualifiedIdentifier::Identifier(id) => id.identifier == "func",
- _ => false,
- })
- .expect("IR should contain a function named 'func'");
+ let f = retrieve_func(&ir, "func");
let t = RsTypeKind::new(&f.params[0].type_.rs_type, &ir)?;
assert_eq!(*is_copy_expected, t.implements_copy(), "Testing '{}'", type_str);
}
@@ -2868,20 +2868,8 @@
void bar(SomeStruct& bar_param);",
)?;
let record = ir.records().next().unwrap();
- let foo_func = ir
- .functions()
- .find(|f| {
- matches!(&f.name, UnqualifiedIdentifier::Identifier(id)
- if id.identifier == "foo")
- })
- .unwrap();
- let bar_func = ir
- .functions()
- .find(|f| {
- matches!(&f.name, UnqualifiedIdentifier::Identifier(id)
- if id.identifier == "bar")
- })
- .unwrap();
+ let foo_func = retrieve_func(&ir, "foo");
+ let bar_func = retrieve_func(&ir, "bar");
// const-ref + lifetimes in C++ ===> shared-ref in Rust
assert_eq!(foo_func.params.len(), 1);
@@ -2909,13 +2897,7 @@
void foo(const SomeStruct& foo_param);",
)?;
let record = ir.records().next().unwrap();
- let foo_func = ir
- .functions()
- .find(|f| {
- matches!(&f.name, UnqualifiedIdentifier::Identifier(id)
- if id.identifier == "foo")
- })
- .unwrap();
+ let foo_func = retrieve_func(&ir, "foo");
// const-ref + *no* lifetimes in C++ ===> const-pointer in Rust
assert_eq!(foo_func.params.len(), 1);
diff --git a/rs_bindings_from_cc/test/struct/operators/operators.h b/rs_bindings_from_cc/test/struct/operators/operators.h
index 9ec5d28..5e7cdb9 100644
--- a/rs_bindings_from_cc/test/struct/operators/operators.h
+++ b/rs_bindings_from_cc/test/struct/operators/operators.h
@@ -24,6 +24,11 @@
return (i % 10) == (other.i % 10);
}
+ // Test that method names starting with "operator" are not confused with real
+ // operator names (e.g. accidentally treating "operator1" as an unrecognized /
+ // unsupported operator).
+ inline int operator1() const { return i; }
+
int i;
};
diff --git a/rs_bindings_from_cc/test/struct/operators/test.rs b/rs_bindings_from_cc/test/struct/operators/test.rs
index b9a2f8b..76735da 100644
--- a/rs_bindings_from_cc/test/struct/operators/test.rs
+++ b/rs_bindings_from_cc/test/struct/operators/test.rs
@@ -32,6 +32,12 @@
}
#[test]
+ fn test_non_operator_method_name() {
+ let s2 = TestStruct2 { i: 2005 };
+ assert_eq!(2005, s2.operator1());
+ }
+
+ #[test]
fn test_eq_out_of_line_definition() {
let s1 = OperandForOutOfLineDefinition { i: 1005 };
let s2 = OperandForOutOfLineDefinition { i: 2005 };