Store USR cache outside of EvidenceEmitter so that it can be shared in more places.

For example, the current function decl's USR is needed to look up whether inferable slots' nullabilities were inferred in previous rounds.

PiperOrigin-RevId: 574873175
Change-Id: Id96d5cb082e7d61464bcc2bdd6825858eac16015
diff --git a/nullability/inference/collect_evidence.cc b/nullability/inference/collect_evidence.cc
index 8ddc113..f2964cf 100644
--- a/nullability/inference/collect_evidence.cc
+++ b/nullability/inference/collect_evidence.cc
@@ -7,6 +7,7 @@
 #include <memory>
 #include <optional>
 #include <string>
+#include <string_view>
 #include <utility>
 #include <vector>
 
@@ -40,7 +41,6 @@
 #include "clang/Basic/SourceLocation.h"
 #include "clang/Basic/Specifiers.h"
 #include "clang/Index/USRGeneration.h"
-#include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/FunctionExtras.h"
 #include "llvm/ADT/STLFunctionalExtras.h"
@@ -51,13 +51,24 @@
 using ::clang::dataflow::DataflowAnalysisContext;
 using ::clang::dataflow::Environment;
 
+std::string_view getOrGenerateUSR(USRCache &Cache, const Decl &Decl) {
+  auto [It, Inserted] = Cache.try_emplace(&Decl);
+  if (Inserted) {
+    llvm::SmallString<128> USR;
+    if (!index::generateUSRForDecl(&Decl, USR)) It->second = USR.str();
+  }
+  return It->second;
+}
+
 llvm::unique_function<EvidenceEmitter> evidenceEmitter(
-    llvm::unique_function<void(const Evidence &) const> Emit) {
+    llvm::unique_function<void(const Evidence &) const> Emit,
+    nullability::USRCache &USRCache) {
   class EvidenceEmitterImpl {
    public:
     EvidenceEmitterImpl(
-        llvm::unique_function<void(const Evidence &) const> Emit)
-        : Emit(std::move(Emit)) {}
+        llvm::unique_function<void(const Evidence &) const> Emit,
+        nullability::USRCache &USRCache)
+        : Emit(std::move(Emit)), USRCache(USRCache) {}
 
     void operator()(const Decl &Target, Slot S, Evidence::Kind Kind,
                     SourceLocation Loc) const {
@@ -68,13 +79,9 @@
       E.set_slot(S);
       E.set_kind(Kind);
 
-      auto [It, Inserted] = USRCache.try_emplace(&Target);
-      if (Inserted) {
-        llvm::SmallString<128> USR;
-        if (!index::generateUSRForDecl(&Target, USR)) It->second = USR.str();
-      }
-      if (It->second.empty()) return;  // Can't emit without a USR
-      E.mutable_symbol()->set_usr(It->second);
+      std::string_view USR = getOrGenerateUSR(USRCache, Target);
+      if (USR.empty()) return;  // Can't emit without a USR
+      E.mutable_symbol()->set_usr(USR);
 
       // TODO: make collecting and propagating location information optional?
       auto &SM =
@@ -88,10 +95,10 @@
     }
 
    private:
-    mutable llvm::DenseMap<const Decl *, std::string> USRCache;
     llvm::unique_function<void(const Evidence &) const> Emit;
+    nullability::USRCache &USRCache;
   };
-  return EvidenceEmitterImpl(std::move(Emit));
+  return EvidenceEmitterImpl(std::move(Emit), USRCache);
 }
 
 namespace {
@@ -336,6 +343,7 @@
 
 llvm::Error collectEvidenceFromImplementation(
     const Decl &Decl, llvm::function_ref<EvidenceEmitter> Emit,
+    USRCache &USRCache,
     const llvm::DenseSet<SlotFingerprint> &PreviouslyInferredNullable,
     const llvm::DenseSet<SlotFingerprint> &PreviouslyInferredNonnull) {
   const FunctionDecl *Func = dyn_cast<FunctionDecl>(&Decl);
diff --git a/nullability/inference/collect_evidence.h b/nullability/inference/collect_evidence.h
index c516359..97a8578 100644
--- a/nullability/inference/collect_evidence.h
+++ b/nullability/inference/collect_evidence.h
@@ -5,12 +5,15 @@
 #ifndef CRUBIT_NULLABILITY_INFERENCE_COLLECT_EVIDENCE_H_
 #define CRUBIT_NULLABILITY_INFERENCE_COLLECT_EVIDENCE_H_
 
+#include <string>
+#include <string_view>
 #include <vector>
 
 #include "nullability/inference/inference.proto.h"
 #include "nullability/inference/slot_fingerprint.h"
 #include "clang/AST/DeclBase.h"
 #include "clang/Basic/SourceLocation.h"
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/FunctionExtras.h"
 #include "llvm/ADT/STLFunctionalExtras.h"
@@ -18,13 +21,17 @@
 
 namespace clang::tidy::nullability {
 
+using USRCache = llvm::DenseMap<const Decl *, std::string>;
+
+std::string_view getOrGenerateUSR(USRCache &Cache, const Decl &);
+
 // Callback used to report collected nullability evidence.
 using EvidenceEmitter = void(const Decl &Target, Slot, Evidence::Kind,
                              SourceLocation);
 // Creates an EvidenceEmitter that serializes the evidence as Evidence protos.
 // This emitter caches USR generation, and should be reused for the whole AST.
 llvm::unique_function<EvidenceEmitter> evidenceEmitter(
-    llvm::unique_function<void(const Evidence &) const>);
+    llvm::unique_function<void(const Evidence &) const>, USRCache &USRCache);
 
 // Analyze code (such as a function body) to infer nullability.
 //
@@ -35,7 +42,7 @@
 // It is up to the caller to ensure the implementation is eligible for inference
 // (function has a body, is not dependent, etc).
 llvm::Error collectEvidenceFromImplementation(
-    const Decl &, llvm::function_ref<EvidenceEmitter>,
+    const Decl &, llvm::function_ref<EvidenceEmitter>, USRCache &USRCache,
     const llvm::DenseSet<SlotFingerprint> &PreviouslyInferredNullable = {},
     const llvm::DenseSet<SlotFingerprint> &PreviouslyInferredNonnull = {});
 
diff --git a/nullability/inference/collect_evidence_test.cc b/nullability/inference/collect_evidence_test.cc
index 5202fa7..5b08e1d 100644
--- a/nullability/inference/collect_evidence_test.cc
+++ b/nullability/inference/collect_evidence_test.cc
@@ -69,10 +69,13 @@
     llvm::StringRef Source) {
   std::vector<Evidence> Results;
   clang::TestAST AST(getInputsWithAnnotationDefinitions(Source));
+  USRCache usr_cache;
   auto Err = collectEvidenceFromImplementation(
       cast<FunctionDecl>(
           *dataflow::test::findValueDecl(AST.context(), "target")),
-      evidenceEmitter([&](const Evidence& E) { Results.push_back(E); }));
+      evidenceEmitter([&](const Evidence& E) { Results.push_back(E); },
+                      usr_cache),
+      usr_cache);
   if (Err) ADD_FAILURE() << toString(std::move(Err));
   return Results;
 }
