Provide a parallel implementation of the 'tests' query function.

Both 'test_suite' expansion and Callback#process work happen in async QueryTaskFutures. Contrast with the old implementation, where we would do test suite expansion in a single thread and then Callback#process those test targets on the same thread.

Also, fix a large CPU hotspot in BlazeTargetAccessor#getLabelListAttr. This hotspot becomes very obvious once we have the parallelism.

RELNOTES: None
PiperOrigin-RevId: 244964705
diff --git a/src/main/java/com/google/devtools/build/lib/query2/AbstractBlazeQueryEnvironment.java b/src/main/java/com/google/devtools/build/lib/query2/AbstractBlazeQueryEnvironment.java
index 4a06fec..030f419 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/AbstractBlazeQueryEnvironment.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/AbstractBlazeQueryEnvironment.java
@@ -16,6 +16,7 @@
 import com.google.common.base.Preconditions;
 import com.google.common.base.Predicate;
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Sets;
 import com.google.devtools.build.lib.cmdline.Label;
@@ -38,6 +39,7 @@
 import java.util.Collection;
 import java.util.LinkedHashSet;
 import java.util.List;
+import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.logging.Logger;
@@ -230,6 +232,23 @@
   public abstract Target getTarget(Label label)
       throws TargetNotFoundException, QueryException, InterruptedException;
 
