Optimize GC usage of iterating over all elements of GroupedLists when we don't care about the group structure, and simplify the logic for prefetching old deps.

PiperOrigin-RevId: 187681887
diff --git a/src/main/java/com/google/devtools/build/lib/util/GroupedList.java b/src/main/java/com/google/devtools/build/lib/util/GroupedList.java
index 11bef3f..3bf690a 100644
--- a/src/main/java/com/google/devtools/build/lib/util/GroupedList.java
+++ b/src/main/java/com/google/devtools/build/lib/util/GroupedList.java
@@ -20,6 +20,7 @@
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Lists;
 import com.google.devtools.build.lib.collect.compacthashset.CompactHashSet;
+import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadHostile;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Iterator;
@@ -196,7 +197,7 @@
 
   @SuppressWarnings("unchecked")
   public Set<T> toSet() {
-    ImmutableSet.Builder<T> builder = ImmutableSet.builder();
+    ImmutableSet.Builder<T> builder = ImmutableSet.builderWithExpectedSize(numElements());
     for (Object obj : elements) {
       if (obj instanceof List) {
         builder.addAll((List<T>) obj);
@@ -262,6 +263,43 @@
     return first.equals(second) || CompactHashSet.create(first).containsAll(second);
   }
 
+  /** An iterator that loops through every element in each group. */
+  private class UngroupedIterator implements Iterator<T> {
+    private final Iterator<Object> iter = elements.iterator();
+    int counter = 0;
+    List<T> currentGroup;
+    int listCounter = 0;
+
+    @Override
+    public boolean hasNext() {
+      return counter < size;
+    }
+
+    @SuppressWarnings("unchecked") // Cast of Object to List<T> or T.
+    @Override
+    public T next() {
+      counter++;
+      if (currentGroup != null && listCounter < currentGroup.size()) {
+        return currentGroup.get(listCounter++);
+      }
+      Object nextGroup = iter.next();
+      if (nextGroup instanceof List) {
+        currentGroup = (List<T>) nextGroup;
+        listCounter = 1;
+        // GroupedLists shouldn't have empty lists stored.
+        return currentGroup.get(0);
+      } else {
+        currentGroup = null;
+        return (T) nextGroup;
+      }
+    }
+  }
+
+  @ThreadHostile
+  public Iterable<T> getAllElementsAsIterable() {
+    return UngroupedIterator::new;
+  }
+
   @Override
   public boolean equals(Object other) {
     if (other == null) {
@@ -326,11 +364,6 @@
       }
       return ImmutableList.of((T) obj);
     }
-
-    @Override
-    public void remove() {
-      throw new UnsupportedOperationException();
-    }
   }
 
   @Override
diff --git a/src/main/java/com/google/devtools/build/skyframe/InMemoryNodeEntry.java b/src/main/java/com/google/devtools/build/skyframe/InMemoryNodeEntry.java
index 76e8100..24dfeab 100644
--- a/src/main/java/com/google/devtools/build/skyframe/InMemoryNodeEntry.java
+++ b/src/main/java/com/google/devtools/build/skyframe/InMemoryNodeEntry.java
@@ -18,7 +18,6 @@
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
-import com.google.common.collect.Iterables;
 import com.google.devtools.build.lib.util.GroupedList;
 import com.google.devtools.build.lib.util.GroupedList.GroupedListHelper;
 import com.google.devtools.build.skyframe.KeyToConsolidate.Op;
@@ -220,7 +219,7 @@
 
   @Override
   public synchronized Iterable<SkyKey> getDirectDeps() {
-    return getGroupedDirectDeps().toSet();
+    return getGroupedDirectDeps().getAllElementsAsIterable();
   }
 
   /**
@@ -520,7 +519,7 @@
       throws InterruptedException {
     Preconditions.checkState(!isDone(), this);
     if (!isDirty()) {
-      return Iterables.concat(getTemporaryDirectDeps());
+      return getTemporaryDirectDeps().getAllElementsAsIterable();
     } else {
       // There may be duplicates here. Make sure everything is unique.
       ImmutableSet.Builder<SkyKey> result = ImmutableSet.builder();
diff --git a/src/main/java/com/google/devtools/build/skyframe/QueryableGraph.java b/src/main/java/com/google/devtools/build/skyframe/QueryableGraph.java
index 26422dd..8c5af09 100644
--- a/src/main/java/com/google/devtools/build/skyframe/QueryableGraph.java
+++ b/src/main/java/com/google/devtools/build/skyframe/QueryableGraph.java
@@ -50,6 +50,15 @@
           throws InterruptedException;
 
   /**
+   * A prefetch call may be used to hint to the graph that we may call {@link #getBatch} on the
+   * specified keys later.
+   */
+  default void prefetchBatch(
+      @Nullable SkyKey requestor, Reason reason, Iterable<? extends SkyKey> keys) {
+    // Do nothing.
+  }
+
+  /**
    * Examines all the given keys. Returns an iterable of keys whose corresponding nodes are
    * currently available to be fetched.
    *
diff --git a/src/main/java/com/google/devtools/build/skyframe/SkyFunctionEnvironment.java b/src/main/java/com/google/devtools/build/skyframe/SkyFunctionEnvironment.java
index 9646b07..d8c8f7a 100644
--- a/src/main/java/com/google/devtools/build/skyframe/SkyFunctionEnvironment.java
+++ b/src/main/java/com/google/devtools/build/skyframe/SkyFunctionEnvironment.java
@@ -20,8 +20,6 @@
 import com.google.common.base.Preconditions;
 import com.google.common.base.Predicates;
 import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
@@ -151,20 +149,26 @@
       boolean assertDone,
       SkyKey keyForDebugging)
       throws InterruptedException {
-    Iterable<SkyKey> depKeysAsIterable = Iterables.concat(depKeys);
-    Iterable<SkyKey> keysToPrefetch = depKeysAsIterable;
+    Set<SkyKey> depKeysAsSet = null;
     if (PREFETCH_OLD_DEPS) {
-      ImmutableSet.Builder<SkyKey> keysToPrefetchBuilder = ImmutableSet.builder();
-      keysToPrefetchBuilder.addAll(depKeysAsIterable).addAll(oldDeps);
-      keysToPrefetch = keysToPrefetchBuilder.build();
+      if (!oldDeps.isEmpty()) {
+        // Create a set here so that filtering the old deps below is fast. Once we create this set,
+        // we may as well use it for the call to evaluatorContext#getBatchValues since we've
+        // precomputed the size.
+        depKeysAsSet = depKeys.toSet();
+        evaluatorContext
+            .getGraph()
+            .prefetchBatch(
+                requestor,
+                Reason.PREFETCH,
+                Iterables.filter(oldDeps, Predicates.not(Predicates.in(depKeysAsSet))));
+      }
     }
     Map<SkyKey, ? extends NodeEntry> batchMap =
-        evaluatorContext.getBatchValues(requestor, Reason.PREFETCH, keysToPrefetch);
-    if (PREFETCH_OLD_DEPS) {
-      batchMap =
-          ImmutableMap.<SkyKey, NodeEntry>copyOf(
-              Maps.filterKeys(batchMap, Predicates.in(ImmutableSet.copyOf(depKeysAsIterable))));
-    }
+        evaluatorContext.getBatchValues(
+            requestor,
+            Reason.PREFETCH,
+            depKeysAsSet == null ? depKeys.getAllElementsAsIterable() : depKeysAsSet);
     if (batchMap.size() != depKeys.numElements()) {
       throw new IllegalStateException(
           "Missing keys for "
@@ -316,7 +320,7 @@
             keySize - directDeps.size() - (bubbleErrorInfo == null ? 0 : bubbleErrorInfo.size()),
             0);
     ArrayList<SkyKey> missingKeys = new ArrayList<>(expectedMissingKeySize);
-    for (SkyKey key : Iterables.concat(depKeys)) {
+    for (SkyKey key : depKeys.getAllElementsAsIterable()) {
       SkyValue value = maybeGetValueFromErrorOrDeps(key);
       if (value == null) {
         missingKeys.add(key);
diff --git a/src/test/java/com/google/devtools/build/lib/util/GroupedListTest.java b/src/test/java/com/google/devtools/build/lib/util/GroupedListTest.java
index 63c99cc..fd5938e 100644
--- a/src/test/java/com/google/devtools/build/lib/util/GroupedListTest.java
+++ b/src/test/java/com/google/devtools/build/lib/util/GroupedListTest.java
@@ -154,6 +154,9 @@
     assertElementsEqual(compressed, allElts);
     assertElementsEqualInGroups(GroupedList.<String>create(compressed), elements);
     assertElementsEqualInGroups(groupedList, elements);
+    assertThat(groupedList.getAllElementsAsIterable())
+        .containsExactlyElementsIn(Iterables.concat(groupedList))
+        .inOrder();
   }
 
   @Test
@@ -186,6 +189,9 @@
     elements.remove(1);
     assertElementsEqualInGroups(GroupedList.<String>create(compressed), elements);
     assertElementsEqualInGroups(groupedList, elements);
+    assertThat(groupedList.getAllElementsAsIterable())
+        .containsExactlyElementsIn(Iterables.concat(groupedList))
+        .inOrder();
   }
 
   @Test