| // Copyright 2019 The Bazel Authors. All rights reserved. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| package com.google.devtools.build.lib.concurrent; |
| |
| import static com.google.common.truth.Truth.assertThat; |
| |
| import com.google.common.collect.ImmutableList; |
| import com.google.common.collect.ImmutableMultimap; |
| import com.google.common.collect.Iterables; |
| import com.google.common.collect.Sets; |
| import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe; |
| import com.google.devtools.build.lib.testutil.TestThread; |
| import java.util.ArrayList; |
| import java.util.Optional; |
| import java.util.Set; |
| import java.util.concurrent.CountDownLatch; |
| import java.util.concurrent.Executors; |
| import org.junit.Test; |
| import org.junit.runner.RunWith; |
| import org.junit.runners.JUnit4; |
| |
| /** Unit tests for {@link ParallelVisitor}. */ |
| @RunWith(JUnit4.class) |
| public class ParallelVisitorTest { |
| |
| private static final int BATCH_CALLBACK_SIZE = 10_000; |
| private static final long MIN_PENDING_TASKS = 3L; |
| |
| private static class SampleException extends Exception { |
| SampleException() { |
| super("sample exception"); |
| } |
| } |
| |
| @ThreadSafe |
| private interface TestCallback<T> extends ThreadSafeBatchCallback<T, SampleException> { |
| |
| @Override |
| void process(Iterable<T> partialResult) throws SampleException, InterruptedException; |
| } |
| |
| /** |
| * A dummy {@link ParallelVisitor} which waits for signal from a {@link CountDownLatch} when |
| * {@link #getVisitResult} is invoked. It allows us to test interruptibility. |
| */ |
| private static class DelayGettingVisitResultParallelVisitor |
| extends ParallelVisitor< |
| String, String, String, String, SampleException, TestCallback<String>> { |
| private final CountDownLatch invocationLatch; |
| private final CountDownLatch delayLatch; |
| |
| private DelayGettingVisitResultParallelVisitor( |
| CountDownLatch invocationLatch, CountDownLatch delayLatch) { |
| super( |
| targets -> {}, |
| SampleException.class, |
| /*visitBatchSize=*/ 1, |
| /*processResultsBatchSize=*/ 1, |
| /*minPendingTasks=*/ MIN_PENDING_TASKS, |
| /*batchCallbackSize=*/ BATCH_CALLBACK_SIZE, |
| Executors.newFixedThreadPool(3), |
| VisitTaskStatusCallback.NULL_INSTANCE); |
| this.invocationLatch = invocationLatch; |
| this.delayLatch = delayLatch; |
| } |
| |
| @Override |
| protected Visit getVisitResult(Iterable<String> values) |
| throws SampleException, InterruptedException { |
| invocationLatch.countDown(); |
| delayLatch.await(); |
| return new Visit(ImmutableList.of(), values); |
| } |
| |
| @Override |
| protected Iterable<String> preprocessInitialVisit(Iterable<String> visitationKeys) { |
| return visitationKeys; |
| } |
| |
| @Override |
| protected Iterable<String> outputKeysToOutputValues(Iterable<String> targetKeys) |
| throws SampleException, InterruptedException { |
| return ImmutableList.of(); |
| } |
| |
| @Override |
| protected Iterable<String> noteAndReturnUniqueVisitationKeys( |
| Iterable<String> prospectiveVisitationKeys) throws SampleException { |
| return ImmutableList.copyOf(prospectiveVisitationKeys); |
| } |
| } |
| |
| @Test |
| public void testInterrupt() throws Exception { |
| // This test verifies that visitations by ParallelVisitor can be interrupted. It also serves as |
| // a regression test of b/62221332. |
| CountDownLatch invocationLatch = new CountDownLatch(1); |
| CountDownLatch delayLatch = new CountDownLatch(1); |
| DelayGettingVisitResultParallelVisitor visitor = |
| new DelayGettingVisitResultParallelVisitor(invocationLatch, delayLatch); |
| ImmutableList<String> keysToVisit = ImmutableList.of("for_testing"); |
| |
| TestThread testThread = new TestThread(() -> visitor.visitAndWaitForCompletion(keysToVisit)); |
| testThread.start(); |
| |
| // Send an interrupt signal to the visitor after #visitAndWaitForCompletion is invoked. |
| invocationLatch.await(); |
| testThread.interrupt(); |
| |
| // Verify that the thread is interruptable (unit test will time out if it's not interruptable). |
| testThread.join(); |
| } |
| |
| private static class RecordingParallelVisitor |
| extends ParallelVisitor< |
| InputKey, String, String, String, SampleException, TestCallback<String>> { |
| private final ArrayList<Iterable<String>> visits = new ArrayList<>(); |
| private final ImmutableMultimap<String, String> successorMap; |
| private final Set<String> visited = Sets.newConcurrentHashSet(); |
| |
| private RecordingParallelVisitor( |
| ImmutableMultimap<String, String> successors, |
| RecordingCallback recordingCallback, |
| int visitBatchSize, |
| int processResultsBatchSize) { |
| super( |
| recordingCallback, |
| SampleException.class, |
| visitBatchSize, |
| processResultsBatchSize, |
| MIN_PENDING_TASKS, |
| BATCH_CALLBACK_SIZE, |
| Executors.newFixedThreadPool(3), |
| VisitTaskStatusCallback.NULL_INSTANCE); |
| this.successorMap = successors; |
| } |
| |
| @Override |
| protected Visit getVisitResult(Iterable<String> values) { |
| synchronized (this) { |
| visits.add(values); |
| } |
| return new Visit( |
| values, |
| Iterables.concat( |
| Iterables.transform( |
| values, |
| v -> Optional.ofNullable(successorMap.get(v)).orElse(ImmutableList.of())))); |
| } |
| |
| @Override |
| protected Iterable<String> noteAndReturnUniqueVisitationKeys( |
| Iterable<String> prospectiveVisitationKeys) { |
| return Iterables.filter(prospectiveVisitationKeys, visited::add); |
| } |
| |
| @Override |
| protected Iterable<String> outputKeysToOutputValues(Iterable<String> targetKeys) { |
| return targetKeys; |
| } |
| |
| @Override |
| protected Iterable<String> preprocessInitialVisit(Iterable<InputKey> visitationKeys) { |
| return Iterables.transform(visitationKeys, InputKey::extract); |
| } |
| } |
| |
| private static class RecordingCallback implements TestCallback<String> { |
| private final ArrayList<Iterable<String>> results = new ArrayList<>(); |
| |
| @Override |
| public synchronized void process(Iterable<String> partialResult) { |
| results.add(partialResult); |
| } |
| } |
| |
| private static class InputKey { |
| private final String str; |
| |
| private InputKey(String str) { |
| this.str = str; |
| } |
| |
| private static String extract(InputKey key) { |
| return key.str; |
| } |
| |
| @Override |
| public boolean equals(Object obj) { |
| if (!(obj instanceof InputKey)) { |
| return false; |
| } |
| InputKey other = (InputKey) obj; |
| return this.str.equals(other.str); |
| } |
| |
| @Override |
| public int hashCode() { |
| return str.hashCode(); |
| } |
| } |
| |
| @Test |
| public void testRespectsBatchSizes() throws Exception { |
| int visitBatchSize = 2; |
| int processResultsBatchSize = 1; |
| RecordingCallback callback = new RecordingCallback(); |
| RecordingParallelVisitor visitor = |
| new RecordingParallelVisitor( |
| ImmutableMultimap.<String, String>builder() |
| .putAll("k1", ImmutableList.of("k2", "k3", "k4", "k5")) |
| .putAll("k2", ImmutableList.of("k6", "k7", "k8", "k9")) |
| .putAll("k3", ImmutableList.of("k4", "k5", "k6", "k7", "k8", "k9")) |
| .build(), |
| callback, |
| visitBatchSize, |
| processResultsBatchSize); |
| visitor.visitAndWaitForCompletion(ImmutableList.of(new InputKey("k1"))); |
| |
| for (Iterable<String> visitBatch : visitor.visits) { |
| assertThat(Iterables.size(visitBatch)).isAtMost(visitBatchSize); |
| } |
| for (Iterable<String> resultBatch : callback.results) { |
| assertThat(Iterables.size(resultBatch)).isAtMost(processResultsBatchSize); |
| } |
| } |
| } |