[nullability] Add `isSupportedRawPointerType()` and `isSupportedSmartPointerType()`.

Existing uses of `isSupportedPointerType()` have been replaced with
`isSupportedRawPointerType()` for now to ensure that we don't get any assertion
failures or crashes on code that isn't prepared to deal with smart pointers.

PiperOrigin-RevId: 581936520
Change-Id: Ib0ff768e3b17ed74d170d6d2fd2c30dabbecf117
diff --git a/nullability/inference/collect_evidence.cc b/nullability/inference/collect_evidence.cc
index f69afd4..b7398c7 100644
--- a/nullability/inference/collect_evidence.cc
+++ b/nullability/inference/collect_evidence.cc
@@ -266,9 +266,9 @@
   for (; ParamI < CalleeDecl.param_size(); ++ParamI, ++ArgI) {
     const auto *ParamDecl = CalleeDecl.getParamDecl(ParamI);
     const auto ParamType = ParamDecl->getType().getNonReferenceType();
-    if (!isSupportedPointerType(ParamType)) continue;
+    if (!isSupportedRawPointerType(ParamType)) continue;
     // the corresponding argument should also be a pointer.
-    CHECK(isSupportedPointerType(Expr.getArg(ArgI)->getType()));
+    CHECK(isSupportedRawPointerType(Expr.getArg(ArgI)->getType()));
 
     dataflow::PointerValue *PV =
         getPointerValueFromExpr(Expr.getArg(ArgI), Env);
@@ -351,7 +351,7 @@
   auto *ReturnStmt = dyn_cast_or_null<clang::ReturnStmt>(&Stmt);
   if (!ReturnStmt) return;
   auto *ReturnExpr = ReturnStmt->getRetValue();
-  if (!ReturnExpr || !isSupportedPointerType(ReturnExpr->getType())) return;
+  if (!ReturnExpr || !isSupportedRawPointerType(ReturnExpr->getType())) return;
 
   // Skip gathering evidence about the current function if the current function
   // is not an inference target.
@@ -386,7 +386,7 @@
     for (auto *Decl : DeclStmt->decls()) {
       if (auto *VarDecl = dyn_cast_or_null<clang::VarDecl>(Decl);
           VarDecl && VarDecl->hasInit()) {
-        bool DeclTypeSupported = isSupportedPointerType(VarDecl->getType());
+        bool DeclTypeSupported = isSupportedRawPointerType(VarDecl->getType());
         bool InitTypeSupported =
             isSupportedPointerType(VarDecl->getInit()->getType());
         if (!DeclTypeSupported) return;
@@ -415,9 +415,9 @@
       BinaryOperator &&
       BinaryOperator->getOpcode() == clang::BinaryOperatorKind::BO_Assign) {
     bool LhsSupported =
-        isSupportedPointerType(BinaryOperator->getLHS()->getType());
+        isSupportedRawPointerType(BinaryOperator->getLHS()->getType());
     bool RhsSupported =
-        isSupportedPointerType(BinaryOperator->getRHS()->getType());
+        isSupportedRawPointerType(BinaryOperator->getRHS()->getType());
     if (!LhsSupported) return;
     if (!RhsSupported) {
       // TODO: we could perhaps support pointer assignments to numeric
@@ -465,7 +465,7 @@
 }
 
 std::optional<Evidence::Kind> evidenceKindFromDeclaredType(QualType T) {
-  if (!isSupportedPointerType(T.getNonReferenceType())) return std::nullopt;
+  if (!isSupportedRawPointerType(T.getNonReferenceType())) return std::nullopt;
   auto Nullability = getNullabilityAnnotationsFromType(T);
   switch (Nullability.front().concrete()) {
     default:
@@ -544,7 +544,7 @@
   auto Parameters = Func->parameters();
   for (auto I = 0; I < Parameters.size(); ++I) {
     auto T = Parameters[I]->getType().getNonReferenceType();
-    if (isSupportedPointerType(T) && !evidenceKindFromDeclaredType(T)) {
+    if (isSupportedRawPointerType(T) && !evidenceKindFromDeclaredType(T)) {
       InferableSlots.push_back(
           std::make_pair(Analysis.assignNullabilityVariable(
                              Parameters[I], AnalysisContext.arena()),
diff --git a/nullability/inference/inferable.cc b/nullability/inference/inferable.cc
index 7347118..1db583c 100644
--- a/nullability/inference/inferable.cc
+++ b/nullability/inference/inferable.cc
@@ -15,7 +15,7 @@
 namespace {
 
 bool isInferable(QualType T) {
-  return isSupportedPointerType(T.getNonReferenceType());
+  return isSupportedRawPointerType(T.getNonReferenceType());
 }
 
 }  // namespace
diff --git a/nullability/pointer_nullability_analysis.cc b/nullability/pointer_nullability_analysis.cc
index be1e0f8..d7e897f 100644
--- a/nullability/pointer_nullability_analysis.cc
+++ b/nullability/pointer_nullability_analysis.cc
@@ -435,7 +435,7 @@
 
   if (ParamTy.isNull()) return;
   if (ParamTy->getPointeeType().isNull()) return;
-  if (!isSupportedPointerType(ParamTy->getPointeeType())) return;
+  if (!isSupportedRawPointerType(ParamTy->getPointeeType())) return;
   if (ParamTy->getPointeeType().isConstQualified()) return;
 
   // TODO(b/298200521): This should extend support to annotations that suggest
@@ -483,7 +483,7 @@
     }
   }
 
-  if (isSupportedPointerType(CallExpr->getType())) {
+  if (isSupportedRawPointerType(CallExpr->getType())) {
     // Create a pointer so that we can attach nullability to it and have the
     // nullability propagate with the pointer.
     auto *PointerVal = getPointerValueFromExpr(CallExpr, State.Env);
@@ -540,7 +540,7 @@
 void transferFlowSensitiveConstMemberCall(
     const CXXMemberCallExpr *MCE, const MatchFinder::MatchResult &Result,
     TransferState<PointerNullabilityLattice> &State) {
-  if (!isSupportedPointerType(MCE->getType())) return;
+  if (!isSupportedRawPointerType(MCE->getType())) return;
   dataflow::RecordStorageLocation *RecordLoc =
       dataflow::getImplicitObjectLocation(*MCE, State.Env);
   if (RecordLoc == nullptr) return;
@@ -560,7 +560,7 @@
   if (dataflow::RecordStorageLocation *RecordLoc =
           dataflow::getImplicitObjectLocation(*MCE, State.Env)) {
     for (const auto [Field, FieldLoc] : RecordLoc->children()) {
-      if (!isSupportedPointerType(Field->getType())) continue;
+      if (!isSupportedRawPointerType(Field->getType())) continue;
       Value *V = State.Env.createValue(Field->getType());
       State.Env.setValue(*FieldLoc, *V);
     }
@@ -830,8 +830,8 @@
   computeNullability(ASE, State, [&]() {
     auto &BaseNullability = getNullabilityForChild(ASE->getBase(), State);
     QualType BaseType = ASE->getBase()->getType();
-    CHECK(isSupportedPointerType(BaseType) || BaseType->isVectorType());
-    return isSupportedPointerType(BaseType)
+    CHECK(isSupportedRawPointerType(BaseType) || BaseType->isVectorType());
+    return isSupportedRawPointerType(BaseType)
                ? ArrayRef(BaseNullability).slice(1).vec()
                : BaseNullability;
   });
@@ -906,7 +906,8 @@
   if (!S) return;
 
   auto *E = dyn_cast<Expr>(S->getStmt());
-  if (E == nullptr || !E->isPRValue() || !isSupportedPointerType(E->getType()))
+  if (E == nullptr || !E->isPRValue() ||
+      !isSupportedRawPointerType(E->getType()))
     return;
 
   if (Env.getValue(*E) == nullptr)
@@ -982,7 +983,7 @@
                                        const Environment &Env2,
                                        Value &MergedVal,
                                        Environment &MergedEnv) {
-  if (!isSupportedPointerType(Type)) {
+  if (!isSupportedRawPointerType(Type)) {
     return false;
   }
 
diff --git a/nullability/pointer_nullability_diagnosis.cc b/nullability/pointer_nullability_diagnosis.cc
index b33fdf5..5f5b229 100644
--- a/nullability/pointer_nullability_diagnosis.cc
+++ b/nullability/pointer_nullability_diagnosis.cc
@@ -63,7 +63,7 @@
 SmallVector<PointerNullabilityDiagnostic> diagnoseTypeExprCompatibility(
     QualType DeclaredType, const Expr *E, const Environment &Env,
     ASTContext &Ctx) {
-  CHECK(isSupportedPointerType(DeclaredType));
+  CHECK(isSupportedRawPointerType(DeclaredType));
   return getNullabilityKind(DeclaredType, Ctx) == NullabilityKind::NonNull
              ? diagnoseNonnullExpected(E, Env)
              : SmallVector<PointerNullabilityDiagnostic>{};
@@ -96,7 +96,7 @@
   SmallVector<PointerNullabilityDiagnostic> Diagnostics;
   for (unsigned int I = 0; I < Args.size(); ++I) {
     auto ParamType = ParamTypes[I].getNonReferenceType();
-    if (isSupportedPointerType(ParamType))
+    if (isSupportedRawPointerType(ParamType))
       Diagnostics.append(
           diagnoseTypeExprCompatibility(ParamType, Args[I], Env, Ctx));
   }
@@ -260,12 +260,12 @@
   auto ReturnType = cast<FunctionDecl>(State.Env.getDeclCtx())->getReturnType();
 
   // TODO: Handle non-pointer return types.
-  if (!isSupportedPointerType(ReturnType)) {
+  if (!isSupportedRawPointerType(ReturnType)) {
     return {};
   }
 
   auto *ReturnExpr = RS->getRetValue();
-  CHECK(isSupportedPointerType(ReturnExpr->getType()));
+  CHECK(isSupportedRawPointerType(ReturnExpr->getType()));
 
   return diagnoseTypeExprCompatibility(ReturnType, ReturnExpr, State.Env,
                                        *Result.Context);
@@ -276,7 +276,7 @@
     const TransferStateForDiagnostics<PointerNullabilityLattice> &State) {
   CHECK(CI->isAnyMemberInitializer());
   auto MemberType = CI->getAnyMember()->getType();
-  if (!isSupportedPointerType(MemberType)) return {};
+  if (!isSupportedRawPointerType(MemberType)) return {};
 
   auto *MemberInitExpr = CI->getInit();
   return diagnoseTypeExprCompatibility(MemberType, MemberInitExpr, State.Env,
diff --git a/nullability/pointer_nullability_matchers.cc b/nullability/pointer_nullability_matchers.cc
index fc3d74c..855a31a 100644
--- a/nullability/pointer_nullability_matchers.cc
+++ b/nullability/pointer_nullability_matchers.cc
@@ -46,7 +46,7 @@
 using ast_matchers::unless;
 using ast_matchers::internal::Matcher;
 
-Matcher<Stmt> isPointerExpr() { return expr(hasType(isSupportedPointer())); }
+Matcher<Stmt> isPointerExpr() { return expr(hasType(isSupportedRawPointer())); }
 Matcher<Stmt> isNullPointerLiteral() {
   return implicitCastExpr(anyOf(hasCastKind(CK_NullToPointer),
                                 hasCastKind(CK_NullToMemberPointer)));
@@ -63,13 +63,13 @@
   return implicitCastExpr(hasCastKind(CK_PointerToBoolean));
 }
 Matcher<Stmt> isMemberOfPointerType() {
-  return memberExpr(hasType(isSupportedPointer()));
+  return memberExpr(hasType(isSupportedRawPointer()));
 }
 Matcher<Stmt> isPointerArrow() { return memberExpr(isArrow()); }
 Matcher<Stmt> isCXXThisExpr() { return cxxThisExpr(); }
 Matcher<Stmt> isCallExpr() { return callExpr(); }
 Matcher<Stmt> isPointerReturn() {
-  return returnStmt(hasReturnValue(hasType(isSupportedPointer())));
+  return returnStmt(hasReturnValue(hasType(isSupportedRawPointer())));
 }
 Matcher<Stmt> isConstructExpr() { return cxxConstructExpr(); }
 Matcher<CXXCtorInitializer> isCtorMemberInitializer() {
@@ -92,7 +92,7 @@
           hasCastKind(CK_LValueToRValue),
           has(ignoringParenImpCasts(
               memberExpr(has(ignoringParenImpCasts(cxxThisExpr())),
-                         hasType(isSupportedPointer()),
+                         hasType(isSupportedRawPointer()),
                          hasDeclaration(decl().bind("member-decl"))))))))))))));
 }
 
diff --git a/nullability/pointer_nullability_matchers.h b/nullability/pointer_nullability_matchers.h
index e6641e1..87f5e8d 100644
--- a/nullability/pointer_nullability_matchers.h
+++ b/nullability/pointer_nullability_matchers.h
@@ -13,8 +13,8 @@
 namespace tidy {
 namespace nullability {
 
-AST_MATCHER(QualType, isSupportedPointer) {
-  return isSupportedPointerType(Node);
+AST_MATCHER(QualType, isSupportedRawPointer) {
+  return isSupportedRawPointerType(Node);
 }
 
 ast_matchers::internal::Matcher<Stmt> isPointerExpr();
diff --git a/nullability/type_nullability.cc b/nullability/type_nullability.cc
index af998f3..3bcf68a 100644
--- a/nullability/type_nullability.cc
+++ b/nullability/type_nullability.cc
@@ -28,7 +28,25 @@
 
 namespace clang::tidy::nullability {
 
-bool isSupportedPointerType(QualType T) { return T->isPointerType(); }
+bool isSupportedPointerType(QualType T) {
+  return isSupportedRawPointerType(T) || isSupportedSmartPointerType(T);
+}
+
+bool isSupportedRawPointerType(QualType T) { return T->isPointerType(); }
+
+bool isSupportedSmartPointerType(QualType T) {
+  // TODO(b/304963199): Add support for the `absl_nullability_compatible` tag.
+  const CXXRecordDecl *RD = T.getCanonicalType()->getAsCXXRecordDecl();
+  if (RD == nullptr) return false;
+
+  if (!RD->getDeclContext()->isStdNamespace()) return false;
+
+  const IdentifierInfo *ID = RD->getIdentifier();
+  if (ID == nullptr) return false;
+
+  StringRef Name = ID->getName();
+  return Name == "unique_ptr" || Name == "shared_ptr";
+}
 
 PointerTypeNullability PointerTypeNullability::createSymbolic(
     dataflow::Arena &A) {
diff --git a/nullability/type_nullability.h b/nullability/type_nullability.h
index 54d5e96..8e80c5f 100644
--- a/nullability/type_nullability.h
+++ b/nullability/type_nullability.h
@@ -49,6 +49,15 @@
 /// supporting pointer-to-member, ObjC pointers, `unique_ptr`, etc).
 bool isSupportedPointerType(QualType);
 
+/// Is this exactly a raw (non-smart) pointer type that we track outer
+/// nullability for?
+/// This unwraps sugar, i.e. it looks at the canonical type.
+bool isSupportedRawPointerType(QualType);
+
+/// Is this exactly a smart pointer type that we track outer nullability for?
+/// This unwraps sugar, i.e. it looks at the canonical type.
+bool isSupportedSmartPointerType(QualType);
+
 /// Describes the nullability contract of a pointer "slot" within a type.
 ///
 /// This may be concrete: nullable/non-null/unknown nullability.
diff --git a/nullability/type_nullability_test.cc b/nullability/type_nullability_test.cc
index e058bf3..819ee30 100644
--- a/nullability/type_nullability_test.cc
+++ b/nullability/type_nullability_test.cc
@@ -19,7 +19,7 @@
 namespace {
 using testing::ElementsAre;
 
-TEST(TypeNullabilityTest, IsSupportedPointerType) {
+TEST(TypeNullabilityTest, IsSupportedRawPointerType) {
   TestAST AST(R"cpp(
     using NotPointer = int;
     using Pointer = NotPointer*;
@@ -36,6 +36,12 @@
     template <class>
     struct Container;
     using ContainsPointers = Container<int*>;
+
+    namespace std {
+    template <typename T>
+    class unique_ptr;
+    }
+    using UniquePointer = std::unique_ptr<NotPointer>;
   )cpp");
 
   auto Underlying = [&](llvm::StringRef Name) {
@@ -44,14 +50,58 @@
     EXPECT_TRUE(Lookup.isSingleResult());
     return Lookup.find_first<TypeAliasDecl>()->getUnderlyingType();
   };
-  EXPECT_FALSE(isSupportedPointerType(Underlying("NotPointer")));
-  EXPECT_TRUE(isSupportedPointerType(Underlying("Pointer")));
-  EXPECT_TRUE(isSupportedPointerType(Underlying("FuncPointer")));
-  EXPECT_TRUE(isSupportedPointerType(Underlying("SugaredPointer")));
-  EXPECT_FALSE(isSupportedPointerType(Underlying("PointerDataMember")));
-  EXPECT_FALSE(isSupportedPointerType(Underlying("PointerMemberFunction")));
-  EXPECT_FALSE(isSupportedPointerType(Underlying("ObjCPointer")));
-  EXPECT_FALSE(isSupportedPointerType(Underlying("ContainsPointers")));
+  EXPECT_FALSE(isSupportedRawPointerType(Underlying("NotPointer")));
+  EXPECT_TRUE(isSupportedRawPointerType(Underlying("Pointer")));
+  EXPECT_TRUE(isSupportedRawPointerType(Underlying("FuncPointer")));
+  EXPECT_TRUE(isSupportedRawPointerType(Underlying("SugaredPointer")));
+  EXPECT_FALSE(isSupportedRawPointerType(Underlying("PointerDataMember")));
+  EXPECT_FALSE(isSupportedRawPointerType(Underlying("PointerMemberFunction")));
+  EXPECT_FALSE(isSupportedRawPointerType(Underlying("ObjCPointer")));
+  EXPECT_FALSE(isSupportedRawPointerType(Underlying("ContainsPointers")));
+  EXPECT_FALSE(isSupportedRawPointerType(Underlying("UniquePointer")));
+}
+
+TEST(TypeNullabilityTest, IsSupportedSmartPointerType) {
+  TestAST AST(R"cpp(
+    namespace std {
+    template <typename>
+    class unique_ptr;
+    template <typename>
+    class shared_ptr;
+    template <typename>
+    class weak_ptr;
+    }  // namespace std
+    template <typename>
+    class unique_ptr;
+
+    using NotPointer = int;
+    using UniquePointer = std::unique_ptr<NotPointer>;
+    using SharedPointer = std::shared_ptr<NotPointer>;
+    using WeakPointer = std::weak_ptr<NotPointer>;
+
+    using UniquePointerWrongNamespace = ::unique_ptr<NotPointer>;
+
+    using SugaredPointer = UniquePointer;
+
+    template <class>
+    struct Container;
+    using ContainsPointers = Container<std::unique_ptr<int>>;
+  )cpp");
+
+  auto Underlying = [&](llvm::StringRef Name) {
+    auto Lookup = AST.context().getTranslationUnitDecl()->lookup(
+        &AST.context().Idents.get(Name));
+    EXPECT_TRUE(Lookup.isSingleResult());
+    return Lookup.find_first<TypeAliasDecl>()->getUnderlyingType();
+  };
+  EXPECT_FALSE(isSupportedSmartPointerType(Underlying("NotPointer")));
+  EXPECT_TRUE(isSupportedSmartPointerType(Underlying("UniquePointer")));
+  EXPECT_TRUE(isSupportedSmartPointerType(Underlying("SharedPointer")));
+  EXPECT_FALSE(isSupportedSmartPointerType(Underlying("WeakPointer")));
+  EXPECT_FALSE(
+      isSupportedSmartPointerType(Underlying("UniquePointerWrongNamespace")));
+  EXPECT_TRUE(isSupportedSmartPointerType(Underlying("SugaredPointer")));
+  EXPECT_FALSE(isSupportedRawPointerType(Underlying("ContainsPointers")));
 }
 
 class GetNullabilityAnnotationsFromTypeTest : public ::testing::Test {