Re-use previously computed deps for TransitiveBaseTraversalFunction#compute if we have them instead of re-computing them each time on a skyframe restart.

PiperOrigin-RevId: 186017079
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/TransitiveBaseTraversalFunction.java b/src/main/java/com/google/devtools/build/lib/skyframe/TransitiveBaseTraversalFunction.java
index 5b6e9ed..bc6b13b 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/TransitiveBaseTraversalFunction.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/TransitiveBaseTraversalFunction.java
@@ -13,6 +13,7 @@
 // limitations under the License.
 package com.google.devtools.build.lib.skyframe;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 import com.google.common.collect.Collections2;
 import com.google.common.collect.ImmutableSet;
@@ -32,6 +33,7 @@
 import com.google.devtools.build.lib.packages.PackageGroup;
 import com.google.devtools.build.lib.packages.Rule;
 import com.google.devtools.build.lib.packages.Target;
+import com.google.devtools.build.lib.util.GroupedList;
 import com.google.devtools.build.skyframe.SkyFunction;
 import com.google.devtools.build.skyframe.SkyFunctionException;
 import com.google.devtools.build.skyframe.SkyKey;
@@ -61,7 +63,6 @@
  * return.
  */
 abstract class TransitiveBaseTraversalFunction<TProcessedTargets> implements SkyFunction {
-
   /**
    * Returns a {@link SkyKey} corresponding to the traversal of a target specified by {@code label}
    * and its transitive dependencies.
@@ -127,9 +128,9 @@
     TargetAndErrorIfAny targetAndErrorIfAny = (TargetAndErrorIfAny) loadTargetResults;
     TProcessedTargets processedTargets = processTarget(label, targetAndErrorIfAny);
 
-    // Process deps from attributes.
-    Collection<SkyKey> labelDepKeys =
-        Collections2.transform(getLabelDeps(targetAndErrorIfAny.getTarget()), this::getKey);
+    // Process deps from attributes. It is essential that the last getValue(s) call we made to
+    // skyframe for building this node was for the corresponding PackageValue.
+    Collection<SkyKey> labelDepKeys = getLabelDepKeys(env, targetAndErrorIfAny);
 
     Map<SkyKey, ValueOrException2<NoSuchPackageException, NoSuchTargetException>> depMap =
         env.getValuesOrThrow(labelDepKeys, NoSuchPackageException.class,
@@ -139,9 +140,10 @@
       return null;
     }
 
-    // Process deps from aspects.
+    // Process deps from attributes. It is essential that the second-to-last getValue(s) call we
+    // made to skyframe for building this node was for the corresponding PackageValue.
     Iterable<SkyKey> labelAspectKeys =
-        getStrictLabelAspectKeys(targetAndErrorIfAny.getTarget(), depMap, env);
+        getStrictLabelAspectDepKeys(env, depMap, targetAndErrorIfAny);
     Set<Entry<SkyKey, ValueOrException2<NoSuchPackageException, NoSuchTargetException>>>
         labelAspectEntries = env.getValuesOrThrow(labelAspectKeys, NoSuchPackageException.class,
         NoSuchTargetException.class).entrySet();
@@ -153,6 +155,60 @@
     return computeSkyValue(targetAndErrorIfAny, processedTargets);
   }
 
+  private Collection<SkyKey> getLabelDepKeys(
+      SkyFunction.Environment env, TargetAndErrorIfAny targetAndErrorIfAny)
+      throws InterruptedException {
+    // As a performance optimization we may already know the deps we are  about to request from
+    // last time #compute was called. By requesting these from the environment, we can avoid
+    // repeating the label visitation step. For TransitiveBaseTraversalFunction#compute, the label
+    // deps dependency group is requested immediately after the package.
+    Collection<SkyKey> oldDepKeys = getDepsAfterLastPackageDep(env, /*offset=*/ 1);
+    return oldDepKeys == null
+        ? Collections2.transform(getLabelDeps(targetAndErrorIfAny.getTarget()), this::getKey)
+        : oldDepKeys;
+  }
+
+  private Iterable<SkyKey> getStrictLabelAspectDepKeys(
+      SkyFunction.Environment env,
+      Map<SkyKey, ValueOrException2<NoSuchPackageException, NoSuchTargetException>> depMap,
+      TargetAndErrorIfAny targetAndErrorIfAny)
+      throws InterruptedException {
+    // As a performance optimization we may already know the deps we are  about to request from
+    // last time #compute was called. By requesting these from the environment, we can avoid
+    // repeating the label visitation step. For TransitiveBaseTraversalFunction#compute, the label
+    // aspect deps dependency group is requested two groups after the package.
+    Collection<SkyKey> oldAspectDepKeys = getDepsAfterLastPackageDep(env, /*offset=*/ 2);
+    return oldAspectDepKeys == null
+        ? getStrictLabelAspectKeys(targetAndErrorIfAny.getTarget(), depMap, env)
+        : oldAspectDepKeys;
+  }
+
+  @Nullable
+  private static Collection<SkyKey> getDepsAfterLastPackageDep(
+      SkyFunction.Environment env, int offset) {
+    GroupedList<SkyKey> temporaryDirectDeps = env.getTemporaryDirectDeps();
+    if (temporaryDirectDeps == null) {
+      return null;
+    }
+    int lastPackageDepIndex = getLastPackageValueIndex(temporaryDirectDeps);
+    if (lastPackageDepIndex == -1
+        || temporaryDirectDeps.listSize() <= lastPackageDepIndex + offset) {
+      return null;
+    }
+    return temporaryDirectDeps.get(lastPackageDepIndex + offset);
+  }
+
+  private static int getLastPackageValueIndex(GroupedList<SkyKey> directDeps) {
+    int directDepsNumGroups = directDeps.listSize();
+    for (int i = directDepsNumGroups - 1; i >= 0; i--) {
+      List<SkyKey> depGroup = directDeps.get(i);
+      if (depGroup.size() == 1 && depGroup.get(0).functionName().equals(SkyFunctions.PACKAGE)) {
+        return i;
+      }
+    }
+    return -1;
+  }
+
   @Override
   public String extractTag(SkyKey skyKey) {
     return Label.print(argumentFromKey(skyKey));
@@ -269,14 +325,18 @@
     Target getTarget();
   }
 
