Refactor diagnoser function to return CFGElement.

PiperOrigin-RevId: 487174212
diff --git a/nullability_verification/pointer_nullability_diagnosis.cc b/nullability_verification/pointer_nullability_diagnosis.cc
index d1aa1e0..6559beb 100644
--- a/nullability_verification/pointer_nullability_diagnosis.cc
+++ b/nullability_verification/pointer_nullability_diagnosis.cc
@@ -41,20 +41,20 @@
          isNullableOrUntracked(E, Env);
 }
 
-llvm::Optional<const Stmt*> diagnoseDereference(const UnaryOperator* UnaryOp,
-                                                const MatchFinder::MatchResult&,
-                                                const Environment& Env) {
+llvm::Optional<CFGElement> diagnoseDereference(const UnaryOperator* UnaryOp,
+                                               const MatchFinder::MatchResult&,
+                                               const Environment& Env) {
   if (isNullableOrUntracked(UnaryOp->getSubExpr(), Env)) {
-    return UnaryOp;
+    return llvm::Optional<CFGElement>(CFGStmt(UnaryOp));
   }
   return llvm::None;
 }
 
-llvm::Optional<const Stmt*> diagnoseArrow(
-    const MemberExpr* MemberExpr, const MatchFinder::MatchResult& Result,
-    const Environment& Env) {
+llvm::Optional<CFGElement> diagnoseArrow(const MemberExpr* MemberExpr,
+                                         const MatchFinder::MatchResult& Result,
+                                         const Environment& Env) {
   if (isNullableOrUntracked(MemberExpr->getBase(), Env)) {
-    return MemberExpr;
+    return llvm::Optional<CFGElement>(CFGStmt(MemberExpr));
   }
   return llvm::None;
 }
@@ -78,7 +78,7 @@
 // TODO(b/233582219): Handle call expressions whose callee is not a decl (e.g.
 // a function returned from another function), or when the callee cannot be
 // interpreted as a function type (e.g. a pointer to a function pointer).
-llvm::Optional<const Stmt*> diagnoseCallExpr(
+llvm::Optional<CFGElement> diagnoseCallExpr(
     const CallExpr* CE, const MatchFinder::MatchResult& Result,
     const Environment& Env) {
   auto* Callee = CE->getCalleeDecl();
@@ -96,11 +96,11 @@
   }
 
   return isIncompatibleArgumentList(ParamTypes, Args, Env, *Result.Context)
-             ? llvm::Optional<const Stmt*>(CE)
+             ? llvm::Optional<CFGElement>(CFGStmt(CE))
              : llvm::None;
 }
 
-llvm::Optional<const Stmt*> diagnoseConstructExpr(
+llvm::Optional<CFGElement> diagnoseConstructExpr(
     const CXXConstructExpr* CE, const MatchFinder::MatchResult& Result,
     const Environment& Env) {
   auto ConstructorParamTypes = CE->getConstructor()
@@ -110,11 +110,11 @@
   ArrayRef<const Expr*> ConstructorArgs(CE->getArgs(), CE->getNumArgs());
   return isIncompatibleArgumentList(ConstructorParamTypes, ConstructorArgs, Env,
                                     *Result.Context)
-             ? llvm::Optional<const Stmt*>(CE)
+             ? llvm::Optional<CFGElement>(CFGStmt(CE))
              : llvm::None;
 }
 
-llvm::Optional<const Stmt*> diagnoseReturn(
+llvm::Optional<CFGElement> diagnoseReturn(
     const ReturnStmt* RS, const MatchFinder::MatchResult& Result,
     const Environment& Env) {
   auto ReturnType = cast<FunctionDecl>(Env.getDeclCtx())->getReturnType();
@@ -124,11 +124,11 @@
   assert(ReturnExpr->getType()->isPointerType());
 
   return isIncompatibleAssignment(ReturnType, ReturnExpr, Env, *Result.Context)
-             ? llvm::Optional<const Stmt*>(RS)
+             ? llvm::Optional<CFGElement>(CFGStmt(RS))
              : llvm::None;
 }
 
-llvm::Optional<const Stmt*> diagnoseMemberInitializer(
+llvm::Optional<CFGElement> diagnoseMemberInitializer(
     const CXXCtorInitializer* CI, const MatchFinder::MatchResult& Result,
     const Environment& Env) {
   assert(CI->isAnyMemberInitializer());
@@ -139,19 +139,12 @@
   auto MemberInitExpr = CI->getInit();
   return isIncompatibleAssignment(MemberType, MemberInitExpr, Env,
                                   *Result.Context)
-             // TODO(b/233582219): CtorInitializer is not compatible with the
-             // return type as it is not a Stmt. Therefore, we currently return
-             // the expression in the initializer. The return type should be
-             // modified to work over different AST nodes. For example,
-             // returning a SourceLocation or creating a Diagnostic base class
-             // that will contain more information about the violation and store
-             // the relevant AST nodes.
-             ? llvm::Optional<const Stmt*>(MemberInitExpr)
+             ? llvm::Optional<CFGElement>(CFGInitializer(CI))
              : llvm::None;
 }
 
 auto buildDiagnoser() {
-  return CFGMatchSwitchBuilder<const Environment, llvm::Optional<const Stmt*>>()
+  return CFGMatchSwitchBuilder<const Environment, llvm::Optional<CFGElement>>()
       // (*)
       .CaseOfCFGStmt<UnaryOperator>(isPointerDereference(), diagnoseDereference)
       // (->)
diff --git a/nullability_verification/pointer_nullability_diagnosis.h b/nullability_verification/pointer_nullability_diagnosis.h
index e688210..5547e77 100644
--- a/nullability_verification/pointer_nullability_diagnosis.h
+++ b/nullability_verification/pointer_nullability_diagnosis.h
@@ -30,14 +30,14 @@
   ///
   /// TODO(b/233582219): Extend diagnosis to return more information, e.g. the
   /// type of violation.
-  llvm::Optional<const Stmt*> diagnose(const CFGElement* Elt, ASTContext& Ctx,
-                                       const dataflow::Environment& Env) {
+  llvm::Optional<CFGElement> diagnose(const CFGElement* Elt, ASTContext& Ctx,
+                                      const dataflow::Environment& Env) {
     return Diagnoser(*Elt, Ctx, Env);
   }
 
  private:
   dataflow::CFGMatchSwitch<const dataflow::Environment,
-                           llvm::Optional<const Stmt*>>
+                           llvm::Optional<CFGElement>>
       Diagnoser;
 };
 
diff --git a/nullability_verification/pointer_nullability_verification_test.cc b/nullability_verification/pointer_nullability_verification_test.cc
index b5af9d1..12a1a44 100644
--- a/nullability_verification/pointer_nullability_verification_test.cc
+++ b/nullability_verification/pointer_nullability_verification_test.cc
@@ -29,7 +29,7 @@
 using ::testing::Test;
 
 void checkDiagnostics(llvm::StringRef SourceCode) {
-  std::vector<const Stmt *> Diagnostics;
+  std::vector<CFGElement> Diagnostics;
   PointerNullabilityDiagnoser Diagnoser;
   ASSERT_THAT_ERROR(
       checkDataflow<PointerNullabilityAnalysis>(
@@ -58,9 +58,17 @@
               ExpectedLines.insert(Line);
             }
             auto &SrcMgr = AnalysisData.ASTCtx.getSourceManager();
-            for (auto *Stmt : Diagnostics) {
-              ActualLines.insert(
-                  SrcMgr.getPresumedLineNumber(Stmt->getBeginLoc()));
+            for (auto Element : Diagnostics) {
+              if (Optional<CFGStmt> stmt = Element.getAs<CFGStmt>()) {
+                ActualLines.insert(SrcMgr.getPresumedLineNumber(
+                    stmt->getStmt()->getBeginLoc()));
+              } else if (Optional<CFGInitializer> init =
+                             Element.getAs<CFGInitializer>()) {
+                ActualLines.insert(SrcMgr.getPresumedLineNumber(
+                    init->getInitializer()->getSourceLocation()));
+              } else {
+                ADD_FAILURE() << "this code should not be reached";
+              }
             }
             EXPECT_THAT(ActualLines, ContainerEq(ExpectedLines));
           }),