Use a ConcurrentPathTrie to track paths to download in RemoteOutputChecker.

This makes it possible to accurately download action outputs in a followup change to RemoteExecutionService (it can't query the RemoteOutputChecker for individual files inside a tree artifact because the respective TreeFileArtifact doesn't exist yet).

PiperOrigin-RevId: 541716962
Change-Id: I7e78f54b490c5e23dc59fcd1bced8aeea8c36411
diff --git a/src/main/java/com/google/devtools/build/lib/remote/BUILD b/src/main/java/com/google/devtools/build/lib/remote/BUILD
index 47c76d1..9421f76 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/remote/BUILD
@@ -196,6 +196,7 @@
         "//src/main/java/com/google/devtools/build/lib/clock",
         "//src/main/java/com/google/devtools/build/lib/packages",
         "//src/main/java/com/google/devtools/build/lib/remote/options",
+        "//src/main/java/com/google/devtools/build/lib/remote/util",
         "//src/main/java/com/google/devtools/build/lib/skyframe:coverage_report_value",
         "//third_party:guava",
         "//third_party:jsr305",
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteOutputChecker.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteOutputChecker.java
index c5c7e19..d8ffdff 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteOutputChecker.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteOutputChecker.java
@@ -19,10 +19,8 @@
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
-import com.google.common.collect.Sets;
 import com.google.devtools.build.lib.actions.ActionInput;
 import com.google.devtools.build.lib.actions.Artifact;
-import com.google.devtools.build.lib.actions.Artifact.TreeFileArtifact;
 import com.google.devtools.build.lib.actions.FileArtifactValue.RemoteFileArtifactValue;
 import com.google.devtools.build.lib.actions.RemoteArtifactChecker;
 import com.google.devtools.build.lib.analysis.AnalysisResult;
@@ -36,7 +34,7 @@
 import com.google.devtools.build.lib.analysis.test.TestProvider;
 import com.google.devtools.build.lib.clock.Clock;
 import com.google.devtools.build.lib.remote.options.RemoteOutputsMode;
-import java.util.Set;
+import com.google.devtools.build.lib.remote.util.ConcurrentPathTrie;
 import java.util.function.Supplier;
 import java.util.regex.Pattern;
 import javax.annotation.Nullable;
@@ -55,8 +53,7 @@
   private final CommandMode commandMode;
   private final boolean downloadToplevel;
   private final ImmutableList<Pattern> patternsToDownload;
-  private final Set<ActionInput> toplevelArtifactsToDownload = Sets.newConcurrentHashSet();
-  private final Set<ActionInput> inputsToDownload = Sets.newConcurrentHashSet();
+  private final ConcurrentPathTrie pathsToDownload = new ConcurrentPathTrie();
 
   public RemoteOutputChecker(
       Clock clock,
@@ -134,8 +131,7 @@
       var artifactsToBuild =
           TopLevelArtifactHelper.getAllArtifactsToBuild(target, topLevelArtifactContext)
               .getImportantArtifacts();
-      toplevelArtifactsToDownload.addAll(artifactsToBuild.toList());
-
+      addOutputsToDownload(artifactsToBuild.toList());
       addRunfiles(target);
     }
   }
@@ -154,21 +150,21 @@
       if (runfile.isSourceArtifact()) {
         continue;
       }
-      toplevelArtifactsToDownload.add(runfile);
+      addOutputToDownload(runfile);
     }
     for (var symlink : runfiles.getSymlinks().toList()) {
       var artifact = symlink.getArtifact();
       if (artifact.isSourceArtifact()) {
         continue;
       }
-      toplevelArtifactsToDownload.add(artifact);
+      addOutputToDownload(artifact);
     }
     for (var symlink : runfiles.getRootSymlinks().toList()) {
       var artifact = symlink.getArtifact();
       if (artifact.isSourceArtifact()) {
         continue;
       }
-      toplevelArtifactsToDownload.add(artifact);
+      addOutputToDownload(artifact);
     }
   }
 
@@ -176,13 +172,13 @@
     TestProvider testProvider = checkNotNull(target.getProvider(TestProvider.class));
     if (downloadToplevel && commandMode == CommandMode.TEST) {
       // In test mode, download the outputs of the test runner action.
-      toplevelArtifactsToDownload.addAll(testProvider.getTestParams().getOutputs());
+      addOutputsToDownload(testProvider.getTestParams().getOutputs());
     }
     if (commandMode == CommandMode.COVERAGE) {
       // In coverage mode, download the per-test and aggregated coverage files.
       // Do this even for MINIMAL, since coverage (unlike test) doesn't produce any observable
       // results other than outputs.
-      toplevelArtifactsToDownload.addAll(testProvider.getTestParams().getCoverageArtifacts());
+      addOutputsToDownload(testProvider.getTestParams().getCoverageArtifacts());
     }
   }
 
@@ -192,13 +188,23 @@
     }
     for (Artifact artifactToBuild : artifactsToBuild) {
       if (artifactToBuild.getArtifactOwner().equals(COVERAGE_REPORT_KEY)) {
-        toplevelArtifactsToDownload.add(artifactToBuild);
+        addOutputToDownload(artifactToBuild);
       }
     }
   }
 
-  public void addInputToDownload(ActionInput file) {
-    inputsToDownload.add(file);
+  private void addOutputsToDownload(Iterable<? extends ActionInput> files) {
+    for (ActionInput file : files) {
+      addOutputToDownload(file);
+    }
+  }
+
+  public void addOutputToDownload(ActionInput file) {
+    if (file instanceof Artifact && ((Artifact) file).isTreeArtifact()) {
+      pathsToDownload.addPrefix(file.getExecPath());
+    } else {
+      pathsToDownload.add(file.getExecPath());
+    }
   }
 
   private boolean shouldAddTopLevelTarget(@Nullable ConfiguredTarget configuredTarget) {
@@ -220,15 +226,7 @@
     }
   }
 
-  private boolean isTopLevelArtifact(ActionInput output) {
-    return isPartOfCollectedSet(output, toplevelArtifactsToDownload);
-  }
-
-  private boolean isInputToLocalAction(ActionInput output) {
-    return isPartOfCollectedSet(output, inputsToDownload);
-  }
-
-  private boolean matchesRegex(ActionInput output) {
+  private boolean matchesPattern(ActionInput output) {
     if (output instanceof Artifact && ((Artifact) output).isTreeArtifact()) {
       return false;
     }
@@ -242,19 +240,11 @@
     return false;
   }
 
-  private static boolean isPartOfCollectedSet(
-      ActionInput actionInput, Set<ActionInput> artifactSet) {
-    return artifactSet.contains(
-        actionInput instanceof TreeFileArtifact
-            ? ((Artifact) actionInput).getParent()
-            : actionInput);
-  }
-
   /**
    * Returns {@code true} if Bazel should download this {@link ActionInput} during spawn execution.
    */
   public boolean shouldDownloadOutput(ActionInput output) {
-    return isTopLevelArtifact(output) || isInputToLocalAction(output) || matchesRegex(output);
+    return pathsToDownload.contains(output.getExecPath()) || matchesPattern(output);
   }
 
   @Override