Handle more shapes of template instantiations by recursively computing nullability annotations.

This change has to do with propagating nullability information through template instantiations. Previously, we used if-statements to "pattern match" over method calls of template instantiated classes and propagate their nullability information. However, this pattern matching only handled very specific expression shapes.

We are replacing the type pattern-matching with a recursive algorithm that can handle many type shapes, including templates with arguments that have more than one pointer type (e.g. a nested pointer or pair of pointers), or have no pointer at all (e.g., an integer argument). We are also adding support for propagating the nullability of templated member calls (e.g, x.f, where x is a template-instantiated class).

We have added test cases to cover the new functionality. We test member call expressions from structs with varying number of typename and int arguments. We have also added a test case with a struct that uses another struct template in a member variable. Nested member calls (for example, x.f.g) are commented out as the current implementation does not handle them.

PiperOrigin-RevId: 490708002
diff --git a/nullability_verification/pointer_nullability_analysis.cc b/nullability_verification/pointer_nullability_analysis.cc
index 2eec8a7..6201365 100644
--- a/nullability_verification/pointer_nullability_analysis.cc
+++ b/nullability_verification/pointer_nullability_analysis.cc
@@ -38,49 +38,216 @@
 
 namespace {
 
-NullabilityKind getNullabilityFromTemplatedExpression(const Expr* E) {
-  // Stores the nullability of each template argument.
-  std::vector<NullabilityKind> NullabilityVector = {};
-  const TemplateTypeParmDecl* ReplacedParameter = nullptr;
+void getNullabilityAnnotationsFromTypeImpl(
+    QualType T, std::vector<NullabilityKind>& Result) {
+  if (auto ET = T->getAs<ElaboratedType>()) {
+    getNullabilityAnnotationsFromTypeImpl(ET->getNamedType(), Result);
+  } else if (auto TST = T->getAs<TemplateSpecializationType>()) {
+    for (auto TA : TST->template_arguments()) {
+      if (TA.getKind() == TemplateArgument::Type) {
+        getNullabilityAnnotationsFromTypeImpl(TA.getAsType(), Result);
+      }
+    }
+  } else if (auto AT = T->getAs<AttributedType>()) {
+    Optional<NullabilityKind> NK = AT->getImmediateNullability();
+    if (NK.has_value()) {
+      Result.push_back(AT->getImmediateNullability().value());
+      QualType MT = AT->getModifiedType();
+      if (auto PT = MT->getAs<PointerType>()) {
+        getNullabilityAnnotationsFromTypeImpl(PT->getPointeeType(), Result);
+      } else {
+        // TODO: Handle this unusual yet possible (e.g. through typedefs)
+        // case.
+        llvm::dbgs() << "\nThe type " << T
+                     << "contains a nullability annotation that is not "
+                     << "succeeded by a pointer type. "
+                     << "This occurence is not currently handled.\n";
+      }
+    } else {
+      getNullabilityAnnotationsFromTypeImpl(AT->getModifiedType(), Result);
+    }
+  } else if (auto PtrT = T->getAs<PointerType>()) {
+    Result.push_back(NullabilityKind::Unspecified);
+    getNullabilityAnnotationsFromTypeImpl(PtrT->getPointeeType(), Result);
+  }
+}
 
-  // If the expression is a member function call on an object whose type
-  // is a class template instantiation, propagate the sugar from template
-  // arguments to the member function return type.
-  //
-  // TODO: Handle more expression shapes.
-  if (auto MemberCall = dyn_cast<CXXMemberCallExpr>(E)) {
-    Expr* Object = MemberCall->getImplicitObjectArgument();
-    if (auto Member = dyn_cast<MemberExpr>(MemberCall->getCallee())) {
-      if (auto Method = dyn_cast<CXXMethodDecl>(Member->getMemberDecl())) {
-        if (auto TST = Object->getType()->getAs<TemplateSpecializationType>()) {
-          for (TemplateArgument TA : TST->template_arguments()) {
-            NullabilityKind ArgumentNullability = NullabilityKind::Unspecified;
-            if (TA.getKind() == TemplateArgument::Type) {
-              if (auto AT = TA.getAsType()->getAs<AttributedType>()) {
-                ArgumentNullability = AT->getImmediateNullability().value_or(
-                    NullabilityKind::Unspecified);
-              }
-            }
-            NullabilityVector.push_back(ArgumentNullability);
-          }
-        }
+/// Traverse over a type to get its nullability. For example, if T is the type
+/// Struct3Arg<int * _Nonnull, int, pair<int * _Nullable, int *>> * _Nonnull,
+/// the resulting nullability annotations will be {_Nonnull, _Nonnull,
+/// _Nullable, _Unknown}. Note that non-pointer elements (e.g., the second
+/// argument of Struct3Arg) do not get a nullability annotation.
+std::vector<NullabilityKind> getNullabilityAnnotationsFromType(QualType T) {
+  std::vector<NullabilityKind> Result;
+  getNullabilityAnnotationsFromTypeImpl(T, Result);
+  return Result;
+}
 
-        // Save the replaced template parameter.
-        // TODO: Handle cases where the template argument is nested inside the
-        // return type (e.g. vector<map<T**, T>>).
-        if (auto SubstTemplate =
-                Method->getReturnType()->getAs<SubstTemplateTypeParmType>()) {
-          ReplacedParameter = SubstTemplate->getReplacedParameter();
-        }
+unsigned countPointersInType(QualType T) {
+  if (auto ET = T->getAs<ElaboratedType>()) {
+    return countPointersInType(ET->getNamedType());
+  } else if (auto AT = T->getAs<AttributedType>()) {
+    return countPointersInType(AT->getModifiedType());
+  } else if (auto PtrT = T->getAs<PointerType>()) {
+    return 1 + countPointersInType(PtrT->getPointeeType());
+  }
+  return 0;
+}
+
+unsigned countPointersInType(TemplateArgument TA) {
+  if (TA.getKind() == TemplateArgument::Type) {
+    return countPointersInType(TA.getAsType());
+  }
+  return 0;
+}
+
+/// Use the nullability annotations of the base type to compute the nullability
+/// of a type that was originally written as a template type parameter.
+/// For example, consider the following code:
+///
+/// template <typename T0, typename T1>
+/// struct S {
+///   T0 arg0;
+///   T1 arg1;
+/// };
+/// void target(S<pair<int * _Nullable, int *> * _Nonnull, int * _Nullable> p) {
+///   p.arg0; // (*)
+/// }
+///
+/// Suppose we wish to find the nullability annotations of arg0. The nullability
+/// annotation list of Struct2Arg is {_Nonnull, _Nullable, _Unknown, _Nullable}.
+/// We use this list and information about S to infer that the
+/// nullability annotation list of arg0 is {_Nonnull, _Nullable, _Unknown}.
+ArrayRef<NullabilityKind> getNullabilityForTemplateParameter(
+    const SubstTemplateTypeParmType* STTPT,
+    ArrayRef<NullabilityKind> BaseNullabilityAnnotations, QualType BaseType) {
+  unsigned PointerCount = 0;
+  unsigned ArgIndex = STTPT->getIndex();
+  if (auto TST = BaseType->getAs<TemplateSpecializationType>()) {
+    for (auto TA : TST->template_arguments().take_front(ArgIndex)) {
+      PointerCount += countPointersInType(TA);
+    }
+    unsigned SliceSize =
+        countPointersInType(TST->template_arguments()[ArgIndex]);
+    return BaseNullabilityAnnotations.slice(PointerCount, SliceSize);
+  }
+  return ArrayRef<NullabilityKind>();
+}
+
+void substituteNullabilityAnnotationsInTemplateImpl(
+    std::vector<NullabilityKind>& Result, QualType T,
+    ArrayRef<NullabilityKind> BaseNullabilityAnnotations, QualType BaseType) {
+  if (auto ST = T->getAs<SubstTemplateTypeParmType>()) {
+    for (auto NK : getNullabilityForTemplateParameter(
+             ST, BaseNullabilityAnnotations, BaseType)) {
+      Result.push_back(NK);
+    }
+  } else if (auto PtrT = T->getAs<PointerType>()) {
+    Result.push_back(NullabilityKind::Unspecified);
+    substituteNullabilityAnnotationsInTemplateImpl(
+        Result, PtrT->getPointeeType(), BaseNullabilityAnnotations, BaseType);
+  } else if (auto ET = T->getAs<ElaboratedType>()) {
+    substituteNullabilityAnnotationsInTemplateImpl(
+        Result, ET->getNamedType(), BaseNullabilityAnnotations, BaseType);
+  } else if (auto TST = T->getAs<TemplateSpecializationType>()) {
+    for (auto TA : TST->template_arguments()) {
+      if (TA.getKind() == TemplateArgument::Type) {
+        substituteNullabilityAnnotationsInTemplateImpl(
+            Result, TA.getAsType(), BaseNullabilityAnnotations, BaseType);
       }
     }
   }
+}
 
-  NullabilityKind Nullability = NullabilityKind::Unspecified;
-  if (ReplacedParameter && ReplacedParameter->getDepth() == 0) {
-    Nullability = NullabilityVector[ReplacedParameter->getIndex()];
+/// Similar to getNullabilityForTemplateParameter, but here we get the
+/// nullability annotation for a type that *contains* another type that was
+/// originally written as a template type parameter. For example, consider the
+/// following code:
+///
+/// template <typename T0, typename T1>
+/// struct Struct2Arg {
+///   T1 *_Nullable getNullableT1Ptr();
+/// };
+/// void target(Struct2Arg<int *, int *_Nonnull> &x) {
+///   x.getNullableT1Ptr();
+/// }
+///
+/// Suppose we wish to find the nullability annotations of x.getNullableT1Ptr().
+/// The return type of this method call is T1 * _Nullable, so its outer
+/// nullability is "_Nullable". Then, we continue recursing over this type to
+/// find the rest of the nullability annotation. We call
+/// getNullabilityFromTemplateParameter to find that T1 has nullability
+/// annotation {_Nonnull}. Thus, our complete nullability annotation for this
+/// member call is {_Nullable, _Nonnull}.
+std::vector<NullabilityKind> substituteNullabilityAnnotationsInTemplate(
+    QualType T, ArrayRef<NullabilityKind> BaseNullabilityAnnotations,
+    QualType BaseType) {
+  std::vector<NullabilityKind> Result;
+  substituteNullabilityAnnotationsInTemplateImpl(
+      Result, T, BaseNullabilityAnnotations, BaseType);
+  return Result;
+}
+
+/// Get nullability annotations of the base type. For example, in the member
+/// expression x.f or the member call x.getF(), x is the base object and its
+/// type is the base type.
+std::vector<NullabilityKind> getBaseNullabilityAnnotations(const Expr* E) {
+  if (auto ME = dyn_cast<MemberExpr>(E)) {
+    return getBaseNullabilityAnnotations(ME->getBase());
+  } else if (auto MC = dyn_cast<CXXMemberCallExpr>(E)) {
+    return getBaseNullabilityAnnotations(MC->getImplicitObjectArgument());
+  } else if (auto DRE = dyn_cast<DeclRefExpr>(E)) {
+    return getNullabilityAnnotationsFromType(DRE->getType());
   }
-  return Nullability;
+  // TODO: Handle other expression shapes.
+  return std::vector<NullabilityKind>();
+}
+
+QualType getBaseType(const Expr* E) {
+  if (auto ME = dyn_cast<MemberExpr>(E)) {
+    return getBaseType(ME->getBase());
+  } else if (auto MC = dyn_cast<CXXMemberCallExpr>(E)) {
+    return getBaseType(MC->getImplicitObjectArgument());
+  } else if (auto DRE = dyn_cast<DeclRefExpr>(E)) {
+    return DRE->getType();
+  }
+  // TODO: Handle other expression shapes and base types.
+  else {
+    llvm::dbgs() << "\nWe cannot get this base type yet...\n";
+  }
+  return QualType();
+}
+
+std::vector<NullabilityKind> getNullabilityAnnotations(
+    const Expr* E, ArrayRef<NullabilityKind> BaseNullabilityAnnotations,
+    QualType BaseType) {
+  if (auto ME = dyn_cast<MemberExpr>(E)) {
+    return substituteNullabilityAnnotationsInTemplate(
+        ME->getType(), BaseNullabilityAnnotations, BaseType);
+  } else if (auto MC = dyn_cast<CXXMemberCallExpr>(E)) {
+    return substituteNullabilityAnnotationsInTemplate(
+        MC->getType(), BaseNullabilityAnnotations, BaseType);
+  }
+  // TODO: Handle other expression shapes.
+  return std::vector<NullabilityKind>();
+}
+
+/// Given an expression E that refers to a member variable or a member function
+/// of a template specialization, construct the nullability vector
+/// of its base type and use it to compute the nullability of E. E's nullability
+/// will itself be a vector; this is to account for cases in which E is
+/// composed of more than one pointer. We return the first element of E's
+/// nullability vector (i.e., E's "outer" nullability).
+NullabilityKind getNullabilityFromTemplatedExpression(const Expr* E) {
+  std::vector<NullabilityKind> BaseNullabilityAnnotations =
+      getBaseNullabilityAnnotations(E);
+  QualType BaseType = getBaseType(E);
+  std::vector<NullabilityKind> NullabilityAnnotations =
+      getNullabilityAnnotations(E, BaseNullabilityAnnotations, BaseType);
+  if (NullabilityAnnotations.empty()) {
+    return NullabilityKind::Unspecified;
+  }
+  return NullabilityAnnotations[0];
 }
 
 NullabilityKind getPointerNullability(const Expr* E, ASTContext& Ctx) {
@@ -88,7 +255,8 @@
   NullabilityKind Nullability =
       ExprType->getNullability(Ctx).value_or(NullabilityKind::Unspecified);
   if (Nullability == NullabilityKind::Unspecified) {
-    // Try to get nullability from the expression itself.
+    // If the type does not contain nullability information, try to gather it
+    // from the expression itself.
     Nullability = getNullabilityFromTemplatedExpression(E);
   }
   return Nullability;
diff --git a/nullability_verification/pointer_nullability_verification_test.cc b/nullability_verification/pointer_nullability_verification_test.cc
index e6485f5..f9e7a92 100644
--- a/nullability_verification/pointer_nullability_verification_test.cc
+++ b/nullability_verification/pointer_nullability_verification_test.cc
@@ -1431,7 +1431,82 @@
   )");
 }
 
-TEST(PointerNullabilityTest, MemberFunctionOfClassTemplateInstantiation) {
+TEST(PointerNullabilityTest, MemberExpressionOfClassTemplateInstantiation) {
+  // Struct with 2 arguments with nullable second argument.
+  checkDiagnostics(R"cc(
+    template <typename T0, typename T1>
+    struct Struct2Arg {
+      T0 arg0;
+      T1 arg1;
+    };
+    void target(Struct2Arg<int* _Nonnull, double* _Nullable> p) {
+      *p.arg0;
+      *p.arg1;  // [[unsafe]]
+    }
+  )cc");
+
+  // Struct with 5 arguments with interleaved nullable/nonnull/unknown.
+  checkDiagnostics(R"cc(
+    template <typename T0, typename T1, typename T2, typename T3, typename T4>
+    struct Struct5Arg {
+      T0 arg0;
+      T1 arg1;
+      T2 arg2;
+      T3 arg3;
+      T4 arg4;
+    };
+    void target(Struct5Arg<int* _Nullable, double* _Nonnull, float*,
+                           double* _Nullable, int* _Nonnull>
+                    p) {
+      *p.arg0;  // [[unsafe]]
+      *p.arg1;
+      *p.arg2;
+      *p.arg3;  // [[unsafe]]
+      *p.arg4;
+    }
+  )cc");
+
+  // Struct with interleaved int and typename arguments.
+  checkDiagnostics(R"cc(
+    template <typename T0, int I1, typename T2, int T3, typename T4>
+    struct Struct5Arg {
+      T0 arg0;
+      T2 arg2;
+      T4 arg4;
+    };
+    void target(Struct5Arg<int* _Nullable, 0, float*, 1, int* _Nullable> p) {
+      *p.arg0;  // [[unsafe]]
+      *p.arg2;
+      *p.arg4;  // [[unsafe]]
+    }
+  )cc");
+
+  // Struct template that uses another struct template in a member variable.
+  checkDiagnostics(R"cc(
+    template <typename T0, typename T1>
+    struct Struct2Arg {
+      T0 arg0;
+      T1 arg1;
+    };
+
+    template <typename TN0, typename TN1>
+    struct Struct2ArgNested {
+      Struct2Arg<TN1, Struct2Arg<TN0, TN1>>* arg0;
+      Struct2Arg<TN1, Struct2Arg<TN0, TN1>>* _Nullable arg1;
+    };
+    void target(Struct2ArgNested<int* _Nonnull, double* _Nullable> p) {
+      *p.arg0;
+      *p.arg1;  // [[unsafe]]
+
+      // TODO: The following lines currently crash at getBaseType()
+      //*p.arg0->arg0; // false-positive
+      //*p.arg0->arg1.arg0;
+      //*p.arg0->arg1.arg1; // false-positive
+    }
+  )cc");
+}
+
+TEST(PointerNullabilityTest, MemberCallExpressionOfClassTemplateInstantiation) {
   // Struct with one argument initialised as _Nullable.
   checkDiagnostics(R"cc(
     template <typename T0>