Replace EvaluableGraph#createIfAbsent with the potentially more efficient EvaluableGraph#createIfAbsentBatch.

--
MOS_MIGRATED_REVID=104534858
diff --git a/src/main/java/com/google/devtools/build/skyframe/EvaluableGraph.java b/src/main/java/com/google/devtools/build/skyframe/EvaluableGraph.java
index a61de6e..7527146 100644
--- a/src/main/java/com/google/devtools/build/skyframe/EvaluableGraph.java
+++ b/src/main/java/com/google/devtools/build/skyframe/EvaluableGraph.java
@@ -15,6 +15,8 @@
 
 import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
 
+import java.util.Map;
+
 /**
  * Interface between a single version of the graph and the evaluator. Supports mutation of that
  * single version of the graph.
@@ -22,8 +24,8 @@
 @ThreadSafe
 interface EvaluableGraph extends QueryableGraph {
   /**
-   * Creates a new node with the specified key if it does not exist yet. Returns the node entry
-   * (either the existing one or the one just created), never {@code null}.
+   * Like {@link QueryableGraph#getBatch}, except it creates a new node for each key not already
+   * present in the graph. Thus, the returned map will have an entry for each key in {@code keys}.
    */
-  NodeEntry createIfAbsent(SkyKey key);
+  Map<SkyKey, NodeEntry> createIfAbsentBatch(Iterable<SkyKey> keys);
 }
diff --git a/src/main/java/com/google/devtools/build/skyframe/InMemoryGraph.java b/src/main/java/com/google/devtools/build/skyframe/InMemoryGraph.java
index f2faaef..5a589ba 100644
--- a/src/main/java/com/google/devtools/build/skyframe/InMemoryGraph.java
+++ b/src/main/java/com/google/devtools/build/skyframe/InMemoryGraph.java
@@ -69,13 +69,21 @@
     return builder.build();
   }
 
-  @Override
-  public NodeEntry createIfAbsent(SkyKey key) {
+  protected NodeEntry createIfAbsent(SkyKey key) {
     NodeEntry newval = keepEdges ? new InMemoryNodeEntry() : new EdgelessInMemoryNodeEntry();
     NodeEntry oldval = nodeMap.putIfAbsent(key, newval);
     return oldval == null ? newval : oldval;
   }
 
+  @Override
+  public Map<SkyKey, NodeEntry> createIfAbsentBatch(Iterable<SkyKey> keys) {
+    ImmutableMap.Builder<SkyKey, NodeEntry> builder = ImmutableMap.builder();
+    for (SkyKey key : keys) {
+      builder.put(key, createIfAbsent(key));
+    }
+    return builder.build();
+  }
+
   /** Only done nodes exist to the outside world. */
   private static final Predicate<NodeEntry> NODE_DONE_PREDICATE =
       new Predicate<NodeEntry>() {
diff --git a/src/main/java/com/google/devtools/build/skyframe/InMemoryMemoizingEvaluator.java b/src/main/java/com/google/devtools/build/skyframe/InMemoryMemoizingEvaluator.java
index 5fa80a0..d7bfb95 100644
--- a/src/main/java/com/google/devtools/build/skyframe/InMemoryMemoizingEvaluator.java
+++ b/src/main/java/com/google/devtools/build/skyframe/InMemoryMemoizingEvaluator.java
@@ -222,10 +222,7 @@
     if (valuesToInject.isEmpty()) {
       return;
     }
-    for (Entry<SkyKey, SkyValue> entry : valuesToInject.entrySet()) {
-      ParallelEvaluator.injectValue(
-          entry.getKey(), entry.getValue(), version, graph, dirtyKeyTracker);
-    }
+    ParallelEvaluator.injectValues(valuesToInject, version, graph, dirtyKeyTracker);
     // Start with a new map to avoid bloat since clear() does not downsize the map.
     valuesToInject = new HashMap<>();
   }
diff --git a/src/main/java/com/google/devtools/build/skyframe/ParallelEvaluator.java b/src/main/java/com/google/devtools/build/skyframe/ParallelEvaluator.java
index 7cc430b..1e555e8 100644
--- a/src/main/java/com/google/devtools/build/skyframe/ParallelEvaluator.java
+++ b/src/main/java/com/google/devtools/build/skyframe/ParallelEvaluator.java
@@ -680,17 +680,16 @@
       this.skyKey = skyKey;
     }
 
