Use previous inference results in PointerNullabilityAnalysis.
Allows for additional propagation of inferred nullability, e.g. from one implementation's return value or parameters to another implementation's interaction with those decls.
PiperOrigin-RevId: 576133969
Change-Id: I241f413fede47ab773cfe77b78435a58c552a279
diff --git a/nullability/BUILD b/nullability/BUILD
index 4ca6a1c..3ba5bcf 100644
--- a/nullability/BUILD
+++ b/nullability/BUILD
@@ -15,6 +15,8 @@
"@absl//absl/log:check",
"@llvm-project//clang:analysis",
"@llvm-project//clang:ast",
+ "@llvm-project//clang:basic",
+ "@llvm-project//llvm:Support",
],
)
diff --git a/nullability/inference/BUILD b/nullability/inference/BUILD
index b620bad..1432ac8 100644
--- a/nullability/inference/BUILD
+++ b/nullability/inference/BUILD
@@ -16,6 +16,7 @@
"//nullability:pointer_nullability_analysis",
"//nullability:pointer_nullability_lattice",
"//nullability:type_nullability",
+ "@absl//absl/container:flat_hash_map",
"@absl//absl/log:check",
"@llvm-project//clang:analysis",
"@llvm-project//clang:ast",
diff --git a/nullability/inference/collect_evidence.cc b/nullability/inference/collect_evidence.cc
index d210e7f..5fdff82 100644
--- a/nullability/inference/collect_evidence.cc
+++ b/nullability/inference/collect_evidence.cc
@@ -11,6 +11,7 @@
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "nullability/inference/inferable.h"
#include "nullability/inference/inference.proto.h"
@@ -52,6 +53,10 @@
using ::clang::dataflow::Environment;
using ::clang::dataflow::Formula;
+using ConcreteNullabilityCache =
+ absl::flat_hash_map<const Decl *,
+ std::optional<const PointerTypeNullability>>;
+
std::string_view getOrGenerateUSR(USRCache &Cache, const Decl &Decl) {
auto [It, Inserted] = Cache.try_emplace(&Decl);
if (Inserted) {
@@ -174,7 +179,7 @@
// was made in the previous round or there was no previous round.
const Formula &getInferableSlotsAsInferredOrUnknownConstraint(
std::vector<std::pair<PointerTypeNullability, Slot>> &InferableSlots,
- const PreviousInferences &PreviousInferences, USRCache &USRCache,
+ USRCache &USRCache, const PreviousInferences &PreviousInferences,
dataflow::Arena &A, const Decl &CurrentFunc) {
const Formula *Constraint = &A.makeLiteral(true);
std::string_view USR = getOrGenerateUSR(USRCache, CurrentFunc);
@@ -254,8 +259,8 @@
// Emit evidence of the parameter's nullability. First, calculate that
// nullability based on InferableSlots for the caller being assigned to
- // Unknown, to reflect the current annotations and not all possible
- // annotations for them.
+ // Unknown or their previously-inferred value, to reflect the current
+ // annotations and not all possible annotations for them.
NullabilityKind ArgNullability =
getNullability(*PV, Env, &InferableSlotsConstraint);
Evidence::Kind ArgEvidenceKind;
@@ -331,6 +336,49 @@
return Evidence::ANNOTATED_NULLABLE;
}
}
+
+// Returns a function that the analysis can use to override Decl nullability
+// values from the source code being analyzed with previously inferred
+// nullabilities.
+//
+// In practice, this should only override the default nullability for Decls that
+// do not spell out a nullability in source code, because we only pass in
+// inferences from the previous round which are non-trivial and annotations
+// "inferred" by reading an annotation from source code in the previous round
+// were marked trivial.
+auto getConcreteNullabilityOverrideFromPreviousInferences(
+ ConcreteNullabilityCache &Cache, USRCache &USRCache,
+ const PreviousInferences &PreviousInferences) {
+ return [&](const Decl &D) -> std::optional<const PointerTypeNullability *> {
+ auto [It, Inserted] = Cache.try_emplace(&D);
+ if (Inserted) {
+ std::optional<const Decl *> fingerprintedDecl;
+ Slot Slot;
+ if (auto *FD = clang::dyn_cast_or_null<FunctionDecl>(&D)) {
+ fingerprintedDecl = (ValueDecl *)FD;
+ Slot = SLOT_RETURN_TYPE;
+ } else if (auto *PD = clang::dyn_cast_or_null<ParmVarDecl>(&D)) {
+ if (auto *Parent = clang::dyn_cast_or_null<FunctionDecl>(
+ PD->getParentFunctionOrMethod())) {
+ fingerprintedDecl = (ValueDecl *)Parent;
+ Slot = paramSlot(PD->getFunctionScopeIndex());
+ }
+ }
+ if (!fingerprintedDecl) return std::nullopt;
+ auto fp =
+ fingerprint(getOrGenerateUSR(USRCache, **fingerprintedDecl), Slot);
+ if (PreviousInferences.Nullable.contains(fp)) {
+ It->second.emplace(NullabilityKind::Nullable);
+ } else if (PreviousInferences.Nonnull.contains(fp)) {
+ It->second.emplace(NullabilityKind::NonNull);
+ } else {
+ It->second = std::nullopt;
+ }
+ }
+ if (!It->second) return std::nullopt;
+ return &*It->second;
+ };
+}
} // namespace
llvm::Error collectEvidenceFromImplementation(
@@ -365,9 +413,14 @@
}
const auto &InferableSlotsConstraint =
getInferableSlotsAsInferredOrUnknownConstraint(
- InferableSlots, PreviousInferences, USRCache, AnalysisContext.arena(),
+ InferableSlots, USRCache, PreviousInferences, AnalysisContext.arena(),
Decl);
+ ConcreteNullabilityCache ConcreteNullabilityCache;
+ Analysis.assignNullabilityOverride(
+ getConcreteNullabilityOverrideFromPreviousInferences(
+ ConcreteNullabilityCache, USRCache, PreviousInferences));
+
return dataflow::runDataflowAnalysis(
*ControlFlowContext, Analysis, Environment,
[&](const CFGElement &Element,
diff --git a/nullability/inference/collect_evidence_test.cc b/nullability/inference/collect_evidence_test.cc
index eca2f3d..7ec9411 100644
--- a/nullability/inference/collect_evidence_test.cc
+++ b/nullability/inference/collect_evidence_test.cc
@@ -17,6 +17,7 @@
#include "clang/Basic/LLVM.h"
#include "clang/Testing/TestAST.h"
#include "third_party/llvm/llvm-project/clang/unittests/Analysis/FlowSensitive/TestingSupport.h"
+#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Error.h"
@@ -537,6 +538,93 @@
IsSupersetOf(ExpectedSecondRoundResults)));
}
+TEST(CollectEvidenceFromImplementationTest,
+ AnalysisUsesPreviousInferencesForSlotsOutsideTargetImplementation) {
+ static constexpr llvm::StringRef Src = R"cc(
+ int* returnsToBeNonnull(int* a) {
+ return a;
+ }
+ int* target(int* q) {
+ *q;
+ return returnsToBeNonnull(q);
+ }
+ )cc";
+ std::string TargetUsr = "c:@F@target#*I#";
+ std::string ReturnsToBeNonnullUsr = "c:@F@returnsToBeNonnull#*I#";
+ const llvm::DenseMap<int, std::vector<testing::Matcher<const Evidence&>>>
+ ExpectedNewResultsPerRound = {
+ {{0,
+ {evidence(
+ paramSlot(0), Evidence::UNCHECKED_DEREFERENCE,
+ AllOf(functionNamed("target"),
+ // Double-check that target's usr is as expected before
+ // we use it to create SlotFingerprints.
+ ResultOf([](Symbol S) { return S.usr(); }, TargetUsr)))}},
+ {1,
+ {evidence(
+ paramSlot(0), Evidence::NONNULL_ARGUMENT,
+ AllOf(functionNamed("returnsToBeNonnull"),
+ // Double-check that returnsToBeNonnull's usr is as
+ // expected before we use it to create SlotFingerprints.
+ ResultOf([](Symbol S) { return S.usr(); },
+ ReturnsToBeNonnullUsr)))}},
+ {2,
+ {
+ // No new evidence from target's implementation in this round,
+ // but in a full-TU analysis, this would be the round where we
+ // decide returnsToBeNonnull returns Nonnull, based on the
+ // now-Nonnull argument that is the only return value.
+ }},
+ {3,
+ {evidence(SLOT_RETURN_TYPE, Evidence::NONNULL_RETURN,
+ functionNamed("target"))}}}};
+
+ // Assert first round results because they don't rely on previous inference
+ // propagation at all and in this case are test setup and preconditions.
+ auto FirstRoundResults = collectEvidenceFromTargetFunction(Src);
+ ASSERT_THAT(FirstRoundResults,
+ IsSupersetOf(ExpectedNewResultsPerRound.at(0)));
+ for (const auto& E : ExpectedNewResultsPerRound.at(1)) {
+ ASSERT_THAT(FirstRoundResults, Not(Contains(E)));
+ }
+
+ auto SecondRoundResults = collectEvidenceFromTargetFunction(
+ Src, {.Nonnull = {fingerprint(TargetUsr, paramSlot(0))}});
+ EXPECT_THAT(SecondRoundResults,
+ AllOf(IsSupersetOf(ExpectedNewResultsPerRound.at(0)),
+ IsSupersetOf(ExpectedNewResultsPerRound.at(1))));
+ for (const auto& E : ExpectedNewResultsPerRound.at(2)) {
+ ASSERT_THAT(SecondRoundResults, Not(Contains(E)));
+ }
+
+ auto ThirdRoundResults = collectEvidenceFromTargetFunction(
+ Src, {.Nonnull = {fingerprint(TargetUsr, paramSlot(0)),
+ fingerprint(ReturnsToBeNonnullUsr, paramSlot(0))}});
+ EXPECT_THAT(ThirdRoundResults,
+ AllOf(IsSupersetOf(ExpectedNewResultsPerRound.at(0)),
+ IsSupersetOf(ExpectedNewResultsPerRound.at(1)),
+ IsSupersetOf(ExpectedNewResultsPerRound.at(2))));
+ for (const auto& E : ExpectedNewResultsPerRound.at(3)) {
+ ASSERT_THAT(ThirdRoundResults, Not(Contains(E)));
+ }
+
+ auto FourthRoundResults = collectEvidenceFromTargetFunction(
+ Src,
+ {.Nonnull = {
+ fingerprint(TargetUsr, paramSlot(0)),
+ fingerprint(ReturnsToBeNonnullUsr, paramSlot(0)),
+ // As noted in the Evidence matcher list above, we don't infer the
+ // return type of returnsToBeNonnull from only collecting evidence
+ // from target's implementation, but for the sake of this test, let's
+ // pretend we collected evidence from the entire TU.
+ fingerprint(ReturnsToBeNonnullUsr, SLOT_RETURN_TYPE)}});
+ EXPECT_THAT(FourthRoundResults,
+ AllOf(IsSupersetOf(ExpectedNewResultsPerRound.at(0)),
+ IsSupersetOf(ExpectedNewResultsPerRound.at(1)),
+ IsSupersetOf(ExpectedNewResultsPerRound.at(2)),
+ IsSupersetOf(ExpectedNewResultsPerRound.at(3))));
+}
+
TEST(CollectEvidenceFromDeclarationTest, VariableDeclIgnored) {
llvm::StringLiteral Src = "Nullable<int *> target;";
EXPECT_THAT(collectEvidenceFromTargetDecl(Src), IsEmpty());
diff --git a/nullability/pointer_nullability_analysis.cc b/nullability/pointer_nullability_analysis.cc
index adcd9c8..0395559 100644
--- a/nullability/pointer_nullability_analysis.cc
+++ b/nullability/pointer_nullability_analysis.cc
@@ -494,7 +494,7 @@
// If nullability for the decl D has been overridden, patch N to reflect it.
// (N is the nullability of an access to D).
-void overrideNullabilityFromDecl(const ValueDecl *D,
+void overrideNullabilityFromDecl(const Decl *D,
PointerNullabilityLattice &Lattice,
TypeNullability &N) {
// For now, overrides are always for pointer values only, and override only
@@ -706,8 +706,11 @@
// TODO(mboehme): Instead of relying on Clang to propagate nullability sugar
// to the `CallExpr`'s type, we should extract nullability directly from the
// callee `Expr .
- return substituteNullabilityAnnotationsInFunctionTemplate(CE->getType(),
- CE);
+ auto Nullability =
+ substituteNullabilityAnnotationsInFunctionTemplate(CE->getType(), CE);
+ overrideNullabilityFromDecl(CE->getCalleeDecl(), State.Lattice,
+ Nullability);
+ return Nullability;
});
}
diff --git a/nullability/pointer_nullability_analysis.h b/nullability/pointer_nullability_analysis.h
index ca94c8a..65ba3bd 100644
--- a/nullability/pointer_nullability_analysis.h
+++ b/nullability/pointer_nullability_analysis.h
@@ -5,19 +5,21 @@
#ifndef CRUBIT_NULLABILITY_POINTER_NULLABILITY_ANALYSIS_H_
#define CRUBIT_NULLABILITY_POINTER_NULLABILITY_ANALYSIS_H_
+#include <optional>
#include <utility>
#include "nullability/pointer_nullability_lattice.h"
#include "nullability/type_nullability.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/Decl.h"
+#include "clang/AST/DeclBase.h"
#include "clang/AST/Type.h"
#include "clang/Analysis/FlowSensitive/Arena.h"
#include "clang/Analysis/FlowSensitive/CFGMatchSwitch.h"
#include "clang/Analysis/FlowSensitive/DataflowAnalysis.h"
#include "clang/Analysis/FlowSensitive/DataflowEnvironment.h"
-#include "clang/Analysis/FlowSensitive/Formula.h"
#include "clang/Analysis/FlowSensitive/Value.h"
+#include "llvm/ADT/FunctionExtras.h"
namespace clang {
namespace tidy {
@@ -67,6 +69,13 @@
PointerTypeNullability assignNullabilityVariable(const ValueDecl *D,
dataflow::Arena &);
+ void assignNullabilityOverride(
+ llvm::unique_function<
+ std::optional<const PointerTypeNullability *>(const Decl &) const>
+ Override) {
+ NFS.ConcreteNullabilityOverride = std::move(Override);
+ }
+
void transfer(const CFGElement &Elt, PointerNullabilityLattice &Lattice,
dataflow::Environment &Env);
diff --git a/nullability/pointer_nullability_lattice.h b/nullability/pointer_nullability_lattice.h
index 2a27d8e..0b7b906 100644
--- a/nullability/pointer_nullability_lattice.h
+++ b/nullability/pointer_nullability_lattice.h
@@ -6,6 +6,7 @@
#define CRUBIT_NULLABILITY_POINTER_NULLABILITY_LATTICE_H_
#include <functional>
+#include <optional>
#include <ostream>
#include "absl/container/flat_hash_map.h"
@@ -14,18 +15,26 @@
#include "clang/AST/Expr.h"
#include "clang/Analysis/FlowSensitive/DataflowAnalysisContext.h"
#include "clang/Analysis/FlowSensitive/DataflowLattice.h"
+#include "clang/Basic/LLVM.h"
+#include "llvm/ADT/FunctionExtras.h"
namespace clang::tidy::nullability {
-
class PointerNullabilityLattice {
public:
struct NonFlowSensitiveState {
absl::flat_hash_map<const Expr *, TypeNullability> ExprToNullability;
// Overridden symbolic nullability for pointer-typed decls.
// These are set by PointerNullabilityAnalysis::assignNullabilityVariable,
- // and take precedence over the declared type.
+ // and take precedence over the declared type and over any result from
+ // ConcreteNullabilityOverride.
absl::flat_hash_map<const ValueDecl *, PointerTypeNullability>
DeclTopLevelNullability;
+ // Returns overriding concrete nullability for decls. This is set by
+ // PointerNullabilityAnalysis::assignNullabilityOverride, and the result, if
+ // present, takes precedence over the declared type.
+ llvm::unique_function<std::optional<const PointerTypeNullability *>(
+ const Decl &) const>
+ ConcreteNullabilityOverride = [](const Decl &) { return std::nullopt; };
};
PointerNullabilityLattice(NonFlowSensitiveState &NFS) : NFS(NFS) {}
@@ -54,11 +63,18 @@
}
// Returns overridden nullability information associated with a declaration.
- // For now we only track top-level decl nullability symbolically.
- const PointerTypeNullability *getDeclNullability(const ValueDecl *D) const {
- auto It = NFS.DeclTopLevelNullability.find(D);
- if (It == NFS.DeclTopLevelNullability.end()) return nullptr;
- return &It->second;
+ // For now we only track top-level decl nullability symbolically and check for
+ // concrete nullability override results.
+ const PointerTypeNullability *getDeclNullability(const Decl *D) const {
+ if (!D) return nullptr;
+ if (const auto *VD = dyn_cast_or_null<ValueDecl>(D)) {
+ auto It = NFS.DeclTopLevelNullability.find(VD);
+ if (It != NFS.DeclTopLevelNullability.end()) return &It->second;
+ }
+ if (const std::optional<const PointerTypeNullability *> N =
+ NFS.ConcreteNullabilityOverride(*D))
+ return *N;
+ return nullptr;
}
bool operator==(const PointerNullabilityLattice &Other) const { return true; }