Infer Nonnull parameters from the passing of an influenced value to a Nonnull parameter.

An influenced value is a value that is made Nonnull by the inferred parameter being marked Nonnull.

PiperOrigin-RevId: 568588041
Change-Id: I14939fc1f7a035721829cfcd4ed1e2b10fe1fa10
diff --git a/nullability/inference/collect_evidence.cc b/nullability/inference/collect_evidence.cc
index 7d5837b..05441ff 100644
--- a/nullability/inference/collect_evidence.cc
+++ b/nullability/inference/collect_evidence.cc
@@ -36,6 +36,7 @@
 #include "clang/Analysis/FlowSensitive/Value.h"
 #include "clang/Analysis/FlowSensitive/WatchedLiteralsSolver.h"
 #include "clang/Basic/LLVM.h"
+#include "clang/Basic/SourceLocation.h"
 #include "clang/Basic/Specifiers.h"
 #include "clang/Index/USRGeneration.h"
 #include "llvm/ADT/DenseMap.h"
@@ -109,6 +110,33 @@
   return {nullptr, SourceLocation()};
 }
 
+// Records evidence derived from the assumption that Value is nonnull.
+// It may be dereferenced, passed as a nonnull param, etc, per EvidenceKind.
+void collectMustBeNonnullEvidence(
+    const dataflow::PointerValue &Value, const dataflow::Environment &Env,
+    SourceLocation Loc,
+    std::vector<std::pair<PointerTypeNullability, Slot>> &InferrableSlots,
+    Evidence::Kind EvidenceKind, llvm::function_ref<EvidenceEmitter> Emit) {
+  auto &A = Env.getDataflowAnalysisContext().arena();
+  auto &NotIsNull = A.makeNot(getPointerNullState(Value).IsNull);
+
+  // If the flow conditions already imply that Value is not null, then we don't
+  // have any new evidence of a necessary annotation.
+  if (Env.flowConditionImplies(NotIsNull)) return;
+
+  // Otherwise, if an inferrable slot being annotated Nonnull would imply that
+  // Value is not null, then we have evidence suggesting that slot should be
+  // annotated. For now, we simply choose the first such slot, sidestepping
+  // complexities around the possibility of multiple such slots, any one of
+  // which would be sufficient if annotated Nonnull.
+  for (auto &[Nullability, Slot] : InferrableSlots) {
+    auto &SlotNonnullImpliesValueNonnull =
+        A.makeImplies(Nullability.isNonnull(A), NotIsNull);
+    if (Env.flowConditionImplies(SlotNonnullImpliesValueNonnull))
+      Emit(*Env.getCurrentFunc(), Slot, EvidenceKind, Loc);
+  }
+}
+
 void collectEvidenceFromDereference(
     std::vector<std::pair<PointerTypeNullability, Slot>> &InferrableSlots,
     const CFGElement &Element, const dataflow::Environment &Env,
@@ -125,24 +153,8 @@
   dataflow::PointerValue *DereferencedValue =
       getPointerValueFromExpr(Target, Env);
   if (!DereferencedValue) return;
-  auto &A = Env.getDataflowAnalysisContext().arena();
-  auto &NotIsNull = A.makeNot(getPointerNullState(*DereferencedValue).IsNull);
-
-  // If the flow conditions already imply the dereferenced value is not null,
-  // then we don't have any new evidence of a necessary annotation.
-  if (Env.flowConditionImplies(NotIsNull)) return;
-
-  // Otherwise, if an inferrable slot being annotated Nonnull would imply that
-  // the dereferenced value is not null, then we have evidence suggesting that
-  // slot should be annotated. For now, we simply choose the first such slot,
-  // sidestepping complexities around the possibility of multiple such slots,
-  // any one of which would be sufficient if annotated Nonnull.
-  for (auto &[Nullability, Slot] : InferrableSlots) {
-    auto &SlotNonnullImpliesDerefValueNonnull =
-        A.makeImplies(Nullability.isNonnull(A), NotIsNull);
-    if (Env.flowConditionImplies(SlotNonnullImpliesDerefValueNonnull))
-      Emit(*Env.getCurrentFunc(), Slot, Evidence::UNCHECKED_DEREFERENCE, Loc);
-  }
+  collectMustBeNonnullEvidence(*DereferencedValue, Env, Loc, InferrableSlots,
+                               Evidence::UNCHECKED_DEREFERENCE, Emit);
 }
 
 const dataflow::Formula *getInferrableSlotsUnknownConstraint(
@@ -159,6 +171,20 @@
   return InferrableSlotsUnknown;
 }
 
