`infer_tu_main -name-filter/-file-filter` limits analysis to interesting functions.

PiperOrigin-RevId: 566270192
Change-Id: I83419219ad5a8d31e95b3cfae9253381f455df97
diff --git a/nullability/inference/BUILD b/nullability/inference/BUILD
index e8edd98..93716e7 100644
--- a/nullability/inference/BUILD
+++ b/nullability/inference/BUILD
@@ -115,13 +115,15 @@
     name = "infer_tu_test",
     srcs = ["infer_tu_test.cc"],
     deps = [
+        ":collect_evidence",
         ":infer_tu",
         ":inference_cc_proto",
         "//nullability:proto_matchers",
+        "@llvm-project//clang:ast",
         "@llvm-project//clang:ast_matchers",
+        "@llvm-project//clang:basic",
         "@llvm-project//clang:index",
         "@llvm-project//clang:testing",
-        "@llvm-project//clang/unittests:dataflow_testing_support",
         "@llvm-project//llvm:Support",
         "@llvm-project//third-party/unittest:gmock",
         "@llvm-project//third-party/unittest:gtest",
@@ -133,10 +135,12 @@
     name = "infer_tu_main",
     srcs = ["infer_tu_main.cc"],
     deps = [
+        ":collect_evidence",
         ":infer_tu",
         ":inference_cc_proto",
         "@absl//absl/log:check",
         "@llvm-project//clang:ast",
+        "@llvm-project//clang:basic",
         "@llvm-project//clang:frontend",
         "@llvm-project//clang:index",
         "@llvm-project//clang:tooling",
diff --git a/nullability/inference/infer_tu.cc b/nullability/inference/infer_tu.cc
index 9354d14..cd783b2 100644
--- a/nullability/inference/infer_tu.cc
+++ b/nullability/inference/infer_tu.cc
@@ -19,7 +19,8 @@
 
 namespace clang::tidy::nullability {
 
-std::vector<Inference> inferTU(ASTContext& Ctx) {
+std::vector<Inference> inferTU(ASTContext& Ctx,
+                               llvm::function_ref<bool(const Decl&)> Filter) {
   if (!Ctx.getLangOpts().CPlusPlus) {
     llvm::errs() << "Skipping non-C++ input file: "
                  << Ctx.getSourceManager()
@@ -35,9 +36,12 @@
   // Collect all evidence.
   auto Sites = EvidenceSites::discover(Ctx);
   auto Emitter = evidenceEmitter([&](auto& E) { AllEvidence.push_back(E); });
-  for (const auto* Decl : Sites.Declarations)
+  for (const auto* Decl : Sites.Declarations) {
+    if (Filter && !Filter(*Decl)) continue;
     collectEvidenceFromTargetDeclaration(*Decl, Emitter);
+  }
   for (const auto* Impl : Sites.Implementations) {
+    if (Filter && !Filter(*Impl)) continue;
     if (auto Err = collectEvidenceFromImplementation(*Impl, Emitter)) {
       llvm::errs() << "Skipping function: " << toString(std::move(Err)) << "\n";
       Impl->print(llvm::errs());
diff --git a/nullability/inference/infer_tu.h b/nullability/inference/infer_tu.h
index 7abdb7c..1f0083f 100644
--- a/nullability/inference/infer_tu.h
+++ b/nullability/inference/infer_tu.h
@@ -9,15 +9,20 @@
 
 #include "nullability/inference/inference.proto.h"
 #include "clang/AST/ASTContext.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
 
 namespace clang::tidy::nullability {
+struct EvidenceSites;
 
 // Performs nullability inference within the scope of a single translation unit.
 //
 // This is not as powerful as running inference over the whole codebase, but is
 // useful in observing the behavior of the inference system.
 // It also lets us write tests for the whole inference system.
-std::vector<Inference> inferTU(ASTContext &);
+//
+// If Filter is provided, only considers decls that return true.
+std::vector<Inference> inferTU(
+    ASTContext &, llvm::function_ref<bool(const Decl &)> Filter = nullptr);
 
 }  // namespace clang::tidy::nullability
 
diff --git a/nullability/inference/infer_tu_main.cc b/nullability/inference/infer_tu_main.cc
index 0946fbe..6591fe1 100644
--- a/nullability/inference/infer_tu_main.cc
+++ b/nullability/inference/infer_tu_main.cc
@@ -15,11 +15,14 @@
 #include <utility>
 
 #include "absl/log/check.h"
+#include "nullability/inference/collect_evidence.h"
 #include "nullability/inference/infer_tu.h"
 #include "nullability/inference/inference.proto.h"
 #include "clang/AST/ASTConsumer.h"
 #include "clang/AST/Decl.h"
+#include "clang/AST/DeclarationName.h"
 #include "clang/AST/RecursiveASTVisitor.h"
+#include "clang/Basic/SourceLocation.h"
 #include "clang/Frontend/CompilerInstance.h"
 #include "clang/Frontend/FrontendAction.h"
 #include "clang/Frontend/FrontendActions.h"
@@ -33,8 +36,10 @@
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/Twine.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Error.h"
+#include "llvm/Support/Regex.h"
 #include "llvm/Support/raw_ostream.h"
 
 llvm::cl::OptionCategory Opts("infer_tu_main options");
@@ -58,6 +63,16 @@
     llvm::cl::desc("Include trivial inferences (annotated, no conflicts)"),
     llvm::cl::init(false),
 };
+llvm::cl::opt<std::string> FileFilter{
+    "file-filter",
+    llvm::cl::desc("Regular expression filenames must match to be analyzed. "
+                   "May be negated with - prefix."),
+};
+llvm::cl::opt<std::string> NameFilter{
+    "name-filter",
+    llvm::cl::desc("Regular expression decl names must match to be analyzed. "
+                   "May be negated with - prefix."),
+};
 
 namespace clang::tidy::nullability {
 namespace {
@@ -128,13 +143,60 @@
   return false;
 }
 
+// Selects which declarations to analyze based on filter flags.
+struct DeclFilter {
+  bool operator()(const Decl &D) const {
+    auto &SM = D.getDeclContext()->getParentASTContext().getSourceManager();
+    if (!checkLocation(D.getLocation(), SM)) return false;
+    if (auto *ND = llvm::dyn_cast<NamedDecl>(&D))
+      if (!checkName(*ND)) return false;
+    return true;
+  }
+
+  bool checkLocation(SourceLocation Loc, const SourceManager &SM) const {
+    if (!FileFilter.getNumOccurrences()) return true;
+    auto ID = SM.getFileID(SM.getFileLoc(Loc));
+    auto [It, Inserted] = FileCache.try_emplace(ID);
+    if (Inserted) {
+      static auto &Pattern = *new RegexFlagFilter(FileFilter);
+      auto *FID = SM.getFileEntryForID(ID);
+      It->second = !FID || Pattern(FID->getName());
+    }
+    return It->second;
+  }
+
+  bool checkName(const NamedDecl &ND) const {
+    if (!NameFilter.getNumOccurrences()) return true;
+    static auto &Pattern = *new RegexFlagFilter(NameFilter);
+    return Pattern(ND.getQualifiedNameAsString());
+  }
+
+  mutable llvm::DenseMap<FileID, bool> FileCache;
+  struct RegexFlagFilter {
+    RegexFlagFilter(llvm::StringRef Regex)
+        : Negative(Regex.consume_front("-")), Pattern(Regex) {
+      std::string Err;
+      CHECK(Pattern.isValid(Err)) << Regex.str() << ": " << Err;
+    }
+
+    bool operator()(llvm::StringRef Text) {
+      bool Match = Pattern.match(Text);
+      return Negative ? !Match : Match;
+    }
+
+    bool Negative;
+    llvm::Regex Pattern;
+  };
+};
+
 class Action : public SyntaxOnlyAction {
   std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &,
                                                  llvm::StringRef) override {
     class Consumer : public ASTConsumer {
       void HandleTranslationUnit(ASTContext &Ctx) override {
         llvm::errs() << "Running inference...\n";
-        auto Results = inferTU(Ctx);
+
+        auto Results = inferTU(Ctx, DeclFilter());
         if (!IncludeTrivial)
           llvm::erase_if(Results, [](Inference &I) {
             llvm::erase_if(*I.mutable_slot_inference(), isTrivial);
diff --git a/nullability/inference/infer_tu_test.cc b/nullability/inference/infer_tu_test.cc
index 90e8a84..92f68e1 100644
--- a/nullability/inference/infer_tu_test.cc
+++ b/nullability/inference/infer_tu_test.cc
@@ -7,12 +7,16 @@
 #include <optional>
 #include <vector>
 
+#include "nullability/inference/collect_evidence.h"
 #include "nullability/inference/inference.proto.h"
 #include "nullability/proto_matchers.h"
+#include "clang/AST/Decl.h"
 #include "clang/ASTMatchers/ASTMatchFinder.h"
 #include "clang/ASTMatchers/ASTMatchers.h"
+#include "clang/Basic/LLVM.h"
 #include "clang/Index/USRGeneration.h"
 #include "clang/Testing/TestAST.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringRef.h"
 #include "third_party/llvm/llvm-project/third-party/unittest/googlemock/include/gmock/gmock.h"
@@ -21,6 +25,7 @@
 namespace clang::tidy::nullability {
 namespace {
 using ast_matchers::hasName;
+using testing::_;
 
 MATCHER_P2(inferredSlot, I, Nullability, "") {
   return arg.slot() == I && arg.nullability() == Nullability;
@@ -34,9 +39,9 @@
 AST_MATCHER(Decl, isCanonical) { return Node.isCanonicalDecl(); }
 
 class InferTUTest : public ::testing::Test {
+ protected:
   std::optional<TestAST> AST;
 
- protected:
   void build(llvm::StringRef Code) {
     TestInputs Inputs = Code;
     Inputs.ExtraFiles["nullability.h"] = R"cc(
@@ -194,5 +199,18 @@
                                  {inferredSlot(0, Inference::NULLABLE)})));
 }
 
+TEST_F(InferTUTest, Filter) {
+  build(R"cc(
+    int* target1() { return nullptr; }
+    int* target2() { return nullptr; }
+  )cc");
+  EXPECT_THAT(inferTU(AST->context(),
+                      [&](const Decl &D) {
+                        return cast<NamedDecl>(D).getNameAsString() !=
+                               "target2";
+                      }),
+              ElementsAre(inference(hasName("target1"), {_})));
+}
+
 }  // namespace
 }  // namespace clang::tidy::nullability