Emit evidence for virtual methods and overrides that constrain each other.

The nullabilities of return and parameter types for these functions are partially constrained by the overriding/overridden method(s). So when collecting evidence for one, collect the same evidence for the overriding or overridden method as appropriate for that constraint.

Because callers can virtually call Derived::foo or Base::foo from a Base object and can call only Derived::foo from a Derived object (without very explicitly referencing the Base method decl), the Derived API can be only as restrictive as the Base API from the *caller's* perspective in terms of what it can return or what it accepts for parameters. Derived::foo can be less restrictive, or equally restrictive, but not more, e.g. Derived::foo can accept Nullable parameters instead of just Base::foo's Nonnulls, but Derived::foo can't require Nonnull parameters instead of Base::foo's allowed Nullables.

Therefore, we apply the following rules for evidence collection. For return types, evidence pointing towards Nonnull emitted for Base methods is emitted also for Derived methods and evidence pointing towards Nullable emitted for Derived methods is emitted also for Base methods. For parameter types, this is flipped. See the many tests in this change for examples of evidence collected and inferences made in various scenarios.

PiperOrigin-RevId: 646930709
Change-Id: I541ec5b9973d80f26114e53f81d9fbfae38ff24c
diff --git a/nullability/inference/collect_evidence.cc b/nullability/inference/collect_evidence.cc
index 61e9a28..d93dbf1 100644
--- a/nullability/inference/collect_evidence.cc
+++ b/nullability/inference/collect_evidence.cc
@@ -64,6 +64,7 @@
 #include "llvm/ADT/STLFunctionalExtras.h"
 #include "llvm/Support/Errc.h"
 #include "llvm/Support/Error.h"
+#include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/raw_ostream.h"
 
 namespace clang::tidy::nullability {
@@ -86,15 +87,123 @@
   return It->second;
 }
 
