Use previous inferences and inferable slot overrides when collecting evidence from binding a value to a type.

PiperOrigin-RevId: 579184731
Change-Id: Iffc11650fddb1b9573b3797077e4511af208927f
diff --git a/nullability/BUILD b/nullability/BUILD
index b06dd32..316833b 100644
--- a/nullability/BUILD
+++ b/nullability/BUILD
@@ -4,6 +4,7 @@
 
 cc_library(
     name = "pointer_nullability_lattice",
+    srcs = ["pointer_nullability_lattice.cc"],
     hdrs = ["pointer_nullability_lattice.h"],
     visibility = [
         "//nullability/inference:__pkg__",
diff --git a/nullability/inference/collect_evidence.cc b/nullability/inference/collect_evidence.cc
index e6e1a42..9caf03c 100644
--- a/nullability/inference/collect_evidence.cc
+++ b/nullability/inference/collect_evidence.cc
@@ -200,17 +200,39 @@
   return *Constraint;
 }
 
+auto getNullabilityAnnotationsFromTypeAndOverrides(
+    QualType Type, const Decl *D, const PointerNullabilityLattice &Lattice) {
+  auto N = getNullabilityAnnotationsFromType(Type);
+  if (N.empty()) {
+    // We expect this not to be the case, but not to a crash-worthy level, so
+    // just log if it is.
+    llvm::errs() << "Nullability for type " << Type.getAsString();
+    if (auto *ND = dyn_cast_or_null<clang::NamedDecl>(D)) {
+      llvm::errs() << "for Decl named " << ND->getName();
+    }
+    llvm::errs() << " requested with overrides, but is an empty vector.\n";
+  } else {
+    Lattice.overrideNullabilityFromDecl(D, N);
+  }
+  return N;
+}
+
 void collectEvidenceFromBindingToType(
     TypeNullability &TypeNullability,
     const dataflow::PointerValue &PointerValue,
     std::vector<std::pair<PointerTypeNullability, Slot>>
         &InferableSlotsFromValueContext,
-    const dataflow::Environment &Env, SourceLocation ValueLoc,
-    llvm::function_ref<EvidenceEmitter> Emit) {
+    const Formula &InferableSlotsConstraint, const dataflow::Environment &Env,
+    SourceLocation ValueLoc, llvm::function_ref<EvidenceEmitter> Emit) {
   //  TODO: Account for variance and each layer of nullability when we handle
   //  more than top-level pointers.
   if (TypeNullability.empty()) return;
-  if (TypeNullability[0].concrete() == NullabilityKind::NonNull) {
+  PointerTypeNullability &TopLevel = TypeNullability[0];
+  dataflow::Arena &A = Env.arena();
+  if (TopLevel.concrete() == NullabilityKind::NonNull ||
+      (TopLevel.isSymbolic() &&
+       Env.proves(
+           A.makeImplies(InferableSlotsConstraint, TopLevel.isNonnull(A))))) {
     collectMustBeNonnullEvidence(PointerValue, Env, ValueLoc,
                                  InferableSlotsFromValueContext,
                                  Evidence::BOUND_TO_NONNULL, Emit);
@@ -221,7 +243,8 @@
 void collectEvidenceFromArgsAndParams(
     const FunctionDecl &CalleeDecl, const CallOrConstructExpr &Expr,
     std::vector<std::pair<PointerTypeNullability, Slot>> &InferableCallerSlots,
-    const Formula &InferableSlotsConstraint, const dataflow::Environment &Env,
+    const Formula &InferableSlotsConstraint,
+    const PointerNullabilityLattice &Lattice, const dataflow::Environment &Env,
     llvm::function_ref<EvidenceEmitter> Emit) {
   unsigned ParamI = 0;
   unsigned ArgI = 0;
@@ -236,8 +259,8 @@
 
   // For each pointer parameter of the callee, ...
   for (; ParamI < CalleeDecl.param_size(); ++ParamI, ++ArgI) {
-    auto ParamType =
-        CalleeDecl.getParamDecl(ParamI)->getType().getNonReferenceType();
+    const auto *ParamDecl = CalleeDecl.getParamDecl(ParamI);
+    const auto ParamType = ParamDecl->getType().getNonReferenceType();
     if (!isSupportedPointerType(ParamType)) continue;
     // the corresponding argument should also be a pointer.
     CHECK(isSupportedPointerType(Expr.getArg(ArgI)->getType()));
@@ -248,13 +271,14 @@
 
     SourceLocation ArgLoc = Expr.getArg(ArgI)->getExprLoc();
 
-    // TODO: Include inferred annotations from previous rounds when propagating.
-    auto ParamNullability = getNullabilityAnnotationsFromType(ParamType);
+    auto ParamNullability = getNullabilityAnnotationsFromTypeAndOverrides(
+        ParamType, ParamDecl, Lattice);
 
     // Collect evidence from the binding of the argument to the parameter's
     // nullability, if known.
-    collectEvidenceFromBindingToType(ParamNullability, *PV,
-                                     InferableCallerSlots, Env, ArgLoc, Emit);
+    collectEvidenceFromBindingToType(
+        ParamNullability, *PV, InferableCallerSlots, InferableSlotsConstraint,
+        Env, ArgLoc, Emit);
 
     // Emit evidence of the parameter's nullability. First, calculate that
     // nullability based on InferableSlots for the caller being assigned to
@@ -280,7 +304,7 @@
 void collectEvidenceFromCallExpr(
     std::vector<std::pair<PointerTypeNullability, Slot>> &InferableCallerSlots,
     const Formula &InferableSlotsConstraint, const CFGElement &Element,
-    const dataflow::Environment &Env,
+    const PointerNullabilityLattice &Lattice, const dataflow::Environment &Env,
     llvm::function_ref<EvidenceEmitter> Emit) {
   // Is this CFGElement a call to a function?
   auto CFGStmt = Element.getAs<clang::CFGStmt>();
@@ -292,13 +316,14 @@
   if (!CalleeDecl || !isInferenceTarget(*CalleeDecl)) return;
 
   collectEvidenceFromArgsAndParams(*CalleeDecl, *CallExpr, InferableCallerSlots,
-                                   InferableSlotsConstraint, Env, Emit);
+                                   InferableSlotsConstraint, Lattice, Env,
+                                   Emit);
 }
 
 void collectEvidenceFromConstructExpr(
     std::vector<std::pair<PointerTypeNullability, Slot>> &InferableSlots,
     const Formula &InferableSlotsConstraint, const CFGElement &Element,
-    const dataflow::Environment &Env,
+    const PointerNullabilityLattice &Lattice, const dataflow::Environment &Env,
     llvm::function_ref<EvidenceEmitter> Emit) {
   auto CFGStmt = Element.getAs<clang::CFGStmt>();
   if (!CFGStmt) return;
@@ -311,7 +336,7 @@
 
   collectEvidenceFromArgsAndParams(*ConstructorDecl, *ConstructExpr,
                                    InferableSlots, InferableSlotsConstraint,
-                                   Env, Emit);
+                                   Lattice, Env, Emit);
 }
 
 void collectEvidenceFromReturn(
@@ -350,7 +375,8 @@
 
 void collectEvidenceFromAssignment(
     std::vector<std::pair<PointerTypeNullability, Slot>> &InferableSlots,
-    const CFGElement &Element, const dataflow::Environment &Env,
+    const Formula &InferableSlotsConstraint, const CFGElement &Element,
+    const PointerNullabilityLattice &Lattice, const dataflow::Environment &Env,
     llvm::function_ref<EvidenceEmitter> Emit) {
   auto CFGStmt = Element.getAs<clang::CFGStmt>();
   if (!CFGStmt) return;
@@ -364,10 +390,11 @@
         auto *PV = getPointerValueFromExpr(VarDecl->getInit(), Env);
         if (!PV) return;
         TypeNullability TypeNullability =
-            getNullabilityAnnotationsFromType(VarDecl->getType());
-        collectEvidenceFromBindingToType(TypeNullability, *PV, InferableSlots,
-                                         Env, VarDecl->getInit()->getExprLoc(),
-                                         Emit);
+            getNullabilityAnnotationsFromTypeAndOverrides(VarDecl->getType(),
+                                                          VarDecl, Lattice);
+        collectEvidenceFromBindingToType(
+            TypeNullability, *PV, InferableSlots, InferableSlotsConstraint, Env,
+            VarDecl->getInit()->getExprLoc(), Emit);
       }
     }
   }
@@ -379,26 +406,35 @@
       isSupportedPointerType(BinaryOperator->getLHS()->getType())) {
     auto *PV = getPointerValueFromExpr(BinaryOperator->getRHS(), Env);
     if (!PV) return;
-    TypeNullability TypeNullability =
-        getNullabilityAnnotationsFromType(BinaryOperator->getLHS()->getType());
-    collectEvidenceFromBindingToType(TypeNullability, *PV, InferableSlots, Env,
-                                     BinaryOperator->getRHS()->getExprLoc(),
-                                     Emit);
+    TypeNullability TypeNullability;
+    if (auto *DeclRefExpr =
+            dyn_cast_or_null<clang::DeclRefExpr>(BinaryOperator->getLHS())) {
+      TypeNullability = getNullabilityAnnotationsFromTypeAndOverrides(
+          BinaryOperator->getLHS()->getType(), DeclRefExpr->getDecl(), Lattice);
+    } else {
+      TypeNullability = getNullabilityAnnotationsFromType(
+          BinaryOperator->getLHS()->getType());
+    }
+    collectEvidenceFromBindingToType(
+        TypeNullability, *PV, InferableSlots, InferableSlotsConstraint, Env,
+        BinaryOperator->getRHS()->getExprLoc(), Emit);
   }
 }
 
 void collectEvidenceFromElement(
     std::vector<std::pair<PointerTypeNullability, Slot>> InferableSlots,
     const Formula &InferableSlotsConstraint, const CFGElement &Element,
-    const Environment &Env, llvm::function_ref<EvidenceEmitter> Emit) {
+    const PointerNullabilityLattice &Lattice, const Environment &Env,
+    llvm::function_ref<EvidenceEmitter> Emit) {
   collectEvidenceFromDereference(InferableSlots, Element, Env, Emit);
   collectEvidenceFromCallExpr(InferableSlots, InferableSlotsConstraint, Element,
-                              Env, Emit);
+                              Lattice, Env, Emit);
   collectEvidenceFromConstructExpr(InferableSlots, InferableSlotsConstraint,
-                                   Element, Env, Emit);
+                                   Element, Lattice, Env, Emit);
   collectEvidenceFromReturn(InferableSlots, InferableSlotsConstraint, Element,
                             Env, Emit);
