Split PointerNullabilityDiagnostic Error Codes to produce more specific messages in findings

PiperOrigin-RevId: 581954638
Change-Id: Ieaab4c5df1032ffe382241adf96d5d8ce7724f71
diff --git a/nullability/pointer_nullability_diagnosis.cc b/nullability/pointer_nullability_diagnosis.cc
index 5f5b229..c9b6b2e 100644
--- a/nullability/pointer_nullability_diagnosis.cc
+++ b/nullability/pointer_nullability_diagnosis.cc
@@ -4,6 +4,8 @@
 
 #include "nullability/pointer_nullability_diagnosis.h"
 
+#include <optional>
+
 #include "absl/log/check.h"
 #include "nullability/pointer_nullability.h"
 #include "nullability/pointer_nullability_matchers.h"
@@ -38,11 +40,14 @@
 
 // Diagnoses whether `E` violates the expectation that it is nonnull.
 SmallVector<PointerNullabilityDiagnostic> diagnoseNonnullExpected(
-    const Expr *E, const Environment &Env) {
+    const Expr *E, const Environment &Env,
+    PointerNullabilityDiagnostic::Context DiagCtx,
+    std::optional<std::string> ParamName = std::nullopt) {
   if (auto *ActualVal = getPointerValueFromExpr(E, Env)) {
     if (isNullable(*ActualVal, Env))
       return {{PointerNullabilityDiagnostic::ErrorCode::ExpectedNonnull,
-               CharSourceRange::getTokenRange(E->getSourceRange())}};
+               DiagCtx, CharSourceRange::getTokenRange(E->getSourceRange()),
+               std::move(ParamName)}};
     return {};
   }
 
@@ -54,7 +59,7 @@
            "unsafe:\n";
     E->dump();
   });
-  return {{PointerNullabilityDiagnostic::ErrorCode::Untracked,
+  return {{PointerNullabilityDiagnostic::ErrorCode::Untracked, DiagCtx,
            CharSourceRange::getTokenRange(E->getSourceRange())}};
 }
 
@@ -62,30 +67,36 @@
 // set by `DeclaredType`.
 SmallVector<PointerNullabilityDiagnostic> diagnoseTypeExprCompatibility(
     QualType DeclaredType, const Expr *E, const Environment &Env,
-    ASTContext &Ctx) {
+    ASTContext &Ctx, PointerNullabilityDiagnostic::Context DiagCtx,
+    std::optional<std::string> ParamName = std::nullopt) {
   CHECK(isSupportedRawPointerType(DeclaredType));
   return getNullabilityKind(DeclaredType, Ctx) == NullabilityKind::NonNull
-             ? diagnoseNonnullExpected(E, Env)
+             ? diagnoseNonnullExpected(E, Env, DiagCtx, ParamName)
              : SmallVector<PointerNullabilityDiagnostic>{};
 }
 
 SmallVector<PointerNullabilityDiagnostic> diagnoseDereference(
     const UnaryOperator *UnaryOp, const MatchFinder::MatchResult &,
     const TransferStateForDiagnostics<PointerNullabilityLattice> &State) {
-  return diagnoseNonnullExpected(UnaryOp->getSubExpr(), State.Env);
+  return diagnoseNonnullExpected(
+      UnaryOp->getSubExpr(), State.Env,
+      PointerNullabilityDiagnostic::Context::NullableDereference);
 }
 
 SmallVector<PointerNullabilityDiagnostic> diagnoseArrow(
     const MemberExpr *MemberExpr, const MatchFinder::MatchResult &Result,
     const TransferStateForDiagnostics<PointerNullabilityLattice> &State) {
-  return diagnoseNonnullExpected(MemberExpr->getBase(), State.Env);
+  return diagnoseNonnullExpected(
+      MemberExpr->getBase(), State.Env,
+      PointerNullabilityDiagnostic::Context::NullableDereference);
 }
 
 // Diagnoses whether any of the arguments are incompatible with the
 // corresponding type in the function prototype.
 SmallVector<PointerNullabilityDiagnostic> diagnoseArgumentCompatibility(
     const FunctionProtoType &CalleeFPT, ArrayRef<const Expr *> Args,
-    const Environment &Env, ASTContext &Ctx) {
+    ArrayRef<const ParmVarDecl *> ParmDecls, const Environment &Env,
+    ASTContext &Ctx) {
   auto ParamTypes = CalleeFPT.getParamTypes();
   // C-style varargs cannot be annotated and therefore are unchecked.
   if (CalleeFPT.isVariadic()) {
@@ -96,9 +107,15 @@
   SmallVector<PointerNullabilityDiagnostic> Diagnostics;
   for (unsigned int I = 0; I < Args.size(); ++I) {
     auto ParamType = ParamTypes[I].getNonReferenceType();
-    if (isSupportedRawPointerType(ParamType))
-      Diagnostics.append(
-          diagnoseTypeExprCompatibility(ParamType, Args[I], Env, Ctx));
+    if (isSupportedRawPointerType(ParamType)) {
+      std::string ParamName = (I < ParmDecls.size())
+                                  ? ParmDecls[I]->getDeclName().getAsString()
+                                  : "";
+      Diagnostics.append(diagnoseTypeExprCompatibility(
+          ParamType, Args[I], Env, Ctx,
+          PointerNullabilityDiagnostic::Context::FunctionArgument,
+          std::move(ParamName)));
+    }
   }
   return Diagnostics;
 }
@@ -166,6 +183,7 @@
       State.Lattice.getExprNullability(GivenExpr);
   if (MaybeComputed == nullptr)
     return {{PointerNullabilityDiagnostic::ErrorCode::Untracked,
+             PointerNullabilityDiagnostic::Context::Other,
              CharSourceRange::getTokenRange(CE->getSourceRange())}};
 
   if (*MaybeComputed == Expected) return {};
@@ -184,6 +202,7 @@
   });
 
   return {{PointerNullabilityDiagnostic::ErrorCode::AssertFailed,
+           PointerNullabilityDiagnostic::Context::Other,
            CharSourceRange::getTokenRange(CE->getSourceRange())}};
 }
 
