Cleanup AbstractActionInputPrefetcher and split RemoteActionInputFetcherTest.

PiperOrigin-RevId: 450386875
diff --git a/src/main/java/com/google/devtools/build/lib/remote/AbstractActionInputPrefetcher.java b/src/main/java/com/google/devtools/build/lib/remote/AbstractActionInputPrefetcher.java
index cc2675f..bd30aea 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/AbstractActionInputPrefetcher.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/AbstractActionInputPrefetcher.java
@@ -18,6 +18,7 @@
 import static com.google.devtools.build.lib.remote.util.RxFutures.toListenableFuture;
 import static com.google.devtools.build.lib.remote.util.RxUtils.mergeBulkTransfer;
 import static com.google.devtools.build.lib.remote.util.RxUtils.toTransferResult;
+import static com.google.devtools.build.lib.remote.util.Utils.getFromFuture;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.ImmutableSet;
@@ -25,6 +26,7 @@
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.devtools.build.lib.actions.ActionInput;
 import com.google.devtools.build.lib.actions.ActionInputPrefetcher;
+import com.google.devtools.build.lib.actions.Artifact;
 import com.google.devtools.build.lib.actions.FileArtifactValue;
 import com.google.devtools.build.lib.actions.MetadataProvider;
 import com.google.devtools.build.lib.actions.cache.VirtualActionInput;
@@ -46,10 +48,8 @@
 import com.google.devtools.build.lib.vfs.Path;
 import io.reactivex.rxjava3.core.Completable;
 import io.reactivex.rxjava3.core.Flowable;
-import io.reactivex.rxjava3.functions.Function;
 import java.io.IOException;
 import java.util.concurrent.atomic.AtomicBoolean;
-import javax.annotation.Nullable;
 
 /**
  * Abstract implementation of {@link ActionInputPrefetcher} which implements the orchestration of
@@ -98,16 +98,15 @@
     }
   }
 
-  protected abstract boolean shouldDownloadInput(
-      ActionInput input, @Nullable FileArtifactValue metadata);
+  protected abstract boolean shouldDownloadFile(Path path, FileArtifactValue metadata);
 
   /**
-   * Downloads the {@code input} to the given path via the metadata.
+   * Downloads file to the given path via its metadata.
    *
-   * @param path the destination which the input should be written to.
+   * @param tempPath the temporary path which the input should be written to.
    */
-  protected abstract ListenableFuture<Void> downloadInput(
-      Path path, ActionInput input, FileArtifactValue metadata) throws IOException;
+  protected abstract ListenableFuture<Void> doDownloadFile(
+      Path tempPath, FileArtifactValue metadata) throws IOException;
 
   protected void prefetchVirtualActionInput(VirtualActionInput input) throws IOException {}
 
@@ -142,34 +141,37 @@
 
   private Completable prefetchInput(MetadataProvider metadataProvider, ActionInput input)
       throws IOException {
+    if (input instanceof Artifact && ((Artifact) input).isSourceArtifact()) {
+      return Completable.complete();
+    }
+
     if (input instanceof VirtualActionInput) {
       prefetchVirtualActionInput((VirtualActionInput) input);
       return Completable.complete();
     }
 
-    FileArtifactValue metadata = metadataProvider.getMetadata(input);
-    if (!shouldDownloadInput(input, metadata)) {
-      return Completable.complete();
-    }
-
     Path path = execRoot.getRelative(input.getExecPath());
-    return downloadFileIfNot(path, (p) -> downloadInput(p, input, metadata));
+    FileArtifactValue metadata = metadataProvider.getMetadata(input);
+    return downloadFileRx(path, metadata);
   }
 
   /**
-   * Downloads file into the {@code path} with given downloader.
+   * Downloads file into the {@code path} with its metadata.
    *
    * <p>The file will be written into a temporary file and moved to the final destination after the
    * download finished.
    */
