Constrain inferable slots to previously inferred nullabilities when available, instead of always to Unknown, when querying the nullability of other values.

PiperOrigin-RevId: 574892768
Change-Id: I5e62ec21d4631d8b33889ef2de84625c07abac6b
diff --git a/nullability/inference/BUILD b/nullability/inference/BUILD
index 3997845..b620bad 100644
--- a/nullability/inference/BUILD
+++ b/nullability/inference/BUILD
@@ -31,6 +31,7 @@
     deps = [
         ":collect_evidence",
         ":inference_cc_proto",
+        ":slot_fingerprint",
         "@llvm-project//clang:ast",
         "@llvm-project//clang:ast_matchers",
         "@llvm-project//clang:basic",
diff --git a/nullability/inference/collect_evidence.cc b/nullability/inference/collect_evidence.cc
index f2964cf..ba670ce 100644
--- a/nullability/inference/collect_evidence.cc
+++ b/nullability/inference/collect_evidence.cc
@@ -50,6 +50,7 @@
 namespace clang::tidy::nullability {
 using ::clang::dataflow::DataflowAnalysisContext;
 using ::clang::dataflow::Environment;
+using ::clang::dataflow::Formula;
 
 std::string_view getOrGenerateUSR(USRCache &Cache, const Decl &Decl) {
   auto [It, Inserted] = Cache.try_emplace(&Decl);
@@ -171,20 +172,24 @@
 // constraint representing these slots having a) the nullability inferred from
 // the previous round for this slot or b) Unknown nullability if no inference
 // was made in the previous round or there was no previous round.
-const dataflow::Formula *getInferableSlotsAsInferredOrUnknownConstraint(
+const Formula *getInferableSlotsAsInferredOrUnknownConstraint(
     std::vector<std::pair<PointerTypeNullability, Slot>> &InferableSlots,
     const llvm::DenseSet<SlotFingerprint> &PreviouslyInferredNullable,
     const llvm::DenseSet<SlotFingerprint> &PreviouslyInferredNonnull,
-    const dataflow::Environment &Env) {
+    USRCache &USRCache, const dataflow::Environment &Env) {
   dataflow::Arena &A = Env.getDataflowAnalysisContext().arena();
-  const dataflow::Formula *InferableSlotsUnknown = &A.makeLiteral(true);
+  const Formula *InferableSlotsUnknown = &A.makeLiteral(true);
+  std::string_view USR = getOrGenerateUSR(USRCache, *Env.getCurrentFunc());
   for (auto &[Nullability, Slot] : InferableSlots) {
-    // TODO: get access to the current function's USR and look up the
-    // fingerprint for that USR and `Slot` in the previously inferred sets,
-    // instead of assuming Unknown.
-    InferableSlotsUnknown = &A.makeAnd(
-        *InferableSlotsUnknown, A.makeAnd(A.makeNot(Nullability.isNullable(A)),
-                                          A.makeNot(Nullability.isNonnull(A))));
+    SlotFingerprint Fingerprint = fingerprint(USR, Slot);
+    const Formula &Nullable = PreviouslyInferredNullable.contains(Fingerprint)
+                                  ? Nullability.isNullable(A)
+                                  : A.makeNot(Nullability.isNullable(A));
+    const Formula &Nonnull = PreviouslyInferredNonnull.contains(Fingerprint)
+                                 ? Nullability.isNonnull(A)
+                                 : A.makeNot(Nullability.isNonnull(A));
+    InferableSlotsUnknown =
+        &A.makeAnd(*InferableSlotsUnknown, A.makeAnd(Nullable, Nonnull));
   }
   return InferableSlotsUnknown;
 }
@@ -207,7 +212,8 @@
     std::vector<std::pair<PointerTypeNullability, Slot>> &InferableCallerSlots,
     const llvm::DenseSet<SlotFingerprint> &PreviouslyInferredNullable,
     const llvm::DenseSet<SlotFingerprint> &PreviouslyInferredNonnull,
-    const CFGElement &Element, const dataflow::Environment &Env,
+    USRCache USRCache, const CFGElement &Element,
+    const dataflow::Environment &Env,
     llvm::function_ref<EvidenceEmitter> Emit) {
   // Is this CFGElement a call to a function?
   auto CFGStmt = Element.getAs<clang::CFGStmt>();
@@ -259,7 +265,7 @@
         getNullability(*PV, Env,
                        getInferableSlotsAsInferredOrUnknownConstraint(
                            InferableCallerSlots, PreviouslyInferredNullable,
-                           PreviouslyInferredNonnull, Env));
+                           PreviouslyInferredNonnull, USRCache, Env));
     Evidence::Kind ArgEvidenceKind;
     switch (ArgNullability) {
       case NullabilityKind::Nullable:
@@ -279,7 +285,8 @@
     std::vector<std::pair<PointerTypeNullability, Slot>> &InferableSlots,
     const llvm::DenseSet<SlotFingerprint> &PreviouslyInferredNullable,
     const llvm::DenseSet<SlotFingerprint> &PreviouslyInferredNonnull,
-    const CFGElement &Element, const dataflow::Environment &Env,
+    USRCache USRCache, const CFGElement &Element,
+    const dataflow::Environment &Env,
     llvm::function_ref<EvidenceEmitter> Emit) {
   // Is this CFGElement a return statement?
   auto CFGStmt = Element.getAs<clang::CFGStmt>();
@@ -297,7 +304,7 @@
       getNullability(ReturnExpr, Env,
                      getInferableSlotsAsInferredOrUnknownConstraint(
                          InferableSlots, PreviouslyInferredNullable,
-                         PreviouslyInferredNonnull, Env));
+                         PreviouslyInferredNonnull, USRCache, Env));
   Evidence::Kind ReturnEvidenceKind;
   switch (ReturnNullability) {
     case NullabilityKind::Nullable:
@@ -317,13 +324,15 @@
     std::vector<std::pair<PointerTypeNullability, Slot>> InferableSlots,
     const llvm::DenseSet<SlotFingerprint> &PreviouslyInferredNullable,
     const llvm::DenseSet<SlotFingerprint> &PreviouslyInferredNonnull,
-    const CFGElement &Element, const Environment &Env,
+    USRCache USRCache, const CFGElement &Element, const Environment &Env,
     llvm::function_ref<EvidenceEmitter> Emit) {
   collectEvidenceFromDereference(InferableSlots, Element, Env, Emit);
   collectEvidenceFromCallExpr(InferableSlots, PreviouslyInferredNullable,
-                              PreviouslyInferredNonnull, Element, Env, Emit);
+                              PreviouslyInferredNonnull, USRCache, Element, Env,
+                              Emit);
   collectEvidenceFromReturn(InferableSlots, PreviouslyInferredNullable,
-                            PreviouslyInferredNonnull, Element, Env, Emit);
+                            PreviouslyInferredNonnull, USRCache, Element, Env,
+                            Emit);
   // TODO: add more heuristic collections here
 }
 
