Allow customization of the package batching strategy used by `RecursivePackageProviderBackedTargetPatternResolver`.

PiperOrigin-RevId: 374922066
diff --git a/src/main/java/com/google/devtools/build/lib/query2/BUILD b/src/main/java/com/google/devtools/build/lib/query2/BUILD
index d48a30b..36dddca 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/query2/BUILD
@@ -83,6 +83,7 @@
         "//src/main/java/com/google/devtools/build/lib/skyframe:detailed_exceptions",
         "//src/main/java/com/google/devtools/build/lib/skyframe:graph_backed_recursive_package_provider",
         "//src/main/java/com/google/devtools/build/lib/skyframe:ignored_package_prefixes_value",
+        "//src/main/java/com/google/devtools/build/lib/skyframe:package_identifier_batching_callback",
         "//src/main/java/com/google/devtools/build/lib/skyframe:package_lookup_value",
         "//src/main/java/com/google/devtools/build/lib/skyframe:package_value",
         "//src/main/java/com/google/devtools/build/lib/skyframe:prepare_deps_of_patterns_function",
diff --git a/src/main/java/com/google/devtools/build/lib/query2/PostAnalysisQueryEnvironment.java b/src/main/java/com/google/devtools/build/lib/query2/PostAnalysisQueryEnvironment.java
index bb31bf1..bf88cce 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/PostAnalysisQueryEnvironment.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/PostAnalysisQueryEnvironment.java
@@ -65,6 +65,7 @@
 import com.google.devtools.build.lib.skyframe.PackageValue;
 import com.google.devtools.build.lib.skyframe.RecursivePackageProviderBackedTargetPatternResolver;
 import com.google.devtools.build.lib.skyframe.RecursivePkgValueRootPackageExtractor;
+import com.google.devtools.build.lib.skyframe.SimplePackageIdentifierBatchingCallback;
 import com.google.devtools.build.lib.skyframe.SkyFunctions;
 import com.google.devtools.build.lib.skyframe.SkyframeExecutor;
 import com.google.devtools.build.lib.skyframe.TargetPatternValue;
@@ -165,7 +166,8 @@
             graphBackedRecursivePackageProvider,
             eventHandler,
             FilteringPolicies.NO_FILTER,
-            MultisetSemaphore.unbounded());
+            MultisetSemaphore.unbounded(),
+            SimplePackageIdentifierBatchingCallback::new);
     checkSettings(settings);
   }
 
diff --git a/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java b/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java
index 5fa2fdf..677f485 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java
@@ -42,6 +42,7 @@
 import com.google.devtools.build.lib.cmdline.PackageIdentifier;
 import com.google.devtools.build.lib.cmdline.TargetParsingException;
 import com.google.devtools.build.lib.cmdline.TargetPattern;
+import com.google.devtools.build.lib.cmdline.TargetPatternResolver;
 import com.google.devtools.build.lib.collect.compacthashset.CompactHashSet;
 import com.google.devtools.build.lib.concurrent.BlockingStack;
 import com.google.devtools.build.lib.concurrent.MultisetSemaphore;
@@ -96,6 +97,7 @@
 import com.google.devtools.build.lib.skyframe.PackageValue;
 import com.google.devtools.build.lib.skyframe.PrepareDepsOfPatternsFunction;
 import com.google.devtools.build.lib.skyframe.RecursivePackageProviderBackedTargetPatternResolver;
+import com.google.devtools.build.lib.skyframe.SimplePackageIdentifierBatchingCallback;
 import com.google.devtools.build.lib.skyframe.TargetPatternValue;
 import com.google.devtools.build.lib.skyframe.TargetPatternValue.TargetPatternKey;
 import com.google.devtools.build.lib.skyframe.TransitiveTraversalValue;
@@ -159,12 +161,12 @@
   private final boolean visibilityDepsAreAllowed;
 
   // The following fields are set in the #beforeEvaluateQuery method.
