Make some improvements to GraphConcurrencyTest -- versions are sensible, and we should now fail the test if there's an exception thrown in a worker thread that would cause a deadlock due to countdown latches not being mutated as expected.

I don't know why the Mac Bazel tests are internally failing to build. Any ideas? I was very cargo-culty with the testutil library because I have no idea what's going on there with the duplicate packages.

--
MOS_MIGRATED_REVID=99733410
diff --git a/src/test/java/com/google/devtools/build/skyframe/GraphConcurrencyTest.java b/src/test/java/com/google/devtools/build/skyframe/GraphConcurrencyTest.java
index 1c9c881..0721341 100644
--- a/src/test/java/com/google/devtools/build/skyframe/GraphConcurrencyTest.java
+++ b/src/test/java/com/google/devtools/build/skyframe/GraphConcurrencyTest.java
@@ -15,17 +15,18 @@
 
 import static com.google.common.truth.Truth.assertThat;
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
-import com.google.common.base.Throwables;
+import com.google.common.base.Preconditions;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Sets;
 import com.google.devtools.build.lib.concurrent.ExecutorUtil;
 import com.google.devtools.build.lib.concurrent.KeyedLocker;
 import com.google.devtools.build.lib.concurrent.RefCountedMultisetKeyedLocker;
-import com.google.devtools.build.lib.concurrent.ThrowableRecordingRunnableWrapper;
+import com.google.devtools.build.lib.testutil.TestRunnableWrapper;
 import com.google.devtools.build.lib.testutil.TestUtils;
 import com.google.devtools.build.lib.util.GroupedList.GroupedListHelper;
 import com.google.devtools.build.skyframe.GraphTester.StringValue;
@@ -49,13 +50,22 @@
   private static final SkyFunctionName SKY_FUNCTION_NAME =
       SkyFunctionName.create("GraphConcurrencyTestKey");
   private ProcessableGraph graph;
-  private ThrowableRecordingRunnableWrapper wrapper;
-  protected abstract ProcessableGraph getGraph();
+  private TestRunnableWrapper wrapper;
+
+  // This code should really be in a @Before method, but @Before methods are executed from the
+  // top down, and this class's @Before method calls #getGraph, so makeGraph must have already
+  // been called.
+  protected abstract void makeGraph() throws Exception;
+
+  protected abstract ProcessableGraph getGraph(Version version) throws Exception;
+
+  private static final IntVersion startingVersion = new IntVersion(42);
 
   @Before
-  public void init() {
-    this.graph = getGraph();
-    this.wrapper = new ThrowableRecordingRunnableWrapper("GraphConcurrencyTest");
+  public void init() throws Exception {
+    makeGraph();
+    this.graph = getGraph(startingVersion);
+    this.wrapper = new TestRunnableWrapper("GraphConcurrencyTest");
   }
 
   private SkyKey key(String name) {
@@ -125,21 +135,20 @@
     } catch (InterruptedException e) {
       fail("Test failed: " + e.toString());
     }
-    entry.setValue(new StringValue("foo1"), new IntVersion(1));
+    entry.setValue(new StringValue("foo1"), startingVersion);
     countDownLatch3.countDown();
     entry.removeReverseDep(key("rdep"));
-    boolean interrupted = ExecutorUtil.interruptibleShutdown(pool);
-    Throwables.propagateIfPossible(wrapper.getFirstThrownError());
-    if (interrupted) {
-      Thread.currentThread().interrupt();
-      throw new InterruptedException();
-    }
+    wrapper.waitForTasksAndMaybeThrow();
+    assertFalse(ExecutorUtil.interruptibleShutdown(pool));
     assertEquals(new StringValue("foo1"), graph.get(key).getValue());
     assertEquals(numKeys, Iterables.size(graph.get(key).getReverseDeps()));
+
+    graph = getGraph(startingVersion.next());
+    NodeEntry sameEntry = Preconditions.checkNotNull(graph.get(key));
     // Mark the node as dirty again and check that the reverse deps have been preserved.
-    entry.markDirty(true);
-    startEvaluation(entry);
-    entry.setValue(new StringValue("foo2"), new IntVersion(2));
+    sameEntry.markDirty(true);
+    startEvaluation(sameEntry);
+    sameEntry.setValue(new StringValue("foo2"), startingVersion.next());
     assertEquals(new StringValue("foo2"), graph.get(key).getValue());
     assertEquals(numKeys, Iterables.size(graph.get(key).getReverseDeps()));
   }
