Download tree artifacts for toplevel targets.

PiperOrigin-RevId: 482729247
Change-Id: I45749cb9868e4f7dd0d733b3f0a5243628a07e5e
diff --git a/src/main/java/com/google/devtools/build/lib/remote/BUILD b/src/main/java/com/google/devtools/build/lib/remote/BUILD
index 696b054..d09608a 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/remote/BUILD
@@ -197,7 +197,6 @@
     srcs = ["ToplevelArtifactsDownloader.java"],
     deps = [
         ":abstract_action_input_prefetcher",
-        "//src/main/java/com/google/devtools/build/lib/actions",
         "//src/main/java/com/google/devtools/build/lib/actions:artifacts",
         "//src/main/java/com/google/devtools/build/lib/actions:file_metadata",
         "//src/main/java/com/google/devtools/build/lib/analysis:analysis_cluster",
@@ -205,6 +204,7 @@
         "//src/main/java/com/google/devtools/build/lib/skyframe:action_execution_value",
         "//src/main/java/com/google/devtools/build/lib/skyframe:tree_artifact_value",
         "//src/main/java/com/google/devtools/build/skyframe",
+        "//src/main/java/com/google/devtools/build/skyframe:skyframe-objects",
         "//third_party:flogger",
         "//third_party:guava",
         "//third_party:jsr305",
diff --git a/src/main/java/com/google/devtools/build/lib/remote/ToplevelArtifactsDownloader.java b/src/main/java/com/google/devtools/build/lib/remote/ToplevelArtifactsDownloader.java
index 3858a8f..a2fd28a 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/ToplevelArtifactsDownloader.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/ToplevelArtifactsDownloader.java
@@ -13,6 +13,7 @@
 // limitations under the License.
 package com.google.devtools.build.lib.remote;
 
+import static com.google.common.collect.ImmutableSet.toImmutableSet;
 import static com.google.common.util.concurrent.Futures.addCallback;
 import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
 
@@ -24,7 +25,6 @@
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.devtools.build.lib.actions.ActionInput;
 import com.google.devtools.build.lib.actions.Artifact;
-import com.google.devtools.build.lib.actions.CompletionContext;
 import com.google.devtools.build.lib.actions.FileArtifactValue;
 import com.google.devtools.build.lib.analysis.AspectCompleteEvent;
 import com.google.devtools.build.lib.analysis.Runfiles;
@@ -35,6 +35,7 @@
 import com.google.devtools.build.lib.skyframe.ActionExecutionValue;
 import com.google.devtools.build.lib.skyframe.TreeArtifactValue;
 import com.google.devtools.build.skyframe.MemoizingEvaluator;
+import com.google.devtools.build.skyframe.SkyValue;
 import javax.annotation.Nullable;
 
 /**
@@ -60,8 +61,7 @@
       return;
     }
 
-    downloadTargetOutputs(
-        event.getCompletionContext(), event.getOutputGroups(), /* runfiles = */ null);
+    downloadTargetOutputs(event.getOutputGroups(), /* runfiles = */ null);
   }
 
   @Subscribe
@@ -72,30 +72,24 @@
     }
 
     downloadTargetOutputs(
-        event.getCompletionContext(),
         event.getOutputs(),
         event.getExecutableTargetData().getRunfiles());
   }
 
   private void downloadTargetOutputs(
-      CompletionContext completionContext,
-      ImmutableMap<String, ArtifactsInOutputGroup> outputGroups,
-      @Nullable Runfiles runfiles) {
+      ImmutableMap<String, ArtifactsInOutputGroup> outputGroups, @Nullable Runfiles runfiles) {
 
     var builder = ImmutableMap.<ActionInput, FileArtifactValue>builder();
-    for (ArtifactsInOutputGroup outputs : outputGroups.values()) {
-      if (!outputs.areImportant()) {
-        continue;
-      }
-      for (Artifact output : outputs.getArtifacts().toList()) {
-        var metadata = completionContext.getFileArtifactValue(output);
-        if (metadata != null) {
-          builder.put(output, metadata);
+    try {
+      for (ArtifactsInOutputGroup outputs : outputGroups.values()) {
+        if (!outputs.areImportant()) {
+          continue;
+        }
+        for (Artifact output : outputs.getArtifacts().toList()) {
+          appendArtifact(output, builder);
         }
       }
-    }
 
-    try {
       appendRunfiles(runfiles, builder);
     } catch (InterruptedException ignored) {
       Thread.currentThread().interrupt();
@@ -105,7 +99,9 @@
     var outputsAndMetadata = builder.buildKeepingLast();
     ListenableFuture<Void> future =
         actionInputPrefetcher.prefetchFiles(
-            outputsAndMetadata.keySet(),
+            outputsAndMetadata.keySet().stream()
+                .filter(ToplevelArtifactsDownloader::isNonTreeArtifact)
+                .collect(toImmutableSet()),
             new StaticMetadataProvider(outputsAndMetadata),
             Priority.LOW);
 
@@ -123,6 +119,10 @@
         directExecutor());
   }
 
+  private static boolean isNonTreeArtifact(ActionInput actionInput) {
+    return !(actionInput instanceof Artifact && ((Artifact) actionInput).isTreeArtifact());
+  }
+
   private void appendRunfiles(
       @Nullable Runfiles runfiles, ImmutableMap.Builder<ActionInput, FileArtifactValue> builder)
       throws InterruptedException {
@@ -131,21 +131,22 @@
     }
 
     for (Artifact runfile : runfiles.getArtifacts().toList()) {
-      var actionExecutionValue =
-          (ActionExecutionValue) memoizingEvaluator.getExistingValue(Artifact.key(runfile));
-      if (actionExecutionValue != null) {
-        if (runfile.isTreeArtifact()) {
-          TreeArtifactValue metadata = actionExecutionValue.getAllTreeArtifactValues().get(runfile);
-          if (metadata != null) {
-            builder.putAll(metadata.getChildValues());
-          }
-        } else {
-          FileArtifactValue metadata = actionExecutionValue.getAllFileValues().get(runfile);
-          if (metadata != null) {
-            builder.put(runfile, metadata);
-          }
-        }
+      appendArtifact(runfile, builder);
+    }
+  }
+
+  private void appendArtifact(
+      Artifact artifact, ImmutableMap.Builder<ActionInput, FileArtifactValue> builder)
+      throws InterruptedException {
+    SkyValue value = memoizingEvaluator.getExistingValue(Artifact.key(artifact));
+    if (value instanceof ActionExecutionValue) {
+      FileArtifactValue metadata = ((ActionExecutionValue) value).getAllFileValues().get(artifact);
+      if (metadata != null) {
+        builder.put(artifact, metadata);
       }
+    } else if (value instanceof TreeArtifactValue) {
+      builder.put(artifact, ((TreeArtifactValue) value).getMetadata());
+      builder.putAll(((TreeArtifactValue) value).getChildValues());
     }
   }
 }