Simple concurrency tests for EvaluableGraph implementations.

--
MOS_MIGRATED_REVID=96414434
diff --git a/src/test/java/com/google/devtools/build/skyframe/GraphConcurrencyTest.java b/src/test/java/com/google/devtools/build/skyframe/GraphConcurrencyTest.java
index 24cde37..66a1355 100644
--- a/src/test/java/com/google/devtools/build/skyframe/GraphConcurrencyTest.java
+++ b/src/test/java/com/google/devtools/build/skyframe/GraphConcurrencyTest.java
@@ -13,21 +13,49 @@
 // limitations under the License.
 package com.google.devtools.build.skyframe;
 
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import com.google.common.base.Throwables;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Sets;
+import com.google.devtools.build.lib.concurrent.ExecutorUtil;
+import com.google.devtools.build.lib.concurrent.KeyedLocker;
+import com.google.devtools.build.lib.concurrent.RefCountedMultisetKeyedLocker;
+import com.google.devtools.build.lib.concurrent.ThrowableRecordingRunnableWrapper;
+import com.google.devtools.build.lib.testutil.TestUtils;
+import com.google.devtools.build.lib.util.GroupedList.GroupedListHelper;
+import com.google.devtools.build.skyframe.GraphTester.StringValue;
+import com.google.devtools.build.skyframe.NodeEntry.DependencyState;
+
 import org.junit.Before;
 import org.junit.Test;
 
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
 /** Base class for concurrency sanity tests on {@link EvaluableGraph} implementations. */
 public abstract class GraphConcurrencyTest {
 
   private static final SkyFunctionName SKY_FUNCTION_NAME =
       new SkyFunctionName("GraphConcurrencyTestKey", /*isComputed=*/false);
   private ProcessableGraph graph;
-
+  private ThrowableRecordingRunnableWrapper wrapper;
   protected abstract ProcessableGraph getGraph();
 
   @Before
   public void init() {
     this.graph = getGraph();
+    this.wrapper = new ThrowableRecordingRunnableWrapper("GraphConcurrencyTest");
   }
 
   private SkyKey key(String name) {
@@ -39,5 +67,269 @@
     graph.createIfAbsent(key("cat"));
   }
 
-  // TODO(bazel-team): Add tests.
+  // Tests adding and removing Rdeps of a {@link NodeEntry} while a node transitions from
+  // not done to done.
+  @Test
+  public void testAddRemoveRdeps() throws Exception {
+    SkyKey key = key("foo");
+    final NodeEntry entry = graph.createIfAbsent(key);
+    // These numbers are arbitrary.
+    int numThreads = 50;
+    int numKeys = 100;
+    // One chunk will be used to add and remove rdeps before setting the node value.  The second
+    // chunk of work will have the node value set and the last chunk will be to add and remove
+    // rdeps after the value has been set.
+    final int chunkSize = 40;
+    final int numIterations = chunkSize * 3;
+    // This latch is used to signal that the runnables have been submitted to the executor.
+    final CountDownLatch countDownLatch1 = new CountDownLatch(1);
+    // This latch is used to signal to the main thread that we have begun the second chunk
+    // for sufficiently many keys.  The minimum of numThreads and numKeys is used to prevent
+    // thread starvation from causing a delay here.
+    final CountDownLatch countDownLatch2 = new CountDownLatch(Math.min(numThreads, numKeys));
+    // This latch is used to guarantee that we set the node's value before we enter the third
+    // chunk for any key.
+    final CountDownLatch countDownLatch3 = new CountDownLatch(1);
+    ExecutorService pool = Executors.newFixedThreadPool(numThreads);
+    // Add single rdep before transition to done.
+    assertEquals(DependencyState.NEEDS_SCHEDULING, entry.addReverseDepAndCheckIfDone(key("rdep")));
+    for (int i = 0; i < numKeys; i++) {
+      final int j = i;
+      Runnable r = new Runnable() {
+        @Override
+        public void run() {
+          try {
+            countDownLatch1.await();
+            // Add and remove the rdep a bunch of times to test interleaving.
+            for (int k = 1; k <= numIterations; k++) {
+              if (k == chunkSize) {
+                countDownLatch2.countDown();
+              }
+              entry.addReverseDepAndCheckIfDone(key("rdep" + j));
+              entry.removeReverseDep(key("rdep" + j));
+              if (k == chunkSize * 2) {
+                countDownLatch3.await();
+              }
+            }
+            entry.addReverseDepAndCheckIfDone(key("rdep" + j));
+          } catch (InterruptedException e) {
+            fail("Test failed: " + e.toString());
+          }
+        }
+      };
+      pool.execute(wrapper.wrap(r));
+    }
+    countDownLatch1.countDown();
+    try {
+      countDownLatch2.await();
+    } catch (InterruptedException e) {
+      fail("Test failed: " + e.toString());
+    }
+    entry.setValue(new StringValue("foo1"), new IntVersion(1));
+    countDownLatch3.countDown();
+    entry.removeReverseDep(key("rdep"));
+    boolean interrupted = ExecutorUtil.interruptibleShutdown(pool);
+    Throwables.propagateIfPossible(wrapper.getFirstThrownError());
+    if (interrupted) {
+      Thread.currentThread().interrupt();
+      throw new InterruptedException();
+    }
+    assertEquals(new StringValue("foo1"), graph.get(key).getValue());
+    assertEquals(numKeys, Iterables.size(graph.get(key).getReverseDeps()));
+    // Mark the node as dirty again and check that the reverse deps have been preserved.
+    entry.markDirty(true);
+    startEvaluation(entry);
+    entry.setValue(new StringValue("foo2"), new IntVersion(2));
+    assertEquals(new StringValue("foo2"), graph.get(key).getValue());
+    assertEquals(numKeys, Iterables.size(graph.get(key).getReverseDeps()));
+  }
+
+  // Tests adding inflight nodes with a given key while an existing node with the same key
+  // undergoes a transition from not done to done.
+  @Test
+  public void testAddingInflightNodes() throws Exception {
+    int numThreads = 50;
+    final KeyedLocker<SkyKey> locker =
+        new RefCountedMultisetKeyedLocker.Factory<SkyKey>().create();
+    ExecutorService pool = Executors.newFixedThreadPool(numThreads);
+    final int numKeys = 500;
+    // Add each key 10 times.
+    final Set<SkyKey> nodeCreated = Sets.newConcurrentHashSet();
+    final Set<SkyKey> valuesSet = Sets.newConcurrentHashSet();
+    for (int i = 0; i < 10; i++) {
+      for (int j = 0; j < numKeys; j++) {
+        final int keyNum = j;
+        final SkyKey key = key("foo" + keyNum);
+        Runnable r = new Runnable() {
+          public void run() {
+            NodeEntry entry;
+            try (KeyedLocker.AutoUnlocker unlocker = locker.lock(key)) {
+              entry = graph.get(key);
+              if (entry == null) {
+                assertTrue(nodeCreated.add(key));
+              }
+              entry = graph.createIfAbsent(key);
+            }
+            // {@code entry.addReverseDepAndCheckIfDone(null)} should return NEEDS_SCHEDULING at
+            // most once.
+            if (startEvaluation(entry).equals(DependencyState.NEEDS_SCHEDULING)) {
+              assertTrue(valuesSet.add(key));
+              // Set to done.
+              entry.setValue(new StringValue("bar" + keyNum), new IntVersion(keyNum));
+              assertThat(entry.isDone()).isTrue();
+            }
+            // This shouldn't cause any problems from the other threads.
+            graph.createIfAbsent(key);
+          }
+        };
+        pool.execute(wrapper.wrap(r));
+      }
+    }
+    boolean interrupted = ExecutorUtil.interruptibleShutdown(pool);
+    Throwables.propagateIfPossible(wrapper.getFirstThrownError());
+    if (interrupted) {
+      Thread.currentThread().interrupt();
+      throw new InterruptedException();
+    }
+    // Check that all the values are as expected.
+    for (int i = 0; i < numKeys; i++) {
+      SkyKey key = key("foo" + i);
+      assertTrue(nodeCreated.contains(key));
+      assertTrue(valuesSet.contains(key));
+      assertThat(graph.get(key).getValue()).isEqualTo(new StringValue("bar" + i));
+      assertThat(graph.get(key).getVersion()).isEqualTo(new IntVersion(i));
+    }
+  }
+
+  /**
+   * Initially calling {@link NodeEntry#setValue} and then making sure concurrent calls to
+   * {@link QueryableGraph#get} and {@link QueryableGraph#getBatch} do not interfere with the node.
+   */
+  @Test
+  public void testDoneToDirty() throws Exception {
+    final int numKeys = 1000;
+    int numThreads = 50;
+    final int numBatchRequests = 100;
+    // Create a bunch of done nodes.
+    for (int i = 0; i < numKeys; i++) {
+      NodeEntry entry = graph.createIfAbsent(key("foo" + i));
+      startEvaluation(entry);
+      entry.setValue(new StringValue("bar"), new IntVersion(0));
+    }
+
+    ExecutorService pool1 = Executors.newFixedThreadPool(numThreads);
+    ExecutorService pool2 = Executors.newFixedThreadPool(numThreads);
+    ExecutorService pool3 = Executors.newFixedThreadPool(numThreads);
+
+    // Only start all the threads once the batch requests are ready.
+    final CountDownLatch makeBatchCountDownLatch = new CountDownLatch(numBatchRequests);
+    // Do at least 5 single requests and batch requests before transitioning node.
+    final CountDownLatch getBatchCountDownLatch = new CountDownLatch(5);
+    final CountDownLatch getCountDownLatch = new CountDownLatch(5);
+
+    for (int i = 0; i < numKeys; i++) {
+      final int keyNum = i;
+      // Transition the nodes from done to dirty and then back to done.
+      Runnable r1 = new Runnable() {
+        @Override
+        public void run() {
+          try {
+            makeBatchCountDownLatch.await();
+            getBatchCountDownLatch.await();
+            getCountDownLatch.await();
+          } catch (InterruptedException e) {
+            fail("Test failed: " + e.toString());
+          }
+          NodeEntry entry = graph.get(key("foo" + keyNum));
+          entry.markDirty(true);
+          // Make some changes, like adding a dep and rdep.
+          entry.addReverseDepAndCheckIfDone(key("rdep"));
+          addTemporaryDirectDep(entry, key("dep"));
+          entry.signalDep();
+          // Move node from dirty back to done.
+          entry.setValue(new StringValue("bar" + keyNum), new IntVersion(1));
+        }
+      };
+
+      // Start a bunch of get() calls while the node transitions from dirty to done and back.
+      Runnable r2 = new Runnable() {
+        @Override
+        public void run() {
+          try {
+            makeBatchCountDownLatch.await();
+          } catch (InterruptedException e) {
+            fail("Test failed: " + e.toString());
+          }
+          NodeEntry entry = graph.get(key("foo" + keyNum));
+          assertNotEquals(null, entry);
+          // Requests for the value are made at the same time that the version changes from 0 to 1.
+          // Check that there is no problem in requesting the version and that the number is sane.
+          assertThat(entry.getVersion()).isAnyOf(new IntVersion(0), new IntVersion(1));
+          getCountDownLatch.countDown();
+        }
+      };
+      pool1.execute(wrapper.wrap(r1));
+      pool2.execute(wrapper.wrap(r2));
+    }
+    Random r = new Random(TestUtils.getRandomSeed());
+    // Start a bunch of getBatch() calls while the node transitions from dirty to done and back.
+    for (int i = 0; i < numBatchRequests; i++) {
+      final List<SkyKey> batch = new ArrayList<>(numKeys);
+      // Pseudorandomly uniformly sample the powerset of the keys.
+      for (int j = 0; j < numKeys; j++) {
+        if (r.nextBoolean()) {
+          batch.add(key("foo" + j));
+        }
+      }
+      makeBatchCountDownLatch.countDown();
+      Runnable r3 = new Runnable() {
+        @Override
+        public void run() {
+          try {
+            makeBatchCountDownLatch.await();
+          } catch (InterruptedException e) {
+            fail("Test failed: " + e.toString());
+          }
+          Map<SkyKey, NodeEntry> batchMap = graph.getBatch(batch);
+          getBatchCountDownLatch.countDown();
+          assertEquals(batch.size(), batchMap.size());
+          for (NodeEntry entry : batchMap.values()) {
+            // Batch requests are made at the same time that the version changes from 0 to 1.
+            // Check that there is no problem in requesting the version and that the number is sane.
+            assertThat(entry.getVersion()).isAnyOf(new IntVersion(0), new IntVersion(1));
+          }
+        }
+      };
+      pool3.execute(wrapper.wrap(r3));
+    }
+    boolean interrupted = ExecutorUtil.interruptibleShutdown(pool1);
+    interrupted |= ExecutorUtil.interruptibleShutdown(pool2);
+    interrupted |= ExecutorUtil.interruptibleShutdown(pool3);
+    Throwables.propagateIfPossible(wrapper.getFirstThrownError());
+    if (interrupted) {
+      Thread.currentThread().interrupt();
+      throw new InterruptedException();
+    }
+    for (int i = 0; i < numKeys; i++) {
+      NodeEntry entry = graph.get(key("foo" + i));
+      assertThat(entry.getValue()).isEqualTo(new StringValue("bar" + i));
+      assertThat(entry.getVersion()).isEqualTo(new IntVersion(1));
+      for (SkyKey key : entry.getReverseDeps()) {
+        assertEquals(key("rdep"), key);
+      }
+      for (SkyKey key : entry.getDirectDeps()) {
+        assertEquals(key("dep"), key);
+      }
+    }
+  }
+
+  private DependencyState startEvaluation(NodeEntry entry) {
+    return entry.addReverseDepAndCheckIfDone(null);
+  }
+
+  private static void addTemporaryDirectDep(NodeEntry entry, SkyKey key) {
+    GroupedListHelper<SkyKey> helper = new GroupedListHelper<>();
+    helper.add(key);
+    entry.addTemporaryDirectDeps(helper);
+  }
 }