Use include guard instead of pragma once.

PiperOrigin-RevId: 672718959
Change-Id: Ie0dd3e2f37d6c2f421bba14bcc7b3a994340a257
diff --git a/cc_bindings_from_rs/bindings.rs b/cc_bindings_from_rs/bindings.rs
index de8f624..964e7cd 100644
--- a/cc_bindings_from_rs/bindings.rs
+++ b/cc_bindings_from_rs/bindings.rs
@@ -80,6 +80,9 @@
         #[input]
         fn no_thunk_name_mangling(&self) -> bool;
 
+        #[input]
+        fn h_out_include_guard(&self) -> IncludeGuard;
+
         fn support_header(&self, suffix: &'tcx str) -> CcInclude;
 
         fn repr_attrs(&self, did: DefId) -> Rc<[rustc_attr::ReprAttr]>;
@@ -112,6 +115,12 @@
     pub struct Database;
 }
 
+#[derive(Clone, Debug)]
+pub enum IncludeGuard {
+    PragmaOnce,
+    Guard(String),
+}
+
 fn support_header<'tcx>(db: &dyn BindingsGenerator<'tcx>, suffix: &'tcx str) -> CcInclude {
     CcInclude::support_lib_header(db.crubit_support_path_format(), suffix.into())
 }
@@ -121,6 +130,29 @@
     pub rs_body: TokenStream,
 }
 
+fn add_include_guard(db: &dyn BindingsGenerator<'_>, h_body: TokenStream) -> Result<TokenStream> {
+    match db.h_out_include_guard() {
+        IncludeGuard::PragmaOnce => Ok(quote! {
+            __HASH_TOKEN__ pragma once __NEWLINE__
+            __NEWLINE__
+
+            #h_body
+        }),
+        IncludeGuard::Guard(include_guard_str) => {
+            let include_guard = format_cc_ident(include_guard_str.as_str())?;
+            Ok(quote! {
+                __HASH_TOKEN__ ifndef #include_guard __NEWLINE__
+                __HASH_TOKEN__ define #include_guard __NEWLINE__
+                __NEWLINE__
+
+                #h_body
+
+                __HASH_TOKEN__ endif __COMMENT__ #include_guard_str __NEWLINE__
+            })
+        }
+    }
+}
+
 pub fn generate_bindings(db: &Database) -> Result<Output> {
     let tcx = db.tcx();
 
@@ -151,14 +183,10 @@
         let src = quote! { __COMMENT__ #txt };
         Output { h_body: src.clone(), rs_body: src }
     });
-
+    let h_body = add_include_guard(db, h_body)?;
     let h_body = quote! {
         #top_comment
 
-        // TODO(b/251445877): Replace `#pragma once` with include guards.
-        __HASH_TOKEN__ pragma once __NEWLINE__
-        __NEWLINE__
-
         #h_body
     };
 
@@ -9586,6 +9614,7 @@
             /* crate_name_to_features= */ Default::default(),
             /* errors = */ Rc::new(IgnoreErrors),
             /* no_thunk_name_mangling= */ false,
+            /* include_guard */ IncludeGuard::PragmaOnce,
         )
     }
 
diff --git a/cc_bindings_from_rs/cc_bindings_from_rs.rs b/cc_bindings_from_rs/cc_bindings_from_rs.rs
index a2ff39f..476ef3c 100644
--- a/cc_bindings_from_rs/cc_bindings_from_rs.rs
+++ b/cc_bindings_from_rs/cc_bindings_from_rs.rs
@@ -15,7 +15,7 @@
 use std::path::Path;
 use std::rc::Rc;
 
-use bindings::Database;
+use bindings::{Database, IncludeGuard};
 use cmdline::Cmdline;
 use code_gen_utils::CcInclude;
 use error_report::{ErrorReport, ErrorReporting, IgnoreErrors};
@@ -25,7 +25,10 @@
 };
 
 fn turn_off_clang_format(mut h_body: String) -> String {
-    h_body.insert_str(h_body.find("#pragma once").unwrap(), "// clang-format off\n");
+    h_body.insert_str(
+        h_body.find("#ifndef").unwrap_or_else(|| h_body.find("#pragma once").unwrap()),
+        "// clang-format off\n",
+    );
     h_body
 }
 
@@ -54,7 +57,11 @@
             crate_name_to_features.entry(crate_name.as_str().into()).or_default();
         *accumulated_features |= *features
     }
