Add `CcInclude` to `common/code_gen_utils.rs`.
PiperOrigin-RevId: 485412993
diff --git a/common/code_gen_utils.rs b/common/code_gen_utils.rs
index 59e7361..84f1e9c 100644
--- a/common/code_gen_utils.rs
+++ b/common/code_gen_utils.rs
@@ -5,7 +5,9 @@
use anyhow::{anyhow, ensure, Result};
use once_cell::sync::Lazy;
use proc_macro2::TokenStream;
-use std::collections::HashSet;
+use quote::{quote, ToTokens};
+use std::collections::{BTreeSet, HashSet};
+use std::rc::Rc;
// TODO(lukasza): Consider adding more items into `code_gen_utils` (this crate).
// For example, the following items from `src_code_gen.rs` will be most likely
@@ -32,6 +34,66 @@
)
}
+/// `CcInclude` represents a single `#include ...` directive in C++.
+#[derive(Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
+pub enum CcInclude {
+ SystemHeader(&'static str),
+ UserHeader(Rc<str>),
+}
+
+impl CcInclude {
+ /// Creates a `CcInclude` that represents `#include <cstddef>` and provides
+ /// C++ types like `std::size_t` or `std::ptrdiff_t`. See also
+ /// https://en.cppreference.com/w/cpp/header/cstddef
+ pub fn cstddef() -> Self {
+ Self::SystemHeader("cstddef")
+ }
+
+ /// Creates a `CcInclude` that represents `#include <memory>`.
+ /// See also https://en.cppreference.com/w/cpp/header/memory
+ pub fn memory() -> Self {
+ Self::SystemHeader("memory")
+ }
+
+ /// Creates a user include: `#include "some/path/to/header.h"`.
+ pub fn user_header(path: Rc<str>) -> Self {
+ Self::UserHeader(path)
+ }
+}
+
+impl ToTokens for CcInclude {
+ fn to_tokens(&self, tokens: &mut TokenStream) {
+ match self {
+ Self::SystemHeader(path) => {
+ let path: TokenStream = path
+ .parse()
+ .expect("`pub` API of `CcInclude` guarantees validity of system includes");
+ quote! { __HASH_TOKEN__ include < #path > __NEWLINE__ }.to_tokens(tokens)
+ }
+ Self::UserHeader(path) => {
+ quote! { __HASH_TOKEN__ include #path __NEWLINE__ }.to_tokens(tokens)
+ }
+ }
+ }
+}
+
+/// Formats a set of `CcInclude`s, trying to follow the guidance from
+/// [the Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html#Names_and_Order_of_Includes).
+pub fn format_cc_includes(set_of_includes: &BTreeSet<CcInclude>) -> TokenStream {
+ let mut tokens = TokenStream::default();
+ let mut iter = set_of_includes.iter().peekable();
+ while let Some(include) = iter.next() {
+ include.to_tokens(&mut tokens);
+
+ // Add an empty line between system headers and user headers.
+ if let (CcInclude::SystemHeader(_), Some(CcInclude::UserHeader(_))) = (include, iter.peek())
+ {
+ quote! { __NEWLINE__ }.to_tokens(&mut tokens)
+ }
+ }
+ tokens
+}
+
static RESERVED_CC_KEYWORDS: Lazy<HashSet<&'static str>> = Lazy::new(|| {
// `RESERVED_CC_KEYWORDS` are based on https://en.cppreference.com/w/cpp/keyword
[
@@ -142,6 +204,7 @@
use super::*;
use quote::quote;
use token_stream_matchers::assert_cc_matches;
+ use token_stream_printer::cc_tokens_to_formatted_string;
#[test]
fn test_format_cc_ident_basic() {
@@ -218,4 +281,65 @@
quote! { std::vector<int> }
);
}
+
+ #[test]
+ fn test_cc_include_to_tokens_for_system_header() {
+ let include = CcInclude::cstddef();
+ assert_cc_matches!(
+ quote! { #include },
+ quote! {
+ __HASH_TOKEN__ include <cstddef>
+ }
+ );
+ }
+
+ #[test]
+ fn test_cc_include_to_tokens_for_user_header() {
+ let include = CcInclude::user_header("some/path/to/header.h".into());
+ assert_cc_matches!(
+ quote! { #include },
+ quote! {
+ __HASH_TOKEN__ include "some/path/to/header.h"
+ }
+ );
+ }
+
+ #[test]
+ fn test_cc_include_ord() {
+ let cstddef = CcInclude::cstddef();
+ let memory = CcInclude::memory();
+ let a = CcInclude::user_header("a.h".into());
+ let b = CcInclude::user_header("b.h".into());
+ assert!(cstddef < memory);
+ assert!(cstddef < a);
+ assert!(cstddef < b);
+ assert!(memory < a);
+ assert!(memory < b);
+ assert!(a < b);
+ }
+
+ #[test]
+ fn test_format_cc_includes() {
+ let includes = [
+ CcInclude::cstddef(),
+ CcInclude::memory(),
+ CcInclude::user_header("a.h".into()),
+ CcInclude::user_header("b.h".into()),
+ ]
+ .into_iter()
+ .collect::<BTreeSet<_>>();
+
+ let tokens = format_cc_includes(&includes);
+ let actual = cc_tokens_to_formatted_string(quote! { __NEWLINE__ #tokens }).unwrap();
+ assert_eq!(
+ actual,
+ r#"
+#include <cstddef>
+#include <memory>
+
+#include "a.h"
+#include "b.h"
+"#
+ );
+ }
}