@@ -379,9 +388,10 @@
              [&](const CFGElement &Element,
                  const dataflow::DataflowAnalysisState<
                      PointerNullabilityLattice> &State) {
-               collectEvidenceFromElement(
-                   InferableSlots, PreviouslyInferredNullable,
-                   PreviouslyInferredNonnull, Element, State.Env, Emit);
+               collectEvidenceFromElement(InferableSlots,
+                                          PreviouslyInferredNullable,
+                                          PreviouslyInferredNonnull, USRCache,
+                                          Element, State.Env, Emit);
              })
       .takeError();
 }
diff --git a/nullability/inference/collect_evidence_test.cc b/nullability/inference/collect_evidence_test.cc
index 5b08e1d..7c97878 100644
--- a/nullability/inference/collect_evidence_test.cc
+++ b/nullability/inference/collect_evidence_test.cc
@@ -9,6 +9,7 @@
 #include <vector>
 
 #include "nullability/inference/inference.proto.h"
+#include "nullability/inference/slot_fingerprint.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/DeclBase.h"
 #include "clang/ASTMatchers/ASTMatchFinder.h"
@@ -16,6 +17,7 @@
 #include "clang/Basic/LLVM.h"
 #include "clang/Testing/TestAST.h"
 #include "third_party/llvm/llvm-project/clang/unittests/Analysis/FlowSensitive/TestingSupport.h"
+#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Error.h"
 #include "llvm/Support/raw_ostream.h"
