Save injected metadata in RemoteActionFileSystem

So that spawn outputs can be accessed among Spwans within the same action using the `FileSystem` API.

This allow us to revert the hack we introduced in #12590. Also fixes the issue described by #15711.

Closes #15711.

Closes #16123.

PiperOrigin-RevId: 469133936
Change-Id: Ide5bcfa0fe2c6a3806d333cd61270e411aa78f80
diff --git a/src/main/java/com/google/devtools/build/lib/exec/AbstractSpawnStrategy.java b/src/main/java/com/google/devtools/build/lib/exec/AbstractSpawnStrategy.java
index ce65d86..f7fd7b3 100644
--- a/src/main/java/com/google/devtools/build/lib/exec/AbstractSpawnStrategy.java
+++ b/src/main/java/com/google/devtools/build/lib/exec/AbstractSpawnStrategy.java
@@ -52,6 +52,7 @@
 import com.google.devtools.build.lib.server.FailureDetails.Spawn.Code;
 import com.google.devtools.build.lib.util.CommandFailureUtils;
 import com.google.devtools.build.lib.util.io.FileOutErr;
+import com.google.devtools.build.lib.vfs.FileSystem;
 import com.google.devtools.build.lib.vfs.Path;
 import com.google.devtools.build.lib.vfs.PathFragment;
 import java.io.IOException;
@@ -353,5 +354,11 @@
         throw e.toExecException();
       }
     }
+
+    @Nullable
+    @Override
+    public FileSystem getActionFileSystem() {
+      return actionExecutionContext.getActionFileSystem();
+    }
   }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/exec/SpawnRunner.java b/src/main/java/com/google/devtools/build/lib/exec/SpawnRunner.java
index 178cac7..9b1fc84 100644
--- a/src/main/java/com/google/devtools/build/lib/exec/SpawnRunner.java
+++ b/src/main/java/com/google/devtools/build/lib/exec/SpawnRunner.java
@@ -31,6 +31,7 @@
 import com.google.devtools.build.lib.actions.cache.MetadataInjector;
 import com.google.devtools.build.lib.events.ExtendedEventHandler;
 import com.google.devtools.build.lib.util.io.FileOutErr;
+import com.google.devtools.build.lib.vfs.FileSystem;
 import com.google.devtools.build.lib.vfs.Path;
 import com.google.devtools.build.lib.vfs.PathFragment;
 import java.io.IOException;
@@ -270,6 +271,10 @@
 
     /** Throws if rewinding is enabled and lost inputs have been detected. */
     void checkForLostInputs() throws LostInputsExecException;
