Support translating Rust modules into C++ namespaces.

PiperOrigin-RevId: 494838620
diff --git a/cc_bindings_from_rs/bindings.rs b/cc_bindings_from_rs/bindings.rs
index 383fc6c..1188ac9 100644
--- a/cc_bindings_from_rs/bindings.rs
+++ b/cc_bindings_from_rs/bindings.rs
@@ -499,7 +499,6 @@
 fn format_fn(tcx: TyCtxt, local_def_id: LocalDefId) -> Result<MixedSnippet> {
     let def_id: DefId = local_def_id.to_def_id(); // Convert LocalDefId to DefId.
 
-    let item_name = tcx.item_name(def_id);
     let mut symbol_name = {
         // Call to `mono` is ok - doc comment requires no generic parameters (although
         // lifetime parameters would have been okay).
@@ -564,13 +563,14 @@
         }
     };
 
+    let FullyQualifiedName { mod_path, name, .. } = FullyQualifiedName::new(tcx, def_id);
+
     let mut cc_prereqs = CcPrerequisites::default();
     let cc_tokens = {
         let ret_type = format_ret_ty_for_cc(tcx, sig.output())
             .context("Error formatting function return type")?
             .into_tokens(&mut cc_prereqs);
-        let fn_name =
-            format_cc_ident(item_name.as_str()).context("Error formatting function name")?;
+        let fn_name = format_cc_ident(name.as_str()).context("Error formatting function name")?;
         let arg_names = tcx
             .fn_arg_names(def_id)
             .iter()
@@ -590,7 +590,7 @@
                     .into_tokens(&mut cc_prereqs))
             })
             .collect::<Result<Vec<_>>>()?;
-        if item_name.as_str() == symbol_name.name {
+        if name.as_str() == symbol_name.name {
             quote! {
                 #doc_comment
                 extern "C" #ret_type #fn_name (
@@ -625,7 +625,8 @@
         quote! {}
     } else {
         let crate_name = make_rs_ident(tcx.crate_name(LOCAL_CRATE).as_str());
-        let fn_name = make_rs_ident(item_name.as_str());
+        let mod_path = mod_path.format_for_rs();
+        let fn_name = make_rs_ident(name.as_str());
         let exported_name = make_rs_ident(symbol_name.name);
         let ret_type = format_ty_for_rs(tcx, sig.output())?;
         let arg_names = tcx
@@ -649,7 +650,7 @@
         quote! {
             #[no_mangle]
             extern "C" fn #exported_name( #( #arg_names: #arg_types ),* ) -> #ret_type {
-                :: #crate_name :: #fn_name( #( #arg_names ),* )
+                :: #crate_name :: #mod_path #fn_name( #( #arg_names ),* )
             }
         }
     };
@@ -982,15 +983,31 @@
             toposort::toposort(nodes, deps, preferred_order)
         };
 
-    let MixedSnippet { cc, rs } = {
-        let ordered = ordered.into_iter().map(|def_id| bindings.remove(&def_id).unwrap());
-        let failed = failed.into_iter().map(|def_id| {
-            // TODO(b/260725687): Add test coverage for the error condition below.
-            format_unsupported_def(tcx, def_id, anyhow!("Definition dependency cycle"))
-        });
-        ordered.chain(failed).sum()
-    };
+    // Neighboring `ordered` items that belong to the same namespace should be put
+    // under a single `namespace foo::bar::baz { #items }`.  We don't just translate
+    // `mod foo` => `namespace foo` in a top-down fashion, because of the need to
+    // reorder the bindings of individual items (see `CcPrerequisites::defs`
+    // toposort above).
+    let ordered = ordered
+        .into_iter()
+        .group_by(|local_def_id| FullyQualifiedName::new(tcx, local_def_id.to_def_id()).mod_path)
+        .into_iter()
+        .map(|(mod_path, def_ids)| {
+            let MixedSnippet { rs, cc: CcSnippet { tokens, prereqs } } =
+                def_ids.map(|def_id| bindings.remove(&def_id).unwrap()).sum();
+            let tokens = mod_path.format_with_cc_body(tokens)?;
+            Ok(MixedSnippet { rs, cc: CcSnippet { tokens, prereqs } })
+        })
+        .collect::<Result<Vec<_>>>()?;
 
+    // Replace `failed` ids with unsupported-item comments.
+    let failed = failed.into_iter().map(|def_id| {
+        // TODO(b/260725687): Add test coverage for the error condition below.
+        format_unsupported_def(tcx, def_id, anyhow!("Definition dependency cycle"))
+    });
+
+    // Generate top-level elements of the C++ header file.
+    let MixedSnippet { cc, rs } = ordered.into_iter().chain(failed).sum();
     let h_body = {
         // TODO(b/254690602): Decide whether using `#crate_name` as the name of the
         // top-level namespace is okay (e.g. investigate if this name is globally
@@ -1280,6 +1297,43 @@
     }
 
     #[test]
