Make sorting of graphless query more efficient.
PiperOrigin-RevId: 319806006
diff --git a/src/main/java/com/google/devtools/build/lib/cmdline/Label.java b/src/main/java/com/google/devtools/build/lib/cmdline/Label.java
index c067736..1214e3d 100644
--- a/src/main/java/com/google/devtools/build/lib/cmdline/Label.java
+++ b/src/main/java/com/google/devtools/build/lib/cmdline/Label.java
@@ -648,6 +648,9 @@
*/
@Override
public int compareTo(Label other) {
+ if (this == other) {
+ return 0;
+ }
return ComparisonChain.start()
.compare(packageIdentifier, other.packageIdentifier)
.compare(name, other.name)
diff --git a/src/main/java/com/google/devtools/build/lib/cmdline/PackageIdentifier.java b/src/main/java/com/google/devtools/build/lib/cmdline/PackageIdentifier.java
index eecec56..38afc5b 100644
--- a/src/main/java/com/google/devtools/build/lib/cmdline/PackageIdentifier.java
+++ b/src/main/java/com/google/devtools/build/lib/cmdline/PackageIdentifier.java
@@ -228,7 +228,15 @@
}
@Override
+ @SuppressWarnings("ReferenceEquality") // Performance optimization.
public int compareTo(PackageIdentifier that) {
+ // Fast-paths for the common case of the same package or a package in the same repository.
+ if (this == that) {
+ return 0;
+ }
+ if (repository == that.repository) {
+ return pkgName.compareTo(that.pkgName);
+ }
return ComparisonChain.start()
.compare(repository.toString(), that.repository.toString())
.compare(pkgName, that.pkgName)
diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/BUILD b/src/main/java/com/google/devtools/build/lib/query2/engine/BUILD
index 6c2474ef..5f75194 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/engine/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/query2/engine/BUILD
@@ -12,9 +12,11 @@
name = "engine",
srcs = glob(["*.java"]),
deps = [
+ "//src/main/java/com/google/devtools/build/lib/cmdline",
"//src/main/java/com/google/devtools/build/lib/collect/compacthashset",
"//src/main/java/com/google/devtools/build/lib/concurrent",
"//src/main/java/com/google/devtools/build/lib/graph",
+ "//src/main/java/com/google/devtools/build/lib/packages",
"//src/main/java/com/google/devtools/build/lib/profiler",
"//third_party:guava",
"//third_party:jsr305",
diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/QueryUtil.java b/src/main/java/com/google/devtools/build/lib/query2/engine/QueryUtil.java
index 32843e9..0661017 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/engine/QueryUtil.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/engine/QueryUtil.java
@@ -14,8 +14,11 @@
package com.google.devtools.build.lib.query2.engine;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSortedSet;
import com.google.common.collect.Iterables;
+import com.google.devtools.build.lib.cmdline.Label;
import com.google.devtools.build.lib.collect.compacthashset.CompactHashSet;
+import com.google.devtools.build.lib.packages.Target;
import com.google.devtools.build.lib.query2.engine.QueryEnvironment.MutableMap;
import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryTaskCallable;
import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryTaskFuture;
@@ -26,6 +29,7 @@
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
+import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
@@ -67,7 +71,7 @@
}
}
- private static class OrderedAggregateAllOutputFormatterCallbackImpl<T>
+ private static final class OrderedAggregateAllOutputFormatterCallbackImpl<T>
extends AggregateAllOutputFormatterCallback<T, Set<T>> {
private final Set<T> resultSet;
private final List<T> resultList;
@@ -78,7 +82,7 @@
}
@Override
- public final synchronized void processOutput(Iterable<T> partialResult) {
+ public synchronized void processOutput(Iterable<T> partialResult) {
for (T element : partialResult) {
if (resultSet.add(element)) {
resultList.add(element);
@@ -95,6 +99,29 @@
}
}
+ private static final class LexicographicallySortedTargetAggregator
+ extends AggregateAllOutputFormatterCallback<Target, Set<Target>> {
+ private final Map<Label, Target> resultMap = new HashMap<>();
+
+ @Override
+ public synchronized void processOutput(Iterable<Target> partialResult) {
+ for (Target target : partialResult) {
+ resultMap.put(target.getLabel(), target);
+ }
+ }
+
+ @Override
+ public synchronized ImmutableSortedSet<Target> getResult() {
+ return ImmutableSortedSet.copyOf(
+ LexicographicallySortedTargetAggregator::compareTargetsByLabel, resultMap.values());
+ }
+
+ // A reference to this method is significantly more efficient than using Comparator#comparing.
+ private static int compareTargetsByLabel(Target t1, Target t2) {
+ return t1.getLabel().compareTo(t2.getLabel());
+ }
+ }
+
/**
* Returns a fresh {@link AggregateAllOutputFormatterCallback} instance whose
* {@link AggregateAllCallback#getResult} returns all the elements of the result in the order they
@@ -106,6 +133,16 @@
}
/**
+ * Returns a fresh {@link AggregateAllOutputFormatterCallback} instance whose {@link
+ * AggregateAllCallback#getResult} returns all the targets in the result sorted lexicographically
+ * by {@link Label}.
+ */
+ public static AggregateAllOutputFormatterCallback<Target, Set<Target>>
+ newLexicographicallySortedTargetAggregator() {
+ return new LexicographicallySortedTargetAggregator();
+ }
+
+ /**
* Returns a fresh {@link AggregateAllCallback} instance that aggregates all of the values into an
* {@link ThreadSafeMutableSet}.
*/
diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/RegexFilterExpression.java b/src/main/java/com/google/devtools/build/lib/query2/engine/RegexFilterExpression.java
index e66fa63..870b11f 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/engine/RegexFilterExpression.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/engine/RegexFilterExpression.java
@@ -15,13 +15,15 @@
import com.google.common.base.Predicate;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Iterables;
+import com.google.common.collect.Iterators;
import com.google.devtools.build.lib.query2.engine.QueryEnvironment.Argument;
import com.google.devtools.build.lib.query2.engine.QueryEnvironment.FilteringQueryFunction;
import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryTaskFuture;
+import java.util.Iterator;
import java.util.List;
import java.util.regex.Pattern;
import java.util.regex.PatternSyntaxException;
+import javax.annotation.Nullable;
/**
* An abstract class that provides generic regex filter expression. Actual expression are
@@ -104,7 +106,7 @@
protected abstract String getPattern(List<Argument> args);
- private static class FilteredCallback<T> implements Callback<T> {
+ private static final class FilteredCallback<T> implements Callback<T> {
private final Callback<T> parentCallback;
private final Predicate<T> retainIfTrue;
@@ -115,9 +117,18 @@
@Override
public void process(Iterable<T> partialResult) throws QueryException, InterruptedException {
- Iterable<T> filter = Iterables.filter(partialResult, retainIfTrue);
- if (!Iterables.isEmpty(filter)) {
- parentCallback.process(filter);
+ // Consume as much of the iterable as possible here. The parent callback may be synchronized,
+ // so we can avoid calling it at all if everything gets filtered out.
+ Iterator<T> it = partialResult.iterator();
+ while (it.hasNext()) {
+ T next = it.next();
+ if (retainIfTrue.apply(next)) {
+ Iterator<T> rest =
+ Iterators.concat(
+ Iterators.singletonIterator(next), Iterators.filter(it, retainIfTrue));
+ parentCallback.process(new InProgressIterable<>(rest, partialResult, retainIfTrue));
+ break;
+ }
}
}
@@ -125,5 +136,34 @@
public String toString() {
return "filtered parentCallback of : " + retainIfTrue;
}
+
+ /**
+ * Specialized {@link Iterable} to resume an in-progress iteration on the first call to {@link
+ * #iterator}.
+ */
+ private static final class InProgressIterable<T> implements Iterable<T> {
+ @Nullable private Iterator<T> inProgress;
+ private final Iterable<T> original;
+ private final Predicate<T> retainIfTrue;
+
+ private InProgressIterable(
+ Iterator<T> inProgress, Iterable<T> original, Predicate<T> retainIfTrue) {
+ this.inProgress = inProgress;
+ this.original = original;
+ this.retainIfTrue = retainIfTrue;
+ }
+
+ @Override
+ public Iterator<T> iterator() {
+ synchronized (this) {
+ if (inProgress != null) {
+ Iterator<T> it = inProgress;
+ inProgress = null;
+ return it;
+ }
+ }
+ return Iterators.filter(original.iterator(), retainIfTrue);
+ }
+ }
}
}
diff --git a/src/main/java/com/google/devtools/build/lib/rules/genquery/GenQuery.java b/src/main/java/com/google/devtools/build/lib/rules/genquery/GenQuery.java
index 247290b..0de2e1a 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/genquery/GenQuery.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/genquery/GenQuery.java
@@ -20,7 +20,6 @@
import com.google.common.collect.Collections2;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.ImmutableSortedSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
@@ -100,7 +99,6 @@
import java.io.OutputStream;
import java.nio.channels.ClosedByInterruptException;
import java.util.Collection;
-import java.util.Comparator;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
@@ -372,7 +370,10 @@
/*useGraphlessQuery=*/ graphlessQuery);
QueryExpression expr = QueryExpression.parse(query, queryEnvironment);
formatter.verifyCompatible(queryEnvironment, expr);
- targets = QueryUtil.newOrderedAggregateAllOutputFormatterCallback(queryEnvironment);
+ targets =
+ graphlessQuery && queryOptions.forceSortForGraphlessGenquery
+ ? QueryUtil.newLexicographicallySortedTargetAggregator()
+ : QueryUtil.newOrderedAggregateAllOutputFormatterCallback(queryEnvironment);
queryResult = queryEnvironment.evaluateQuery(expr, targets);
} catch (SkyframeRestartQueryException e) {
// Do not emit errors for skyframe restarts. They make output of the ConfiguredTargetFunction
@@ -389,12 +390,8 @@
ruleContext.getConfiguration().getFragment(GenQueryConfiguration.class);
GenQueryOutputStream outputStream =
new GenQueryOutputStream(genQueryConfig.inMemoryCompressionEnabled());
+ Set<Target> result = targets.getResult();
try {
- Set<Target> result = targets.getResult();
- if (graphlessQuery && queryOptions.forceSortForGraphlessGenquery) {
- result =
- ImmutableSortedSet.copyOf(Comparator.comparing(Target::getLabel), targets.getResult());
- }
QueryOutputUtils.output(
queryOptions,
queryResult,