-  private MultisetSemaphore<PackageIdentifier> packageSemaphore;
+  protected MultisetSemaphore<PackageIdentifier> packageSemaphore;
   protected WalkableGraph graph;
   protected InterruptibleSupplier<ImmutableSet<PathFragment>> ignoredPatternsSupplier;
   protected GraphBackedRecursivePackageProvider graphBackedRecursivePackageProvider;
   protected ListeningExecutorService executor;
-  private RecursivePackageProviderBackedTargetPatternResolver resolver;
+  private TargetPatternResolver<Target> resolver;
 
   public SkyQueryEnvironment(
       boolean keepGoing,
@@ -284,12 +286,16 @@
             /*workQueue=*/ new BlockingStack<Runnable>(),
             new ThreadFactoryBuilder().setNameFormat("QueryEnvironment %d").build()));
     }
-    resolver =
-        new RecursivePackageProviderBackedTargetPatternResolver(
-            graphBackedRecursivePackageProvider,
-            eventHandler,
-            FilteringPolicies.NO_FILTER,
-            packageSemaphore);
+    resolver = makeNewTargetPatternResolver();
+  }
+
+  protected TargetPatternResolver<Target> makeNewTargetPatternResolver() {
+    return new RecursivePackageProviderBackedTargetPatternResolver(
+        graphBackedRecursivePackageProvider,
+        eventHandler,
+        FilteringPolicies.NO_FILTER,
+        packageSemaphore,
+        SimplePackageIdentifierBatchingCallback::new);
   }
 
   /** Returns the TargetPatterns corresponding to {@code universeKey}. */
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/BUILD b/src/main/java/com/google/devtools/build/lib/skyframe/BUILD
index d65930c..c0b2af9 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/BUILD
@@ -155,6 +155,7 @@
         ":output_store",
         ":package_error_function",
         ":package_error_message_function",
+        ":package_identifier_batching_callback",
         ":package_lookup_function",
         ":package_lookup_value",
         ":package_progress_receiver",
