blob: 3e76e52e314595c4c958c886b62a3208471a1c7c [file] [log] [blame]
// Part of the Crubit project, under the Apache License v2.0 with LLVM
// Exceptions. See /LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "lifetime_analysis/lifetime_constraints.h"
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <optional>
#include <vector>
#include "lifetime_annotations/lifetime.h"
#include "lifetime_annotations/lifetime_substitutions.h"
#include "lifetime_annotations/pointee_type.h"
#include "lifetime_annotations/type_lifetimes.h"
#include "clang/AST/Type.h"
#include "clang/Analysis/FlowSensitive/DataflowLattice.h"
#include "clang/Basic/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/Error.h"
namespace clang {
namespace tidy {
namespace lifetimes {
clang::dataflow::LatticeJoinEffect LifetimeConstraints::join(
const LifetimeConstraints& other) {
bool changed = false;
for (auto p : other.outlives_constraints_) {
changed |= outlives_constraints_.insert(p).second;
}
return changed ? clang::dataflow::LatticeJoinEffect::Changed
: clang::dataflow::LatticeJoinEffect::Unchanged;
}
namespace {
// Simple Disjoint-Set-Union with path compression (but no union-by-rank). This
// guarantees O(log n) time per operation.
class LifetimeDSU {
public:
void MakeSet(Lifetime l) { parent_[l] = l; }
Lifetime Find(Lifetime l) {
if (l == parent_[l]) return l;
return parent_[l] = Find(parent_[l]);
}
void Union(Lifetime a, Lifetime b) {
a = Find(a);
b = Find(b);
if (a != b) {
parent_[a] = b;
}
}
private:
llvm::DenseMap<Lifetime, Lifetime> parent_;
};
} // namespace
llvm::DenseSet<Lifetime> LifetimeConstraints::GetOutlivingLifetimes(
const Lifetime l) const {
// TODO(veluca): here we could certainly reduce complexity, for example by
// constructing the constraint graph instead of iterating over all constraints
// each time.
std::vector<Lifetime> stack{l};
llvm::DenseSet<Lifetime> visited;
while (!stack.empty()) {
Lifetime v = stack.back();
stack.pop_back();
if (visited.contains(v)) continue;
visited.insert(v);
for (auto [shorter, longer] : outlives_constraints_) {
if (shorter == v) {
stack.push_back(longer);
}
}
}
visited.erase(l);
return visited;
}
llvm::Error LifetimeConstraints::ApplyToFunctionLifetimes(
FunctionLifetimes& function_lifetimes) {
// We want to make output-only lifetimes as long as possible; thus, we collect
// those separately.
llvm::DenseSet<Lifetime> output_lifetimes;
function_lifetimes.GetReturnLifetimes().Traverse(
[&output_lifetimes](Lifetime l, Variance) {
output_lifetimes.insert(l);
});
// Collect all "interesting" lifetimes, i.e. all lifetimes that appear in the
// function call.
llvm::DenseSet<Lifetime> all_interesting_lifetimes;
function_lifetimes.Traverse(
[&all_interesting_lifetimes](Lifetime l, Variance) {
all_interesting_lifetimes.insert(l);
});
LifetimeSubstitutions substitutions;
// Keep track of which lifetimes already have their final substitutions
// computed.
llvm::DenseSet<Lifetime> already_have_substitutions;
// First of all, substitute everything that outlives 'static with 'static.
for (Lifetime outlives_static : GetOutlivingLifetimes(Lifetime::Static())) {
if (outlives_static.IsLocal()) {
return llvm::createStringError(llvm::inconvertibleErrorCode(),
"Function assigns local to static");
}
already_have_substitutions.insert(outlives_static);
substitutions.Add(outlives_static, Lifetime::Static());
}
LifetimeDSU dsu;
dsu.MakeSet(Lifetime::Static());
for (Lifetime lifetime : all_interesting_lifetimes) {
dsu.MakeSet(lifetime);
}
for (Lifetime lifetime : all_interesting_lifetimes) {
llvm::DenseSet<Lifetime> longer_lifetimes = GetOutlivingLifetimes(lifetime);
longer_lifetimes.erase(Lifetime::Static());
// If constrained to be outlived by 'local, replace the lifetime with
// 'local, or error out if 'static.
auto local_it =
std::find_if(longer_lifetimes.begin(), longer_lifetimes.end(),
[](Lifetime l) { return l.IsLocal(); });
if (local_it != longer_lifetimes.end()) {
substitutions.Add(lifetime, *local_it);
already_have_substitutions.insert(lifetime);
continue;
}
// Now all the longer lifetimes must be variable lifetimes. As we do not
// support inequalities, we simply state that they must be equivalent.
for (Lifetime longer : longer_lifetimes) {
if (already_have_substitutions.contains(longer)) continue;
if (!all_interesting_lifetimes.contains(longer)) continue;
dsu.Union(longer, lifetime);
}
}
// Everything that is equivalent to 'static must be replaced by 'static, not
// by an arbitrary lifetime in the equivalence set.
Lifetime cc_of_static = dsu.Find(Lifetime::Static());
for (Lifetime lifetime : all_interesting_lifetimes) {
if (already_have_substitutions.contains(lifetime) ||
!lifetime.IsVariable()) {
continue;
}
Lifetime cc = dsu.Find(lifetime);
if (cc == cc_of_static) {
substitutions.Add(lifetime, Lifetime::Static());
} else {
substitutions.Add(lifetime, cc);
}
}
function_lifetimes.SubstituteLifetimes(substitutions);
return llvm::Error::success();
}
namespace {
enum LifetimeRequirement {
kReplacementIsGe = 0x1,
kReplacementIsLe = 0x2,
kReplacementIsEq = 0x3,
};
// Computes the requirement corresponding to *composing* the two requirements
// together; for example, using a type containing a contravariant lifetime in
// contravariant position would result in a covariant lifetime.
// In general, this function behaves like multiplication where Ge = -1, Le = 1,
// Eq = 0.
LifetimeRequirement Compose(LifetimeRequirement a, LifetimeRequirement b) {
if (a == LifetimeRequirement::kReplacementIsEq ||
b == LifetimeRequirement::kReplacementIsEq) {
return LifetimeRequirement::kReplacementIsEq;
}
if (a != b) {
return LifetimeRequirement::kReplacementIsGe;
}
return LifetimeRequirement::kReplacementIsLe;
}
void AddConstraint(LifetimeRequirement req, Lifetime obj, Lifetime replacement,
LifetimeConstraints& constraints) {
if (req & LifetimeRequirement::kReplacementIsLe) {
constraints.AddOutlivesConstraint(replacement, obj);
}
if (req & LifetimeRequirement::kReplacementIsGe) {
constraints.AddOutlivesConstraint(obj, replacement);
}
}
void CollectLifetimeConstraints(const ValueLifetimes&, const ValueLifetimes&,
LifetimeRequirement, LifetimeConstraints&);
void CollectLifetimeConstraints(const FunctionLifetimes&,
const FunctionLifetimes&, LifetimeRequirement,
LifetimeConstraints&);
// Collects all the constraints that are required to use `replacement` as a
// replacement for `obj`, taking into account the requirements due to their
// positions (i.e. covariant/contravariant/invariant).
void CollectLifetimeConstraints(const ObjectLifetimes& obj,
const ObjectLifetimes& replacement,
LifetimeRequirement object_requirement,
LifetimeRequirement descendants_requirement,
LifetimeConstraints& constraints) {
AddConstraint(object_requirement, obj.GetLifetime(),
replacement.GetLifetime(), constraints);
CollectLifetimeConstraints(obj.GetValueLifetimes(),
replacement.GetValueLifetimes(),
descendants_requirement, constraints);
}
void CollectLifetimeConstraints(const ValueLifetimes& obj,
const ValueLifetimes& replacement,
LifetimeRequirement requirement,
LifetimeConstraints& constraints) {
assert(obj.Type().getCanonicalType() ==
replacement.Type().getCanonicalType());
if (!PointeeType(obj.Type()).isNull()) {
LifetimeRequirement pointee_req =
PointeeType(obj.Type()).isConstQualified()
? LifetimeRequirement::kReplacementIsLe
: LifetimeRequirement::kReplacementIsEq;
CollectLifetimeConstraints(obj.GetPointeeLifetimes(),
replacement.GetPointeeLifetimes(), requirement,
Compose(pointee_req, requirement), constraints);
}
if (obj.Type()->isRecordType()) {
assert(obj.GetNumTemplateNestingLevels() ==
replacement.GetNumTemplateNestingLevels());
for (size_t depth = 0; depth < obj.GetNumTemplateNestingLevels(); depth++) {
assert(obj.GetNumTemplateArgumentsAtDepth(depth) ==
replacement.GetNumTemplateArgumentsAtDepth(depth));
for (size_t idx = 0; idx < obj.GetNumTemplateArgumentsAtDepth(depth);
idx++) {
std::optional<ValueLifetimes> obj_arg =
obj.GetTemplateArgumentLifetimes(depth, idx);
std::optional<ValueLifetimes> replacement_arg =
replacement.GetTemplateArgumentLifetimes(depth, idx);
assert(obj_arg.has_value() == replacement_arg.has_value());
if (obj_arg.has_value() && replacement_arg.has_value()) {
CollectLifetimeConstraints(*obj_arg, *replacement_arg,
LifetimeRequirement::kReplacementIsEq,
constraints);
}
}
}
for (const auto& lftm_param : GetLifetimeParameters(obj.Type())) {
// TODO(veluca): should lifetime parameters be invariant like template
// parameters?
AddConstraint(requirement, obj.GetLifetimeParameter(lftm_param),
replacement.GetLifetimeParameter(lftm_param), constraints);
}
}
if (clang::isa<clang::FunctionProtoType>(obj.Type())) {
CollectLifetimeConstraints(obj.GetFuncLifetimes(),
replacement.GetFuncLifetimes(), requirement,
constraints);
}
}
void CollectLifetimeConstraints(const FunctionLifetimes& callable,
const FunctionLifetimes& replacement_callable,
LifetimeRequirement requirement,
LifetimeConstraints& constraints) {
for (size_t i = 0; i < callable.GetNumParams(); i++) {
CollectLifetimeConstraints(
callable.GetParamLifetimes(i),
replacement_callable.GetParamLifetimes(i),
Compose(LifetimeRequirement::kReplacementIsGe, requirement),
constraints);
}
CollectLifetimeConstraints(
callable.GetReturnLifetimes(), replacement_callable.GetReturnLifetimes(),
Compose(LifetimeRequirement::kReplacementIsLe, requirement), constraints);
if (callable.IsNonStaticMethod()) {
CollectLifetimeConstraints(
callable.GetThisLifetimes(), replacement_callable.GetThisLifetimes(),
Compose(LifetimeRequirement::kReplacementIsGe, requirement),
constraints);
}
}
} // namespace
LifetimeConstraints LifetimeConstraints::ForCallableSubstitution(
const FunctionLifetimes& callable,
const FunctionLifetimes& replacement_callable) {
LifetimeConstraints constraints =
LifetimeConstraints::ForCallableSubstitutionFull(callable,
replacement_callable);
llvm::DenseSet<Lifetime> all_lifetimes;
callable.Traverse(
[&all_lifetimes](Lifetime l, Variance) { all_lifetimes.insert(l); });
LifetimeConstraints ret;
for (auto l : all_lifetimes) {
for (auto outliving : constraints.GetOutlivingLifetimes(l)) {
if (all_lifetimes.contains(outliving)) {
ret.AddOutlivesConstraint(l, outliving);
}
}
}
return ret;
}
LifetimeConstraints LifetimeConstraints::ForCallableSubstitutionFull(
const FunctionLifetimes& callable,
const FunctionLifetimes& replacement_callable) {
LifetimeConstraints constraints;
CollectLifetimeConstraints(callable, replacement_callable,
LifetimeRequirement::kReplacementIsLe,
constraints);
return constraints;
}
} // namespace lifetimes
} // namespace tidy
} // namespace clang