@@ -29,10 +31,13 @@
 using ::clang::ast_matchers::isTemplateInstantiation;
 using ::clang::ast_matchers::match;
 using ::testing::_;
+using ::testing::AllOf;
 using ::testing::Contains;
 using ::testing::ElementsAre;
 using ::testing::IsEmpty;
+using ::testing::IsSupersetOf;
 using ::testing::Not;
+using ::testing::ResultOf;
 using ::testing::SizeIs;
 using ::testing::UnorderedElementsAre;
 
@@ -66,7 +71,9 @@
 }
 
 std::vector<Evidence> collectEvidenceFromTargetFunction(
-    llvm::StringRef Source) {
+    llvm::StringRef Source,
+    const llvm::DenseSet<SlotFingerprint>& PreviouslyInferredNullable = {},
+    const llvm::DenseSet<SlotFingerprint>& PreviouslyInferredNonnull = {}) {
   std::vector<Evidence> Results;
   clang::TestAST AST(getInputsWithAnnotationDefinitions(Source));
   USRCache usr_cache;
@@ -75,7 +82,7 @@
           *dataflow::test::findValueDecl(AST.context(), "target")),
       evidenceEmitter([&](const Evidence& E) { Results.push_back(E); },
                       usr_cache),
-      usr_cache);
+      usr_cache, PreviouslyInferredNullable, PreviouslyInferredNonnull);
   if (Err) ADD_FAILURE() << toString(std::move(Err));
   return Results;
 }
@@ -488,6 +495,52 @@
   EXPECT_THAT(Results, IsEmpty());
 }
 
+TEST(CollectEvidenceFromImplementationTest, PropagatesPreviousInferences) {
+  static constexpr llvm::StringRef Src = R"cc(
+    void calledWithToBeNullable(int* x);
+    void calledWithToBeNonnull(int* a);
+    void target(int* p, int* q) {
+      target(nullptr, q);
+      calledWithToBeNullable(p);
+      *q;
+      calledWithToBeNonnull(q);
+    }
+  )cc";
+  std::string TargetUsr = "c:@F@target#*I#S0_#";
+  std::vector ExpectedBothRoundResults = {
+      evidence(paramSlot(0), Evidence::NULLABLE_ARGUMENT,
+               AllOf(functionNamed("target"),
+                     // Double-check that target's usr is as expected before we
+                     // use it to create SlotFingerprints.
+                     ResultOf([](Symbol S) { return S.usr(); }, TargetUsr))),
+      evidence(paramSlot(1), Evidence::UNCHECKED_DEREFERENCE,
+               functionNamed("target")),
+  };
+  std::vector ExpectedSecondRoundResults = {
+      evidence(paramSlot(0), Evidence::NULLABLE_ARGUMENT,
+               functionNamed("calledWithToBeNullable")),
+      evidence(paramSlot(0), Evidence::NONNULL_ARGUMENT,
+               functionNamed("calledWithToBeNonnull"))};
+
+  // Only proceed if we have the correct USR for target and the first round
+  // results contain the evidence needed to produce our expected inferences and
+  // do not contain the evidence only found from propagating inferences from the
+  // first round.
+  auto FirstRoundResults = collectEvidenceFromTargetFunction(Src);
+  ASSERT_THAT(FirstRoundResults, IsSupersetOf(ExpectedBothRoundResults));
+  for (const auto& E : ExpectedSecondRoundResults) {
+    ASSERT_THAT(FirstRoundResults, Not(Contains(E)));
+  }
+
+  EXPECT_THAT(
+      collectEvidenceFromTargetFunction(Src, /* PreviouslyInferredNullable= */
+                                        {fingerprint(TargetUsr, paramSlot(0))},
+                                        /* PreviouslyInferredNonnull= */
+                                        {fingerprint(TargetUsr, paramSlot(1))}),
+      AllOf(IsSupersetOf(ExpectedBothRoundResults),
+            IsSupersetOf(ExpectedSecondRoundResults)));
+}
+
 TEST(CollectEvidenceFromDeclarationTest, VariableDeclIgnored) {
   llvm::StringLiteral Src = "Nullable<int *> target;";
   EXPECT_THAT(collectEvidenceFromTargetDecl(Src), IsEmpty());