+
+    /** Returns action-scoped file system or {@code null} if it doesn't exist. */
+    @Nullable
+    FileSystem getActionFileSystem();
   }
 
   /**
diff --git a/src/main/java/com/google/devtools/build/lib/exec/StandaloneTestStrategy.java b/src/main/java/com/google/devtools/build/lib/exec/StandaloneTestStrategy.java
index 09dbe44..1261a89 100644
--- a/src/main/java/com/google/devtools/build/lib/exec/StandaloneTestStrategy.java
+++ b/src/main/java/com/google/devtools/build/lib/exec/StandaloneTestStrategy.java
@@ -26,22 +26,17 @@
 import com.google.devtools.build.lib.actions.ActionExecutionContext;
 import com.google.devtools.build.lib.actions.ActionInput;
 import com.google.devtools.build.lib.actions.ActionInputHelper;
-import com.google.devtools.build.lib.actions.Artifact;
-import com.google.devtools.build.lib.actions.Artifact.DerivedArtifact;
 import com.google.devtools.build.lib.actions.Artifact.SpecialArtifact;
-import com.google.devtools.build.lib.actions.Artifact.TreeFileArtifact;
 import com.google.devtools.build.lib.actions.ArtifactPathResolver;
 import com.google.devtools.build.lib.actions.EnvironmentalExecException;
 import com.google.devtools.build.lib.actions.ExecException;
 import com.google.devtools.build.lib.actions.ExecutionRequirements;
-import com.google.devtools.build.lib.actions.FileArtifactValue;
 import com.google.devtools.build.lib.actions.SimpleSpawn;
 import com.google.devtools.build.lib.actions.Spawn;
 import com.google.devtools.build.lib.actions.SpawnContinuation;
 import com.google.devtools.build.lib.actions.SpawnMetrics;
 import com.google.devtools.build.lib.actions.SpawnResult;
 import com.google.devtools.build.lib.actions.TestExecException;
-import com.google.devtools.build.lib.actions.cache.MetadataHandler;
 import com.google.devtools.build.lib.analysis.actions.SpawnAction;
 import com.google.devtools.build.lib.analysis.test.TestAttempt;
 import com.google.devtools.build.lib.analysis.test.TestResult;
@@ -57,7 +52,6 @@
 import com.google.devtools.build.lib.server.FailureDetails.Execution.Code;
 import com.google.devtools.build.lib.server.FailureDetails.FailureDetail;
 import com.google.devtools.build.lib.server.FailureDetails.TestAction;
-import com.google.devtools.build.lib.skyframe.TreeArtifactValue;
 import com.google.devtools.build.lib.util.Pair;
 import com.google.devtools.build.lib.util.io.FileOutErr;
 import com.google.devtools.build.lib.vfs.FileStatus;
@@ -76,8 +70,6 @@
 import java.util.List;
 import java.util.Map;
 import java.util.TreeMap;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ConcurrentMap;
 import javax.annotation.Nullable;
 
 /** Runs TestRunnerAction actions. */
@@ -147,7 +139,7 @@
             ImmutableMap.of(),
             /*inputs=*/ action.getInputs(),
             NestedSetBuilder.emptySet(Order.STABLE_ORDER),
-            createSpawnOutputs(action),
+            ImmutableSet.copyOf(action.getSpawnOutputs()),
             /*mandatoryOutputs=*/ ImmutableSet.of(),
             localResourcesSupplier);
     Path execRoot = actionExecutionContext.getExecRoot();
@@ -159,21 +151,6 @@
         action, actionExecutionContext, spawn, tmpDir, workingDirectory, execRoot);
   }
 
-  private ImmutableSet<ActionInput> createSpawnOutputs(TestRunnerAction action) {
-    ImmutableSet.Builder<ActionInput> builder = ImmutableSet.builder();
-    for (ActionInput output : action.getSpawnOutputs()) {
-      if (output.getExecPath().equals(action.getXmlOutputPath())) {
-        // HACK: Convert type of test.xml from BasicActionInput to DerivedArtifact. We want to
-        // inject metadata of test.xml if it is generated remotely and it's currently only possible
-        // to inject Artifact.
-        builder.add(createArtifactOutput(action, output.getExecPath()));
-      } else {
-        builder.add(output);
-      }
-    }
-    return builder.build();
-  }
-
   private static ImmutableList<Pair<String, Path>> renameOutputs(
       ActionExecutionContext actionExecutionContext,
       TestRunnerAction action,
@@ -326,83 +303,6 @@
         action, clientEnv, getTimeout(action), runfilesDir.relativeTo(execRoot), relativeTmpDir);
   }
 