-
+    let include_guard = if let Some(include_guard) = &cmdline.h_out_include_guard {
+        IncludeGuard::Guard(include_guard.clone())
+    } else {
+        IncludeGuard::PragmaOnce
+    };
     Database::new(
         tcx,
         crubit_support_path_format,
@@ -62,6 +69,7 @@
         crate_name_to_features.into(),
         errors,
         cmdline.no_thunk_name_mangling,
+        include_guard,
     )
 }
 
@@ -159,6 +167,7 @@
         extra_rustc_args: Vec<String>,
 
         tempdir: TempDir,
+        include_guard: Option<String>,
     }
 
     /// Result of `TestArgs::run` that helps tests access test outputs (e.g. the
@@ -180,6 +189,7 @@
                 panic_mechanism: "abort".to_string(),
                 extra_rustc_args: vec![],
                 tempdir: tempdir()?,
+                include_guard: None,
             })
         }
 
@@ -223,6 +233,11 @@
             self
         }
 
+        fn with_include_guard(mut self, include_guard: &str) -> Self {
+            self.include_guard = Some(include_guard.to_string());
+            self
+        }
+
         /// Invokes `super::run_with_cmdline_args` with default `test_crate.rs`
         /// input (and with other default args + args gathered by
         /// `self`).
@@ -269,6 +284,11 @@
                     error_report_out_path.as_ref().unwrap().display()
                 ));
             }
+
+            if let Some(include_guard) = &self.include_guard {
+                args.push(format!("--h-out-include-guard={include_guard}"));
+            }
+
             args.extend(self.extra_crubit_args.iter().cloned());
             args.extend([
                 "--".to_string(),
@@ -361,6 +381,55 @@
     }
 
     #[test]
+    fn test_with_include_guard() -> Result<()> {
+        let test_args =
+            TestArgs::default_args()?.with_include_guard("CRUBIT_GENERATED_HEADER_FOR_test_crate_");
+        let test_result = test_args.run().expect("Customized include guard should succeed");
+
+        assert!(test_result.h_path.exists());
+        let temp_dir_str = test_args.tempdir.path().to_str().unwrap();
+        let h_body = std::fs::read_to_string(&test_result.h_path)?;
+        #[rustfmt::skip]
+        assert_body_matches(
+            &h_body,
+            &format!(
+                "{}\n{}\n{}",
+r#"// Automatically @generated C++ bindings for the following Rust crate:
+// test_crate
+// Features: <none>
+
+// clang-format off
+#ifndef CRUBIT_GENERATED_HEADER_FOR_test_crate_
+#define CRUBIT_GENERATED_HEADER_FOR_test_crate_
+
+namespace test_crate {
+
+namespace public_module {
+"#,
+ // TODO(b/261185414): Avoid assuming that all source code paths are google3 paths.
+format!("// Generated from: google3/{temp_dir_str}/test_crate.rs;l=2"),
+r#"void public_function();
+
+namespace __crubit_internal {
+extern "C" void
+__crubit_thunk__ANY_IDENTIFIER_CHARACTERS();
+}
+inline void public_function() {
+  return __crubit_internal::
+      __crubit_thunk__ANY_IDENTIFIER_CHARACTERS();
+}
+
+}  // namespace public_module
+
+}  // namespace test_crate
+#endif  // CRUBIT_GENERATED_HEADER_FOR_test_crate_
+"#
+            ),
+        );
+        Ok(())
+    }
+
+    #[test]
     fn test_happy_path() -> Result<()> {
         let test_args = TestArgs::default_args()?;
         let test_result = test_args.run().expect("Default args should succeed");
diff --git a/cc_bindings_from_rs/cmdline.rs b/cc_bindings_from_rs/cmdline.rs
index f792695..4282254 100644
--- a/cc_bindings_from_rs/cmdline.rs
+++ b/cc_bindings_from_rs/cmdline.rs
@@ -22,6 +22,10 @@
     #[clap(long, value_parser, value_name = "FILE")]
     pub h_out: PathBuf,
 
+    /// Include guard for the C++ header file with bindings.
+    #[clap(long, value_parser, value_name = "STRING")]
+    pub h_out_include_guard: Option<String>,
+
     /// Output path for Rust implementation of the bindings.
     #[clap(long, value_parser, value_name = "FILE")]
     pub rs_out: PathBuf,
@@ -259,6 +263,9 @@
       --h-out <FILE>
           Output path for C++ header file with bindings
 
+      --h-out-include-guard <STRING>
+          Include guard for the C++ header file with bindings
+
       --rs-out <FILE>
           Output path for Rust implementation of the bindings