Use previous inference results in PointerNullabilityAnalysis.

Allows for additional propagation of inferred nullability, e.g. from one implementation's return value or parameters to another implementation's interaction with those decls.

PiperOrigin-RevId: 576133969
Change-Id: I241f413fede47ab773cfe77b78435a58c552a279
diff --git a/nullability/BUILD b/nullability/BUILD
index 4ca6a1c..3ba5bcf 100644
--- a/nullability/BUILD
+++ b/nullability/BUILD
@@ -15,6 +15,8 @@
         "@absl//absl/log:check",
         "@llvm-project//clang:analysis",
         "@llvm-project//clang:ast",
+        "@llvm-project//clang:basic",
+        "@llvm-project//llvm:Support",
     ],
 )
 
diff --git a/nullability/inference/BUILD b/nullability/inference/BUILD
index b620bad..1432ac8 100644
--- a/nullability/inference/BUILD
+++ b/nullability/inference/BUILD
@@ -16,6 +16,7 @@
         "//nullability:pointer_nullability_analysis",
         "//nullability:pointer_nullability_lattice",
         "//nullability:type_nullability",
+        "@absl//absl/container:flat_hash_map",
         "@absl//absl/log:check",
         "@llvm-project//clang:analysis",
         "@llvm-project//clang:ast",
diff --git a/nullability/inference/collect_evidence.cc b/nullability/inference/collect_evidence.cc
index d210e7f..5fdff82 100644
--- a/nullability/inference/collect_evidence.cc
+++ b/nullability/inference/collect_evidence.cc
@@ -11,6 +11,7 @@
 #include <utility>
 #include <vector>
 
+#include "absl/container/flat_hash_map.h"
 #include "absl/log/check.h"
 #include "nullability/inference/inferable.h"
 #include "nullability/inference/inference.proto.h"
@@ -52,6 +53,10 @@
 using ::clang::dataflow::Environment;
 using ::clang::dataflow::Formula;
 
+using ConcreteNullabilityCache =
+    absl::flat_hash_map<const Decl *,
+                        std::optional<const PointerTypeNullability>>;
+
 std::string_view getOrGenerateUSR(USRCache &Cache, const Decl &Decl) {
   auto [It, Inserted] = Cache.try_emplace(&Decl);
   if (Inserted) {
@@ -174,7 +179,7 @@
 // was made in the previous round or there was no previous round.
 const Formula &getInferableSlotsAsInferredOrUnknownConstraint(
     std::vector<std::pair<PointerTypeNullability, Slot>> &InferableSlots,
-    const PreviousInferences &PreviousInferences, USRCache &USRCache,
+    USRCache &USRCache, const PreviousInferences &PreviousInferences,
     dataflow::Arena &A, const Decl &CurrentFunc) {
   const Formula *Constraint = &A.makeLiteral(true);
   std::string_view USR = getOrGenerateUSR(USRCache, CurrentFunc);
@@ -254,8 +259,8 @@
 
     // Emit evidence of the parameter's nullability. First, calculate that
     // nullability based on InferableSlots for the caller being assigned to
-    // Unknown, to reflect the current annotations and not all possible
-    // annotations for them.
+    // Unknown or their previously-inferred value, to reflect the current
+    // annotations and not all possible annotations for them.
     NullabilityKind ArgNullability =
         getNullability(*PV, Env, &InferableSlotsConstraint);
     Evidence::Kind ArgEvidenceKind;
@@ -331,6 +336,49 @@
       return Evidence::ANNOTATED_NULLABLE;
   }
 }
+
+// Returns a function that the analysis can use to override Decl nullability
+// values from the source code being analyzed with previously inferred
+// nullabilities.
+//
+// In practice, this should only override the default nullability for Decls that
+// do not spell out a nullability in source code, because we only pass in
+// inferences from the previous round which are non-trivial and annotations
+// "inferred" by reading an annotation from source code in the previous round
+// were marked trivial.
+auto getConcreteNullabilityOverrideFromPreviousInferences(
+    ConcreteNullabilityCache &Cache, USRCache &USRCache,
+    const PreviousInferences &PreviousInferences) {
+  return [&](const Decl &D) -> std::optional<const PointerTypeNullability *> {
+    auto [It, Inserted] = Cache.try_emplace(&D);
+    if (Inserted) {
+      std::optional<const Decl *> fingerprintedDecl;
+      Slot Slot;
+      if (auto *FD = clang::dyn_cast_or_null<FunctionDecl>(&D)) {
+        fingerprintedDecl = (ValueDecl *)FD;
+        Slot = SLOT_RETURN_TYPE;
+      } else if (auto *PD = clang::dyn_cast_or_null<ParmVarDecl>(&D)) {
+        if (auto *Parent = clang::dyn_cast_or_null<FunctionDecl>(
+                PD->getParentFunctionOrMethod())) {
+          fingerprintedDecl = (ValueDecl *)Parent;
+          Slot = paramSlot(PD->getFunctionScopeIndex());
+        }
+      }
+      if (!fingerprintedDecl) return std::nullopt;
+      auto fp =
+          fingerprint(getOrGenerateUSR(USRCache, **fingerprintedDecl), Slot);
+      if (PreviousInferences.Nullable.contains(fp)) {
+        It->second.emplace(NullabilityKind::Nullable);
+      } else if (PreviousInferences.Nonnull.contains(fp)) {
+        It->second.emplace(NullabilityKind::NonNull);
+      } else {
+        It->second = std::nullopt;
+      }
+    }
+    if (!It->second) return std::nullopt;
+    return &*It->second;
+  };
+}
 }  // namespace
 
 llvm::Error collectEvidenceFromImplementation(
@@ -365,9 +413,14 @@
   }
   const auto &InferableSlotsConstraint =
       getInferableSlotsAsInferredOrUnknownConstraint(
-          InferableSlots, PreviousInferences, USRCache, AnalysisContext.arena(),
+          InferableSlots, USRCache, PreviousInferences, AnalysisContext.arena(),
           Decl);
 
