`infer_tu_main -name-filter/-file-filter` limits analysis to interesting functions.
PiperOrigin-RevId: 566270192
Change-Id: I83419219ad5a8d31e95b3cfae9253381f455df97
diff --git a/nullability/inference/BUILD b/nullability/inference/BUILD
index e8edd98..93716e7 100644
--- a/nullability/inference/BUILD
+++ b/nullability/inference/BUILD
@@ -115,13 +115,15 @@
name = "infer_tu_test",
srcs = ["infer_tu_test.cc"],
deps = [
+ ":collect_evidence",
":infer_tu",
":inference_cc_proto",
"//nullability:proto_matchers",
+ "@llvm-project//clang:ast",
"@llvm-project//clang:ast_matchers",
+ "@llvm-project//clang:basic",
"@llvm-project//clang:index",
"@llvm-project//clang:testing",
- "@llvm-project//clang/unittests:dataflow_testing_support",
"@llvm-project//llvm:Support",
"@llvm-project//third-party/unittest:gmock",
"@llvm-project//third-party/unittest:gtest",
@@ -133,10 +135,12 @@
name = "infer_tu_main",
srcs = ["infer_tu_main.cc"],
deps = [
+ ":collect_evidence",
":infer_tu",
":inference_cc_proto",
"@absl//absl/log:check",
"@llvm-project//clang:ast",
+ "@llvm-project//clang:basic",
"@llvm-project//clang:frontend",
"@llvm-project//clang:index",
"@llvm-project//clang:tooling",
diff --git a/nullability/inference/infer_tu.cc b/nullability/inference/infer_tu.cc
index 9354d14..cd783b2 100644
--- a/nullability/inference/infer_tu.cc
+++ b/nullability/inference/infer_tu.cc
@@ -19,7 +19,8 @@
namespace clang::tidy::nullability {
-std::vector<Inference> inferTU(ASTContext& Ctx) {
+std::vector<Inference> inferTU(ASTContext& Ctx,
+ llvm::function_ref<bool(const Decl&)> Filter) {
if (!Ctx.getLangOpts().CPlusPlus) {
llvm::errs() << "Skipping non-C++ input file: "
<< Ctx.getSourceManager()
@@ -35,9 +36,12 @@
// Collect all evidence.
auto Sites = EvidenceSites::discover(Ctx);
auto Emitter = evidenceEmitter([&](auto& E) { AllEvidence.push_back(E); });
- for (const auto* Decl : Sites.Declarations)
+ for (const auto* Decl : Sites.Declarations) {
+ if (Filter && !Filter(*Decl)) continue;
collectEvidenceFromTargetDeclaration(*Decl, Emitter);
+ }
for (const auto* Impl : Sites.Implementations) {
+ if (Filter && !Filter(*Impl)) continue;
if (auto Err = collectEvidenceFromImplementation(*Impl, Emitter)) {
llvm::errs() << "Skipping function: " << toString(std::move(Err)) << "\n";
Impl->print(llvm::errs());
diff --git a/nullability/inference/infer_tu.h b/nullability/inference/infer_tu.h
index 7abdb7c..1f0083f 100644
--- a/nullability/inference/infer_tu.h
+++ b/nullability/inference/infer_tu.h
@@ -9,15 +9,20 @@
#include "nullability/inference/inference.proto.h"
#include "clang/AST/ASTContext.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
namespace clang::tidy::nullability {
+struct EvidenceSites;
// Performs nullability inference within the scope of a single translation unit.
//
// This is not as powerful as running inference over the whole codebase, but is
// useful in observing the behavior of the inference system.
// It also lets us write tests for the whole inference system.
-std::vector<Inference> inferTU(ASTContext &);
+//
+// If Filter is provided, only considers decls that return true.
+std::vector<Inference> inferTU(
+ ASTContext &, llvm::function_ref<bool(const Decl &)> Filter = nullptr);
} // namespace clang::tidy::nullability
diff --git a/nullability/inference/infer_tu_main.cc b/nullability/inference/infer_tu_main.cc
index 0946fbe..6591fe1 100644
--- a/nullability/inference/infer_tu_main.cc
+++ b/nullability/inference/infer_tu_main.cc
@@ -15,11 +15,14 @@
#include <utility>
#include "absl/log/check.h"
+#include "nullability/inference/collect_evidence.h"
#include "nullability/inference/infer_tu.h"
#include "nullability/inference/inference.proto.h"
#include "clang/AST/ASTConsumer.h"
#include "clang/AST/Decl.h"
+#include "clang/AST/DeclarationName.h"
#include "clang/AST/RecursiveASTVisitor.h"
+#include "clang/Basic/SourceLocation.h"
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Frontend/FrontendAction.h"
#include "clang/Frontend/FrontendActions.h"
@@ -33,8 +36,10 @@
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Error.h"
+#include "llvm/Support/Regex.h"
#include "llvm/Support/raw_ostream.h"
llvm::cl::OptionCategory Opts("infer_tu_main options");
@@ -58,6 +63,16 @@
llvm::cl::desc("Include trivial inferences (annotated, no conflicts)"),
llvm::cl::init(false),
};
+llvm::cl::opt<std::string> FileFilter{
+ "file-filter",
+ llvm::cl::desc("Regular expression filenames must match to be analyzed. "
+ "May be negated with - prefix."),
+};
+llvm::cl::opt<std::string> NameFilter{
+ "name-filter",
+ llvm::cl::desc("Regular expression decl names must match to be analyzed. "
+ "May be negated with - prefix."),
+};
namespace clang::tidy::nullability {
namespace {
@@ -128,13 +143,60 @@
return false;
}
+// Selects which declarations to analyze based on filter flags.
+struct DeclFilter {
+ bool operator()(const Decl &D) const {
+ auto &SM = D.getDeclContext()->getParentASTContext().getSourceManager();
+ if (!checkLocation(D.getLocation(), SM)) return false;
+ if (auto *ND = llvm::dyn_cast<NamedDecl>(&D))
+ if (!checkName(*ND)) return false;
+ return true;
+ }
+
+ bool checkLocation(SourceLocation Loc, const SourceManager &SM) const {
+ if (!FileFilter.getNumOccurrences()) return true;
+ auto ID = SM.getFileID(SM.getFileLoc(Loc));
+ auto [It, Inserted] = FileCache.try_emplace(ID);
+ if (Inserted) {
+ static auto &Pattern = *new RegexFlagFilter(FileFilter);
+ auto *FID = SM.getFileEntryForID(ID);
+ It->second = !FID || Pattern(FID->getName());
+ }
+ return It->second;
+ }
+
+ bool checkName(const NamedDecl &ND) const {
+ if (!NameFilter.getNumOccurrences()) return true;
+ static auto &Pattern = *new RegexFlagFilter(NameFilter);
+ return Pattern(ND.getQualifiedNameAsString());
+ }
+
+ mutable llvm::DenseMap<FileID, bool> FileCache;
+ struct RegexFlagFilter {
+ RegexFlagFilter(llvm::StringRef Regex)
+ : Negative(Regex.consume_front("-")), Pattern(Regex) {
+ std::string Err;
+ CHECK(Pattern.isValid(Err)) << Regex.str() << ": " << Err;
+ }
+
+ bool operator()(llvm::StringRef Text) {
+ bool Match = Pattern.match(Text);
+ return Negative ? !Match : Match;
+ }
+
+ bool Negative;
+ llvm::Regex Pattern;
+ };
+};
+
class Action : public SyntaxOnlyAction {
std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &,
llvm::StringRef) override {
class Consumer : public ASTConsumer {
void HandleTranslationUnit(ASTContext &Ctx) override {
llvm::errs() << "Running inference...\n";
- auto Results = inferTU(Ctx);
+
+ auto Results = inferTU(Ctx, DeclFilter());
if (!IncludeTrivial)
llvm::erase_if(Results, [](Inference &I) {
llvm::erase_if(*I.mutable_slot_inference(), isTrivial);
diff --git a/nullability/inference/infer_tu_test.cc b/nullability/inference/infer_tu_test.cc
index 90e8a84..92f68e1 100644
--- a/nullability/inference/infer_tu_test.cc
+++ b/nullability/inference/infer_tu_test.cc
@@ -7,12 +7,16 @@
#include <optional>
#include <vector>
+#include "nullability/inference/collect_evidence.h"
#include "nullability/inference/inference.proto.h"
#include "nullability/proto_matchers.h"
+#include "clang/AST/Decl.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/ASTMatchers/ASTMatchers.h"
+#include "clang/Basic/LLVM.h"
#include "clang/Index/USRGeneration.h"
#include "clang/Testing/TestAST.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringRef.h"
#include "third_party/llvm/llvm-project/third-party/unittest/googlemock/include/gmock/gmock.h"
@@ -21,6 +25,7 @@
namespace clang::tidy::nullability {
namespace {
using ast_matchers::hasName;
+using testing::_;
MATCHER_P2(inferredSlot, I, Nullability, "") {
return arg.slot() == I && arg.nullability() == Nullability;
@@ -34,9 +39,9 @@
AST_MATCHER(Decl, isCanonical) { return Node.isCanonicalDecl(); }
class InferTUTest : public ::testing::Test {
+ protected:
std::optional<TestAST> AST;
- protected:
void build(llvm::StringRef Code) {
TestInputs Inputs = Code;
Inputs.ExtraFiles["nullability.h"] = R"cc(
@@ -194,5 +199,18 @@
{inferredSlot(0, Inference::NULLABLE)})));
}
+TEST_F(InferTUTest, Filter) {
+ build(R"cc(
+ int* target1() { return nullptr; }
+ int* target2() { return nullptr; }
+ )cc");
+ EXPECT_THAT(inferTU(AST->context(),
+ [&](const Decl &D) {
+ return cast<NamedDecl>(D).getNameAsString() !=
+ "target2";
+ }),
+ ElementsAre(inference(hasName("target1"), {_})));
+}
+
} // namespace
} // namespace clang::tidy::nullability