Add a hook for getting notified when a ParallelVisitor visitation is discovered or completed.
RELNOTES: None
PiperOrigin-RevId: 279320620
diff --git a/src/main/java/com/google/devtools/build/lib/concurrent/ParallelVisitor.java b/src/main/java/com/google/devtools/build/lib/concurrent/ParallelVisitor.java
index 0204c5b..decd9ef 100644
--- a/src/main/java/com/google/devtools/build/lib/concurrent/ParallelVisitor.java
+++ b/src/main/java/com/google/devtools/build/lib/concurrent/ParallelVisitor.java
@@ -55,6 +55,7 @@
private final int processResultsBatchSize;
protected final int resultBatchSize;
private final VisitingTaskExecutor executor;
+ private final VisitTaskStatusCallback visitTaskStatusCallback;
/**
* A queue to store pending visits. These should be unique wrt {@link
@@ -122,12 +123,14 @@
int processResultsBatchSize,
long minPendingTasks,
int batchCallbackSize,
- ExecutorService executor) {
+ ExecutorService executor,
+ VisitTaskStatusCallback visitTaskStatusCallback) {
this.callback = callback;
this.exceptionClass = exceptionClass;
this.visitBatchSize = visitBatchSize;
this.processResultsBatchSize = processResultsBatchSize;
this.resultBatchSize = batchCallbackSize;
+ this.visitTaskStatusCallback = visitTaskStatusCallback;
this.executor =
new VisitingTaskExecutor(executor, PARALLEL_VISITOR_ERROR_CLASSIFIER, batchCallbackSize);
this.minPendingTasks = minPendingTasks;
@@ -144,6 +147,22 @@
ParallelVisitor<InputT, VisitKeyT, OutputKeyT, OutputResultT, ExceptionT, CallbackT> create();
}
+ /** A hook for getting notified when a visitation is discovered or completed. */
+ public interface VisitTaskStatusCallback {
+ void onVisitTaskDiscovered();
+
+ void onVisitTaskCompleted();
+
+ VisitTaskStatusCallback NULL_INSTANCE =
+ new VisitTaskStatusCallback() {
+ @Override
+ public void onVisitTaskDiscovered() {}
+
+ @Override
+ public void onVisitTaskCompleted() {}
+ };
+ }
+
protected abstract Iterable<OutputResultT> outputKeysToOutputValues(
Iterable<OutputKeyT> targetKeys) throws ExceptionT, InterruptedException;
@@ -166,7 +185,7 @@
public void visitAndWaitForCompletion(Iterable<InputT> keys)
throws ExceptionT, InterruptedException {
- noteAndReturnUniqueVisitationKeys(preprocessInitialVisit(keys)).forEach(visitQueue::add);
+ noteAndReturnUniqueVisitationKeys(preprocessInitialVisit(keys)).forEach(this::addToVisitQueue);
executor.visitAndWaitForCompletion();
}
@@ -197,6 +216,11 @@
return builder.build();
}
+ private void addToVisitQueue(VisitKeyT visitKey) {
+ visitQueue.add(visitKey);
+ visitTaskStatusCallback.onVisitTaskDiscovered();
+ }
+
/** A {@link Runnable} which handles {@link ExceptionT} and {@link InterruptedException}. */
protected abstract static class Task<ExceptionT extends Exception> implements Runnable {
protected final Class<ExceptionT> exceptionClass;
@@ -248,7 +272,10 @@
executor.execute(
new GetAndProcessUniqueResultsTask(keysToUseForResultBatch, exceptionClass));
}
- noteAndReturnUniqueVisitationKeys(visit.keysToVisit).forEach(visitQueue::add);
+ noteAndReturnUniqueVisitationKeys(visit.keysToVisit)
+ .forEach(ParallelVisitor.this::addToVisitQueue);
+ keysToVisit.forEach(
+ key -> ParallelVisitor.this.visitTaskStatusCallback.onVisitTaskCompleted());
}
}
diff --git a/src/main/java/com/google/devtools/build/lib/query2/AbstractSkyKeyParallelVisitor.java b/src/main/java/com/google/devtools/build/lib/query2/AbstractSkyKeyParallelVisitor.java
index 35cb00a..5138a82 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/AbstractSkyKeyParallelVisitor.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/AbstractSkyKeyParallelVisitor.java
@@ -32,8 +32,9 @@
Uniquifier<SkyKey> visitationUniquifier,
Callback<T> callback,
int visitBatchSize,
- int processResultsBatchSize) {
- super(callback, visitBatchSize, processResultsBatchSize);
+ int processResultsBatchSize,
+ VisitTaskStatusCallback visitTaskStatusCallback) {
+ super(callback, visitBatchSize, processResultsBatchSize, visitTaskStatusCallback);
this.uniquifier = visitationUniquifier;
}
diff --git a/src/main/java/com/google/devtools/build/lib/query2/AbstractTargetOuputtingVisitor.java b/src/main/java/com/google/devtools/build/lib/query2/AbstractTargetOuputtingVisitor.java
index 13c8cf0..e9ad01b 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/AbstractTargetOuputtingVisitor.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/AbstractTargetOuputtingVisitor.java
@@ -38,7 +38,11 @@
protected final SkyQueryEnvironment env;
protected AbstractTargetOuputtingVisitor(SkyQueryEnvironment env, Callback<Target> callback) {
- super(callback, env.getVisitBatchSizeForParallelVisitation(), PROCESS_RESULTS_BATCH_SIZE);
+ super(
+ callback,
+ env.getVisitBatchSizeForParallelVisitation(),
+ PROCESS_RESULTS_BATCH_SIZE,
+ env.getVisitTaskStatusCallback());
this.env = env;
}
diff --git a/src/main/java/com/google/devtools/build/lib/query2/AbstractUnfilteredTTVDTCVisitor.java b/src/main/java/com/google/devtools/build/lib/query2/AbstractUnfilteredTTVDTCVisitor.java
index c777d3f..b178a2d 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/AbstractUnfilteredTTVDTCVisitor.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/AbstractUnfilteredTTVDTCVisitor.java
@@ -37,7 +37,8 @@
uniquifier,
callback,
env.getVisitBatchSizeForParallelVisitation(),
- processResultsBatchSize);
+ processResultsBatchSize,
+ env.getVisitTaskStatusCallback());
this.env = env;
}
diff --git a/src/main/java/com/google/devtools/build/lib/query2/ParallelVisitorUtils.java b/src/main/java/com/google/devtools/build/lib/query2/ParallelVisitorUtils.java
index cecc828..870ac5a 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/ParallelVisitorUtils.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/ParallelVisitorUtils.java
@@ -86,7 +86,10 @@
extends ParallelVisitor<
SkyKey, VisitKeyT, OutputKeyT, OutputResultT, QueryException, Callback<OutputResultT>> {
public ParallelQueryVisitor(
- Callback<OutputResultT> callback, int visitBatchSize, int processResultsBatchSize) {
+ Callback<OutputResultT> callback,
+ int visitBatchSize,
+ int processResultsBatchSize,
+ VisitTaskStatusCallback visitTaskStatusCallback) {
super(
callback,
QueryException.class,
@@ -94,7 +97,8 @@
processResultsBatchSize,
3L * SkyQueryEnvironment.DEFAULT_THREAD_COUNT,
SkyQueryEnvironment.BATCH_CALLBACK_SIZE,
- FIXED_THREAD_POOL_EXECUTOR);
+ FIXED_THREAD_POOL_EXECUTOR,
+ visitTaskStatusCallback);
}
}
}
diff --git a/src/main/java/com/google/devtools/build/lib/query2/RBuildFilesVisitor.java b/src/main/java/com/google/devtools/build/lib/query2/RBuildFilesVisitor.java
index 79c36df..8e4a7b7 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/RBuildFilesVisitor.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/RBuildFilesVisitor.java
@@ -67,7 +67,11 @@
Uniquifier<SkyKey> resultUniquifier,
QueryExpressionContext<Target> context,
Callback<Target> callback) {
- super(callback, env.getVisitBatchSizeForParallelVisitation(), PROCESS_RESULTS_BATCH_SIZE);
+ super(
+ callback,
+ env.getVisitBatchSizeForParallelVisitation(),
+ PROCESS_RESULTS_BATCH_SIZE,
+ env.getVisitTaskStatusCallback());
this.env = env;
this.visitUniquifier = visitUniquifier;
this.resultUniquifier = resultUniquifier;
diff --git a/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java b/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java
index aa19514..d8388a9 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java
@@ -48,6 +48,7 @@
import com.google.devtools.build.lib.collect.compacthashset.CompactHashSet;
import com.google.devtools.build.lib.concurrent.BlockingStack;
import com.google.devtools.build.lib.concurrent.MultisetSemaphore;
+import com.google.devtools.build.lib.concurrent.ParallelVisitor.VisitTaskStatusCallback;
import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
import com.google.devtools.build.lib.events.DelegatingEventHandler;
import com.google.devtools.build.lib.events.Event;
@@ -853,6 +854,10 @@
return ParallelSkyQueryUtils.VISIT_BATCH_SIZE;
}
+ public VisitTaskStatusCallback getVisitTaskStatusCallback() {
+ return VisitTaskStatusCallback.NULL_INSTANCE;
+ }
+
private Target getLoadTarget(Label label, Package pkg) {
return new FakeLoadTarget(label, pkg);
}
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/TraversalInfoRootPackageExtractor.java b/src/main/java/com/google/devtools/build/lib/skyframe/TraversalInfoRootPackageExtractor.java
index 6a30fa4..4d0da0e 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/TraversalInfoRootPackageExtractor.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/TraversalInfoRootPackageExtractor.java
@@ -119,7 +119,8 @@
processResultsBatchSize,
minPendingTasks,
resultBatchSize,
- PACKAGE_ID_COLLECTING_EXECUTOR);
+ PACKAGE_ID_COLLECTING_EXECUTOR,
+ VisitTaskStatusCallback.NULL_INSTANCE);
this.eventHandler = eventHandler;
this.repository = repository;
this.graph = graph;
diff --git a/src/test/java/com/google/devtools/build/lib/concurrent/ParallelVisitorTest.java b/src/test/java/com/google/devtools/build/lib/concurrent/ParallelVisitorTest.java
index 096dcaa..9e7a7f8 100644
--- a/src/test/java/com/google/devtools/build/lib/concurrent/ParallelVisitorTest.java
+++ b/src/test/java/com/google/devtools/build/lib/concurrent/ParallelVisitorTest.java
@@ -69,7 +69,8 @@
/*processResultsBatchSize=*/ 1,
/*minPendingTasks=*/ MIN_PENDING_TASKS,
/*batchCallbackSize=*/ BATCH_CALLBACK_SIZE,
- Executors.newFixedThreadPool(3));
+ Executors.newFixedThreadPool(3),
+ VisitTaskStatusCallback.NULL_INSTANCE);
this.invocationLatch = invocationLatch;
this.delayLatch = delayLatch;
}
@@ -140,7 +141,8 @@
processResultsBatchSize,
MIN_PENDING_TASKS,
BATCH_CALLBACK_SIZE,
- Executors.newFixedThreadPool(3));
+ Executors.newFixedThreadPool(3),
+ VisitTaskStatusCallback.NULL_INSTANCE);
this.successorMap = successors;
}