Create PointerValues for function calls returning a pointer and initialise its nullability based on the return type of the function.
PiperOrigin-RevId: 468781668
diff --git a/nullability_verification/pointer_nullability_analysis.cc b/nullability_verification/pointer_nullability_analysis.cc
index cc4ba80..773d3c6 100644
--- a/nullability_verification/pointer_nullability_analysis.cc
+++ b/nullability_verification/pointer_nullability_analysis.cc
@@ -137,6 +137,23 @@
State.Env.setStorageLocation(*CastExpr, CastExprLoc);
}
+void transferCallExpr(const CallExpr* CallExpr,
+ const MatchFinder::MatchResult& Result,
+ TransferState<NoopLattice>& State) {
+ auto ReturnType = CallExpr->getType();
+ if (!ReturnType->isAnyPointerType()) return;
+
+ auto* PointerVal = getPointerValueFromExpr(CallExpr, State.Env);
+ if (!PointerVal) {
+ PointerVal = cast<PointerValue>(State.Env.createValue(ReturnType));
+ auto& CallExprLoc = State.Env.createStorageLocation(*CallExpr);
+ State.Env.setValue(CallExprLoc, *PointerVal);
+ State.Env.setStorageLocation(*CallExpr, CallExprLoc);
+ }
+ initPointerFromAnnotations(*PointerVal, ReturnType, State.Env,
+ *Result.Context);
+}
+
auto buildTransferer() {
return MatchSwitchBuilder<TransferState<NoopLattice>>()
// Handles initialization of the null states of pointers
@@ -145,6 +162,7 @@
.CaseOf<Expr>(isAddrOf(), transferNotNullPointer)
.CaseOf<Expr>(isNullPointerLiteral(), transferNullPointer)
.CaseOf<MemberExpr>(isMemberOfPointerType(), transferPointer)
+ .CaseOf<CallExpr>(isCallExpr(), transferCallExpr)
// Handles comparison between 2 pointers
.CaseOf<BinaryOperator>(isPointerCheckBinOp(),
transferNullCheckComparison)
diff --git a/nullability_verification/pointer_nullability_matchers.cc b/nullability_verification/pointer_nullability_matchers.cc
index 300547a..15e5f9f 100644
--- a/nullability_verification/pointer_nullability_matchers.cc
+++ b/nullability_verification/pointer_nullability_matchers.cc
@@ -13,6 +13,7 @@
using ast_matchers::anyOf;
using ast_matchers::binaryOperator;
+using ast_matchers::callExpr;
using ast_matchers::cxxThisExpr;
using ast_matchers::declRefExpr;
using ast_matchers::expr;
@@ -53,6 +54,8 @@
}
Matcher<Stmt> isPointerArrow() { return memberExpr(isArrow()); }
Matcher<Stmt> isCXXThisExpr() { return cxxThisExpr(); }
+Matcher<Stmt> isCallExpr() { return callExpr(); }
+
} // namespace nullability
} // namespace tidy
} // namespace clang
diff --git a/nullability_verification/pointer_nullability_matchers.h b/nullability_verification/pointer_nullability_matchers.h
index c5562ae..c666321 100644
--- a/nullability_verification/pointer_nullability_matchers.h
+++ b/nullability_verification/pointer_nullability_matchers.h
@@ -20,6 +20,8 @@
ast_matchers::internal::Matcher<Stmt> isPointerDereference();
ast_matchers::internal::Matcher<Stmt> isPointerCheckBinOp();
ast_matchers::internal::Matcher<Stmt> isImplicitCastPointerToBool();
+ast_matchers::internal::Matcher<Stmt> isCallExpr();
+
} // namespace nullability
} // namespace tidy
} // namespace clang
diff --git a/nullability_verification/pointer_nullability_verification_test.cc b/nullability_verification/pointer_nullability_verification_test.cc
index a69630b..c4504f0 100644
--- a/nullability_verification/pointer_nullability_verification_test.cc
+++ b/nullability_verification/pointer_nullability_verification_test.cc
@@ -1304,6 +1304,91 @@
)");
}
+TEST(PointerNullabilityTest, CallExprWithPointerReturnType) {
+ // free function
+ checkDiagnostics(R"(
+ int * _Nonnull makeNonnull();
+ int * _Nullable makeNullable();
+ int *makeUnannotated();
+ void target() {
+ *makeNonnull();
+ *makeNullable(); // [[unsafe]]
+ *makeUnannotated();
+ }
+ )");
+
+ // member function
+ checkDiagnostics(R"(
+ struct Foo {
+ int * _Nonnull makeNonnull();
+ int * _Nullable makeNullable();
+ int *makeUnannotated();
+ };
+ void target(Foo foo) {
+ *foo.makeNonnull();
+ *foo.makeNullable(); // [[unsafe]]
+ *foo.makeUnannotated();
+ }
+ )");
+
+ // function pointer
+ checkDiagnostics(R"(
+ void target(int * _Nonnull (*makeNonnull)(),
+ int * _Nullable (*makeNullable)(),
+ int * (*makeUnannotated)()) {
+ *makeNonnull();
+ *makeNullable(); // [[unsafe]]
+ *makeUnannotated();
+ }
+ )");
+
+ // pointer to function pointer
+ checkDiagnostics(R"(
+ void target(int * _Nonnull (**makeNonnull)(),
+ int * _Nullable (**makeNullable)(),
+ int * (**makeUnannotated)()) {
+ *(*makeNonnull)();
+ *(*makeNullable)(); // [[unsafe]]
+ *(*makeUnannotated)();
+ }
+ )");
+
+ // function returning a function pointer which returns a pointer
+ checkDiagnostics(R"(
+ typedef int * _Nonnull (*MakeNonnullT)();
+ typedef int * _Nullable (*MakeNullableT)();
+ typedef int * (*MakeUnannotatedT)();
+ void target(MakeNonnullT (*makeNonnull)(),
+ MakeNullableT (*makeNullable)(),
+ MakeUnannotatedT (*makeUnannotated)()) {
+ *(*makeNonnull)()();
+ *(*makeNullable)()(); // [[unsafe]]
+ *(*makeUnannotated)()();
+ }
+ )");
+
+ // function called in loop
+ //
+ // TODO(b/233582219): Fix false negative. The pointer is only null-checked and
+ // therefore safe to dereference on the first iteration of the loop. On
+ // subsequent iterations of the loop, the pointer dereference is unsafe due to
+ // the lack of null check. The diagnoser currently fails to catch the
+ // unsafe dereference as it only evaluates the statement once.
+ checkDiagnostics(R"(
+ int * _Nullable makeNullable();
+ bool makeBool();
+ void target() {
+ bool first = true;
+ while(true) {
+ int *x = makeNullable();
+ if (first && x == nullptr) return;
+ first = false;
+ *x; // false-negative
+ }
+ }
+ )");
+}
+
} // namespace
} // namespace nullability
} // namespace tidy