@@ -160,44 +169,41 @@
       for (int j = 0; j < numKeys; j++) {
         final int keyNum = j;
         final SkyKey key = key("foo" + keyNum);
-        Runnable r = new Runnable() {
-          public void run() {
-            NodeEntry entry;
-            try (KeyedLocker.AutoUnlocker unlocker = locker.lock(key)) {
-              entry = graph.get(key);
-              if (entry == null) {
-                assertTrue(nodeCreated.add(key));
+        Runnable r =
+            new Runnable() {
+              public void run() {
+                NodeEntry entry;
+                try (KeyedLocker.AutoUnlocker unlocker = locker.lock(key)) {
+                  entry = graph.get(key);
+                  if (entry == null) {
+                    assertTrue(nodeCreated.add(key));
+                  }
+                  entry = graph.createIfAbsent(key);
+                }
+                // {@code entry.addReverseDepAndCheckIfDone(null)} should return NEEDS_SCHEDULING at
+                // most once.
+                if (startEvaluation(entry).equals(DependencyState.NEEDS_SCHEDULING)) {
+                  assertTrue(valuesSet.add(key));
+                  // Set to done.
+                  entry.setValue(new StringValue("bar" + keyNum), startingVersion);
+                  assertThat(entry.isDone()).isTrue();
+                }
+                // This shouldn't cause any problems from the other threads.
+                graph.createIfAbsent(key);
               }
-              entry = graph.createIfAbsent(key);
-            }
-            // {@code entry.addReverseDepAndCheckIfDone(null)} should return NEEDS_SCHEDULING at
-            // most once.
-            if (startEvaluation(entry).equals(DependencyState.NEEDS_SCHEDULING)) {
-              assertTrue(valuesSet.add(key));
-              // Set to done.
-              entry.setValue(new StringValue("bar" + keyNum), new IntVersion(keyNum));
-              assertThat(entry.isDone()).isTrue();
-            }
-            // This shouldn't cause any problems from the other threads.
-            graph.createIfAbsent(key);
-          }
-        };
+            };
         pool.execute(wrapper.wrap(r));
       }
     }
-    boolean interrupted = ExecutorUtil.interruptibleShutdown(pool);
-    Throwables.propagateIfPossible(wrapper.getFirstThrownError());
-    if (interrupted) {
-      Thread.currentThread().interrupt();
-      throw new InterruptedException();
-    }
+    wrapper.waitForTasksAndMaybeThrow();
+    assertFalse(ExecutorUtil.interruptibleShutdown(pool));
     // Check that all the values are as expected.
     for (int i = 0; i < numKeys; i++) {
       SkyKey key = key("foo" + i);
       assertTrue(nodeCreated.contains(key));
       assertTrue(valuesSet.contains(key));
       assertThat(graph.get(key).getValue()).isEqualTo(new StringValue("bar" + i));
-      assertThat(graph.get(key).getVersion()).isEqualTo(new IntVersion(i));
+      assertThat(graph.get(key).getVersion()).isEqualTo(startingVersion);
     }
   }
 
@@ -214,9 +220,12 @@
     for (int i = 0; i < numKeys; i++) {
       NodeEntry entry = graph.createIfAbsent(key("foo" + i));
       startEvaluation(entry);
-      entry.setValue(new StringValue("bar"), new IntVersion(0));
+      entry.setValue(new StringValue("bar"), startingVersion);
     }
 
+    assertNotNull(graph.get(key("foo" + 0)));
+    graph = getGraph(startingVersion.next());
+    assertNotNull(graph.get(key("foo" + 0)));
     ExecutorService pool1 = Executors.newFixedThreadPool(numThreads);
     ExecutorService pool2 = Executors.newFixedThreadPool(numThreads);
     ExecutorService pool3 = Executors.newFixedThreadPool(numThreads);
