Send IR to Rust to generate rust source code using quote!
To pass IR across the FFI boundary we serialize it to a json string, pass the string to Rust, and return a string with generated Rust source code back.
Alternatives considered:
Protobuf instead of json
------------------------
using protobuf as a serialization format would have some advantages (defining types only once, we wouldn't have to write our own serialization logic in C++), we decided not to use it because:
* there is no approved way of using protobufs in Rust in google3 today, and we didn't want to special case us. We could still use protobuf on the C++ side and generate json using it, so at least we don't have to write serialization code ourselves.
* we felt going with json and manual serialization will be more flexible in the uncertain future
* we tossed a coin and json won
Implementing our own Rust code generator in C++
-----------------------------------------------
We sketched the code in unknown commit, and we decided going with Rust solution is more readable and maintainable. We also plan to use Clang syntax trees to generate C++ source code, of which quote!, proc_macro2, and syn are moral equivalents.
PiperOrigin-RevId: 389566607
diff --git a/rs_bindings_from_cc/ir.cc b/rs_bindings_from_cc/ir.cc
new file mode 100644
index 0000000..2f8e5a8
--- /dev/null
+++ b/rs_bindings_from_cc/ir.cc
@@ -0,0 +1,57 @@
+// Part of the Crubit project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "rs_bindings_from_cc/ir.h"
+
+#include <string>
+#include <vector>
+
+#include "third_party/absl/strings/cord.h"
+#include "third_party/json/src/json.hpp"
+
+namespace rs_bindings_from_cc {
+
+nlohmann::json Type::ToJson() const {
+ nlohmann::json result;
+ result["rs_name"] = std::string(rs_name_);
+ return result;
+}
+
+nlohmann::json Identifier::ToJson() const {
+ nlohmann::json result;
+ result["identifier"] = std::string(identifier_);
+ return result;
+}
+
+nlohmann::json FuncParam::ToJson() const {
+ nlohmann::json result;
+ result["type"] = type_.ToJson();
+ result["identifier"] = identifier_.ToJson();
+ return result;
+}
+
+nlohmann::json Func::ToJson() const {
+ std::vector<nlohmann::json> params;
+ for (const FuncParam& param : params_) {
+ params.push_back(param.ToJson());
+ }
+ nlohmann::json result;
+ result["identifier"] = identifier_.ToJson();
+ result["mangled_name"] = std::string(mangled_name_);
+ result["return_type"] = return_type_.ToJson();
+ result["params"] = params;
+ return result;
+}
+
+nlohmann::json IR::ToJson() const {
+ std::vector<nlohmann::json> functions;
+ for (const Func& func : functions_) {
+ functions.push_back(func.ToJson());
+ }
+ nlohmann::json result;
+ result["functions"] = functions;
+ return result;
+}
+
+} // namespace rs_bindings_from_cc
diff --git a/rs_bindings_from_cc/ir.h b/rs_bindings_from_cc/ir.h
index 43bf33b..8c5cb68 100644
--- a/rs_bindings_from_cc/ir.h
+++ b/rs_bindings_from_cc/ir.h
@@ -16,6 +16,7 @@
#include "base/logging.h"
#include "third_party/absl/strings/cord.h"
+#include "third_party/json/src/json.hpp"
namespace rs_bindings_from_cc {
@@ -35,6 +36,8 @@
const absl::Cord &RsName() const { return rs_name_; }
+ nlohmann::json ToJson() const;
+
private:
absl::Cord rs_name_;
};
@@ -56,6 +59,8 @@
const absl::Cord &Ident() const { return identifier_; }
+ nlohmann::json ToJson() const;
+
private:
absl::Cord identifier_;
};
@@ -73,31 +78,18 @@
const Type &ParamType() const { return type_; }
const Identifier &Ident() const { return identifier_; }
+ nlohmann::json ToJson() const;
+
private:
Type type_;
Identifier identifier_;
};
-// All parameters of a function.
-//
-// Invariants:
-// `params` can be empty.
-class FuncParams {
- public:
- explicit FuncParams(std::vector<FuncParam> params)
- : params_(std::move(params)) {}
-
- const std::vector<FuncParam> &Params() const { return params_; }
-
- private:
- std::vector<FuncParam> params_;
-};
-
// A function involved in the bindings.
class Func {
public:
explicit Func(Identifier identifier, absl::Cord mangled_name,
- Type return_type, FuncParams params)
+ Type return_type, std::vector<FuncParam> params)
: identifier_(std::move(identifier)),
mangled_name_(std::move(mangled_name)),
return_type_(std::move(return_type)),
@@ -107,13 +99,15 @@
const Type &ReturnType() const { return return_type_; }
const Identifier &Ident() const { return identifier_; }
- const FuncParams &Params() const { return params_; }
+ const std::vector<FuncParam> &Params() const { return params_; }
+
+ nlohmann::json ToJson() const;
private:
Identifier identifier_;
absl::Cord mangled_name_;
Type return_type_;
- FuncParams params_;
+ std::vector<FuncParam> params_;
};
// A complete intermediate representation of bindings for publicly accessible
@@ -122,6 +116,8 @@
public:
explicit IR(std::vector<Func> functions) : functions_(std::move(functions)) {}
+ nlohmann::json ToJson() const;
+
private:
std::vector<Func> functions_;
};
diff --git a/rs_bindings_from_cc/ir.rs b/rs_bindings_from_cc/ir.rs
new file mode 100644
index 0000000..4272337
--- /dev/null
+++ b/rs_bindings_from_cc/ir.rs
@@ -0,0 +1,80 @@
+// Part of the Crubit project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+use anyhow::Result;
+use serde::Deserialize;
+use std::io::Read;
+
+pub fn deserialize_ir<R: Read>(reader: R) -> Result<IR> {
+ Ok(serde_json::from_reader(reader)?)
+}
+
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Deserialize)]
+pub struct IRType {
+ pub rs_name: String,
+}
+
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Deserialize)]
+pub struct Identifier {
+ pub identifier: String,
+}
+
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Deserialize)]
+pub struct FuncParam {
+ #[serde(rename(deserialize = "type"))]
+ pub type_: IRType,
+ pub identifier: Identifier,
+}
+
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Deserialize)]
+pub struct Func {
+ pub identifier: Identifier,
+ pub mangled_name: String,
+ pub return_type: IRType,
+ pub params: Vec<FuncParam>,
+}
+
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Deserialize)]
+pub struct IR {
+ pub functions: Vec<Func>,
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_deserializing() {
+ let input = r#"
+ {
+ "functions": [
+ {
+ "identifier": { "identifier": "hello_world" },
+ "mangled_name": "$$mangled_name$$",
+ "params": [
+ {
+ "identifier": { "identifier": "arg" },
+ "type": { "rs_name":"i32" }
+ }
+ ],
+ "return_type": { "rs_name": "i32" }
+ }
+ ]
+ }
+ "#;
+ let ir = deserialize_ir(input.as_bytes()).unwrap();
+ let expected = IR {
+ functions: vec![Func {
+ identifier: Identifier { identifier: "hello_world".to_string() },
+ mangled_name: "$$mangled_name$$".to_string(),
+ return_type: IRType { rs_name: "i32".to_string() },
+ params: vec![FuncParam {
+ type_: IRType { rs_name: "i32".to_string() },
+ identifier: Identifier { identifier: "arg".to_string() },
+ }],
+ }],
+ };
+ assert_eq!(ir, expected);
+ }
+}
diff --git a/rs_bindings_from_cc/ir_test.cc b/rs_bindings_from_cc/ir_test.cc
new file mode 100644
index 0000000..e6649ea
--- /dev/null
+++ b/rs_bindings_from_cc/ir_test.cc
@@ -0,0 +1,46 @@
+// Part of the Crubit project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "rs_bindings_from_cc/ir.h"
+
+#include <string>
+
+#include "testing/base/public/gunit.h"
+#include "third_party/absl/strings/cord.h"
+#include "third_party/json/src/json.hpp"
+
+namespace rs_bindings_from_cc {
+
+namespace {
+
+TEST(IrTest, TestTypeToJson) {
+ nlohmann::json expected = nlohmann::json::parse(R"j({ "rs_name": "i32" })j");
+ EXPECT_EQ(Type(absl::Cord("i32")).ToJson(), expected);
+}
+
+TEST(IrTest, TestIR) {
+ nlohmann::json expected = nlohmann::json::parse(
+ R"j({
+ "functions": [{
+ "identifier": { "identifier": "hello_world" },
+ "mangled_name": "#$mangled_name$#",
+ "return_type": { "rs_name": "i32" },
+ "params": [
+ {
+ "identifier": {"identifier": "arg" },
+ "type": { "rs_name": "i32" }
+ }
+ ]
+ }]
+ })j");
+ EXPECT_EQ(IR({Func(Identifier(absl::Cord("hello_world")),
+ absl::Cord("#$mangled_name$#"), Type(absl::Cord("i32")),
+ {FuncParam(Type(absl::Cord("i32")),
+ Identifier(absl::Cord("arg")))})})
+ .ToJson(),
+ expected);
+}
+
+} // namespace
+} // namespace rs_bindings_from_cc
diff --git a/rs_bindings_from_cc/rs_src_code_gen.cc b/rs_bindings_from_cc/rs_src_code_gen.cc
new file mode 100644
index 0000000..32b1e00
--- /dev/null
+++ b/rs_bindings_from_cc/rs_src_code_gen.cc
@@ -0,0 +1,48 @@
+// Part of the Crubit project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "rs_bindings_from_cc/rs_src_code_gen.h"
+
+#include <stddef.h>
+
+#include <string>
+
+#include "rs_bindings_from_cc/ir.h"
+#include "third_party/absl/strings/string_view.h"
+#include "third_party/json/src/json.hpp"
+
+namespace rs_bindings_from_cc {
+
+struct FfiU8SliceBox {
+ const char* ptr;
+ size_t size;
+};
+
+struct FfiU8Slice {
+ const char* ptr;
+ size_t size;
+};
+
+static FfiU8Slice MakeFfiU8Slice(absl::string_view s) {
+ FfiU8Slice result;
+ result.ptr = s.data();
+ result.size = s.size();
+ return result;
+}
+
+// This function is implemented in Rust.
+extern "C" FfiU8SliceBox GenerateRustApiImpl(FfiU8Slice);
+
+// This function is implemented in Rust.
+extern "C" void FreeFfiU8SliceBox(FfiU8SliceBox);
+
+std::string GenerateRustApi(const IR& ir) {
+ std::string json = ir.ToJson().dump();
+ FfiU8SliceBox slice_box = GenerateRustApiImpl(MakeFfiU8Slice(json));
+ std::string rs_api(slice_box.ptr, slice_box.size);
+ FreeFfiU8SliceBox(slice_box);
+ return rs_api;
+}
+
+} // namespace rs_bindings_from_cc
diff --git a/rs_bindings_from_cc/rs_src_code_gen.h b/rs_bindings_from_cc/rs_src_code_gen.h
new file mode 100644
index 0000000..1a5e733
--- /dev/null
+++ b/rs_bindings_from_cc/rs_src_code_gen.h
@@ -0,0 +1,19 @@
+// Part of the Crubit project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef CRUBIT_RS_BINDINGS_FROM_CC_RS_SRC_CODE_GEN_H_
+#define CRUBIT_RS_BINDINGS_FROM_CC_RS_SRC_CODE_GEN_H_
+
+#include <string>
+
+#include "rs_bindings_from_cc/ir.h"
+
+namespace rs_bindings_from_cc {
+
+// Generates Rust bindings source code from the given `IR`.
+std::string GenerateRustApi(const IR &ir);
+
+} // namespace rs_bindings_from_cc
+
+#endif // CRUBIT_RS_BINDINGS_FROM_CC_RS_SRC_CODE_GEN_H_
diff --git a/rs_bindings_from_cc/rs_src_code_gen.rs b/rs_bindings_from_cc/rs_src_code_gen.rs
new file mode 100644
index 0000000..98f152d
--- /dev/null
+++ b/rs_bindings_from_cc/rs_src_code_gen.rs
@@ -0,0 +1,180 @@
+// Part of the Crubit project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+use anyhow::Result;
+use ir::*;
+use itertools::Itertools;
+use quote::format_ident;
+use quote::quote;
+use std::boxed::Box;
+use std::iter::Iterator;
+use std::panic::catch_unwind;
+use std::process;
+use std::slice;
+use syn::*;
+
+#[repr(C)]
+pub struct FfiU8Slice {
+ ptr: *const u8,
+ size: usize,
+}
+
+impl FfiU8Slice {
+ /// Borrows data pointed to by this `FfiU8Slice` as a slice.
+ fn as_slice(&self) -> &[u8] {
+ // Safety:
+ // Instances of `FfiU8Slice` are only created by FFI functions, which are unsafe themselves
+ // so it's their responsibility to maintain safety.
+ unsafe { slice::from_raw_parts(self.ptr, self.size) }
+ }
+}
+
+#[repr(C)]
+pub struct FfiU8SliceBox {
+ ptr: *const u8,
+ size: usize,
+}
+
+impl FfiU8SliceBox {
+ fn from_boxed_slice(bytes: Box<[u8]>) -> FfiU8SliceBox {
+ let slice = Box::leak(bytes);
+ FfiU8SliceBox { ptr: slice.as_mut_ptr(), size: slice.len() }
+ }
+
+ /// Consumes self and returns boxed slice.
+ fn into_boxed_slice(self) -> Box<[u8]> {
+ // Safety:
+ // Instances of `FfiU8SliceBox` are either created by `from_boxed_slice`, which is safe,
+ // or by FFI functions, which are unsafe themselves so it's their responsibility to maintain
+ // safety.
+ unsafe { Box::from_raw(slice::from_raw_parts_mut(self.ptr as *mut u8, self.size)) }
+ }
+}
+
+/// Deserializes IR from `json` and generates Rust bindings source code.
+///
+/// This function panics on error.
+///
+/// Ownership:
+/// * function doesn't take ownership of (in other words it borrows) the param `json`
+/// * function passes ownership of the returned value to the caller
+///
+/// Safety:
+/// * function expects that param `json` is a FfiU8Slice for a valid array of bytes with the
+/// given size.
+/// * function expects that param `json` doesn't change during the call.
+#[no_mangle]
+pub unsafe extern "C" fn GenerateRustApiImpl(json: FfiU8Slice) -> FfiU8SliceBox {
+ catch_unwind(|| {
+ let result = gen_rs_api(json.as_slice());
+ // it is ok to abort with the error message here.
+ FfiU8SliceBox::from_boxed_slice(result.unwrap().into_bytes().into_boxed_slice())
+ })
+ .unwrap_or_else(|_| process::abort())
+}
+
+/// Frees C-string allocated by Rust.
+#[no_mangle]
+pub unsafe extern "C" fn FreeFfiU8SliceBox(sb: FfiU8SliceBox) {
+ catch_unwind(|| {
+ let _ = sb.into_boxed_slice();
+ })
+ .unwrap_or_else(|_| process::abort())
+}
+
+fn gen_rs_api(json: &[u8]) -> Result<String> {
+ let ir = deserialize_ir(json)?;
+ Ok(gen_src_code(ir)?)
+}
+
+fn gen_src_code(ir: IR) -> Result<String> {
+ let mut thunks = vec![];
+ let mut api_funcs = vec![];
+ for func in ir.functions {
+ let mangled_name = &func.mangled_name;
+ let ident = make_ident(&func.identifier.identifier);
+ let thunk_ident = format_ident!("__rust_thunk__{}", &func.identifier.identifier);
+ let return_type_name = make_ident(&func.return_type.rs_name);
+
+ let param_idents =
+ func.params.iter().map(|p| make_ident(&p.identifier.identifier)).collect_vec();
+
+ let param_types = func.params.iter().map(|p| make_ident(&p.type_.rs_name)).collect_vec();
+
+ api_funcs.push(quote! {
+ #[inline(always)]
+ pub fn #ident( #( #param_idents: #param_types ),* ) -> #return_type_name {
+ unsafe { crate::detail::#thunk_ident( #( #param_idents ),* ) }
+ }
+ });
+
+ thunks.push(quote! {
+ #[link_name = #mangled_name]
+ pub(crate) fn #thunk_ident( #( #param_idents: #param_types ),* ) -> #return_type_name ;
+ });
+ }
+
+ let result = quote! {
+ #( #api_funcs )*
+
+ mod detail {
+ extern "C" {
+ #( #thunks )*
+ }
+ }
+ };
+
+ Ok(result.to_string())
+}
+
+fn make_ident(ident: &str) -> Ident {
+ format_ident!("{}", ident)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::gen_src_code;
+ use super::Result;
+ use ir::*;
+ use quote::quote;
+
+ #[test]
+ fn test_gen_src_code() -> Result<()> {
+ let ir = IR {
+ functions: vec![Func {
+ identifier: Identifier { identifier: "add".to_string() },
+ mangled_name: "_Z3Addii".to_string(),
+ return_type: IRType { rs_name: "i32".to_string() },
+ params: vec![
+ FuncParam {
+ identifier: Identifier { identifier: "a".to_string() },
+ type_: IRType { rs_name: "i32".to_string() },
+ },
+ FuncParam {
+ identifier: Identifier { identifier: "b".to_string() },
+ type_: IRType { rs_name: "i32".to_string() },
+ },
+ ],
+ }],
+ };
+ let result = gen_src_code(ir)?;
+ assert_eq!(
+ result,
+ quote! {#[inline(always)]
+ pub fn add(a: i32, b: i32) -> i32 {
+ unsafe { crate::detail::__rust_thunk__add(a, b) }
+ }
+
+ mod detail {
+ extern "C" {
+ #[link_name = "_Z3Addii"]
+ pub(crate) fn __rust_thunk__add(a: i32, b: i32) -> i32;
+ } // extern
+ } // mod detail
+ }
+ .to_string()
+ );
+ Ok(())
+ }
+}
diff --git a/rs_bindings_from_cc/rs_src_code_gen_test.cc b/rs_bindings_from_cc/rs_src_code_gen_test.cc
new file mode 100644
index 0000000..eac421e
--- /dev/null
+++ b/rs_bindings_from_cc/rs_src_code_gen_test.cc
@@ -0,0 +1,43 @@
+// Part of the Crubit project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "rs_bindings_from_cc/rs_src_code_gen.h"
+
+#include <string>
+
+#include "rs_bindings_from_cc/ir.h"
+#include "testing/base/public/gmock.h"
+#include "testing/base/public/gunit.h"
+#include "third_party/absl/strings/cord.h"
+
+namespace rs_bindings_from_cc {
+
+namespace {
+
+using ::testing::StrEq;
+
+TEST(RsSrcGenTest, FFIIntegration) {
+ IR ir({Func(
+ Identifier(absl::Cord("hello_world")), absl::Cord("$$mangled_name$$"),
+ Type(absl::Cord("i32")),
+ {FuncParam(Type(absl::Cord("i32")), Identifier(absl::Cord("arg")))})});
+ std::string rs_api = GenerateRustApi(ir);
+ EXPECT_THAT(
+ rs_api,
+ StrEq(
+ // TODO(hlopko): Run generated sources through rustfmt.
+ "# [inline (always)] "
+ "pub fn hello_world (arg : i32) -> i32 { "
+ "unsafe { crate :: detail :: __rust_thunk__hello_world (arg) } "
+ "} "
+ "mod detail { "
+ "extern \"C\" { "
+ "# [link_name = \"$$mangled_name$$\"] "
+ "pub (crate) fn __rust_thunk__hello_world (arg : i32) -> i32 ; "
+ "} "
+ "}"));
+}
+
+} // namespace
+} // namespace rs_bindings_from_cc