-  static class TestMetadataHandler implements MetadataHandler {
-    private final MetadataHandler metadataHandler;
-    private final ImmutableSet<Artifact> outputs;
-    private final ConcurrentMap<Artifact, FileArtifactValue> fileMetadataMap =
-        new ConcurrentHashMap<>();
-
-    TestMetadataHandler(MetadataHandler metadataHandler, ImmutableSet<Artifact> outputs) {
-      this.metadataHandler = metadataHandler;
-      this.outputs = outputs;
-    }
-
-    @Nullable
-    @Override
-    public ActionInput getInput(String execPath) {
-      return metadataHandler.getInput(execPath);
-    }
-
-    @Nullable
-    @Override
-    public FileArtifactValue getMetadata(ActionInput input) throws IOException {
-      return metadataHandler.getMetadata(input);
-    }
-
-    @Override
-    public void setDigestForVirtualArtifact(Artifact artifact, byte[] digest) {
-      metadataHandler.setDigestForVirtualArtifact(artifact, digest);
-    }
-
-    @Override
-    public FileArtifactValue constructMetadataForDigest(
-        Artifact output, FileStatus statNoFollow, byte[] injectedDigest) throws IOException {
-      return metadataHandler.constructMetadataForDigest(output, statNoFollow, injectedDigest);
-    }
-
-    @Override
-    public ImmutableSet<TreeFileArtifact> getTreeArtifactChildren(SpecialArtifact treeArtifact) {
-      return metadataHandler.getTreeArtifactChildren(treeArtifact);
-    }
-
-    @Override
-    public TreeArtifactValue getTreeArtifactValue(SpecialArtifact treeArtifact) throws IOException {
-      return metadataHandler.getTreeArtifactValue(treeArtifact);
-    }
-
-    @Override
-    public void markOmitted(Artifact output) {
-      metadataHandler.markOmitted(output);
-    }
-
-    @Override
-    public boolean artifactOmitted(Artifact artifact) {
-      return metadataHandler.artifactOmitted(artifact);
-    }
-
-    @Override
-    public void resetOutputs(Iterable<? extends Artifact> outputs) {
-      metadataHandler.resetOutputs(outputs);
-    }
-
-    @Override
-    public void injectFile(Artifact output, FileArtifactValue metadata) {
-      if (outputs.contains(output)) {
-        metadataHandler.injectFile(output, metadata);
-      }
-      fileMetadataMap.put(output, metadata);
-    }
-
-    @Override
-    public void injectTree(SpecialArtifact output, TreeArtifactValue tree) {
-      metadataHandler.injectTree(output, tree);
-    }
-
-    public boolean fileInjected(Artifact output) {
-      return fileMetadataMap.containsKey(output);
-    }
-  }
-
   private TestAttemptContinuation beginTestAttempt(
       TestRunnerAction testAction,
       Spawn spawn,
@@ -420,25 +320,12 @@
               Reporter.outErrForReporter(actionExecutionContext.getEventHandler()), out);
     }
 
-    // We use TestMetadataHandler here mainly because the one provided by actionExecutionContext
-    // doesn't allow to inject undeclared outputs and test.xml is undeclared by the test action.
-    TestMetadataHandler testMetadataHandler = null;
-    if (actionExecutionContext.getMetadataHandler() != null) {
-      testMetadataHandler =
-          new TestMetadataHandler(
-              actionExecutionContext.getMetadataHandler(), testAction.getOutputs());
-    }
-
     long startTimeMillis = actionExecutionContext.getClock().currentTimeMillis();
     SpawnStrategyResolver resolver = actionExecutionContext.getContext(SpawnStrategyResolver.class);
     SpawnContinuation spawnContinuation;
     try {
       spawnContinuation =
-          resolver.beginExecution(
-              spawn,
-              actionExecutionContext
-                  .withFileOutErr(testOutErr)
-                  .withMetadataHandler(testMetadataHandler));
+          resolver.beginExecution(spawn, actionExecutionContext.withFileOutErr(testOutErr));
     } catch (InterruptedException e) {
       if (streamed != null) {
         streamed.close();
@@ -448,7 +335,6 @@
     }
     return new BazelTestAttemptContinuation(
         testAction,
-        testMetadataHandler,
         actionExecutionContext,
         spawn,
         resolvedPaths,
@@ -559,12 +445,6 @@
     return Durations.fromNanos(d.toNanos());
   }
 
-  private static Artifact.DerivedArtifact createArtifactOutput(
-      TestRunnerAction action, PathFragment outputPath) {
-    Artifact.DerivedArtifact testLog = (Artifact.DerivedArtifact) action.getTestLog();
-    return DerivedArtifact.create(testLog.getRoot(), outputPath, testLog.getArtifactOwner());
-  }
-
   /**
    * A spawn to generate a test.xml file from the test log. This is only used if the test does not
    * generate a test.xml file itself.
@@ -610,7 +490,7 @@
         /*inputs=*/ NestedSetBuilder.create(
             Order.STABLE_ORDER, action.getTestXmlGeneratorScript(), action.getTestLog()),
         /*tools=*/ NestedSetBuilder.emptySet(Order.STABLE_ORDER),
-        /*outputs=*/ ImmutableSet.of(createArtifactOutput(action, action.getXmlOutputPath())),
+        /*outputs=*/ ImmutableSet.of(ActionInputHelper.fromPath(action.getXmlOutputPath())),
         /*mandatoryOutputs=*/ null,
         SpawnAction.DEFAULT_RESOURCE_SET);
   }