+  ConcreteNullabilityCache ConcreteNullabilityCache;
+  Analysis.assignNullabilityOverride(
+      getConcreteNullabilityOverrideFromPreviousInferences(
+          ConcreteNullabilityCache, USRCache, PreviousInferences));
+
   return dataflow::runDataflowAnalysis(
              *ControlFlowContext, Analysis, Environment,
              [&](const CFGElement &Element,
diff --git a/nullability/inference/collect_evidence_test.cc b/nullability/inference/collect_evidence_test.cc
index eca2f3d..7ec9411 100644
--- a/nullability/inference/collect_evidence_test.cc
+++ b/nullability/inference/collect_evidence_test.cc
@@ -17,6 +17,7 @@
 #include "clang/Basic/LLVM.h"
 #include "clang/Testing/TestAST.h"
 #include "third_party/llvm/llvm-project/clang/unittests/Analysis/FlowSensitive/TestingSupport.h"
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Error.h"
@@ -537,6 +538,93 @@
                     IsSupersetOf(ExpectedSecondRoundResults)));
 }
 
+TEST(CollectEvidenceFromImplementationTest,
+     AnalysisUsesPreviousInferencesForSlotsOutsideTargetImplementation) {
+  static constexpr llvm::StringRef Src = R"cc(
+    int* returnsToBeNonnull(int* a) {
+      return a;
+    }
+    int* target(int* q) {
+      *q;
+      return returnsToBeNonnull(q);
+    }
+  )cc";
+  std::string TargetUsr = "c:@F@target#*I#";
+  std::string ReturnsToBeNonnullUsr = "c:@F@returnsToBeNonnull#*I#";
+  const llvm::DenseMap<int, std::vector<testing::Matcher<const Evidence&>>>
+      ExpectedNewResultsPerRound = {
+          {{0,
+            {evidence(
+                paramSlot(0), Evidence::UNCHECKED_DEREFERENCE,
+                AllOf(functionNamed("target"),
+                      // Double-check that target's usr is as expected before
+                      // we use it to create SlotFingerprints.
+                      ResultOf([](Symbol S) { return S.usr(); }, TargetUsr)))}},
+           {1,
+            {evidence(
+                paramSlot(0), Evidence::NONNULL_ARGUMENT,
+                AllOf(functionNamed("returnsToBeNonnull"),
+                      // Double-check that returnsToBeNonnull's usr is as
+                      // expected before we use it to create SlotFingerprints.
+                      ResultOf([](Symbol S) { return S.usr(); },
+                               ReturnsToBeNonnullUsr)))}},
+           {2,
+            {
+                // No new evidence from target's implementation in this round,
+                // but in a full-TU analysis, this would be the round where we
+                // decide returnsToBeNonnull returns Nonnull, based on the
+                // now-Nonnull argument that is the only return value.
+            }},
+           {3,
+            {evidence(SLOT_RETURN_TYPE, Evidence::NONNULL_RETURN,
+                      functionNamed("target"))}}}};
+
+  // Assert first round results because they don't rely on previous inference
+  // propagation at all and in this case are test setup and preconditions.
+  auto FirstRoundResults = collectEvidenceFromTargetFunction(Src);
+  ASSERT_THAT(FirstRoundResults,
+              IsSupersetOf(ExpectedNewResultsPerRound.at(0)));
+  for (const auto& E : ExpectedNewResultsPerRound.at(1)) {
+    ASSERT_THAT(FirstRoundResults, Not(Contains(E)));
+  }
+
+  auto SecondRoundResults = collectEvidenceFromTargetFunction(
+      Src, {.Nonnull = {fingerprint(TargetUsr, paramSlot(0))}});
+  EXPECT_THAT(SecondRoundResults,
+              AllOf(IsSupersetOf(ExpectedNewResultsPerRound.at(0)),
+                    IsSupersetOf(ExpectedNewResultsPerRound.at(1))));
+  for (const auto& E : ExpectedNewResultsPerRound.at(2)) {
+    ASSERT_THAT(SecondRoundResults, Not(Contains(E)));
+  }
+
+  auto ThirdRoundResults = collectEvidenceFromTargetFunction(
+      Src, {.Nonnull = {fingerprint(TargetUsr, paramSlot(0)),
+                        fingerprint(ReturnsToBeNonnullUsr, paramSlot(0))}});
+  EXPECT_THAT(ThirdRoundResults,
+              AllOf(IsSupersetOf(ExpectedNewResultsPerRound.at(0)),
+                    IsSupersetOf(ExpectedNewResultsPerRound.at(1)),
+                    IsSupersetOf(ExpectedNewResultsPerRound.at(2))));
+  for (const auto& E : ExpectedNewResultsPerRound.at(3)) {
+    ASSERT_THAT(ThirdRoundResults, Not(Contains(E)));
+  }
+
+  auto FourthRoundResults = collectEvidenceFromTargetFunction(
+      Src,
+      {.Nonnull = {
+           fingerprint(TargetUsr, paramSlot(0)),
+           fingerprint(ReturnsToBeNonnullUsr, paramSlot(0)),
+           // As noted in the Evidence matcher list above, we don't infer the
+           // return type of returnsToBeNonnull from only collecting evidence
+           // from target's implementation, but for the sake of this test, let's
+           // pretend we collected evidence from the entire TU.
+           fingerprint(ReturnsToBeNonnullUsr, SLOT_RETURN_TYPE)}});
+  EXPECT_THAT(FourthRoundResults,
+              AllOf(IsSupersetOf(ExpectedNewResultsPerRound.at(0)),
+                    IsSupersetOf(ExpectedNewResultsPerRound.at(1)),
+                    IsSupersetOf(ExpectedNewResultsPerRound.at(2)),
+                    IsSupersetOf(ExpectedNewResultsPerRound.at(3))));
+}
+
 TEST(CollectEvidenceFromDeclarationTest, VariableDeclIgnored) {
   llvm::StringLiteral Src = "Nullable<int *> target;";
   EXPECT_THAT(collectEvidenceFromTargetDecl(Src), IsEmpty());
diff --git a/nullability/pointer_nullability_analysis.cc b/nullability/pointer_nullability_analysis.cc
index adcd9c8..0395559 100644
--- a/nullability/pointer_nullability_analysis.cc
+++ b/nullability/pointer_nullability_analysis.cc
@@ -494,7 +494,7 @@
 
 // If nullability for the decl D has been overridden, patch N to reflect it.
 // (N is the nullability of an access to D).
-void overrideNullabilityFromDecl(const ValueDecl *D,
+void overrideNullabilityFromDecl(const Decl *D,
                                  PointerNullabilityLattice &Lattice,
                                  TypeNullability &N) {
   // For now, overrides are always for pointer values only, and override only
@@ -706,8 +706,11 @@
     // TODO(mboehme): Instead of relying on Clang to propagate nullability sugar
     // to the `CallExpr`'s type, we should extract nullability directly from the
     // callee `Expr .
-    return substituteNullabilityAnnotationsInFunctionTemplate(CE->getType(),
-                                                              CE);
+    auto Nullability =
+        substituteNullabilityAnnotationsInFunctionTemplate(CE->getType(), CE);
+    overrideNullabilityFromDecl(CE->getCalleeDecl(), State.Lattice,
+                                Nullability);
+    return Nullability;
   });
 }
 
