Allow multiple iterations of inference in infer_tu[_main].

Each iteration uses inferences from the previous iteration as additional input.

PiperOrigin-RevId: 577840484
Change-Id: I5c845df5a2633774bc453298ce060759ff3587bc
diff --git a/nullability/inference/BUILD b/nullability/inference/BUILD
index 1432ac8..30b869a 100644
--- a/nullability/inference/BUILD
+++ b/nullability/inference/BUILD
@@ -105,6 +105,7 @@
         ":collect_evidence",
         ":inference_cc_proto",
         ":merge",
+        ":slot_fingerprint",
         "@llvm-project//clang:ast",
         "@llvm-project//clang:basic",
         "@llvm-project//llvm:Support",
@@ -132,7 +133,6 @@
     name = "infer_tu_test",
     srcs = ["infer_tu_test.cc"],
     deps = [
-        ":collect_evidence",
         ":infer_tu",
         ":inference_cc_proto",
         "//nullability:proto_matchers",
diff --git a/nullability/inference/infer_tu.cc b/nullability/inference/infer_tu.cc
index 2f9d616..d6d43af 100644
--- a/nullability/inference/infer_tu.cc
+++ b/nullability/inference/infer_tu.cc
@@ -10,63 +10,117 @@
 #include "nullability/inference/collect_evidence.h"
 #include "nullability/inference/inference.proto.h"
 #include "nullability/inference/merge.h"
+#include "nullability/inference/slot_fingerprint.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/DeclBase.h"
 #include "clang/Basic/SourceManager.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
 #include "llvm/Support/Error.h"
 #include "llvm/Support/raw_ostream.h"
 
 namespace clang::tidy::nullability {
+namespace {
 
-std::vector<Inference> inferTU(ASTContext& Ctx,
-                               llvm::function_ref<bool(const Decl&)> Filter) {
-  if (!Ctx.getLangOpts().CPlusPlus) {
-    llvm::errs() << "Skipping non-C++ input file: "
-                 << Ctx.getSourceManager()
-                        .getFileEntryForID(
-                            Ctx.getSourceManager().getMainFileID())
-                        ->getName()
-                 << "\n";
-    return std::vector<Inference>();
-  }
+class InferenceManager {
+ public:
+  InferenceManager(ASTContext& Ctx, unsigned Iterations,
+                   llvm::function_ref<bool(const Decl&)> Filter)
+      : Ctx(Ctx), Iterations(Iterations), Filter(Filter) {}
 
-  std::vector<Evidence> AllEvidence;
+  std::vector<Inference> inferenceRound(
+      EvidenceSites Sites, USRCache USRCache,
+      PreviousInferences InferencesFromLastRound) const {
+    std::vector<Inference> AllInference;
+    std::vector<Evidence> AllEvidence;
 
-  // Collect all evidence.
-  auto Sites = EvidenceSites::discover(Ctx);
-  USRCache USRCache;
-  auto Emitter =
-      evidenceEmitter([&](auto& E) { AllEvidence.push_back(E); }, USRCache);
-  for (const auto* Decl : Sites.Declarations) {
-    if (Filter && !Filter(*Decl)) continue;
-    collectEvidenceFromTargetDeclaration(*Decl, Emitter);
-  }
-  for (const auto* Impl : Sites.Implementations) {
-    if (Filter && !Filter(*Impl)) continue;
-    if (auto Err =
-            collectEvidenceFromImplementation(*Impl, Emitter, USRCache)) {
-      llvm::errs() << "Skipping function: " << toString(std::move(Err)) << "\n";
-      Impl->print(llvm::errs());
+    // Collect all evidence.
+    auto Emitter =
+        evidenceEmitter([&](auto& E) { AllEvidence.push_back(E); }, USRCache);
+    for (const auto* Decl : Sites.Declarations) {
+      if (Filter && !Filter(*Decl)) continue;
+      collectEvidenceFromTargetDeclaration(*Decl, Emitter);
     }
-  }
-  // Group by symbol.
-  llvm::sort(AllEvidence, [&](const Evidence& L, const Evidence& R) {
-    return L.symbol().usr() < R.symbol().usr();
-  });
-  // For each symbol, combine evidence into an inference.
-  llvm::ArrayRef<Evidence> RemainingEvidence = AllEvidence;
-  std::vector<Inference> AllInference;
-  while (!RemainingEvidence.empty()) {
-    auto Batch = RemainingEvidence.take_while([&](const Evidence& E) {
-      return E.symbol().usr() == RemainingEvidence.front().symbol().usr();
+    for (const auto* Impl : Sites.Implementations) {
+      if (Filter && !Filter(*Impl)) continue;
+      if (auto Err = collectEvidenceFromImplementation(
+              *Impl, Emitter, USRCache, InferencesFromLastRound)) {
+        llvm::errs() << "Skipping function: " << toString(std::move(Err))
+                     << "\n";
+      }
+    }
+    // Group by symbol.
+    llvm::sort(AllEvidence, [&](const Evidence& L, const Evidence& R) {
+      return L.symbol().usr() < R.symbol().usr();
     });
-    RemainingEvidence = RemainingEvidence.drop_front(Batch.size());
-    AllInference.push_back(mergeEvidence(Batch));
+    // For each symbol, combine evidence into an inference.
+    llvm::ArrayRef<Evidence> RemainingEvidence = AllEvidence;
+
+    while (!RemainingEvidence.empty()) {
+      auto Batch = RemainingEvidence.take_while([&](const Evidence& E) {
+        return E.symbol().usr() == RemainingEvidence.front().symbol().usr();
+      });
+      RemainingEvidence = RemainingEvidence.drop_front(Batch.size());
+      AllInference.push_back(mergeEvidence(Batch));
+    }
+    return AllInference;
   }
 
-  return AllInference;
+  std::vector<Inference> iterativelyInfer() const {
+    if (!Ctx.getLangOpts().CPlusPlus) {
+      llvm::errs() << "Skipping non-C++ input file: "
+                   << Ctx.getSourceManager()
+                          .getFileEntryForID(
+                              Ctx.getSourceManager().getMainFileID())
+                          ->getName()
+                   << "\n";
+      return std::vector<Inference>();
+    }
+    auto Sites = EvidenceSites::discover(Ctx);
+    USRCache USRCache;
+
+    std::vector<Inference> AllInference = inferenceRound(Sites, USRCache, {});
+
+    for (unsigned Iteration = 1; Iteration < Iterations; ++Iteration) {
+      llvm::DenseSet<SlotFingerprint> NullableFromLastRound;
+      llvm::DenseSet<SlotFingerprint> NonnullFromLastRound;
+
+      for (const auto& Inference : AllInference) {
+        for (const auto& slot_inference : Inference.slot_inference()) {
+          if (slot_inference.trivial() || slot_inference.conflict()) continue;
+          switch (slot_inference.nullability()) {
+            case Inference::NULLABLE:
+              NullableFromLastRound.insert(
+                  fingerprint(Inference.symbol().usr(), slot_inference.slot()));
+              break;
+            case Inference::NONNULL:
+              NonnullFromLastRound.insert(
+                  fingerprint(Inference.symbol().usr(), slot_inference.slot()));
+              break;
+            default:
+              break;
+          }
+        }
+      }
+
+      AllInference = inferenceRound(
+          Sites, USRCache, {NullableFromLastRound, NonnullFromLastRound});
+    }
+    return AllInference;
+  }
+
+ private:
+  ASTContext& Ctx;
+  unsigned Iterations;
+  llvm::function_ref<bool(const Decl&)> Filter;
+};
+}  // namespace
+
+std::vector<Inference> inferTU(ASTContext& Ctx, unsigned Iterations,
+                               llvm::function_ref<bool(const Decl&)> Filter) {
+  return InferenceManager(Ctx, Iterations, Filter).iterativelyInfer();
 }
 
 }  // namespace clang::tidy::nullability
diff --git a/nullability/inference/infer_tu.h b/nullability/inference/infer_tu.h
index 1f0083f..65d237d 100644
--- a/nullability/inference/infer_tu.h
+++ b/nullability/inference/infer_tu.h
@@ -9,6 +9,7 @@
 
 #include "nullability/inference/inference.proto.h"
 #include "clang/AST/ASTContext.h"
+#include "clang/AST/DeclBase.h"
 #include "llvm/ADT/STLFunctionalExtras.h"
 
 namespace clang::tidy::nullability {
@@ -22,7 +23,8 @@
 //
 // If Filter is provided, only considers decls that return true.
 std::vector<Inference> inferTU(
-    ASTContext &, llvm::function_ref<bool(const Decl &)> Filter = nullptr);
+    ASTContext &, unsigned Iterations = 1,
+    llvm::function_ref<bool(const Decl &)> Filter = nullptr);
 
 }  // namespace clang::tidy::nullability
 
diff --git a/nullability/inference/infer_tu_main.cc b/nullability/inference/infer_tu_main.cc
index 6ed1639..656d125 100644
--- a/nullability/inference/infer_tu_main.cc
+++ b/nullability/inference/infer_tu_main.cc
@@ -72,6 +72,11 @@
     llvm::cl::desc("Regular expression decl names must match to be analyzed. "
                    "May be negated with - prefix."),
 };
+llvm::cl::opt<unsigned> Iterations{
+    "iterations",
+    llvm::cl::desc("Number of inference iterations"),
+    llvm::cl::init(1),
+};
 
 namespace clang::tidy::nullability {
 namespace {
@@ -186,7 +191,7 @@
       void HandleTranslationUnit(ASTContext &Ctx) override {
         llvm::errs() << "Running inference...\n";
 
-        auto Results = inferTU(Ctx, DeclFilter());
+        auto Results = inferTU(Ctx, Iterations, DeclFilter());
         if (!IncludeTrivial)
           llvm::erase_if(Results, [](Inference &I) {
             llvm::erase_if(
diff --git a/nullability/inference/infer_tu_test.cc b/nullability/inference/infer_tu_test.cc
index 92f68e1..4814523 100644
--- a/nullability/inference/infer_tu_test.cc
+++ b/nullability/inference/infer_tu_test.cc
@@ -7,16 +7,15 @@
 #include <optional>
 #include <vector>
 
-#include "nullability/inference/collect_evidence.h"
 #include "nullability/inference/inference.proto.h"
 #include "nullability/proto_matchers.h"
 #include "clang/AST/Decl.h"
 #include "clang/ASTMatchers/ASTMatchFinder.h"
 #include "clang/ASTMatchers/ASTMatchers.h"
+#include "clang/ASTMatchers/ASTMatchersMacros.h"
 #include "clang/Basic/LLVM.h"
 #include "clang/Index/USRGeneration.h"
 #include "clang/Testing/TestAST.h"
-#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringRef.h"
 #include "third_party/llvm/llvm-project/third-party/unittest/googlemock/include/gmock/gmock.h"
@@ -26,6 +25,8 @@
 namespace {
 using ast_matchers::hasName;
 using testing::_;
+using testing::ElementsAre;
+using testing::UnorderedElementsAre;
 
 MATCHER_P2(inferredSlot, I, Nullability, "") {
   return arg.slot() == I && arg.nullability() == Nullability;
@@ -204,7 +205,7 @@
     int* target1() { return nullptr; }
     int* target2() { return nullptr; }
   )cc");
-  EXPECT_THAT(inferTU(AST->context(),
+  EXPECT_THAT(inferTU(AST->context(), /*Iterations=*/1,
                       [&](const Decl &D) {
                         return cast<NamedDecl>(D).getNameAsString() !=
                                "target2";
@@ -212,5 +213,47 @@
               ElementsAre(inference(hasName("target1"), {_})));
 }
 
+TEST_F(InferTUTest, IterationsPropagateInferences) {
+  build(R"cc(
+    int* returnsToBeNonnull(int* a) { return a; }
+    int* target(int* q) {
+      *q;
+      return returnsToBeNonnull(q);
+    }
+  )cc");
+  EXPECT_THAT(
+      inferTU(AST->context(), /*Iterations=*/1),
+      UnorderedElementsAre(
+          inference(hasName("target"), {inferredSlot(0, Inference::UNKNOWN),
+                                        inferredSlot(1, Inference::NONNULL)}),
+          inference(hasName("returnsToBeNonnull"),
+                    {inferredSlot(0, Inference::UNKNOWN),
+                     inferredSlot(1, Inference::UNKNOWN)})));
+  EXPECT_THAT(
+      inferTU(AST->context(), /*Iterations=*/2),
+      UnorderedElementsAre(
+          inference(hasName("target"), {inferredSlot(0, Inference::UNKNOWN),
+                                        inferredSlot(1, Inference::NONNULL)}),
+          inference(hasName("returnsToBeNonnull"),
+                    {inferredSlot(0, Inference::UNKNOWN),
+                     inferredSlot(1, Inference::NONNULL)})));
+  EXPECT_THAT(
+      inferTU(AST->context(), /*Iterations=*/3),
+      UnorderedElementsAre(
+          inference(hasName("target"), {inferredSlot(0, Inference::UNKNOWN),
+                                        inferredSlot(1, Inference::NONNULL)}),
+          inference(hasName("returnsToBeNonnull"),
+                    {inferredSlot(0, Inference::NONNULL),
+                     inferredSlot(1, Inference::NONNULL)})));
+  EXPECT_THAT(
+      inferTU(AST->context(), /*Iterations=*/4),
+      UnorderedElementsAre(
+          inference(hasName("target"), {inferredSlot(0, Inference::NONNULL),
+                                        inferredSlot(1, Inference::NONNULL)}),
+          inference(hasName("returnsToBeNonnull"),
+                    {inferredSlot(0, Inference::NONNULL),
+                     inferredSlot(1, Inference::NONNULL)})));
+}
+
 }  // namespace
 }  // namespace clang::tidy::nullability