In the Blaze Query implementation, use Set and Map implementations backed by the same KeyExtractor used that the Uniquifier implementation uses. This fixes a hypothetical issue where we were previously relying on Target#equals/hashCode.

RELNOTES: None
PiperOrigin-RevId: 159741545
diff --git a/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java b/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java
index 8a22522..a2cf884 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java
@@ -62,17 +62,21 @@
 import com.google.devtools.build.lib.pkgcache.PathPackageLocator;
 import com.google.devtools.build.lib.pkgcache.TargetPatternEvaluator;
 import com.google.devtools.build.lib.profiler.AutoProfiler;
+import com.google.devtools.build.lib.query2.AbstractBlazeQueryEnvironment.TargetKeyExtractor;
 import com.google.devtools.build.lib.query2.engine.AllRdepsFunction;
 import com.google.devtools.build.lib.query2.engine.Callback;
 import com.google.devtools.build.lib.query2.engine.FunctionExpression;
 import com.google.devtools.build.lib.query2.engine.KeyExtractor;
 import com.google.devtools.build.lib.query2.engine.MinDepthUniquifier;
 import com.google.devtools.build.lib.query2.engine.OutputFormatterCallback;
+import com.google.devtools.build.lib.query2.engine.QueryEnvironment.MutableMap;
 import com.google.devtools.build.lib.query2.engine.QueryEvalResult;
 import com.google.devtools.build.lib.query2.engine.QueryException;
 import com.google.devtools.build.lib.query2.engine.QueryExpression;
 import com.google.devtools.build.lib.query2.engine.QueryExpressionMapper;
 import com.google.devtools.build.lib.query2.engine.QueryUtil.MinDepthUniquifierImpl;
+import com.google.devtools.build.lib.query2.engine.QueryUtil.MutableKeyExtractorBackedMapImpl;
+import com.google.devtools.build.lib.query2.engine.QueryUtil.ThreadSafeMutableKeyExtractorBackedSetImpl;
 import com.google.devtools.build.lib.query2.engine.QueryUtil.UniquifierImpl;
 import com.google.devtools.build.lib.query2.engine.RdepsFunction;
 import com.google.devtools.build.lib.query2.engine.StreamableQueryEnvironment;
@@ -110,7 +114,6 @@
 import java.util.Deque;
 import java.util.HashMap;
 import java.util.HashSet;
-import java.util.LinkedHashSet;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
@@ -430,30 +433,6 @@
     return result.build();
   }
 
-  private Map<Target, Collection<Target>> targetifyKeys(Map<SkyKey, Collection<Target>> input)
-      throws InterruptedException {
-    Map<SkyKey, Target> targets = makeTargetsFromSkyKeys(input.keySet());
-    ImmutableMap.Builder<Target, Collection<Target>> resultBuilder = ImmutableMap.builder();
-    for (Map.Entry<SkyKey, Collection<Target>> entry : input.entrySet()) {
-      SkyKey key = entry.getKey();
-      Target target = targets.get(key);
-      if (target != null) {
-        resultBuilder.put(target, entry.getValue());
-      }
-    }
-    return resultBuilder.build();
-  }
-
-  private Map<Target, Collection<Target>> targetifyKeysAndValues(
-      Map<SkyKey, Iterable<SkyKey>> input) throws InterruptedException {
-    return targetifyKeys(targetifyValues(input));
-  }
-
-  private Map<Target, Collection<Target>> getRawFwdDeps(Iterable<Target> targets)
-      throws InterruptedException {
-    return targetifyKeysAndValues(graph.getDirectDeps(makeTransitiveTraversalKeys(targets)));
-  }
-
   private Map<SkyKey, Collection<Target>> getRawReverseDeps(
       Iterable<SkyKey> transitiveTraversalKeys) throws InterruptedException {
     return targetifyValues(graph.getReverseDeps(transitiveTraversalKeys));
@@ -482,22 +461,24 @@
         });
   }
 
