Refactor template nullability functions using TypeVisitors.
We have refactored substituteNullabilityAnnotationsInTemplate and countPointersInType to use TypeVisitors instead of if-statements. This makes the code more idiomatic and readable.
PiperOrigin-RevId: 491886720
diff --git a/nullability_verification/pointer_nullability_analysis.cc b/nullability_verification/pointer_nullability_analysis.cc
index a2416e3..2732345 100644
--- a/nullability_verification/pointer_nullability_analysis.cc
+++ b/nullability_verification/pointer_nullability_analysis.cc
@@ -98,15 +98,41 @@
return std::move(AnnotationVisitor).getNullabilityAnnotations();
}
-unsigned countPointersInType(QualType T) {
- if (auto ET = T->getAs<ElaboratedType>()) {
- return countPointersInType(ET->getNamedType());
- } else if (auto AT = T->getAs<AttributedType>()) {
- return countPointersInType(AT->getModifiedType());
- } else if (auto PtrT = T->getAs<PointerType>()) {
- return 1 + countPointersInType(PtrT->getPointeeType());
+class CountPointersInTypeVisitor
+ : public TypeVisitor<CountPointersInTypeVisitor> {
+ unsigned count = 0;
+
+ public:
+ CountPointersInTypeVisitor() {}
+
+ unsigned getCount() { return count; }
+
+ void Visit(QualType T) { TypeVisitor::Visit(T.getTypePtrOrNull()); }
+
+ void VisitElaboratedType(const ElaboratedType* ET) {
+ Visit(ET->getNamedType());
}
- return 0;
+
+ void VisitAttributedType(const AttributedType* AT) {
+ Visit(AT->getModifiedType());
+ }
+
+ void VisitPointerType(const PointerType* PT) {
+ count += 1;
+ Visit(PT->getPointeeType());
+ }
+
+ void Visit(TemplateArgument TA) {
+ if (TA.getKind() == TemplateArgument::Type) {
+ Visit(TA.getAsType());
+ }
+ }
+};
+
+unsigned countPointersInType(QualType T) {
+ CountPointersInTypeVisitor PointerCountVisitor;
+ PointerCountVisitor.Visit(T);
+ return PointerCountVisitor.getCount();
}
unsigned countPointersInType(TemplateArgument TA) {
@@ -149,30 +175,48 @@
return ArrayRef<NullabilityKind>();
}
-void substituteNullabilityAnnotationsInTemplateImpl(
- std::vector<NullabilityKind>& Result, QualType T,
- ArrayRef<NullabilityKind> BaseNullabilityAnnotations, QualType BaseType) {
- if (auto ST = T->getAs<SubstTemplateTypeParmType>()) {
+class SubstituteNullabilityAnnotationsInTemplateVisitor
+ : public TypeVisitor<SubstituteNullabilityAnnotationsInTemplateVisitor> {
+ QualType BaseType;
+ ArrayRef<NullabilityKind> BaseNullabilityAnnotations;
+ std::vector<NullabilityKind> NullabilityAnnotations;
+
+ public:
+ SubstituteNullabilityAnnotationsInTemplateVisitor(
+ QualType BaseType, ArrayRef<NullabilityKind> BaseNullabilityAnnotations)
+ : BaseType(BaseType),
+ BaseNullabilityAnnotations(BaseNullabilityAnnotations) {}
+
+ std::vector<NullabilityKind> getNullabilityAnnotations() && {
+ return std::move(NullabilityAnnotations);
+ }
+
+ void Visit(QualType T) { TypeVisitor::Visit(T.getTypePtr()); }
+
+ void VisitSubstTemplateTypeParmType(const SubstTemplateTypeParmType* ST) {
for (auto NK : getNullabilityForTemplateParameter(
ST, BaseNullabilityAnnotations, BaseType)) {
- Result.push_back(NK);
+ NullabilityAnnotations.push_back(NK);
}
- } else if (auto PtrT = T->getAs<PointerType>()) {
- Result.push_back(NullabilityKind::Unspecified);
- substituteNullabilityAnnotationsInTemplateImpl(
- Result, PtrT->getPointeeType(), BaseNullabilityAnnotations, BaseType);
- } else if (auto ET = T->getAs<ElaboratedType>()) {
- substituteNullabilityAnnotationsInTemplateImpl(
- Result, ET->getNamedType(), BaseNullabilityAnnotations, BaseType);
- } else if (auto TST = T->getAs<TemplateSpecializationType>()) {
+ }
+
+ void VisitPointerType(const PointerType* PT) {
+ NullabilityAnnotations.push_back(NullabilityKind::Unspecified);
+ Visit(PT->getPointeeType());
+ }
+
+ void VisitElaboratedType(const ElaboratedType* ET) {
+ Visit(ET->getNamedType());
+ }
+
+ void VisitTemplateSpecializationType(const TemplateSpecializationType* TST) {
for (auto TA : TST->template_arguments()) {
if (TA.getKind() == TemplateArgument::Type) {
- substituteNullabilityAnnotationsInTemplateImpl(
- Result, TA.getAsType(), BaseNullabilityAnnotations, BaseType);
+ Visit(TA.getAsType());
}
}
}
-}
+};
/// Similar to getNullabilityForTemplateParameter, but here we get the
/// nullability annotation for a type that *contains* another type that was
@@ -197,10 +241,10 @@
std::vector<NullabilityKind> substituteNullabilityAnnotationsInTemplate(
QualType T, ArrayRef<NullabilityKind> BaseNullabilityAnnotations,
QualType BaseType) {
- std::vector<NullabilityKind> Result;
- substituteNullabilityAnnotationsInTemplateImpl(
- Result, T, BaseNullabilityAnnotations, BaseType);
- return Result;
+ SubstituteNullabilityAnnotationsInTemplateVisitor AnnotationVisitor(
+ BaseType, BaseNullabilityAnnotations);
+ AnnotationVisitor.Visit(T);
+ return std::move(AnnotationVisitor).getNullabilityAnnotations();
}
/// Get nullability annotations of the base type. For example, in the member