diff --git a/nullability/pointer_nullability_analysis.h b/nullability/pointer_nullability_analysis.h
index ca94c8a..65ba3bd 100644
--- a/nullability/pointer_nullability_analysis.h
+++ b/nullability/pointer_nullability_analysis.h
@@ -5,19 +5,21 @@
 #ifndef CRUBIT_NULLABILITY_POINTER_NULLABILITY_ANALYSIS_H_
 #define CRUBIT_NULLABILITY_POINTER_NULLABILITY_ANALYSIS_H_
 
+#include <optional>
 #include <utility>
 
 #include "nullability/pointer_nullability_lattice.h"
 #include "nullability/type_nullability.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/Decl.h"
+#include "clang/AST/DeclBase.h"
 #include "clang/AST/Type.h"
 #include "clang/Analysis/FlowSensitive/Arena.h"
 #include "clang/Analysis/FlowSensitive/CFGMatchSwitch.h"
 #include "clang/Analysis/FlowSensitive/DataflowAnalysis.h"
 #include "clang/Analysis/FlowSensitive/DataflowEnvironment.h"
-#include "clang/Analysis/FlowSensitive/Formula.h"
 #include "clang/Analysis/FlowSensitive/Value.h"
+#include "llvm/ADT/FunctionExtras.h"
 
 namespace clang {
 namespace tidy {
@@ -67,6 +69,13 @@
   PointerTypeNullability assignNullabilityVariable(const ValueDecl *D,
                                                    dataflow::Arena &);
 
+  void assignNullabilityOverride(
+      llvm::unique_function<
+          std::optional<const PointerTypeNullability *>(const Decl &) const>
+          Override) {
+    NFS.ConcreteNullabilityOverride = std::move(Override);
+  }
+
   void transfer(const CFGElement &Elt, PointerNullabilityLattice &Lattice,
                 dataflow::Environment &Env);
 
diff --git a/nullability/pointer_nullability_lattice.h b/nullability/pointer_nullability_lattice.h
index 2a27d8e..0b7b906 100644
--- a/nullability/pointer_nullability_lattice.h
+++ b/nullability/pointer_nullability_lattice.h
@@ -6,6 +6,7 @@
 #define CRUBIT_NULLABILITY_POINTER_NULLABILITY_LATTICE_H_
 
 #include <functional>
+#include <optional>
 #include <ostream>
 
 #include "absl/container/flat_hash_map.h"
@@ -14,18 +15,26 @@
 #include "clang/AST/Expr.h"
 #include "clang/Analysis/FlowSensitive/DataflowAnalysisContext.h"
 #include "clang/Analysis/FlowSensitive/DataflowLattice.h"
+#include "clang/Basic/LLVM.h"
+#include "llvm/ADT/FunctionExtras.h"
 
 namespace clang::tidy::nullability {
-
 class PointerNullabilityLattice {
  public:
   struct NonFlowSensitiveState {
     absl::flat_hash_map<const Expr *, TypeNullability> ExprToNullability;
     // Overridden symbolic nullability for pointer-typed decls.
     // These are set by PointerNullabilityAnalysis::assignNullabilityVariable,
-    // and take precedence over the declared type.
+    // and take precedence over the declared type and over any result from
+    // ConcreteNullabilityOverride.
     absl::flat_hash_map<const ValueDecl *, PointerTypeNullability>
         DeclTopLevelNullability;
+    // Returns overriding concrete nullability for decls. This is set by
+    // PointerNullabilityAnalysis::assignNullabilityOverride, and the result, if
+    // present, takes precedence over the declared type.
+    llvm::unique_function<std::optional<const PointerTypeNullability *>(
+        const Decl &) const>
+        ConcreteNullabilityOverride = [](const Decl &) { return std::nullopt; };
   };
 
   PointerNullabilityLattice(NonFlowSensitiveState &NFS) : NFS(NFS) {}
@@ -54,11 +63,18 @@
   }
 
   // Returns overridden nullability information associated with a declaration.
-  // For now we only track top-level decl nullability symbolically.
-  const PointerTypeNullability *getDeclNullability(const ValueDecl *D) const {
-    auto It = NFS.DeclTopLevelNullability.find(D);
-    if (It == NFS.DeclTopLevelNullability.end()) return nullptr;
-    return &It->second;
+  // For now we only track top-level decl nullability symbolically and check for
+  // concrete nullability override results.
+  const PointerTypeNullability *getDeclNullability(const Decl *D) const {
+    if (!D) return nullptr;
+    if (const auto *VD = dyn_cast_or_null<ValueDecl>(D)) {
+      auto It = NFS.DeclTopLevelNullability.find(VD);
+      if (It != NFS.DeclTopLevelNullability.end()) return &It->second;
+    }
+    if (const std::optional<const PointerTypeNullability *> N =
+            NFS.ConcreteNullabilityOverride(*D))
+      return *N;
+    return nullptr;
   }
 
   bool operator==(const PointerNullabilityLattice &Other) const { return true; }