@@ -1705,7 +1706,10 @@
 
 java_library(
     name = "package_identifier_batching_callback",
-    srcs = ["PackageIdentifierBatchingCallback.java"],
+    srcs = [
+        "PackageIdentifierBatchingCallback.java",
+        "SimplePackageIdentifierBatchingCallback.java",
+    ],
     deps = [
         "//src/main/java/com/google/devtools/build/lib/cmdline",
         "//src/main/java/com/google/devtools/build/lib/concurrent",
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/PackageIdentifierBatchingCallback.java b/src/main/java/com/google/devtools/build/lib/skyframe/PackageIdentifierBatchingCallback.java
index f7bd24b..daaac76 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/PackageIdentifierBatchingCallback.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/PackageIdentifierBatchingCallback.java
@@ -1,4 +1,4 @@
-// Copyright 2015 The Bazel Authors. All rights reserved.
+// Copyright 2021 The Bazel Authors. All rights reserved.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -13,70 +13,27 @@
 // limitations under the License.
 package com.google.devtools.build.lib.skyframe;
 
-import com.google.common.collect.ImmutableList;
 import com.google.devtools.build.lib.cmdline.PackageIdentifier;
 import com.google.devtools.build.lib.concurrent.BatchCallback;
 import com.google.devtools.build.lib.concurrent.ParallelVisitor.UnusedException;
 import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
-import javax.annotation.concurrent.GuardedBy;
 
 /**
  * A callback for {@link
  * com.google.devtools.build.lib.pkgcache.RecursivePackageProvider#streamPackagesUnderDirectory}
- * that buffers the PackageIdentifiers it receives into fixed-size batches that it delivers to a
- * supplied {@code BatchCallback<PackageIdentifier, RuntimeException>}.
+ * that buffers the {@link PackageIdentifier} instances it receives into bounded-size batches that
+ * it delivers to a supplied callback.
  *
- * <p>The final batch delivered to the delegate callback may be smaller than the fixed size; the
- * callback must be {@link #close() closed} to deliver this final batch.
+ * <p>This callback must be {@link #close() closed} to deliver this final batch.
  */
 @ThreadSafe
-public class PackageIdentifierBatchingCallback
-    implements BatchCallback<PackageIdentifier, UnusedException>, AutoCloseable {
+public interface PackageIdentifierBatchingCallback
+    extends BatchCallback<PackageIdentifier, UnusedException>, AutoCloseable {
+  void close() throws InterruptedException;
 
-  private final BatchCallback<PackageIdentifier, UnusedException> batchResults;
-  private final int batchSize;
-
-  @GuardedBy("this")
-  private ImmutableList.Builder<PackageIdentifier> packageIdentifiers;
-
-  @GuardedBy("this")
-  private int bufferedPackageIds;
-
-  public PackageIdentifierBatchingCallback(
-      BatchCallback<PackageIdentifier, UnusedException> batchResults, int batchSize) {
-    this.batchResults = batchResults;
-    this.batchSize = batchSize;
-    reset();
-  }
-
-  @Override
-  public synchronized void process(Iterable<PackageIdentifier> partialResult)
-      throws InterruptedException {
-    for (PackageIdentifier path : partialResult) {
-      packageIdentifiers.add(path);
-      bufferedPackageIds++;
-      if (bufferedPackageIds >= this.batchSize) {
-        flush();
-      }
-    }
-  }
-
-  @Override
-  public synchronized void close() throws InterruptedException {
-    flush();
-  }
-
-  @GuardedBy("this")
-  private void flush() throws InterruptedException {
-    if (bufferedPackageIds > 0) {
-      batchResults.process(packageIdentifiers.build());
-      reset();
-    }
-  }
-
-  @GuardedBy("this")
-  private void reset() {
-    packageIdentifiers = ImmutableList.builderWithExpectedSize(batchSize);
-    bufferedPackageIds = 0;
+  /** Factory for {@link PackageIdentifierBatchingCallback}. */
+  interface Factory {
+    PackageIdentifierBatchingCallback create(
+        BatchCallback<PackageIdentifier, UnusedException> batchResults, int maxBatchSize);
   }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/RecursivePackageProviderBackedTargetPatternResolver.java b/src/main/java/com/google/devtools/build/lib/skyframe/RecursivePackageProviderBackedTargetPatternResolver.java
index f89e4db..471d806 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/RecursivePackageProviderBackedTargetPatternResolver.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/RecursivePackageProviderBackedTargetPatternResolver.java
@@ -68,16 +68,19 @@
   private final RecursivePackageProvider recursivePackageProvider;
   private final ExtendedEventHandler eventHandler;
   private final MultisetSemaphore<PackageIdentifier> packageSemaphore;
+  private final PackageIdentifierBatchingCallback.Factory packageIdentifierBatchingCallbackFactory;
 
   public RecursivePackageProviderBackedTargetPatternResolver(
       RecursivePackageProvider recursivePackageProvider,
       ExtendedEventHandler eventHandler,
       FilteringPolicy policy,
-      MultisetSemaphore<PackageIdentifier> packageSemaphore) {
+      MultisetSemaphore<PackageIdentifier> packageSemaphore,
+      PackageIdentifierBatchingCallback.Factory packageIdentifierBatchingCallbackFactory) {
     this.recursivePackageProvider = recursivePackageProvider;
     this.eventHandler = eventHandler;
     this.policy = policy;
     this.packageSemaphore = packageSemaphore;
+    this.packageIdentifierBatchingCallbackFactory = packageIdentifierBatchingCallbackFactory;
   }
 
   @Override
@@ -245,7 +248,8 @@
 
     PathFragment pathFragment;
     try (PackageIdentifierBatchingCallback batchingCallback =
-        new PackageIdentifierBatchingCallback(getPackageTargetsCallback, MAX_PACKAGES_BULK_GET)) {
+        packageIdentifierBatchingCallbackFactory.create(
+            getPackageTargetsCallback, MAX_PACKAGES_BULK_GET)) {
       pathFragment = TargetPatternResolverUtil.getPathFragment(directory);
       recursivePackageProvider.streamPackagesUnderDirectory(
           batchingCallback,
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/SimplePackageIdentifierBatchingCallback.java b/src/main/java/com/google/devtools/build/lib/skyframe/SimplePackageIdentifierBatchingCallback.java
new file mode 100644
index 0000000..80ce116
--- /dev/null
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/SimplePackageIdentifierBatchingCallback.java
@@ -0,0 +1,74 @@
+// Copyright 2015 The Bazel Authors. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package com.google.devtools.build.lib.skyframe;
+
+import com.google.common.collect.ImmutableList;
+import com.google.devtools.build.lib.cmdline.PackageIdentifier;
+import com.google.devtools.build.lib.concurrent.BatchCallback;
+import com.google.devtools.build.lib.concurrent.ParallelVisitor.UnusedException;
+import javax.annotation.concurrent.GuardedBy;
+
+/**
+ * Simple implementation of {@link PackageIdentifierBatchingCallback} that naively shards a stream
+ * of {@link PackageIdentifier} instances, in order, into fixed-size batches. The final batch may be
+ * smaller than the others.
+ */
+public class SimplePackageIdentifierBatchingCallback implements PackageIdentifierBatchingCallback {
+  private final BatchCallback<PackageIdentifier, UnusedException> batchResults;
+  private final int batchSize;
+
+  @GuardedBy("this")
+  private ImmutableList.Builder<PackageIdentifier> packageIdentifiers;
+
+  @GuardedBy("this")
+  private int bufferedPackageIds;
+
+  public SimplePackageIdentifierBatchingCallback(
+      BatchCallback<PackageIdentifier, UnusedException> batchResults, int batchSize) {
+    this.batchResults = batchResults;
+    this.batchSize = batchSize;
+    reset();
+  }
+
+  @Override
+  public synchronized void process(Iterable<PackageIdentifier> partialResult)
+      throws InterruptedException {
+    for (PackageIdentifier path : partialResult) {
+      packageIdentifiers.add(path);
+      bufferedPackageIds++;
+      if (bufferedPackageIds >= this.batchSize) {
+        flush();
+      }
+    }
+  }
+
+  @Override
+  public synchronized void close() throws InterruptedException {
+    flush();
+  }
+
+  @GuardedBy("this")
+  private void flush() throws InterruptedException {
+    if (bufferedPackageIds > 0) {
+      batchResults.process(packageIdentifiers.build());
+      reset();
+    }
+  }
+
+  @GuardedBy("this")
+  private void reset() {
+    packageIdentifiers = ImmutableList.builderWithExpectedSize(batchSize);
+    bufferedPackageIds = 0;
+  }
+}
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/SkyframeTargetPatternEvaluator.java b/src/main/java/com/google/devtools/build/lib/skyframe/SkyframeTargetPatternEvaluator.java
index 4b1094a..da99603 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/SkyframeTargetPatternEvaluator.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/SkyframeTargetPatternEvaluator.java
@@ -250,7 +250,8 @@
                   ImmutableMap.of(pkg.getPackageIdentifier(), pkg)),
               eventHandler,
               FilteringPolicies.NO_FILTER,
-              /* packageSemaphore= */ null);
+              /* packageSemaphore= */ null,
+              SimplePackageIdentifierBatchingCallback::new);
       AtomicReference<Collection<Target>> result = new AtomicReference<>();
       targetPattern.eval(
           resolver,
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/TargetPatternFunction.java b/src/main/java/com/google/devtools/build/lib/skyframe/TargetPatternFunction.java
index df3578e..4a3d93e 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/TargetPatternFunction.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/TargetPatternFunction.java
@@ -66,7 +66,8 @@
               provider,
               env.getListener(),
               patternKey.getPolicy(),
-              MultisetSemaphore.<PackageIdentifier>unbounded());
+              MultisetSemaphore.<PackageIdentifier>unbounded(),
+              SimplePackageIdentifierBatchingCallback::new);
       ImmutableSet<PathFragment> excludedSubdirectories = patternKey.getExcludedSubdirectories();
       ResolvedTargets.Builder<Target> resolvedTargetsBuilder = ResolvedTargets.builder();
       BatchCallback<Target, RuntimeException> callback =