+    fn test_generated_bindings_modules() {
+        let test_src = r#"
+                pub mod some_module {
+                    pub fn some_func() {}
+                }
+            "#;
+        test_generated_bindings(test_src, |bindings| {
+            let bindings = bindings.expect("Test expects success");
+            assert_cc_matches!(
+                bindings.h_body,
+                quote! {
+                    namespace rust_out {
+                        ... // TODO(b/258265044): This `...` should be removed.
+                            // (there should be no unsupported-item comment
+                            // for the module item).
+                        namespace some_module {
+                            ...
+                            inline void some_func() { ... }
+                            ...
+                        }  // namespace some_module
+                    }  // namespace rust_out
+                }
+            );
+            assert_rs_matches!(
+                bindings.rs_body,
+                quote! {
+                    #[no_mangle]
+                    extern "C"
+                    fn ...() -> () {
+                        ::rust_out::some_module::some_func()
+                    }
+                }
+            );
+        });
+    }
+
+    #[test]
     fn test_generated_bindings_non_pub_items() {
         let test_src = r#"
                 #![allow(dead_code)]
@@ -1292,6 +1346,15 @@
                     x: i32,
                     y: i32,
                 }
+
+                pub mod public_module {
+                    fn priv_func_in_pub_module() {}
+                }
+
+                mod private_module {
+                    pub fn pub_func_in_priv_module() { priv_func_in_priv_module() }
+                    fn priv_func_in_priv_module() {}
+                }
             "#;
         test_generated_bindings(test_src, |bindings| {
             let bindings = bindings.expect("Test expects success");
@@ -1301,6 +1364,18 @@
             assert_rs_not_matches!(bindings.rs_body, quote! { private_function });
             assert_cc_not_matches!(bindings.h_body, quote! { PrivateStruct });
             assert_rs_not_matches!(bindings.rs_body, quote! { PrivateStruct });
+            assert_cc_not_matches!(bindings.h_body, quote! { priv_func_in_priv_module });
+            assert_rs_not_matches!(bindings.rs_body, quote! { priv_func_in_priv_module });
+            assert_cc_not_matches!(bindings.h_body, quote! { priv_func_in_pub_module });
+            assert_rs_not_matches!(bindings.rs_body, quote! { priv_func_in_pub_module });
+
+            // TODO(b/258265044): The test expectations below are (temporarily) incorrect. A public
+            // function in a private module is effectively private - `format_crate` shouldn't
+            // just use `tcx.local_visibility`.
+            assert_cc_matches!(bindings.h_body, quote! { private_module });
+            assert_rs_matches!(bindings.rs_body, quote! { private_module });
+            assert_cc_matches!(bindings.h_body, quote! { pub_func_in_priv_module });
+            assert_rs_matches!(bindings.rs_body, quote! { pub_func_in_priv_module });
         });
     }
 
diff --git a/cc_bindings_from_rs/cc_bindings_from_rs.rs b/cc_bindings_from_rs/cc_bindings_from_rs.rs
index e58f2aa..31d2ab0 100644
--- a/cc_bindings_from_rs/cc_bindings_from_rs.rs
+++ b/cc_bindings_from_rs/cc_bindings_from_rs.rs
@@ -313,11 +313,13 @@
             let rs_input_path = self.tempdir.path().join("test_crate.rs");
             std::fs::write(
                 &rs_input_path,
-                r#" pub fn public_function() {
-                        private_function()
-                    }
+                r#" pub mod public_module {
+                        pub fn public_function() {
+                            private_function()
+                        }
 
-                    fn private_function() {}
+                        fn private_function() {}
+                    }
                 "#,
             )?;
 
@@ -380,7 +382,7 @@
             let mut marked_pattern = expected.to_string();
             marked_pattern.insert_str(longest_matching_expectation_len, "!!!>>>");
             panic!(
-                "h_body didn't match expectations:\n\
+                "Mismatched expectations:\n\
                     #### Actual body (first mismatch follows the \"!!!>>>\" marker):\n\
                     {marked_body}\n\
                     #### Mismatched pattern (mismatch follows the \"!!!>>>\" marker):\n\
@@ -404,6 +406,12 @@
 #pragma once
 
 namespace test_crate {
+
+// Error generating bindings for `public_module` defined at
+// /tmp/.ANY_IDENTIFIER_CHARACTERS/test_crate.rs:1:2: 1:23: Unsupported
+// rustc_hir::hir::ItemKind: module
+
+namespace public_module {
 namespace __crubit_internal {
 extern "C" void
 __crubit_thunk__ANY_IDENTIFIER_CHARACTERS();
@@ -412,6 +420,7 @@
   return __crubit_internal::
       __crubit_thunk__ANY_IDENTIFIER_CHARACTERS();
 }
+}  // namespace public_module
 }  // namespace test_crate"#,
         );
 
@@ -425,8 +434,9 @@
 #![allow(improper_ctypes_definitions)]
 
 #[no_mangle]
