Add public_headers to the IR
These will be needed by the generated C++ code.
PiperOrigin-RevId: 391729880
diff --git a/rs_bindings_from_cc/ast_consumer.h b/rs_bindings_from_cc/ast_consumer.h
index 059007c..949f022 100644
--- a/rs_bindings_from_cc/ast_consumer.h
+++ b/rs_bindings_from_cc/ast_consumer.h
@@ -5,18 +5,23 @@
#ifndef CRUBIT_RS_BINDINGS_FROM_CC_AST_CONSUMER_H_
#define CRUBIT_RS_BINDINGS_FROM_CC_AST_CONSUMER_H_
+#include <string>
+
#include "rs_bindings_from_cc/ast_visitor.h"
#include "rs_bindings_from_cc/ir.h"
+#include "third_party/absl/types/span.h"
#include "third_party/llvm/llvm-project/clang/include/clang/AST/ASTConsumer.h"
#include "third_party/llvm/llvm-project/clang/include/clang/AST/ASTContext.h"
namespace rs_bindings_from_cc {
-// Consumes the Clang AST of the header and generates the intermediate
-// representation (`IR`).
+// Consumes the Clang AST created from `public_headers` (a collection of paths
+// in the format suitable for a google3-relative quote include) and generates
+// the intermediate representation (`IR`).
class AstConsumer : public clang::ASTConsumer {
public:
- explicit AstConsumer(IR &ir) : ast_visitor_(ir) {}
+ explicit AstConsumer(absl::Span<const std::string> public_headers, IR &ir)
+ : ast_visitor_(public_headers, ir) {}
void HandleTranslationUnit(clang::ASTContext &context) override;
diff --git a/rs_bindings_from_cc/ast_visitor.cc b/rs_bindings_from_cc/ast_visitor.cc
index c79204f..41a15f0 100644
--- a/rs_bindings_from_cc/ast_visitor.cc
+++ b/rs_bindings_from_cc/ast_visitor.cc
@@ -30,6 +30,20 @@
bool AstVisitor::TraverseTranslationUnitDecl(
clang::TranslationUnitDecl* translation_unit_decl) {
mangler_.reset(translation_unit_decl->getASTContext().createMangleContext());
+
+ // TODO(hlopko): Make the generated C++ code include-what-you-use clean.
+ // Currently we pass public headers of the library to the src_code_gen.
+ // Through those Clang has access to all declarations needed by the public API
+ // of the library. However the code violates IWYU - it will not directly
+ // include all the headers declaring names used in the generated source. This
+ // could be fixed by passing not only public headers of the library to the
+ // tool, but also all public headers of the direct dependencies of the
+ // library. This way if the library was IWYU clean, the generated code will be
+ // too.
+ for (const std::string& header : public_headers_) {
+ ir_.UsedHeaders().emplace_back(HeaderName(absl::Cord(header)));
+ }
+
return Base::TraverseTranslationUnitDecl(translation_unit_decl);
}
diff --git a/rs_bindings_from_cc/ast_visitor.h b/rs_bindings_from_cc/ast_visitor.h
index 5ca91db..2f14147 100644
--- a/rs_bindings_from_cc/ast_visitor.h
+++ b/rs_bindings_from_cc/ast_visitor.h
@@ -6,10 +6,12 @@
#define CRUBIT_RS_BINDINGS_FROM_CC_AST_VISITOR_H_
#include <memory>
+#include <string>
#include "rs_bindings_from_cc/ir.h"
#include "third_party/absl/container/flat_hash_set.h"
#include "third_party/absl/strings/cord.h"
+#include "third_party/absl/types/span.h"
#include "third_party/llvm/llvm-project/clang/include/clang/AST/Decl.h"
#include "third_party/llvm/llvm-project/clang/include/clang/AST/Mangle.h"
#include "third_party/llvm/llvm-project/clang/include/clang/AST/RecursiveASTVisitor.h"
@@ -17,13 +19,15 @@
namespace rs_bindings_from_cc {
-// Iterates over the AST nodes of the header and creates intermediate
-// representation of the import (`IR`).
+// Iterates over the AST created from `public_headers` (a collection of paths
+// in the format suitable for a google3-relative quote include) and creates
+// an intermediate representation of the import (`IR`).
class AstVisitor : public clang::RecursiveASTVisitor<AstVisitor> {
public:
using Base = clang::RecursiveASTVisitor<AstVisitor>;
- explicit AstVisitor(IR &ir) : ir_(ir) {}
+ explicit AstVisitor(absl::Span<const std::string> public_headers, IR &ir)
+ : public_headers_(public_headers), ir_(ir) {}
// These functions are called by the base class while visiting the different
// parts of the AST. The API follows the rules of the base class which is
@@ -39,6 +43,7 @@
Identifier GetTranslatedName(const clang::NamedDecl *named_decl) const;
Type ConvertType(clang::QualType qual_type) const;
+ absl::Span<const std::string> public_headers_;
IR &ir_;
std::unique_ptr<clang::MangleContext> mangler_;
absl::flat_hash_set<const clang::Decl *> seen_decls_;
diff --git a/rs_bindings_from_cc/ast_visitor_test.cc b/rs_bindings_from_cc/ast_visitor_test.cc
index 336e4ac..7fedb88 100644
--- a/rs_bindings_from_cc/ast_visitor_test.cc
+++ b/rs_bindings_from_cc/ast_visitor_test.cc
@@ -4,15 +4,19 @@
#include <memory>
#include <string>
+#include <utility>
#include <vector>
+#include "devtools/cymbal/common/clang_tool.h"
#include "rs_bindings_from_cc/frontend_action.h"
#include "rs_bindings_from_cc/ir.h"
#include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h"
+#include "third_party/absl/container/flat_hash_map.h"
#include "third_party/absl/strings/string_view.h"
+#include "third_party/absl/strings/substitute.h"
+#include "third_party/absl/types/span.h"
#include "third_party/llvm/llvm-project/clang/include/clang/Frontend/FrontendAction.h"
-#include "third_party/llvm/llvm-project/clang/include/clang/Tooling/Tooling.h"
namespace rs_bindings_from_cc {
namespace {
@@ -20,28 +24,53 @@
using ::testing::IsEmpty;
using ::testing::SizeIs;
-IR ImportCode(const absl::string_view code,
+constexpr absl::string_view kVirtualInputPath =
+ "ast_visitor_test_virtual_input.cc";
+
+IR ImportCode(absl::Span<const absl::string_view> header_files_contents,
const std::vector<absl::string_view>& args) {
- IR ir;
+ std::vector<std::string> headers;
+ absl::flat_hash_map<std::string, std::string> file_contents;
+ std::string virtual_input_file_content;
+
+ int counter = 0;
+ for (const absl::string_view header_content : header_files_contents) {
+ std::string filename(
+ absl::Substitute("test/testing_header_$0.h", counter++));
+ file_contents.insert({filename, std::string(header_content)});
+ absl::SubstituteAndAppend(&virtual_input_file_content, "#include \"$0\"\n",
+ filename);
+ headers.emplace_back(std::move(filename));
+ }
+
+ file_contents.insert(
+ {std::string(kVirtualInputPath), virtual_input_file_content});
+
std::vector<std::string> args_as_strings(args.begin(), args.end());
- clang::tooling::runToolOnCodeWithArgs(
- std::make_unique<rs_bindings_from_cc::FrontendAction>(ir), code,
- args_as_strings);
+ args_as_strings.emplace_back(std::string("--syntax-only"));
+ args_as_strings.emplace_back(std::string(kVirtualInputPath));
+
+ IR ir;
+ devtools::cymbal::RunToolWithClangFlagsOnCode(
+ args_as_strings, file_contents,
+ std::make_unique<rs_bindings_from_cc::FrontendAction>(headers, ir));
return ir;
}
TEST(AstVisitorTest, TestNoop) {
- IR ir = ImportCode("// nothing interesting there.", {});
+ IR ir = ImportCode({"// nothing interesting there."}, {});
EXPECT_THAT(ir.Functions(), IsEmpty());
+ EXPECT_THAT(ir.UsedHeaders(), SizeIs(1));
+ EXPECT_EQ(ir.UsedHeaders()[0].IncludePath(), "test/testing_header_0.h");
}
TEST(AstVisitorTest, TestIREmptyOnInvalidInput) {
- IR ir = ImportCode("int foo(); But this is not C++", {});
+ IR ir = ImportCode({"int foo(); But this is not C++"}, {});
EXPECT_THAT(ir.Functions(), IsEmpty());
}
TEST(AstVisitorTest, TestImportFuncWithVoidReturnType) {
- IR ir = ImportCode("void Foo();", {});
+ IR ir = ImportCode({"void Foo();"}, {});
ASSERT_THAT(ir.Functions(), SizeIs(1));
Func func = ir.Functions()[0];
EXPECT_EQ(func.Ident().Ident(), "Foo");
@@ -51,7 +80,7 @@
}
TEST(AstVisitorTest, TestImportTwoFuncs) {
- IR ir = ImportCode("void Foo(); void Bar();", {});
+ IR ir = ImportCode({"void Foo(); void Bar();"}, {});
ASSERT_THAT(ir.Functions(), SizeIs(2));
Func foo = ir.Functions()[0];
@@ -67,18 +96,24 @@
EXPECT_THAT(bar.Params(), IsEmpty());
}
+TEST(AstVisitorTest, TestImportTwoFuncsFromTwoHeaders) {
+ IR ir = ImportCode({"void Foo();", "void Bar();"}, {});
+ ASSERT_THAT(ir.Functions(), SizeIs(2));
+ Func foo = ir.Functions()[0];
+ EXPECT_EQ(foo.Ident().Ident(), "Foo");
+ Func bar = ir.Functions()[1];
+ EXPECT_EQ(bar.Ident().Ident(), "Bar");
+}
+
TEST(AstVisitorTest, TestImportFuncJustOnce) {
- IR ir = ImportCode(
- "void Foo();"
- "void Foo();",
- {});
+ IR ir = ImportCode({"void Foo(); void Foo();"}, {});
ASSERT_THAT(ir.Functions(), SizeIs(1));
Func func = ir.Functions()[0];
EXPECT_EQ(func.Ident().Ident(), "Foo");
}
TEST(AstVisitorTest, TestImportFuncParams) {
- IR ir = ImportCode("int Add(int a, int b);", {});
+ IR ir = ImportCode({"int Add(int a, int b);"}, {});
EXPECT_THAT(ir.Functions(), SizeIs(1));
Func func = ir.Functions()[0];
diff --git a/rs_bindings_from_cc/frontend_action.cc b/rs_bindings_from_cc/frontend_action.cc
index 6de4495..a4d3708 100644
--- a/rs_bindings_from_cc/frontend_action.cc
+++ b/rs_bindings_from_cc/frontend_action.cc
@@ -14,7 +14,7 @@
std::unique_ptr<clang::ASTConsumer> FrontendAction::CreateASTConsumer(
clang::CompilerInstance &, llvm::StringRef) {
- return std::make_unique<AstConsumer>(ir_);
+ return std::make_unique<AstConsumer>(public_headers_, ir_);
}
} // namespace rs_bindings_from_cc
diff --git a/rs_bindings_from_cc/frontend_action.h b/rs_bindings_from_cc/frontend_action.h
index 833769f..e9c414c 100644
--- a/rs_bindings_from_cc/frontend_action.h
+++ b/rs_bindings_from_cc/frontend_action.h
@@ -6,8 +6,10 @@
#define CRUBIT_RS_BINDINGS_FROM_CC_FRONTEND_ACTION_H_
#include <memory>
+#include <string>
#include "rs_bindings_from_cc/ir.h"
+#include "third_party/absl/types/span.h"
#include "third_party/llvm/llvm-project/clang/include/clang/AST/ASTConsumer.h"
#include "third_party/llvm/llvm-project/clang/include/clang/Frontend/CompilerInstance.h"
#include "third_party/llvm/llvm-project/clang/include/clang/Frontend/FrontendAction.h"
@@ -18,12 +20,14 @@
// (`IR`) into the `ir` parameter.
class FrontendAction : public clang::ASTFrontendAction {
public:
- explicit FrontendAction(IR &ir) : ir_(ir) {}
+ explicit FrontendAction(absl::Span<const std::string> public_headers, IR &ir)
+ : public_headers_(public_headers), ir_(ir) {}
std::unique_ptr<clang::ASTConsumer> CreateASTConsumer(
clang::CompilerInstance &, llvm::StringRef) override;
private:
+ absl::Span<const std::string> public_headers_;
IR &ir_;
};
diff --git a/rs_bindings_from_cc/ir.cc b/rs_bindings_from_cc/ir.cc
index 2f8e5a8..8aecec4 100644
--- a/rs_bindings_from_cc/ir.cc
+++ b/rs_bindings_from_cc/ir.cc
@@ -12,6 +12,12 @@
namespace rs_bindings_from_cc {
+nlohmann::json HeaderName::ToJson() const {
+ nlohmann::json result;
+ result["name"] = std::string(name_);
+ return result;
+}
+
nlohmann::json Type::ToJson() const {
nlohmann::json result;
result["rs_name"] = std::string(rs_name_);
@@ -45,11 +51,18 @@
}
nlohmann::json IR::ToJson() const {
+ std::vector<nlohmann::json> used_headers;
+ for (const HeaderName& header : used_headers_) {
+ used_headers.push_back(header.ToJson());
+ }
+
std::vector<nlohmann::json> functions;
for (const Func& func : functions_) {
functions.push_back(func.ToJson());
}
+
nlohmann::json result;
+ result["used_headers"] = used_headers;
result["functions"] = functions;
return result;
}
diff --git a/rs_bindings_from_cc/ir.h b/rs_bindings_from_cc/ir.h
index 46b891b..628cd94 100644
--- a/rs_bindings_from_cc/ir.h
+++ b/rs_bindings_from_cc/ir.h
@@ -20,6 +20,21 @@
namespace rs_bindings_from_cc {
+// A name of a public header of the C++ library.
+class HeaderName {
+ public:
+ explicit HeaderName(absl::Cord name) : name_(std::move(name)) {}
+
+ const absl::Cord &IncludePath() const { return name_; }
+
+ nlohmann::json ToJson() const;
+
+ private:
+ // Header pathname in the format suitable for a google3-relative quote
+ // include.
+ absl::Cord name_;
+};
+
// A type involved in the bindings. It has the knowledge about how the type is
// spelled in Rust and in C++ code.
//
@@ -117,14 +132,22 @@
class IR {
public:
explicit IR() {}
- explicit IR(std::vector<Func> functions) : functions_(std::move(functions)) {}
+ explicit IR(std::vector<HeaderName> used_headers, std::vector<Func> functions)
+ : used_headers_(std::move(used_headers)),
+ functions_(std::move(functions)) {}
nlohmann::json ToJson() const;
+ const std::vector<HeaderName> &UsedHeaders() const { return used_headers_; }
+ std::vector<HeaderName> &UsedHeaders() { return used_headers_; }
+
const std::vector<Func> &Functions() const { return functions_; }
std::vector<Func> &Functions() { return functions_; }
private:
+ // Collection of public headers that were used to construct the AST this `IR`
+ // is generated from.
+ std::vector<HeaderName> used_headers_;
std::vector<Func> functions_;
};
diff --git a/rs_bindings_from_cc/ir.rs b/rs_bindings_from_cc/ir.rs
index 4272337..bb00a19 100644
--- a/rs_bindings_from_cc/ir.rs
+++ b/rs_bindings_from_cc/ir.rs
@@ -2,6 +2,8 @@
// Exceptions. See /LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+/// Types and deserialization logic for IR. See docs in
+// `rs_bindings_from_cc/ir.h` for more information.
use anyhow::Result;
use serde::Deserialize;
use std::io::Read;
@@ -11,6 +13,11 @@
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Deserialize)]
+pub struct HeaderName {
+ pub name: String,
+}
+
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Deserialize)]
pub struct IRType {
pub rs_name: String,
}
@@ -37,6 +44,7 @@
#[derive(Debug, PartialEq, Eq, Hash, Clone, Deserialize)]
pub struct IR {
+ pub used_headers: Vec<HeaderName>,
pub functions: Vec<Func>,
}
@@ -48,6 +56,7 @@
fn test_deserializing() {
let input = r#"
{
+ "used_headers": [{ "name": "foo/bar.h" }],
"functions": [
{
"identifier": { "identifier": "hello_world" },
@@ -65,6 +74,7 @@
"#;
let ir = deserialize_ir(input.as_bytes()).unwrap();
let expected = IR {
+ used_headers: vec![HeaderName { name: "foo/bar.h".to_string() }],
functions: vec![Func {
identifier: Identifier { identifier: "hello_world".to_string() },
mangled_name: "$$mangled_name$$".to_string(),
diff --git a/rs_bindings_from_cc/ir_test.cc b/rs_bindings_from_cc/ir_test.cc
index e6649ea..81209ef 100644
--- a/rs_bindings_from_cc/ir_test.cc
+++ b/rs_bindings_from_cc/ir_test.cc
@@ -22,6 +22,7 @@
TEST(IrTest, TestIR) {
nlohmann::json expected = nlohmann::json::parse(
R"j({
+ "used_headers": [{ "name": "foo/bar.h" }],
"functions": [{
"identifier": { "identifier": "hello_world" },
"mangled_name": "#$mangled_name$#",
@@ -34,7 +35,8 @@
]
}]
})j");
- EXPECT_EQ(IR({Func(Identifier(absl::Cord("hello_world")),
+ EXPECT_EQ(IR({HeaderName(absl::Cord("foo/bar.h"))},
+ {Func(Identifier(absl::Cord("hello_world")),
absl::Cord("#$mangled_name$#"), Type(absl::Cord("i32")),
{FuncParam(Type(absl::Cord("i32")),
Identifier(absl::Cord("arg")))})})
diff --git a/rs_bindings_from_cc/rs_bindings_from_cc.cc b/rs_bindings_from_cc/rs_bindings_from_cc.cc
index 0e4d116..710a8e8 100644
--- a/rs_bindings_from_cc/rs_bindings_from_cc.cc
+++ b/rs_bindings_from_cc/rs_bindings_from_cc.cc
@@ -32,7 +32,8 @@
"output path for the C++ source file with bindings implementation");
ABSL_FLAG(std::vector<std::string>, public_headers, std::vector<std::string>(),
"public headers of the cc_library this tool should generate bindings "
- "for, in a format suitable for usage in #include \"\".");
+ "for, in a format suitable for usage in google3-relative quote "
+ "include (#include \"\").");
constexpr absl::string_view kVirtualInputPath =
"rs_bindings_from_cc_virtual_input.cc";
@@ -63,15 +64,16 @@
rs_bindings_from_cc::IR ir;
if (devtools::cymbal::RunToolWithClangFlagsOnCode(
command_line, file_contents,
- std::make_unique<rs_bindings_from_cc::FrontendAction>(ir))) {
+ std::make_unique<rs_bindings_from_cc::FrontendAction>(public_headers,
+ ir))) {
std::string rs_api = rs_bindings_from_cc::GenerateRustApi(ir);
std::string rs_api_impl = "// No bindings implementation code was needed.";
CHECK_OK(file::SetContents(rs_out, rs_api, file::Defaults()));
CHECK_OK(file::SetContents(cc_out, rs_api_impl, file::Defaults()));
return 0;
- } else {
- CHECK_OK(file::Delete(rs_out, file::Defaults()));
- CHECK_OK(file::Delete(cc_out, file::Defaults()));
- return 1;
}
+
+ CHECK_OK(file::Delete(rs_out, file::Defaults()));
+ CHECK_OK(file::Delete(cc_out, file::Defaults()));
+ return 1;
}
diff --git a/rs_bindings_from_cc/src_code_gen.rs b/rs_bindings_from_cc/src_code_gen.rs
index 71a8161..bfba321 100644
--- a/rs_bindings_from_cc/src_code_gen.rs
+++ b/rs_bindings_from_cc/src_code_gen.rs
@@ -95,6 +95,7 @@
#[test]
fn test_gen_src_code() -> Result<()> {
let ir = IR {
+ used_headers: vec![],
functions: vec![Func {
identifier: Identifier { identifier: "add".to_string() },
mangled_name: "_Z3Addii".to_string(),
@@ -114,7 +115,8 @@
let result = gen_src_code(ir)?;
assert_eq!(
result,
- quote! {#[inline(always)]
+ quote! {
+ #[inline(always)]
pub fn add(a: i32, b: i32) -> i32 {
unsafe { crate::detail::__rust_thunk__add(a, b) }
}
diff --git a/rs_bindings_from_cc/src_code_gen_test.cc b/rs_bindings_from_cc/src_code_gen_test.cc
index 7b2329d..d64ea28 100644
--- a/rs_bindings_from_cc/src_code_gen_test.cc
+++ b/rs_bindings_from_cc/src_code_gen_test.cc
@@ -18,10 +18,11 @@
using ::testing::StrEq;
TEST(SrcGenTest, FFIIntegration) {
- IR ir({Func(
- Identifier(absl::Cord("hello_world")), absl::Cord("$$mangled_name$$"),
- Type(absl::Cord("i32")),
- {FuncParam(Type(absl::Cord("i32")), Identifier(absl::Cord("arg")))})});
+ IR ir({HeaderName(absl::Cord("foo/bar.h"))},
+ {Func(Identifier(absl::Cord("hello_world")),
+ absl::Cord("$$mangled_name$$"), Type(absl::Cord("i32")),
+ {FuncParam(Type(absl::Cord("i32")),
+ Identifier(absl::Cord("arg")))})});
std::string rs_api = GenerateRustApi(ir);
EXPECT_THAT(
rs_api,