Open-source lifetime inference/verification code.

PiperOrigin-RevId: 450954978
diff --git a/lifetime_analysis/template_placeholder_support.cc b/lifetime_analysis/template_placeholder_support.cc
new file mode 100644
index 0000000..0daaf7a
--- /dev/null
+++ b/lifetime_analysis/template_placeholder_support.cc
@@ -0,0 +1,226 @@
+// Part of the Crubit project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "lifetime_analysis/template_placeholder_support.h"
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <string_view>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_replace.h"
+#include "clang/AST/DeclTemplate.h"
+#include "clang/ASTMatchers/ASTMatchFinder.h"
+#include "clang/ASTMatchers/ASTMatchers.h"
+#include "clang/Analysis/CFG.h"
+#include "clang/Lex/Lexer.h"
+#include "clang/Tooling/Tooling.h"
+#include "clang/Tooling/Transformer/Stencil.h"
+#include "clang/Tooling/Transformer/Transformer.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/VirtualFileSystem.h"
+
+namespace clang {
+namespace tidy {
+namespace lifetimes {
+
+namespace {
+
+using clang::ast_matchers::MatchFinder;
+
+class TranslationUnitMatcherCallback : public MatchFinder::MatchCallback {
+ public:
+  explicit TranslationUnitMatcherCallback(
+      std::function<void(clang::ASTContext&)> operation)
+      : operation_{operation} {}
+
+  void run(const MatchFinder::MatchResult& Result) override {
+    const auto* tu = Result.Nodes.getNodeAs<clang::TranslationUnitDecl>("tu");
+    if (!tu) return;
+    operation_(tu->getASTContext());
+  }
+
+  std::function<void(clang::ASTContext&)> operation_;
+};
+
+}  // namespace
+
+llvm::Expected<GeneratedCode> GenerateTemplateInstantiationCode(
+    const clang::TranslationUnitDecl* tu,
+    const llvm::DenseMap<clang::FunctionTemplateDecl*,
+                         const clang::FunctionDecl*>& templates) {
+  using clang::ast_matchers::asString;
+  using clang::ast_matchers::decl;
+  using clang::ast_matchers::equalsNode;
+  using clang::ast_matchers::functionDecl;
+  using clang::ast_matchers::functionTemplateDecl;
+  using clang::ast_matchers::hasBody;
+  using clang::ast_matchers::hasParent;
+  using clang::ast_matchers::loc;
+  using clang::ast_matchers::qualType;
+  using clang::ast_matchers::stmt;
+  using clang::ast_matchers::typeLoc;
+  using clang::tooling::Transformer;
+  using clang::transformer::cat;
+  using clang::transformer::charRange;
+  using clang::transformer::edit;
+  using clang::transformer::EditGenerator;
+  using clang::transformer::name;
+  using clang::transformer::node;
+  using clang::transformer::remove;
+
+  auto& context = tu->getASTContext();
+  auto file_id = tu->getASTContext().getSourceManager().getMainFileID();
+  auto& source_manager = context.getSourceManager();
+  auto source_filename =
+      source_manager.getFilename(source_manager.getLocForStartOfFile(file_id));
+
+  auto source_code = clang::Lexer::getSourceText(
+      clang::CharSourceRange::getTokenRange(
+          source_manager.getLocForStartOfFile(file_id),
+          source_manager.getLocForEndOfFile(file_id)),
+      source_manager, context.getLangOpts());
+
+  llvm::Error err = llvm::Error::success();
+  clang::tooling::AtomicChanges changes;
+  std::vector<std::unique_ptr<Transformer>> transformers;
+
+  auto consumer =
+      [&changes,
+       &err](llvm::Expected<llvm::MutableArrayRef<clang::tooling::AtomicChange>>
+                 c) {
+        if (c) {
+          changes.insert(changes.end(), std::make_move_iterator(c->begin()),
+                         std::make_move_iterator(c->end()));
+        } else {
+          err = c.takeError();
+          llvm::errs() << llvm::toString(c.takeError()) << "\n";
+        }
+      };
+
+  clang::TranslationUnitDecl* translation_unit =
+      context.getTranslationUnitDecl();
+  llvm::DenseSet<const clang::Decl*> toplevels(translation_unit->decls_begin(),
+                                               translation_unit->decls_end());
+
+  int placeholder_suffix_idx = 0;
+  std::vector<std::string> placeholder_classes;
+  for (const auto& [tmpl, func] : templates) {
+    toplevels.erase(tmpl);
+    auto* params = tmpl->getTemplateParameters();
+    std::vector<std::string> parameters;
+    llvm::SmallVector<EditGenerator, 2> edits;
+    std::string func_name = func->getNameAsString();
+
+    for (auto param : *params) {
+      // TODO(kinuko): check the template parameter types, this only assumes
+      // type parameters for now.
+      std::string placeholder_class = absl::StrCat(
+          func_name, "_type_placeholder_", placeholder_suffix_idx++);
+
+      placeholder_classes.push_back(placeholder_class);
+      parameters.push_back(placeholder_class);
+
+      auto change_type_rule =
+          makeRule(typeLoc(loc(qualType(asString(param->getNameAsString())))),
+                   changeTo(cat(placeholder_class)));
+      edits.push_back(rewriteDescendants(func_name, change_type_rule));
+    }
+
+    edits.push_back(edit(changeTo(node("body"), cat(";"))));
+    edits.push_back(edit(
+        changeTo(name(func_name),
+                 cat(absl::StrCat(func->getNameAsString(), "<",
+                                  absl::StrJoin(parameters, ", "), ">")))));
+    edits.push_back(edit(remove(charRange(clang::CharSourceRange::getCharRange(
+        params->getLAngleLoc(), params->getRAngleLoc().getLocWithOffset(1))))));
+
+    auto rule =
+        makeRule(functionDecl(equalsNode(func), hasBody(stmt().bind("body")),
+                              hasParent(functionTemplateDecl()))
+                     .bind(func_name),
+                 flattenVector(edits));
+    transformers.push_back(std::make_unique<Transformer>(rule, consumer));
+  }
+
+  for (const auto* node_to_delete : toplevels) {
+    // Delete all other top-level nodes (we only need the instantiation code as
+    // original code is to be included separately)
+    auto rule = makeRule(decl(equalsNode(node_to_delete)), changeTo(cat("")));
+    transformers.push_back(std::make_unique<Transformer>(rule, consumer));
+  }
+
+  std::string instantiation_code;
+  MatchFinder match_finder;
+  for (const auto& transformer : transformers) {
+    transformer->registerMatchers(&match_finder);
+  }
+  match_finder.matchAST(context);
+
+  // `consumer` might have produced an error.
+  if (err) return std::move(err);
+
+  if ((err = clang::tooling::applyAtomicChanges(
+                 source_filename, source_code, changes,
+                 clang::tooling::ApplyChangesSpec())
+                 .moveInto(instantiation_code))) {
+    return std::move(err);
+  }
+
+  // insertBefore or other transform edits don't work quite well, so simply
+  // concat and add the string.
+  std::vector<std::string> placeholder_definitions;
+  for (auto& c : placeholder_classes) {
+    placeholder_definitions.push_back("struct ");
+    placeholder_definitions.push_back(c);
+    placeholder_definitions.push_back(" {};\n");
+  }
+
+  GeneratedCode generated;
+  generated.filename = (source_filename + "-with-placeholders.cc").str();
+  generated.code = absl::StrCat("#include \"", source_filename.str(), "\"\n",
+                                absl::StrJoin(placeholder_definitions, ""),
+                                instantiation_code);
+  return generated;
+}
+
+void RunToolOnCodeWithOverlay(
+    clang::ASTContext& original_context, const std::string& filename,
+    const std::string& code,
+    const std::function<void(clang::ASTContext&)> operation) {
+  using clang::ast_matchers::MatchFinder;
+  using clang::ast_matchers::translationUnitDecl;
+
+  // Set up an overlay filesystem and add the `code` as a virtual file of it.
+  llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> fs(
+      &original_context.getSourceManager()
+           .getFileManager()
+           .getVirtualFileSystem());
+  auto overlay = llvm::makeIntrusiveRefCnt<llvm::vfs::OverlayFileSystem>(fs);
+  auto memory_fs = llvm::makeIntrusiveRefCnt<llvm::vfs::InMemoryFileSystem>();
+  overlay->pushOverlay(memory_fs);
+  memory_fs->addFile(filename, 0, llvm::MemoryBuffer::getMemBuffer(code));
+
+  clang::ast_matchers::MatchFinder match_finder;
+  TranslationUnitMatcherCallback callback(operation);
+
+  match_finder.addMatcher(translationUnitDecl().bind("tu"), &callback);
+  std::unique_ptr<clang::tooling::FrontendActionFactory> factory(
+      (clang::tooling::newFrontendActionFactory(&match_finder)));
+
+  // TODO(kinuko): get the args from the current ASTContext.
+  clang::tooling::runToolOnCodeWithArgs(factory->create(), code, overlay,
+                                        {"-fsyntax-only", "-std=c++17"},
+                                        filename, "lifetime-with-placedholder");
+}
+
+}  // namespace lifetimes
+}  // namespace tidy
+}  // namespace clang