+  /** Batch version of {@link #getTarget(Label)}. Missing targets are absent in the returned map. */
+  // TODO(http://b/128626678): Implement and use this in more places.
+  public Map<Label, Target> getTargets(Iterable<Label> labels)
+      throws InterruptedException, QueryException {
+    ImmutableMap.Builder<Label, Target> resultBuilder = ImmutableMap.builder();
+    for (Label label : labels) {
+      Target target;
+      try {
+        target = getTarget(label);
+      } catch (TargetNotFoundException e) {
+        continue;
+      }
+      resultBuilder.put(label, target);
+    }
+    return resultBuilder.build();
+  }
+
   protected boolean validateScope(Label label, boolean strict) throws QueryException {
     if (!labelFilter.apply(label)) {
       String error = String.format("target '%s' is not within the scope of the query", label);
diff --git a/src/main/java/com/google/devtools/build/lib/query2/BlazeTargetAccessor.java b/src/main/java/com/google/devtools/build/lib/query2/BlazeTargetAccessor.java
index 00c6e8b..0e09b42 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/BlazeTargetAccessor.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/BlazeTargetAccessor.java
@@ -31,8 +31,8 @@
 import com.google.devtools.build.lib.query2.engine.QueryExpression;
 import com.google.devtools.build.lib.query2.engine.QueryVisibility;
 import com.google.devtools.build.lib.syntax.Type;
-import java.util.ArrayList;
 import java.util.List;
+import java.util.Map;
 import java.util.Set;
 
 /**
@@ -63,12 +63,11 @@
   }
 
   @Override
-  public List<Target> getLabelListAttr(
+  public Iterable<Target> getLabelListAttr(
       QueryExpression caller, Target target, String attrName, String errorMsgPrefix)
       throws QueryException, InterruptedException {
     Preconditions.checkArgument(target instanceof Rule);
 
-    List<Target> result = new ArrayList<>();
     Rule rule = (Rule) target;
 
     AggregatingAttributeMapper attrMap = AggregatingAttributeMapper.of(rule);
@@ -78,15 +77,26 @@
       return ImmutableList.of();
     }
 
-    for (Label label : attrMap.getReachableLabels(attrName, false)) {
-      try {
-        result.add(queryEnvironment.getTarget(label));
-      } catch (TargetNotFoundException e) {
-        queryEnvironment.reportBuildFileError(caller, errorMsgPrefix + e.getMessage());
+    Set<Label> labels = attrMap.getReachableLabels(attrName, false);
+    // TODO(nharmata): Figure out how to make use of the package semaphore in the transitive
+    // callsites of this method.
+    Map<Label, Target> labelTargetMap = queryEnvironment.getTargets(labels);
+    // Optimize for the common-case of no missing targets.
+    if (labelTargetMap.size() != labels.size()) {
+      for (Label label : labels) {
+        if (!labelTargetMap.containsKey(label)) {
+          // If a target was missing, fetch it directly for the sole purpose of getting a useful
+          // error message.
+          try {
+            queryEnvironment.getTarget(label);
+          } catch (TargetNotFoundException e) {
+            queryEnvironment.reportBuildFileError(caller, errorMsgPrefix + e.getMessage());
+          }
+        }
       }
-    }
 
-    return result;
+    }
+    return labelTargetMap.values();
   }
 
   @Override
diff --git a/src/main/java/com/google/devtools/build/lib/query2/ParallelSkyQueryUtils.java b/src/main/java/com/google/devtools/build/lib/query2/ParallelSkyQueryUtils.java
index 1f5e55a..9c767df 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/ParallelSkyQueryUtils.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/ParallelSkyQueryUtils.java
@@ -119,22 +119,22 @@
 
     Function<ThreadSafeMutableSet<Target>, QueryTaskFuture<Predicate<SkyKey>>>
         getTransitiveClosureAsyncFunction =
-        universeValue -> {
-          ThreadSafeAggregateAllSkyKeysCallback aggregateAllCallback =
-              new ThreadSafeAggregateAllSkyKeysCallback(concurrencyLevel);
-          return env.executeAsync(
-              () -> {
-                Callback<Target> visitorCallback =
-                    ParallelVisitor.createParallelVisitorCallback(
-                        new UnfilteredSkyKeyTTVDTCVisitor.Factory(
-                            env,
-                            env.createSkyKeyUniquifier(),
-                            processResultsBatchSize,
-                            aggregateAllCallback));
-                visitorCallback.process(universeValue);
-                return Predicates.in(aggregateAllCallback.getResult());
-              });
-        };
+            universeValue -> {
+              ThreadSafeAggregateAllSkyKeysCallback aggregateAllCallback =
+                  new ThreadSafeAggregateAllSkyKeysCallback(concurrencyLevel);
+              return env.execute(
+                  () -> {
+                    Callback<Target> visitorCallback =
+                        ParallelVisitor.createParallelVisitorCallback(
+                            new UnfilteredSkyKeyTTVDTCVisitor.Factory(
+                                env,
+                                env.createSkyKeyUniquifier(),
+                                processResultsBatchSize,
+                                aggregateAllCallback));
+                    visitorCallback.process(universeValue);
+                    return Predicates.in(aggregateAllCallback.getResult());
+                  });
+            };
 
     return env.transformAsync(universeValueFuture, getTransitiveClosureAsyncFunction);
   }
diff --git a/src/main/java/com/google/devtools/build/lib/query2/ParallelVisitor.java b/src/main/java/com/google/devtools/build/lib/query2/ParallelVisitor.java
index cf2987a..2c4d441 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/ParallelVisitor.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/ParallelVisitor.java
@@ -319,7 +319,7 @@
     public void process(Iterable<Target> partialResult)
         throws QueryException, InterruptedException {
       ParallelVisitor<?, ?> visitor = visitorFactory.create();
-      // TODO(nharmata): It's not ideal to have an operation like this in #process that blocks on
+      // TODO(b/131109214): It's not ideal to have an operation like this in #process that blocks on
       // another, potentially expensive computation. Refactor to something like "processAsync".
       visitor.visitAndWaitForCompletion(
           SkyQueryEnvironment.makeTransitiveTraversalKeysStrict(partialResult));
diff --git a/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java b/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java
index 5184e8e..b26cc88 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java
@@ -34,7 +34,6 @@
 import com.google.common.collect.Multimap;
 import com.google.common.collect.Sets;
 import com.google.common.collect.Streams;
-import com.google.common.util.concurrent.AsyncCallable;
 import com.google.common.util.concurrent.AsyncFunction;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
@@ -593,9 +592,10 @@
     }
   }
 
-  private <R> ListenableFuture<R> safeSubmitAsync(AsyncCallable<R> callable) {
+  @SuppressWarnings("unchecked")
+  private <R> ListenableFuture<R> safeSubmitAsync(QueryTaskAsyncCallable<R> callable) {
     try {
-      return Futures.submitAsync(callable, executor);
+      return Futures.submitAsync(() -> (ListenableFuture<R>) callable.call(), executor);
     } catch (RejectedExecutionException e) {
       return Futures.immediateCancelledFuture();
     }
@@ -609,17 +609,20 @@
       final Callback<Target> callback) {
     // TODO(bazel-team): As in here, use concurrency for the async #eval of other QueryEnvironment
     // implementations.
-    AsyncCallable<Void> task =
-        () -> (QueryTaskFutureImpl<Void>) expr.eval(SkyQueryEnvironment.this, context, callback);
-    return QueryTaskFutureImpl.ofDelegate(safeSubmitAsync(task));
+    return executeAsync(() -> expr.eval(SkyQueryEnvironment.this, context, callback));
   }
 
   @Override
-  public <R> QueryTaskFuture<R> executeAsync(QueryTaskCallable<R> callable) {
+  public <R> QueryTaskFuture<R> execute(QueryTaskCallable<R> callable) {
     return QueryTaskFutureImpl.ofDelegate(safeSubmit(callable));
   }
 
   @Override
+  public <R> QueryTaskFuture<R> executeAsync(QueryTaskAsyncCallable<R> callable) {
+    return QueryTaskFutureImpl.ofDelegate(safeSubmitAsync(callable));
+  }
+
+  @Override
   public <T1, T2> QueryTaskFuture<T2> transformAsync(
       QueryTaskFuture<T1> future,
       final Function<T1, QueryTaskFuture<T2>> function) {
@@ -860,6 +863,31 @@
     }
   }
 
+  @Override
+  public Map<Label, Target> getTargets(Iterable<Label> labels) throws InterruptedException {
+    Multimap<PackageIdentifier, Label> packageIdToLabelMap = ArrayListMultimap.create();
+    labels.forEach(label -> packageIdToLabelMap.put(label.getPackageIdentifier(), label));
+    Map<PackageIdentifier, Package> packageIdToPackageMap =
+        bulkGetPackages(packageIdToLabelMap.keySet());
+    ImmutableMap.Builder<Label, Target> resultBuilder = ImmutableMap.builder();
+    for (PackageIdentifier pkgId : packageIdToLabelMap.keySet()) {
+      Package pkg = packageIdToPackageMap.get(pkgId);
+      if (pkg == null) {
+        continue;
+      }
+      for (Label label : packageIdToLabelMap.get(pkgId)) {
+        Target target;
+        try {
+          target = pkg.getTarget(label.getName());
+        } catch (NoSuchTargetException e) {
+          continue;
+        }
+        resultBuilder.put(label, target);
+      }
+    }
+    return resultBuilder.build();
+  }
+
   @ThreadSafe
   public Map<PackageIdentifier, Package> bulkGetPackages(Iterable<PackageIdentifier> pkgIds)
       throws InterruptedException {
diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/AbstractQueryEnvironment.java b/src/main/java/com/google/devtools/build/lib/query2/engine/AbstractQueryEnvironment.java
index 6971d07..1d87384 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/engine/AbstractQueryEnvironment.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/engine/AbstractQueryEnvironment.java
@@ -139,7 +139,7 @@
   }
 
   @Override
-  public <R> QueryTaskFuture<R> executeAsync(QueryTaskCallable<R> callable) {
+  public <R> QueryTaskFuture<R> execute(QueryTaskCallable<R> callable) {
     try {
       return immediateSuccessfulFuture(callable.call());
     } catch (QueryException e) {
@@ -150,6 +150,11 @@
   }
 
   @Override
+  public <R> QueryTaskFuture<R> executeAsync(QueryTaskAsyncCallable<R> callable) {
+    return callable.call();
+  }
+
+  @Override
   public <R> QueryTaskFuture<R> whenSucceedsCall(
       QueryTaskFuture<?> future, QueryTaskCallable<R> callable) {
     return whenAllSucceedCall(ImmutableList.of(future), callable);
diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/QueryEnvironment.java b/src/main/java/com/google/devtools/build/lib/query2/engine/QueryEnvironment.java
index afdbee3..d586574 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/engine/QueryEnvironment.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/engine/QueryEnvironment.java
@@ -231,9 +231,9 @@
   /**
    * An asynchronous computation of part of a query evaluation.
    *
-   * <p>A {@link QueryTaskFuture} can only be produced from scratch via {@link #eval},
-   * {@link #executeAsync}, {@link #immediateSuccessfulFuture}, {@link #immediateFailedFuture}, and
-   * {@link #immediateCancelledFuture}.
+   * <p>A {@link QueryTaskFuture} can only be produced from scratch via {@link #eval}, {@link
+   * #execute}, {@link #immediateSuccessfulFuture}, {@link #immediateFailedFuture}, and {@link
+   * #immediateCancelledFuture}.
    *
    * <p>Combined with the helper methods like {@link #whenSucceedsCall} below, this is very similar
    * to Guava's {@link ListenableFuture}.
@@ -301,15 +301,31 @@
     T call() throws QueryException, InterruptedException;
   }
 
+  /** Like Guava's AsyncCallable, but for {@link QueryTaskFuture}. */
+  @ThreadSafe
+  public interface QueryTaskAsyncCallable<T> {
+    /**
+     * Returns a {@link QueryTaskFuture} whose completion encapsulates the result of the
+     * computation.
+     */
+    QueryTaskFuture<T> call();
+  }
+
   /**
    * Returns a {@link QueryTaskFuture} representing the given computation {@code callable} being
    * performed asynchronously.
    *
-   * <p>The returned {@link QueryTaskFuture} is considered "successful" for purposes of
-   * {@link #whenSucceedsCall}, {@link #whenAllSucceed}, and
-   * {@link QueryTaskFuture#getIfSuccessful} iff {@code callable#call} does not throw an exception.
+   * <p>The returned {@link QueryTaskFuture} is considered "successful" for purposes of {@link
+   * #whenSucceedsCall}, {@link #whenAllSucceed}, and {@link QueryTaskFuture#getIfSuccessful} iff
+   * {@code callable#call} does not throw an exception.
    */
-  <R> QueryTaskFuture<R> executeAsync(QueryTaskCallable<R> callable);
+  <R> QueryTaskFuture<R> execute(QueryTaskCallable<R> callable);
+
+  /**
+   * Returns a {@link QueryTaskFuture} representing both the given {@code callable} being performed
+   * asynchronously and also the returned {@link QueryTaskFuture} returned therein being completed.
+   */
+  <R> QueryTaskFuture<R> executeAsync(QueryTaskAsyncCallable<R> callable);
 
   /**
    * Returns a {@link QueryTaskFuture} representing the given computation {@code callable} being
@@ -362,9 +378,9 @@
    * The sole package-protected subclass of {@link QueryTaskFuture}.
    *
    * <p>Do not subclass this class; it's an implementation detail. {@link QueryExpression} and
-   * {@link QueryFunction} implementations should use {@link #eval} and {@link #executeAsync} to get
-   * access to {@link QueryTaskFuture} instances and the then use the helper methods like
-   * {@link #whenSucceedsCall} to transform them.
+   * {@link QueryFunction} implementations should use {@link #eval} and {@link #execute} to get
+   * access to {@link QueryTaskFuture} instances and the then use the helper methods like {@link
+   * #whenSucceedsCall} to transform them.
    */
   abstract class QueryTaskFutureImplBase<T> extends QueryTaskFuture<T> {
     protected QueryTaskFutureImplBase() {
@@ -534,7 +550,7 @@
      *
      * @throws IllegalArgumentException if target is not a rule (according to {@link #isRule})
      */
-    List<T> getLabelListAttr(
+    Iterable<T> getLabelListAttr(
         QueryExpression caller, T target, String attrName, String errorMsgPrefix)
         throws QueryException, InterruptedException;
 
diff --git a/src/main/java/com/google/devtools/build/lib/query2/engine/TestsFunction.java b/src/main/java/com/google/devtools/build/lib/query2/engine/TestsFunction.java
index 37e95d4..7e7cfdc 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/engine/TestsFunction.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/engine/TestsFunction.java
@@ -15,18 +15,18 @@
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
 import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
 import com.google.devtools.build.lib.query2.engine.QueryEnvironment.Argument;
 import com.google.devtools.build.lib.query2.engine.QueryEnvironment.ArgumentType;
-import com.google.devtools.build.lib.query2.engine.QueryEnvironment.MutableMap;
 import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryFunction;
 import com.google.devtools.build.lib.query2.engine.QueryEnvironment.QueryTaskFuture;
 import com.google.devtools.build.lib.query2.engine.QueryEnvironment.Setting;
-import com.google.devtools.build.lib.query2.engine.QueryEnvironment.ThreadSafeMutableSet;
+import com.google.devtools.build.lib.query2.engine.QueryEnvironment.TargetAccessor;
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.HashSet;
-import java.util.Iterator;
 import java.util.List;
 import java.util.Set;
 
@@ -63,50 +63,233 @@
 
   @Override
   public <T> QueryTaskFuture<Void> eval(
-      final QueryEnvironment<T> env,
+      QueryEnvironment<T> env,
       QueryExpressionContext<T> context,
       QueryExpression expression,
       List<Argument> args,
-      final Callback<T> callback) {
-    final Closure<T> closure = new Closure<>(expression, env);
-    final Uniquifier<T> uniquifier = env.createUniquifier();
+      Callback<T> callback) {
+    Closure<T> closure = new Closure<>(expression, callback, env);
 
-    return env.eval(
-        args.get(0).getExpression(),
-        context,
-        new Callback<T>() {
-          @Override
-          public void process(Iterable<T> partialResult)
-              throws QueryException, InterruptedException {
-            for (T target : partialResult) {
-              if (env.getAccessor().isTestRule(target)) {
-                if (uniquifier.unique(target)) {
-                  callback.process(ImmutableList.of(target));
-                }
-              } else if (env.getAccessor().isTestSuite(target)) {
-                for (T test : closure.getTestsInSuite(target)) {
-                  T testTarget = env.getOrCreate(test);
-                  if (uniquifier.unique(testTarget)) {
-                    callback.process(ImmutableList.of(testTarget));
-                  }
-                }
-              }
-            }
+    // A callback that appropriately feeds top-level test and test_suite targets to 'closure'.
+    Callback<T> visitAllTestSuitesCallback =
+        partialResult -> {
+          PartitionResult<T> partitionResult = closure.partition(partialResult);
+          closure
+              .getUniqueTestSuites(partitionResult.testSuiteTargets)
+              .forEach(closure::visitUniqueTestsInUniqueSuite);
+          callback.process(closure.getUniqueTests(partitionResult.testTargets));
+        };
+
+    // Get a future that represents full evaluation of the argument expression.
+    QueryTaskFuture<Void> testSuiteVisitationStartedFuture =
+        env.eval(args.get(0).getExpression(), context, visitAllTestSuitesCallback);
+
+    return env.transformAsync(
+        // When this future is done, all top-level test_suite targets have already been fed to the
+        // 'closure', meaning that ...
+        testSuiteVisitationStartedFuture,
+        // ... 'closure.getTopLevelRecursiveVisitationFutures()' represents the full visitation of
+        // all these test_suite targets.
+        dummyVal -> env.whenAllSucceed(closure.getTopLevelRecursiveVisitationFutures()));
+  }
+
+  private static class PartitionResult<T> {
+    final ImmutableList<T> testTargets;
+    final ImmutableList<T> testSuiteTargets;
+    final ImmutableList<T> otherTargets;
+
+    private PartitionResult(
+        ImmutableList<T> testTargets,
+        ImmutableList<T> testSuiteTargets,
+        ImmutableList<T> otherTargets) {
+      this.testTargets = testTargets;
+      this.testSuiteTargets = testSuiteTargets;
+      this.otherTargets = otherTargets;
+    }
+  }
+
+  /** A closure over the state needed to do asynchronous test_suite visitation and expansion. */
+  @ThreadSafe
+  private static final class Closure<T> {
+    private final QueryExpression expression;
+    private final Callback<T> callback;
+    /** The environment in which this query is being evaluated. */
+    private final QueryEnvironment<T> env;
+
+    private final TargetAccessor<T> accessor;
+    private final boolean strict;
+    private final Uniquifier<T> testUniquifier;
+    private final Uniquifier<T> testSuiteUniquifier;
+    private final List<QueryTaskFuture<Void>> topLevelRecursiveVisitationFutures =
+        Collections.synchronizedList(new ArrayList<>());
+
+    private Closure(QueryExpression expression, Callback<T> callback, QueryEnvironment<T> env) {
+      this.expression = expression;
+      this.callback = callback;
+      this.env = env;
+      this.accessor = env.getAccessor();
+      this.strict = env.isSettingEnabled(Setting.TESTS_EXPRESSION_STRICT);
+      this.testUniquifier = env.createUniquifier();
+      this.testSuiteUniquifier = env.createUniquifier();
+    }
+
+    private Iterable<T> getUniqueTests(Iterable<T> tests) throws QueryException {
+      return testUniquifier.unique(tests);
+    }
+
+    private Iterable<T> getUniqueTestSuites(Iterable<T> testSuites) throws QueryException {
+      return testSuiteUniquifier.unique(testSuites);
+    }
+
+    private void visitUniqueTestsInUniqueSuite(T testSuite) {
+      topLevelRecursiveVisitationFutures.add(
+          env.executeAsync(() -> recursivelyVisitUniqueTestsInUniqueSuite(testSuite)));
+    }
+
+    /**
+     * Returns all the futures representing the work items entailed by all the previous calls to
+     * {@link #visitUniqueTestsInUniqueSuite}.
+     */
+    private ImmutableList<QueryTaskFuture<Void>> getTopLevelRecursiveVisitationFutures() {
+      return ImmutableList.copyOf(topLevelRecursiveVisitationFutures);
+    }
+
+    private QueryTaskFuture<Void> recursivelyVisitUniqueTestsInUniqueSuite(T testSuite) {
+      List<String> tagsAttribute = accessor.getStringListAttr(testSuite, "tags");
+      // Split the tags list into positive and negative tags
+      Set<String> requiredTags = new HashSet<>();
+      Set<String> excludedTags = new HashSet<>();
+      sortTagsBySense(tagsAttribute, requiredTags, excludedTags);
+
+      List<T> testsToProcess = new ArrayList<>();
+      List<T> testSuites;
+
+      try {
+        PartitionResult<T> partitionResult = partition(getPrerequisites(testSuite, "tests"));
+
+        for (T testTarget : partitionResult.testTargets) {
+          if (includeTest(requiredTags, excludedTags, testTarget)
+              && testUniquifier.unique(testTarget)) {
+            testsToProcess.add(testTarget);
           }
-        });
+        }
+
+        testSuites = testSuiteUniquifier.unique(partitionResult.testSuiteTargets);
+
+        // If strict mode is enabled, then give an error for any non-test, non-test-suite target.
+        if (strict) {
+          for (T otherTarget : partitionResult.otherTargets) {
+            env.reportBuildFileError(
+                expression,
+                "The label '"
+                    + accessor.getLabel(otherTarget)
+                    + "' in the test_suite '"
+                    + accessor.getLabel(testSuite)
+                    + "' does not refer to a test or test_suite "
+                    + "rule!");
+          }
+        }
+
+        // Add implicit dependencies on tests in same package, if any.
+        for (T target : getPrerequisites(testSuite, "$implicit_tests")) {
+          // The Package construction of $implicit_tests ensures that this check never fails, but we
+          // add it here anyway for compatibility with future code.
+          if (accessor.isTestRule(target)
+              && includeTest(requiredTags, excludedTags, target)
+              && testUniquifier.unique(target)) {
+            testsToProcess.add(target);
+          }
+        }
+      } catch (InterruptedException e) {
+        return env.immediateCancelledFuture();
+      } catch (QueryException e) {
+        return env.immediateFailedFuture(e);
+      }
+
+      // Process all tests, asynchronously.
+      QueryTaskFuture<Void> allTestsProcessedFuture =
+          env.execute(
+              () -> {
+                callback.process(testsToProcess);
+                return null;
+              });
+
+      // Visit all suites recursively, asynchronously.
+      QueryTaskFuture<Void> allTestSuitsVisitedFuture =
+          env.whenAllSucceed(
+              Iterables.transform(testSuites, this::recursivelyVisitUniqueTestsInUniqueSuite));
+
+      return env.whenAllSucceed(
+          ImmutableList.of(allTestsProcessedFuture, allTestSuitsVisitedFuture));
+    }
+
+    private PartitionResult<T> partition(Iterable<T> targets) {
+      ImmutableList.Builder<T> testTargetsBuilder = ImmutableList.builder();
+      ImmutableList.Builder<T> testSuiteTargetsBuilder = ImmutableList.builder();
+      ImmutableList.Builder<T> otherTargetsBuilder = ImmutableList.builder();
+
+      for (T target : targets) {
+        if (accessor.isTestRule(target)) {
+          testTargetsBuilder.add(target);
+        } else if (accessor.isTestSuite(target)) {
+          testSuiteTargetsBuilder.add(target);
+        } else {
+          otherTargetsBuilder.add(target);
+        }
+      }
+
+      return new PartitionResult<>(
+          testTargetsBuilder.build(), testSuiteTargetsBuilder.build(), otherTargetsBuilder.build());
+    }
+
+    /**
+     * Returns the set of rules named by the attribute 'attrName' of test_suite rule 'testSuite'.
+     * The attribute must be a list of labels. If a target cannot be resolved, then an error is
+     * reported to the environment (which may throw an exception if {@code keep_going} is disabled).
+     *
+     * @precondition env.getAccessor().isTestSuite(testSuite)
+     */
+    private Iterable<T> getPrerequisites(T testSuite, String attrName)
+        throws QueryException, InterruptedException {
+      return accessor.getLabelListAttr(
+          expression,
+          testSuite,
+          attrName,
+          "couldn't expand '"
+              + attrName
+              + "' attribute of test_suite "
+              + accessor.getLabel(testSuite)
+              + ": ");
+    }
+
+    /**
+     * Filters 'tests' (by mutation) according to the 'tags' attribute, specifically those that
+     * match ALL of the tags in tagsAttribute.
+     *
+     * @precondition {@code env.getAccessor().isTestSuite(testSuite)}
+     * @precondition {@code env.getAccessor().isTestRule(test)}
+     */
+    private boolean includeTest(Set<String> requiredTags, Set<String> excludedTags, T test) {
+      List<String> testTags = new ArrayList<>(accessor.getStringListAttr(test, "tags"));
+      testTags.add(accessor.getStringAttr(test, "size"));
+      return TestsFunction.includeTest(testTags, requiredTags, excludedTags);
+    }
   }
 
   // TODO(ulfjack): This must match the code in TestTargetUtils. However, we don't currently want
   // to depend on the packages library. Extract to a neutral place?
   /**
    * Decides whether to include a test in a test_suite or not.
+   *
    * @param testTags Collection of all tags exhibited by a given test.
    * @param positiveTags Tags declared by the suite. A test must match ALL of these.
    * @param negativeTags Tags declared by the suite. A test must match NONE of these.
    * @return false is the test is to be removed.
    */
-  private static boolean includeTest(Collection<String> testTags,
-      Collection<String> positiveTags, Collection<String> negativeTags) {
+  private static boolean includeTest(
+      Collection<String> testTags,
+      Collection<String> positiveTags,
+      Collection<String> negativeTags) {
     // Add this test if it matches ALL of the positive tags and NONE of the
     // negative tags in the tags attribute.
     for (String tag : negativeTags) {
@@ -151,129 +334,4 @@
       }
     }
   }
-
-  /** A closure over the temporary state needed to compute the expression. */
-  @ThreadSafe
-  private static final class Closure<T> {
-    private final QueryExpression expression;
-    /** A dynamically-populated mapping from test_suite rules to their tests. */
-    private final MutableMap<T, ThreadSafeMutableSet<T>> testsInSuite;
-
-    /** The environment in which this query is being evaluated. */
-    private final QueryEnvironment<T> env;
-
-    private final boolean strict;
-
-    private Closure(QueryExpression expression, QueryEnvironment<T> env) {
-      this.expression = expression;
-      this.env = env;
-      this.strict = env.isSettingEnabled(Setting.TESTS_EXPRESSION_STRICT);
-      this.testsInSuite = env.createMutableMap();
-    }
-
-    /**
-     * Computes and returns the set of test rules in a particular suite. Uses dynamic
-     * programming---a memoized version of {@link #computeTestsInSuite}.
-     *
-     * @precondition env.getAccessor().isTestSuite(testSuite)
-     */
-    private synchronized ThreadSafeMutableSet<T> getTestsInSuite(T testSuite)
-        throws QueryException, InterruptedException {
-      ThreadSafeMutableSet<T> tests = testsInSuite.get(testSuite);
-      if (tests == null) {
-        tests = env.createThreadSafeMutableSet();
-        testsInSuite.put(testSuite, tests); // break cycles by inserting empty set early.
-        computeTestsInSuite(testSuite, tests);
-      }
-      return tests;
-    }
-
-    /**
-     * Populates 'result' with all the tests associated with the specified 'testSuite'. Throws an
-     * exception if any target is missing.
-     *
-     * <p>CAUTION! Keep this logic consistent with {@code TestsSuiteConfiguredTarget}!
-     *
-     * @precondition env.getAccessor().isTestSuite(testSuite)
-     */
-    private void computeTestsInSuite(T testSuite, ThreadSafeMutableSet<T> result)
-        throws QueryException, InterruptedException {
-      List<T> testsAndSuites = new ArrayList<>();
-      // Note that testsAndSuites can contain input file targets; the test_suite rule does not
-      // restrict the set of targets that can appear in tests or suites.
-      testsAndSuites.addAll(getPrerequisites(testSuite, "tests"));
-
-      // 1. Add all tests
-      for (T test : testsAndSuites) {
-        if (env.getAccessor().isTestRule(test)) {
-          result.add(test);
-        } else if (strict && !env.getAccessor().isTestSuite(test)) {
-          // If strict mode is enabled, then give an error for any non-test, non-test-suite targets.
-          env.reportBuildFileError(expression, "The label '"
-              + env.getAccessor().getLabel(test) + "' in the test_suite '"
-              + env.getAccessor().getLabel(testSuite) + "' does not refer to a test or test_suite "
-              + "rule!");
-        }
-      }
-
-      // 2. Add implicit dependencies on tests in same package, if any.
-      for (T target : getPrerequisites(testSuite, "$implicit_tests")) {
-        // The Package construction of $implicit_tests ensures that this check never fails, but we
-        // add it here anyway for compatibility with future code.
-        if (env.getAccessor().isTestRule(target)) {
-          result.add(target);
-        }
-      }
-
-      // 3. Filter based on tags, size, env.
-      filterTests(testSuite, result);
-
-      // 4. Expand all suites recursively.
-      for (T suite : testsAndSuites) {
-        if (env.getAccessor().isTestSuite(suite)) {
-          result.addAll(getTestsInSuite(suite));
-        }
-      }
-    }
-
-    /**
-     * Returns the set of rules named by the attribute 'attrName' of test_suite rule 'testSuite'.
-     * The attribute must be a list of labels. If a target cannot be resolved, then an error is
-     * reported to the environment (which may throw an exception if {@code keep_going} is disabled).
-     *
-     * @precondition env.getAccessor().isTestSuite(testSuite)
-     */
-    private List<T> getPrerequisites(T testSuite, String attrName)
-        throws QueryException, InterruptedException {
-      return env.getAccessor().getLabelListAttr(expression, testSuite, attrName,
-          "couldn't expand '" + attrName
-          + "' attribute of test_suite " + env.getAccessor().getLabel(testSuite) + ": ");
-    }
-
-    /**
-     * Filters 'tests' (by mutation) according to the 'tags' attribute, specifically those that
-     * match ALL of the tags in tagsAttribute.
-     *
-     * @precondition {@code env.getAccessor().isTestSuite(testSuite)}
-     * @precondition {@code env.getAccessor().isTestRule(test)} for all test in tests
-     */
-    private void filterTests(T testSuite, ThreadSafeMutableSet<T> tests) {
-      List<String> tagsAttribute = env.getAccessor().getStringListAttr(testSuite, "tags");
-      // Split the tags list into positive and negative tags
-      Set<String> requiredTags = new HashSet<>();
-      Set<String> excludedTags = new HashSet<>();
-      sortTagsBySense(tagsAttribute, requiredTags, excludedTags);
-
-      Iterator<T> it = tests.iterator();
-      while (it.hasNext()) {
-        T test = it.next();
-        // TODO(ulfjack): This does not match the code used for TestSuite.
-        List<String> testTags = new ArrayList<>(env.getAccessor().getStringListAttr(test, "tags"));
-        testTags.add(env.getAccessor().getStringAttr(test, "size"));
-        if (!includeTest(testTags, requiredTags, excludedTags)) {
-          it.remove();
-        }
-      }
-    }
-  }
 }