-  private static class TargetAndErrorIfAnyImpl implements TargetAndErrorIfAny, LoadTargetResults {
+  @VisibleForTesting
+  static class TargetAndErrorIfAnyImpl implements TargetAndErrorIfAny, LoadTargetResults {
 
     private final boolean packageLoadedSuccessfully;
     @Nullable private final NoSuchTargetException errorLoadingTarget;
     private final Target target;
 
-    private TargetAndErrorIfAnyImpl(boolean packageLoadedSuccessfully,
-        @Nullable NoSuchTargetException errorLoadingTarget, Target target) {
+    @VisibleForTesting
+    TargetAndErrorIfAnyImpl(
+        boolean packageLoadedSuccessfully,
+        @Nullable NoSuchTargetException errorLoadingTarget,
+        Target target) {
       this.packageLoadedSuccessfully = packageLoadedSuccessfully;
       this.errorLoadingTarget = errorLoadingTarget;
       this.target = target;
@@ -304,7 +364,7 @@
     }
   }
 
-  private LoadTargetResults loadTarget(Environment env, Label label)
+  protected LoadTargetResults loadTarget(Environment env, Label label)
       throws NoSuchTargetException, NoSuchPackageException, InterruptedException {
     SkyKey packageKey = PackageValue.key(label.getPackageIdentifier());
     SkyKey targetKey = TargetMarkerValue.key(label);
diff --git a/src/main/java/com/google/devtools/build/skyframe/AbstractSkyFunctionEnvironment.java b/src/main/java/com/google/devtools/build/skyframe/AbstractSkyFunctionEnvironment.java
index 7c5deab..2b0984d 100644
--- a/src/main/java/com/google/devtools/build/skyframe/AbstractSkyFunctionEnvironment.java
+++ b/src/main/java/com/google/devtools/build/skyframe/AbstractSkyFunctionEnvironment.java
@@ -17,6 +17,7 @@
 import com.google.common.base.Function;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Maps;
+import com.google.devtools.build.lib.util.GroupedList;
 import com.google.devtools.build.skyframe.ValueOrExceptionUtils.BottomException;
 import java.util.Collections;
 import java.util.Map;
@@ -29,12 +30,27 @@
 @VisibleForTesting
 public abstract class AbstractSkyFunctionEnvironment implements SkyFunction.Environment {
   protected boolean valuesMissing = false;
+  @Nullable private final GroupedList<SkyKey> temporaryDirectDeps;
+
   private <E extends Exception> ValueOrException<E> getValueOrException(
       SkyKey depKey, Class<E> exceptionClass) throws InterruptedException {
     return ValueOrExceptionUtils.downconvert(
         getValueOrException(depKey, exceptionClass, BottomException.class), exceptionClass);
   }
 
+  public AbstractSkyFunctionEnvironment(@Nullable GroupedList<SkyKey> temporaryDirectDeps) {
+    this.temporaryDirectDeps = temporaryDirectDeps;
+  }
+
+  public AbstractSkyFunctionEnvironment() {
+    this(null);
+  }
+
+  @Override
+  public GroupedList<SkyKey> getTemporaryDirectDeps() {
+    return temporaryDirectDeps;
+  }
+
   private <E1 extends Exception, E2 extends Exception>
       ValueOrException2<E1, E2> getValueOrException(
           SkyKey depKey, Class<E1> exceptionClass1, Class<E2> exceptionClass2)
diff --git a/src/main/java/com/google/devtools/build/skyframe/SkyFunction.java b/src/main/java/com/google/devtools/build/skyframe/SkyFunction.java
index 66e9408..9b7c29d 100644
--- a/src/main/java/com/google/devtools/build/skyframe/SkyFunction.java
+++ b/src/main/java/com/google/devtools/build/skyframe/SkyFunction.java
@@ -16,6 +16,7 @@
 import com.google.common.annotations.VisibleForTesting;
 import com.google.devtools.build.lib.concurrent.ThreadSafety.ThreadSafe;
 import com.google.devtools.build.lib.events.ExtendedEventHandler;
+import com.google.devtools.build.lib.util.GroupedList;
 import java.util.Map;
 import javax.annotation.Nullable;
 
@@ -273,6 +274,16 @@
      */
     ExtendedEventHandler getListener();
 
+    /**
+     * A live view of deps known to have already been requested either through an earlier call to
+     * {@link SkyFunction#compute} or inferred during change pruning. Should return {@code null} if
+     * unknown.
+     */
+    @Nullable
+    default GroupedList<SkyKey> getTemporaryDirectDeps() {
+      return null;
+    }
+
     /** Returns whether we are currently in error bubbling. */
     @VisibleForTesting
     boolean inErrorBubblingForTesting();
diff --git a/src/main/java/com/google/devtools/build/skyframe/SkyFunctionEnvironment.java b/src/main/java/com/google/devtools/build/skyframe/SkyFunctionEnvironment.java
index eb25426..b9830df 100644
--- a/src/main/java/com/google/devtools/build/skyframe/SkyFunctionEnvironment.java
+++ b/src/main/java/com/google/devtools/build/skyframe/SkyFunctionEnvironment.java
@@ -129,6 +129,7 @@
       Set<SkyKey> oldDeps,
       ParallelEvaluatorContext evaluatorContext)
       throws InterruptedException {
+    super(directDeps);
     this.skyKey = skyKey;
     this.oldDeps = oldDeps;
     this.evaluatorContext = evaluatorContext;
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/TransitiveBaseTraversalFunctionTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/TransitiveBaseTraversalFunctionTest.java
new file mode 100644
index 0000000..665e706
--- /dev/null
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/TransitiveBaseTraversalFunctionTest.java
@@ -0,0 +1,134 @@
+// Copyright 2018 The Bazel Authors. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package com.google.devtools.build.lib.skyframe;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.when;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.devtools.build.lib.analysis.util.BuildViewTestCase;
+import com.google.devtools.build.lib.cmdline.Label;
+import com.google.devtools.build.lib.cmdline.PackageIdentifier;
+import com.google.devtools.build.lib.packages.NoSuchPackageException;
+import com.google.devtools.build.lib.packages.NoSuchTargetException;
+import com.google.devtools.build.lib.packages.Package;
+import com.google.devtools.build.lib.skyframe.TransitiveBaseTraversalFunction.TargetAndErrorIfAnyImpl;
+import com.google.devtools.build.lib.testutil.TestRuleClassProvider;
+import com.google.devtools.build.lib.util.GroupedList;
+import com.google.devtools.build.lib.util.GroupedList.GroupedListHelper;
+import com.google.devtools.build.lib.vfs.Path;
+import com.google.devtools.build.skyframe.SkyFunction;
+import com.google.devtools.build.skyframe.SkyKey;
+import java.util.concurrent.atomic.AtomicBoolean;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.Mockito;
+
+/** Test for {@link TransitiveBaseTraversalFunction}. */
+@RunWith(JUnit4.class)
+public class TransitiveBaseTraversalFunctionTest extends BuildViewTestCase {
+  Label label;
+  TargetAndErrorIfAnyImpl targetAndErrorIfAny;
+
+  @Before
+  public void setUp() throws Exception {
+    // Create a basic package with a target //foo:foo.
+    label = Label.parseAbsolute("//foo:foo");
+    Package pkg =
+        scratchPackage(
+            "workspace",
+            label.getPackageIdentifier(),
+            "sh_library(name = '" + label.getName() + "')");
+    targetAndErrorIfAny =
+        new TargetAndErrorIfAnyImpl(
+            /*packageLoadedSuccessfully=*/ true,
+            /*errorLoadingTarget=*/ null,
+            pkg.getTarget(label.getName()));
+  }
+
+  @Test
+  public void assertNoLabelVisitationForTransitiveTraversalFunction() throws Exception {
+    assertNoLabelVisitationForFunction(
+        new TransitiveTraversalFunction() {
+          @Override
+          protected LoadTargetResults loadTarget(Environment env, Label label)
+              throws NoSuchTargetException, NoSuchPackageException, InterruptedException {
+            return targetAndErrorIfAny;
+          }
+        });
+  }
+
+  @Test
+  public void assertNoLabelVisitationForTransitiveTargetFunction() throws Exception {
+    assertNoLabelVisitationForFunction(
+        new TransitiveTargetFunction(TestRuleClassProvider.getRuleClassProvider()) {
+          @Override
+          protected LoadTargetResults loadTarget(Environment env, Label label)
+              throws NoSuchTargetException, NoSuchPackageException, InterruptedException {
+            return targetAndErrorIfAny;
+          }
+        });
+  }
+
+  private void assertNoLabelVisitationForFunction(TransitiveBaseTraversalFunction<?> function)
+      throws Exception {
+    // Create the GroupedList saying we had already requested two targets the last time we called
+    // #compute.
+    GroupedListHelper<SkyKey> helper = new GroupedListHelper<>();
+    SkyKey fakeDep1 = function.getKey(Label.parseAbsolute("//foo:bar"));
+    SkyKey fakeDep2 = function.getKey(Label.parseAbsolute("//foo:baz"));
+    helper.add(TargetMarkerValue.key(label));
+    helper.add(PackageValue.key(label.getPackageIdentifier()));
+    helper.startGroup();
+    // Note that these targets don't actually exist in the package we created initially. It doesn't
+    // matter for the purpose of this test, the original package was just to create some objects
+    // that we needed.
+    helper.add(fakeDep1);
+    helper.add(fakeDep2);
+    helper.endGroup();
+    GroupedList<SkyKey> groupedList = new GroupedList<>();
+    groupedList.append(helper);
+    AtomicBoolean wasOptimizationUsed = new AtomicBoolean(false);
+    SkyFunction.Environment mockEnv = Mockito.mock(SkyFunction.Environment.class);
+    when(mockEnv.getTemporaryDirectDeps()).thenReturn(groupedList);
+    when(mockEnv.getValuesOrThrow(
+            groupedList.get(2), NoSuchPackageException.class, NoSuchTargetException.class))
+        .thenAnswer(
+            (invocationOnMock) -> {
+              wasOptimizationUsed.set(true);
+              // It doesn't matter what this map is, we'll return false in the valuesMissing() call.
+              return ImmutableMap.of();
+            });
+    when(mockEnv.valuesMissing()).thenReturn(true);
+
+    // Run the compute function and check that we returned null.
+    assertThat(function.compute(function.getKey(label), mockEnv)).isNull();
+
+    // Verify that the mock was called with the arguments we expected.
+    assertThat(wasOptimizationUsed.get()).isTrue();
+  }
+
+  private Package scratchPackage(String workspaceName, PackageIdentifier packageId, String... lines)
+      throws Exception {
+    Path buildFile = scratch.file("" + packageId.getSourceRoot() + "/BUILD", lines);
+    Package.Builder externalPkg =
+        Package.newExternalPackageBuilder(
+            Package.Builder.DefaultHelper.INSTANCE, buildFile.getRelative("WORKSPACE"), "TESTING");
+    externalPkg.setWorkspaceName(workspaceName);
+    return pkgFactory.createPackageForTesting(
+        packageId, externalPkg.build(), buildFile, packageIdentifier -> buildFile, reporter);
+  }
+}
diff --git a/src/test/java/com/google/devtools/build/skyframe/MemoizingEvaluatorTest.java b/src/test/java/com/google/devtools/build/skyframe/MemoizingEvaluatorTest.java
index 0b1ed35..b1518fc 100644
--- a/src/test/java/com/google/devtools/build/skyframe/MemoizingEvaluatorTest.java
+++ b/src/test/java/com/google/devtools/build/skyframe/MemoizingEvaluatorTest.java
@@ -53,6 +53,7 @@
 import com.google.devtools.build.skyframe.SkyFunctionException.Transience;
 import java.lang.ref.WeakReference;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -214,6 +215,33 @@
   }
 
   @Test
+  public void testEnvProvidesTemporaryDirectDeps() throws Exception {
+    AtomicInteger counter = new AtomicInteger();
+    List<SkyKey> deps = Collections.synchronizedList(new ArrayList<>());
+    SkyKey topKey = toSkyKey("top");
+    SkyKey bottomKey = toSkyKey("bottom");
+    SkyValue bottomValue = new StringValue("bottom");
+    tester
+        .getOrCreate(topKey)
+        .setBuilder(
+            new NoExtractorFunction() {
+              @Override
+              public SkyValue compute(SkyKey skyKey, Environment env) throws InterruptedException {
+                if (counter.getAndIncrement() > 0) {
+                  deps.addAll(env.getTemporaryDirectDeps().get(0));
+                } else {
+                  assertThat(env.getTemporaryDirectDeps().listSize()).isEqualTo(0);
+                }
+                return env.getValue(bottomKey);
+              }
+            });
+    tester.getOrCreate(bottomKey).setConstantValue(bottomValue);
+    EvaluationResult<StringValue> result = tester.eval(/*keepGoing=*/ true, "top");
+    assertThat(result.get(topKey)).isEqualTo(bottomValue);
+    assertThat(deps).containsExactly(bottomKey);
+  }
+
+  @Test
   public void cachedErrorShutsDownThreadpool() throws Exception {
     // When a node throws an error on the first build,
     SkyKey cachedErrorKey = GraphTester.skyKey("error");