-  collectEvidenceFromAssignment(InferableSlots, Element, Env, Emit);
+  collectEvidenceFromAssignment(InferableSlots, InferableSlotsConstraint,
+                                Element, Lattice, Env, Emit);
   // TODO: add more heuristic collections here
 }
 
@@ -506,7 +542,7 @@
                      PointerNullabilityLattice> &State) {
                collectEvidenceFromElement(InferableSlots,
                                           InferableSlotsConstraint, Element,
-                                          State.Env, Emit);
+                                          State.Lattice, State.Env, Emit);
              })
       .takeError();
 }
diff --git a/nullability/inference/collect_evidence_test.cc b/nullability/inference/collect_evidence_test.cc
index 87567b3..115c3bf 100644
--- a/nullability/inference/collect_evidence_test.cc
+++ b/nullability/inference/collect_evidence_test.cc
@@ -65,6 +65,8 @@
     using Nullable [[clang::annotate("Nullable")]] = T;
     template <typename T>
     using Nonnull [[clang::annotate("Nonnull")]] = T;
+    template <typename T>
+    using Unknown [[clang::annotate("Nullability_Unspecified")]] = T;
   )cc";
   Inputs.ExtraArgs.push_back("-include");
   Inputs.ExtraArgs.push_back("nullability.h");
