In pointer nullability diagnosis, distinguish different kinds of diagnostics.

PiperOrigin-RevId: 549710060
Change-Id: I49ecafba4e91025b8b85f9b273dbd494d2622b1b
diff --git a/nullability/BUILD b/nullability/BUILD
index fe0a564..c5b4784 100644
--- a/nullability/BUILD
+++ b/nullability/BUILD
@@ -80,10 +80,12 @@
         ":pointer_nullability_lattice",
         ":pointer_nullability_matchers",
         ":type_nullability",
+        "@absl//absl/log:check",
         "@llvm-project//clang:analysis",
         "@llvm-project//clang:ast",
         "@llvm-project//clang:ast_matchers",
         "@llvm-project//clang:basic",
+        "@llvm-project//llvm:Support",
     ],
 )
 
diff --git a/nullability/pointer_nullability_diagnosis.cc b/nullability/pointer_nullability_diagnosis.cc
index 6456910..75d1a2b 100644
--- a/nullability/pointer_nullability_diagnosis.cc
+++ b/nullability/pointer_nullability_diagnosis.cc
@@ -4,9 +4,7 @@
 
 #include "nullability/pointer_nullability_diagnosis.h"
 
-#include <optional>
-#include <string>
-
+#include "absl/log/check.h"
 #include "nullability/pointer_nullability.h"
 #include "nullability/pointer_nullability_matchers.h"
 #include "nullability/type_nullability.h"
@@ -15,10 +13,18 @@
 #include "clang/AST/Expr.h"
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/Stmt.h"
+#include "clang/AST/Type.h"
 #include "clang/ASTMatchers/ASTMatchFinder.h"
+#include "clang/Analysis/CFG.h"
 #include "clang/Analysis/FlowSensitive/CFGMatchSwitch.h"
 #include "clang/Analysis/FlowSensitive/DataflowEnvironment.h"
