Make sorting of graphless query more efficient.

PiperOrigin-RevId: 319806006
diff --git a/src/main/java/com/google/devtools/build/lib/cmdline/Label.java b/src/main/java/com/google/devtools/build/lib/cmdline/Label.java
index c067736..1214e3d 100644
--- a/src/main/java/com/google/devtools/build/lib/cmdline/Label.java
+++ b/src/main/java/com/google/devtools/build/lib/cmdline/Label.java
@@ -648,6 +648,9 @@
    */
   @Override
   public int compareTo(Label other) {
+    if (this == other) {
+      return 0;
+    }
     return ComparisonChain.start()
         .compare(packageIdentifier, other.packageIdentifier)
         .compare(name, other.name)
diff --git a/src/main/java/com/google/devtools/build/lib/cmdline/PackageIdentifier.java b/src/main/java/com/google/devtools/build/lib/cmdline/PackageIdentifier.java
index eecec56..38afc5b 100644
--- a/src/main/java/com/google/devtools/build/lib/cmdline/PackageIdentifier.java
+++ b/src/main/java/com/google/devtools/build/lib/cmdline/PackageIdentifier.java
@@ -228,7 +228,15 @@
   }
 
   @Override
+  @SuppressWarnings("ReferenceEquality") // Performance optimization.
   public int compareTo(PackageIdentifier that) {
+    // Fast-paths for the common case of the same package or a package in the same repository.
+    if (this == that) {
+      return 0;
+    }
+    if (repository == that.repository) {
+      return pkgName.compareTo(that.pkgName);
+    }
     return ComparisonChain.start()
         .compare(repository.toString(), that.repository.toString())
         .compare(pkgName, that.pkgName)
diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/BUILD b/src/main/java/com/google/devtools/build/lib/query2/engine/BUILD
index 6c2474ef..5f75194 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/engine/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/query2/engine/BUILD
@@ -12,9 +12,11 @@
     name = "engine",
     srcs = glob(["*.java"]),
     deps = [
+        "//src/main/java/com/google/devtools/build/lib/cmdline",
         "//src/main/java/com/google/devtools/build/lib/collect/compacthashset",
         "//src/main/java/com/google/devtools/build/lib/concurrent",
         "//src/main/java/com/google/devtools/build/lib/graph",
+        "//src/main/java/com/google/devtools/build/lib/packages",
         "//src/main/java/com/google/devtools/build/lib/profiler",
         "//third_party:guava",
         "//third_party:jsr305",
diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/QueryUtil.java b/src/main/java/com/google/devtools/build/lib/query2/engine/QueryUtil.java
index 32843e9..0661017 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/engine/QueryUtil.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/engine/QueryUtil.java
@@ -14,8 +14,11 @@
 package com.google.devtools.build.lib.query2.engine;
 
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSortedSet;
 import com.google.common.collect.Iterables;
+import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.collect.compacthashset.CompactHashSet;
+import com.google.devtools.build.lib.packages.Target;
 import com.google.devtools.build.lib.query2.engine.QueryEnvironment.MutableMap;
 import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryTaskCallable;
 import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryTaskFuture;
@@ -26,6 +29,7 @@
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
+import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
@@ -67,7 +71,7 @@
     }
   }
 
-  private static class OrderedAggregateAllOutputFormatterCallbackImpl<T>
+  private static final class OrderedAggregateAllOutputFormatterCallbackImpl<T>
       extends AggregateAllOutputFormatterCallback<T, Set<T>> {
     private final Set<T> resultSet;
     private final List<T> resultList;
@@ -78,7 +82,7 @@
     }
 
     @Override