-extern "C" fn __crubit_thunk__ANY_IDENTIFIER_CHARACTERS() -> () {
-    ::test_crate::public_function()
+extern "C" fn __crubit_thunk__ANY_IDENTIFIER_CHARACTERS()
+-> () {
+    ::test_crate::public_module::public_function()
 }
 "#,
         );
diff --git a/cc_bindings_from_rs/test/modules/BUILD b/cc_bindings_from_rs/test/modules/BUILD
new file mode 100644
index 0000000..52578e0
--- /dev/null
+++ b/cc_bindings_from_rs/test/modules/BUILD
@@ -0,0 +1,37 @@
+"""End-to-end tests of `cc_bindings_from_rs`, focusing on
+module/namespace-related bindings."""
+
+load(
+    "@rules_rust//rust:defs.bzl",
+    "rust_library",
+)
+load(
+    "//cc_bindings_from_rs/bazel_support:cc_bindings_from_rust_rule.bzl",
+    "cc_bindings_from_rust",
+)
+
+licenses(["notice"])
+
+rust_library(
+    name = "modules",
+    testonly = 1,
+    srcs = ["modules.rs"],
+    deps = [
+        "//common:rust_allocator_shims",
+    ],
+)
+
+cc_bindings_from_rust(
+    name = "modules_cc_api",
+    testonly = 1,
+    crate = ":modules",
+)
+
+cc_test(
+    name = "modules_test",
+    srcs = ["modules_test.cc"],
+    deps = [
+        ":modules_cc_api",
+        "@com_google_googletest//:gtest_main",
+    ],
+)
diff --git a/cc_bindings_from_rs/test/modules/modules.rs b/cc_bindings_from_rs/test/modules/modules.rs
new file mode 100644
index 0000000..0922e5a
--- /dev/null
+++ b/cc_bindings_from_rs/test/modules/modules.rs
@@ -0,0 +1,12 @@
+// Part of the Crubit project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+//! This crate is used as a test input for `cc_bindings_from_rs` and the
+//! generated C++ bindings are then tested via `modules_test.cc`.
+
+pub mod basic_module {
+    pub fn add_i32(x: i32, y: i32) -> i32 {
+        x + y
+    }
+}
diff --git a/cc_bindings_from_rs/test/modules/modules_test.cc b/cc_bindings_from_rs/test/modules/modules_test.cc
new file mode 100644
index 0000000..3359b8a
--- /dev/null
+++ b/cc_bindings_from_rs/test/modules/modules_test.cc
@@ -0,0 +1,19 @@
+// Part of the Crubit project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <utility>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "cc_bindings_from_rs/test/modules/modules_cc_api.h"
+
+namespace crubit {
+namespace {
+
+TEST(ModulesTest, BasicModule) {
+  ASSERT_EQ(123 + 456, modules::basic_module::add_i32(123, 456));
+}
+
+}  // namespace
+}  // namespace crubit
diff --git a/common/code_gen_utils.rs b/common/code_gen_utils.rs
index 46f1f8a..9f84dc3 100644
--- a/common/code_gen_utils.rs
+++ b/common/code_gen_utils.rs
@@ -68,10 +68,26 @@
     }
 
     pub fn format_for_cc(&self) -> Result<TokenStream> {
-        let namespace_cc_idents =
-            self.0.iter().map(|ns| format_cc_ident(ns)).collect::<Result<Vec<_>>>()?;
+        let namespace_cc_idents = self.cc_idents()?;
         Ok(quote! { #(#namespace_cc_idents::)* })
     }
+
+    pub fn format_with_cc_body(&self, body: TokenStream) -> Result<TokenStream> {
+        if self.0.is_empty() {
+            Ok(body)
+        } else {
+            let namespace_cc_idents = self.cc_idents()?;
+            Ok(quote! {
+                namespace #(#namespace_cc_idents)::* {
+                    #body
+                }
+            })
+        }
+    }
+
+    fn cc_idents(&self) -> Result<Vec<TokenStream>> {
+        self.0.iter().map(|ns| format_cc_ident(ns)).collect()
+    }
 }
 
 /// `CcInclude` represents a single `#include ...` directive in C++.
@@ -476,4 +492,26 @@
         assert!(msg.contains("`reinterpret_cast`"));
         assert!(msg.contains("C++ reserved keyword"));
     }
+
+    #[test]
+    fn test_namespace_qualifier_format_with_cc_body_top_level_namespace() {
+        let ns = create_namespace_qualifier_for_tests(&[]);
+        assert_cc_matches!(
+            ns.format_with_cc_body(quote! { cc body goes here }).unwrap(),
+            quote! { cc body goes here },
+        );
+    }
+
+    #[test]
+    fn test_namespace_qualifier_format_with_cc_body_nested_namespace() {
+        let ns = create_namespace_qualifier_for_tests(&["foo", "bar", "baz"]);
+        assert_cc_matches!(
+            ns.format_with_cc_body(quote! { cc body goes here }).unwrap(),
+            quote! {
+                namespace foo::bar::baz {
+                    cc body goes here
+                }  // namespace foo::bar::baz
+            },
+        );
+    }
 }