Simple concurrency tests for EvaluableGraph implementations.
--
MOS_MIGRATED_REVID=96414434
diff --git a/src/test/java/com/google/devtools/build/skyframe/GraphConcurrencyTest.java b/src/test/java/com/google/devtools/build/skyframe/GraphConcurrencyTest.java
index 24cde37..66a1355 100644
--- a/src/test/java/com/google/devtools/build/skyframe/GraphConcurrencyTest.java
+++ b/src/test/java/com/google/devtools/build/skyframe/GraphConcurrencyTest.java
@@ -13,21 +13,49 @@
// limitations under the License.
package com.google.devtools.build.skyframe;
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import com.google.common.base.Throwables;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Sets;
+import com.google.devtools.build.lib.concurrent.ExecutorUtil;
+import com.google.devtools.build.lib.concurrent.KeyedLocker;
+import com.google.devtools.build.lib.concurrent.RefCountedMultisetKeyedLocker;
+import com.google.devtools.build.lib.concurrent.ThrowableRecordingRunnableWrapper;
+import com.google.devtools.build.lib.testutil.TestUtils;
+import com.google.devtools.build.lib.util.GroupedList.GroupedListHelper;
+import com.google.devtools.build.skyframe.GraphTester.StringValue;
+import com.google.devtools.build.skyframe.NodeEntry.DependencyState;
+
import org.junit.Before;
import org.junit.Test;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
/** Base class for concurrency sanity tests on {@link EvaluableGraph} implementations. */
public abstract class GraphConcurrencyTest {
private static final SkyFunctionName SKY_FUNCTION_NAME =
new SkyFunctionName("GraphConcurrencyTestKey", /*isComputed=*/false);
private ProcessableGraph graph;
-
+ private ThrowableRecordingRunnableWrapper wrapper;
protected abstract ProcessableGraph getGraph();
@Before
public void init() {
this.graph = getGraph();
+ this.wrapper = new ThrowableRecordingRunnableWrapper("GraphConcurrencyTest");
}
private SkyKey key(String name) {
@@ -39,5 +67,269 @@
graph.createIfAbsent(key("cat"));
}
- // TODO(bazel-team): Add tests.
+ // Tests adding and removing Rdeps of a {@link NodeEntry} while a node transitions from
+ // not done to done.
+ @Test
+ public void testAddRemoveRdeps() throws Exception {
+ SkyKey key = key("foo");
+ final NodeEntry entry = graph.createIfAbsent(key);
+ // These numbers are arbitrary.
+ int numThreads = 50;
+ int numKeys = 100;
+ // One chunk will be used to add and remove rdeps before setting the node value. The second
+ // chunk of work will have the node value set and the last chunk will be to add and remove
+ // rdeps after the value has been set.
+ final int chunkSize = 40;
+ final int numIterations = chunkSize * 3;
+ // This latch is used to signal that the runnables have been submitted to the executor.
+ final CountDownLatch countDownLatch1 = new CountDownLatch(1);
+ // This latch is used to signal to the main thread that we have begun the second chunk
+ // for sufficiently many keys. The minimum of numThreads and numKeys is used to prevent
+ // thread starvation from causing a delay here.
+ final CountDownLatch countDownLatch2 = new CountDownLatch(Math.min(numThreads, numKeys));
+ // This latch is used to guarantee that we set the node's value before we enter the third
+ // chunk for any key.
+ final CountDownLatch countDownLatch3 = new CountDownLatch(1);
+ ExecutorService pool = Executors.newFixedThreadPool(numThreads);
+ // Add single rdep before transition to done.
+ assertEquals(DependencyState.NEEDS_SCHEDULING, entry.addReverseDepAndCheckIfDone(key("rdep")));
+ for (int i = 0; i < numKeys; i++) {
+ final int j = i;
+ Runnable r = new Runnable() {
+ @Override
+ public void run() {
+ try {
+ countDownLatch1.await();
+ // Add and remove the rdep a bunch of times to test interleaving.
+ for (int k = 1; k <= numIterations; k++) {
+ if (k == chunkSize) {
+ countDownLatch2.countDown();
+ }
+ entry.addReverseDepAndCheckIfDone(key("rdep" + j));
+ entry.removeReverseDep(key("rdep" + j));
+ if (k == chunkSize * 2) {
+ countDownLatch3.await();
+ }
+ }
+ entry.addReverseDepAndCheckIfDone(key("rdep" + j));
+ } catch (InterruptedException e) {
+ fail("Test failed: " + e.toString());
+ }
+ }
+ };
+ pool.execute(wrapper.wrap(r));
+ }
+ countDownLatch1.countDown();
+ try {
+ countDownLatch2.await();
+ } catch (InterruptedException e) {
+ fail("Test failed: " + e.toString());
+ }
+ entry.setValue(new StringValue("foo1"), new IntVersion(1));
+ countDownLatch3.countDown();
+ entry.removeReverseDep(key("rdep"));
+ boolean interrupted = ExecutorUtil.interruptibleShutdown(pool);
+ Throwables.propagateIfPossible(wrapper.getFirstThrownError());
+ if (interrupted) {
+ Thread.currentThread().interrupt();
+ throw new InterruptedException();
+ }
+ assertEquals(new StringValue("foo1"), graph.get(key).getValue());
+ assertEquals(numKeys, Iterables.size(graph.get(key).getReverseDeps()));
+ // Mark the node as dirty again and check that the reverse deps have been preserved.
+ entry.markDirty(true);
+ startEvaluation(entry);
+ entry.setValue(new StringValue("foo2"), new IntVersion(2));
+ assertEquals(new StringValue("foo2"), graph.get(key).getValue());
+ assertEquals(numKeys, Iterables.size(graph.get(key).getReverseDeps()));
+ }
+
+ // Tests adding inflight nodes with a given key while an existing node with the same key
+ // undergoes a transition from not done to done.
+ @Test
+ public void testAddingInflightNodes() throws Exception {
+ int numThreads = 50;
+ final KeyedLocker<SkyKey> locker =
+ new RefCountedMultisetKeyedLocker.Factory<SkyKey>().create();
+ ExecutorService pool = Executors.newFixedThreadPool(numThreads);
+ final int numKeys = 500;
+ // Add each key 10 times.
+ final Set<SkyKey> nodeCreated = Sets.newConcurrentHashSet();
+ final Set<SkyKey> valuesSet = Sets.newConcurrentHashSet();
+ for (int i = 0; i < 10; i++) {
+ for (int j = 0; j < numKeys; j++) {
+ final int keyNum = j;
+ final SkyKey key = key("foo" + keyNum);
+ Runnable r = new Runnable() {
+ public void run() {
+ NodeEntry entry;
+ try (KeyedLocker.AutoUnlocker unlocker = locker.lock(key)) {
+ entry = graph.get(key);
+ if (entry == null) {
+ assertTrue(nodeCreated.add(key));
+ }
+ entry = graph.createIfAbsent(key);
+ }
+ // {@code entry.addReverseDepAndCheckIfDone(null)} should return NEEDS_SCHEDULING at
+ // most once.
+ if (startEvaluation(entry).equals(DependencyState.NEEDS_SCHEDULING)) {
+ assertTrue(valuesSet.add(key));
+ // Set to done.
+ entry.setValue(new StringValue("bar" + keyNum), new IntVersion(keyNum));
+ assertThat(entry.isDone()).isTrue();
+ }
+ // This shouldn't cause any problems from the other threads.
+ graph.createIfAbsent(key);
+ }
+ };
+ pool.execute(wrapper.wrap(r));
+ }
+ }
+ boolean interrupted = ExecutorUtil.interruptibleShutdown(pool);
+ Throwables.propagateIfPossible(wrapper.getFirstThrownError());
+ if (interrupted) {
+ Thread.currentThread().interrupt();
+ throw new InterruptedException();
+ }
+ // Check that all the values are as expected.
+ for (int i = 0; i < numKeys; i++) {
+ SkyKey key = key("foo" + i);
+ assertTrue(nodeCreated.contains(key));
+ assertTrue(valuesSet.contains(key));
+ assertThat(graph.get(key).getValue()).isEqualTo(new StringValue("bar" + i));
+ assertThat(graph.get(key).getVersion()).isEqualTo(new IntVersion(i));
+ }
+ }
+
+ /**
+ * Initially calling {@link NodeEntry#setValue} and then making sure concurrent calls to
+ * {@link QueryableGraph#get} and {@link QueryableGraph#getBatch} do not interfere with the node.
+ */
+ @Test
+ public void testDoneToDirty() throws Exception {
+ final int numKeys = 1000;
+ int numThreads = 50;
+ final int numBatchRequests = 100;
+ // Create a bunch of done nodes.
+ for (int i = 0; i < numKeys; i++) {
+ NodeEntry entry = graph.createIfAbsent(key("foo" + i));
+ startEvaluation(entry);
+ entry.setValue(new StringValue("bar"), new IntVersion(0));
+ }
+
+ ExecutorService pool1 = Executors.newFixedThreadPool(numThreads);
+ ExecutorService pool2 = Executors.newFixedThreadPool(numThreads);
+ ExecutorService pool3 = Executors.newFixedThreadPool(numThreads);
+
+ // Only start all the threads once the batch requests are ready.
+ final CountDownLatch makeBatchCountDownLatch = new CountDownLatch(numBatchRequests);
+ // Do at least 5 single requests and batch requests before transitioning node.
+ final CountDownLatch getBatchCountDownLatch = new CountDownLatch(5);
+ final CountDownLatch getCountDownLatch = new CountDownLatch(5);
+
+ for (int i = 0; i < numKeys; i++) {
+ final int keyNum = i;
+ // Transition the nodes from done to dirty and then back to done.
+ Runnable r1 = new Runnable() {
+ @Override
+ public void run() {
+ try {
+ makeBatchCountDownLatch.await();
+ getBatchCountDownLatch.await();
+ getCountDownLatch.await();
+ } catch (InterruptedException e) {
+ fail("Test failed: " + e.toString());
+ }
+ NodeEntry entry = graph.get(key("foo" + keyNum));
+ entry.markDirty(true);
+ // Make some changes, like adding a dep and rdep.
+ entry.addReverseDepAndCheckIfDone(key("rdep"));
+ addTemporaryDirectDep(entry, key("dep"));
+ entry.signalDep();
+ // Move node from dirty back to done.
+ entry.setValue(new StringValue("bar" + keyNum), new IntVersion(1));
+ }
+ };
+
+ // Start a bunch of get() calls while the node transitions from dirty to done and back.
+ Runnable r2 = new Runnable() {
+ @Override
+ public void run() {
+ try {
+ makeBatchCountDownLatch.await();
+ } catch (InterruptedException e) {
+ fail("Test failed: " + e.toString());
+ }
+ NodeEntry entry = graph.get(key("foo" + keyNum));
+ assertNotEquals(null, entry);
+ // Requests for the value are made at the same time that the version changes from 0 to 1.
+ // Check that there is no problem in requesting the version and that the number is sane.
+ assertThat(entry.getVersion()).isAnyOf(new IntVersion(0), new IntVersion(1));
+ getCountDownLatch.countDown();
+ }
+ };
+ pool1.execute(wrapper.wrap(r1));
+ pool2.execute(wrapper.wrap(r2));
+ }
+ Random r = new Random(TestUtils.getRandomSeed());
+ // Start a bunch of getBatch() calls while the node transitions from dirty to done and back.
+ for (int i = 0; i < numBatchRequests; i++) {
+ final List<SkyKey> batch = new ArrayList<>(numKeys);
+ // Pseudorandomly uniformly sample the powerset of the keys.
+ for (int j = 0; j < numKeys; j++) {
+ if (r.nextBoolean()) {
+ batch.add(key("foo" + j));
+ }
+ }
+ makeBatchCountDownLatch.countDown();
+ Runnable r3 = new Runnable() {
+ @Override
+ public void run() {
+ try {
+ makeBatchCountDownLatch.await();
+ } catch (InterruptedException e) {
+ fail("Test failed: " + e.toString());
+ }
+ Map<SkyKey, NodeEntry> batchMap = graph.getBatch(batch);
+ getBatchCountDownLatch.countDown();
+ assertEquals(batch.size(), batchMap.size());
+ for (NodeEntry entry : batchMap.values()) {
+ // Batch requests are made at the same time that the version changes from 0 to 1.
+ // Check that there is no problem in requesting the version and that the number is sane.
+ assertThat(entry.getVersion()).isAnyOf(new IntVersion(0), new IntVersion(1));
+ }
+ }
+ };
+ pool3.execute(wrapper.wrap(r3));
+ }
+ boolean interrupted = ExecutorUtil.interruptibleShutdown(pool1);
+ interrupted |= ExecutorUtil.interruptibleShutdown(pool2);
+ interrupted |= ExecutorUtil.interruptibleShutdown(pool3);
+ Throwables.propagateIfPossible(wrapper.getFirstThrownError());
+ if (interrupted) {
+ Thread.currentThread().interrupt();
+ throw new InterruptedException();
+ }
+ for (int i = 0; i < numKeys; i++) {
+ NodeEntry entry = graph.get(key("foo" + i));
+ assertThat(entry.getValue()).isEqualTo(new StringValue("bar" + i));
+ assertThat(entry.getVersion()).isEqualTo(new IntVersion(1));
+ for (SkyKey key : entry.getReverseDeps()) {
+ assertEquals(key("rdep"), key);
+ }
+ for (SkyKey key : entry.getDirectDeps()) {
+ assertEquals(key("dep"), key);
+ }
+ }
+ }
+
+ private DependencyState startEvaluation(NodeEntry entry) {
+ return entry.addReverseDepAndCheckIfDone(null);
+ }
+
+ private static void addTemporaryDirectDep(NodeEntry entry, SkyKey key) {
+ GroupedListHelper<SkyKey> helper = new GroupedListHelper<>();
+ helper.add(key);
+ entry.addTemporaryDirectDeps(helper);
+ }
}