Constrain inferable slots to previously inferred nullabilities when available, instead of always to Unknown, when querying the nullability of other values.
PiperOrigin-RevId: 574892768
Change-Id: I5e62ec21d4631d8b33889ef2de84625c07abac6b
diff --git a/nullability/inference/BUILD b/nullability/inference/BUILD
index 3997845..b620bad 100644
--- a/nullability/inference/BUILD
+++ b/nullability/inference/BUILD
@@ -31,6 +31,7 @@
deps = [
":collect_evidence",
":inference_cc_proto",
+ ":slot_fingerprint",
"@llvm-project//clang:ast",
"@llvm-project//clang:ast_matchers",
"@llvm-project//clang:basic",
diff --git a/nullability/inference/collect_evidence.cc b/nullability/inference/collect_evidence.cc
index f2964cf..ba670ce 100644
--- a/nullability/inference/collect_evidence.cc
+++ b/nullability/inference/collect_evidence.cc
@@ -50,6 +50,7 @@
namespace clang::tidy::nullability {
using ::clang::dataflow::DataflowAnalysisContext;
using ::clang::dataflow::Environment;
+using ::clang::dataflow::Formula;
std::string_view getOrGenerateUSR(USRCache &Cache, const Decl &Decl) {
auto [It, Inserted] = Cache.try_emplace(&Decl);
@@ -171,20 +172,24 @@
// constraint representing these slots having a) the nullability inferred from
// the previous round for this slot or b) Unknown nullability if no inference
// was made in the previous round or there was no previous round.
-const dataflow::Formula *getInferableSlotsAsInferredOrUnknownConstraint(
+const Formula *getInferableSlotsAsInferredOrUnknownConstraint(
std::vector<std::pair<PointerTypeNullability, Slot>> &InferableSlots,
const llvm::DenseSet<SlotFingerprint> &PreviouslyInferredNullable,
const llvm::DenseSet<SlotFingerprint> &PreviouslyInferredNonnull,
- const dataflow::Environment &Env) {
+ USRCache &USRCache, const dataflow::Environment &Env) {
dataflow::Arena &A = Env.getDataflowAnalysisContext().arena();
- const dataflow::Formula *InferableSlotsUnknown = &A.makeLiteral(true);
+ const Formula *InferableSlotsUnknown = &A.makeLiteral(true);
+ std::string_view USR = getOrGenerateUSR(USRCache, *Env.getCurrentFunc());
for (auto &[Nullability, Slot] : InferableSlots) {
- // TODO: get access to the current function's USR and look up the
- // fingerprint for that USR and `Slot` in the previously inferred sets,
- // instead of assuming Unknown.
- InferableSlotsUnknown = &A.makeAnd(
- *InferableSlotsUnknown, A.makeAnd(A.makeNot(Nullability.isNullable(A)),
- A.makeNot(Nullability.isNonnull(A))));
+ SlotFingerprint Fingerprint = fingerprint(USR, Slot);
+ const Formula &Nullable = PreviouslyInferredNullable.contains(Fingerprint)
+ ? Nullability.isNullable(A)
+ : A.makeNot(Nullability.isNullable(A));
+ const Formula &Nonnull = PreviouslyInferredNonnull.contains(Fingerprint)
+ ? Nullability.isNonnull(A)
+ : A.makeNot(Nullability.isNonnull(A));
+ InferableSlotsUnknown =
+ &A.makeAnd(*InferableSlotsUnknown, A.makeAnd(Nullable, Nonnull));
}
return InferableSlotsUnknown;
}
@@ -207,7 +212,8 @@
std::vector<std::pair<PointerTypeNullability, Slot>> &InferableCallerSlots,
const llvm::DenseSet<SlotFingerprint> &PreviouslyInferredNullable,
const llvm::DenseSet<SlotFingerprint> &PreviouslyInferredNonnull,
- const CFGElement &Element, const dataflow::Environment &Env,
+ USRCache USRCache, const CFGElement &Element,
+ const dataflow::Environment &Env,
llvm::function_ref<EvidenceEmitter> Emit) {
// Is this CFGElement a call to a function?
auto CFGStmt = Element.getAs<clang::CFGStmt>();
@@ -259,7 +265,7 @@
getNullability(*PV, Env,
getInferableSlotsAsInferredOrUnknownConstraint(
InferableCallerSlots, PreviouslyInferredNullable,
- PreviouslyInferredNonnull, Env));
+ PreviouslyInferredNonnull, USRCache, Env));
Evidence::Kind ArgEvidenceKind;
switch (ArgNullability) {
case NullabilityKind::Nullable:
@@ -279,7 +285,8 @@
std::vector<std::pair<PointerTypeNullability, Slot>> &InferableSlots,
const llvm::DenseSet<SlotFingerprint> &PreviouslyInferredNullable,
const llvm::DenseSet<SlotFingerprint> &PreviouslyInferredNonnull,
- const CFGElement &Element, const dataflow::Environment &Env,
+ USRCache USRCache, const CFGElement &Element,
+ const dataflow::Environment &Env,
llvm::function_ref<EvidenceEmitter> Emit) {
// Is this CFGElement a return statement?
auto CFGStmt = Element.getAs<clang::CFGStmt>();
@@ -297,7 +304,7 @@
getNullability(ReturnExpr, Env,
getInferableSlotsAsInferredOrUnknownConstraint(
InferableSlots, PreviouslyInferredNullable,
- PreviouslyInferredNonnull, Env));
+ PreviouslyInferredNonnull, USRCache, Env));
Evidence::Kind ReturnEvidenceKind;
switch (ReturnNullability) {
case NullabilityKind::Nullable:
@@ -317,13 +324,15 @@
std::vector<std::pair<PointerTypeNullability, Slot>> InferableSlots,
const llvm::DenseSet<SlotFingerprint> &PreviouslyInferredNullable,
const llvm::DenseSet<SlotFingerprint> &PreviouslyInferredNonnull,
- const CFGElement &Element, const Environment &Env,
+ USRCache USRCache, const CFGElement &Element, const Environment &Env,
llvm::function_ref<EvidenceEmitter> Emit) {
collectEvidenceFromDereference(InferableSlots, Element, Env, Emit);
collectEvidenceFromCallExpr(InferableSlots, PreviouslyInferredNullable,
- PreviouslyInferredNonnull, Element, Env, Emit);
+ PreviouslyInferredNonnull, USRCache, Element, Env,
+ Emit);
collectEvidenceFromReturn(InferableSlots, PreviouslyInferredNullable,
- PreviouslyInferredNonnull, Element, Env, Emit);
+ PreviouslyInferredNonnull, USRCache, Element, Env,
+ Emit);
// TODO: add more heuristic collections here
}
@@ -379,9 +388,10 @@
[&](const CFGElement &Element,
const dataflow::DataflowAnalysisState<
PointerNullabilityLattice> &State) {
- collectEvidenceFromElement(
- InferableSlots, PreviouslyInferredNullable,
- PreviouslyInferredNonnull, Element, State.Env, Emit);
+ collectEvidenceFromElement(InferableSlots,
+ PreviouslyInferredNullable,
+ PreviouslyInferredNonnull, USRCache,
+ Element, State.Env, Emit);
})
.takeError();
}
diff --git a/nullability/inference/collect_evidence_test.cc b/nullability/inference/collect_evidence_test.cc
index 5b08e1d..7c97878 100644
--- a/nullability/inference/collect_evidence_test.cc
+++ b/nullability/inference/collect_evidence_test.cc
@@ -9,6 +9,7 @@
#include <vector>
#include "nullability/inference/inference.proto.h"
+#include "nullability/inference/slot_fingerprint.h"
#include "clang/AST/Decl.h"
#include "clang/AST/DeclBase.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
@@ -16,6 +17,7 @@
#include "clang/Basic/LLVM.h"
#include "clang/Testing/TestAST.h"
#include "third_party/llvm/llvm-project/clang/unittests/Analysis/FlowSensitive/TestingSupport.h"
+#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/raw_ostream.h"
@@ -29,10 +31,13 @@
using ::clang::ast_matchers::isTemplateInstantiation;
using ::clang::ast_matchers::match;
using ::testing::_;
+using ::testing::AllOf;
using ::testing::Contains;
using ::testing::ElementsAre;
using ::testing::IsEmpty;
+using ::testing::IsSupersetOf;
using ::testing::Not;
+using ::testing::ResultOf;
using ::testing::SizeIs;
using ::testing::UnorderedElementsAre;
@@ -66,7 +71,9 @@
}
std::vector<Evidence> collectEvidenceFromTargetFunction(
- llvm::StringRef Source) {
+ llvm::StringRef Source,
+ const llvm::DenseSet<SlotFingerprint>& PreviouslyInferredNullable = {},
+ const llvm::DenseSet<SlotFingerprint>& PreviouslyInferredNonnull = {}) {
std::vector<Evidence> Results;
clang::TestAST AST(getInputsWithAnnotationDefinitions(Source));
USRCache usr_cache;
@@ -75,7 +82,7 @@
*dataflow::test::findValueDecl(AST.context(), "target")),
evidenceEmitter([&](const Evidence& E) { Results.push_back(E); },
usr_cache),
- usr_cache);
+ usr_cache, PreviouslyInferredNullable, PreviouslyInferredNonnull);
if (Err) ADD_FAILURE() << toString(std::move(Err));
return Results;
}
@@ -488,6 +495,52 @@
EXPECT_THAT(Results, IsEmpty());
}
+TEST(CollectEvidenceFromImplementationTest, PropagatesPreviousInferences) {
+ static constexpr llvm::StringRef Src = R"cc(
+ void calledWithToBeNullable(int* x);
+ void calledWithToBeNonnull(int* a);
+ void target(int* p, int* q) {
+ target(nullptr, q);
+ calledWithToBeNullable(p);
+ *q;
+ calledWithToBeNonnull(q);
+ }
+ )cc";
+ std::string TargetUsr = "c:@F@target#*I#S0_#";
+ std::vector ExpectedBothRoundResults = {
+ evidence(paramSlot(0), Evidence::NULLABLE_ARGUMENT,
+ AllOf(functionNamed("target"),
+ // Double-check that target's usr is as expected before we
+ // use it to create SlotFingerprints.
+ ResultOf([](Symbol S) { return S.usr(); }, TargetUsr))),
+ evidence(paramSlot(1), Evidence::UNCHECKED_DEREFERENCE,
+ functionNamed("target")),
+ };
+ std::vector ExpectedSecondRoundResults = {
+ evidence(paramSlot(0), Evidence::NULLABLE_ARGUMENT,
+ functionNamed("calledWithToBeNullable")),
+ evidence(paramSlot(0), Evidence::NONNULL_ARGUMENT,
+ functionNamed("calledWithToBeNonnull"))};
+
+ // Only proceed if we have the correct USR for target and the first round
+ // results contain the evidence needed to produce our expected inferences and
+ // do not contain the evidence only found from propagating inferences from the
+ // first round.
+ auto FirstRoundResults = collectEvidenceFromTargetFunction(Src);
+ ASSERT_THAT(FirstRoundResults, IsSupersetOf(ExpectedBothRoundResults));
+ for (const auto& E : ExpectedSecondRoundResults) {
+ ASSERT_THAT(FirstRoundResults, Not(Contains(E)));
+ }
+
+ EXPECT_THAT(
+ collectEvidenceFromTargetFunction(Src, /* PreviouslyInferredNullable= */
+ {fingerprint(TargetUsr, paramSlot(0))},
+ /* PreviouslyInferredNonnull= */
+ {fingerprint(TargetUsr, paramSlot(1))}),
+ AllOf(IsSupersetOf(ExpectedBothRoundResults),
+ IsSupersetOf(ExpectedSecondRoundResults)));
+}
+
TEST(CollectEvidenceFromDeclarationTest, VariableDeclIgnored) {
llvm::StringLiteral Src = "Nullable<int *> target;";
EXPECT_THAT(collectEvidenceFromTargetDecl(Src), IsEmpty());