@@ -499,6 +501,18 @@
                                     functionNamed("target"))));
 }
 
+TEST(CollectEvidenceFromImplementationTest, AssignedToNullableOrUnknown) {
+  static constexpr llvm::StringRef Src = R"cc(
+    void target(int* p, int* q, int* r) {
+      Nullable<int*> a = p;
+      int* b = q;
+      Unknown<int*> c = r;
+      q = r;
+    }
+  )cc";
+  EXPECT_THAT(collectEvidenceFromTargetFunction(Src), IsEmpty());
+}
+
 // A crash repro involving callable parameters.
 TEST(CollectEvidenceFromImplementationTest, FunctionPointerParam) {
   static constexpr llvm::StringRef Src = R"cc(
@@ -597,31 +611,31 @@
   std::string ReturnsToBeNonnullUsr = "c:@F@returnsToBeNonnull#*I#";
   const llvm::DenseMap<int, std::vector<testing::Matcher<const Evidence&>>>
       ExpectedNewResultsPerRound = {
-          {{0,
-            {evidence(
-                paramSlot(0), Evidence::UNCHECKED_DEREFERENCE,
-                AllOf(functionNamed("target"),
-                      // Double-check that target's usr is as expected before
-                      // we use it to create SlotFingerprints.
-                      ResultOf([](Symbol S) { return S.usr(); }, TargetUsr)))}},
-           {1,
-            {evidence(
-                paramSlot(0), Evidence::NONNULL_ARGUMENT,
-                AllOf(functionNamed("returnsToBeNonnull"),
-                      // Double-check that returnsToBeNonnull's usr is as
-                      // expected before we use it to create SlotFingerprints.
-                      ResultOf([](Symbol S) { return S.usr(); },
-                               ReturnsToBeNonnullUsr)))}},
-           {2,
-            {
-                // No new evidence from target's implementation in this round,
-                // but in a full-TU analysis, this would be the round where we
-                // decide returnsToBeNonnull returns Nonnull, based on the
-                // now-Nonnull argument that is the only return value.
-            }},
-           {3,
-            {evidence(SLOT_RETURN_TYPE, Evidence::NONNULL_RETURN,
-                      functionNamed("target"))}}}};
+          {0,
+           {evidence(
+               paramSlot(0), Evidence::UNCHECKED_DEREFERENCE,
+               AllOf(functionNamed("target"),
+                     // Double-check that target's usr is as expected before
+                     // we use it to create SlotFingerprints.
+                     ResultOf([](Symbol S) { return S.usr(); }, TargetUsr)))}},
+          {1,
+           {evidence(
+               paramSlot(0), Evidence::NONNULL_ARGUMENT,
+               AllOf(functionNamed("returnsToBeNonnull"),
+                     // Double-check that returnsToBeNonnull's usr is as
+                     // expected before we use it to create SlotFingerprints.
+                     ResultOf([](Symbol S) { return S.usr(); },
+                              ReturnsToBeNonnullUsr)))}},
+          {2,
+           {
+               // No new evidence from target's implementation in this round,
+               // but in a full-TU analysis, this would be the round where we
+               // decide returnsToBeNonnull returns Nonnull, based on the
+               // now-Nonnull argument that is the only return value.
+           }},
+          {3,
+           {evidence(SLOT_RETURN_TYPE, Evidence::NONNULL_RETURN,
+                     functionNamed("target"))}}};
 
   // Assert first round results because they don't rely on previous inference
   // propagation at all and in this case are test setup and preconditions.
