Add `CcInclude` to `common/code_gen_utils.rs`.
PiperOrigin-RevId: 485412993
diff --git a/common/BUILD b/common/BUILD
index 9cce4ab..00023ee 100644
--- a/common/BUILD
+++ b/common/BUILD
@@ -30,6 +30,7 @@
"@crate_index//:anyhow",
"@crate_index//:once_cell",
"@crate_index//:proc-macro2",
+ "@crate_index//:quote",
],
)
@@ -38,7 +39,7 @@
crate = ":code_gen_utils",
deps = [
":token_stream_matchers",
- "@crate_index//:quote",
+ ":token_stream_printer",
],
)
diff --git a/common/code_gen_utils.rs b/common/code_gen_utils.rs
index 59e7361..84f1e9c 100644
--- a/common/code_gen_utils.rs
+++ b/common/code_gen_utils.rs
@@ -5,7 +5,9 @@
use anyhow::{anyhow, ensure, Result};
use once_cell::sync::Lazy;
use proc_macro2::TokenStream;
-use std::collections::HashSet;
+use quote::{quote, ToTokens};
+use std::collections::{BTreeSet, HashSet};
+use std::rc::Rc;
// TODO(lukasza): Consider adding more items into `code_gen_utils` (this crate).
// For example, the following items from `src_code_gen.rs` will be most likely
@@ -32,6 +34,66 @@
)
}
+/// `CcInclude` represents a single `#include ...` directive in C++.
+#[derive(Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
+pub enum CcInclude {
+ SystemHeader(&'static str),
+ UserHeader(Rc<str>),
+}
+
+impl CcInclude {
+ /// Creates a `CcInclude` that represents `#include <cstddef>` and provides
+ /// C++ types like `std::size_t` or `std::ptrdiff_t`. See also
+ /// https://en.cppreference.com/w/cpp/header/cstddef
+ pub fn cstddef() -> Self {
+ Self::SystemHeader("cstddef")
+ }
+
+ /// Creates a `CcInclude` that represents `#include <memory>`.
+ /// See also https://en.cppreference.com/w/cpp/header/memory
+ pub fn memory() -> Self {
+ Self::SystemHeader("memory")
+ }
+
+ /// Creates a user include: `#include "some/path/to/header.h"`.
+ pub fn user_header(path: Rc<str>) -> Self {
+ Self::UserHeader(path)
+ }
+}
+
+impl ToTokens for CcInclude {
+ fn to_tokens(&self, tokens: &mut TokenStream) {
+ match self {
+ Self::SystemHeader(path) => {
+ let path: TokenStream = path
+ .parse()
+ .expect("`pub` API of `CcInclude` guarantees validity of system includes");
+ quote! { __HASH_TOKEN__ include < #path > __NEWLINE__ }.to_tokens(tokens)
+ }
+ Self::UserHeader(path) => {
+ quote! { __HASH_TOKEN__ include #path __NEWLINE__ }.to_tokens(tokens)
+ }
+ }
+ }
+}
+
+/// Formats a set of `CcInclude`s, trying to follow the guidance from
+/// [the Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html#Names_and_Order_of_Includes).
+pub fn format_cc_includes(set_of_includes: &BTreeSet<CcInclude>) -> TokenStream {
+ let mut tokens = TokenStream::default();
+ let mut iter = set_of_includes.iter().peekable();
+ while let Some(include) = iter.next() {
+ include.to_tokens(&mut tokens);
+
+ // Add an empty line between system headers and user headers.
+ if let (CcInclude::SystemHeader(_), Some(CcInclude::UserHeader(_))) = (include, iter.peek())
+ {
+ quote! { __NEWLINE__ }.to_tokens(&mut tokens)
+ }
+ }
+ tokens
+}
+
static RESERVED_CC_KEYWORDS: Lazy<HashSet<&'static str>> = Lazy::new(|| {
// `RESERVED_CC_KEYWORDS` are based on https://en.cppreference.com/w/cpp/keyword
[
@@ -142,6 +204,7 @@
use super::*;
use quote::quote;
use token_stream_matchers::assert_cc_matches;
+ use token_stream_printer::cc_tokens_to_formatted_string;
#[test]
fn test_format_cc_ident_basic() {
@@ -218,4 +281,65 @@
quote! { std::vector<int> }
);
}
+
+ #[test]
+ fn test_cc_include_to_tokens_for_system_header() {
+ let include = CcInclude::cstddef();
+ assert_cc_matches!(
+ quote! { #include },
+ quote! {
+ __HASH_TOKEN__ include <cstddef>
+ }
+ );
+ }
+
+ #[test]
+ fn test_cc_include_to_tokens_for_user_header() {
+ let include = CcInclude::user_header("some/path/to/header.h".into());
+ assert_cc_matches!(
+ quote! { #include },
+ quote! {
+ __HASH_TOKEN__ include "some/path/to/header.h"
+ }
+ );
+ }
+
+ #[test]
+ fn test_cc_include_ord() {
+ let cstddef = CcInclude::cstddef();
+ let memory = CcInclude::memory();
+ let a = CcInclude::user_header("a.h".into());
+ let b = CcInclude::user_header("b.h".into());
+ assert!(cstddef < memory);
+ assert!(cstddef < a);
+ assert!(cstddef < b);
+ assert!(memory < a);
+ assert!(memory < b);
+ assert!(a < b);
+ }
+
+ #[test]
+ fn test_format_cc_includes() {
+ let includes = [
+ CcInclude::cstddef(),
+ CcInclude::memory(),
+ CcInclude::user_header("a.h".into()),
+ CcInclude::user_header("b.h".into()),
+ ]
+ .into_iter()
+ .collect::<BTreeSet<_>>();
+
+ let tokens = format_cc_includes(&includes);
+ let actual = cc_tokens_to_formatted_string(quote! { __NEWLINE__ #tokens }).unwrap();
+ assert_eq!(
+ actual,
+ r#"
+#include <cstddef>
+#include <memory>
+
+#include "a.h"
+#include "b.h"
+"#
+ );
+ }
}
diff --git a/rs_bindings_from_cc/ir.rs b/rs_bindings_from_cc/ir.rs
index 4e6d2ba..0e752c8 100644
--- a/rs_bindings_from_cc/ir.rs
+++ b/rs_bindings_from_cc/ir.rs
@@ -102,7 +102,7 @@
#[derive(Debug, PartialEq, Eq, Hash, Clone, Deserialize)]
pub struct HeaderName {
- pub name: String,
+ pub name: Rc<str>,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Deserialize)]
@@ -897,7 +897,7 @@
"#;
let ir = deserialize_ir(input.as_bytes()).unwrap();
let expected = FlatIR {
- used_headers: vec![HeaderName { name: "foo/bar.h".to_string() }],
+ used_headers: vec![HeaderName { name: "foo/bar.h".into() }],
current_target: "//foo:bar".into(),
top_level_item_ids: vec![],
items: vec![],
diff --git a/rs_bindings_from_cc/src_code_gen.rs b/rs_bindings_from_cc/src_code_gen.rs
index df3fef7..7a5167f 100644
--- a/rs_bindings_from_cc/src_code_gen.rs
+++ b/rs_bindings_from_cc/src_code_gen.rs
@@ -3,6 +3,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
use arc_anyhow::{Context, Result};
+use code_gen_utils::{format_cc_includes, CcInclude};
use error_report::{anyhow, bail, ensure, ErrorReport, ErrorReporting, IgnoreErrors};
use ffi_types::*;
use ir::*;
@@ -3679,25 +3680,29 @@
.map(|record| cc_struct_layout_assertion(record, &ir))
.collect::<Result<Vec<_>>>()?;
- let mut standard_headers = <BTreeSet<Ident>>::new();
- standard_headers.insert(format_ident!("memory")); // ubiquitous.
+ let mut internal_includes = BTreeSet::new();
+ internal_includes.insert(CcInclude::memory()); // ubiquitous.
if ir.records().next().is_some() {
- standard_headers.insert(format_ident!("cstddef"));
+ internal_includes.insert(CcInclude::cstddef());
};
-
- let mut includes = vec!["cxx20_backports.h", "offsetof.h"]
- .into_iter()
- .map(|hdr| format!("{}/{}", crubit_support_path, hdr))
- .collect_vec();
+ for crubit_header in ["cxx20_backports.h", "offsetof.h"] {
+ internal_includes.insert(CcInclude::user_header(
+ format!("{crubit_support_path}/{crubit_header}").into(),
+ ));
+ }
+ let internal_includes = format_cc_includes(&internal_includes);
// In order to generate C++ thunk in all the cases Clang needs to be able to
- // access declarations from public headers of the C++ library.
- includes.extend(ir.used_headers().map(|hdr| hdr.name.clone()));
+ // access declarations from public headers of the C++ library. We don't
+ // process these includes via `format_cc_includes` to preserve their
+ // original order (some libraries require certain headers to be included
+ // first - e.g. `config.h`).
+ let ir_includes =
+ ir.used_headers().map(|hdr| CcInclude::user_header(hdr.name.clone())).collect_vec();
Ok(quote! {
- #( __HASH_TOKEN__ include <#standard_headers> __NEWLINE__)*
- __NEWLINE__
- #( __HASH_TOKEN__ include #includes __NEWLINE__)* __NEWLINE__
+ #internal_includes
+ #( #ir_includes )* __NEWLINE__
__HASH_TOKEN__ pragma clang diagnostic push __NEWLINE__
// Disable Clang thread-safety-analysis warnings that would otherwise
// complain about thunks that call mutex locking functions in an unpaired way.