Refactor template nullability functions using TypeVisitors.

We have refactored substituteNullabilityAnnotationsInTemplate and countPointersInType to use TypeVisitors instead of if-statements. This makes the code more idiomatic and readable.

PiperOrigin-RevId: 491886720
diff --git a/nullability_verification/pointer_nullability_analysis.cc b/nullability_verification/pointer_nullability_analysis.cc
index a2416e3..2732345 100644
--- a/nullability_verification/pointer_nullability_analysis.cc
+++ b/nullability_verification/pointer_nullability_analysis.cc
@@ -98,15 +98,41 @@
   return std::move(AnnotationVisitor).getNullabilityAnnotations();
 }
 
-unsigned countPointersInType(QualType T) {
-  if (auto ET = T->getAs<ElaboratedType>()) {
-    return countPointersInType(ET->getNamedType());
-  } else if (auto AT = T->getAs<AttributedType>()) {
-    return countPointersInType(AT->getModifiedType());
-  } else if (auto PtrT = T->getAs<PointerType>()) {
-    return 1 + countPointersInType(PtrT->getPointeeType());
+class CountPointersInTypeVisitor
+    : public TypeVisitor<CountPointersInTypeVisitor> {
+  unsigned count = 0;
+
+ public:
+  CountPointersInTypeVisitor() {}
+
+  unsigned getCount() { return count; }
+
+  void Visit(QualType T) { TypeVisitor::Visit(T.getTypePtrOrNull()); }
+
+  void VisitElaboratedType(const ElaboratedType* ET) {
+    Visit(ET->getNamedType());
   }
-  return 0;
+
+  void VisitAttributedType(const AttributedType* AT) {
+    Visit(AT->getModifiedType());
+  }
+
+  void VisitPointerType(const PointerType* PT) {
+    count += 1;
+    Visit(PT->getPointeeType());
+  }
+
+  void Visit(TemplateArgument TA) {
+    if (TA.getKind() == TemplateArgument::Type) {
+      Visit(TA.getAsType());
+    }
+  }
+};
+
+unsigned countPointersInType(QualType T) {
+  CountPointersInTypeVisitor PointerCountVisitor;
+  PointerCountVisitor.Visit(T);
+  return PointerCountVisitor.getCount();
 }
 
 unsigned countPointersInType(TemplateArgument TA) {
@@ -149,30 +175,48 @@
   return ArrayRef<NullabilityKind>();
 }
 
-void substituteNullabilityAnnotationsInTemplateImpl(
-    std::vector<NullabilityKind>& Result, QualType T,
-    ArrayRef<NullabilityKind> BaseNullabilityAnnotations, QualType BaseType) {
-  if (auto ST = T->getAs<SubstTemplateTypeParmType>()) {
+class SubstituteNullabilityAnnotationsInTemplateVisitor
+    : public TypeVisitor<SubstituteNullabilityAnnotationsInTemplateVisitor> {
+  QualType BaseType;
+  ArrayRef<NullabilityKind> BaseNullabilityAnnotations;
+  std::vector<NullabilityKind> NullabilityAnnotations;
+
+ public:
+  SubstituteNullabilityAnnotationsInTemplateVisitor(
+      QualType BaseType, ArrayRef<NullabilityKind> BaseNullabilityAnnotations)
+      : BaseType(BaseType),
+        BaseNullabilityAnnotations(BaseNullabilityAnnotations) {}
+
+  std::vector<NullabilityKind> getNullabilityAnnotations() && {
+    return std::move(NullabilityAnnotations);
+  }
+
+  void Visit(QualType T) { TypeVisitor::Visit(T.getTypePtr()); }
+
+  void VisitSubstTemplateTypeParmType(const SubstTemplateTypeParmType* ST) {
     for (auto NK : getNullabilityForTemplateParameter(
              ST, BaseNullabilityAnnotations, BaseType)) {
-      Result.push_back(NK);
+      NullabilityAnnotations.push_back(NK);
     }
-  } else if (auto PtrT = T->getAs<PointerType>()) {
-    Result.push_back(NullabilityKind::Unspecified);
-    substituteNullabilityAnnotationsInTemplateImpl(
-        Result, PtrT->getPointeeType(), BaseNullabilityAnnotations, BaseType);
-  } else if (auto ET = T->getAs<ElaboratedType>()) {
-    substituteNullabilityAnnotationsInTemplateImpl(
-        Result, ET->getNamedType(), BaseNullabilityAnnotations, BaseType);
-  } else if (auto TST = T->getAs<TemplateSpecializationType>()) {
+  }
+
+  void VisitPointerType(const PointerType* PT) {
+    NullabilityAnnotations.push_back(NullabilityKind::Unspecified);
+    Visit(PT->getPointeeType());
+  }
+
+  void VisitElaboratedType(const ElaboratedType* ET) {
+    Visit(ET->getNamedType());
+  }
+
+  void VisitTemplateSpecializationType(const TemplateSpecializationType* TST) {
     for (auto TA : TST->template_arguments()) {
       if (TA.getKind() == TemplateArgument::Type) {
-        substituteNullabilityAnnotationsInTemplateImpl(
-            Result, TA.getAsType(), BaseNullabilityAnnotations, BaseType);
+        Visit(TA.getAsType());
       }
     }
   }
-}
+};
 
 /// Similar to getNullabilityForTemplateParameter, but here we get the
 /// nullability annotation for a type that *contains* another type that was
@@ -197,10 +241,10 @@
 std::vector<NullabilityKind> substituteNullabilityAnnotationsInTemplate(
     QualType T, ArrayRef<NullabilityKind> BaseNullabilityAnnotations,
     QualType BaseType) {
-  std::vector<NullabilityKind> Result;
-  substituteNullabilityAnnotationsInTemplateImpl(
-      Result, T, BaseNullabilityAnnotations, BaseType);
-  return Result;
+  SubstituteNullabilityAnnotationsInTemplateVisitor AnnotationVisitor(
+      BaseType, BaseNullabilityAnnotations);
+  AnnotationVisitor.Visit(T);
+  return std::move(AnnotationVisitor).getNullabilityAnnotations();
 }
 
 /// Get nullability annotations of the base type. For example, in the member