-    public final synchronized void processOutput(Iterable<T> partialResult) {
+    public synchronized void processOutput(Iterable<T> partialResult) {
       for (T element : partialResult) {
         if (resultSet.add(element)) {
           resultList.add(element);
@@ -95,6 +99,29 @@
     }
   }
 
+  private static final class LexicographicallySortedTargetAggregator
+      extends AggregateAllOutputFormatterCallback<Target, Set<Target>> {
+    private final Map<Label, Target> resultMap = new HashMap<>();
+
+    @Override
+    public synchronized void processOutput(Iterable<Target> partialResult) {
+      for (Target target : partialResult) {
+        resultMap.put(target.getLabel(), target);
+      }
+    }
+
+    @Override
+    public synchronized ImmutableSortedSet<Target> getResult() {
+      return ImmutableSortedSet.copyOf(
+          LexicographicallySortedTargetAggregator::compareTargetsByLabel, resultMap.values());
+    }
+
+    // A reference to this method is significantly more efficient than using Comparator#comparing.
+    private static int compareTargetsByLabel(Target t1, Target t2) {
+      return t1.getLabel().compareTo(t2.getLabel());
+    }
+  }
+
   /**
    * Returns a fresh {@link AggregateAllOutputFormatterCallback} instance whose
    * {@link AggregateAllCallback#getResult} returns all the elements of the result in the order they
@@ -106,6 +133,16 @@
   }
 
   /**
+   * Returns a fresh {@link AggregateAllOutputFormatterCallback} instance whose {@link
+   * AggregateAllCallback#getResult} returns all the targets in the result sorted lexicographically
+   * by {@link Label}.
+   */
+  public static AggregateAllOutputFormatterCallback<Target, Set<Target>>
+      newLexicographicallySortedTargetAggregator() {
+    return new LexicographicallySortedTargetAggregator();
+  }
+
+  /**
    * Returns a fresh {@link AggregateAllCallback} instance that aggregates all of the values into an
    * {@link ThreadSafeMutableSet}.
    */
diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/RegexFilterExpression.java b/src/main/java/com/google/devtools/build/lib/query2/engine/RegexFilterExpression.java
index e66fa63..870b11f 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/engine/RegexFilterExpression.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/engine/RegexFilterExpression.java
@@ -15,13 +15,15 @@
 
 import com.google.common.base.Predicate;
 import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Iterables;
+import com.google.common.collect.Iterators;
 import com.google.devtools.build.lib.query2.engine.QueryEnvironment.Argument;
 import com.google.devtools.build.lib.query2.engine.QueryEnvironment.FilteringQueryFunction;
 import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryTaskFuture;
+import java.util.Iterator;
 import java.util.List;
 import java.util.regex.Pattern;
 import java.util.regex.PatternSyntaxException;
+import javax.annotation.Nullable;
 
 /**
  * An abstract class that provides generic regex filter expression. Actual expression are
@@ -104,7 +106,7 @@
 
   protected abstract String getPattern(List<Argument> args);
 
-  private static class FilteredCallback<T> implements Callback<T> {
+  private static final class FilteredCallback<T> implements Callback<T> {
     private final Callback<T> parentCallback;
     private final Predicate<T> retainIfTrue;
 
@@ -115,9 +117,18 @@
 
     @Override
     public void process(Iterable<T> partialResult) throws QueryException, InterruptedException {
-      Iterable<T> filter = Iterables.filter(partialResult, retainIfTrue);
-      if (!Iterables.isEmpty(filter)) {
-        parentCallback.process(filter);
+      // Consume as much of the iterable as possible here. The parent callback may be synchronized,
+      // so we can avoid calling it at all if everything gets filtered out.
+      Iterator<T> it = partialResult.iterator();
+      while (it.hasNext()) {
+        T next = it.next();
+        if (retainIfTrue.apply(next)) {
+          Iterator<T> rest =
+              Iterators.concat(
+                  Iterators.singletonIterator(next), Iterators.filter(it, retainIfTrue));
+          parentCallback.process(new InProgressIterable<>(rest, partialResult, retainIfTrue));
+          break;
+        }
       }
     }
 
@@ -125,5 +136,34 @@
     public String toString() {
       return "filtered parentCallback of : " + retainIfTrue;
     }
+
+    /**
+     * Specialized {@link Iterable} to resume an in-progress iteration on the first call to {@link
+     * #iterator}.
+     */
+    private static final class InProgressIterable<T> implements Iterable<T> {
+      @Nullable private Iterator<T> inProgress;
+      private final Iterable<T> original;
+      private final Predicate<T> retainIfTrue;
+
+      private InProgressIterable(
+          Iterator<T> inProgress, Iterable<T> original, Predicate<T> retainIfTrue) {
+        this.inProgress = inProgress;
+        this.original = original;
+        this.retainIfTrue = retainIfTrue;
+      }
+
+      @Override
+      public Iterator<T> iterator() {
+        synchronized (this) {
+          if (inProgress != null) {
+            Iterator<T> it = inProgress;
+            inProgress = null;
+            return it;
+          }
+        }
+        return Iterators.filter(original.iterator(), retainIfTrue);
+      }
+    }
   }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/rules/genquery/GenQuery.java b/src/main/java/com/google/devtools/build/lib/rules/genquery/GenQuery.java
index 247290b..0de2e1a 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/genquery/GenQuery.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/genquery/GenQuery.java
@@ -20,7 +20,6 @@
 import com.google.common.collect.Collections2;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.ImmutableSortedSet;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
@@ -100,7 +99,6 @@
 import java.io.OutputStream;
 import java.nio.channels.ClosedByInterruptException;
 import java.util.Collection;
-import java.util.Comparator;
 import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
@@ -372,7 +370,10 @@
               /*useGraphlessQuery=*/ graphlessQuery);
       QueryExpression expr = QueryExpression.parse(query, queryEnvironment);
       formatter.verifyCompatible(queryEnvironment, expr);
-      targets = QueryUtil.newOrderedAggregateAllOutputFormatterCallback(queryEnvironment);
+      targets =
+          graphlessQuery && queryOptions.forceSortForGraphlessGenquery
+              ? QueryUtil.newLexicographicallySortedTargetAggregator()
+              : QueryUtil.newOrderedAggregateAllOutputFormatterCallback(queryEnvironment);
       queryResult = queryEnvironment.evaluateQuery(expr, targets);
     } catch (SkyframeRestartQueryException e) {
       // Do not emit errors for skyframe restarts. They make output of the ConfiguredTargetFunction
@@ -389,12 +390,8 @@
         ruleContext.getConfiguration().getFragment(GenQueryConfiguration.class);
     GenQueryOutputStream outputStream =
         new GenQueryOutputStream(genQueryConfig.inMemoryCompressionEnabled());
+    Set<Target> result = targets.getResult();
     try {
-      Set<Target> result = targets.getResult();
-      if (graphlessQuery && queryOptions.forceSortForGraphlessGenquery) {
-        result =
-            ImmutableSortedSet.copyOf(Comparator.comparing(Target::getLabel), targets.getResult());
-      }
       QueryOutputUtils.output(
           queryOptions,
           queryResult,