Add public_headers to the IR

These will be needed by the generated C++ code.

PiperOrigin-RevId: 391729880
diff --git a/rs_bindings_from_cc/ast_consumer.h b/rs_bindings_from_cc/ast_consumer.h
index 059007c..949f022 100644
--- a/rs_bindings_from_cc/ast_consumer.h
+++ b/rs_bindings_from_cc/ast_consumer.h
@@ -5,18 +5,23 @@
 #ifndef CRUBIT_RS_BINDINGS_FROM_CC_AST_CONSUMER_H_
 #define CRUBIT_RS_BINDINGS_FROM_CC_AST_CONSUMER_H_
 
+#include <string>
+
 #include "rs_bindings_from_cc/ast_visitor.h"
 #include "rs_bindings_from_cc/ir.h"
+#include "third_party/absl/types/span.h"
 #include "third_party/llvm/llvm-project/clang/include/clang/AST/ASTConsumer.h"
 #include "third_party/llvm/llvm-project/clang/include/clang/AST/ASTContext.h"
 
 namespace rs_bindings_from_cc {
 
-// Consumes the Clang AST of the header and generates the intermediate
-// representation (`IR`).
+// Consumes the Clang AST created from `public_headers` (a collection of paths
+// in the format suitable for a google3-relative quote include) and generates
+// the intermediate representation (`IR`).
 class AstConsumer : public clang::ASTConsumer {
  public:
-  explicit AstConsumer(IR &ir) : ast_visitor_(ir) {}
+  explicit AstConsumer(absl::Span<const std::string> public_headers, IR &ir)
+      : ast_visitor_(public_headers, ir) {}
 
   void HandleTranslationUnit(clang::ASTContext &context) override;
 
diff --git a/rs_bindings_from_cc/ast_visitor.cc b/rs_bindings_from_cc/ast_visitor.cc
index c79204f..41a15f0 100644
--- a/rs_bindings_from_cc/ast_visitor.cc
+++ b/rs_bindings_from_cc/ast_visitor.cc
@@ -30,6 +30,20 @@
 bool AstVisitor::TraverseTranslationUnitDecl(
     clang::TranslationUnitDecl* translation_unit_decl) {
   mangler_.reset(translation_unit_decl->getASTContext().createMangleContext());
+
+  // TODO(hlopko): Make the generated C++ code include-what-you-use clean.
+  // Currently we pass public headers of the library to the src_code_gen.
+  // Through those Clang has access to all declarations needed by the public API
+  // of the library. However the code violates IWYU - it will not directly
+  // include all the headers declaring names used in the generated source. This
+  // could be fixed by passing not only public headers of the library to the
+  // tool, but also all public headers of the direct dependencies of the
+  // library. This way if the library was IWYU clean, the generated code will be
+  // too.
+  for (const std::string& header : public_headers_) {
+    ir_.UsedHeaders().emplace_back(HeaderName(absl::Cord(header)));
+  }
+
   return Base::TraverseTranslationUnitDecl(translation_unit_decl);
 }
 
diff --git a/rs_bindings_from_cc/ast_visitor.h b/rs_bindings_from_cc/ast_visitor.h
index 5ca91db..2f14147 100644
--- a/rs_bindings_from_cc/ast_visitor.h
+++ b/rs_bindings_from_cc/ast_visitor.h
@@ -6,10 +6,12 @@
 #define CRUBIT_RS_BINDINGS_FROM_CC_AST_VISITOR_H_
 
 #include <memory>
+#include <string>
 
 #include "rs_bindings_from_cc/ir.h"
 #include "third_party/absl/container/flat_hash_set.h"
 #include "third_party/absl/strings/cord.h"
+#include "third_party/absl/types/span.h"
 #include "third_party/llvm/llvm-project/clang/include/clang/AST/Decl.h"
 #include "third_party/llvm/llvm-project/clang/include/clang/AST/Mangle.h"
 #include "third_party/llvm/llvm-project/clang/include/clang/AST/RecursiveASTVisitor.h"
@@ -17,13 +19,15 @@
 
 namespace rs_bindings_from_cc {
 
-// Iterates over the AST nodes of the header and creates intermediate
-// representation of the import (`IR`).
+// Iterates over the AST created from `public_headers` (a collection of paths
+// in the format suitable for a google3-relative quote include) and creates
+// an intermediate representation of the import (`IR`).
 class AstVisitor : public clang::RecursiveASTVisitor<AstVisitor> {
  public:
   using Base = clang::RecursiveASTVisitor<AstVisitor>;
 
-  explicit AstVisitor(IR &ir) : ir_(ir) {}
+  explicit AstVisitor(absl::Span<const std::string> public_headers, IR &ir)
+      : public_headers_(public_headers), ir_(ir) {}
 
   // These functions are called by the base class while visiting the different
   // parts of the AST. The API follows the rules of the base class which is
@@ -39,6 +43,7 @@
   Identifier GetTranslatedName(const clang::NamedDecl *named_decl) const;
   Type ConvertType(clang::QualType qual_type) const;
 
+  absl::Span<const std::string> public_headers_;
   IR &ir_;
   std::unique_ptr<clang::MangleContext> mangler_;
   absl::flat_hash_set<const clang::Decl *> seen_decls_;
diff --git a/rs_bindings_from_cc/ast_visitor_test.cc b/rs_bindings_from_cc/ast_visitor_test.cc
index 336e4ac..7fedb88 100644
--- a/rs_bindings_from_cc/ast_visitor_test.cc
+++ b/rs_bindings_from_cc/ast_visitor_test.cc
@@ -4,15 +4,19 @@
 
 #include <memory>
 #include <string>
+#include <utility>
 #include <vector>
 
+#include "devtools/cymbal/common/clang_tool.h"
 #include "rs_bindings_from_cc/frontend_action.h"
 #include "rs_bindings_from_cc/ir.h"
 #include "testing/base/public/gmock.h"
 #include "testing/base/public/gunit.h"
+#include "third_party/absl/container/flat_hash_map.h"
 #include "third_party/absl/strings/string_view.h"
+#include "third_party/absl/strings/substitute.h"
+#include "third_party/absl/types/span.h"
 #include "third_party/llvm/llvm-project/clang/include/clang/Frontend/FrontendAction.h"
-#include "third_party/llvm/llvm-project/clang/include/clang/Tooling/Tooling.h"
 
 namespace rs_bindings_from_cc {
 namespace {
@@ -20,28 +24,53 @@
 using ::testing::IsEmpty;
 using ::testing::SizeIs;
 
-IR ImportCode(const absl::string_view code,
+constexpr absl::string_view kVirtualInputPath =
+    "ast_visitor_test_virtual_input.cc";
+
+IR ImportCode(absl::Span<const absl::string_view> header_files_contents,
               const std::vector<absl::string_view>& args) {
-  IR ir;
+  std::vector<std::string> headers;
+  absl::flat_hash_map<std::string, std::string> file_contents;
+  std::string virtual_input_file_content;
+
+  int counter = 0;
+  for (const absl::string_view header_content : header_files_contents) {
+    std::string filename(
+        absl::Substitute("test/testing_header_$0.h", counter++));
+    file_contents.insert({filename, std::string(header_content)});
+    absl::SubstituteAndAppend(&virtual_input_file_content, "#include \"$0\"\n",
+                              filename);
+    headers.emplace_back(std::move(filename));
+  }
+
+  file_contents.insert(
+      {std::string(kVirtualInputPath), virtual_input_file_content});
+
   std::vector<std::string> args_as_strings(args.begin(), args.end());
-  clang::tooling::runToolOnCodeWithArgs(
-      std::make_unique<rs_bindings_from_cc::FrontendAction>(ir), code,
-      args_as_strings);
+  args_as_strings.emplace_back(std::string("--syntax-only"));
+  args_as_strings.emplace_back(std::string(kVirtualInputPath));
+
+  IR ir;
+  devtools::cymbal::RunToolWithClangFlagsOnCode(
+      args_as_strings, file_contents,
+      std::make_unique<rs_bindings_from_cc::FrontendAction>(headers, ir));
   return ir;
 }
 
 TEST(AstVisitorTest, TestNoop) {
-  IR ir = ImportCode("// nothing interesting there.", {});
+  IR ir = ImportCode({"// nothing interesting there."}, {});
   EXPECT_THAT(ir.Functions(), IsEmpty());
+  EXPECT_THAT(ir.UsedHeaders(), SizeIs(1));
+  EXPECT_EQ(ir.UsedHeaders()[0].IncludePath(), "test/testing_header_0.h");
 }
 
 TEST(AstVisitorTest, TestIREmptyOnInvalidInput) {
-  IR ir = ImportCode("int foo(); But this is not C++", {});
+  IR ir = ImportCode({"int foo(); But this is not C++"}, {});
   EXPECT_THAT(ir.Functions(), IsEmpty());
 }
 
 TEST(AstVisitorTest, TestImportFuncWithVoidReturnType) {
-  IR ir = ImportCode("void Foo();", {});
+  IR ir = ImportCode({"void Foo();"}, {});
   ASSERT_THAT(ir.Functions(), SizeIs(1));
   Func func = ir.Functions()[0];
   EXPECT_EQ(func.Ident().Ident(), "Foo");
@@ -51,7 +80,7 @@
 }
 
 TEST(AstVisitorTest, TestImportTwoFuncs) {
-  IR ir = ImportCode("void Foo(); void Bar();", {});
+  IR ir = ImportCode({"void Foo(); void Bar();"}, {});
   ASSERT_THAT(ir.Functions(), SizeIs(2));
 
   Func foo = ir.Functions()[0];
@@ -67,18 +96,24 @@
   EXPECT_THAT(bar.Params(), IsEmpty());
 }
 
+TEST(AstVisitorTest, TestImportTwoFuncsFromTwoHeaders) {
+  IR ir = ImportCode({"void Foo();", "void Bar();"}, {});
+  ASSERT_THAT(ir.Functions(), SizeIs(2));
+  Func foo = ir.Functions()[0];
+  EXPECT_EQ(foo.Ident().Ident(), "Foo");
+  Func bar = ir.Functions()[1];
+  EXPECT_EQ(bar.Ident().Ident(), "Bar");
+}
+
 TEST(AstVisitorTest, TestImportFuncJustOnce) {
-  IR ir = ImportCode(
-      "void Foo();"
-      "void Foo();",
-      {});
+  IR ir = ImportCode({"void Foo(); void Foo();"}, {});
   ASSERT_THAT(ir.Functions(), SizeIs(1));
   Func func = ir.Functions()[0];
   EXPECT_EQ(func.Ident().Ident(), "Foo");
 }
 
 TEST(AstVisitorTest, TestImportFuncParams) {
-  IR ir = ImportCode("int Add(int a, int b);", {});
+  IR ir = ImportCode({"int Add(int a, int b);"}, {});
   EXPECT_THAT(ir.Functions(), SizeIs(1));
 
   Func func = ir.Functions()[0];
diff --git a/rs_bindings_from_cc/frontend_action.cc b/rs_bindings_from_cc/frontend_action.cc
index 6de4495..a4d3708 100644
--- a/rs_bindings_from_cc/frontend_action.cc
+++ b/rs_bindings_from_cc/frontend_action.cc
@@ -14,7 +14,7 @@
 
 std::unique_ptr<clang::ASTConsumer> FrontendAction::CreateASTConsumer(
     clang::CompilerInstance &, llvm::StringRef) {
-  return std::make_unique<AstConsumer>(ir_);
+  return std::make_unique<AstConsumer>(public_headers_, ir_);
 }
 
 }  // namespace rs_bindings_from_cc
diff --git a/rs_bindings_from_cc/frontend_action.h b/rs_bindings_from_cc/frontend_action.h
index 833769f..e9c414c 100644
--- a/rs_bindings_from_cc/frontend_action.h
+++ b/rs_bindings_from_cc/frontend_action.h
@@ -6,8 +6,10 @@
 #define CRUBIT_RS_BINDINGS_FROM_CC_FRONTEND_ACTION_H_
 
 #include <memory>
+#include <string>
 
 #include "rs_bindings_from_cc/ir.h"
+#include "third_party/absl/types/span.h"
 #include "third_party/llvm/llvm-project/clang/include/clang/AST/ASTConsumer.h"
 #include "third_party/llvm/llvm-project/clang/include/clang/Frontend/CompilerInstance.h"
 #include "third_party/llvm/llvm-project/clang/include/clang/Frontend/FrontendAction.h"
@@ -18,12 +20,14 @@
 // (`IR`) into the `ir` parameter.
 class FrontendAction : public clang::ASTFrontendAction {
  public:
-  explicit FrontendAction(IR &ir) : ir_(ir) {}
+  explicit FrontendAction(absl::Span<const std::string> public_headers, IR &ir)
+      : public_headers_(public_headers), ir_(ir) {}
 
   std::unique_ptr<clang::ASTConsumer> CreateASTConsumer(
       clang::CompilerInstance &, llvm::StringRef) override;
 
  private:
+  absl::Span<const std::string> public_headers_;
   IR &ir_;
 };
 
diff --git a/rs_bindings_from_cc/ir.cc b/rs_bindings_from_cc/ir.cc
index 2f8e5a8..8aecec4 100644
--- a/rs_bindings_from_cc/ir.cc
+++ b/rs_bindings_from_cc/ir.cc
@@ -12,6 +12,12 @@
 
 namespace rs_bindings_from_cc {
 
+nlohmann::json HeaderName::ToJson() const {
+  nlohmann::json result;
+  result["name"] = std::string(name_);
+  return result;
+}
+
 nlohmann::json Type::ToJson() const {
   nlohmann::json result;
   result["rs_name"] = std::string(rs_name_);
@@ -45,11 +51,18 @@
 }
 
 nlohmann::json IR::ToJson() const {
+  std::vector<nlohmann::json> used_headers;
+  for (const HeaderName& header : used_headers_) {
+    used_headers.push_back(header.ToJson());
+  }
+
   std::vector<nlohmann::json> functions;
   for (const Func& func : functions_) {
     functions.push_back(func.ToJson());
   }
+
   nlohmann::json result;
+  result["used_headers"] = used_headers;
   result["functions"] = functions;
   return result;
 }
diff --git a/rs_bindings_from_cc/ir.h b/rs_bindings_from_cc/ir.h
index 46b891b..628cd94 100644
--- a/rs_bindings_from_cc/ir.h
+++ b/rs_bindings_from_cc/ir.h
@@ -20,6 +20,21 @@
 
 namespace rs_bindings_from_cc {
 
+// A name of a public header of the C++ library.
+class HeaderName {
+ public:
+  explicit HeaderName(absl::Cord name) : name_(std::move(name)) {}
+
+  const absl::Cord &IncludePath() const { return name_; }
+
+  nlohmann::json ToJson() const;
+
+ private:
+  // Header pathname in the format suitable for a google3-relative quote
+  // include.
+  absl::Cord name_;
+};
+
 // A type involved in the bindings. It has the knowledge about how the type is
 // spelled in Rust and in C++ code.
 //
@@ -117,14 +132,22 @@
 class IR {
  public:
   explicit IR() {}
-  explicit IR(std::vector<Func> functions) : functions_(std::move(functions)) {}
+  explicit IR(std::vector<HeaderName> used_headers, std::vector<Func> functions)
+      : used_headers_(std::move(used_headers)),
+        functions_(std::move(functions)) {}
 
   nlohmann::json ToJson() const;
 
+  const std::vector<HeaderName> &UsedHeaders() const { return used_headers_; }
+  std::vector<HeaderName> &UsedHeaders() { return used_headers_; }
+
   const std::vector<Func> &Functions() const { return functions_; }
   std::vector<Func> &Functions() { return functions_; }
 
  private:
+  // Collection of public headers that were used to construct the AST this `IR`
+  // is generated from.
+  std::vector<HeaderName> used_headers_;
   std::vector<Func> functions_;
 };
 
diff --git a/rs_bindings_from_cc/ir.rs b/rs_bindings_from_cc/ir.rs
index 4272337..bb00a19 100644
--- a/rs_bindings_from_cc/ir.rs
+++ b/rs_bindings_from_cc/ir.rs
@@ -2,6 +2,8 @@
 // Exceptions. See /LICENSE for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+/// Types and deserialization logic for IR. See docs in
+// `rs_bindings_from_cc/ir.h` for more information.
 use anyhow::Result;
 use serde::Deserialize;
 use std::io::Read;
@@ -11,6 +13,11 @@
 }
 
 #[derive(Debug, PartialEq, Eq, Hash, Clone, Deserialize)]
+pub struct HeaderName {
+    pub name: String,
+}
+
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Deserialize)]
 pub struct IRType {
     pub rs_name: String,
 }
@@ -37,6 +44,7 @@
 
 #[derive(Debug, PartialEq, Eq, Hash, Clone, Deserialize)]
 pub struct IR {
+    pub used_headers: Vec<HeaderName>,
     pub functions: Vec<Func>,
 }
 
@@ -48,6 +56,7 @@
     fn test_deserializing() {
         let input = r#"
         {
+            "used_headers": [{ "name": "foo/bar.h" }],
             "functions": [
                 {
                     "identifier": { "identifier": "hello_world" },
@@ -65,6 +74,7 @@
         "#;
         let ir = deserialize_ir(input.as_bytes()).unwrap();
         let expected = IR {
+            used_headers: vec![HeaderName { name: "foo/bar.h".to_string() }],
             functions: vec![Func {
                 identifier: Identifier { identifier: "hello_world".to_string() },
                 mangled_name: "$$mangled_name$$".to_string(),
diff --git a/rs_bindings_from_cc/ir_test.cc b/rs_bindings_from_cc/ir_test.cc
index e6649ea..81209ef 100644
--- a/rs_bindings_from_cc/ir_test.cc
+++ b/rs_bindings_from_cc/ir_test.cc
@@ -22,6 +22,7 @@
 TEST(IrTest, TestIR) {
   nlohmann::json expected = nlohmann::json::parse(
       R"j({
+            "used_headers": [{ "name": "foo/bar.h" }],
             "functions": [{
               "identifier": { "identifier": "hello_world" },
               "mangled_name": "#$mangled_name$#",
@@ -34,7 +35,8 @@
               ]
             }]
       })j");
-  EXPECT_EQ(IR({Func(Identifier(absl::Cord("hello_world")),
+  EXPECT_EQ(IR({HeaderName(absl::Cord("foo/bar.h"))},
+               {Func(Identifier(absl::Cord("hello_world")),
                      absl::Cord("#$mangled_name$#"), Type(absl::Cord("i32")),
                      {FuncParam(Type(absl::Cord("i32")),
                                 Identifier(absl::Cord("arg")))})})
diff --git a/rs_bindings_from_cc/rs_bindings_from_cc.cc b/rs_bindings_from_cc/rs_bindings_from_cc.cc
index 0e4d116..710a8e8 100644
--- a/rs_bindings_from_cc/rs_bindings_from_cc.cc
+++ b/rs_bindings_from_cc/rs_bindings_from_cc.cc
@@ -32,7 +32,8 @@
           "output path for the C++ source file with bindings implementation");
 ABSL_FLAG(std::vector<std::string>, public_headers, std::vector<std::string>(),
           "public headers of the cc_library this tool should generate bindings "
-          "for, in a format suitable for usage in #include \"\".");
+          "for, in a format suitable for usage in google3-relative quote "
+          "include (#include \"\").");
 
 constexpr absl::string_view kVirtualInputPath =
     "rs_bindings_from_cc_virtual_input.cc";
@@ -63,15 +64,16 @@
   rs_bindings_from_cc::IR ir;
   if (devtools::cymbal::RunToolWithClangFlagsOnCode(
           command_line, file_contents,
-          std::make_unique<rs_bindings_from_cc::FrontendAction>(ir))) {
+          std::make_unique<rs_bindings_from_cc::FrontendAction>(public_headers,
+                                                                ir))) {
     std::string rs_api = rs_bindings_from_cc::GenerateRustApi(ir);
     std::string rs_api_impl = "// No bindings implementation code was needed.";
     CHECK_OK(file::SetContents(rs_out, rs_api, file::Defaults()));
     CHECK_OK(file::SetContents(cc_out, rs_api_impl, file::Defaults()));
     return 0;
-  } else {
-    CHECK_OK(file::Delete(rs_out, file::Defaults()));
-    CHECK_OK(file::Delete(cc_out, file::Defaults()));
-    return 1;
   }
+
+  CHECK_OK(file::Delete(rs_out, file::Defaults()));
+  CHECK_OK(file::Delete(cc_out, file::Defaults()));
+  return 1;
 }
diff --git a/rs_bindings_from_cc/src_code_gen.rs b/rs_bindings_from_cc/src_code_gen.rs
index 71a8161..bfba321 100644
--- a/rs_bindings_from_cc/src_code_gen.rs
+++ b/rs_bindings_from_cc/src_code_gen.rs
@@ -95,6 +95,7 @@
     #[test]
     fn test_gen_src_code() -> Result<()> {
         let ir = IR {
+            used_headers: vec![],
             functions: vec![Func {
                 identifier: Identifier { identifier: "add".to_string() },
                 mangled_name: "_Z3Addii".to_string(),
@@ -114,7 +115,8 @@
         let result = gen_src_code(ir)?;
         assert_eq!(
             result,
-            quote! {#[inline(always)]
+            quote! {
+                #[inline(always)]
                 pub fn add(a: i32, b: i32) -> i32 {
                     unsafe { crate::detail::__rust_thunk__add(a, b) }
                 }
diff --git a/rs_bindings_from_cc/src_code_gen_test.cc b/rs_bindings_from_cc/src_code_gen_test.cc
index 7b2329d..d64ea28 100644
--- a/rs_bindings_from_cc/src_code_gen_test.cc
+++ b/rs_bindings_from_cc/src_code_gen_test.cc
@@ -18,10 +18,11 @@
 using ::testing::StrEq;
 
 TEST(SrcGenTest, FFIIntegration) {
-  IR ir({Func(
-      Identifier(absl::Cord("hello_world")), absl::Cord("$$mangled_name$$"),
-      Type(absl::Cord("i32")),
-      {FuncParam(Type(absl::Cord("i32")), Identifier(absl::Cord("arg")))})});
+  IR ir({HeaderName(absl::Cord("foo/bar.h"))},
+        {Func(Identifier(absl::Cord("hello_world")),
+              absl::Cord("$$mangled_name$$"), Type(absl::Cord("i32")),
+              {FuncParam(Type(absl::Cord("i32")),
+                         Identifier(absl::Cord("arg")))})});
   std::string rs_api = GenerateRustApi(ir);
   EXPECT_THAT(
       rs_api,