Collect evidence for return type from return statements.

Records Nonnull, Nullable, and Unknown return values.

PiperOrigin-RevId: 549329488
Change-Id: I3c95554d674e5555a25cc99d771509162b7cb78b
diff --git a/nullability/inference/collect_evidence.cc b/nullability/inference/collect_evidence.cc
index 2f91725..9c423a3 100644
--- a/nullability/inference/collect_evidence.cc
+++ b/nullability/inference/collect_evidence.cc
@@ -93,9 +93,8 @@
 }
 
 void collectEvidenceFromDereference(
-    std::vector<std::pair<PointerTypeNullability, Slot>> InferrableSlots,
-    const CFGElement &Element, const PointerNullabilityLattice &Lattice,
-    const dataflow::Environment &Env,
+    std::vector<std::pair<PointerTypeNullability, Slot>> &InferrableSlots,
+    const CFGElement &Element, const dataflow::Environment &Env,
     llvm::function_ref<EvidenceEmitter> Emit) {
   // Is this CFGElement a dereference of a pointer?
   auto CFGStmt = Element.getAs<clang::CFGStmt>();
@@ -132,10 +131,22 @@
   }
 }
 
+const dataflow::Formula *getInferrableSlotsUnknownConstraint(
+    std::vector<std::pair<PointerTypeNullability, Slot>> &InferrableSlots,
+    const dataflow::Environment &Env) {
+  dataflow::Arena &A = Env.getDataflowAnalysisContext().arena();
+  const dataflow::Formula *CallerSlotsUnknown = &A.makeLiteral(true);
+  for (auto &[Nullability, Slot] : InferrableSlots) {
+    CallerSlotsUnknown = &A.makeAnd(
+        *CallerSlotsUnknown, A.makeAnd(A.makeNot(Nullability.isNullable(A)),
+                                       A.makeNot(Nullability.isNonnull(A))));
+  }
+  return CallerSlotsUnknown;
+}
+
 void collectEvidenceFromCallExpr(
-    std::vector<std::pair<PointerTypeNullability, Slot>> InferrableCallerSlots,
-    const CFGElement &Element, const PointerNullabilityLattice &Lattice,
-    const dataflow::Environment &Env,
+    std::vector<std::pair<PointerTypeNullability, Slot>> &InferrableCallerSlots,
+    const CFGElement &Element, const dataflow::Environment &Env,
     llvm::function_ref<EvidenceEmitter> Emit) {
   // Is this CFGElement a call to a function?
   auto CFGStmt = Element.getAs<clang::CFGStmt>();
@@ -169,16 +180,9 @@
     // nullability based on InferrableSlots for the caller being assigned to
     // Unknown, to reflect the current annotations and not all possible
     // annotations for them.
-    dataflow::Arena &A = Env.getDataflowAnalysisContext().arena();
-    const dataflow::Formula *CallerSlotsUnknown = &A.makeLiteral(true);
-    for (auto &[Nullability, Slot] : InferrableCallerSlots) {
-      CallerSlotsUnknown = &A.makeAnd(
-          *CallerSlotsUnknown, A.makeAnd(A.makeNot(Nullability.isNullable(A)),
-                                         A.makeNot(Nullability.isNonnull(A))));
-    }
-
-    NullabilityKind ArgNullability =
-        getNullability(*PV, Env, CallerSlotsUnknown);
+    NullabilityKind ArgNullability = getNullability(
+        *PV, Env,
+        getInferrableSlotsUnknownConstraint(InferrableCallerSlots, Env));
     Evidence::Kind ArgEvidenceKind;
     switch (ArgNullability) {
       case NullabilityKind::Nullable:
@@ -195,12 +199,43 @@
   }
 }
 
+void collectEvidenceFromReturn(
+    std::vector<std::pair<PointerTypeNullability, Slot>> &InferrableSlots,
+    const CFGElement &Element, const dataflow::Environment &Env,
+    llvm::function_ref<EvidenceEmitter> Emit) {
+  // Is this CFGElement a return statement?
+  auto CFGStmt = Element.getAs<clang::CFGStmt>();
+  if (!CFGStmt) return;
+  auto *ReturnStmt = dyn_cast_or_null<clang::ReturnStmt>(CFGStmt->getStmt());
+  if (!ReturnStmt) return;
+  auto *ReturnExpr = ReturnStmt->getRetValue();
+  if (!ReturnExpr || !ReturnExpr->getType()->isPointerType()) return;
+
+  NullabilityKind ReturnNullability =
+      getNullability(ReturnExpr, Env,
+                     getInferrableSlotsUnknownConstraint(InferrableSlots, Env));
+  Evidence::Kind ReturnEvidenceKind;
+  switch (ReturnNullability) {
+    case NullabilityKind::Nullable:
+      ReturnEvidenceKind = Evidence::NULLABLE_RETURN;
+      break;
+    case NullabilityKind::NonNull:
+      ReturnEvidenceKind = Evidence::NONNULL_RETURN;
+      break;
+    default:
+      ReturnEvidenceKind = Evidence::UNKNOWN_RETURN;
+  }
+  Emit(*Env.getCurrentFunc(), SLOT_RETURN_TYPE, ReturnEvidenceKind,
+       ReturnExpr->getExprLoc());
+}
+
 void collectEvidenceFromElement(
     std::vector<std::pair<PointerTypeNullability, Slot>> InferrableSlots,
-    const CFGElement &Element, const PointerNullabilityLattice &Lattice,
-    const Environment &Env, llvm::function_ref<EvidenceEmitter> Emit) {
-  collectEvidenceFromDereference(InferrableSlots, Element, Lattice, Env, Emit);
-  collectEvidenceFromCallExpr(InferrableSlots, Element, Lattice, Env, Emit);
+    const CFGElement &Element, const Environment &Env,
+    llvm::function_ref<EvidenceEmitter> Emit) {
+  collectEvidenceFromDereference(InferrableSlots, Element, Env, Emit);
+  collectEvidenceFromCallExpr(InferrableSlots, Element, Env, Emit);
+  collectEvidenceFromReturn(InferrableSlots, Element, Env, Emit);
   // TODO: add location information.
   // TODO: add more heuristic collections here
 }
@@ -257,8 +292,8 @@
           [&](const CFGElement &Element,
               const dataflow::DataflowAnalysisState<PointerNullabilityLattice>
                   &State) {
-            collectEvidenceFromElement(InferrableSlots, Element, State.Lattice,
-                                       State.Env, Emit);
+            collectEvidenceFromElement(InferrableSlots, Element, State.Env,
+                                       Emit);
           });
 
   return llvm::Error::success();