// Part of the Crubit project, under the Apache License v2.0 with LLVM
// Exceptions. See /LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "lifetime_analysis/lifetime_constraints.h"

#include <algorithm>
#include <cassert>
#include <cstddef>
#include <optional>
#include <vector>

#include "lifetime_annotations/lifetime.h"
#include "lifetime_annotations/lifetime_substitutions.h"
#include "lifetime_annotations/pointee_type.h"
#include "lifetime_annotations/type_lifetimes.h"
#include "clang/AST/Type.h"
#include "clang/Analysis/FlowSensitive/DataflowLattice.h"
#include "clang/Basic/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/Error.h"

namespace clang {
namespace tidy {
namespace lifetimes {

clang::dataflow::LatticeJoinEffect LifetimeConstraints::join(
    const LifetimeConstraints& other) {
  bool changed = false;
  for (auto p : other.outlives_constraints_) {
    changed |= outlives_constraints_.insert(p).second;
  }
  return changed ? clang::dataflow::LatticeJoinEffect::Changed
                 : clang::dataflow::LatticeJoinEffect::Unchanged;
}

namespace {

// Simple Disjoint-Set-Union with path compression (but no union-by-rank). This
// guarantees O(log n) time per operation.
class LifetimeDSU {
 public:
  void MakeSet(Lifetime l) { parent_[l] = l; }
  Lifetime Find(Lifetime l) {
    if (l == parent_[l]) return l;
    return parent_[l] = Find(parent_[l]);
  }
  void Union(Lifetime a, Lifetime b) {
    a = Find(a);
    b = Find(b);
    if (a != b) {
      parent_[a] = b;
    }
  }