+static llvm::DenseSet<const CXXMethodDecl *> getOverridden(
+    const CXXMethodDecl *Derived) {
+  llvm::DenseSet<const CXXMethodDecl *> Overridden;
+  for (const CXXMethodDecl *Base : Derived->overridden_methods()) {
+    Overridden.insert(Base);
+    for (const CXXMethodDecl *BaseOverridden : getOverridden(Base)) {
+      Overridden.insert(BaseOverridden);
+    }
+  }
+  return Overridden;
+}
+
+/// Shared base class for visitors that walk the AST for evidence collection
+/// purposes, to ensure they see the same nodes.
+template <typename Derived>
+struct EvidenceLocationsWalker : public RecursiveASTVisitor<Derived> {
+  // We do want to see concrete code, including function instantiations.
+  bool shouldVisitTemplateInstantiations() const { return true; }
+
+  // In order to collect from more default member initializers, we do want to
+  // see defaulted default constructors, which are implicitly-defined
+  // functions whether the declaration is implicit or explicit. We also want
+  // to see lambda bodies in the form of operator() definitions that are not
+  // themselves implicit but show up in an implicit context.
+  bool shouldVisitImplicitCode() const { return true; }
+};
+
+using VirtualMethodOverridesMap =
+    absl::flat_hash_map<const CXXMethodDecl *,
+                        llvm::DenseSet<const CXXMethodDecl *>>;
+
+/// Collect a map from virtual methods to a set of their overrides.
+static VirtualMethodOverridesMap getVirtualMethodOverrides(ASTContext &Ctx) {
+  struct Walker : public EvidenceLocationsWalker<Walker> {
+    VirtualMethodOverridesMap Out;
+
+    bool VisitCXXMethodDecl(const CXXMethodDecl *MD) {
+      if (MD->isVirtual()) {
+        for (const auto *O : getOverridden(MD)) {
+          Out[O].insert(MD);
+        }
+      }
+      return true;
+    }
+  };
+
+  Walker W;
+  W.TraverseAST(Ctx);
+  return std::move(W.Out);
+}
+
+namespace {
+enum VirtualMethodEvidenceFlowDirection {
+  kFromBaseToDerived,
+  kFromDerivedToBase,
+  // Bidirectional evidence or new evidence kinds that create bidirectional
+  // information could be used for low-priority heuristics, e.g. Nonnull returns
+  // in all derived => Nonnull return for the base. These are not currently
+  // supported, though.
+};
+}  // namespace
+
+static VirtualMethodEvidenceFlowDirection getFlowDirection(Evidence::Kind Kind,
+                                                           bool ForReturnSlot) {
+  switch (Kind) {
+    case Evidence::ANNOTATED_NONNULL:
+    case Evidence::UNCHECKED_DEREFERENCE:
+    case Evidence::NONNULL_ARGUMENT:
+    case Evidence::NONNULL_RETURN:
+    case Evidence::ASSIGNED_TO_NONNULL:
+    case Evidence::ABORT_IF_NULL:
+    case Evidence::ARITHMETIC:
+    case Evidence::GCC_NONNULL_ATTRIBUTE:
+      // Evidence pointing toward Unknown is only used to prevent Nonnull
+      // inferences; it cannot override Nullable. So propagate it in the same
+      // direction we do for Nonnull-pointing evidence.
+    case Evidence::ANNOTATED_UNKNOWN:
+    case Evidence::UNKNOWN_ARGUMENT:
+    case Evidence::UNKNOWN_RETURN:
+      return ForReturnSlot ? kFromBaseToDerived : kFromDerivedToBase;
+    case Evidence::ANNOTATED_NULLABLE:
+    case Evidence::NULLABLE_ARGUMENT:
+    case Evidence::NULLABLE_RETURN:
+    case Evidence::ASSIGNED_TO_MUTABLE_NULLABLE:
+    case Evidence::ASSIGNED_FROM_NULLABLE:
+    case Evidence::NULLPTR_DEFAULT_MEMBER_INITIALIZER:
+      return ForReturnSlot ? kFromDerivedToBase : kFromBaseToDerived;
+  }
+}
+
+static llvm::DenseSet<const CXXMethodDecl *>
+getAdditionalTargetsForVirtualMethod(
+    const CXXMethodDecl *MD, Evidence::Kind Kind, bool ForReturnSlot,
+    const VirtualMethodOverridesMap &OverridesMap) {
+  VirtualMethodEvidenceFlowDirection FlowDirection =
+      getFlowDirection(Kind, ForReturnSlot);
+  switch (FlowDirection) {
+    case kFromBaseToDerived:
+      if (auto It = OverridesMap.find(MD); It != OverridesMap.end())
+        return It->second;
+      return {};
+    case kFromDerivedToBase:
+      return getOverridden(MD);
+  }
+}
+
 llvm::unique_function<EvidenceEmitter> evidenceEmitter(
     llvm::unique_function<void(const Evidence &) const> Emit,
-    nullability::USRCache &USRCache) {
+    USRCache &USRCache, ASTContext &Ctx) {
   class EvidenceEmitterImpl {
    public:
     EvidenceEmitterImpl(
         llvm::unique_function<void(const Evidence &) const> Emit,
-        nullability::USRCache &USRCache)
-        : Emit(std::move(Emit)), USRCache(USRCache) {}
+        nullability::USRCache &USRCache, ASTContext &Ctx)
+        : Emit(std::move(Emit)),
+          USRCache(USRCache),
+          OverridesMap(getVirtualMethodOverrides(Ctx)) {}
 
     void operator()(const Decl &Target, Slot S, Evidence::Kind Kind,
                     SourceLocation Loc) const {
@@ -121,13 +230,28 @@
         E.set_location(Loc.printToString(SM));
 
       Emit(E);
+
+      // Virtual methods and their overrides constrain each other's
+      // nullabilities, so propagate evidence in the appropriate direction based
+      // on the evidence kind and whether the evidence is for the return type or
+      // a parameter type.
+      if (auto *MD = dyn_cast<CXXMethodDecl>(&Target); MD && MD->isVirtual()) {
+        for (const auto *O : getAdditionalTargetsForVirtualMethod(
+                 MD, Kind, S == SLOT_RETURN_TYPE, OverridesMap)) {
+          USR = getOrGenerateUSR(USRCache, *O);
+          if (USR.empty()) return;  // Can't emit without a USR
+          E.mutable_symbol()->set_usr(USR);
+          Emit(E);
+        }
+      }
     }
 
    private:
     llvm::unique_function<void(const Evidence &) const> Emit;
     nullability::USRCache &USRCache;
+    const VirtualMethodOverridesMap OverridesMap;
   };