-    private void enqueueChild(SkyKey skyKey, NodeEntry entry, SkyKey child, boolean dirtyParent) {
+    private void enqueueChild(SkyKey skyKey, NodeEntry entry, SkyKey child, NodeEntry childEntry,
+        boolean dirtyParent) {
       Preconditions.checkState(!entry.isDone(), "%s %s", skyKey, entry);
-
-      NodeEntry depEntry = graph.createIfAbsent(child);
       DependencyState dependencyState =
           dirtyParent
-              ? depEntry.checkIfDoneForDirtyReverseDep(skyKey)
-              : depEntry.addReverseDepAndCheckIfDone(skyKey);
+              ? childEntry.checkIfDoneForDirtyReverseDep(skyKey)
+              : childEntry.addReverseDepAndCheckIfDone(skyKey);
       switch (dependencyState) {
         case DONE:
-          if (entry.signalDep(depEntry.getVersion())) {
+          if (entry.signalDep(childEntry.getVersion())) {
             // This can only happen if there are no more children to be added.
             visitor.enqueueEvaluation(skyKey);
           }
@@ -786,8 +785,11 @@
           // than this node, so we are going to mark it clean (since the error transience node is
           // always the last dep).
           state.addTemporaryDirectDeps(GroupedListHelper.create(directDepsToCheck));
-          for (SkyKey directDep : directDepsToCheck) {
-            enqueueChild(skyKey, state, directDep, /*dirtyParent=*/ true);
+          for (Map.Entry<SkyKey, NodeEntry> e
+              : graph.createIfAbsentBatch(directDepsToCheck).entrySet()) {
+            SkyKey directDep = e.getKey();
+            NodeEntry directDepEntry = e.getValue();
+            enqueueChild(skyKey, state, directDep, directDepEntry, /*dirtyParent=*/ true);
           }
           return DirtyOutcome.ALREADY_PROCESSED;
         case VERIFIED_CLEAN:
@@ -958,8 +960,10 @@
         return;
       }
 
-      for (SkyKey newDirectDep : newDirectDeps) {
-        enqueueChild(skyKey, state, newDirectDep, /*dirtyParent=*/ false);
+      for (Map.Entry<SkyKey, NodeEntry> e : graph.createIfAbsentBatch(newDirectDeps).entrySet()) {
+        SkyKey newDirectDep = e.getKey();
+        NodeEntry newDirectDepEntry = e.getValue();
+        enqueueChild(skyKey, state, newDirectDep, newDirectDepEntry, /*dirtyParent=*/ false);
       }
       // It is critical that there is no code below this point.
     }
@@ -1129,17 +1133,19 @@
     // We unconditionally add the ErrorTransienceValue here, to ensure that it will be created, and
     // in the graph, by the time that it is needed. Creating it on demand in a parallel context sets
     // up a race condition, because there is no way to atomically create a node and set its value.
-    NodeEntry errorTransienceEntry = graph.createIfAbsent(ErrorTransienceValue.key());
+    SkyKey errorTransienceKey = ErrorTransienceValue.key();
+    NodeEntry errorTransienceEntry = Iterables.getOnlyElement(
+        graph.createIfAbsentBatch(ImmutableList.of(errorTransienceKey)).values());
     if (!errorTransienceEntry.isDone()) {
-      injectValue(
-          ErrorTransienceValue.key(),
-          new ErrorTransienceValue(),
+      injectValues(
+          ImmutableMap.of(errorTransienceKey, (SkyValue) new ErrorTransienceValue()),
           graphVersion,
           graph,
           dirtyKeyTracker);
     }
-    for (SkyKey skyKey : skyKeys) {
-      NodeEntry entry = graph.createIfAbsent(skyKey);
+    for (Map.Entry<SkyKey, NodeEntry> e : graph.createIfAbsentBatch(skyKeys).entrySet()) {
+      SkyKey skyKey = e.getKey();
+      NodeEntry entry = e.getValue();
       // This must be equivalent to the code in enqueueChild above, in order to be thread-safe.
       switch (entry.addReverseDepAndCheckIfDone(null)) {
         case NEEDS_SCHEDULING:
@@ -1753,38 +1759,42 @@
     return entry != null && entry.isDone();
   }
 
-  static void injectValue(
-      SkyKey key,
-      SkyValue value,
+  static void injectValues(
+      Map<SkyKey, SkyValue> injectionMap,
       Version version,
       EvaluableGraph graph,
       DirtyKeyTracker dirtyKeyTracker) {
-    Preconditions.checkNotNull(value, key);
-    NodeEntry prevEntry = graph.createIfAbsent(key);
-    DependencyState newState = prevEntry.addReverseDepAndCheckIfDone(null);
-    Preconditions.checkState(
-        newState != DependencyState.ALREADY_EVALUATING, "%s %s", key, prevEntry);
-    if (prevEntry.isDirty()) {
+    Map<SkyKey, NodeEntry> prevNodeEntries = graph.createIfAbsentBatch(injectionMap.keySet());
+    for (Map.Entry<SkyKey, SkyValue> injectionEntry : injectionMap.entrySet()) {
+      SkyKey key = injectionEntry.getKey();
+      SkyValue value = injectionEntry.getValue();
+      NodeEntry prevEntry = prevNodeEntries.get(key);
+      DependencyState newState = prevEntry.addReverseDepAndCheckIfDone(null);
       Preconditions.checkState(
-          newState == DependencyState.NEEDS_SCHEDULING, "%s %s", key, prevEntry);
-      // There was an existing entry for this key in the graph.
-      // Get the node in the state where it is able to accept a value.
+          newState != DependencyState.ALREADY_EVALUATING, "%s %s", key, prevEntry);
+      if (prevEntry.isDirty()) {
+        Preconditions.checkState(
+            newState == DependencyState.NEEDS_SCHEDULING, "%s %s", key, prevEntry);
+        // There was an existing entry for this key in the graph.
+        // Get the node in the state where it is able to accept a value.
 
-      // Check that the previous node has no dependencies. Overwriting a value with deps with an
-      // injected value (which is by definition deps-free) needs a little additional bookkeeping
-      // (removing reverse deps from the dependencies), but more importantly it's something that
-      // we want to avoid, because it indicates confusion of input values and derived values.
-      Preconditions.checkState(
-          prevEntry.noDepsLastBuild(), "existing entry for %s has deps: %s", key, prevEntry);
-      // Put the node into a "rebuilding" state and verify that there were no dirty deps remaining.
-      Preconditions.checkState(
-          prevEntry.markRebuildingAndGetAllRemainingDirtyDirectDeps().isEmpty(),
-          "%s %s",
-          key,
-          prevEntry);
+        // Check that the previous node has no dependencies. Overwriting a value with deps with an
+        // injected value (which is by definition deps-free) needs a little additional bookkeeping
+        // (removing reverse deps from the dependencies), but more importantly it's something that
+        // we want to avoid, because it indicates confusion of input values and derived values.
+        Preconditions.checkState(
+            prevEntry.noDepsLastBuild(), "existing entry for %s has deps: %s", key, prevEntry);
+        // Put the node into a "rebuilding" state and verify that there were no dirty deps
+        // remaining.
+        Preconditions.checkState(
+            prevEntry.markRebuildingAndGetAllRemainingDirtyDirectDeps().isEmpty(),
+            "%s %s",
+            key,
+            prevEntry);
+      }
+      prevEntry.setValue(value, version);
+      // Now that this key's injected value is set, it is no longer dirty.
+      dirtyKeyTracker.notDirty(key);
     }
-    prevEntry.setValue(value, version);
-    // Now that this key's injected value is set, it is no longer dirty.
-    dirtyKeyTracker.notDirty(key);
   }
 }
diff --git a/src/test/java/com/google/devtools/build/skyframe/GraphConcurrencyTest.java b/src/test/java/com/google/devtools/build/skyframe/GraphConcurrencyTest.java
index d466c8f..40cf5fb 100644
--- a/src/test/java/com/google/devtools/build/skyframe/GraphConcurrencyTest.java
+++ b/src/test/java/com/google/devtools/build/skyframe/GraphConcurrencyTest.java
@@ -21,6 +21,7 @@
 import static org.junit.Assert.fail;
 
 import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Sets;
 import com.google.devtools.build.lib.concurrent.ExecutorUtil;
@@ -74,8 +75,8 @@
   }
 
   @Test
-  public void createIfAbsentSanity() {
-    graph.createIfAbsent(key("cat"));
+  public void createIfAbsentBatchSanity() {
+    graph.createIfAbsentBatch(ImmutableList.of(key("cat"), key("dog")));
   }
 
   // Tests adding and removing Rdeps of a {@link NodeEntry} while a node transitions from
@@ -83,7 +84,8 @@
   @Test
   public void testAddRemoveRdeps() throws Exception {
     SkyKey key = key("foo");
-    final NodeEntry entry = graph.createIfAbsent(key);
+    final NodeEntry entry = Iterables.getOnlyElement(
+        graph.createIfAbsentBatch(ImmutableList.of(key)).values());
     // These numbers are arbitrary.
     int numThreads = 50;
     int numKeys = numThreads;
@@ -164,37 +166,50 @@
     final KeyedLocker<SkyKey> locker = new RefCountedMultisetKeyedLocker<>();
     ExecutorService pool = Executors.newFixedThreadPool(numThreads);
     final int numKeys = 500;
-    // Add each key 10 times.
+    // Add each pair of keys 10 times.
     final Set<SkyKey> nodeCreated = Sets.newConcurrentHashSet();
     final Set<SkyKey> valuesSet = Sets.newConcurrentHashSet();
     for (int i = 0; i < 10; i++) {
       for (int j = 0; j < numKeys; j++) {
-        final int keyNum = j;
-        final SkyKey key = key("foo" + keyNum);
-        Runnable r =
-            new Runnable() {
-              public void run() {
-                NodeEntry entry;
-                try (KeyedLocker.AutoUnlocker unlocker = locker.lock(key)) {
-                  entry = graph.get(key);
-                  if (entry == null) {
-                    assertTrue(nodeCreated.add(key));
+        for (int k = j + 1; k < numKeys; k++) {
+          final int keyNum1 = j;
+          final int keyNum2 = k;
+          final SkyKey key1 = key("foo" + keyNum1);
+          final SkyKey key2 = key("foo" + keyNum2);
+          final Iterable<SkyKey> keys = ImmutableList.of(key1, key2);
+          Runnable r =
+              new Runnable() {
+                public void run() {
+                  Map<SkyKey, NodeEntry> entries;
+                  try (KeyedLocker.AutoUnlocker unlocker1 = locker.lock(key1)) {
+                    try (KeyedLocker.AutoUnlocker unlocker2 = locker.lock(key2)) {
+                      for (SkyKey key : keys) {
+                        NodeEntry entry = graph.get(key);
+                        if (entry == null) {
+                          assertTrue(nodeCreated.add(key));
+                        }
+                      }
+                      entries = graph.createIfAbsentBatch(keys);
+                    }
                   }
-                  entry = graph.createIfAbsent(key);
+                  for (Integer keyNum : ImmutableList.of(keyNum1, keyNum2)) {
+                    SkyKey key = key("foo" + keyNum);
+                    NodeEntry entry = entries.get(key);
+                    // {@code entry.addReverseDepAndCheckIfDone(null)} should return
+                    // NEEDS_SCHEDULING at most once.
+                    if (startEvaluation(entry).equals(DependencyState.NEEDS_SCHEDULING)) {
+                      assertTrue(valuesSet.add(key));
+                      // Set to done.
+                      entry.setValue(new StringValue("bar" + keyNum), startingVersion);
+                      assertThat(entry.isDone()).isTrue();
+                    }
+                  }
+                  // This shouldn't cause any problems from the other threads.
+                  graph.createIfAbsentBatch(keys);
                 }
-                // {@code entry.addReverseDepAndCheckIfDone(null)} should return NEEDS_SCHEDULING at
-                // most once.
-                if (startEvaluation(entry).equals(DependencyState.NEEDS_SCHEDULING)) {
-                  assertTrue(valuesSet.add(key));
-                  // Set to done.
-                  entry.setValue(new StringValue("bar" + keyNum), startingVersion);
-                  assertThat(entry.isDone()).isTrue();
-                }
-                // This shouldn't cause any problems from the other threads.
-                graph.createIfAbsent(key);
-              }
-            };
-        pool.execute(wrapper.wrap(r));
+              };
+          pool.execute(wrapper.wrap(r));
+        }
       }
     }
     wrapper.waitForTasksAndMaybeThrow();
@@ -219,8 +234,13 @@
     int numThreads = 50;
     final int numBatchRequests = 100;
     // Create a bunch of done nodes.
+    ArrayList<SkyKey> keys = new ArrayList<>();
     for (int i = 0; i < numKeys; i++) {
-      NodeEntry entry = graph.createIfAbsent(key("foo" + i));
+      keys.add(key("foo" + i));
+    }
+    Map<SkyKey, NodeEntry> entries = graph.createIfAbsentBatch(keys);
+    for (int i = 0; i < numKeys; i++) {
+      NodeEntry entry = entries.get(key("foo" + i));
       startEvaluation(entry);
       entry.setValue(new StringValue("bar"), startingVersion);
     }
diff --git a/src/test/java/com/google/devtools/build/skyframe/NotifyingInMemoryGraph.java b/src/test/java/com/google/devtools/build/skyframe/NotifyingInMemoryGraph.java
index 7cdafc3..1ff4fce 100644
--- a/src/test/java/com/google/devtools/build/skyframe/NotifyingInMemoryGraph.java
+++ b/src/test/java/com/google/devtools/build/skyframe/NotifyingInMemoryGraph.java
@@ -33,8 +33,7 @@
     this.graphListener = new ErrorRecordingDelegatingListener(graphListener);
   }
 
-  @Override
-  public NodeEntry createIfAbsent(SkyKey key) {
+  protected NodeEntry createIfAbsent(SkyKey key) {
     graphListener.accept(key, EventType.CREATE_IF_ABSENT, Order.BEFORE, null);
     NodeEntry newval = getEntry(key);
     NodeEntry oldval = getNodeMap().putIfAbsent(key, newval);