Support translating Rust modules into C++ namespaces.
PiperOrigin-RevId: 494838620
diff --git a/cc_bindings_from_rs/bindings.rs b/cc_bindings_from_rs/bindings.rs
index 383fc6c..1188ac9 100644
--- a/cc_bindings_from_rs/bindings.rs
+++ b/cc_bindings_from_rs/bindings.rs
@@ -499,7 +499,6 @@
fn format_fn(tcx: TyCtxt, local_def_id: LocalDefId) -> Result<MixedSnippet> {
let def_id: DefId = local_def_id.to_def_id(); // Convert LocalDefId to DefId.
- let item_name = tcx.item_name(def_id);
let mut symbol_name = {
// Call to `mono` is ok - doc comment requires no generic parameters (although
// lifetime parameters would have been okay).
@@ -564,13 +563,14 @@
}
};
+ let FullyQualifiedName { mod_path, name, .. } = FullyQualifiedName::new(tcx, def_id);
+
let mut cc_prereqs = CcPrerequisites::default();
let cc_tokens = {
let ret_type = format_ret_ty_for_cc(tcx, sig.output())
.context("Error formatting function return type")?
.into_tokens(&mut cc_prereqs);
- let fn_name =
- format_cc_ident(item_name.as_str()).context("Error formatting function name")?;
+ let fn_name = format_cc_ident(name.as_str()).context("Error formatting function name")?;
let arg_names = tcx
.fn_arg_names(def_id)
.iter()
@@ -590,7 +590,7 @@
.into_tokens(&mut cc_prereqs))
})
.collect::<Result<Vec<_>>>()?;
- if item_name.as_str() == symbol_name.name {
+ if name.as_str() == symbol_name.name {
quote! {
#doc_comment
extern "C" #ret_type #fn_name (
@@ -625,7 +625,8 @@
quote! {}
} else {
let crate_name = make_rs_ident(tcx.crate_name(LOCAL_CRATE).as_str());
- let fn_name = make_rs_ident(item_name.as_str());
+ let mod_path = mod_path.format_for_rs();
+ let fn_name = make_rs_ident(name.as_str());
let exported_name = make_rs_ident(symbol_name.name);
let ret_type = format_ty_for_rs(tcx, sig.output())?;
let arg_names = tcx
@@ -649,7 +650,7 @@
quote! {
#[no_mangle]
extern "C" fn #exported_name( #( #arg_names: #arg_types ),* ) -> #ret_type {
- :: #crate_name :: #fn_name( #( #arg_names ),* )
+ :: #crate_name :: #mod_path #fn_name( #( #arg_names ),* )
}
}
};
@@ -982,15 +983,31 @@
toposort::toposort(nodes, deps, preferred_order)
};
- let MixedSnippet { cc, rs } = {
- let ordered = ordered.into_iter().map(|def_id| bindings.remove(&def_id).unwrap());
- let failed = failed.into_iter().map(|def_id| {
- // TODO(b/260725687): Add test coverage for the error condition below.
- format_unsupported_def(tcx, def_id, anyhow!("Definition dependency cycle"))
- });
- ordered.chain(failed).sum()
- };
+ // Neighboring `ordered` items that belong to the same namespace should be put
+ // under a single `namespace foo::bar::baz { #items }`. We don't just translate
+ // `mod foo` => `namespace foo` in a top-down fashion, because of the need to
+ // reorder the bindings of individual items (see `CcPrerequisites::defs`
+ // toposort above).
+ let ordered = ordered
+ .into_iter()
+ .group_by(|local_def_id| FullyQualifiedName::new(tcx, local_def_id.to_def_id()).mod_path)
+ .into_iter()
+ .map(|(mod_path, def_ids)| {
+ let MixedSnippet { rs, cc: CcSnippet { tokens, prereqs } } =
+ def_ids.map(|def_id| bindings.remove(&def_id).unwrap()).sum();
+ let tokens = mod_path.format_with_cc_body(tokens)?;
+ Ok(MixedSnippet { rs, cc: CcSnippet { tokens, prereqs } })
+ })
+ .collect::<Result<Vec<_>>>()?;
+ // Replace `failed` ids with unsupported-item comments.
+ let failed = failed.into_iter().map(|def_id| {
+ // TODO(b/260725687): Add test coverage for the error condition below.
+ format_unsupported_def(tcx, def_id, anyhow!("Definition dependency cycle"))
+ });
+
+ // Generate top-level elements of the C++ header file.
+ let MixedSnippet { cc, rs } = ordered.into_iter().chain(failed).sum();
let h_body = {
// TODO(b/254690602): Decide whether using `#crate_name` as the name of the
// top-level namespace is okay (e.g. investigate if this name is globally
@@ -1280,6 +1297,43 @@
}
#[test]
+ fn test_generated_bindings_modules() {
+ let test_src = r#"
+ pub mod some_module {
+ pub fn some_func() {}
+ }
+ "#;
+ test_generated_bindings(test_src, |bindings| {
+ let bindings = bindings.expect("Test expects success");
+ assert_cc_matches!(
+ bindings.h_body,
+ quote! {
+ namespace rust_out {
+ ... // TODO(b/258265044): This `...` should be removed.
+ // (there should be no unsupported-item comment
+ // for the module item).
+ namespace some_module {
+ ...
+ inline void some_func() { ... }
+ ...
+ } // namespace some_module
+ } // namespace rust_out
+ }
+ );
+ assert_rs_matches!(
+ bindings.rs_body,
+ quote! {
+ #[no_mangle]
+ extern "C"
+ fn ...() -> () {
+ ::rust_out::some_module::some_func()
+ }
+ }
+ );
+ });
+ }
+
+ #[test]
fn test_generated_bindings_non_pub_items() {
let test_src = r#"
#![allow(dead_code)]
@@ -1292,6 +1346,15 @@
x: i32,
y: i32,
}
+
+ pub mod public_module {
+ fn priv_func_in_pub_module() {}
+ }
+
+ mod private_module {
+ pub fn pub_func_in_priv_module() { priv_func_in_priv_module() }
+ fn priv_func_in_priv_module() {}
+ }
"#;
test_generated_bindings(test_src, |bindings| {
let bindings = bindings.expect("Test expects success");
@@ -1301,6 +1364,18 @@
assert_rs_not_matches!(bindings.rs_body, quote! { private_function });
assert_cc_not_matches!(bindings.h_body, quote! { PrivateStruct });
assert_rs_not_matches!(bindings.rs_body, quote! { PrivateStruct });
+ assert_cc_not_matches!(bindings.h_body, quote! { priv_func_in_priv_module });
+ assert_rs_not_matches!(bindings.rs_body, quote! { priv_func_in_priv_module });
+ assert_cc_not_matches!(bindings.h_body, quote! { priv_func_in_pub_module });
+ assert_rs_not_matches!(bindings.rs_body, quote! { priv_func_in_pub_module });
+
+ // TODO(b/258265044): The test expectations below are (temporarily) incorrect. A public
+ // function in a private module is effectively private - `format_crate` shouldn't
+ // just use `tcx.local_visibility`.
+ assert_cc_matches!(bindings.h_body, quote! { private_module });
+ assert_rs_matches!(bindings.rs_body, quote! { private_module });
+ assert_cc_matches!(bindings.h_body, quote! { pub_func_in_priv_module });
+ assert_rs_matches!(bindings.rs_body, quote! { pub_func_in_priv_module });
});
}
diff --git a/cc_bindings_from_rs/cc_bindings_from_rs.rs b/cc_bindings_from_rs/cc_bindings_from_rs.rs
index e58f2aa..31d2ab0 100644
--- a/cc_bindings_from_rs/cc_bindings_from_rs.rs
+++ b/cc_bindings_from_rs/cc_bindings_from_rs.rs
@@ -313,11 +313,13 @@
let rs_input_path = self.tempdir.path().join("test_crate.rs");
std::fs::write(
&rs_input_path,
- r#" pub fn public_function() {
- private_function()
- }
+ r#" pub mod public_module {
+ pub fn public_function() {
+ private_function()
+ }
- fn private_function() {}
+ fn private_function() {}
+ }
"#,
)?;
@@ -380,7 +382,7 @@
let mut marked_pattern = expected.to_string();
marked_pattern.insert_str(longest_matching_expectation_len, "!!!>>>");
panic!(
- "h_body didn't match expectations:\n\
+ "Mismatched expectations:\n\
#### Actual body (first mismatch follows the \"!!!>>>\" marker):\n\
{marked_body}\n\
#### Mismatched pattern (mismatch follows the \"!!!>>>\" marker):\n\
@@ -404,6 +406,12 @@
#pragma once
namespace test_crate {
+
+// Error generating bindings for `public_module` defined at
+// /tmp/.ANY_IDENTIFIER_CHARACTERS/test_crate.rs:1:2: 1:23: Unsupported
+// rustc_hir::hir::ItemKind: module
+
+namespace public_module {
namespace __crubit_internal {
extern "C" void
__crubit_thunk__ANY_IDENTIFIER_CHARACTERS();
@@ -412,6 +420,7 @@
return __crubit_internal::
__crubit_thunk__ANY_IDENTIFIER_CHARACTERS();
}
+} // namespace public_module
} // namespace test_crate"#,
);
@@ -425,8 +434,9 @@
#![allow(improper_ctypes_definitions)]
#[no_mangle]
-extern "C" fn __crubit_thunk__ANY_IDENTIFIER_CHARACTERS() -> () {
- ::test_crate::public_function()
+extern "C" fn __crubit_thunk__ANY_IDENTIFIER_CHARACTERS()
+-> () {
+ ::test_crate::public_module::public_function()
}
"#,
);
diff --git a/cc_bindings_from_rs/test/modules/BUILD b/cc_bindings_from_rs/test/modules/BUILD
new file mode 100644
index 0000000..52578e0
--- /dev/null
+++ b/cc_bindings_from_rs/test/modules/BUILD
@@ -0,0 +1,37 @@
+"""End-to-end tests of `cc_bindings_from_rs`, focusing on
+module/namespace-related bindings."""
+
+load(
+ "@rules_rust//rust:defs.bzl",
+ "rust_library",
+)
+load(
+ "//cc_bindings_from_rs/bazel_support:cc_bindings_from_rust_rule.bzl",
+ "cc_bindings_from_rust",
+)
+
+licenses(["notice"])
+
+rust_library(
+ name = "modules",
+ testonly = 1,
+ srcs = ["modules.rs"],
+ deps = [
+ "//common:rust_allocator_shims",
+ ],
+)
+
+cc_bindings_from_rust(
+ name = "modules_cc_api",
+ testonly = 1,
+ crate = ":modules",
+)
+
+cc_test(
+ name = "modules_test",
+ srcs = ["modules_test.cc"],
+ deps = [
+ ":modules_cc_api",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/cc_bindings_from_rs/test/modules/modules.rs b/cc_bindings_from_rs/test/modules/modules.rs
new file mode 100644
index 0000000..0922e5a
--- /dev/null
+++ b/cc_bindings_from_rs/test/modules/modules.rs
@@ -0,0 +1,12 @@
+// Part of the Crubit project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+//! This crate is used as a test input for `cc_bindings_from_rs` and the
+//! generated C++ bindings are then tested via `modules_test.cc`.
+
+pub mod basic_module {
+ pub fn add_i32(x: i32, y: i32) -> i32 {
+ x + y
+ }
+}
diff --git a/cc_bindings_from_rs/test/modules/modules_test.cc b/cc_bindings_from_rs/test/modules/modules_test.cc
new file mode 100644
index 0000000..3359b8a
--- /dev/null
+++ b/cc_bindings_from_rs/test/modules/modules_test.cc
@@ -0,0 +1,19 @@
+// Part of the Crubit project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "cc_bindings_from_rs/test/modules/modules_cc_api.h"
+
+namespace crubit {
+namespace {
+
+TEST(ModulesTest, BasicModule) {
+ ASSERT_EQ(123 + 456, modules::basic_module::add_i32(123, 456));
+}
+
+} // namespace
+} // namespace crubit
diff --git a/common/code_gen_utils.rs b/common/code_gen_utils.rs
index 46f1f8a..9f84dc3 100644
--- a/common/code_gen_utils.rs
+++ b/common/code_gen_utils.rs
@@ -68,10 +68,26 @@
}
pub fn format_for_cc(&self) -> Result<TokenStream> {
- let namespace_cc_idents =
- self.0.iter().map(|ns| format_cc_ident(ns)).collect::<Result<Vec<_>>>()?;
+ let namespace_cc_idents = self.cc_idents()?;
Ok(quote! { #(#namespace_cc_idents::)* })
}
+
+ pub fn format_with_cc_body(&self, body: TokenStream) -> Result<TokenStream> {
+ if self.0.is_empty() {
+ Ok(body)
+ } else {
+ let namespace_cc_idents = self.cc_idents()?;
+ Ok(quote! {
+ namespace #(#namespace_cc_idents)::* {
+ #body
+ }
+ })
+ }
+ }
+
+ fn cc_idents(&self) -> Result<Vec<TokenStream>> {
+ self.0.iter().map(|ns| format_cc_ident(ns)).collect()
+ }
}
/// `CcInclude` represents a single `#include ...` directive in C++.
@@ -476,4 +492,26 @@
assert!(msg.contains("`reinterpret_cast`"));
assert!(msg.contains("C++ reserved keyword"));
}
+
+ #[test]
+ fn test_namespace_qualifier_format_with_cc_body_top_level_namespace() {
+ let ns = create_namespace_qualifier_for_tests(&[]);
+ assert_cc_matches!(
+ ns.format_with_cc_body(quote! { cc body goes here }).unwrap(),
+ quote! { cc body goes here },
+ );
+ }
+
+ #[test]
+ fn test_namespace_qualifier_format_with_cc_body_nested_namespace() {
+ let ns = create_namespace_qualifier_for_tests(&["foo", "bar", "baz"]);
+ assert_cc_matches!(
+ ns.format_with_cc_body(quote! { cc body goes here }).unwrap(),
+ quote! {
+ namespace foo::bar::baz {
+ cc body goes here
+ } // namespace foo::bar::baz
+ },
+ );
+ }
}