Rewrite constraint resolution with a different (hopefully correct) algorithm.

The previous algorithm considered *all* lifetime constraints (except those involving local or static) to be equality constraints. However, this is incorrect, because even if we cannot represent inequality constraints in the signature, inequality constraints between lifetimes that are not visible in the signature are still meaningful.

Moreover, the previous algorithm did not handle the case of unconstrained / outlived-by-static output lifetimes (which can be made 'static).

For example, consider:
```
'local1 >= 'local2
'param1 >= 'local2
```

The previous algorithm would incorrectly deduce that the function leaked a 'local1 lifetime through 'param1.

PiperOrigin-RevId: 474809494
diff --git a/lifetime_analysis/lifetime_constraints.cc b/lifetime_analysis/lifetime_constraints.cc
index abecc03..3289f6c 100644
--- a/lifetime_analysis/lifetime_constraints.cc
+++ b/lifetime_analysis/lifetime_constraints.cc
@@ -4,6 +4,10 @@
 
 #include "lifetime_analysis/lifetime_constraints.h"
 
+#include <llvm/ADT/DenseSet.h>
+
+#include <algorithm>
+
 #include "lifetime_annotations/lifetime_substitutions.h"
 
 namespace clang {
@@ -11,7 +15,7 @@
 namespace lifetimes {
 
 clang::dataflow::LatticeJoinEffect LifetimeConstraints::join(
-    const LifetimeConstraints& other) {
+    const LifetimeConstraints &other) {
   bool changed = false;
   for (auto p : other.outlives_constraints_) {
     changed |= outlives_constraints_.insert(p).second;
@@ -20,77 +24,145 @@
                  : clang::dataflow::LatticeJoinEffect::Unchanged;
 }
 
-llvm::Error LifetimeConstraints::ApplyToFunctionLifetimes(
-    FunctionLifetimes& function_lifetimes) {
-  // Since we do not support "outlives" annotations, we can simply
-  // find all the *connected* components on an undirected graph where there is
-  // an edge between lifetime a and lifetime b iff there is a constraint in
-  // either direction between a and b, ignoring `static` and `local` (which are
-  // handled separately).
-  // TODO(veluca): do this properly by collapsing SCCs and analyzing the
-  // resulting outlives DAG if/when we support "outlives" annotations.
-  llvm::DenseMap<Lifetime, llvm::SmallVector<Lifetime>> outlives_edges;
-  llvm::DenseSet<Lifetime> all_lifetimes;
-  // CCs that contain any of these lifetimes must be substituted with `$static`.
-  llvm::DenseSet<Lifetime> outlives_static;
-  llvm::DenseMap<Lifetime, Lifetime> is_outlived_by_local;
-  for (auto [shorter, longer] : outlives_constraints_) {
-    all_lifetimes.insert(longer);
-    all_lifetimes.insert(shorter);
-    if (shorter.IsVariable() && longer.IsVariable()) {
-      outlives_edges[longer].push_back(shorter);
-      outlives_edges[shorter].push_back(longer);
-    }
-    if (shorter == Lifetime::Static()) {
-      outlives_static.insert(longer);
-    }
-    if (longer.IsLocal()) {
-      is_outlived_by_local[shorter] = longer;
+namespace {
+
+// Simple Disjoint-Set-Union with path compression (but no union-by-rank). This
+// guarantees O(log n) time per operation.
+class LifetimeDSU {
+ public:
+  void MakeSet(Lifetime l) { parent_[l] = l; }
+  Lifetime Find(Lifetime l) {
+    if (l == parent_[l]) return l;
+    return parent_[l] = Find(parent_[l]);
+  }
+  void Union(Lifetime a, Lifetime b) {
+    a = Find(a);
+    b = Find(b);
+    if (a != b) {
+      parent_[a] = b;
     }
   }
 
-  llvm::DenseSet<Lifetime> visited;
-  LifetimeSubstitutions substitutions;
+ private:
+  llvm::DenseMap<Lifetime, Lifetime> parent_;
+};
 
-  for (const auto& lifetime : all_lifetimes) {
-    if (!lifetime.IsVariable() || visited.count(lifetime)) continue;
+}  // namespace
 
-    llvm::SmallVector<Lifetime> connected_component;
-    llvm::SmallVector<Lifetime> stack;
-    stack.push_back(lifetime);
-    bool cc_outlives_static = false;
-    bool cc_is_outlived_by_local = false;
-    Lifetime local_outliving_lifetime;
+llvm::Error LifetimeConstraints::ApplyToFunctionLifetimes(
+    FunctionLifetimes &function_lifetimes) {
+  // We want to make output-only lifetimes as long as possible; thus, we collect
+  // those separately.
+  llvm::DenseSet<Lifetime> output_lifetimes;
+  function_lifetimes.GetReturnLifetimes().Traverse(
+      [&output_lifetimes](Lifetime l, Variance) {
+        output_lifetimes.insert(l);
+      });
+
+  // Collect all "interesting" lifetimes, i.e. all lifetimes that appear in the
+  // function call.
+  llvm::DenseSet<Lifetime> all_lifetimes;
+  function_lifetimes.Traverse(
+      [&all_lifetimes](Lifetime l, Variance) { all_lifetimes.insert(l); });
+
+  // Compute the set of static, input or local lifetimes that must outlive the
+  // given lifetime (excluding the lifetime itself).
+  // This function ignores constraints of the form 'a <= 'static, as "outlived
+  // by 'static" is not a meaningful constraint.
+  // TODO(veluca): here we could certainly reduce complexity, for example by
+  // constructing the constraint graph instead of iterating over all constraints
+  // each time.
+  auto get_outliving_lifetimes = [&](Lifetime l) {
+    std::vector<Lifetime> stack{l};
+    llvm::DenseSet<Lifetime> visited;
     while (!stack.empty()) {
-      Lifetime cur = stack.back();
+      Lifetime v = stack.back();
       stack.pop_back();
-      if (visited.count(cur)) continue;
-      visited.insert(cur);
-      if (outlives_static.count(cur)) {
-        cc_outlives_static = true;
-      }
-      if (is_outlived_by_local.count(cur)) {
-        cc_is_outlived_by_local = true;
-        local_outliving_lifetime = is_outlived_by_local[cur];
-      }
-      connected_component.push_back(cur);
-      for (auto next : outlives_edges[cur]) {
-        stack.push_back(next);
+      if (visited.contains(v)) continue;
+      visited.insert(v);
+      for (auto [shorter, longer] : outlives_constraints_) {
+        if (shorter == v && longer != Lifetime::Static()) {
+          stack.push_back(longer);
+        }
       }
     }
-    if (cc_outlives_static && cc_is_outlived_by_local) {
+    visited.erase(l);
+    return visited;
+  };
+
+  LifetimeSubstitutions substitutions;
+
+  // Keep track of which lifetimes already have their final substitutions
+  // computed.
+  llvm::DenseSet<Lifetime> already_have_substitutions;
+
+  // First of all, substitute everything that outlives 'static with 'static.
+  for (Lifetime outlives_static : get_outliving_lifetimes(Lifetime::Static())) {
+    if (outlives_static.IsLocal()) {
       return llvm::createStringError(llvm::inconvertibleErrorCode(),
                                      "Function assigns local to static");
     }
-    Lifetime representative = cc_outlives_static ? Lifetime::Static()
-                              : cc_is_outlived_by_local
-                                  ? local_outliving_lifetime
-                                  : lifetime;
+    already_have_substitutions.insert(outlives_static);
+    substitutions.Add(outlives_static, Lifetime::Static());
+  }
 
-    for (Lifetime memb : connected_component) {
-      substitutions.Add(memb, representative);
+  LifetimeDSU dsu;
+  dsu.MakeSet(Lifetime::Static());
+  for (Lifetime lifetime : all_lifetimes) {
+    dsu.MakeSet(lifetime);
+  }
+
+  for (Lifetime lifetime : all_lifetimes) {
+    llvm::DenseSet<Lifetime> longer_lifetimes =
+        get_outliving_lifetimes(lifetime);
+    assert(!longer_lifetimes.contains(Lifetime::Static()));
+
+    // Replace unconstrained output lifetimes with 'static.
+    if (output_lifetimes.contains(lifetime) && longer_lifetimes.empty()) {
+      substitutions.Add(lifetime, Lifetime::Static());
+      already_have_substitutions.insert(lifetime);
+      continue;
+    }
+
+    // If constrained to be outlived by 'local, replace the lifetime with
+    // 'local, or error out if 'static.
+    auto local_it =
+        std::find_if(longer_lifetimes.begin(), longer_lifetimes.end(),
+                     [](Lifetime l) { return l.IsLocal(); });
+
+    if (local_it != longer_lifetimes.end()) {
+      substitutions.Add(lifetime, *local_it);
+      already_have_substitutions.insert(lifetime);
+      continue;
+    }
+
+    // Now all the longer lifetimes must be variable lifetimes. As we do not
+    // support inequalities, we simply state that they must be equivalent.
+    for (Lifetime longer : longer_lifetimes) {
+      if (already_have_substitutions.contains(longer)) continue;
+      dsu.Union(longer, lifetime);
     }
   }
+
+  // Everything that is equivalent to 'static must be replaced by 'static, not
+  // by an arbitrary lifetime in the equivalence set.
+  Lifetime cc_of_static = dsu.Find(Lifetime::Static());
+
+  for (Lifetime lifetime : all_lifetimes) {
+    if (already_have_substitutions.contains(lifetime) ||
+        !lifetime.IsVariable()) {
+      continue;
+    }
+
+    Lifetime cc = dsu.Find(lifetime);
+
+    if (cc == cc_of_static) {
+      substitutions.Add(lifetime, Lifetime::Static());
+    } else {
+      substitutions.Add(lifetime, cc);
+    }
+  }
+
   function_lifetimes.SubstituteLifetimes(substitutions);
 
   return llvm::Error::success();