Introduce `assert_nullability`, a debugging function for nullability verification
Currently, the way in which we test the presence of nullability annotations is a bit limited. We can only check it through explicitly dereferencing pointer expressions and checking whether an "unsafe" warning is emitted. This can be inconvienent, because one needs several lines of code (i.e., several dereferences) to test a single expression that contains several pointers.
Convenience aside, our current check is also ambiguous. Firstly, it cannot distinguish between nonnull and unspecified annotations, as dereferencing either of these types is considered a safe operation. Secondly, on an unsafe dereference of a nested expression, it cannot pinpoint *which* subexpression caused the unsafety (consider, for example, a nullable pointer to a pair of nonnull pointers, or a nonnull pointer to a pair of nullable pointers; `*myPair->first` would be unsafe in both cases, but for different reasons. More granular information about which subexpression caused the nullable-ness would be helpful for debugging).
To solve these issues, we have implemented `__assert_nullability`, a special function to check, in one line of code, whether all the nullability information computed for an expression corresponds to what we expect. All the user needs to do is pass the expression of interest as argument to `__assert_nullability`, and write the expression's expected nullability as a template parameter.
Consider:
```
enum NullabilityKind{
_nonnull,
_nullable,
_unspecified,
};
template<NullabilityKind ...NK, typename T>
void __assert_nullability(T&);
template<typename T0, typename T1>
struct Struct2Arg{};
void target(Struct2Arg<int *, int * _Nullable> p){
__assert_nullability<_unspecified, _nullable>(p);
}
```
During nullability diagnosis (in particular, `diagnoseCallExpr`), we now check whether the function being diagnosed is called `__assert_nullability`; if so, then we retrieve from the lattice the nullability vector computed for the function's argument (in the example above, `p`), and compare it with the nullabilities passed to the call as a template parameter pack (in the example, `_unspecified, _nullable`). If the two vectors are not equivalent, we signal an error and print both the expected and the computed nullability.
Future improvements:
- Add the definition of `__assert_nullability` and `NullabilityKind` as a preamble to every test, to avoid unneccesary repetition. This change will also involve adjusting the line counting offset so that line numbers output by the diagnoser reflect those seen by the user (i.e., line count should not include the preamble).
- Add another type of diagnostic check in the testing framework so we can distinguish between regular unsafe warnings and warnings of the kind "failing `__assert_nullability` check".
PiperOrigin-RevId: 500745169
diff --git a/nullability_verification/pointer_nullability_diagnosis.cc b/nullability_verification/pointer_nullability_diagnosis.cc
index 66d7f3b..02abba7 100644
--- a/nullability_verification/pointer_nullability_diagnosis.cc
+++ b/nullability_verification/pointer_nullability_diagnosis.cc
@@ -5,6 +5,7 @@
#include "nullability_verification/pointer_nullability_diagnosis.h"
#include <optional>
+#include <string>
#include "nullability_verification/pointer_nullability.h"
#include "nullability_verification/pointer_nullability_matchers.h"
@@ -78,6 +79,100 @@
return false;
}
+NullabilityKind parseNullabilityKind(StringRef EnumName) {
+ return llvm::StringSwitch<NullabilityKind>(EnumName)
+ .Case("NK_nonnull", NullabilityKind::NonNull)
+ .Case("NK_nullable", NullabilityKind::Nullable)
+ .Case("NK_unspecified", NullabilityKind::Unspecified)
+ .Default(NullabilityKind::Unspecified);
+}
+
+std::string nullabilityToString(ArrayRef<NullabilityKind> Nullability) {
+ std::string Result = "[";
+ llvm::interleave(
+ Nullability,
+ [&](const NullabilityKind n) {
+ Result += getNullabilitySpelling(n).str();
+ },
+ [&] { Result += ", "; });
+ Result += "]";
+ return Result;
+}
+
+/// Evaluates the `__assert_nullability` call by comparing the expected
+/// nullability to the nullability computed by the dataflow analysis.
+///
+/// If the function being diagnosed is called `__assert_nullability`, we assume
+/// it is a call of the shape __assert_nullability<a, b, c, ...>(p), where `p`
+/// is an expression that contains pointers and a, b, c ... represent each of
+/// the NullabilityKinds in `p`'s expected nullability. An expression's
+/// nullability can be expressed as a vector of NullabilityKinds, where each
+/// vector element corresponds to one of the pointers contained in the
+/// expression.
+///
+/// For example:
+/// \code
+/// enum NullabilityKind {
+/// NK_nonnull,
+/// NK_nullable,
+/// NK_unspecified,
+/// };
+///
+/// template<NullabilityKind ...NK, typename T>
+/// void __assert_nullability(T&);
+///
+/// template<typename T0, typename T1>
+/// struct Struct2Arg {
+/// T0 arg0;
+/// T1 arg1;
+/// };
+///
+/// void target(Struct2Arg<int *, int * _Nullable> p) {
+/// __assert_nullability<NK_unspecified, NK_nullable>(p);
+/// }
+/// \endcode
+bool diagnoseAssertNullabilityCall(
+ const CallExpr* CE,
+ const TransferStateForDiagnostics<PointerNullabilityLattice>& State,
+ ASTContext& Ctx) {
+ auto* DRE = cast<DeclRefExpr>(CE->getCallee()->IgnoreImpCasts());
+
+ // Extract the expected nullability from the template parameter pack.
+ std::vector<NullabilityKind> Expected;
+ for (auto P : DRE->template_arguments()) {
+ if (P.getArgument().getKind() == TemplateArgument::Expression) {
+ if (auto* EnumDRE = dyn_cast<DeclRefExpr>(P.getSourceExpression())) {
+ Expected.push_back(parseNullabilityKind(EnumDRE->getDecl()->getName()));
+ }
+ }
+ }
+
+ // Compare the nullability computed by nullability analysis with the
+ // expected one.
+ const Expr* GivenExpr = CE->getArg(0);
+ Optional<ArrayRef<NullabilityKind>> MaybeComputed =
+ State.Lattice.getExprNullability(GivenExpr);
+ if (!MaybeComputed.has_value()) {
+ llvm::dbgs()
+ << "Could not evaluate __assert_nullability. Could not find the "
+ "nullability of the argument expression: ";
+ CE->dump();
+ return false;
+ }
+ if (MaybeComputed->vec() == Expected) return true;
+ // The computed and expected nullabilities differ. Print both to aid
+ // debugging.
+ llvm::dbgs() << "__assert_nullability failed at location: ";
+ CE->getExprLoc().print(llvm::dbgs(), Ctx.getSourceManager());
+ llvm::dbgs() << "\nExpression:\n";
+ GivenExpr->dump();
+ llvm::dbgs() << "Expected nullability: ";
+ llvm::dbgs() << nullabilityToString(Expected) << "\n";
+ llvm::dbgs() << "Computed nullability: ";
+ llvm::dbgs() << nullabilityToString(*MaybeComputed) << "\n";
+ return false;
+}
+
// TODO(b/233582219): Handle call expressions whose callee is not a decl (e.g.
// a function returned from another function), or when the callee cannot be
// interpreted as a function type (e.g. a pointer to a function pointer).
@@ -90,6 +185,16 @@
auto* CalleeType = Callee->getFunctionType();
if (!CalleeType) return std::nullopt;
+ if (auto* FD = Callee->getAsFunction()) {
+ if (FD->getDeclName().isIdentifier() &&
+ FD->getName() == "__assert_nullability" &&
+ !diagnoseAssertNullabilityCall(CE, State, *Result.Context)) {
+ // TODO: Handle __assert_nullability failures differently from regular
+ // diagnostic ([[unsafe]]) failures.
+ return llvm::Optional<CFGElement>(CFGStmt(CE));
+ }
+ }
+
auto ParamTypes = CalleeType->getAs<FunctionProtoType>()->getParamTypes();
ArrayRef<const Expr*> Args(CE->getArgs(), CE->getNumArgs());
if (isa<CXXOperatorCallExpr>(CE)) {