blob: da5486832a7cf8230de498a1dc99bb0ba8e7f21c [file] [log] [blame]
// Copyright 2019 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.devtools.build.lib.concurrent;
import static com.google.common.truth.Truth.assertThat;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
import com.google.devtools.build.lib.testutil.TestThread;
import java.util.ArrayList;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link ParallelVisitor}. */
@RunWith(JUnit4.class)
public class ParallelVisitorTest {
private static final int BATCH_CALLBACK_SIZE = 10_000;
private static final long MIN_PENDING_TASKS = 3L;
private static class SampleException extends Exception {
SampleException() {
super("sample exception");
}
}
@ThreadSafe
private interface TestCallback<T> extends BatchCallback<T, SampleException> {
@Override
void process(Iterable<T> partialResult) throws SampleException, InterruptedException;
}
/**
* A dummy {@link ParallelVisitor} which waits for signal from a {@link CountDownLatch} when
* {@link #getVisitResult} is invoked. It allows us to test interruptibility.
*/
private static class DelayGettingVisitResultParallelVisitor
extends ParallelVisitor<
String, String, String, String, SampleException, TestCallback<String>> {
private final CountDownLatch invocationLatch;
private final CountDownLatch delayLatch;
private DelayGettingVisitResultParallelVisitor(
CountDownLatch invocationLatch, CountDownLatch delayLatch) {
super(
targets -> {},
SampleException.class,
/*visitBatchSize=*/ 1,
/*processResultsBatchSize=*/ 1,
/*minPendingTasks=*/ MIN_PENDING_TASKS,
/*batchCallbackSize=*/ BATCH_CALLBACK_SIZE,
Executors.newFixedThreadPool(3),
VisitTaskStatusCallback.NULL_INSTANCE);
this.invocationLatch = invocationLatch;
this.delayLatch = delayLatch;
}
@Override
protected Visit getVisitResult(Iterable<String> values)
throws SampleException, InterruptedException {
invocationLatch.countDown();
delayLatch.await();
return new Visit(ImmutableList.of(), values);
}
@Override
protected Iterable<String> preprocessInitialVisit(Iterable<String> visitationKeys) {
return visitationKeys;
}
@Override
protected Iterable<String> outputKeysToOutputValues(Iterable<String> targetKeys)
throws SampleException, InterruptedException {
return ImmutableList.of();
}
@Override
protected Iterable<String> noteAndReturnUniqueVisitationKeys(
Iterable<String> prospectiveVisitationKeys) throws SampleException {
return ImmutableList.copyOf(prospectiveVisitationKeys);
}
}
@Test
public void testInterrupt() throws Exception {
// This test verifies that visitations by ParallelVisitor can be interrupted. It also serves as
// a regression test of b/62221332.
CountDownLatch invocationLatch = new CountDownLatch(1);
CountDownLatch delayLatch = new CountDownLatch(1);
DelayGettingVisitResultParallelVisitor visitor =
new DelayGettingVisitResultParallelVisitor(invocationLatch, delayLatch);
ImmutableList<String> keysToVisit = ImmutableList.of("for_testing");
TestThread testThread = new TestThread(() -> visitor.visitAndWaitForCompletion(keysToVisit));
testThread.start();
// Send an interrupt signal to the visitor after #visitAndWaitForCompletion is invoked.
invocationLatch.await();
testThread.interrupt();
// Verify that the thread is interruptable (unit test will time out if it's not interruptable).
testThread.join();
}
private static class RecordingParallelVisitor
extends ParallelVisitor<
InputKey, String, String, String, SampleException, TestCallback<String>> {
private final ArrayList<Iterable<String>> visits = new ArrayList<>();
private final ImmutableMultimap<String, String> successorMap;
private final Set<String> visited = Sets.newConcurrentHashSet();
private RecordingParallelVisitor(
ImmutableMultimap<String, String> successors,
RecordingCallback recordingCallback,
int visitBatchSize,
int processResultsBatchSize) {
super(
recordingCallback,
SampleException.class,
visitBatchSize,
processResultsBatchSize,
MIN_PENDING_TASKS,
BATCH_CALLBACK_SIZE,
Executors.newFixedThreadPool(3),
VisitTaskStatusCallback.NULL_INSTANCE);
this.successorMap = successors;
}
@Override
protected Visit getVisitResult(Iterable<String> values) {
synchronized (this) {
visits.add(values);
}
return new Visit(
values,
Iterables.concat(
Iterables.transform(
values,
v -> Optional.ofNullable(successorMap.get(v)).orElse(ImmutableList.of()))));
}
@Override
protected Iterable<String> noteAndReturnUniqueVisitationKeys(
Iterable<String> prospectiveVisitationKeys) {
return Iterables.filter(prospectiveVisitationKeys, visited::add);
}
@Override
protected Iterable<String> outputKeysToOutputValues(Iterable<String> targetKeys) {
return targetKeys;
}
@Override
protected Iterable<String> preprocessInitialVisit(Iterable<InputKey> visitationKeys) {
return Iterables.transform(visitationKeys, InputKey::extract);
}
}
private static class RecordingCallback implements TestCallback<String> {
private final ArrayList<Iterable<String>> results = new ArrayList<>();
@Override
public synchronized void process(Iterable<String> partialResult) {
results.add(partialResult);
}
}
private static class InputKey {
private final String str;
private InputKey(String str) {
this.str = str;
}
private static String extract(InputKey key) {
return key.str;
}
@Override
public boolean equals(Object obj) {
if (!(obj instanceof InputKey)) {
return false;
}
InputKey other = (InputKey) obj;
return this.str.equals(other.str);
}
@Override
public int hashCode() {
return str.hashCode();
}
}
@Test
public void testRespectsBatchSizes() throws Exception {
int visitBatchSize = 2;
int processResultsBatchSize = 1;
RecordingCallback callback = new RecordingCallback();
RecordingParallelVisitor visitor =
new RecordingParallelVisitor(
ImmutableMultimap.<String, String>builder()
.putAll("k1", ImmutableList.of("k2", "k3", "k4", "k5"))
.putAll("k2", ImmutableList.of("k6", "k7", "k8", "k9"))
.putAll("k3", ImmutableList.of("k4", "k5", "k6", "k7", "k8", "k9"))
.build(),
callback,
visitBatchSize,
processResultsBatchSize);
visitor.visitAndWaitForCompletion(ImmutableList.of(new InputKey("k1")));
for (Iterable<String> visitBatch : visitor.visits) {
assertThat(Iterables.size(visitBatch)).isAtMost(visitBatchSize);
}
for (Iterable<String> resultBatch : callback.results) {
assertThat(Iterables.size(resultBatch)).isAtMost(processResultsBatchSize);
}
}
}