Provide a parallel implementation of "e1 - e2 - e3" by noting its equivalence to "e1 - (e2 + e3)" and the fact that we already have a parallel implementation of "e2 + e3".
--
MOS_MIGRATED_REVID=139792288
diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/BinaryOperatorExpression.java b/src/main/java/com/google/devtools/build/lib/query2/engine/BinaryOperatorExpression.java
index fe72440..d1efe80 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/engine/BinaryOperatorExpression.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/engine/BinaryOperatorExpression.java
@@ -15,6 +15,7 @@
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Sets;
import com.google.devtools.build.lib.concurrent.MoreFutures;
import com.google.devtools.build.lib.query2.engine.Lexer.TokenKind;
import com.google.devtools.build.lib.util.Preconditions;
@@ -107,36 +108,78 @@
@Override
protected <T> void parEvalImpl(
+ QueryEnvironment<T> env,
+ VariableContext<T> context,
+ ThreadSafeCallback<T> callback,
+ ForkJoinPool forkJoinPool)
+ throws QueryException, InterruptedException {
+ if (operator == TokenKind.PLUS || operator == TokenKind.UNION) {
+ parEvalPlus(operands, env, context, callback, forkJoinPool);
+ } else if (operator == TokenKind.EXCEPT || operator == TokenKind.MINUS) {
+ parEvalMinus(operands, env, context, callback, forkJoinPool);
+ } else {
+ evalImpl(env, context, callback);
+ }
+ }
+
+ /**
+ * Evaluates an expression of the form "e1 + e2 + ... + eK" by evaluating all the subexpressions
+ * in parallel.
+ */
+ private static <T> void parEvalPlus(
+ ImmutableList<QueryExpression> operands,
final QueryEnvironment<T> env,
final VariableContext<T> context,
final ThreadSafeCallback<T> callback,
ForkJoinPool forkJoinPool)
- throws QueryException, InterruptedException {
- if (operator == TokenKind.PLUS || operator == TokenKind.UNION) {
- ArrayList<ForkJoinTask<Void>> tasks = new ArrayList<>(operands.size());
- for (final QueryExpression operand : operands) {
- tasks.add(ForkJoinTask.adapt(
- new Callable<Void>() {
- @Override
- public Void call() throws QueryException, InterruptedException {
- env.eval(operand, context, callback);
- return null;
- }
- }));
- }
- for (ForkJoinTask<?> task : tasks) {
- forkJoinPool.submit(task);
- }
- try {
- MoreFutures.waitForAllInterruptiblyFailFast(tasks);
- } catch (ExecutionException e) {
- Throwables.propagateIfPossible(
- e.getCause(), QueryException.class, InterruptedException.class);
- throw new IllegalStateException(e);
- }
- } else {
- evalImpl(env, context, callback);
+ throws QueryException, InterruptedException {
+ ArrayList<ForkJoinTask<Void>> tasks = new ArrayList<>(operands.size());
+ for (final QueryExpression operand : operands) {
+ tasks.add(ForkJoinTask.adapt(
+ new Callable<Void>() {
+ @Override
+ public Void call() throws QueryException, InterruptedException {
+ env.eval(operand, context, callback);
+ return null;
+ }
+ }));
}
+ for (ForkJoinTask<?> task : tasks) {
+ forkJoinPool.submit(task);
+ }
+ try {
+ MoreFutures.waitForAllInterruptiblyFailFast(tasks);
+ } catch (ExecutionException e) {
+ Throwables.propagateIfPossible(
+ e.getCause(), QueryException.class, InterruptedException.class);
+ throw new IllegalStateException(e);
+ }
+ }
+
+ /**
+ * Evaluates an expression of the form "e1 - e2 - ... - eK" by noting its equivalence to
+ * "e1 - (e2 + ... + eK)" and evaluating the subexpressions on the right-hand-side in parallel.
+ */
+ private static <T> void parEvalMinus(
+ ImmutableList<QueryExpression> operands,
+ QueryEnvironment<T> env,
+ VariableContext<T> context,
+ ThreadSafeCallback<T> callback,
+ ForkJoinPool forkJoinPool)
+ throws QueryException, InterruptedException {
+ final Set<T> lhsValue =
+ Sets.newConcurrentHashSet(QueryUtil.evalAll(env, context, operands.get(0)));
+ ThreadSafeCallback<T> subtractionCallback = new ThreadSafeCallback<T>() {
+ @Override
+ public void process(Iterable<T> partialResult) throws QueryException, InterruptedException {
+ for (T target : partialResult) {
+ lhsValue.remove(target);
+ }
+ }
+ };
+ parEvalPlus(
+ operands.subList(1, operands.size()), env, context, subtractionCallback, forkJoinPool);
+ callback.process(lhsValue);
}
@Override