Allow multiple iterations of inference in infer_tu[_main].
Each iteration uses inferences from the previous iteration as additional input.
PiperOrigin-RevId: 577840484
Change-Id: I5c845df5a2633774bc453298ce060759ff3587bc
diff --git a/nullability/inference/BUILD b/nullability/inference/BUILD
index 1432ac8..30b869a 100644
--- a/nullability/inference/BUILD
+++ b/nullability/inference/BUILD
@@ -105,6 +105,7 @@
":collect_evidence",
":inference_cc_proto",
":merge",
+ ":slot_fingerprint",
"@llvm-project//clang:ast",
"@llvm-project//clang:basic",
"@llvm-project//llvm:Support",
@@ -132,7 +133,6 @@
name = "infer_tu_test",
srcs = ["infer_tu_test.cc"],
deps = [
- ":collect_evidence",
":infer_tu",
":inference_cc_proto",
"//nullability:proto_matchers",
diff --git a/nullability/inference/infer_tu.cc b/nullability/inference/infer_tu.cc
index 2f9d616..d6d43af 100644
--- a/nullability/inference/infer_tu.cc
+++ b/nullability/inference/infer_tu.cc
@@ -10,63 +10,117 @@
#include "nullability/inference/collect_evidence.h"
#include "nullability/inference/inference.proto.h"
#include "nullability/inference/merge.h"
+#include "nullability/inference/slot_fingerprint.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/DeclBase.h"
#include "clang/Basic/SourceManager.h"
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/raw_ostream.h"
namespace clang::tidy::nullability {
+namespace {
-std::vector<Inference> inferTU(ASTContext& Ctx,
- llvm::function_ref<bool(const Decl&)> Filter) {
- if (!Ctx.getLangOpts().CPlusPlus) {
- llvm::errs() << "Skipping non-C++ input file: "
- << Ctx.getSourceManager()
- .getFileEntryForID(
- Ctx.getSourceManager().getMainFileID())
- ->getName()
- << "\n";
- return std::vector<Inference>();
- }
+class InferenceManager {
+ public:
+ InferenceManager(ASTContext& Ctx, unsigned Iterations,
+ llvm::function_ref<bool(const Decl&)> Filter)
+ : Ctx(Ctx), Iterations(Iterations), Filter(Filter) {}
- std::vector<Evidence> AllEvidence;
+ std::vector<Inference> inferenceRound(
+ EvidenceSites Sites, USRCache USRCache,
+ PreviousInferences InferencesFromLastRound) const {
+ std::vector<Inference> AllInference;
+ std::vector<Evidence> AllEvidence;
- // Collect all evidence.
- auto Sites = EvidenceSites::discover(Ctx);
- USRCache USRCache;
- auto Emitter =
- evidenceEmitter([&](auto& E) { AllEvidence.push_back(E); }, USRCache);
- for (const auto* Decl : Sites.Declarations) {
- if (Filter && !Filter(*Decl)) continue;
- collectEvidenceFromTargetDeclaration(*Decl, Emitter);
- }
- for (const auto* Impl : Sites.Implementations) {
- if (Filter && !Filter(*Impl)) continue;
- if (auto Err =
- collectEvidenceFromImplementation(*Impl, Emitter, USRCache)) {
- llvm::errs() << "Skipping function: " << toString(std::move(Err)) << "\n";
- Impl->print(llvm::errs());
+ // Collect all evidence.
+ auto Emitter =
+ evidenceEmitter([&](auto& E) { AllEvidence.push_back(E); }, USRCache);
+ for (const auto* Decl : Sites.Declarations) {
+ if (Filter && !Filter(*Decl)) continue;
+ collectEvidenceFromTargetDeclaration(*Decl, Emitter);
}
- }
- // Group by symbol.
- llvm::sort(AllEvidence, [&](const Evidence& L, const Evidence& R) {
- return L.symbol().usr() < R.symbol().usr();
- });
- // For each symbol, combine evidence into an inference.
- llvm::ArrayRef<Evidence> RemainingEvidence = AllEvidence;
- std::vector<Inference> AllInference;
- while (!RemainingEvidence.empty()) {
- auto Batch = RemainingEvidence.take_while([&](const Evidence& E) {
- return E.symbol().usr() == RemainingEvidence.front().symbol().usr();
+ for (const auto* Impl : Sites.Implementations) {
+ if (Filter && !Filter(*Impl)) continue;
+ if (auto Err = collectEvidenceFromImplementation(
+ *Impl, Emitter, USRCache, InferencesFromLastRound)) {
+ llvm::errs() << "Skipping function: " << toString(std::move(Err))
+ << "\n";
+ }
+ }
+ // Group by symbol.
+ llvm::sort(AllEvidence, [&](const Evidence& L, const Evidence& R) {
+ return L.symbol().usr() < R.symbol().usr();
});
- RemainingEvidence = RemainingEvidence.drop_front(Batch.size());
- AllInference.push_back(mergeEvidence(Batch));
+ // For each symbol, combine evidence into an inference.
+ llvm::ArrayRef<Evidence> RemainingEvidence = AllEvidence;
+
+ while (!RemainingEvidence.empty()) {
+ auto Batch = RemainingEvidence.take_while([&](const Evidence& E) {
+ return E.symbol().usr() == RemainingEvidence.front().symbol().usr();
+ });
+ RemainingEvidence = RemainingEvidence.drop_front(Batch.size());
+ AllInference.push_back(mergeEvidence(Batch));
+ }
+ return AllInference;
}
- return AllInference;
+ std::vector<Inference> iterativelyInfer() const {
+ if (!Ctx.getLangOpts().CPlusPlus) {
+ llvm::errs() << "Skipping non-C++ input file: "
+ << Ctx.getSourceManager()
+ .getFileEntryForID(
+ Ctx.getSourceManager().getMainFileID())
+ ->getName()
+ << "\n";
+ return std::vector<Inference>();
+ }
+ auto Sites = EvidenceSites::discover(Ctx);
+ USRCache USRCache;
+
+ std::vector<Inference> AllInference = inferenceRound(Sites, USRCache, {});
+
+ for (unsigned Iteration = 1; Iteration < Iterations; ++Iteration) {
+ llvm::DenseSet<SlotFingerprint> NullableFromLastRound;
+ llvm::DenseSet<SlotFingerprint> NonnullFromLastRound;
+
+ for (const auto& Inference : AllInference) {
+ for (const auto& slot_inference : Inference.slot_inference()) {
+ if (slot_inference.trivial() || slot_inference.conflict()) continue;
+ switch (slot_inference.nullability()) {
+ case Inference::NULLABLE:
+ NullableFromLastRound.insert(
+ fingerprint(Inference.symbol().usr(), slot_inference.slot()));
+ break;
+ case Inference::NONNULL:
+ NonnullFromLastRound.insert(
+ fingerprint(Inference.symbol().usr(), slot_inference.slot()));
+ break;
+ default:
+ break;
+ }
+ }
+ }
+
+ AllInference = inferenceRound(
+ Sites, USRCache, {NullableFromLastRound, NonnullFromLastRound});
+ }
+ return AllInference;
+ }
+
+ private:
+ ASTContext& Ctx;
+ unsigned Iterations;
+ llvm::function_ref<bool(const Decl&)> Filter;
+};
+} // namespace
+
+std::vector<Inference> inferTU(ASTContext& Ctx, unsigned Iterations,
+ llvm::function_ref<bool(const Decl&)> Filter) {
+ return InferenceManager(Ctx, Iterations, Filter).iterativelyInfer();
}
} // namespace clang::tidy::nullability
diff --git a/nullability/inference/infer_tu.h b/nullability/inference/infer_tu.h
index 1f0083f..65d237d 100644
--- a/nullability/inference/infer_tu.h
+++ b/nullability/inference/infer_tu.h
@@ -9,6 +9,7 @@
#include "nullability/inference/inference.proto.h"
#include "clang/AST/ASTContext.h"
+#include "clang/AST/DeclBase.h"
#include "llvm/ADT/STLFunctionalExtras.h"
namespace clang::tidy::nullability {
@@ -22,7 +23,8 @@
//
// If Filter is provided, only considers decls that return true.
std::vector<Inference> inferTU(
- ASTContext &, llvm::function_ref<bool(const Decl &)> Filter = nullptr);
+ ASTContext &, unsigned Iterations = 1,
+ llvm::function_ref<bool(const Decl &)> Filter = nullptr);
} // namespace clang::tidy::nullability
diff --git a/nullability/inference/infer_tu_main.cc b/nullability/inference/infer_tu_main.cc
index 6ed1639..656d125 100644
--- a/nullability/inference/infer_tu_main.cc
+++ b/nullability/inference/infer_tu_main.cc
@@ -72,6 +72,11 @@
llvm::cl::desc("Regular expression decl names must match to be analyzed. "
"May be negated with - prefix."),
};
+llvm::cl::opt<unsigned> Iterations{
+ "iterations",
+ llvm::cl::desc("Number of inference iterations"),
+ llvm::cl::init(1),
+};
namespace clang::tidy::nullability {
namespace {
@@ -186,7 +191,7 @@
void HandleTranslationUnit(ASTContext &Ctx) override {
llvm::errs() << "Running inference...\n";
- auto Results = inferTU(Ctx, DeclFilter());
+ auto Results = inferTU(Ctx, Iterations, DeclFilter());
if (!IncludeTrivial)
llvm::erase_if(Results, [](Inference &I) {
llvm::erase_if(
diff --git a/nullability/inference/infer_tu_test.cc b/nullability/inference/infer_tu_test.cc
index 92f68e1..4814523 100644
--- a/nullability/inference/infer_tu_test.cc
+++ b/nullability/inference/infer_tu_test.cc
@@ -7,16 +7,15 @@
#include <optional>
#include <vector>
-#include "nullability/inference/collect_evidence.h"
#include "nullability/inference/inference.proto.h"
#include "nullability/proto_matchers.h"
#include "clang/AST/Decl.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/ASTMatchers/ASTMatchers.h"
+#include "clang/ASTMatchers/ASTMatchersMacros.h"
#include "clang/Basic/LLVM.h"
#include "clang/Index/USRGeneration.h"
#include "clang/Testing/TestAST.h"
-#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringRef.h"
#include "third_party/llvm/llvm-project/third-party/unittest/googlemock/include/gmock/gmock.h"
@@ -26,6 +25,8 @@
namespace {
using ast_matchers::hasName;
using testing::_;
+using testing::ElementsAre;
+using testing::UnorderedElementsAre;
MATCHER_P2(inferredSlot, I, Nullability, "") {
return arg.slot() == I && arg.nullability() == Nullability;
@@ -204,7 +205,7 @@
int* target1() { return nullptr; }
int* target2() { return nullptr; }
)cc");
- EXPECT_THAT(inferTU(AST->context(),
+ EXPECT_THAT(inferTU(AST->context(), /*Iterations=*/1,
[&](const Decl &D) {
return cast<NamedDecl>(D).getNameAsString() !=
"target2";
@@ -212,5 +213,47 @@
ElementsAre(inference(hasName("target1"), {_})));
}
+TEST_F(InferTUTest, IterationsPropagateInferences) {
+ build(R"cc(
+ int* returnsToBeNonnull(int* a) { return a; }
+ int* target(int* q) {
+ *q;
+ return returnsToBeNonnull(q);
+ }
+ )cc");
+ EXPECT_THAT(
+ inferTU(AST->context(), /*Iterations=*/1),
+ UnorderedElementsAre(
+ inference(hasName("target"), {inferredSlot(0, Inference::UNKNOWN),
+ inferredSlot(1, Inference::NONNULL)}),
+ inference(hasName("returnsToBeNonnull"),
+ {inferredSlot(0, Inference::UNKNOWN),
+ inferredSlot(1, Inference::UNKNOWN)})));
+ EXPECT_THAT(
+ inferTU(AST->context(), /*Iterations=*/2),
+ UnorderedElementsAre(
+ inference(hasName("target"), {inferredSlot(0, Inference::UNKNOWN),
+ inferredSlot(1, Inference::NONNULL)}),
+ inference(hasName("returnsToBeNonnull"),
+ {inferredSlot(0, Inference::UNKNOWN),
+ inferredSlot(1, Inference::NONNULL)})));
+ EXPECT_THAT(
+ inferTU(AST->context(), /*Iterations=*/3),
+ UnorderedElementsAre(
+ inference(hasName("target"), {inferredSlot(0, Inference::UNKNOWN),
+ inferredSlot(1, Inference::NONNULL)}),
+ inference(hasName("returnsToBeNonnull"),
+ {inferredSlot(0, Inference::NONNULL),
+ inferredSlot(1, Inference::NONNULL)})));
+ EXPECT_THAT(
+ inferTU(AST->context(), /*Iterations=*/4),
+ UnorderedElementsAre(
+ inference(hasName("target"), {inferredSlot(0, Inference::NONNULL),
+ inferredSlot(1, Inference::NONNULL)}),
+ inference(hasName("returnsToBeNonnull"),
+ {inferredSlot(0, Inference::NONNULL),
+ inferredSlot(1, Inference::NONNULL)})));
+}
+
} // namespace
} // namespace clang::tidy::nullability