+#include "clang/Basic/LLVM.h"
+#include "clang/Basic/SourceLocation.h"
 #include "clang/Basic/Specifiers.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "nullability-diagnostic"
 
 namespace clang::tidy::nullability {
 
@@ -26,52 +32,60 @@
 using dataflow::CFGMatchSwitchBuilder;
 using dataflow::Environment;
 using dataflow::TransferStateForDiagnostics;
+using ::llvm::SmallVector;
 
 namespace {
 
-// Returns true if `Expr` is uninterpreted or known to be nullable.
-bool isNullableOrUntracked(const Expr *E, const Environment &Env) {
-  auto *ActualVal = getPointerValueFromExpr(E, Env);
-  if (ActualVal == nullptr) {
+// Diagnoses whether `E` violates the expectation that it is nonnull.
+SmallVector<PointerNullabilityDiagnostic> diagnoseNonnullExpected(
+    const Expr *E, const Environment &Env) {
+  if (auto *ActualVal = getPointerValueFromExpr(E, Env)) {
+    if (isNullable(*ActualVal, Env))
+      return {{PointerNullabilityDiagnostic::ErrorCode::ExpectedNonnull,
+               CharSourceRange::getTokenRange(E->getSourceRange())}};
+    return {};
+  }
+
+  LLVM_DEBUG({
     llvm::dbgs()
-        << "The dataflow analysis framework does not model a PointerValue for "
+        << "The dataflow analysis framework does not model a PointerValue "
+           "for "
            "the following Expr, and thus its dereference is marked as "
            "unsafe:\n";
     E->dump();
-  }
-  return !ActualVal || isNullable(*ActualVal, Env);
+  });
+  return {{PointerNullabilityDiagnostic::ErrorCode::Untracked,
+           CharSourceRange::getTokenRange(E->getSourceRange())}};
 }
 
-// Returns true if an uninterpreted or nullable `Expr` was assigned to a
-// construct with a non-null `DeclaredType`.
-bool isIncompatibleAssignment(QualType DeclaredType, const Expr *E,
-                              const Environment &Env, ASTContext &Ctx) {
+// Diagnoses whether the nullability of `E` is incompatible with the expectation
+// set by `DeclaredType`.
+SmallVector<PointerNullabilityDiagnostic> diagnoseTypeExprCompatibility(
+    QualType DeclaredType, const Expr *E, const Environment &Env,
+    ASTContext &Ctx) {
   CHECK(DeclaredType->isAnyPointerType());
-  return getNullabilityKind(DeclaredType, Ctx) == NullabilityKind::NonNull &&
-         isNullableOrUntracked(E, Env);
+  return getNullabilityKind(DeclaredType, Ctx) == NullabilityKind::NonNull
+             ? diagnoseNonnullExpected(E, Env)
+             : SmallVector<PointerNullabilityDiagnostic>{};
 }
 
-std::optional<CFGElement> diagnoseDereference(
+SmallVector<PointerNullabilityDiagnostic> diagnoseDereference(
     const UnaryOperator *UnaryOp, const MatchFinder::MatchResult &,
     const TransferStateForDiagnostics<PointerNullabilityLattice> &State) {
-  if (isNullableOrUntracked(UnaryOp->getSubExpr(), State.Env)) {
-    return std::optional<CFGElement>(CFGStmt(UnaryOp));
-  }
-  return std::nullopt;
+  return diagnoseNonnullExpected(UnaryOp->getSubExpr(), State.Env);
 }
 
-std::optional<CFGElement> diagnoseArrow(
+SmallVector<PointerNullabilityDiagnostic> diagnoseArrow(
     const MemberExpr *MemberExpr, const MatchFinder::MatchResult &Result,
     const TransferStateForDiagnostics<PointerNullabilityLattice> &State) {
-  if (isNullableOrUntracked(MemberExpr->getBase(), State.Env)) {
-    return std::optional<CFGElement>(CFGStmt(MemberExpr));
-  }
-  return std::nullopt;
+  return diagnoseNonnullExpected(MemberExpr->getBase(), State.Env);
 }
 
-bool isIncompatibleArgumentList(const FunctionProtoType &CalleeFPT,
-                                ArrayRef<const Expr *> Args,
-                                const Environment &Env, ASTContext &Ctx) {
+// Diagnoses whether any of the arguments are incompatible with the
+// corresponding type in the function prototype.
+SmallVector<PointerNullabilityDiagnostic> diagnoseArgumentCompatibility(
+    const FunctionProtoType &CalleeFPT, ArrayRef<const Expr *> Args,
+    const Environment &Env, ASTContext &Ctx) {
   auto ParamTypes = CalleeFPT.getParamTypes();
   // C-style varargs cannot be annotated and therefore are unchecked.
   if (CalleeFPT.isVariadic()) {
@@ -79,16 +93,14 @@
     Args = Args.take_front(ParamTypes.size());
   }
   CHECK_EQ(ParamTypes.size(), Args.size());
+  SmallVector<PointerNullabilityDiagnostic> Diagnostics;
   for (unsigned int I = 0; I < Args.size(); ++I) {
     auto ParamType = ParamTypes[I].getNonReferenceType();
-    if (!ParamType->isAnyPointerType()) {
-      continue;
-    }
-    if (isIncompatibleAssignment(ParamType, Args[I], Env, Ctx)) {
-      return true;
-    }
+    if (ParamType->isAnyPointerType())
+      Diagnostics.append(
+          diagnoseTypeExprCompatibility(ParamType, Args[I], Env, Ctx));
   }
-  return false;
+  return Diagnostics;
 }
 
 NullabilityKind parseNullabilityKind(StringRef EnumName) {
@@ -131,7 +143,7 @@
 ///      __assert_nullability<NK_unspecified, NK_nullable>(p);
 ///    }
 /// \endcode
-bool diagnoseAssertNullabilityCall(
+SmallVector<PointerNullabilityDiagnostic> diagnoseAssertNullabilityCall(
     const CallExpr *CE,
     const TransferStateForDiagnostics<PointerNullabilityLattice> &State,
     ASTContext &Ctx) {
@@ -152,28 +164,30 @@
   const Expr *GivenExpr = CE->getArg(0);
   const TypeNullability *MaybeComputed =
       State.Lattice.getExprNullability(GivenExpr);
-  if (MaybeComputed == nullptr) {
-    llvm::dbgs()
-        << "Could not evaluate __assert_nullability. Could not find the "
-           "nullability of the argument expression: ";
-    CE->dump();
-    return false;
-  }
-  if (*MaybeComputed == Expected) return true;
-  // The computed and expected nullabilities differ. Print both to aid
-  // debugging.
-  llvm::dbgs() << "__assert_nullability failed at location: ";
-  CE->getExprLoc().print(llvm::dbgs(), Ctx.getSourceManager());
-  llvm::dbgs() << "\nExpression:\n";
-  GivenExpr->dump();
-  llvm::dbgs() << "Expected nullability: ";
-  llvm::dbgs() << nullabilityToString(Expected) << "\n";
-  llvm::dbgs() << "Computed nullability: ";
-  llvm::dbgs() << nullabilityToString(*MaybeComputed) << "\n";
-  return false;
+  if (MaybeComputed == nullptr)
+    return {{PointerNullabilityDiagnostic::ErrorCode::Untracked,
+             CharSourceRange::getTokenRange(CE->getSourceRange())}};
+
+  if (*MaybeComputed == Expected) return {};
+
+  LLVM_DEBUG({
+    // The computed and expected nullabilities differ. Print both to aid
+    // debugging.
+    llvm::dbgs() << "__assert_nullability failed at location: ";
+    CE->getExprLoc().print(llvm::dbgs(), Ctx.getSourceManager());
+    llvm::dbgs() << "\nExpression:\n";
+    GivenExpr->dump();
+    llvm::dbgs() << "Expected nullability: ";
+    llvm::dbgs() << nullabilityToString(Expected) << "\n";
+    llvm::dbgs() << "Computed nullability: ";
+    llvm::dbgs() << nullabilityToString(*MaybeComputed) << "\n";
+  });
+
+  return {{PointerNullabilityDiagnostic::ErrorCode::AssertFailed,
+           CharSourceRange::getTokenRange(CE->getSourceRange())}};
 }
 
-std::optional<CFGElement> diagnoseCallExpr(
+SmallVector<PointerNullabilityDiagnostic> diagnoseCallExpr(
     const CallExpr *CE, const MatchFinder::MatchResult &Result,
     const TransferStateForDiagnostics<PointerNullabilityLattice> &State) {
   // Check whether the callee is null.
@@ -182,29 +196,25 @@
   // - Skip member callees, as they are not pointers at all (rather "bound
   //   member function type").
   //   Note that in `(obj.*nullable_pmf)()` the deref is *before* the call.
-  if (!CE->getDirectCallee() && !isa<CXXMemberCallExpr>(CE) &&
-      isNullableOrUntracked(CE->getCallee(), State.Env)) {
-    return std::optional<CFGElement>(CFGStmt(CE->getCallee()));
+  if (!CE->getDirectCallee() && !isa<CXXMemberCallExpr>(CE)) {
+    auto D = diagnoseNonnullExpected(CE->getCallee(), State.Env);
+    if (!D.empty()) return D;
   }
 
-  if (auto *FD = CE->getDirectCallee()) {
+  if (auto *FD = CE->getDirectCallee())
     if (FD->getDeclName().isIdentifier() &&
-        FD->getName() == "__assert_nullability" &&
-        !diagnoseAssertNullabilityCall(CE, State, *Result.Context)) {
-      // TODO: Handle __assert_nullability failures differently from regular
-      // diagnostic ([[unsafe]]) failures.
-      return std::optional<CFGElement>(CFGStmt(CE));
+        FD->getName() == "__assert_nullability") {
+      return diagnoseAssertNullabilityCall(CE, State, *Result.Context);
     }
-  }
 
   auto *Callee = CE->getCalleeDecl();
   // TODO(mboehme): Retrieve the nullability directly from the callee using
   // `getNullabilityForChild(CE->getCallee())`, as what we have here now
   // doesn't work for callees that don't have a decl.
-  if (!Callee) return std::nullopt;
+  if (!Callee) return {};
 
   auto *CalleeType = Callee->getFunctionType();
-  if (!CalleeType) return std::nullopt;
+  if (!CalleeType) return {};
 
   // TODO(mboehme): We're only looking at the nullability spelled on the
   // `FunctionProtoType`, but there could be extra information in the callee.
@@ -220,7 +230,7 @@
   //   // (not sure if it is today)
   // }
   auto *CalleeFPT = CalleeType->getAs<FunctionProtoType>();
-  if (!CalleeFPT) return std::nullopt;
+  if (!CalleeFPT) return {};
 
   ArrayRef<const Expr *> Args(CE->getArgs(), CE->getNumArgs());
   // The first argument of an member operator call expression is the implicit
@@ -230,62 +240,55 @@
       isa<CXXMethodDecl>(CE->getDirectCallee())) {
     Args = Args.drop_front();
   }
-  return isIncompatibleArgumentList(*CalleeFPT, Args, State.Env,
-                                    *Result.Context)
-             ? std::optional<CFGElement>(CFGStmt(CE))
-             : std::nullopt;
+  return diagnoseArgumentCompatibility(*CalleeFPT, Args, State.Env,
+                                       *Result.Context);
 }
 
-std::optional<CFGElement> diagnoseConstructExpr(
+SmallVector<PointerNullabilityDiagnostic> diagnoseConstructExpr(
     const CXXConstructExpr *CE, const MatchFinder::MatchResult &Result,
     const TransferStateForDiagnostics<PointerNullabilityLattice> &State) {
   auto *CalleeFPT = CE->getConstructor()->getType()->getAs<FunctionProtoType>();
-  if (!CalleeFPT) return std::nullopt;
+  if (!CalleeFPT) return {};
   ArrayRef<const Expr *> ConstructorArgs(CE->getArgs(), CE->getNumArgs());
-  return isIncompatibleArgumentList(*CalleeFPT, ConstructorArgs, State.Env,
-                                    *Result.Context)
-             ? std::optional<CFGElement>(CFGStmt(CE))
-             : std::nullopt;
+  return diagnoseArgumentCompatibility(*CalleeFPT, ConstructorArgs, State.Env,
+                                       *Result.Context);
 }
 
-std::optional<CFGElement> diagnoseReturn(
+SmallVector<PointerNullabilityDiagnostic> diagnoseReturn(
     const ReturnStmt *RS, const MatchFinder::MatchResult &Result,
     const TransferStateForDiagnostics<PointerNullabilityLattice> &State) {
   auto ReturnType = cast<FunctionDecl>(State.Env.getDeclCtx())->getReturnType();
 
   // TODO: Handle non-pointer return types.
   if (!ReturnType->isPointerType()) {
-    return std::nullopt;
+    return {};
   }
 
   auto *ReturnExpr = RS->getRetValue();
   CHECK(ReturnExpr->getType()->isPointerType());
 
-  return isIncompatibleAssignment(ReturnType, ReturnExpr, State.Env,
-                                  *Result.Context)
-             ? std::optional<CFGElement>(CFGStmt(RS))
-             : std::nullopt;
+  return diagnoseTypeExprCompatibility(ReturnType, ReturnExpr, State.Env,
+                                       *Result.Context);
 }
 
-std::optional<CFGElement> diagnoseMemberInitializer(
+SmallVector<PointerNullabilityDiagnostic> diagnoseMemberInitializer(
     const CXXCtorInitializer *CI, const MatchFinder::MatchResult &Result,
     const TransferStateForDiagnostics<PointerNullabilityLattice> &State) {
   CHECK(CI->isAnyMemberInitializer());
   auto MemberType = CI->getAnyMember()->getType();
-  if (!MemberType->isAnyPointerType()) {
-    return std::nullopt;
-  }
-  auto MemberInitExpr = CI->getInit();
-  return isIncompatibleAssignment(MemberType, MemberInitExpr, State.Env,
-                                  *Result.Context)
-             ? std::optional<CFGElement>(CFGInitializer(CI))
-             : std::nullopt;
+  if (!MemberType->isAnyPointerType()) return {};
+
+  auto *MemberInitExpr = CI->getInit();
+  return diagnoseTypeExprCompatibility(MemberType, MemberInitExpr, State.Env,
+                                       *Result.Context);
 }
 
-auto buildDiagnoser() {
+}  // namespace
+
+PointerNullabilityDiagnoser pointerNullabilityDiagnoser() {
   return CFGMatchSwitchBuilder<const dataflow::TransferStateForDiagnostics<
                                    PointerNullabilityLattice>,
-                               std::optional<CFGElement>>()
+                               SmallVector<PointerNullabilityDiagnostic>>()
       // (*)
       .CaseOfCFGStmt<UnaryOperator>(isPointerDereference(), diagnoseDereference)
       // (->)
@@ -299,9 +302,4 @@
       .Build();
 }
 
-}  // namespace
-
-PointerNullabilityDiagnoser::PointerNullabilityDiagnoser()
-    : Diagnoser(buildDiagnoser()) {}
-
 }  // namespace clang::tidy::nullability
diff --git a/nullability/pointer_nullability_diagnosis.h b/nullability/pointer_nullability_diagnosis.h
index e5eb931..806166a 100644
--- a/nullability/pointer_nullability_diagnosis.h
+++ b/nullability/pointer_nullability_diagnosis.h
@@ -5,46 +5,48 @@
 #ifndef CRUBIT_NULLABILITY_POINTER_NULLABILITY_DIAGNOSIS_H_
 #define CRUBIT_NULLABILITY_POINTER_NULLABILITY_DIAGNOSIS_H_
 
-#include <optional>
+#include <functional>
 
 #include "nullability/pointer_nullability_lattice.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/Stmt.h"
-#include "clang/Analysis/FlowSensitive/CFGMatchSwitch.h"
-#include "clang/Analysis/FlowSensitive/DataflowEnvironment.h"
+#include "clang/Analysis/FlowSensitive/MatchSwitch.h"
+#include "clang/Basic/SourceLocation.h"
+#include "llvm/ADT/SmallVector.h"
 
 namespace clang {
 namespace tidy {
 namespace nullability {
 
+/// Diagnoses a nullability-related issue in the associated CFG element.
+struct PointerNullabilityDiagnostic {
+  enum class ErrorCode {
+    /// A nullable pointer was used where a nonnull pointer was expected.
+    ExpectedNonnull,
+    /// A pointer-typed expression was encountered with no corresponding model.
+    Untracked,
+    /// A nullability assertion was violated.
+    AssertFailed,
+  };
+  ErrorCode Code;
+  CharSourceRange Range;
+};
+
 /// Checks that nullable pointers are used safely, using nullability information
 /// that is collected by `PointerNullabilityAnalysis`.
 ///
 /// Examples of null safety violations include dereferencing nullable pointers
 /// without null checks, and assignments between pointers of incompatible
 /// nullability.
-class PointerNullabilityDiagnoser {
- public:
-  PointerNullabilityDiagnoser();
+///
+/// The diagnoser returns an empty vector when no issues are found in the code.
+using PointerNullabilityDiagnoser =
+    std::function<llvm::SmallVector<PointerNullabilityDiagnostic>(
+        const CFGElement &, ASTContext &,
+        const dataflow::TransferStateForDiagnostics<PointerNullabilityLattice>
+            &)>;
 
-  /// Returns the pointer to the statement if null safety is violated, otherwise
-  /// the optional is empty.
-  ///
-  /// TODO(b/233582219): Extend diagnosis to return more information, e.g. the
-  /// type of violation.
-  std::optional<CFGElement> diagnose(
-      const CFGElement *Elt, ASTContext &Ctx,
-      const dataflow::TransferStateForDiagnostics<PointerNullabilityLattice>
-          &State) {
-    return Diagnoser(*Elt, Ctx, State);
-  }
-
- private:
-  dataflow::CFGMatchSwitch<
-      const dataflow::TransferStateForDiagnostics<PointerNullabilityLattice>,
-      std::optional<CFGElement>>
-      Diagnoser;
-};
+PointerNullabilityDiagnoser pointerNullabilityDiagnoser();
 
 }  // namespace nullability
 }  // namespace tidy
diff --git a/nullability/test/check_diagnostics.cc b/nullability/test/check_diagnostics.cc
index 4b4b9c8..291237a 100644
--- a/nullability/test/check_diagnostics.cc
+++ b/nullability/test/check_diagnostics.cc
@@ -4,10 +4,14 @@
 
 #include "nullability/test/check_diagnostics.h"
 
+#include <iterator>
+#include <vector>
+
 #include "nullability/pointer_nullability_analysis.h"
 #include "nullability/pointer_nullability_diagnosis.h"
 #include "clang/Analysis/CFG.h"
 #include "third_party/llvm/llvm-project/clang/unittests/Analysis/FlowSensitive/TestingSupport.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Testing/Support/Error.h"
 #include "third_party/llvm/llvm-project/third-party/unittest/googletest/include/gtest/gtest.h"
 
@@ -39,8 +43,8 @@
 )cc";
 
 bool checkDiagnostics(llvm::StringRef SourceCode) {
-  std::vector<CFGElement> Diagnostics;
-  PointerNullabilityDiagnoser Diagnoser;
+  std::vector<PointerNullabilityDiagnostic> Diagnostics;
+  PointerNullabilityDiagnoser Diagnoser = pointerNullabilityDiagnoser();
   bool Failed = false;
   EXPECT_THAT_ERROR(
       dataflow::test::checkDataflow<PointerNullabilityAnalysis>(
@@ -53,10 +57,8 @@
                                     ASTContext &Ctx, const CFGElement &Elt,
                                     const dataflow::TransferStateForDiagnostics<
                                         PointerNullabilityLattice> &State) {
-                auto EltDiagnostics = Diagnoser.diagnose(&Elt, Ctx, State);
-                if (EltDiagnostics.has_value()) {
-                  Diagnostics.push_back(EltDiagnostics.value());
-                }
+                auto EltDiagnostics = Diagnoser(Elt, Ctx, State);
+                llvm::move(EltDiagnostics, std::back_inserter(Diagnostics));
               })
               .withASTBuildVirtualMappedFiles(
                   {{"preamble.h", kPreamble}, {"new", kNewHeader}})
@@ -73,22 +75,11 @@
               ExpectedLines.insert(Line);
             }
             auto &SrcMgr = AnalysisData.ASTCtx.getSourceManager();
-            for (auto Element : Diagnostics) {
-              if (std::optional<CFGStmt> stmt = Element.getAs<CFGStmt>()) {
-                ActualLines.insert(SrcMgr.getPresumedLineNumber(
-                    stmt->getStmt()->getBeginLoc()));
-              } else if (std::optional<CFGInitializer> init =
-                             Element.getAs<CFGInitializer>()) {
-                ActualLines.insert(SrcMgr.getPresumedLineNumber(
-                    init->getInitializer()->getSourceLocation()));
-              } else {
-                ADD_FAILURE() << "this code should not be reached";
-              }
-            }
+            for (auto Diag : Diagnostics)
+              ActualLines.insert(
+                  SrcMgr.getPresumedLineNumber(Diag.Range.getBegin()));
             EXPECT_THAT(ActualLines, testing::ContainerEq(ExpectedLines));
-            if (ActualLines != ExpectedLines) {
-              Failed = true;
-            }
+            if (ActualLines != ExpectedLines) Failed = true;
           }),
       llvm::Succeeded());
   return !Failed;
diff --git a/nullability/test/function_calls.cc b/nullability/test/function_calls.cc
index 5911e98..51b0b76 100644
--- a/nullability/test/function_calls.cc
+++ b/nullability/test/function_calls.cc
@@ -258,6 +258,17 @@
   )cc"));
 }
 
+TEST(PointerNullabilityTest, CallExprMultiNonnullParams) {
+  EXPECT_TRUE(checkDiagnostics(R"cc(
+    void take(int *_Nonnull, int *_Nullable, int *_Nonnull);
+    void target() {
+      take(nullptr,  // [[unsafe]]
+           nullptr,
+           nullptr);  // [[unsafe]]
+    }
+  )cc"));
+}
+
 TEST(PointerNullabilityTest, CanOverwritePtrWithPtrCreatedFromRefReturnType) {
   // Test that if we create a pointer from a function returning a reference, we
   // can use that pointer to overwrite an existing nullable pointer and make it