Collect information from calls to constructors.

Also add test for collecting information from constructor bodies, just to make it clear we are covering them.

PiperOrigin-RevId: 578915862
Change-Id: Ic4fd38e5f701e56ff4831991def3ef52cb322535
diff --git a/nullability/inference/collect_evidence.cc b/nullability/inference/collect_evidence.cc
index d3d8ce5..e6e1a42 100644
--- a/nullability/inference/collect_evidence.cc
+++ b/nullability/inference/collect_evidence.cc
@@ -23,6 +23,7 @@
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/DeclBase.h"
+#include "clang/AST/DeclCXX.h"
 #include "clang/AST/Expr.h"
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/OperationKinds.h"
@@ -216,20 +217,12 @@
   }
 }
 
-void collectEvidenceFromCallExpr(
+template <typename CallOrConstructExpr>
+void collectEvidenceFromArgsAndParams(
+    const FunctionDecl &CalleeDecl, const CallOrConstructExpr &Expr,
     std::vector<std::pair<PointerTypeNullability, Slot>> &InferableCallerSlots,
-    const Formula &InferableSlotsConstraint, const CFGElement &Element,
-    const dataflow::Environment &Env,
+    const Formula &InferableSlotsConstraint, const dataflow::Environment &Env,
     llvm::function_ref<EvidenceEmitter> Emit) {
-  // Is this CFGElement a call to a function?
-  auto CFGStmt = Element.getAs<clang::CFGStmt>();
-  if (!CFGStmt) return;
-  auto *CallExpr = dyn_cast_or_null<clang::CallExpr>(CFGStmt->getStmt());
-  if (!CallExpr || !CallExpr->getCalleeDecl()) return;
-  auto *CalleeDecl =
-      dyn_cast_or_null<clang::FunctionDecl>(CallExpr->getCalleeDecl());
-  if (!CalleeDecl || !isInferenceTarget(*CalleeDecl)) return;
-
   unsigned ParamI = 0;
   unsigned ArgI = 0;
   // Member operator calls hold the function object as the first argument,
@@ -237,23 +230,23 @@
   // For example: Given struct S { bool operator+(int*); }
   // The CXXMethodDecl has one parameter, but a call S{}+p is a
   // CXXOperatorCallExpr with two arguments: an S and an int*.
-  if (isa<clang::CXXOperatorCallExpr>(CallExpr) &&
+  if (isa<clang::CXXOperatorCallExpr>(Expr) &&
       isa<clang::CXXMethodDecl>(CalleeDecl))
     ++ArgI;
 
   // For each pointer parameter of the callee, ...
-  for (; ParamI < CalleeDecl->param_size(); ++ParamI, ++ArgI) {
+  for (; ParamI < CalleeDecl.param_size(); ++ParamI, ++ArgI) {
     auto ParamType =
-        CalleeDecl->getParamDecl(ParamI)->getType().getNonReferenceType();
+        CalleeDecl.getParamDecl(ParamI)->getType().getNonReferenceType();
     if (!isSupportedPointerType(ParamType)) continue;
     // the corresponding argument should also be a pointer.
-    CHECK(isSupportedPointerType(CallExpr->getArg(ArgI)->getType()));
+    CHECK(isSupportedPointerType(Expr.getArg(ArgI)->getType()));
 
     dataflow::PointerValue *PV =
-        getPointerValueFromExpr(CallExpr->getArg(ArgI), Env);
+        getPointerValueFromExpr(Expr.getArg(ArgI), Env);
     if (!PV) continue;
 
-    SourceLocation ArgLoc = CallExpr->getArg(ArgI)->getExprLoc();
+    SourceLocation ArgLoc = Expr.getArg(ArgI)->getExprLoc();
 
     // TODO: Include inferred annotations from previous rounds when propagating.
     auto ParamNullability = getNullabilityAnnotationsFromType(ParamType);
@@ -280,10 +273,47 @@
       default:
         ArgEvidenceKind = Evidence::UNKNOWN_ARGUMENT;
     }
-    Emit(*CalleeDecl, paramSlot(ParamI), ArgEvidenceKind, ArgLoc);
+    Emit(CalleeDecl, paramSlot(ParamI), ArgEvidenceKind, ArgLoc);
   }
 }
 
+void collectEvidenceFromCallExpr(
+    std::vector<std::pair<PointerTypeNullability, Slot>> &InferableCallerSlots,
+    const Formula &InferableSlotsConstraint, const CFGElement &Element,
+    const dataflow::Environment &Env,
+    llvm::function_ref<EvidenceEmitter> Emit) {
+  // Is this CFGElement a call to a function?
+  auto CFGStmt = Element.getAs<clang::CFGStmt>();
+  if (!CFGStmt) return;
+  auto *CallExpr = dyn_cast_or_null<clang::CallExpr>(CFGStmt->getStmt());
+  if (!CallExpr || !CallExpr->getCalleeDecl()) return;
+  auto *CalleeDecl =
+      dyn_cast_or_null<clang::FunctionDecl>(CallExpr->getCalleeDecl());
+  if (!CalleeDecl || !isInferenceTarget(*CalleeDecl)) return;
+
+  collectEvidenceFromArgsAndParams(*CalleeDecl, *CallExpr, InferableCallerSlots,
+                                   InferableSlotsConstraint, Env, Emit);
+}
+
+void collectEvidenceFromConstructExpr(
+    std::vector<std::pair<PointerTypeNullability, Slot>> &InferableSlots,
+    const Formula &InferableSlotsConstraint, const CFGElement &Element,
+    const dataflow::Environment &Env,
+    llvm::function_ref<EvidenceEmitter> Emit) {
+  auto CFGStmt = Element.getAs<clang::CFGStmt>();
+  if (!CFGStmt) return;
+  auto *ConstructExpr =
+      dyn_cast_or_null<clang::CXXConstructExpr>(CFGStmt->getStmt());
+  if (!ConstructExpr || !ConstructExpr->getConstructor()) return;
+  auto *ConstructorDecl = dyn_cast_or_null<clang::CXXConstructorDecl>(
+      ConstructExpr->getConstructor());
+  if (!ConstructorDecl || !isInferenceTarget(*ConstructorDecl)) return;
+
+  collectEvidenceFromArgsAndParams(*ConstructorDecl, *ConstructExpr,
+                                   InferableSlots, InferableSlotsConstraint,
+                                   Env, Emit);
+}
+
 void collectEvidenceFromReturn(
     std::vector<std::pair<PointerTypeNullability, Slot>> &InferableSlots,
     const Formula &InferableSlotsConstraint, const CFGElement &Element,
@@ -364,6 +394,8 @@
   collectEvidenceFromDereference(InferableSlots, Element, Env, Emit);
   collectEvidenceFromCallExpr(InferableSlots, InferableSlotsConstraint, Element,
                               Env, Emit);
+  collectEvidenceFromConstructExpr(InferableSlots, InferableSlotsConstraint,
+                                   Element, Env, Emit);
   collectEvidenceFromReturn(InferableSlots, InferableSlotsConstraint, Element,
                             Env, Emit);
   collectEvidenceFromAssignment(InferableSlots, Element, Env, Emit);
@@ -401,12 +433,12 @@
       std::optional<const Decl *> fingerprintedDecl;
       Slot Slot;
       if (auto *FD = clang::dyn_cast_or_null<FunctionDecl>(&D)) {
-        fingerprintedDecl = (ValueDecl *)FD;
+        fingerprintedDecl = FD;
         Slot = SLOT_RETURN_TYPE;
       } else if (auto *PD = clang::dyn_cast_or_null<ParmVarDecl>(&D)) {
         if (auto *Parent = clang::dyn_cast_or_null<FunctionDecl>(
                 PD->getParentFunctionOrMethod())) {
-          fingerprintedDecl = (ValueDecl *)Parent;
+          fingerprintedDecl = Parent;
           Slot = paramSlot(PD->getFunctionScopeIndex());
         }
       }
diff --git a/nullability/inference/collect_evidence_test.cc b/nullability/inference/collect_evidence_test.cc
index 58ed6ca..87567b3 100644
--- a/nullability/inference/collect_evidence_test.cc
+++ b/nullability/inference/collect_evidence_test.cc
@@ -452,6 +452,22 @@
                                 functionNamed("operator()"))));
 }
 