-  /** Targets may not be in the graph because they are not in the universe or depend on cycles. */
-  private void warnIfMissingTargets(
-      Iterable<Target> targets, Set<Target> result) {
-    if (Iterables.size(targets) != result.size()) {
-      Set<Target> missingTargets = Sets.difference(ImmutableSet.copyOf(targets), result);
+  @Override
+  public ThreadSafeMutableSet<Target> getFwdDeps(Iterable<Target> targets)
+      throws InterruptedException {
+    Map<SkyKey, Target> targetsByKey = new HashMap<>(Iterables.size(targets));
+    for (Target target : targets) {
+      targetsByKey.put(TARGET_TO_SKY_KEY.apply(target), target);
+    }
+    Map<SkyKey, Collection<Target>> directDeps = targetifyValues(
+        graph.getDirectDeps(targetsByKey.keySet()));
+    if (targetsByKey.keySet().size() != directDeps.keySet().size()) {
+      Iterable<Label> missingTargets = Iterables.transform(
+          Sets.difference(targetsByKey.keySet(), directDeps.keySet()),
+          SKYKEY_TO_LABEL);
       eventHandler.handle(Event.warn("Targets were missing from graph: " + missingTargets));
     }
-  }
-
-  @Override
-  public Collection<Target> getFwdDeps(Iterable<Target> targets) throws InterruptedException {
-    Set<Target> result = new HashSet<>();
-    Map<Target, Collection<Target>> rawFwdDeps = getRawFwdDeps(targets);
-    warnIfMissingTargets(targets, rawFwdDeps.keySet());
-    for (Map.Entry<Target, Collection<Target>> entry : rawFwdDeps.entrySet()) {
-      result.addAll(filterFwdDeps(entry.getKey(), entry.getValue()));
+    ThreadSafeMutableSet<Target> result = createThreadSafeMutableSet();
+    for (Map.Entry<SkyKey, Collection<Target>> entry : directDeps.entrySet()) {
+      result.addAll(filterFwdDeps(targetsByKey.get(entry.getKey()), entry.getValue()));
     }
     return result;
   }
@@ -555,35 +536,46 @@
   }
 
   @Override
-  public Set<Target> getTransitiveClosure(Set<Target> targets) throws InterruptedException {
-    Set<Target> visited = new HashSet<>();
-    Collection<Target> current = targets;
+  public ThreadSafeMutableSet<Target> getTransitiveClosure(ThreadSafeMutableSet<Target> targets)
+      throws InterruptedException {
+    ThreadSafeMutableSet<Target> visited = createThreadSafeMutableSet();
+    ThreadSafeMutableSet<Target> current = targets;
     while (!current.isEmpty()) {
-      Collection<Target> toVisit = Collections2.filter(current,
+      Iterable<Target> toVisit = Iterables.filter(current,
           Predicates.not(Predicates.in(visited)));
       current = getFwdDeps(toVisit);
-      visited.addAll(toVisit);
+      Iterables.addAll(visited, toVisit);
     }
-    return ImmutableSet.copyOf(visited);
+    return visited;
   }
 
   // Implemented with a breadth-first search.
   @Override
-  public Set<Target> getNodesOnPath(Target from, Target to) throws InterruptedException {
+  public ImmutableList<Target> getNodesOnPath(Target from, Target to)
+      throws InterruptedException {
     // Tree of nodes visited so far.
-    Map<Target, Target> nodeToParent = new HashMap<>();
+    Map<Label, Label> nodeToParent = new HashMap<>();
+    Map<Label, Target> labelToTarget = new HashMap<>();
     // Contains all nodes left to visit in a (LIFO) stack.
     Deque<Target> toVisit = new ArrayDeque<>();
     toVisit.add(from);
-    nodeToParent.put(from, null);
+    nodeToParent.put(from.getLabel(), null);
+    labelToTarget.put(from.getLabel(), from);
     while (!toVisit.isEmpty()) {
       Target current = toVisit.removeFirst();
       if (to.equals(current)) {
-        return ImmutableSet.copyOf(Digraph.getPathToTreeNode(nodeToParent, to));
+        List<Label> labelPath = Digraph.getPathToTreeNode(nodeToParent, to.getLabel());
+        ImmutableList.Builder<Target> targetPathBuilder = ImmutableList.builder();
+        for (Label label : labelPath) {
+          targetPathBuilder.add(Preconditions.checkNotNull(labelToTarget.get(label), label));
+        }
+        return targetPathBuilder.build();
       }
       for (Target dep : getFwdDeps(ImmutableList.of(current))) {
-        if (!nodeToParent.containsKey(dep)) {
-          nodeToParent.put(dep, current);
+        Label depLabel = dep.getLabel();
+        if (!nodeToParent.containsKey(depLabel)) {
+          nodeToParent.put(depLabel, current.getLabel());
+          labelToTarget.put(depLabel, dep);
           toVisit.addFirst(dep);
         }
       }
@@ -649,6 +641,18 @@
 
   @ThreadSafe
   @Override
+  public ThreadSafeMutableSet<Target> createThreadSafeMutableSet() {
+    return new ThreadSafeMutableKeyExtractorBackedSetImpl<>(
+        TargetKeyExtractor.INSTANCE, Target.class, DEFAULT_THREAD_COUNT);
+  }
+
+  @Override
+  public <V> MutableMap<Target, V> createMutableMap() {
+    return new MutableKeyExtractorBackedMapImpl<Target, Label, V>(TargetKeyExtractor.INSTANCE);
+  }
+
+  @ThreadSafe
+  @Override
   public Uniquifier<Target> createUniquifier() {
     return createTargetUniquifier();
   }
@@ -731,15 +735,15 @@
 
   @ThreadSafe
   @Override
-  public Set<Target> getBuildFiles(
+  public ThreadSafeMutableSet<Target> getBuildFiles(
       QueryExpression caller,
-      Set<Target> nodes,
+      ThreadSafeMutableSet<Target> nodes,
       boolean buildFiles,
       boolean subincludes,
       boolean loads)
       throws QueryException {
-    Set<Target> dependentFiles = new LinkedHashSet<>();
-    Set<Package> seenPackages = new HashSet<>();
+    ThreadSafeMutableSet<Target> dependentFiles = createThreadSafeMutableSet();
+    Set<PackageIdentifier> seenPackages = new HashSet<>();
     // Keep track of seen labels, to avoid adding a fake subinclude label that also exists as a
     // real target.
     Set<Label> seenLabels = new HashSet<>();
@@ -748,7 +752,7 @@
     // extensions) for package "pkg", to "buildfiles".
     for (Target x : nodes) {
       Package pkg = x.getPackage();
-      if (seenPackages.add(pkg)) {
+      if (seenPackages.add(pkg.getPackageIdentifier())) {
         if (buildFiles) {
           addIfUniqueLabel(pkg.getBuildFile(), seenLabels, dependentFiles);
         }
@@ -843,8 +847,10 @@
   }
 
   @Override
-  public void buildTransitiveClosure(QueryExpression caller, Set<Target> targets, int maxDepth)
-      throws QueryException, InterruptedException {
+  public void buildTransitiveClosure(
+      QueryExpression caller,
+      ThreadSafeMutableSet<Target> targets,
+      int maxDepth) throws QueryException, InterruptedException {
     // Everything has already been loaded, so here we just check for errors so that we can
     // pre-emptively throw/report if needed.
     Iterable<SkyKey> transitiveTraversalKeys = makeTransitiveTraversalKeys(targets);