@@ -763,7 +643,6 @@
 
   private final class BazelTestAttemptContinuation extends TestAttemptContinuation {
     private final TestRunnerAction testAction;
-    @Nullable private final TestMetadataHandler testMetadataHandler;
     private final ActionExecutionContext actionExecutionContext;
     private final Spawn spawn;
     private final ResolvedPaths resolvedPaths;
@@ -776,7 +655,6 @@
 
     BazelTestAttemptContinuation(
         TestRunnerAction testAction,
-        @Nullable TestMetadataHandler testMetadataHandler,
         ActionExecutionContext actionExecutionContext,
         Spawn spawn,
         ResolvedPaths resolvedPaths,
@@ -787,7 +665,6 @@
         TestResultData.Builder testResultDataBuilder,
         ImmutableList<SpawnResult> spawnResults) {
       this.testAction = testAction;
-      this.testMetadataHandler = testMetadataHandler;
       this.actionExecutionContext = actionExecutionContext;
       this.spawn = spawn;
       this.resolvedPaths = resolvedPaths;
@@ -827,7 +704,6 @@
           if (!nextContinuation.isDone()) {
             return new BazelTestAttemptContinuation(
                 testAction,
-                testMetadataHandler,
                 actionExecutionContext,
                 spawn,
                 resolvedPaths,
@@ -918,7 +794,6 @@
           appendCoverageLog(coverageOutErr, fileOutErr);
           return new BazelCoveragePostProcessingContinuation(
               testAction,
-              testMetadataHandler,
               actionExecutionContext,
               spawn,
               resolvedPaths,
@@ -955,12 +830,6 @@
       }
 
       Path xmlOutputPath = resolvedPaths.getXmlOutputPath();
-      boolean testXmlGenerated = xmlOutputPath.exists();
-      if (!testXmlGenerated && testMetadataHandler != null) {
-        testXmlGenerated =
-            testMetadataHandler.fileInjected(
-                createArtifactOutput(testAction, testAction.getXmlOutputPath()));
-      }
 
       // If the test did not create a test.xml, and --experimental_split_xml_generation is enabled,
       // then we run a separate action to create a test.xml from test.log. We do this as a spawn
@@ -968,7 +837,7 @@
       // remote execution is enabled), and we do not want to have to download it.
       if (executionOptions.splitXmlGeneration
           && fileOutErr.getOutputPath().exists()
-          && !testXmlGenerated) {
+          && !xmlOutputPath.exists()) {
         Spawn xmlGeneratingSpawn =
             createXmlGeneratingSpawn(testAction, spawn.getEnvironment(), spawnResults.get(0));
         SpawnStrategyResolver spawnStrategyResolver =
@@ -979,10 +848,7 @@
         try {
           SpawnContinuation xmlContinuation =
               spawnStrategyResolver.beginExecution(
-                  xmlGeneratingSpawn,
-                  actionExecutionContext
-                      .withFileOutErr(xmlSpawnOutErr)
-                      .withMetadataHandler(testMetadataHandler));
+                  xmlGeneratingSpawn, actionExecutionContext.withFileOutErr(xmlSpawnOutErr));
           return new BazelXmlCreationContinuation(
               resolvedPaths, xmlSpawnOutErr, testResultDataBuilder, spawnResults, xmlContinuation);
         } catch (InterruptedException e) {
@@ -1080,7 +946,6 @@
 
   private final class BazelCoveragePostProcessingContinuation extends TestAttemptContinuation {
     private final ResolvedPaths resolvedPaths;
-    @Nullable private final TestMetadataHandler testMetadataHandler;
     private final FileOutErr fileOutErr;
     private final Closeable streamed;
     private final TestResultData.Builder testResultDataBuilder;
@@ -1092,7 +957,6 @@
 
     BazelCoveragePostProcessingContinuation(
         TestRunnerAction testAction,
-        @Nullable TestMetadataHandler testMetadataHandler,
         ActionExecutionContext actionExecutionContext,
         Spawn spawn,
         ResolvedPaths resolvedPaths,
@@ -1102,7 +966,6 @@
         ImmutableList<SpawnResult> primarySpawnResults,
         SpawnContinuation spawnContinuation) {
       this.testAction = testAction;
-      this.testMetadataHandler = testMetadataHandler;
       this.actionExecutionContext = actionExecutionContext;
       this.spawn = spawn;
       this.resolvedPaths = resolvedPaths;
@@ -1127,7 +990,6 @@
         if (!nextContinuation.isDone()) {
           return new BazelCoveragePostProcessingContinuation(
               testAction,
-              testMetadataHandler,
               actionExecutionContext,
               spawn,
               resolvedPaths,
@@ -1164,7 +1026,6 @@
 
       return new BazelTestAttemptContinuation(
           testAction,
-          testMetadataHandler,
           actionExecutionContext,
           spawn,
           resolvedPaths,
diff --git a/src/main/java/com/google/devtools/build/lib/remote/BUILD b/src/main/java/com/google/devtools/build/lib/remote/BUILD
index 1129795..e0a6b09 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/remote/BUILD
@@ -101,6 +101,7 @@
         "//src/main/java/com/google/devtools/build/lib/vfs",
         "//src/main/java/com/google/devtools/build/lib/vfs:output_service",
         "//src/main/java/com/google/devtools/build/lib/vfs:pathfragment",
+        "//src/main/java/com/google/devtools/build/skyframe",
         "//src/main/java/com/google/devtools/common/options",
         "//src/main/protobuf:failure_details_java_proto",
         "//third_party:auth",
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteActionFileSystem.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteActionFileSystem.java
index dc5018a..6c43cf6 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteActionFileSystem.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteActionFileSystem.java
@@ -17,9 +17,15 @@
 
 import static com.google.common.base.Preconditions.checkNotNull;
 
+import com.google.devtools.build.lib.actions.ActionInput;
 import com.google.devtools.build.lib.actions.ActionInputMap;
+import com.google.devtools.build.lib.actions.Artifact;
+import com.google.devtools.build.lib.actions.Artifact.SpecialArtifact;
+import com.google.devtools.build.lib.actions.Artifact.TreeFileArtifact;
 import com.google.devtools.build.lib.actions.FileArtifactValue;
 import com.google.devtools.build.lib.actions.FileArtifactValue.RemoteFileArtifactValue;
+import com.google.devtools.build.lib.actions.cache.MetadataInjector;
+import com.google.devtools.build.lib.skyframe.TreeArtifactValue;
 import com.google.devtools.build.lib.vfs.DelegateFileSystem;
 import com.google.devtools.build.lib.vfs.Dirent;
 import com.google.devtools.build.lib.vfs.FileStatus;
@@ -31,6 +37,8 @@
 import java.io.InputStream;
 import java.nio.channels.ReadableByteChannel;
 import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
 import javax.annotation.Nullable;
 
 /**
@@ -43,12 +51,15 @@
  *
  * <p>This implementation only supports creating local action outputs.
  */
-class RemoteActionFileSystem extends DelegateFileSystem {
+public class RemoteActionFileSystem extends DelegateFileSystem {
 
   private final PathFragment execRoot;
   private final PathFragment outputBase;
   private final ActionInputMap inputArtifactData;
   private final RemoteActionInputFetcher inputFetcher;
+  private final Map<PathFragment, RemoteFileArtifactValue> injectedMetadata = new HashMap<>();
+
+  @Nullable private MetadataInjector metadataInjector = null;
 
   RemoteActionFileSystem(
       FileSystem localDelegate,
@@ -68,6 +79,35 @@
     return getRemoteInputMetadata(path.asFragment()) != null;
   }
 
+  public void updateContext(MetadataInjector metadataInjector) {
+    this.metadataInjector = metadataInjector;
+  }
+
+  void injectTree(SpecialArtifact tree, TreeArtifactValue metadata) {
+    checkNotNull(metadataInjector, "metadataInject is null");
+
+    for (Map.Entry<TreeFileArtifact, FileArtifactValue> entry :
+        metadata.getChildValues().entrySet()) {
+      FileArtifactValue childMetadata = entry.getValue();
+      if (childMetadata instanceof RemoteFileArtifactValue) {
+        TreeFileArtifact child = entry.getKey();
+        injectedMetadata.put(child.getExecPath(), (RemoteFileArtifactValue) childMetadata);
+      }
+    }
+
+    metadataInjector.injectTree(tree, metadata);
+  }
+
+  void injectFile(ActionInput file, RemoteFileArtifactValue metadata) {
+    checkNotNull(metadataInjector, "metadataInject is null");
+
+    injectedMetadata.put(file.getExecPath(), metadata);
+
+    if (file instanceof Artifact) {
+      metadataInjector.injectFile((Artifact) file, metadata);
+    }
+  }
+
   @Override
   public String getFileSystemType(PathFragment path) {
     return "remoteActionFS";
@@ -330,7 +370,8 @@
     if (m != null && m.isRemote()) {
       return (RemoteFileArtifactValue) m;
     }
-    return null;
+
+    return injectedMetadata.get(execPath);
   }
 
   private void downloadFileIfRemote(PathFragment path) throws IOException {
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java
index 729c741..f4f6502 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java
@@ -67,7 +67,7 @@
 
   @Override
   protected boolean shouldDownloadFile(Path path, FileArtifactValue metadata) {
-    return metadata.isRemote();
+    return metadata.isRemote() && !path.exists();
   }
 
   @Override
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionService.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionService.java
index cbab9f6..81e5f27 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionService.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteExecutionService.java
@@ -78,7 +78,6 @@
 import com.google.devtools.build.lib.actions.SpawnResult;
 import com.google.devtools.build.lib.actions.Spawns;
 import com.google.devtools.build.lib.actions.UserExecException;
-import com.google.devtools.build.lib.actions.cache.MetadataInjector;
 import com.google.devtools.build.lib.analysis.platform.PlatformUtils;
 import com.google.devtools.build.lib.buildtool.buildevent.BuildInterruptedEvent;
 import com.google.devtools.build.lib.events.Event;
@@ -111,6 +110,7 @@
 import com.google.devtools.build.lib.server.FailureDetails.RemoteExecution;
 import com.google.devtools.build.lib.skyframe.TreeArtifactValue;
 import com.google.devtools.build.lib.util.io.FileOutErr;
+import com.google.devtools.build.lib.vfs.FileSystem;
 import com.google.devtools.build.lib.vfs.FileSystemUtils;
 import com.google.devtools.build.lib.vfs.Path;
 import com.google.devtools.build.lib.vfs.PathFragment;
@@ -768,11 +768,15 @@
   }
 
   private void injectRemoteArtifact(
-      RemoteAction action, Artifact output, ActionResultMetadata metadata) throws IOException {
+      RemoteAction action, ActionInput output, ActionResultMetadata metadata) throws IOException {
+    FileSystem actionFileSystem = action.getSpawnExecutionContext().getActionFileSystem();
+    checkState(actionFileSystem instanceof RemoteActionFileSystem);
+
     RemoteActionExecutionContext context = action.getRemoteActionExecutionContext();
-    MetadataInjector metadataInjector = action.getSpawnExecutionContext().getMetadataInjector();
+    RemoteActionFileSystem remoteActionFileSystem = (RemoteActionFileSystem) actionFileSystem;
+
     Path path = remotePathResolver.outputPathToLocalPath(output);
-    if (output.isTreeArtifact()) {
+    if (output instanceof Artifact && ((Artifact) output).isTreeArtifact()) {
       DirectoryMetadata directory = metadata.directory(path);
       if (directory == null) {
         // A declared output wasn't created. It might have been an optional output and if not
@@ -797,7 +801,7 @@
                 context.getRequestMetadata().getActionId());
         tree.putChild(child, value);
       }
-      metadataInjector.injectTree(parent, tree.build());
+      remoteActionFileSystem.injectTree(parent, tree.build());
     } else {
       FileMetadata outputMetadata = metadata.file(path);
       if (outputMetadata == null) {
@@ -805,7 +809,7 @@
         // SkyFrame will make sure to fail.
         return;
       }
-      metadataInjector.injectFile(
+      remoteActionFileSystem.injectFile(
           output,
           new RemoteFileArtifactValue(
               DigestUtil.toBinaryDigest(outputMetadata.digest()),
@@ -1176,9 +1180,7 @@
           inMemoryOutputDigest = m.digest();
           inMemoryOutput = output;
         }
-        if (output instanceof Artifact) {
-          injectRemoteArtifact(action, (Artifact) output, metadata);
-        }
+        injectRemoteArtifact(action, output, metadata);
       }
 
       try (SilentCloseable c = Profiler.instance().profile("Remote.downloadInMemoryOutput")) {
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteOutputService.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteOutputService.java
index 13994d0..9e71413 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteOutputService.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteOutputService.java
@@ -17,12 +17,14 @@
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableCollection;
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
 import com.google.devtools.build.lib.actions.Action;
 import com.google.devtools.build.lib.actions.ActionInputMap;
 import com.google.devtools.build.lib.actions.Artifact;
 import com.google.devtools.build.lib.actions.ArtifactPathResolver;
 import com.google.devtools.build.lib.actions.FilesetOutputSymlink;
 import com.google.devtools.build.lib.actions.cache.MetadataHandler;
+import com.google.devtools.build.lib.actions.cache.MetadataInjector;
 import com.google.devtools.build.lib.events.EventHandler;
 import com.google.devtools.build.lib.util.AbruptExitException;
 import com.google.devtools.build.lib.vfs.BatchStat;
@@ -31,6 +33,7 @@
 import com.google.devtools.build.lib.vfs.OutputService;
 import com.google.devtools.build.lib.vfs.PathFragment;
 import com.google.devtools.build.lib.vfs.Root;
+import com.google.devtools.build.skyframe.SkyFunction.Environment;
 import java.util.Map;
 import java.util.UUID;
 import javax.annotation.Nullable;
@@ -71,6 +74,15 @@
   }
 
   @Override
+  public void updateActionFileSystemContext(
+      FileSystem actionFileSystem,
+      Environment env,
+      MetadataInjector injector,
+      ImmutableMap<Artifact, ImmutableList<FilesetOutputSymlink>> filesets) {
+    ((RemoteActionFileSystem) actionFileSystem).updateContext(injector);
+  }
+
+  @Override
   public String getFilesSystemName() {
     return "remoteActionFS";
   }