-  protected Completable downloadFileIfNot(
-      Path path, Function<Path, ListenableFuture<Void>> downloader) {
+  public Completable downloadFileRx(Path path, FileArtifactValue metadata) {
+    if (!shouldDownloadFile(path, metadata)) {
+      return Completable.complete();
+    }
+
     AtomicBoolean completed = new AtomicBoolean(false);
     Completable download =
         Completable.using(
             tempPathGenerator::generateTempPath,
             tempPath ->
-                toCompletable(() -> downloader.apply(tempPath), directExecutor())
+                toCompletable(() -> doDownloadFile(tempPath, metadata), directExecutor())
                     .doOnComplete(
                         () -> {
                           finalizeDownload(tempPath, path);
@@ -186,6 +188,28 @@
     return downloadCache.executeIfNot(path, download);
   }
 
+  /**
+   * Downloads file into the {@code path} with its metadata.
+   *
+   * <p>The file will be written into a temporary file and moved to the final destination after the
+   * download finished.
+   */
+  public ListenableFuture<Void> downloadFileAsync(Path path, FileArtifactValue metadata) {
+    return toListenableFuture(downloadFileRx(path, metadata));
+  }
+
+  /**
+   * Download file to the {@code path} with given metadata. Blocking await for the download to
+   * complete.
+   *
+   * <p>The file will be written into a temporary file and moved to the final destination after the
+   * download finished.
+   */
+  public void downloadFile(Path path, FileArtifactValue metadata)
+      throws IOException, InterruptedException {
+    getFromFuture(downloadFileAsync(path, metadata));
+  }
+
   private void finalizeDownload(Path tmpPath, Path path) throws IOException {
     // The permission of output file is changed to 0555 after action execution. We manually change
     // the permission here for the downloaded file to keep this behaviour consistent.
@@ -202,16 +226,16 @@
     }
   }
 
-  ImmutableSet<Path> downloadedFiles() {
+  public ImmutableSet<Path> downloadedFiles() {
     return downloadCache.getFinishedTasks();
   }
 
-  ImmutableSet<Path> downloadsInProgress() {
+  public ImmutableSet<Path> downloadsInProgress() {
     return downloadCache.getInProgressTasks();
   }
 
   @VisibleForTesting
-  AsyncTaskCache.NoResult<Path> getDownloadCache() {
+  public AsyncTaskCache.NoResult<Path> getDownloadCache() {
     return downloadCache;
   }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/remote/BUILD b/src/main/java/com/google/devtools/build/lib/remote/BUILD
index a8faaf8..ed33a59 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/remote/BUILD
@@ -32,6 +32,7 @@
             "RemoteRetrier.java",
             "RemoteRetrierUtils.java",
             "Retrier.java",
+            "AbstractActionInputPrefetcher.java",
         ],
     ),
     exports = [
@@ -43,6 +44,7 @@
         ":ExecutionStatusException",
         ":ReferenceCountedChannel",
         ":Retrier",
+        ":abstract_action_input_prefetcher",
         "//src/main/java/com/google/devtools/build/lib:build-request-options",
         "//src/main/java/com/google/devtools/build/lib:runtime",
         "//src/main/java/com/google/devtools/build/lib/actions",
@@ -162,3 +164,24 @@
         "//third_party/grpc-java:grpc-jar",
     ],
 )
+
+java_library(
+    name = "abstract_action_input_prefetcher",
+    srcs = ["AbstractActionInputPrefetcher.java"],
+    deps = [
+        "//src/main/java/com/google/devtools/build/lib/actions",
+        "//src/main/java/com/google/devtools/build/lib/actions:artifacts",
+        "//src/main/java/com/google/devtools/build/lib/actions:file_metadata",
+        "//src/main/java/com/google/devtools/build/lib/events",
+        "//src/main/java/com/google/devtools/build/lib/profiler",
+        "//src/main/java/com/google/devtools/build/lib/remote/util",
+        "//src/main/java/com/google/devtools/build/lib/util:abrupt_exit_exception",
+        "//src/main/java/com/google/devtools/build/lib/util:detailed_exit_code",
+        "//src/main/java/com/google/devtools/build/lib/util:exit_code",
+        "//src/main/java/com/google/devtools/build/lib/vfs",
+        "//src/main/protobuf:failure_details_java_proto",
+        "//third_party:flogger",
+        "//third_party:guava",
+        "//third_party:rxjava3",
+    ],
+)
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java
index eb8d255..5b0213a 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java
@@ -14,14 +14,11 @@
 package com.google.devtools.build.lib.remote;
 
 import static com.google.common.base.Preconditions.checkArgument;
-import static com.google.devtools.build.lib.remote.util.RxFutures.toListenableFuture;
-import static com.google.devtools.build.lib.remote.util.Utils.getFromFuture;
 
 import build.bazel.remote.execution.v2.Digest;
 import build.bazel.remote.execution.v2.RequestMetadata;
 import com.google.common.base.Preconditions;
 import com.google.common.util.concurrent.ListenableFuture;
-import com.google.devtools.build.lib.actions.ActionInput;
 import com.google.devtools.build.lib.actions.FileArtifactValue;
 import com.google.devtools.build.lib.actions.cache.VirtualActionInput;
 import com.google.devtools.build.lib.actions.cache.VirtualActionInput.EmptyActionInput;
@@ -35,7 +32,6 @@
 import com.google.devtools.build.lib.vfs.Path;
 import io.reactivex.rxjava3.core.Completable;
 import java.io.IOException;
-import javax.annotation.Nullable;
 
 /**
  * Stages output files that are stored remotely to the local filesystem.
@@ -70,14 +66,20 @@
   }
 
   @Override
-  protected boolean shouldDownloadInput(ActionInput input, @Nullable FileArtifactValue metadata) {
-    return metadata != null && metadata.isRemote();
+  protected boolean shouldDownloadFile(Path path, FileArtifactValue metadata) {
+    return metadata.isRemote();
   }
 
   @Override
-  protected ListenableFuture<Void> downloadInput(
-      Path path, ActionInput input, FileArtifactValue metadata) throws IOException {
-    return downloadFileAsync(path, metadata);
+  protected ListenableFuture<Void> doDownloadFile(Path tempPath, FileArtifactValue metadata)
+      throws IOException {
+    checkArgument(metadata.isRemote(), "Cannot download file that is not a remote file.");
+    RequestMetadata requestMetadata =
+        TracingMetadataUtils.buildMetadata(buildRequestId, commandId, metadata.getActionId(), null);
+    RemoteActionExecutionContext context = RemoteActionExecutionContext.create(requestMetadata);
+
+    Digest digest = DigestUtil.buildDigest(metadata.getDigest(), metadata.getSize());
+    return remoteCache.downloadFile(context, tempPath, digest);
   }
 
   @Override
@@ -101,25 +103,4 @@
     }
     return Completable.error(error);
   }
-
-  private ListenableFuture<Void> downloadFileAsync(Path path, FileArtifactValue metadata)
-      throws IOException {
-    checkArgument(metadata.isRemote(), "Cannot download file that is not a remote file.");
-    RequestMetadata requestMetadata =
-        TracingMetadataUtils.buildMetadata(buildRequestId, commandId, metadata.getActionId(), null);
-    RemoteActionExecutionContext context = RemoteActionExecutionContext.create(requestMetadata);
-
-    Digest digest = DigestUtil.buildDigest(metadata.getDigest(), metadata.getSize());
-
-    return remoteCache.downloadFile(context, path, digest);
-  }
-
-  /** Download file to the {@code path} with given metadata. */
-  public void downloadFile(Path path, FileArtifactValue metadata)
-      throws IOException, InterruptedException {
-    if (metadata.isRemote()) {
-      getFromFuture(
-          toListenableFuture(downloadFileIfNot(path, (p) -> downloadFileAsync(p, metadata))));
-    }
-  }
 }
diff --git a/src/test/java/com/google/devtools/build/lib/remote/ActionInputPrefetcherTestBase.java b/src/test/java/com/google/devtools/build/lib/remote/ActionInputPrefetcherTestBase.java
new file mode 100644
index 0000000..4e97c84
--- /dev/null
+++ b/src/test/java/com/google/devtools/build/lib/remote/ActionInputPrefetcherTestBase.java
@@ -0,0 +1,346 @@
+// Copyright 2019 The Bazel Authors. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package com.google.devtools.build.lib.remote;
+
+import static com.google.common.base.Throwables.throwIfInstanceOf;
+import static com.google.common.truth.Truth.assertThat;
+import static java.nio.charset.StandardCharsets.UTF_8;
+import static org.junit.Assert.assertThrows;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.spy;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.hash.HashCode;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.SettableFuture;
+import com.google.devtools.build.lib.actions.ActionInput;
+import com.google.devtools.build.lib.actions.Artifact;
+import com.google.devtools.build.lib.actions.ArtifactRoot;
+import com.google.devtools.build.lib.actions.ArtifactRoot.RootType;
+import com.google.devtools.build.lib.actions.FileArtifactValue;
+import com.google.devtools.build.lib.actions.FileArtifactValue.RemoteFileArtifactValue;
+import com.google.devtools.build.lib.actions.MetadataProvider;
+import com.google.devtools.build.lib.actions.util.ActionsTestUtil;
+import com.google.devtools.build.lib.clock.JavaClock;
+import com.google.devtools.build.lib.remote.common.BulkTransferException;
+import com.google.devtools.build.lib.remote.util.StaticMetadataProvider;
+import com.google.devtools.build.lib.remote.util.TempPathGenerator;
+import com.google.devtools.build.lib.vfs.DigestHashFunction;
+import com.google.devtools.build.lib.vfs.FileSystem;
+import com.google.devtools.build.lib.vfs.FileSystemUtils;
+import com.google.devtools.build.lib.vfs.Path;
+import com.google.devtools.build.lib.vfs.inmemoryfs.InMemoryFileSystem;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.Supplier;
+import org.junit.Before;
+import org.junit.Test;
+
+/** Base test class for {@link AbstractActionInputPrefetcher} implementations. */
+public abstract class ActionInputPrefetcherTestBase {
+  protected static final DigestHashFunction HASH_FUNCTION = DigestHashFunction.SHA256;
+
+  protected FileSystem fs;
+  protected Path execRoot;
+  protected ArtifactRoot artifactRoot;
+  protected TempPathGenerator tempPathGenerator;
+
+  @Before
+  public void setUp() throws IOException {
+    fs = new InMemoryFileSystem(new JavaClock(), HASH_FUNCTION);
+    execRoot = fs.getPath("/exec");
+    execRoot.createDirectoryAndParents();
+    artifactRoot = ArtifactRoot.asDerivedRoot(execRoot, RootType.Output, "root");
+    artifactRoot.getRoot().asPath().createDirectoryAndParents();
+    Path tempDir = fs.getPath("/tmp");
+    tempDir.createDirectoryAndParents();
+    tempPathGenerator = new TempPathGenerator(tempDir);
+  }
+
+  protected Artifact createRemoteArtifact(
+      String pathFragment,
+      String contents,
+      Map<ActionInput, FileArtifactValue> metadata,
+      Map<HashCode, byte[]> cas) {
+    Path p = artifactRoot.getRoot().getRelative(pathFragment);
+    Artifact a = ActionsTestUtil.createArtifact(artifactRoot, p);
+    byte[] contentsBytes = contents.getBytes(UTF_8);
+    HashCode hashCode = HASH_FUNCTION.getHashFunction().hashBytes(contentsBytes);
+    FileArtifactValue f =
+        new RemoteFileArtifactValue(
+            hashCode.asBytes(), contentsBytes.length, /* locationIndex= */ 1, "action-id");
+    metadata.put(a, f);
+    cas.put(hashCode, contentsBytes);
+    return a;
+  }
+
+  protected abstract AbstractActionInputPrefetcher createPrefetcher(Map<HashCode, byte[]> cas);
+
+  @Test
+  public void prefetchFiles_downloadRemoteFiles() throws Exception {
+    Map<ActionInput, FileArtifactValue> metadata = new HashMap<>();
+    Map<HashCode, byte[]> cas = new HashMap<>();
+    Artifact a1 = createRemoteArtifact("file1", "hello world", metadata, cas);
+    Artifact a2 = createRemoteArtifact("file2", "fizz buzz", metadata, cas);
+    MetadataProvider metadataProvider = new StaticMetadataProvider(metadata);
+    AbstractActionInputPrefetcher prefetcher = createPrefetcher(cas);
+
+    wait(prefetcher.prefetchFiles(metadata.keySet(), metadataProvider));
+
+    assertThat(FileSystemUtils.readContent(a1.getPath(), UTF_8)).isEqualTo("hello world");
+    assertThat(a1.getPath().isExecutable()).isTrue();
+    assertThat(FileSystemUtils.readContent(a2.getPath(), UTF_8)).isEqualTo("fizz buzz");
+    assertThat(a2.getPath().isExecutable()).isTrue();
+    assertThat(prefetcher.downloadedFiles()).hasSize(2);
+    assertThat(prefetcher.downloadedFiles()).containsAtLeast(a1.getPath(), a2.getPath());
+    assertThat(prefetcher.downloadsInProgress()).isEmpty();
+  }
+
+  @Test
+  public void prefetchFiles_missingFiles_fails() throws Exception {
+    Map<ActionInput, FileArtifactValue> metadata = new HashMap<>();
+    Artifact a = createRemoteArtifact("file1", "hello world", metadata, /* cas= */ new HashMap<>());
+    MetadataProvider metadataProvider = new StaticMetadataProvider(metadata);
+    AbstractActionInputPrefetcher prefetcher = createPrefetcher(new HashMap<>());
+
+    assertThrows(
+        BulkTransferException.class,
+        () -> wait(prefetcher.prefetchFiles(ImmutableList.of(a), metadataProvider)));
+
+    assertThat(prefetcher.downloadedFiles()).isEmpty();
+    assertThat(prefetcher.downloadsInProgress()).isEmpty();
+  }
+
+  @Test
+  public void prefetchFiles_ignoreNonRemoteFiles() throws Exception {
+    // Test that files that are not remote are not downloaded
+
+    Path p = execRoot.getRelative(artifactRoot.getExecPath()).getRelative("file1");
+    FileSystemUtils.writeContent(p, UTF_8, "hello world");
+    Artifact a = ActionsTestUtil.createArtifact(artifactRoot, p);
+    FileArtifactValue f = FileArtifactValue.createForTesting(a);
+    MetadataProvider metadataProvider = new StaticMetadataProvider(ImmutableMap.of(a, f));
+    AbstractActionInputPrefetcher prefetcher = createPrefetcher(new HashMap<>());
+
+    wait(prefetcher.prefetchFiles(ImmutableList.of(a), metadataProvider));
+
+    assertThat(prefetcher.downloadedFiles()).isEmpty();
+    assertThat(prefetcher.downloadsInProgress()).isEmpty();
+  }
+
+  @Test
+  public void prefetchFiles_multipleThreads_downloadIsCancelled() throws Exception {
+    // Test shared downloads are cancelled if all threads/callers are interrupted
+
+    // arrange
+    Map<ActionInput, FileArtifactValue> metadata = new HashMap<>();
+    Map<HashCode, byte[]> cas = new HashMap<>();
+    Artifact artifact = createRemoteArtifact("file1", "hello world", metadata, cas);
+    MetadataProvider metadataProvider = new StaticMetadataProvider(metadata);
+
+    AbstractActionInputPrefetcher prefetcher = spy(createPrefetcher(cas));
+    SettableFuture<Void> downloadThatNeverFinishes = SettableFuture.create();
+    mockDownload(prefetcher, cas, () -> downloadThatNeverFinishes);
+
+    Thread cancelledThread1 =
+        new Thread(
+            () -> {
+              try {
+                wait(prefetcher.prefetchFiles(ImmutableList.of(artifact), metadataProvider));
+              } catch (IOException | InterruptedException ignored) {
+                // do nothing
+              }
+            });
+
+    Thread cancelledThread2 =
+        new Thread(
+            () -> {
+              try {
+                wait(prefetcher.prefetchFiles(ImmutableList.of(artifact), metadataProvider));
+              } catch (IOException | InterruptedException ignored) {
+                // do nothing
+              }
+            });
+
+    // act
+    cancelledThread1.start();
+    cancelledThread2.start();
+    cancelledThread1.interrupt();
+    cancelledThread2.interrupt();
+    cancelledThread1.join();
+    cancelledThread2.join();
+
+    // assert
+    assertThat(downloadThatNeverFinishes.isCancelled()).isTrue();
+    assertThat(artifact.getPath().exists()).isFalse();
+    assertThat(tempPathGenerator.getTempDir().getDirectoryEntries()).isEmpty();
+  }
+
+  @Test
+  public void prefetchFiles_multipleThreads_downloadIsNotCancelledByOtherThreads()
+      throws Exception {
+    // Test multiple threads can share downloads, but do not cancel each other when interrupted
+
+    // arrange
+    Map<ActionInput, FileArtifactValue> metadata = new HashMap<>();
+    Map<HashCode, byte[]> cas = new HashMap<>();
+    Artifact artifact = createRemoteArtifact("file1", "hello world", metadata, cas);
+    MetadataProvider metadataProvider = new StaticMetadataProvider(metadata);
+    SettableFuture<Void> download = SettableFuture.create();
+    AbstractActionInputPrefetcher prefetcher = spy(createPrefetcher(cas));
+    mockDownload(prefetcher, cas, () -> download);
+    Thread cancelledThread =
+        new Thread(
+            () -> {
+              try {
+                wait(prefetcher.prefetchFiles(ImmutableList.of(artifact), metadataProvider));
+              } catch (IOException | InterruptedException ignored) {
+                // do nothing
+              }
+            });
+
+    AtomicBoolean successful = new AtomicBoolean(false);
+    Thread successfulThread =
+        new Thread(
+            () -> {
+              try {
+                wait(prefetcher.prefetchFiles(ImmutableList.of(artifact), metadataProvider));
+                successful.set(true);
+              } catch (IOException | InterruptedException ignored) {
+                // do nothing
+              }
+            });
+    cancelledThread.start();
+    successfulThread.start();
+    while (true) {
+      if (prefetcher
+              .getDownloadCache()
+              .getSubscriberCount(execRoot.getRelative(artifact.getExecPath()))
+          == 2) {
+        break;
+      }
+    }
+
+    // act
+    cancelledThread.interrupt();
+    cancelledThread.join();
+    // simulate the download finishing
+    assertThat(download.isCancelled()).isFalse();
+    download.set(null);
+    successfulThread.join();
+
+    // assert
+    assertThat(successful.get()).isTrue();
+    assertThat(FileSystemUtils.readContent(artifact.getPath(), UTF_8)).isEqualTo("hello world");
+  }
+
+  @Test
+  public void downloadFile_downloadRemoteFiles() throws Exception {
+    Map<ActionInput, FileArtifactValue> metadata = new HashMap<>();
+    Map<HashCode, byte[]> cas = new HashMap<>();
+    Artifact a1 = createRemoteArtifact("file1", "hello world", metadata, cas);
+    AbstractActionInputPrefetcher prefetcher = createPrefetcher(cas);
+
+    prefetcher.downloadFile(a1.getPath(), metadata.get(a1));
+
+    assertThat(FileSystemUtils.readContent(a1.getPath(), UTF_8)).isEqualTo("hello world");
+    assertThat(a1.getPath().isExecutable()).isTrue();
+    assertThat(a1.getPath().isReadable()).isTrue();
+    assertThat(a1.getPath().isWritable()).isFalse();
+  }
+
+  @Test
+  public void downloadFile_onInterrupt_deletePartialDownloadedFile() throws Exception {
+    Semaphore startSemaphore = new Semaphore(0);
+    Semaphore endSemaphore = new Semaphore(0);
+    Map<ActionInput, FileArtifactValue> metadata = new HashMap<>();
+    Map<HashCode, byte[]> cas = new HashMap<>();
+    Artifact a1 = createRemoteArtifact("file1", "hello world", metadata, cas);
+    AbstractActionInputPrefetcher prefetcher = spy(createPrefetcher(cas));
+    mockDownload(
+        prefetcher,
+        cas,
+        () -> {
+          startSemaphore.release();
+          return SettableFuture.create(); // A future that never complete so we can interrupt later
+        });
+
+    AtomicBoolean interrupted = new AtomicBoolean(false);
+    Thread t =
+        new Thread(
+            () -> {
+              try {
+                prefetcher.downloadFile(a1.getPath(), metadata.get(a1));
+              } catch (IOException ignored) {
+                // Intentionally left empty
+              } catch (InterruptedException e) {
+                interrupted.set(true);
+              }
+              endSemaphore.release();
+            });
+    t.start();
+    startSemaphore.acquire();
+    t.interrupt();
+    endSemaphore.acquire();
+
+    assertThat(interrupted.get()).isTrue();
+    assertThat(a1.getPath().exists()).isFalse();
+    assertThat(tempPathGenerator.getTempDir().getDirectoryEntries()).isEmpty();
+  }
+
+  protected static void wait(ListenableFuture<Void> future)
+      throws IOException, InterruptedException {
+    try {
+      future.get();
+    } catch (ExecutionException e) {
+      Throwable cause = e.getCause();
+      if (cause != null) {
+        throwIfInstanceOf(cause, IOException.class);
+        throwIfInstanceOf(cause, InterruptedException.class);
+        throwIfInstanceOf(cause, RuntimeException.class);
+      }
+      throw new IOException(e);
+    } catch (InterruptedException e) {
+      future.cancel(/*mayInterruptIfRunning=*/ true);
+      throw e;
+    }
+  }
+
+  protected static void mockDownload(
+      AbstractActionInputPrefetcher prefetcher,
+      Map<HashCode, byte[]> cas,
+      Supplier<ListenableFuture<Void>> resultSupplier)
+      throws IOException {
+    doAnswer(
+            invocation -> {
+              Path path = invocation.getArgument(0);
+              FileArtifactValue metadata = invocation.getArgument(1);
+              byte[] content = cas.get(HashCode.fromBytes(metadata.getDigest()));
+              if (content == null) {
+                return Futures.immediateFailedFuture(new IOException("Not found"));
+              }
+              FileSystemUtils.writeContent(path, content);
+              return resultSupplier.get();
+            })
+        .when(prefetcher)
+        .doDownloadFile(any(), any());
+  }
+}
diff --git a/src/test/java/com/google/devtools/build/lib/remote/BUILD b/src/test/java/com/google/devtools/build/lib/remote/BUILD
index 8cc38a4..44df625 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/BUILD
+++ b/src/test/java/com/google/devtools/build/lib/remote/BUILD
@@ -69,6 +69,7 @@
         "//src/main/java/com/google/devtools/build/lib/exec:spawn_runner",
         "//src/main/java/com/google/devtools/build/lib/pkgcache",
         "//src/main/java/com/google/devtools/build/lib/remote",
+        "//src/main/java/com/google/devtools/build/lib/remote:abstract_action_input_prefetcher",
         "//src/main/java/com/google/devtools/build/lib/remote/common",
         "//src/main/java/com/google/devtools/build/lib/remote/disk",
         "//src/main/java/com/google/devtools/build/lib/remote/grpc",
diff --git a/src/test/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcherTest.java b/src/test/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcherTest.java
index 1d0d3df..4b13e10 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcherTest.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcherTest.java
@@ -13,112 +13,52 @@
 // limitations under the License.
 package com.google.devtools.build.lib.remote;
 
-import static com.google.common.base.Throwables.throwIfInstanceOf;
 import static com.google.common.truth.Truth.assertThat;
-import static org.junit.Assert.assertThrows;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
 
 import build.bazel.remote.execution.v2.Digest;
-import com.google.common.base.Supplier;
 import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Maps;
 import com.google.common.hash.HashCode;
-import com.google.common.util.concurrent.Futures;
-import com.google.common.util.concurrent.ListenableFuture;
-import com.google.common.util.concurrent.SettableFuture;
-import com.google.devtools.build.lib.actions.ActionInput;
-import com.google.devtools.build.lib.actions.Artifact;
-import com.google.devtools.build.lib.actions.ArtifactRoot;
-import com.google.devtools.build.lib.actions.ArtifactRoot.RootType;
-import com.google.devtools.build.lib.actions.FileArtifactValue;
-import com.google.devtools.build.lib.actions.FileArtifactValue.RemoteFileArtifactValue;
 import com.google.devtools.build.lib.actions.MetadataProvider;
 import com.google.devtools.build.lib.actions.cache.VirtualActionInput;
 import com.google.devtools.build.lib.actions.util.ActionsTestUtil;
-import com.google.devtools.build.lib.clock.JavaClock;
-import com.google.devtools.build.lib.remote.common.BulkTransferException;
 import com.google.devtools.build.lib.remote.options.RemoteOptions;
 import com.google.devtools.build.lib.remote.util.DigestUtil;
 import com.google.devtools.build.lib.remote.util.InMemoryCacheClient;
 import com.google.devtools.build.lib.remote.util.StaticMetadataProvider;
-import com.google.devtools.build.lib.remote.util.TempPathGenerator;
-import com.google.devtools.build.lib.vfs.DigestHashFunction;
-import com.google.devtools.build.lib.vfs.FileSystem;
 import com.google.devtools.build.lib.vfs.FileSystemUtils;
 import com.google.devtools.build.lib.vfs.Path;
 import com.google.devtools.build.lib.vfs.SyscallCache;
-import com.google.devtools.build.lib.vfs.inmemoryfs.InMemoryFileSystem;
 import com.google.devtools.common.options.Options;
-import com.google.protobuf.ByteString;
 import java.io.IOException;
 import java.nio.charset.StandardCharsets;
 import java.util.HashMap;
 import java.util.Map;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.Semaphore;
-import java.util.concurrent.atomic.AtomicBoolean;
-import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
 /** Tests for {@link RemoteActionInputFetcher}. */
 @RunWith(JUnit4.class)
-public class RemoteActionInputFetcherTest {
+public class RemoteActionInputFetcherTest extends ActionInputPrefetcherTestBase {
 
-  private static final DigestHashFunction HASH_FUNCTION = DigestHashFunction.SHA256;
-
-  private Path execRoot;
-  private TempPathGenerator tempPathGenerator;
-  private ArtifactRoot artifactRoot;
   private RemoteOptions options;
   private DigestUtil digestUtil;
 
-  @Before
+  @Override
   public void setUp() throws IOException {
-    FileSystem fs = new InMemoryFileSystem(new JavaClock(), HASH_FUNCTION);
-    execRoot = fs.getPath("/exec");
-    execRoot.createDirectoryAndParents();
-    Path tempDir = fs.getPath("/tmp");
-    tempDir.createDirectoryAndParents();
-    tempPathGenerator = new TempPathGenerator(tempDir);
+    super.setUp();
     Path dev = fs.getPath("/dev");
     dev.createDirectory();
     dev.setWritable(false);
-    artifactRoot = ArtifactRoot.asDerivedRoot(execRoot, RootType.Output, "root");
-    artifactRoot.getRoot().asPath().createDirectoryAndParents();
     options = Options.getDefaults(RemoteOptions.class);
     digestUtil = new DigestUtil(SyscallCache.NO_CACHE, HASH_FUNCTION);
   }
 
-  @Test
-  public void testFetching() throws Exception {
-    // arrange
-    Map<ActionInput, FileArtifactValue> metadata = new HashMap<>();
-    Map<Digest, ByteString> cacheEntries = new HashMap<>();
-    Artifact a1 = createRemoteArtifact("file1", "hello world", metadata, cacheEntries);
-    Artifact a2 = createRemoteArtifact("file2", "fizz buzz", metadata, cacheEntries);
-    MetadataProvider metadataProvider = new StaticMetadataProvider(metadata);
-    RemoteCache remoteCache = newCache(options, digestUtil, cacheEntries);
-    RemoteActionInputFetcher actionInputFetcher =
-        new RemoteActionInputFetcher("none", "none", remoteCache, execRoot, tempPathGenerator);
-
-    // act
-    wait(actionInputFetcher.prefetchFiles(metadata.keySet(), metadataProvider));
-
-    // assert
-    assertThat(FileSystemUtils.readContent(a1.getPath(), StandardCharsets.UTF_8))
-        .isEqualTo("hello world");
-    assertThat(a1.getPath().isExecutable()).isTrue();
-    assertThat(FileSystemUtils.readContent(a2.getPath(), StandardCharsets.UTF_8))
-        .isEqualTo("fizz buzz");
-    assertThat(a2.getPath().isExecutable()).isTrue();
-    assertThat(actionInputFetcher.downloadedFiles()).hasSize(2);
-    assertThat(actionInputFetcher.downloadedFiles()).containsAtLeast(a1.getPath(), a2.getPath());
-    assertThat(actionInputFetcher.downloadsInProgress()).isEmpty();
+  @Override
+  protected AbstractActionInputPrefetcher createPrefetcher(Map<HashCode, byte[]> cas) {
+    RemoteCache remoteCache = newCache(options, digestUtil, cas);
+    return new RemoteActionInputFetcher("none", "none", remoteCache, execRoot, tempPathGenerator);
   }
 
   @Test
@@ -159,287 +99,14 @@
     assertThat(actionInputFetcher.downloadsInProgress()).isEmpty();
   }
 
-  @Test
-  public void testFileNotFound() throws Exception {
-    // Test that we get an exception if an input file is missing
-
-    // arrange
-    Map<ActionInput, FileArtifactValue> metadata = new HashMap<>();
-    Artifact a =
-        createRemoteArtifact("file1", "hello world", metadata, /* cacheEntries= */ new HashMap<>());
-    MetadataProvider metadataProvider = new StaticMetadataProvider(metadata);
-    RemoteCache remoteCache = newCache(options, digestUtil, new HashMap<>());
-    RemoteActionInputFetcher actionInputFetcher =
-        new RemoteActionInputFetcher("none", "none", remoteCache, execRoot, tempPathGenerator);
-
-    // act
-    assertThrows(
-        BulkTransferException.class,
-        () -> wait(actionInputFetcher.prefetchFiles(ImmutableList.of(a), metadataProvider)));
-
-    // assert
-    assertThat(actionInputFetcher.downloadedFiles()).isEmpty();
-    assertThat(actionInputFetcher.downloadsInProgress()).isEmpty();
-  }
-
-  @Test
-  public void testIgnoreNoneRemoteFiles() throws Exception {
-    // Test that files that are not remote are not downloaded
-
-    // arrange
-    Path p = execRoot.getRelative(artifactRoot.getExecPath()).getRelative("file1");
-    FileSystemUtils.writeContent(p, StandardCharsets.UTF_8, "hello world");
-    Artifact a = ActionsTestUtil.createArtifact(artifactRoot, p);
-    FileArtifactValue f = FileArtifactValue.createForTesting(a);
-    MetadataProvider metadataProvider = new StaticMetadataProvider(ImmutableMap.of(a, f));
-    RemoteCache remoteCache = newCache(options, digestUtil, new HashMap<>());
-    RemoteActionInputFetcher actionInputFetcher =
-        new RemoteActionInputFetcher("none", "none", remoteCache, execRoot, tempPathGenerator);
-
-    // act
-    wait(actionInputFetcher.prefetchFiles(ImmutableList.of(a), metadataProvider));
-
-    // assert
-    assertThat(actionInputFetcher.downloadedFiles()).isEmpty();
-    assertThat(actionInputFetcher.downloadsInProgress()).isEmpty();
-  }
-
-  @Test
-  public void testDownloadFile() throws Exception {
-    // arrange
-    Map<ActionInput, FileArtifactValue> metadata = new HashMap<>();
-    Map<Digest, ByteString> cacheEntries = new HashMap<>();
-    Artifact a1 = createRemoteArtifact("file1", "hello world", metadata, cacheEntries);
-    RemoteCache remoteCache = newCache(options, digestUtil, cacheEntries);
-    RemoteActionInputFetcher actionInputFetcher =
-        new RemoteActionInputFetcher("none", "none", remoteCache, execRoot, tempPathGenerator);
-
-    // act
-    actionInputFetcher.downloadFile(a1.getPath(), metadata.get(a1));
-
-    // assert
-    assertThat(FileSystemUtils.readContent(a1.getPath(), StandardCharsets.UTF_8))
-        .isEqualTo("hello world");
-    assertThat(a1.getPath().isExecutable()).isTrue();
-    assertThat(a1.getPath().isReadable()).isTrue();
-    assertThat(a1.getPath().isWritable()).isFalse();
-  }
-
-  @Test
-  public void testDownloadFile_onInterrupt_deletePartialDownloadedFile() throws Exception {
-    Semaphore startSemaphore = new Semaphore(0);
-    Semaphore endSemaphore = new Semaphore(0);
-    Map<ActionInput, FileArtifactValue> metadata = new HashMap<>();
-    Map<Digest, ByteString> cacheEntries = new HashMap<>();
-    Artifact a1 = createRemoteArtifact("file1", "hello world", metadata, cacheEntries);
-    RemoteCache remoteCache = mock(RemoteCache.class);
-    mockDownload(
-        remoteCache,
-        cacheEntries,
-        () -> {
-          startSemaphore.release();
-          return SettableFuture.create(); // A future that never complete so we can interrupt later
-        });
-    RemoteActionInputFetcher actionInputFetcher =
-        new RemoteActionInputFetcher("none", "none", remoteCache, execRoot, tempPathGenerator);
-
-    AtomicBoolean interrupted = new AtomicBoolean(false);
-    Thread t =
-        new Thread(
-            () -> {
-              try {
-                actionInputFetcher.downloadFile(a1.getPath(), metadata.get(a1));
-              } catch (IOException ignored) {
-                interrupted.set(false);
-              } catch (InterruptedException e) {
-                interrupted.set(true);
-              }
-              endSemaphore.release();
-            });
-    t.start();
-    startSemaphore.acquire();
-    t.interrupt();
-    endSemaphore.acquire();
-
-    assertThat(interrupted.get()).isTrue();
-    assertThat(a1.getPath().exists()).isFalse();
-    assertThat(tempPathGenerator.getTempDir().getDirectoryEntries()).isEmpty();
-  }
-
-  @Test
-  public void testPrefetchFiles_multipleThreads_downloadIsNotCancelledByOtherThreads()
-      throws Exception {
-    // Test multiple threads can share downloads, but do not cancel each other when interrupted
-
-    // arrange
-    Map<ActionInput, FileArtifactValue> metadata = new HashMap<>();
-    Map<Digest, ByteString> cacheEntries = new HashMap<>();
-    Artifact artifact = createRemoteArtifact("file1", "hello world", metadata, cacheEntries);
-    MetadataProvider metadataProvider = new StaticMetadataProvider(metadata);
-    SettableFuture<Void> download = SettableFuture.create();
-    RemoteCache remoteCache = mock(RemoteCache.class);
-    mockDownload(remoteCache, cacheEntries, () -> download);
-    RemoteActionInputFetcher actionInputFetcher =
-        new RemoteActionInputFetcher("none", "none", remoteCache, execRoot, tempPathGenerator);
-    Thread cancelledThread =
-        new Thread(
-            () -> {
-              try {
-                wait(
-                    actionInputFetcher.prefetchFiles(ImmutableList.of(artifact), metadataProvider));
-              } catch (IOException | InterruptedException ignored) {
-                // do nothing
-              }
-            });
-
-    AtomicBoolean successful = new AtomicBoolean(false);
-    Thread successfulThread =
-        new Thread(
-            () -> {
-              try {
-                wait(
-                    actionInputFetcher.prefetchFiles(ImmutableList.of(artifact), metadataProvider));
-                successful.set(true);
-              } catch (IOException | InterruptedException ignored) {
-                // do nothing
-              }
-            });
-    cancelledThread.start();
-    successfulThread.start();
-    while (true) {
-      if (actionInputFetcher
-              .getDownloadCache()
-              .getSubscriberCount(execRoot.getRelative(artifact.getExecPath()))
-          == 2) {
-        break;
-      }
-    }
-
-    // act
-    cancelledThread.interrupt();
-    cancelledThread.join();
-    // simulate the download finishing
-    assertThat(download.isCancelled()).isFalse();
-    download.set(null);
-    successfulThread.join();
-
-    // assert
-    assertThat(successful.get()).isTrue();
-    assertThat(FileSystemUtils.readContent(artifact.getPath(), StandardCharsets.UTF_8))
-        .isEqualTo("hello world");
-  }
-
-  @Test
-  public void testPrefetchFiles_multipleThreads_downloadIsCancelled() throws Exception {
-    // Test shared downloads are cancelled if all threads/callers are interrupted
-
-    // arrange
-    Map<ActionInput, FileArtifactValue> metadata = new HashMap<>();
-    Map<Digest, ByteString> cacheEntries = new HashMap<>();
-    Artifact artifact = createRemoteArtifact("file1", "hello world", metadata, cacheEntries);
-    MetadataProvider metadataProvider = new StaticMetadataProvider(metadata);
-
-    SettableFuture<Void> download = SettableFuture.create();
-    RemoteCache remoteCache = mock(RemoteCache.class);
-    mockDownload(remoteCache, cacheEntries, () -> download);
-    RemoteActionInputFetcher actionInputFetcher =
-        new RemoteActionInputFetcher("none", "none", remoteCache, execRoot, tempPathGenerator);
-
-    Thread cancelledThread1 =
-        new Thread(
-            () -> {
-              try {
-                wait(
-                    actionInputFetcher.prefetchFiles(ImmutableList.of(artifact), metadataProvider));
-              } catch (IOException | InterruptedException ignored) {
-                // do nothing
-              }
-            });
-
-    Thread cancelledThread2 =
-        new Thread(
-            () -> {
-              try {
-                wait(
-                    actionInputFetcher.prefetchFiles(ImmutableList.of(artifact), metadataProvider));
-              } catch (IOException | InterruptedException ignored) {
-                // do nothing
-              }
-            });
-
-    // act
-    cancelledThread1.start();
-    cancelledThread2.start();
-    cancelledThread1.interrupt();
-    cancelledThread2.interrupt();
-    cancelledThread1.join();
-    cancelledThread2.join();
-
-    // assert
-    assertThat(download.isCancelled()).isTrue();
-    assertThat(artifact.getPath().exists()).isFalse();
-    assertThat(tempPathGenerator.getTempDir().getDirectoryEntries()).isEmpty();
-  }
-
-  private Artifact createRemoteArtifact(
-      String pathFragment,
-      String contents,
-      Map<ActionInput, FileArtifactValue> metadata,
-      Map<Digest, ByteString> cacheEntries) {
-    Path p = artifactRoot.getRoot().getRelative(pathFragment);
-    Artifact a = ActionsTestUtil.createArtifact(artifactRoot, p);
-    byte[] b = contents.getBytes(StandardCharsets.UTF_8);
-    HashCode h = HASH_FUNCTION.getHashFunction().hashBytes(b);
-    FileArtifactValue f =
-        new RemoteFileArtifactValue(h.asBytes(), b.length, /* locationIndex= */ 1, "action-id");
-    metadata.put(a, f);
-    cacheEntries.put(DigestUtil.buildDigest(h.asBytes(), b.length), ByteString.copyFrom(b));
-    return a;
-  }
-
   private RemoteCache newCache(
-      RemoteOptions options, DigestUtil digestUtil, Map<Digest, ByteString> cacheEntries) {
-    Map<Digest, byte[]> cacheEntriesByteArray =
-        Maps.newHashMapWithExpectedSize(cacheEntries.size());
-    for (Map.Entry<Digest, ByteString> entry : cacheEntries.entrySet()) {
-      cacheEntriesByteArray.put(entry.getKey(), entry.getValue().toByteArray());
+      RemoteOptions options, DigestUtil digestUtil, Map<HashCode, byte[]> cas) {
+    Map<Digest, byte[]> cacheEntries = Maps.newHashMapWithExpectedSize(cas.size());
+    for (Map.Entry<HashCode, byte[]> entry : cas.entrySet()) {
+      cacheEntries.put(
+          DigestUtil.buildDigest(entry.getKey().asBytes(), entry.getValue().length),
+          entry.getValue());
     }
-    return new RemoteCache(new InMemoryCacheClient(cacheEntriesByteArray), options, digestUtil);
-  }
-
-  private static void wait(ListenableFuture<Void> future) throws IOException, InterruptedException {
-    try {
-      future.get();
-    } catch (ExecutionException e) {
-      Throwable cause = e.getCause();
-      if (cause != null) {
-        throwIfInstanceOf(cause, IOException.class);
-        throwIfInstanceOf(cause, InterruptedException.class);
-        throwIfInstanceOf(cause, RuntimeException.class);
-      }
-      throw new IOException(e);
-    } catch (InterruptedException e) {
-      future.cancel(/*mayInterruptIfRunning=*/ true);
-      throw e;
-    }
-  }
-
-  private static void mockDownload(
-      RemoteCache remoteCache,
-      Map<Digest, ByteString> cacheEntries,
-      Supplier<ListenableFuture<Void>> resultSupplier)
-      throws IOException {
-    when(remoteCache.downloadFile(any(), any(), any()))
-        .thenAnswer(
-            invocation -> {
-              Path path = invocation.getArgument(1);
-              Digest digest = invocation.getArgument(2);
-              ByteString content = cacheEntries.get(digest);
-              if (content == null) {
-                return Futures.immediateFailedFuture(new IOException("Not found"));
-              }
-              FileSystemUtils.writeContent(path, content.toByteArray());
-              return resultSupplier.get();
-            });
+    return new RemoteCache(new InMemoryCacheClient(cacheEntries), options, digestUtil);
   }
 }