+void collectEvidenceFromParamAnnotation(
+    TypeNullability &ParamNullability, const dataflow::PointerValue &ArgPV,
+    std::vector<std::pair<PointerTypeNullability, Slot>> &InferrableCallerSlots,
+    const dataflow::Environment &Env, SourceLocation ArgLoc,
+    llvm::function_ref<EvidenceEmitter> Emit) {
+  //  TODO: Account for variance and each layer of nullability when we handle
+  //  more than top-level pointers.
+  if (ParamNullability.empty()) return;
+  if (ParamNullability[0].concrete() == NullabilityKind::NonNull) {
+    collectMustBeNonnullEvidence(ArgPV, Env, ArgLoc, InferrableCallerSlots,
+                                 Evidence::PASSED_TO_NONNULL, Emit);
+  }
+}
+
 void collectEvidenceFromCallExpr(
     std::vector<std::pair<PointerTypeNullability, Slot>> &InferrableCallerSlots,
     const CFGElement &Element, const dataflow::Environment &Env,
@@ -185,9 +211,9 @@
 
   // For each pointer parameter of the callee, ...
   for (; ParamI < CalleeDecl->param_size(); ++ParamI, ++ArgI) {
-    if (!isSupportedPointerType(
-            CalleeDecl->getParamDecl(ParamI)->getType().getNonReferenceType()))
-      continue;
+    auto ParamType =
+        CalleeDecl->getParamDecl(ParamI)->getType().getNonReferenceType();
+    if (!isSupportedPointerType(ParamType)) continue;
     // the corresponding argument should also be a pointer.
     CHECK(isSupportedPointerType(CallExpr->getArg(ArgI)->getType()));
 
@@ -195,11 +221,15 @@
         getPointerValueFromExpr(CallExpr->getArg(ArgI), Env);
     if (!PV) continue;
 
-    // TODO: Check if the parameter is annotated. If annotated Nonnull, (instead
-    // of collecting evidence for it?) collect evidence similar to a
-    // dereference, i.e. if the argument is not already proven Nonnull, collect
-    // evidence for a parameter that could be annotated Nonnull as a way to
-    // force the argument to be Nonnull.
+    SourceLocation ArgLoc = CallExpr->getArg(ArgI)->getExprLoc();
+
+    // TODO: Include inferred annotations from previous rounds when propagating.
+    auto ParamNullability = getNullabilityAnnotationsFromType(ParamType);
+
+    // Collect evidence from the binding of the argument to the parameter's
+    // nullability, if known.
+    collectEvidenceFromParamAnnotation(
+        ParamNullability, *PV, InferrableCallerSlots, Env, ArgLoc, Emit);
 
     // Emit evidence of the parameter's nullability. First, calculate that
     // nullability based on InferrableSlots for the caller being assigned to
@@ -219,8 +249,7 @@
       default:
         ArgEvidenceKind = Evidence::UNKNOWN_ARGUMENT;
     }
-    Emit(*CalleeDecl, paramSlot(ParamI), ArgEvidenceKind,
-         CallExpr->getArg(ArgI)->getExprLoc());
+    Emit(*CalleeDecl, paramSlot(ParamI), ArgEvidenceKind, ArgLoc);
   }
 }
 
diff --git a/nullability/inference/collect_evidence_test.cc b/nullability/inference/collect_evidence_test.cc
index 4fca9e2..6b8c6ee 100644
--- a/nullability/inference/collect_evidence_test.cc
+++ b/nullability/inference/collect_evidence_test.cc
@@ -441,6 +441,17 @@
                                 functionNamed("operator()"))));
 }
 
+TEST(CollectEvidenceFromImplementationTest, PassedToNonnull) {
+  static constexpr llvm::StringRef Src = R"cc(
+    void callee(Nonnull<int*> i);
+
+    void target(int* p) { callee(p); }
+  )cc";
+  EXPECT_THAT(collectEvidenceFromTargetFunction(Src),
+              Contains(evidence(paramSlot(0), Evidence::PASSED_TO_NONNULL,
+                                functionNamed("target"))));
+}
+
 TEST(CollectEvidenceFromImplementationTest, NotInferenceTarget) {
   static constexpr llvm::StringRef Src = R"cc(
     template <typename T>
diff --git a/nullability/inference/inference.proto b/nullability/inference/inference.proto
index f79533f..b11fc5e 100644
--- a/nullability/inference/inference.proto
+++ b/nullability/inference/inference.proto
@@ -72,6 +72,8 @@
     NONNULL_RETURN = 8;
     // A value with Unknown nullability was returned.
     UNKNOWN_RETURN = 9;
+    // A value was passed to a Nonnull parameter.
+    PASSED_TO_NONNULL = 10;
   }
 }
 
diff --git a/nullability/inference/merge.cc b/nullability/inference/merge.cc
index 26134ea..52ddae4 100644
--- a/nullability/inference/merge.cc
+++ b/nullability/inference/merge.cc
@@ -124,6 +124,7 @@
   if (Counts[Evidence::NONNULL_RETURN] && !Counts[Evidence::NULLABLE_RETURN] &&
       !Counts[Evidence::UNKNOWN_RETURN])
     update(Result, Inference::NONNULL);
+  if (Counts[Evidence::PASSED_TO_NONNULL]) update(Result, Inference::NONNULL);
   if (Result) return *Result;
 
   // Optional "soft" inference heuristics.
diff --git a/nullability/inference/merge_test.cc b/nullability/inference/merge_test.cc
index 279720d..69ce865 100644
--- a/nullability/inference/merge_test.cc
+++ b/nullability/inference/merge_test.cc
@@ -237,5 +237,10 @@
   EXPECT_EQ(Inference::NULLABLE, infer());
 }
 
+TEST_F(InferTest, PassedToNonnull) {
+  add(Evidence::PASSED_TO_NONNULL);
+  EXPECT_EQ(Inference::NONNULL, infer());
+}
+
 }  // namespace
 }  // namespace clang::tidy::nullability
diff --git a/nullability/type_nullability.h b/nullability/type_nullability.h
index 8605452..54d5e96 100644
--- a/nullability/type_nullability.h
+++ b/nullability/type_nullability.h
@@ -44,7 +44,6 @@
 
 /// Is this exactly a pointer type that we track outer nullability for?
 /// This unwraps sugar, i.e. it looks at the canonical type.
-/// Does not unwrap sugar, consider isSupportedPointer(T.getCanonicalType()).
 ///
 /// (For now, only regular `PointerType`s, in future we should consider
 /// supporting pointer-to-member, ObjC pointers, `unique_ptr`, etc).