+TEST(CollectEvidenceFromImplementationTest, ConstructorCall) {
+  static constexpr llvm::StringRef Src = R"cc(
+    class S {
+     public:
+      S(Nonnull<int*> a);
+    };
+    void target(int* p) { S s(p); }
+  )cc";
+  EXPECT_THAT(
+      collectEvidenceFromTargetFunction(Src),
+      UnorderedElementsAre(evidence(paramSlot(0), Evidence::BOUND_TO_NONNULL,
+                                    functionNamed("target")),
+                           evidence(paramSlot(0), Evidence::UNKNOWN_ARGUMENT,
+                                    functionNamed("S"))));
+}
+
 TEST(CollectEvidenceFromImplementationTest, PassedToNonnull) {
   static constexpr llvm::StringRef Src = R"cc(
     void callee(Nonnull<int*> i);
@@ -709,6 +725,7 @@
     auto Lambda = []() {};  // Not analyzed yet.
 
     struct S {
+      S() {}
       void member();
     };
     void S::member() {}
@@ -716,11 +733,11 @@
   auto Sites = EvidenceSites::discover(AST.context());
   EXPECT_THAT(Sites.Declarations,
               ElementsAre(declNamed("foo"), declNamed("bar"), declNamed("bar"),
-                          declNamed("baz"), declNamed("S::member"),
+                          declNamed("baz"), declNamed("S::S"),
+                          declNamed("S::member"), declNamed("S::member")));
+  EXPECT_THAT(Sites.Implementations,
+              ElementsAre(declNamed("bar"), declNamed("baz"), declNamed("S::S"),
                           declNamed("S::member")));
-  EXPECT_THAT(
-      Sites.Implementations,
-      ElementsAre(declNamed("bar"), declNamed("baz"), declNamed("S::member")));
 }
 
 TEST(EvidenceSitesTest, Variables) {