-  return EvidenceEmitterImpl(std::move(Emit), USRCache);
+  return EvidenceEmitterImpl(std::move(Emit), USRCache, Ctx);
 }
 
 namespace {
@@ -1369,9 +1493,9 @@
   }
 }
 
-void collectNonnullAttributeEvidence(const clang::FunctionDecl &Fn,
-                                     unsigned ParamIndex, SourceLocation Loc,
-                                     llvm::function_ref<EvidenceEmitter> Emit) {
+static void collectNonnullAttributeEvidence(
+    const clang::FunctionDecl &Fn, unsigned ParamIndex, SourceLocation Loc,
+    llvm::function_ref<EvidenceEmitter> Emit) {
   const ParmVarDecl *ParamDecl = Fn.getParamDecl(ParamIndex);
   // The attribute does not apply to references-to-pointers or nested pointers
   // or smart pointers.
@@ -1442,19 +1566,9 @@
 }
 
 EvidenceSites EvidenceSites::discover(ASTContext &Ctx) {
-  struct Walker : public RecursiveASTVisitor<Walker> {
+  struct Walker : public EvidenceLocationsWalker<Walker> {
     EvidenceSites Out;
 
-    // We do want to see concrete code, including function instantiations.
-    bool shouldVisitTemplateInstantiations() const { return true; }
-
-    // In order to collect from more default member initializers, we do want to
-    // see defaulted default constructors, which are implicitly-defined
-    // functions whether the declaration is implicit or explicit. We also want
-    // to see lambda bodies in the form of operator() definitions that are not
-    // themselves implicit but show up in an implicit context.
-    bool shouldVisitImplicitCode() const { return true; }
-
     bool VisitFunctionDecl(absl::Nonnull<const FunctionDecl *> FD) {
       if (isInferenceTarget(*FD)) Out.Declarations.insert(FD);
 
diff --git a/nullability/inference/collect_evidence.h b/nullability/inference/collect_evidence.h
index 6cfe4af..66f109e 100644
--- a/nullability/inference/collect_evidence.h
+++ b/nullability/inference/collect_evidence.h
@@ -5,6 +5,7 @@
 #ifndef CRUBIT_NULLABILITY_INFERENCE_COLLECT_EVIDENCE_H_
 #define CRUBIT_NULLABILITY_INFERENCE_COLLECT_EVIDENCE_H_
 
+#include <memory>
 #include <string>
 #include <string_view>
 
@@ -12,7 +13,9 @@
 #include "nullability/inference/slot_fingerprint.h"
 #include "nullability/pointer_nullability_analysis.h"
 #include "nullability/pragma.h"
+#include "clang/AST/ASTContext.h"
 #include "clang/AST/DeclBase.h"
+#include "clang/AST/DeclCXX.h"
 #include "clang/Analysis/FlowSensitive/Solver.h"
 #include "clang/Basic/SourceLocation.h"
 #include "llvm/ADT/DenseMap.h"
@@ -33,7 +36,8 @@
 /// Creates an EvidenceEmitter that serializes the evidence as Evidence protos.
 /// This emitter caches USR generation, and should be reused for the whole AST.
 llvm::unique_function<EvidenceEmitter> evidenceEmitter(
-    llvm::unique_function<void(const Evidence &) const>, USRCache &USRCache);
+    llvm::unique_function<void(const Evidence &) const>, USRCache &USRCache,
+    ASTContext &Ctx);
 
 struct PreviousInferences {
   const llvm::DenseSet<SlotFingerprint> &Nullable = {};
diff --git a/nullability/inference/collect_evidence_test.cc b/nullability/inference/collect_evidence_test.cc
index 33083ad..35bd59b 100644
--- a/nullability/inference/collect_evidence_test.cc
+++ b/nullability/inference/collect_evidence_test.cc
@@ -114,7 +114,7 @@
       collectEvidenceFromDefinition(
           Definition,
           evidenceEmitter([&](const Evidence& E) { Results.push_back(E); },
-                          UsrCache),
+                          UsrCache, AST.context()),
           UsrCache, Pragmas, InputInferences),
       llvm::Succeeded());
   return Results;
@@ -157,7 +157,7 @@
   collectEvidenceFromTargetDeclaration(
       *dataflow::test::findValueDecl(AST.context(), "target"),
       evidenceEmitter([&](const Evidence& E) { Results.push_back(E); },
-                      USRCache),
+                      USRCache, AST.context()),
       Pragmas);
   return Results;
 }
@@ -2556,6 +2556,282 @@
                                 functionNamed("foo"))));
 }
 