@@ -669,6 +683,29 @@
                     IsSupersetOf(ExpectedNewResultsPerRound.at(3))));
 }
 
+TEST(CollectEvidenceFromImplementationTest,
+     PreviousInferencesOfNonTargetParameterNullabilitiesPropagate) {
+  static constexpr llvm::StringRef Src = R"cc(
+    void takesToBeNonnull(int* a) {
+      // Not read when collecting evidence only from Target, but corresponding
+      // inference is explicitly input below.
+      *a;
+    }
+    void target(int* q) { takesToBeNonnull(q); }
+  )cc";
+  std::string TakesToBeNonnullUsr = "c:@F@takesToBeNonnull#*I#";
+
+  // Pretend that in a first round of inferring for all functions, we made this
+  // inference about takesToBeNonnull's first parameter.
+  // This test confirms that we use that information when collecting from
+  // target's implementation.
+  EXPECT_THAT(
+      collectEvidenceFromTargetFunction(
+          Src, {.Nonnull = {fingerprint(TakesToBeNonnullUsr, paramSlot(0))}}),
+      Contains(evidence(paramSlot(0), Evidence::BOUND_TO_NONNULL,
+                        functionNamed("target"))));
+}
+
 TEST(CollectEvidenceFromDeclarationTest, VariableDeclIgnored) {
   llvm::StringLiteral Src = "Nullable<int *> target;";
   EXPECT_THAT(collectEvidenceFromTargetDecl(Src), IsEmpty());
diff --git a/nullability/inference/infer_tu_test.cc b/nullability/inference/infer_tu_test.cc
index 4814523..5dc4449 100644
--- a/nullability/inference/infer_tu_test.cc
+++ b/nullability/inference/infer_tu_test.cc
@@ -215,10 +215,13 @@
 
 TEST_F(InferTUTest, IterationsPropagateInferences) {
   build(R"cc(
+    void takesToBeNonnull(int* x) { *x; }
     int* returnsToBeNonnull(int* a) { return a; }
-    int* target(int* q) {
-      *q;
-      return returnsToBeNonnull(q);
+    int* target(int* p, int* q, int* r) {
+      *p;
+      takesToBeNonnull(q);
+      q = r;
+      return returnsToBeNonnull(p);
     }
   )cc");
   EXPECT_THAT(
@@ -228,31 +231,44 @@
                                         inferredSlot(1, Inference::NONNULL)}),
           inference(hasName("returnsToBeNonnull"),
                     {inferredSlot(0, Inference::UNKNOWN),
-                     inferredSlot(1, Inference::UNKNOWN)})));
+                     inferredSlot(1, Inference::UNKNOWN)}),
+          inference(hasName("takesToBeNonnull"),
+                    {inferredSlot(1, Inference::NONNULL)})));
   EXPECT_THAT(
       inferTU(AST->context(), /*Iterations=*/2),
       UnorderedElementsAre(
           inference(hasName("target"), {inferredSlot(0, Inference::UNKNOWN),
-                                        inferredSlot(1, Inference::NONNULL)}),
+                                        inferredSlot(1, Inference::NONNULL),
+                                        inferredSlot(2, Inference::NONNULL)}),
           inference(hasName("returnsToBeNonnull"),
                     {inferredSlot(0, Inference::UNKNOWN),
-                     inferredSlot(1, Inference::NONNULL)})));
+                     inferredSlot(1, Inference::NONNULL)}),
+          inference(hasName("takesToBeNonnull"),
+                    {inferredSlot(1, Inference::NONNULL)})));
   EXPECT_THAT(
       inferTU(AST->context(), /*Iterations=*/3),
       UnorderedElementsAre(
           inference(hasName("target"), {inferredSlot(0, Inference::UNKNOWN),
-                                        inferredSlot(1, Inference::NONNULL)}),
+                                        inferredSlot(1, Inference::NONNULL),
+                                        inferredSlot(2, Inference::NONNULL),
+                                        inferredSlot(3, Inference::NONNULL)}),
           inference(hasName("returnsToBeNonnull"),
                     {inferredSlot(0, Inference::NONNULL),
-                     inferredSlot(1, Inference::NONNULL)})));
+                     inferredSlot(1, Inference::NONNULL)}),
+          inference(hasName("takesToBeNonnull"),
+                    {inferredSlot(1, Inference::NONNULL)})));
   EXPECT_THAT(
       inferTU(AST->context(), /*Iterations=*/4),
       UnorderedElementsAre(
           inference(hasName("target"), {inferredSlot(0, Inference::NONNULL),
-                                        inferredSlot(1, Inference::NONNULL)}),
+                                        inferredSlot(1, Inference::NONNULL),
+                                        inferredSlot(2, Inference::NONNULL),
+                                        inferredSlot(3, Inference::NONNULL)}),
           inference(hasName("returnsToBeNonnull"),
                     {inferredSlot(0, Inference::NONNULL),
-                     inferredSlot(1, Inference::NONNULL)})));
+                     inferredSlot(1, Inference::NONNULL)}),
+          inference(hasName("takesToBeNonnull"),
+                    {inferredSlot(1, Inference::NONNULL)})));
 }
 
 }  // namespace
