Rewrite getNullabilityAnnotationsFromType with a TypeVisitor.
We have refactored the function that gathers the nullability annotations from sugared types. Previously, this function used several if-statements to recurse over a type and build its vector of nullability annotations. Now, we do so using a custom TypeVisitor class, which is more idiomatic (and hopefully more readable too!).
PiperOrigin-RevId: 491417652
diff --git a/nullability_verification/pointer_nullability_analysis.cc b/nullability_verification/pointer_nullability_analysis.cc
index 274561d..43015f4 100644
--- a/nullability_verification/pointer_nullability_analysis.cc
+++ b/nullability_verification/pointer_nullability_analysis.cc
@@ -14,6 +14,7 @@
#include "clang/AST/OperationKinds.h"
#include "clang/AST/Stmt.h"
#include "clang/AST/Type.h"
+#include "clang/AST/TypeVisitor.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/Analysis/FlowSensitive/CFGMatchSwitch.h"
#include "clang/Analysis/FlowSensitive/DataflowEnvironment.h"
@@ -38,39 +39,54 @@
namespace {
-void getNullabilityAnnotationsFromTypeImpl(
- QualType T, std::vector<NullabilityKind>& Result) {
- if (auto ET = T->getAs<ElaboratedType>()) {
- getNullabilityAnnotationsFromTypeImpl(ET->getNamedType(), Result);
- } else if (auto TST = T->getAs<TemplateSpecializationType>()) {
+class GetNullabilityAnnotationsFromTypeVisitor
+ : public TypeVisitor<GetNullabilityAnnotationsFromTypeVisitor> {
+ std::vector<NullabilityKind> NullabilityAnnotations;
+
+ public:
+ std::vector<NullabilityKind> getNullabilityAnnotations() && {
+ return std::move(NullabilityAnnotations);
+ }
+
+ void Visit(QualType T) { TypeVisitor::Visit(T.getTypePtr()); }
+
+ void VisitElaboratedType(const ElaboratedType* ET) {
+ Visit(ET->getNamedType());
+ }
+
+ void VisitTemplateSpecializationType(const TemplateSpecializationType* TST) {
for (auto TA : TST->template_arguments()) {
if (TA.getKind() == TemplateArgument::Type) {
- getNullabilityAnnotationsFromTypeImpl(TA.getAsType(), Result);
+ Visit(TA.getAsType());
}
}
- } else if (auto AT = T->getAs<AttributedType>()) {
+ }
+
+ void VisitAttributedType(const AttributedType* AT) {
Optional<NullabilityKind> NK = AT->getImmediateNullability();
if (NK.has_value()) {
- Result.push_back(AT->getImmediateNullability().value());
+ NullabilityAnnotations.push_back(AT->getImmediateNullability().value());
QualType MT = AT->getModifiedType();
if (auto PT = MT->getAs<PointerType>()) {
- getNullabilityAnnotationsFromTypeImpl(PT->getPointeeType(), Result);
+ Visit(PT->getPointeeType());
} else {
// TODO: Handle this unusual yet possible (e.g. through typedefs)
// case.
- llvm::dbgs() << "\nThe type " << T
+ llvm::dbgs() << "\nThe type " << AT
<< "contains a nullability annotation that is not "
<< "succeeded by a pointer type. "
<< "This occurence is not currently handled.\n";
}
} else {
- getNullabilityAnnotationsFromTypeImpl(AT->getModifiedType(), Result);
+ Visit(AT->getModifiedType());
}
- } else if (auto PtrT = T->getAs<PointerType>()) {
- Result.push_back(NullabilityKind::Unspecified);
- getNullabilityAnnotationsFromTypeImpl(PtrT->getPointeeType(), Result);
}
-}
+
+ void VisitPointerType(const PointerType* PT) {
+ NullabilityAnnotations.push_back(NullabilityKind::Unspecified);
+ Visit(PT->getPointeeType());
+ }
+};
/// Traverse over a type to get its nullability. For example, if T is the type
/// Struct3Arg<int * _Nonnull, int, pair<int * _Nullable, int *>> * _Nonnull,
@@ -78,9 +94,9 @@
/// _Nullable, _Unknown}. Note that non-pointer elements (e.g., the second
/// argument of Struct3Arg) do not get a nullability annotation.
std::vector<NullabilityKind> getNullabilityAnnotationsFromType(QualType T) {
- std::vector<NullabilityKind> Result;
- getNullabilityAnnotationsFromTypeImpl(T, Result);
- return Result;
+ GetNullabilityAnnotationsFromTypeVisitor AnnotationVisitor;
+ AnnotationVisitor.Visit(T);
+ return std::move(AnnotationVisitor).getNullabilityAnnotations();
}
unsigned countPointersInType(QualType T) {