+// Evidence for return type nonnull-ness should flow only from derived to base,
+// so we collect evidence for the base but not the derived.
+TEST(CollectEvidenceFromDefinitionTest, FromVirtualDerivedForReturnNonnull) {
+  static constexpr llvm::StringRef Src = R"cc(
+    struct Base {
+      virtual int* foo();
+    };
+
+    struct Derived : public Base {
+      int* foo() override {
+        static int i;
+        return &i;
+      }
+    };
+
+    void target() {
+      Derived D;
+      *D.foo();
+    }
+  )cc";
+  EXPECT_THAT(
+      collectFromDefinitionNamed("Derived::foo", Src),
+      UnorderedElementsAre(evidence(SLOT_RETURN_TYPE, Evidence::NONNULL_RETURN,
+                                    functionNamed("Derived@F@foo"))));
+
+  EXPECT_THAT(collectFromTargetFuncDefinition(Src),
+              UnorderedElementsAre(evidence(SLOT_RETURN_TYPE,
+                                            Evidence::UNCHECKED_DEREFERENCE,
+                                            functionNamed("Derived@F@foo"))));
+}
+
+TEST(CollectEvidenceFromDefinitionTest, FromVirtualDerivedForReturnNullable) {
+  static constexpr llvm::StringRef Src = R"cc(
+    struct Base {
+      virtual int* foo();
+    };
+
+    struct Derived : public Base {
+      int* foo() override { return nullptr; }
+    };
+  )cc";
+  EXPECT_THAT(
+      collectFromDefinitionNamed("Derived::foo", Src),
+      UnorderedElementsAre(evidence(SLOT_RETURN_TYPE, Evidence::NULLABLE_RETURN,
+                                    functionNamed("Derived@F@foo")),
+                           evidence(SLOT_RETURN_TYPE, Evidence::NULLABLE_RETURN,
+                                    functionNamed("Base@F@foo"))));
+
+  // We don't currently have any evidence kinds that can force a non-reference
+  // top-level pointer return type to be nullable from its usage, so no other
+  // expectation.
+}
+
+TEST(CollectEvidenceFromDefinitionTest, FromVirtualDerivedForParamNonnull) {
+  static constexpr llvm::StringRef Src = R"cc(
+    struct Base {
+      virtual void foo(int* p);
+    };
+
+    struct Derived : public Base {
+      void foo(int* p) override { *p; }
+    };
+
+    void target() {
+      int i;
+      Derived D;
+      D.foo(&i);
+    }
+  )cc";
+  EXPECT_THAT(
+      collectFromTargetFuncDefinition(Src),
+      UnorderedElementsAre(evidence(paramSlot(0), Evidence::NONNULL_ARGUMENT,
+                                    functionNamed("Derived@F@foo")),
+                           evidence(paramSlot(0), Evidence::NONNULL_ARGUMENT,
+                                    functionNamed("Base@F@foo"))));
+
+  EXPECT_THAT(collectFromDefinitionNamed("Derived::foo", Src),
+              UnorderedElementsAre(
+                  evidence(paramSlot(0), Evidence::UNCHECKED_DEREFERENCE,
+                           functionNamed("Derived@F@foo")),
+                  evidence(paramSlot(0), Evidence::UNCHECKED_DEREFERENCE,
+                           functionNamed("Base@F@foo"))));
+}
+
+// Evidence for parameter nullable-ness should flow only from base to derived,
+// so we collect evidence for the derived but not the base.
+TEST(CollectEvidenceFromDefinitionTest, FromVirtualDerivedForParamNullable) {
+  static constexpr llvm::StringRef Src = R"cc(
+    struct Base {
+      virtual void foo(int* p);
+    };
+
+    struct Derived : public Base {
+      void foo(int* p) override { p = nullptr; }
+    };
+
+    void target() {
+      Derived D;
+      D.foo(nullptr);
+    }
+  )cc";
+  EXPECT_THAT(
+      collectFromTargetFuncDefinition(Src),
+      UnorderedElementsAre(evidence(paramSlot(0), Evidence::NULLABLE_ARGUMENT,
+                                    functionNamed("Derived@F@foo"))));
+
+  EXPECT_THAT(collectFromDefinitionNamed("Derived::foo", Src),
+              UnorderedElementsAre(evidence(paramSlot(0),
+                                            Evidence::ASSIGNED_FROM_NULLABLE,
+                                            functionNamed("Derived@F@foo"))));
+}
+
+TEST(CollectEvidenceFromDefinitionTest, FromVirtualBaseForReturnNonnull) {
+  static constexpr llvm::StringRef Src = R"cc(
+    struct Base {
+      virtual int* foo() {
+        static int i;
+        return &i;
+      }
+    };
+
+    struct Derived : public Base {
+      int* foo() override;
+    };
+
+    void target() {
+      Base B;
+      *B.foo();
+    }
+  )cc";
+  EXPECT_THAT(
+      collectFromDefinitionNamed("Base::foo", Src),
+      UnorderedElementsAre(evidence(SLOT_RETURN_TYPE, Evidence::NONNULL_RETURN,
+                                    functionNamed("Base@F@foo")),
+                           evidence(SLOT_RETURN_TYPE, Evidence::NONNULL_RETURN,
+                                    functionNamed("Derived@F@foo"))));
+
+  EXPECT_THAT(collectFromTargetFuncDefinition(Src),
+              UnorderedElementsAre(
+                  evidence(SLOT_RETURN_TYPE, Evidence::UNCHECKED_DEREFERENCE,
+                           functionNamed("Base@F@foo")),
+                  evidence(SLOT_RETURN_TYPE, Evidence::UNCHECKED_DEREFERENCE,
+                           functionNamed("Derived@F@foo"))));
+}
+
+// Evidence for return type nullable-ness should flow only from derived to base,
+// so we collect evidence for the base but not the derived.
+TEST(CollectEvidenceFromDefinitionTest, FromVirtualBaseForReturnNullable) {
+  static constexpr llvm::StringRef Src = R"cc(
+    struct Base {
+      virtual int* foo() { return nullptr; }
+    };
+
+    struct Derived : public Base {
+      int* foo() override;
+    };
+  )cc";
+  EXPECT_THAT(
+      collectFromDefinitionNamed("Base::foo", Src),
+      UnorderedElementsAre(evidence(SLOT_RETURN_TYPE, Evidence::NULLABLE_RETURN,
+                                    functionNamed("Base@F@foo"))));
+
+  // We don't currently have any evidence kinds that can force a non-reference
+  // top-level pointer return type to be nullable from its usage, so no other
+  // expectation.
+}
+
+// Evidence for parameter nonnull-ness should flow only from derived to base, so
+// we collect evidence for the base but not the derived.
+TEST(CollectEvidenceFromDefinitionTest, FromVirtualBaseForParamNonnull) {
+  static constexpr llvm::StringRef Src = R"cc(
+    struct Base {
+      virtual void foo(int* p) { *p; }
+    };
+
+    struct Derived : public Base {
+      void foo(int* p) override;
+    };
+
+    void target() {
+      int i;
+      Base B;
+      B.foo(&i);
+    }
+  )cc";
+  EXPECT_THAT(
+      collectFromTargetFuncDefinition(Src),
+      UnorderedElementsAre(evidence(paramSlot(0), Evidence::NONNULL_ARGUMENT,
+                                    functionNamed("Base@F@foo"))));
+
+  EXPECT_THAT(collectFromDefinitionNamed("Base::foo", Src),
+              UnorderedElementsAre(evidence(paramSlot(0),
+                                            Evidence::UNCHECKED_DEREFERENCE,
+                                            functionNamed("Base@F@foo"))));
+}
+
+TEST(CollectEvidenceFromDefinitionTest, FromVirtualBaseForParamNullable) {
+  static constexpr llvm::StringRef Src = R"cc(
+    struct Base {
+      virtual void foo(int* p) { p = nullptr; }
+    };
+
+    struct Derived : public Base {
+      void foo(int* p) override;
+    };
+
+    void target() {
+      Base B;
+      B.foo(nullptr);
+    }
+  )cc";
+  EXPECT_THAT(
+      collectFromTargetFuncDefinition(Src),
+      UnorderedElementsAre(evidence(paramSlot(0), Evidence::NULLABLE_ARGUMENT,
+                                    functionNamed("Base@F@foo")),
+                           evidence(paramSlot(0), Evidence::NULLABLE_ARGUMENT,
+                                    functionNamed("Derived@F@foo"))));
+
+  EXPECT_THAT(collectFromDefinitionNamed("Base::foo", Src),
+              UnorderedElementsAre(
+                  evidence(paramSlot(0), Evidence::ASSIGNED_FROM_NULLABLE,
+                           functionNamed("Base@F@foo")),
+                  evidence(paramSlot(0), Evidence::ASSIGNED_FROM_NULLABLE,
+                           functionNamed("Derived@F@foo"))));
+}
+
+TEST(CollectEvidenceFromDefinitionTest, FromVirtualDerivedMultipleLayers) {
+  static constexpr llvm::StringRef Src = R"cc(
+    struct Base {
+      virtual int* foo();
+    };
+
+    struct Derived : public Base {
+      virtual int* foo();
+    };
+
+    struct DerivedDerived : public Derived {
+      int* foo() override { return nullptr; };
+    };
+  )cc";
+
+  EXPECT_THAT(
+      collectFromDefinitionNamed("DerivedDerived::foo", Src),
+      UnorderedElementsAre(evidence(SLOT_RETURN_TYPE, Evidence::NULLABLE_RETURN,
+                                    functionNamed("DerivedDerived@F@foo")),
+                           evidence(SLOT_RETURN_TYPE, Evidence::NULLABLE_RETURN,
+                                    functionNamed("Derived@F@foo")),
+                           evidence(SLOT_RETURN_TYPE, Evidence::NULLABLE_RETURN,
+                                    functionNamed("Base@F@foo"))));
+}
+
+TEST(CollectEvidenceFromDefinitionTest, FromVirtualBaseMultipleLayers) {
+  static constexpr llvm::StringRef Src = R"cc(
+    struct Base {
+      virtual void foo(int* p) { p = nullptr; }
+    };
+
+    struct Derived : public Base {
+      virtual void foo(int*);
+    };
+
+    struct DerivedDerived : public Derived {
+      void foo(int*) override;
+    };
+  )cc";
+
+  EXPECT_THAT(collectFromDefinitionNamed("Base::foo", Src),
+              UnorderedElementsAre(
+                  evidence(paramSlot(0), Evidence::ASSIGNED_FROM_NULLABLE,
+                           functionNamed("DerivedDerived@F@foo")),
+                  evidence(paramSlot(0), Evidence::ASSIGNED_FROM_NULLABLE,
+                           functionNamed("Derived@F@foo")),
+                  evidence(paramSlot(0), Evidence::ASSIGNED_FROM_NULLABLE,
+                           functionNamed("Base@F@foo"))));
+}
+
 TEST(CollectEvidenceFromDefinitionTest, NotInferenceTarget) {
   static constexpr llvm::StringRef Src = R"cc(
     void isATarget(Nonnull<int*> a);
@@ -2788,7 +3064,7 @@
           *cast<FunctionDecl>(
               dataflow::test::findValueDecl(AST.context(), "target")),
           evidenceEmitter([&](const Evidence& E) { Results.push_back(E); },
-                          UsrCache),
+                          UsrCache, AST.context()),
           UsrCache, Pragmas, /*PreviousInferences=*/{},
           // Enough iterations to collect one piece of evidence but not both.
           []() {
@@ -3300,10 +3576,11 @@
   ASSERT_NE(TargetDecl, nullptr);
 
   USRCache USRCache;
-  EXPECT_DEATH(evidenceEmitter([](const Evidence& E) {}, USRCache)(
-                   *TargetDecl, Slot{}, Evidence::ANNOTATED_UNKNOWN,
-                   TargetDecl->getLocation()),
-               "not an inference target");
+  EXPECT_DEATH(
+      evidenceEmitter([](const Evidence& E) {}, USRCache, AST.context())(
+          *TargetDecl, Slot{}, Evidence::ANNOTATED_UNKNOWN,
+          TargetDecl->getLocation()),
+      "not an inference target");
 }
 
 }  // namespace