diff --git a/nullability/pointer_nullability_analysis.cc b/nullability/pointer_nullability_analysis.cc
index 466bb99..b02582d 100644
--- a/nullability/pointer_nullability_analysis.cc
+++ b/nullability/pointer_nullability_analysis.cc
@@ -554,25 +554,12 @@
   transferFlowSensitiveCallExpr(MCE, Result, State);
 }
 
-// If nullability for the decl D has been overridden, patch N to reflect it.
-// (N is the nullability of an access to D).
-void overrideNullabilityFromDecl(const Decl *D,
-                                 PointerNullabilityLattice &Lattice,
-                                 TypeNullability &N) {
-  // For now, overrides are always for pointer values only, and override only
-  // the top-level nullability.
-  if (auto *PN = Lattice.getDeclNullability(D)) {
-    CHECK(!N.empty());
-    N.front() = *PN;
-  }
-}
-
 void transferNonFlowSensitiveDeclRefExpr(
     const DeclRefExpr *DRE, const MatchFinder::MatchResult &MR,
     TransferState<PointerNullabilityLattice> &State) {
   computeNullability(DRE, State, [&] {
     auto Nullability = getNullabilityAnnotationsFromType(DRE->getType());
-    overrideNullabilityFromDecl(DRE->getDecl(), State.Lattice, Nullability);
+    State.Lattice.overrideNullabilityFromDecl(DRE->getDecl(), Nullability);
     return Nullability;
   });
 }