@@ -80,9 +83,11 @@
 std::vector<Evidence> collectEvidenceFromTargetDecl(llvm::StringRef Source) {
   std::vector<Evidence> Results;
   clang::TestAST AST(getInputsWithAnnotationDefinitions(Source));
+  USRCache usr_cache;
   collectEvidenceFromTargetDeclaration(
       *dataflow::test::findValueDecl(AST.context(), "target"),
-      evidenceEmitter([&](const Evidence& E) { Results.push_back(E); }));
+      evidenceEmitter([&](const Evidence& E) { Results.push_back(E); },
+                      usr_cache));
   return Results;
 }
 
@@ -472,10 +477,13 @@
       "target", TargetInstantiationNodes);
   ASSERT_NE(InstantiationDecl, nullptr);
 
+  USRCache usr_cache;
   std::vector<Evidence> Results;
   auto Err = collectEvidenceFromImplementation(
       *InstantiationDecl,
-      evidenceEmitter([&](const Evidence& E) { Results.push_back(E); }));
+      evidenceEmitter([&](const Evidence& E) { Results.push_back(E); },
+                      usr_cache),
+      usr_cache);
   if (Err) ADD_FAILURE() << toString(std::move(Err));
   EXPECT_THAT(Results, IsEmpty());
 }
@@ -610,7 +618,8 @@
       dataflow::test::findValueDecl(AST.context(), "target");
   ASSERT_NE(TargetDecl, nullptr);
 
-  EXPECT_DEATH(evidenceEmitter([](const Evidence& e) {})(
+  USRCache usr_cache;
+  EXPECT_DEATH(evidenceEmitter([](const Evidence& e) {}, usr_cache)(
                    *TargetDecl, Slot{}, Evidence::ANNOTATED_UNKNOWN,
                    TargetDecl->getLocation()),
                "not an inference target");
diff --git a/nullability/inference/infer_tu.cc b/nullability/inference/infer_tu.cc
index cd783b2..2f9d616 100644
--- a/nullability/inference/infer_tu.cc
+++ b/nullability/inference/infer_tu.cc
@@ -11,6 +11,7 @@
 #include "nullability/inference/inference.proto.h"
 #include "nullability/inference/merge.h"
 #include "clang/AST/ASTContext.h"
+#include "clang/AST/DeclBase.h"
 #include "clang/Basic/SourceManager.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
@@ -35,14 +36,17 @@
 
   // Collect all evidence.
   auto Sites = EvidenceSites::discover(Ctx);
-  auto Emitter = evidenceEmitter([&](auto& E) { AllEvidence.push_back(E); });
+  USRCache USRCache;
+  auto Emitter =
+      evidenceEmitter([&](auto& E) { AllEvidence.push_back(E); }, USRCache);
   for (const auto* Decl : Sites.Declarations) {
     if (Filter && !Filter(*Decl)) continue;
     collectEvidenceFromTargetDeclaration(*Decl, Emitter);
   }
   for (const auto* Impl : Sites.Implementations) {
     if (Filter && !Filter(*Impl)) continue;
-    if (auto Err = collectEvidenceFromImplementation(*Impl, Emitter)) {
+    if (auto Err =
+            collectEvidenceFromImplementation(*Impl, Emitter, USRCache)) {
       llvm::errs() << "Skipping function: " << toString(std::move(Err)) << "\n";
       Impl->print(llvm::errs());
     }