diff --git a/nullability/inference/infer_tu.cc b/nullability/inference/infer_tu.cc
index 051e49f..9c84fa7 100644
--- a/nullability/inference/infer_tu.cc
+++ b/nullability/inference/infer_tu.cc
@@ -39,8 +39,8 @@
     std::vector<Evidence> AllEvidence;
 
     // Collect all evidence.
-    auto Emitter =
-        evidenceEmitter([&](auto& E) { AllEvidence.push_back(E); }, USRCache);
+    auto Emitter = evidenceEmitter([&](auto& E) { AllEvidence.push_back(E); },
+                                   USRCache, Ctx);
     for (const auto* Decl : Sites.Declarations) {
       if (Filter && !Filter(*Decl)) continue;
       collectEvidenceFromTargetDeclaration(*Decl, Emitter, Pragmas);
diff --git a/nullability/inference/infer_tu_test.cc b/nullability/inference/infer_tu_test.cc
index 7ec90f3..0866ea5 100644
--- a/nullability/inference/infer_tu_test.cc
+++ b/nullability/inference/infer_tu_test.cc
@@ -581,5 +581,165 @@
                                     {inferredSlot(0, Nullability::NONNULL)})));
 }
 
+using InferTUVirtualMethodsTest = InferTUTest;
+
+TEST_F(InferTUVirtualMethodsTest, SafeVarianceNoConflicts) {
+  build(R"cc(
+    struct Base {
+      virtual int* foo(int* p) {
+        *p;
+        return nullptr;
+      }
+    };
+
+    struct Derived : public Base {
+      int* foo(int* p) override {
+        static int i = 0;
+        p = nullptr;
+        return &i;
+      }
+    };
+  )cc");
+
+  EXPECT_THAT(infer(),
+              UnorderedElementsAre(
+                  inference(hasName("Base::foo"),
+                            {inferredSlot(0, Nullability::NULLABLE),
+                             inferredSlot(1, Nullability::NONNULL)}),
+                  inference(hasName("Derived::foo"),
+                            {inferredSlot(0, Nullability::NONNULL),
+                             inferredSlot(1, Nullability::NULLABLE)})));
+}
+
+TEST_F(InferTUVirtualMethodsTest, BaseConstrainsDerived) {
+  build(R"cc(
+    struct Base {
+      virtual Nonnull<int*> foo(int* p) {
+        static int i = 0;
+        p = nullptr;
+        return &i;
+      }
+    };
+
+    struct Derived : public Base {
+      int* foo(int* p) override;
+    };
+  )cc");
+
+  EXPECT_THAT(infer(),
+              UnorderedElementsAre(
+                  inference(hasName("Base::foo"),
+                            {inferredSlot(0, Nullability::NONNULL),
+                             inferredSlot(1, Nullability::NULLABLE)}),
+                  inference(hasName("Derived::foo"),
+                            {inferredSlot(0, Nullability::NONNULL),
+                             inferredSlot(1, Nullability::NULLABLE)})));
+}
+
+TEST_F(InferTUVirtualMethodsTest, DerivedConstrainsBase) {
+  build(R"cc(
+    struct Base {
+      virtual int* foo(int* p);
+    };
+
+    struct Derived : public Base {
+      int* foo(int* p) override {
+        *p;
+        return nullptr;
+      }
+    };
+  )cc");
+
+  EXPECT_THAT(infer(), UnorderedElementsAre(
+                           inference(hasName("Base::foo"),
+                                     {inferredSlot(0, Nullability::NULLABLE),
+                                      inferredSlot(1, Nullability::NONNULL)}),
+                           inference(hasName("Derived::foo"),
+                                     {inferredSlot(0, Nullability::NULLABLE),
+                                      inferredSlot(1, Nullability::NONNULL)})));
+}
+
+TEST_F(InferTUVirtualMethodsTest, Conflict) {
+  build(R"cc(
+    struct Base {
+      virtual int* foo(int* p);
+    };
+
+    struct Derived : public Base {
+      int* foo(int* p) override {
+        *p;
+        return nullptr;
+      }
+    };
+
+    void usage() {
+      Base B;
+      // Conflict-producing nonnull return type evidence is only possible
+      // from a usage site. Since we need a usage, produce the parameter
+      // evidence here as well.
+      *B.foo(nullptr);
+    }
+  )cc");
+
+  EXPECT_THAT(
+      infer(),
+      UnorderedElementsAre(
+          inference(hasName("Base::foo"),
+                    {inferredSlot(0, Nullability::NONNULL, /*Conflict*/ true),
+                     inferredSlot(1, Nullability::NONNULL, /*Conflict*/ true)}),
+          inference(
+              hasName("Derived::foo"),
+              {inferredSlot(0, Nullability::NONNULL, /*Conflict*/ true),
+               inferredSlot(1, Nullability::NONNULL, /*Conflict*/ true)})));
+}
+
+TEST_F(InferTUVirtualMethodsTest, MultipleDerived) {
+  build(R"cc(
+    struct Base {
+      virtual void foo(int* p) { p = nullptr; }
+    };
+
+    struct DerivedA : public Base {
+      void foo(int* p) override;
+    };
+
+    struct DerivedB : public Base {
+      void foo(int* p) override;
+    };
+  )cc");
+  EXPECT_THAT(infer(),
+              UnorderedElementsAre(
+                  inference(hasName("Base::foo"),
+                            {inferredSlot(1, Nullability::NULLABLE)}),
+                  inference(hasName("DerivedA::foo"),
+                            {inferredSlot(1, Nullability::NULLABLE)}),
+                  inference(hasName("DerivedB::foo"),
+                            {inferredSlot(1, Nullability::NULLABLE)})));
+}
+
+TEST_F(InferTUVirtualMethodsTest, MultipleBase) {
+  build(R"cc(
+    struct BaseA {
+      virtual void foo(int* p);
+    };
+
+    struct BaseB {
+      virtual void foo(int* p);
+    };
+
+    struct Derived : public BaseA, public BaseB {
+      void foo(int* p) override { *p; }
+    };
+  )cc");
+
+  EXPECT_THAT(infer(), UnorderedElementsAre(
+                           inference(hasName("BaseA::foo"),
+                                     {inferredSlot(1, Nullability::NONNULL)}),
+                           inference(hasName("BaseB::foo"),
+                                     {inferredSlot(1, Nullability::NONNULL)}),
+                           inference(hasName("Derived::foo"),
+                                     {inferredSlot(1, Nullability::NONNULL)})));
+}
+
 }  // namespace
 }  // namespace clang::tidy::nullability