@@ -595,8 +582,7 @@
     }
     auto Nullability = substituteNullabilityAnnotationsInClassTemplate(
         MemberType, BaseNullability, ME->getBase()->getType());
-    overrideNullabilityFromDecl(ME->getMemberDecl(), State.Lattice,
-                                Nullability);
+    State.Lattice.overrideNullabilityFromDecl(ME->getMemberDecl(), Nullability);
     return Nullability;
   });
 }
@@ -771,8 +757,8 @@
     auto Nullability =
         substituteNullabilityAnnotationsInFunctionTemplate(CE->getType(), CE);
     if (!Nullability.empty()) {
-      overrideNullabilityFromDecl(CE->getCalleeDecl(), State.Lattice,
-                                  Nullability);
+      State.Lattice.overrideNullabilityFromDecl(CE->getCalleeDecl(),
+                                                Nullability);
     }
     return Nullability;
   });
diff --git a/nullability/pointer_nullability_lattice.cc b/nullability/pointer_nullability_lattice.cc
new file mode 100644
index 0000000..abab6ab
--- /dev/null
+++ b/nullability/pointer_nullability_lattice.cc
@@ -0,0 +1,45 @@
+// Part of the Crubit project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "nullability/pointer_nullability_lattice.h"
+
+#include <optional>
+
+#include "absl/log/check.h"
+#include "nullability/type_nullability.h"
+#include "clang/AST/Decl.h"
+#include "clang/AST/DeclBase.h"
+#include "clang/Basic/LLVM.h"
+
+namespace clang::tidy::nullability {
+namespace {
+// Returns overridden nullability information associated with a declaration.
+// For now we only track top-level decl nullability symbolically and check for
+// concrete nullability override results.
+const PointerTypeNullability *getDeclNullability(
+    const Decl *D,
+    const PointerNullabilityLattice::NonFlowSensitiveState &NFS) {
+  if (!D) return nullptr;
+  if (const auto *VD = dyn_cast_or_null<ValueDecl>(D)) {
+    auto It = NFS.DeclTopLevelNullability.find(VD);
+    if (It != NFS.DeclTopLevelNullability.end()) return &It->second;
+  }
+  if (const std::optional<const PointerTypeNullability *> N =
+          NFS.ConcreteNullabilityOverride(*D))
+    return *N;
+  return nullptr;
+}
+}  // namespace
+
+void PointerNullabilityLattice::overrideNullabilityFromDecl(
+    const Decl *D, TypeNullability &N) const {
+  // For now, overrides are always for pointer values only, and override only
+  // the top-level nullability.
+  if (auto *PN = getDeclNullability(D, NFS)) {
+    CHECK(!N.empty());
+    N.front() = *PN;
+  }
+}
+
+}  // namespace clang::tidy::nullability
diff --git a/nullability/pointer_nullability_lattice.h b/nullability/pointer_nullability_lattice.h
index 0b7b906..932fbb3 100644
--- a/nullability/pointer_nullability_lattice.h
+++ b/nullability/pointer_nullability_lattice.h
@@ -62,20 +62,9 @@
     return Iterator->second;
   }
 
-  // Returns overridden nullability information associated with a declaration.
-  // For now we only track top-level decl nullability symbolically and check for
-  // concrete nullability override results.
-  const PointerTypeNullability *getDeclNullability(const Decl *D) const {
-    if (!D) return nullptr;
-    if (const auto *VD = dyn_cast_or_null<ValueDecl>(D)) {
-      auto It = NFS.DeclTopLevelNullability.find(VD);
-      if (It != NFS.DeclTopLevelNullability.end()) return &It->second;
-    }
-    if (const std::optional<const PointerTypeNullability *> N =
-            NFS.ConcreteNullabilityOverride(*D))
-      return *N;
-    return nullptr;
-  }
+  // If nullability for the decl D has been overridden, patch N to reflect it.
+  // (N is the nullability of an access to D).
+  void overrideNullabilityFromDecl(const Decl *D, TypeNullability &N) const;
 
   bool operator==(const PointerNullabilityLattice &Other) const { return true; }