 private:
  llvm::DenseMap<Lifetime, Lifetime> parent_;
};

}  // namespace

llvm::DenseSet<Lifetime> LifetimeConstraints::GetOutlivingLifetimes(
    const Lifetime l) const {
  // TODO(veluca): here we could certainly reduce complexity, for example by
  // constructing the constraint graph instead of iterating over all constraints
  // each time.
  std::vector<Lifetime> stack{l};
  llvm::DenseSet<Lifetime> visited;
  while (!stack.empty()) {
    Lifetime v = stack.back();
    stack.pop_back();
    if (visited.contains(v)) continue;
    visited.insert(v);
    for (auto [shorter, longer] : outlives_constraints_) {
      if (shorter == v) {
        stack.push_back(longer);
      }
    }
  }
  visited.erase(l);
  return visited;
}

llvm::Error LifetimeConstraints::ApplyToFunctionLifetimes(
    FunctionLifetimes& function_lifetimes) {
  // We want to make output-only lifetimes as long as possible; thus, we collect
  // those separately.
  llvm::DenseSet<Lifetime> output_lifetimes;
  function_lifetimes.GetReturnLifetimes().Traverse(
      [&output_lifetimes](Lifetime l, Variance) {
        output_lifetimes.insert(l);
      });

  // Collect all "interesting" lifetimes, i.e. all lifetimes that appear in the
  // function call.
  llvm::DenseSet<Lifetime> all_interesting_lifetimes;
  function_lifetimes.Traverse(
      [&all_interesting_lifetimes](Lifetime l, Variance) {
        all_interesting_lifetimes.insert(l);
      });

  LifetimeSubstitutions substitutions;

  // Keep track of which lifetimes already have their final substitutions
  // computed.
  llvm::DenseSet<Lifetime> already_have_substitutions;

  // First of all, substitute everything that outlives 'static with 'static.
  for (Lifetime outlives_static : GetOutlivingLifetimes(Lifetime::Static())) {
    if (outlives_static.IsLocal()) {
      return llvm::createStringError(llvm::inconvertibleErrorCode(),
                                     "Function assigns local to static");
    }
    already_have_substitutions.insert(outlives_static);
    substitutions.Add(outlives_static, Lifetime::Static());
  }

  LifetimeDSU dsu;
  dsu.MakeSet(Lifetime::Static());
  for (Lifetime lifetime : all_interesting_lifetimes) {
    dsu.MakeSet(lifetime);
  }

  for (Lifetime lifetime : all_interesting_lifetimes) {
    llvm::DenseSet<Lifetime> longer_lifetimes = GetOutlivingLifetimes(lifetime);
    longer_lifetimes.erase(Lifetime::Static());

    // If constrained to be outlived by 'local, replace the lifetime with
    // 'local, or error out if 'static.
    auto local_it =
        std::find_if(longer_lifetimes.begin(), longer_lifetimes.end(),
                     [](Lifetime l) { return l.IsLocal(); });

    if (local_it != longer_lifetimes.end()) {
      substitutions.Add(lifetime, *local_it);
      already_have_substitutions.insert(lifetime);
      continue;
    }

    // Now all the longer lifetimes must be variable lifetimes. As we do not
    // support inequalities, we simply state that they must be equivalent.
    for (Lifetime longer : longer_lifetimes) {
      if (already_have_substitutions.contains(longer)) continue;
      if (!all_interesting_lifetimes.contains(longer)) continue;
      dsu.Union(longer, lifetime);
    }
  }

  // Everything that is equivalent to 'static must be replaced by 'static, not
  // by an arbitrary lifetime in the equivalence set.
  Lifetime cc_of_static = dsu.Find(Lifetime::Static());

  for (Lifetime lifetime : all_interesting_lifetimes) {
    if (already_have_substitutions.contains(lifetime) ||
        !lifetime.IsVariable()) {
      continue;
    }

    Lifetime cc = dsu.Find(lifetime);

    if (cc == cc_of_static) {
      substitutions.Add(lifetime, Lifetime::Static());
    } else {
      substitutions.Add(lifetime, cc);
    }
  }

  function_lifetimes.SubstituteLifetimes(substitutions);

  return llvm::Error::success();
}

namespace {

enum LifetimeRequirement {
  kReplacementIsGe = 0x1,
  kReplacementIsLe = 0x2,
  kReplacementIsEq = 0x3,
};

// Computes the requirement corresponding to *composing* the two requirements
// together; for example, using a type containing a contravariant lifetime in
// contravariant position would result in a covariant lifetime.
// In general, this function behaves like multiplication where Ge = -1, Le = 1,
// Eq = 0.
LifetimeRequirement Compose(LifetimeRequirement a, LifetimeRequirement b) {
  if (a == LifetimeRequirement::kReplacementIsEq ||
      b == LifetimeRequirement::kReplacementIsEq) {
    return LifetimeRequirement::kReplacementIsEq;
  }
  if (a != b) {
    return LifetimeRequirement::kReplacementIsGe;
  }
  return LifetimeRequirement::kReplacementIsLe;
}

void AddConstraint(LifetimeRequirement req, Lifetime obj, Lifetime replacement,
                   LifetimeConstraints& constraints) {
  if (req & LifetimeRequirement::kReplacementIsLe) {
    constraints.AddOutlivesConstraint(replacement, obj);
  }
  if (req & LifetimeRequirement::kReplacementIsGe) {
    constraints.AddOutlivesConstraint(obj, replacement);
  }
}

void CollectLifetimeConstraints(const ValueLifetimes&, const ValueLifetimes&,
                                LifetimeRequirement, LifetimeConstraints&);

void CollectLifetimeConstraints(const FunctionLifetimes&,
                                const FunctionLifetimes&, LifetimeRequirement,
                                LifetimeConstraints&);

// Collects all the constraints that are required to use `replacement` as a
// replacement for `obj`, taking into account the requirements due to their
// positions (i.e. covariant/contravariant/invariant).
void CollectLifetimeConstraints(const ObjectLifetimes& obj,
                                const ObjectLifetimes& replacement,
                                LifetimeRequirement object_requirement,
                                LifetimeRequirement descendants_requirement,
                                LifetimeConstraints& constraints) {
  AddConstraint(object_requirement, obj.GetLifetime(),
                replacement.GetLifetime(), constraints);
  CollectLifetimeConstraints(obj.GetValueLifetimes(),
                             replacement.GetValueLifetimes(),
                             descendants_requirement, constraints);
}

void CollectLifetimeConstraints(const ValueLifetimes& obj,
                                const ValueLifetimes& replacement,
                                LifetimeRequirement requirement,
                                LifetimeConstraints& constraints) {
  assert(obj.Type().getCanonicalType() ==
         replacement.Type().getCanonicalType());
  if (!PointeeType(obj.Type()).isNull()) {
    LifetimeRequirement pointee_req =
        PointeeType(obj.Type()).isConstQualified()
            ? LifetimeRequirement::kReplacementIsLe
            : LifetimeRequirement::kReplacementIsEq;
    CollectLifetimeConstraints(obj.GetPointeeLifetimes(),
                               replacement.GetPointeeLifetimes(), requirement,
                               Compose(pointee_req, requirement), constraints);
  }
  if (obj.Type()->isRecordType()) {
    assert(obj.GetNumTemplateNestingLevels() ==
           replacement.GetNumTemplateNestingLevels());
    for (size_t depth = 0; depth < obj.GetNumTemplateNestingLevels(); depth++) {
      assert(obj.GetNumTemplateArgumentsAtDepth(depth) ==
             replacement.GetNumTemplateArgumentsAtDepth(depth));
      for (size_t idx = 0; idx < obj.GetNumTemplateArgumentsAtDepth(depth);
           idx++) {
        std::optional<ValueLifetimes> obj_arg =
            obj.GetTemplateArgumentLifetimes(depth, idx);
        std::optional<ValueLifetimes> replacement_arg =
            replacement.GetTemplateArgumentLifetimes(depth, idx);
        assert(obj_arg.has_value() == replacement_arg.has_value());
        if (obj_arg.has_value() && replacement_arg.has_value()) {
          CollectLifetimeConstraints(*obj_arg, *replacement_arg,
                                     LifetimeRequirement::kReplacementIsEq,
                                     constraints);
        }
      }
    }
    for (const auto& lftm_param : GetLifetimeParameters(obj.Type())) {
      // TODO(veluca): should lifetime parameters be invariant like template
      // parameters?
      AddConstraint(requirement, obj.GetLifetimeParameter(lftm_param),
                    replacement.GetLifetimeParameter(lftm_param), constraints);
    }
  }
  if (clang::isa<clang::FunctionProtoType>(obj.Type())) {
    CollectLifetimeConstraints(obj.GetFuncLifetimes(),
                               replacement.GetFuncLifetimes(), requirement,
                               constraints);
  }
}

void CollectLifetimeConstraints(const FunctionLifetimes& callable,
                                const FunctionLifetimes& replacement_callable,
                                LifetimeRequirement requirement,
                                LifetimeConstraints& constraints) {
  for (size_t i = 0; i < callable.GetNumParams(); i++) {
    CollectLifetimeConstraints(
        callable.GetParamLifetimes(i),
        replacement_callable.GetParamLifetimes(i),
        Compose(LifetimeRequirement::kReplacementIsGe, requirement),
        constraints);
  }
  CollectLifetimeConstraints(
      callable.GetReturnLifetimes(), replacement_callable.GetReturnLifetimes(),
      Compose(LifetimeRequirement::kReplacementIsLe, requirement), constraints);
  if (callable.IsNonStaticMethod()) {
    CollectLifetimeConstraints(
        callable.GetThisLifetimes(), replacement_callable.GetThisLifetimes(),
        Compose(LifetimeRequirement::kReplacementIsGe, requirement),
        constraints);
  }
}

}  // namespace

LifetimeConstraints LifetimeConstraints::ForCallableSubstitution(
    const FunctionLifetimes& callable,
    const FunctionLifetimes& replacement_callable) {
  LifetimeConstraints constraints =
      LifetimeConstraints::ForCallableSubstitutionFull(callable,
                                                       replacement_callable);

  llvm::DenseSet<Lifetime> all_lifetimes;
  callable.Traverse(
      [&all_lifetimes](Lifetime l, Variance) { all_lifetimes.insert(l); });

  LifetimeConstraints ret;
  for (auto l : all_lifetimes) {
    for (auto outliving : constraints.GetOutlivingLifetimes(l)) {
      if (all_lifetimes.contains(outliving)) {
        ret.AddOutlivesConstraint(l, outliving);
      }
    }
  }

  return ret;
}

LifetimeConstraints LifetimeConstraints::ForCallableSubstitutionFull(
    const FunctionLifetimes& callable,
    const FunctionLifetimes& replacement_callable) {
  LifetimeConstraints constraints;
  CollectLifetimeConstraints(callable, replacement_callable,
                             LifetimeRequirement::kReplacementIsLe,
                             constraints);
  return constraints;
}

}  // namespace lifetimes
}  // namespace tidy
}  // namespace clang