@@ -197,15 +216,18 @@
   //   member function type").
   //   Note that in `(obj.*nullable_pmf)()` the deref is *before* the call.
   if (!CE->getDirectCallee() && !isa<CXXMemberCallExpr>(CE)) {
-    auto D = diagnoseNonnullExpected(CE->getCallee(), State.Env);
+    auto D =
+        diagnoseNonnullExpected(CE->getCallee(), State.Env,
+                                PointerNullabilityDiagnostic::Context::Other);
     if (!D.empty()) return D;
   }
 
-  if (auto *FD = CE->getDirectCallee())
+  if (auto *FD = CE->getDirectCallee()) {
     if (FD->getDeclName().isIdentifier() &&
         FD->getName() == "__assert_nullability") {
       return diagnoseAssertNullabilityCall(CE, State, *Result.Context);
     }
+  }
 
   auto *Callee = CE->getCalleeDecl();
   // TODO(mboehme): Retrieve the nullability directly from the callee using
@@ -240,7 +262,10 @@
       isa<CXXMethodDecl>(CE->getDirectCallee())) {
     Args = Args.drop_front();
   }
-  return diagnoseArgumentCompatibility(*CalleeFPT, Args, State.Env,
+  ArrayRef<const ParmVarDecl *> ParmDecls = {};
+  if (Callee->getAsFunction())
+    ParmDecls = Callee->getAsFunction()->parameters();
+  return diagnoseArgumentCompatibility(*CalleeFPT, Args, ParmDecls, State.Env,
                                        *Result.Context);
 }
 
@@ -250,8 +275,11 @@
   auto *CalleeFPT = CE->getConstructor()->getType()->getAs<FunctionProtoType>();
   if (!CalleeFPT) return {};
   ArrayRef<const Expr *> ConstructorArgs(CE->getArgs(), CE->getNumArgs());
-  return diagnoseArgumentCompatibility(*CalleeFPT, ConstructorArgs, State.Env,
-                                       *Result.Context);
+
+  return diagnoseArgumentCompatibility(
+      *CalleeFPT, ConstructorArgs,
+      CE->getConstructor()->getAsFunction()->parameters(), State.Env,
+      *Result.Context);
 }
 
 SmallVector<PointerNullabilityDiagnostic> diagnoseReturn(
@@ -267,8 +295,9 @@
   auto *ReturnExpr = RS->getRetValue();
   CHECK(isSupportedRawPointerType(ReturnExpr->getType()));
 
-  return diagnoseTypeExprCompatibility(ReturnType, ReturnExpr, State.Env,
-                                       *Result.Context);
+  return diagnoseTypeExprCompatibility(
+      ReturnType, ReturnExpr, State.Env, *Result.Context,
+      PointerNullabilityDiagnostic::Context::ReturnValue);
 }
 
 SmallVector<PointerNullabilityDiagnostic> diagnoseMemberInitializer(
@@ -279,8 +308,9 @@
   if (!isSupportedRawPointerType(MemberType)) return {};
 
   auto *MemberInitExpr = CI->getInit();
-  return diagnoseTypeExprCompatibility(MemberType, MemberInitExpr, State.Env,
-                                       *Result.Context);
+  return diagnoseTypeExprCompatibility(
+      MemberType, MemberInitExpr, State.Env, *Result.Context,
+      PointerNullabilityDiagnostic::Context::Initializer);
 }
 
 }  // namespace
diff --git a/nullability/pointer_nullability_diagnosis.h b/nullability/pointer_nullability_diagnosis.h
index 806166a..d41cf09 100644
--- a/nullability/pointer_nullability_diagnosis.h
+++ b/nullability/pointer_nullability_diagnosis.h
@@ -6,6 +6,8 @@
 #define CRUBIT_NULLABILITY_POINTER_NULLABILITY_DIAGNOSIS_H_
 
 #include <functional>
+#include <optional>
+#include <string>
 
 #include "nullability/pointer_nullability_lattice.h"
 #include "clang/AST/ASTContext.h"
@@ -29,7 +31,23 @@
     AssertFailed,
   };
   ErrorCode Code;
+  /// Context in which the error occurred.
+  enum class Context {
+    /// Dereferencing a pointer.
+    NullableDereference,
+    /// Initializing a variable.
+    Initializer,
+    /// Value of a return statement.
+    ReturnValue,
+    /// Function argument.
+    FunctionArgument,
+    Other
+  } Ctx = Context::Other;
   CharSourceRange Range;
+  /// Name of the parameter that the argument is being passed to.
+  /// Populated only if `Ctx` is `FunctionArgument` and the parameter name is
+  /// known.
+  std::optional<std::string> ParamName;
 };
 
 /// Checks that nullable pointers are used safely, using nullability information