Refactor diagnoser function to return CFGElement.
PiperOrigin-RevId: 487174212
diff --git a/nullability_verification/pointer_nullability_diagnosis.cc b/nullability_verification/pointer_nullability_diagnosis.cc
index d1aa1e0..6559beb 100644
--- a/nullability_verification/pointer_nullability_diagnosis.cc
+++ b/nullability_verification/pointer_nullability_diagnosis.cc
@@ -41,20 +41,20 @@
isNullableOrUntracked(E, Env);
}
-llvm::Optional<const Stmt*> diagnoseDereference(const UnaryOperator* UnaryOp,
- const MatchFinder::MatchResult&,
- const Environment& Env) {
+llvm::Optional<CFGElement> diagnoseDereference(const UnaryOperator* UnaryOp,
+ const MatchFinder::MatchResult&,
+ const Environment& Env) {
if (isNullableOrUntracked(UnaryOp->getSubExpr(), Env)) {
- return UnaryOp;
+ return llvm::Optional<CFGElement>(CFGStmt(UnaryOp));
}
return llvm::None;
}
-llvm::Optional<const Stmt*> diagnoseArrow(
- const MemberExpr* MemberExpr, const MatchFinder::MatchResult& Result,
- const Environment& Env) {
+llvm::Optional<CFGElement> diagnoseArrow(const MemberExpr* MemberExpr,
+ const MatchFinder::MatchResult& Result,
+ const Environment& Env) {
if (isNullableOrUntracked(MemberExpr->getBase(), Env)) {
- return MemberExpr;
+ return llvm::Optional<CFGElement>(CFGStmt(MemberExpr));
}
return llvm::None;
}
@@ -78,7 +78,7 @@
// TODO(b/233582219): Handle call expressions whose callee is not a decl (e.g.
// a function returned from another function), or when the callee cannot be
// interpreted as a function type (e.g. a pointer to a function pointer).
-llvm::Optional<const Stmt*> diagnoseCallExpr(
+llvm::Optional<CFGElement> diagnoseCallExpr(
const CallExpr* CE, const MatchFinder::MatchResult& Result,
const Environment& Env) {
auto* Callee = CE->getCalleeDecl();
@@ -96,11 +96,11 @@
}
return isIncompatibleArgumentList(ParamTypes, Args, Env, *Result.Context)
- ? llvm::Optional<const Stmt*>(CE)
+ ? llvm::Optional<CFGElement>(CFGStmt(CE))
: llvm::None;
}
-llvm::Optional<const Stmt*> diagnoseConstructExpr(
+llvm::Optional<CFGElement> diagnoseConstructExpr(
const CXXConstructExpr* CE, const MatchFinder::MatchResult& Result,
const Environment& Env) {
auto ConstructorParamTypes = CE->getConstructor()
@@ -110,11 +110,11 @@
ArrayRef<const Expr*> ConstructorArgs(CE->getArgs(), CE->getNumArgs());
return isIncompatibleArgumentList(ConstructorParamTypes, ConstructorArgs, Env,
*Result.Context)
- ? llvm::Optional<const Stmt*>(CE)
+ ? llvm::Optional<CFGElement>(CFGStmt(CE))
: llvm::None;
}
-llvm::Optional<const Stmt*> diagnoseReturn(
+llvm::Optional<CFGElement> diagnoseReturn(
const ReturnStmt* RS, const MatchFinder::MatchResult& Result,
const Environment& Env) {
auto ReturnType = cast<FunctionDecl>(Env.getDeclCtx())->getReturnType();
@@ -124,11 +124,11 @@
assert(ReturnExpr->getType()->isPointerType());
return isIncompatibleAssignment(ReturnType, ReturnExpr, Env, *Result.Context)
- ? llvm::Optional<const Stmt*>(RS)
+ ? llvm::Optional<CFGElement>(CFGStmt(RS))
: llvm::None;
}
-llvm::Optional<const Stmt*> diagnoseMemberInitializer(
+llvm::Optional<CFGElement> diagnoseMemberInitializer(
const CXXCtorInitializer* CI, const MatchFinder::MatchResult& Result,
const Environment& Env) {
assert(CI->isAnyMemberInitializer());
@@ -139,19 +139,12 @@
auto MemberInitExpr = CI->getInit();
return isIncompatibleAssignment(MemberType, MemberInitExpr, Env,
*Result.Context)
- // TODO(b/233582219): CtorInitializer is not compatible with the
- // return type as it is not a Stmt. Therefore, we currently return
- // the expression in the initializer. The return type should be
- // modified to work over different AST nodes. For example,
- // returning a SourceLocation or creating a Diagnostic base class
- // that will contain more information about the violation and store
- // the relevant AST nodes.
- ? llvm::Optional<const Stmt*>(MemberInitExpr)
+ ? llvm::Optional<CFGElement>(CFGInitializer(CI))
: llvm::None;
}
auto buildDiagnoser() {
- return CFGMatchSwitchBuilder<const Environment, llvm::Optional<const Stmt*>>()
+ return CFGMatchSwitchBuilder<const Environment, llvm::Optional<CFGElement>>()
// (*)
.CaseOfCFGStmt<UnaryOperator>(isPointerDereference(), diagnoseDereference)
// (->)
diff --git a/nullability_verification/pointer_nullability_diagnosis.h b/nullability_verification/pointer_nullability_diagnosis.h
index e688210..5547e77 100644
--- a/nullability_verification/pointer_nullability_diagnosis.h
+++ b/nullability_verification/pointer_nullability_diagnosis.h
@@ -30,14 +30,14 @@
///
/// TODO(b/233582219): Extend diagnosis to return more information, e.g. the
/// type of violation.
- llvm::Optional<const Stmt*> diagnose(const CFGElement* Elt, ASTContext& Ctx,
- const dataflow::Environment& Env) {
+ llvm::Optional<CFGElement> diagnose(const CFGElement* Elt, ASTContext& Ctx,
+ const dataflow::Environment& Env) {
return Diagnoser(*Elt, Ctx, Env);
}
private:
dataflow::CFGMatchSwitch<const dataflow::Environment,
- llvm::Optional<const Stmt*>>
+ llvm::Optional<CFGElement>>
Diagnoser;
};
diff --git a/nullability_verification/pointer_nullability_verification_test.cc b/nullability_verification/pointer_nullability_verification_test.cc
index b5af9d1..12a1a44 100644
--- a/nullability_verification/pointer_nullability_verification_test.cc
+++ b/nullability_verification/pointer_nullability_verification_test.cc
@@ -29,7 +29,7 @@
using ::testing::Test;
void checkDiagnostics(llvm::StringRef SourceCode) {
- std::vector<const Stmt *> Diagnostics;
+ std::vector<CFGElement> Diagnostics;
PointerNullabilityDiagnoser Diagnoser;
ASSERT_THAT_ERROR(
checkDataflow<PointerNullabilityAnalysis>(
@@ -58,9 +58,17 @@
ExpectedLines.insert(Line);
}
auto &SrcMgr = AnalysisData.ASTCtx.getSourceManager();
- for (auto *Stmt : Diagnostics) {
- ActualLines.insert(
- SrcMgr.getPresumedLineNumber(Stmt->getBeginLoc()));
+ for (auto Element : Diagnostics) {
+ if (Optional<CFGStmt> stmt = Element.getAs<CFGStmt>()) {
+ ActualLines.insert(SrcMgr.getPresumedLineNumber(
+ stmt->getStmt()->getBeginLoc()));
+ } else if (Optional<CFGInitializer> init =
+ Element.getAs<CFGInitializer>()) {
+ ActualLines.insert(SrcMgr.getPresumedLineNumber(
+ init->getInitializer()->getSourceLocation()));
+ } else {
+ ADD_FAILURE() << "this code should not be reached";
+ }
}
EXPECT_THAT(ActualLines, ContainerEq(ExpectedLines));
}),