@@ -230,44 +239,47 @@
     for (int i = 0; i < numKeys; i++) {
       final int keyNum = i;
       // Transition the nodes from done to dirty and then back to done.
-      Runnable r1 = new Runnable() {
-        @Override
-        public void run() {
-          try {
-            makeBatchCountDownLatch.await();
-            getBatchCountDownLatch.await();
-            getCountDownLatch.await();
-          } catch (InterruptedException e) {
-            fail("Test failed: " + e.toString());
-          }
-          NodeEntry entry = graph.get(key("foo" + keyNum));
-          entry.markDirty(true);
-          // Make some changes, like adding a dep and rdep.
-          entry.addReverseDepAndCheckIfDone(key("rdep"));
-          addTemporaryDirectDep(entry, key("dep"));
-          entry.signalDep();
-          // Move node from dirty back to done.
-          entry.setValue(new StringValue("bar" + keyNum), new IntVersion(1));
-        }
-      };
+      Runnable r1 =
+          new Runnable() {
+            @Override
+            public void run() {
+              try {
+                makeBatchCountDownLatch.await();
+                getBatchCountDownLatch.await();
+                getCountDownLatch.await();
+              } catch (InterruptedException e) {
+                throw new AssertionError(e);
+              }
+              NodeEntry entry = graph.get(key("foo" + keyNum));
+              entry.markDirty(true);
+              // Make some changes, like adding a dep and rdep.
+              entry.addReverseDepAndCheckIfDone(key("rdep"));
+              addTemporaryDirectDep(entry, key("dep"));
+              entry.signalDep();
+              // Move node from dirty back to done.
+              entry.setValue(new StringValue("bar" + keyNum), startingVersion.next());
+            }
+          };
 
       // Start a bunch of get() calls while the node transitions from dirty to done and back.
-      Runnable r2 = new Runnable() {
-        @Override
-        public void run() {
-          try {
-            makeBatchCountDownLatch.await();
-          } catch (InterruptedException e) {
-            fail("Test failed: " + e.toString());
-          }
-          NodeEntry entry = graph.get(key("foo" + keyNum));
-          assertNotEquals(null, entry);
-          // Requests for the value are made at the same time that the version changes from 0 to 1.
-          // Check that there is no problem in requesting the version and that the number is sane.
-          assertThat(entry.getVersion()).isAnyOf(new IntVersion(0), new IntVersion(1));
-          getCountDownLatch.countDown();
-        }
-      };
+      Runnable r2 =
+          new Runnable() {
+            @Override
+            public void run() {
+              try {
+                makeBatchCountDownLatch.await();
+              } catch (InterruptedException e) {
+                throw new AssertionError(e);
+              }
+              NodeEntry entry = graph.get(key("foo" + keyNum));
+              assertNotNull(entry);
+              // Requests for the value are made at the same time that the version increments from
+              // the base. Check that there is no problem in requesting the version and that the
+              // number is sane.
+              assertThat(entry.getVersion()).isAnyOf(startingVersion, startingVersion.next());
+              getCountDownLatch.countDown();
+            }
+          };
       pool1.execute(wrapper.wrap(r1));
       pool2.execute(wrapper.wrap(r2));
     }
@@ -282,38 +294,36 @@
         }
       }
       makeBatchCountDownLatch.countDown();
-      Runnable r3 = new Runnable() {
-        @Override
-        public void run() {
-          try {
-            makeBatchCountDownLatch.await();
-          } catch (InterruptedException e) {
-            fail("Test failed: " + e.toString());
-          }
-          Map<SkyKey, NodeEntry> batchMap = graph.getBatch(batch);
-          getBatchCountDownLatch.countDown();
-          assertEquals(batch.size(), batchMap.size());
-          for (NodeEntry entry : batchMap.values()) {
-            // Batch requests are made at the same time that the version changes from 0 to 1.
-            // Check that there is no problem in requesting the version and that the number is sane.
-            assertThat(entry.getVersion()).isAnyOf(new IntVersion(0), new IntVersion(1));
-          }
-        }
-      };
+      Runnable r3 =
+          new Runnable() {
+            @Override
+            public void run() {
+              try {
+                makeBatchCountDownLatch.await();
+              } catch (InterruptedException e) {
+                throw new AssertionError(e);
+              }
+              Map<SkyKey, NodeEntry> batchMap = graph.getBatch(batch);
+              getBatchCountDownLatch.countDown();
+              assertThat(batchMap).hasSize(batch.size());
+              for (NodeEntry entry : batchMap.values()) {
+                // Batch requests are made at the same time that the version increments from the
+                // base. Check that there is no problem in requesting the version and that the
+                // number is sane.
+                assertThat(entry.getVersion()).isAnyOf(startingVersion, startingVersion.next());
+              }
+            }
+          };
       pool3.execute(wrapper.wrap(r3));
     }
-    boolean interrupted = ExecutorUtil.interruptibleShutdown(pool1);
-    interrupted |= ExecutorUtil.interruptibleShutdown(pool2);
-    interrupted |= ExecutorUtil.interruptibleShutdown(pool3);
-    Throwables.propagateIfPossible(wrapper.getFirstThrownError());
-    if (interrupted) {
-      Thread.currentThread().interrupt();
-      throw new InterruptedException();
-    }
+    wrapper.waitForTasksAndMaybeThrow();
+    assertFalse(ExecutorUtil.interruptibleShutdown(pool1));
+    assertFalse(ExecutorUtil.interruptibleShutdown(pool2));
+    assertFalse(ExecutorUtil.interruptibleShutdown(pool3));
     for (int i = 0; i < numKeys; i++) {
       NodeEntry entry = graph.get(key("foo" + i));
       assertThat(entry.getValue()).isEqualTo(new StringValue("bar" + i));
-      assertThat(entry.getVersion()).isEqualTo(new IntVersion(1));
+      assertThat(entry.getVersion()).isEqualTo(startingVersion.next());
       for (SkyKey key : entry.getReverseDeps()) {
         assertEquals(key("rdep"), key);
       }