In the Blaze Query implementation, use Set and Map implementations backed by the same KeyExtractor used that the Uniquifier implementation uses. This fixes a hypothetical issue where we were previously relying on Target#equals/hashCode.

RELNOTES: None
PiperOrigin-RevId: 159741545
diff --git a/src/main/java/com/google/devtools/build/lib/query2/BlazeQueryEnvironment.java b/src/main/java/com/google/devtools/build/lib/query2/BlazeQueryEnvironment.java
index 361514b..7fa0051 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/BlazeQueryEnvironment.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/BlazeQueryEnvironment.java
@@ -16,12 +16,15 @@
 import com.google.common.base.Function;
 import com.google.common.base.Predicate;
 import com.google.common.collect.Collections2;
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Maps;
 import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.cmdline.LabelSyntaxException;
+import com.google.devtools.build.lib.cmdline.PackageIdentifier;
 import com.google.devtools.build.lib.cmdline.ResolvedTargets;
 import com.google.devtools.build.lib.cmdline.TargetParsingException;
+import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
 import com.google.devtools.build.lib.events.ExtendedEventHandler;
 import com.google.devtools.build.lib.graph.Digraph;
 import com.google.devtools.build.lib.graph.Node;
@@ -42,6 +45,8 @@
 import com.google.devtools.build.lib.query2.engine.QueryException;
 import com.google.devtools.build.lib.query2.engine.QueryExpression;
 import com.google.devtools.build.lib.query2.engine.QueryUtil.MinDepthUniquifierImpl;
+import com.google.devtools.build.lib.query2.engine.QueryUtil.MutableKeyExtractorBackedMapImpl;
+import com.google.devtools.build.lib.query2.engine.QueryUtil.ThreadSafeMutableKeyExtractorBackedSetImpl;
 import com.google.devtools.build.lib.query2.engine.QueryUtil.UniquifierImpl;
 import com.google.devtools.build.lib.query2.engine.SkyframeRestartQueryException;
 import com.google.devtools.build.lib.query2.engine.ThreadSafeOutputFormatterCallback;
@@ -226,7 +231,7 @@
 
   @Override
   public Collection<Target> getFwdDeps(Iterable<Target> targets) {
-    Set<Target> result = new HashSet<>();
+    ThreadSafeMutableSet<Target> result = createThreadSafeMutableSet();
     for (Target target : targets) {
       result.addAll(getTargetsFromNodes(getNode(target).getSuccessors()));
     }
@@ -235,7 +240,7 @@
 
   @Override
   public Collection<Target> getReverseDeps(Iterable<Target> targets) {
-    Set<Target> result = new HashSet<>();
+    ThreadSafeMutableSet<Target> result = createThreadSafeMutableSet();
     for (Target target : targets) {
       result.addAll(getTargetsFromNodes(getNode(target).getPredecessors()));
     }
@@ -243,7 +248,8 @@
   }
 
   @Override
-  public Set<Target> getTransitiveClosure(Set<Target> targetNodes) {
+  public ThreadSafeMutableSet<Target> getTransitiveClosure(
+      ThreadSafeMutableSet<Target> targetNodes) {
     for (Target node : targetNodes) {
       checkBuilt(node);
     }
@@ -270,11 +276,10 @@
 
   @Override
   public void buildTransitiveClosure(QueryExpression caller,
-                                     Set<Target> targetNodes,
+                                     ThreadSafeMutableSet<Target> targetNodes,
                                      int maxDepth) throws QueryException, InterruptedException {
-    Set<Target> targets = targetNodes;
-    preloadTransitiveClosure(targets, maxDepth);
-    labelVisitor.syncWithVisitor(eventHandler, targets, keepGoing,
+    preloadTransitiveClosure(targetNodes, maxDepth);
+    labelVisitor.syncWithVisitor(eventHandler, targetNodes, keepGoing,
         loadingPhaseThreads, maxDepth, errorObserver, new GraphBuildingObserver());
 
     if (errorObserver.hasErrors()) {
@@ -283,8 +288,24 @@
   }
 
   @Override
-  public Set<Target> getNodesOnPath(Target from, Target to) {
-    return getTargetsFromNodes(graph.getShortestPath(getNode(from), getNode(to)));
+  public Iterable<Target> getNodesOnPath(Target from, Target to) {
+    ImmutableList.Builder<Target> builder = ImmutableList.builder();
+    for (Node<Target> node : graph.getShortestPath(getNode(from), getNode(to))) {
+      builder.add(node.getLabel());
+    }
+    return builder.build();
+  }
+
+  @ThreadSafe
+  @Override
+  public ThreadSafeMutableSet<Target> createThreadSafeMutableSet() {
+    return new ThreadSafeMutableKeyExtractorBackedSetImpl<>(
+        TargetKeyExtractor.INSTANCE, Target.class);
+  }
+
+  @Override
+  public <V> MutableMap<Target, V> createMutableMap() {
+    return new MutableKeyExtractorBackedMapImpl<>(TargetKeyExtractor.INSTANCE);
   }
 
   @Override
@@ -297,7 +318,7 @@
     return new MinDepthUniquifierImpl<>(TargetKeyExtractor.INSTANCE, /*concurrencyLevel=*/ 1);
   }
 
-  private void preloadTransitiveClosure(Set<Target> targets, int maxDepth)
+  private void preloadTransitiveClosure(ThreadSafeMutableSet<Target> targets, int maxDepth)
       throws InterruptedException {
     if (maxDepth >= MAX_DEPTH_FULL_SCAN_LIMIT && transitivePackageLoader != null) {
       // Only do the full visitation if "maxDepth" is large enough. Otherwise, the benefits of
@@ -350,15 +371,15 @@
   // TODO(bazel-team): rename this to getDependentFiles when all implementations
   // of QueryEnvironment is fixed.
   @Override
-  public Set<Target> getBuildFiles(
+  public ThreadSafeMutableSet<Target> getBuildFiles(
       final QueryExpression caller,
-      Set<Target> nodes,
+      ThreadSafeMutableSet<Target> nodes,
       boolean buildFiles,
       boolean subincludes,
       boolean loads)
       throws QueryException {
-    Set<Target> dependentFiles = new LinkedHashSet<>();
-    Set<Package> seenPackages = new HashSet<>();
+    ThreadSafeMutableSet<Target> dependentFiles = createThreadSafeMutableSet();
+    Set<PackageIdentifier> seenPackages = new HashSet<>();
     // Keep track of seen labels, to avoid adding a fake subinclude label that also exists as a
     // real target.
     Set<Label> seenLabels = new HashSet<>();
@@ -367,7 +388,7 @@
     // extensions) for package "pkg", to "buildfiles".
     for (Target x : nodes) {
       Package pkg = x.getPackage();
-      if (seenPackages.add(pkg)) {
+      if (seenPackages.add(pkg.getPackageIdentifier())) {
         if (buildFiles) {
           addIfUniqueLabel(getNode(pkg.getBuildFile()), seenLabels, dependentFiles);
         }
@@ -438,8 +459,8 @@
   }
 
   /** Given a set of target nodes, returns the targets. */
-  private static Set<Target> getTargetsFromNodes(Iterable<Node<Target>> input) {
-    Set<Target> result = new LinkedHashSet<>();
+  private ThreadSafeMutableSet<Target> getTargetsFromNodes(Iterable<Node<Target>> input) {
+    ThreadSafeMutableSet<Target> result = createThreadSafeMutableSet();
     for (Node<Target> node : input) {
       result.add(node.getLabel());
     }