Check digest of local file to decide whether to re-download it.

Normally, for a non-clean build (outputBase is not clean), skyframe is able to detect modifications, invalid the action, and rerun. Before rerun, it will delete the stales outputs. So we only use `path.exists()` to decide whether we should download an input.

However, there are some files  under the outputBase are not tracked by skyframe. In that case, we could wrongly use a staled output.

PiperOrigin-RevId: 483352318
Change-Id: I7e100100e6c3218630c5dc9bf3f900b2de232e0e
diff --git a/src/main/java/com/google/devtools/build/lib/remote/AbstractActionInputPrefetcher.java b/src/main/java/com/google/devtools/build/lib/remote/AbstractActionInputPrefetcher.java
index 572fc24..81b7ebe 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/AbstractActionInputPrefetcher.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/AbstractActionInputPrefetcher.java
@@ -47,6 +47,7 @@
 import io.reactivex.rxjava3.core.Flowable;
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
@@ -102,7 +103,26 @@
     this.patternsToDownload = patternsToDownload;
   }
 
-  protected abstract boolean shouldDownloadFile(Path path, FileArtifactValue metadata);
+  private boolean shouldDownloadFile(Path path, FileArtifactValue metadata) {
+    if (!path.exists()) {
+      return true;
+    }
+
+    // In the most cases, skyframe should be able to detect source files modifications and delete
+    // staled outputs before action execution. However, there are some cases where outputs are not
+    // tracked by skyframe. We compare the digest here to make sure we don't use staled files.
+    try {
+      byte[] digest = path.getFastDigest();
+      if (digest == null) {
+        digest = path.getDigest();
+      }
+      return !Arrays.equals(digest, metadata.getDigest());
+    } catch (IOException ignored) {
+      return true;
+    }
+  }
+
+  protected abstract boolean canDownloadFile(Path path, FileArtifactValue metadata);
 
   /**
    * Downloads file to the given path via its metadata.
@@ -189,13 +209,14 @@
     // TODO(tjgq): Only download individual files that were requested within the tree.
     // This isn't straightforward because multiple tree artifacts may share the same output tree
     // when a ctx.actions.symlink is involved.
-    if (treeMetadata == null || !shouldDownloadAnyTreeFiles(treeFiles, treeMetadata)) {
+    if (treeMetadata == null || !canDownloadAnyTreeFiles(treeFiles, treeMetadata)) {
       return Completable.complete();
     }
 
     PathFragment prefetchExecPath = treeMetadata.getMaterializationExecPath().orElse(execPath);
 
-    Completable prefetch = prefetchInputTree(provider, prefetchExecPath, treeFiles, priority);
+    Completable prefetch =
+        prefetchInputTree(provider, prefetchExecPath, treeFiles, treeMetadata, priority);
 
     // If prefetching to a different path, plant a symlink into it.
     if (!prefetchExecPath.equals(execPath)) {
@@ -207,6 +228,16 @@
     return prefetch;
   }
 
+  private boolean canDownloadAnyTreeFiles(
+      Iterable<TreeFileArtifact> treeFiles, FileArtifactValue metadata) {
+    for (TreeFileArtifact treeFile : treeFiles) {
+      if (canDownloadFile(treeFile.getPath(), metadata)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
   private boolean shouldDownloadAnyTreeFiles(
       Iterable<TreeFileArtifact> treeFiles, FileArtifactValue metadata) {
     for (TreeFileArtifact treeFile : treeFiles) {
@@ -221,6 +252,7 @@
       MetadataProvider provider,
       PathFragment execPath,
       List<TreeFileArtifact> treeFiles,
+      FileArtifactValue treeMetadata,
       Priority priority) {
     Path treeRoot = execRoot.getRelative(execPath);
     HashMap<TreeFileArtifact, Path> treeFileTmpPathMap = new HashMap<>();
@@ -293,7 +325,15 @@
                     }
                   }
                 });
-    return downloadCache.executeIfNot(treeRoot, download);
+    return downloadCache.executeIfNot(
+        treeRoot,
+        Completable.defer(
+            () -> {
+              if (shouldDownloadAnyTreeFiles(treeFiles, treeMetadata)) {
+                return download;
+              }
+              return Completable.complete();
+            }));
   }
 
   private Completable prefetchInputFileOrSymlink(
@@ -306,7 +346,7 @@
     PathFragment execPath = input.getExecPath();
 
     FileArtifactValue metadata = metadataProvider.getMetadata(input);
-    if (metadata == null || !shouldDownloadFile(execRoot.getRelative(execPath), metadata)) {
+    if (metadata == null || !canDownloadFile(execRoot.getRelative(execPath), metadata)) {
       return Completable.complete();
     }
 
@@ -332,7 +372,7 @@
    * download finished.
    */
   private Completable downloadFileRx(Path path, FileArtifactValue metadata, Priority priority) {
-    if (!shouldDownloadFile(path, metadata)) {
+    if (!canDownloadFile(path, metadata)) {
       return Completable.complete();
     }
     return downloadFileNoCheckRx(path, metadata, priority);
@@ -373,7 +413,16 @@
             // Set eager=false here because we want cleanup the download *after* upstream is
             // disposed.
             /* eager= */ false);
-    return downloadCache.executeIfNot(path, download);
+
+    return downloadCache.executeIfNot(
+        finalPath,
+        Completable.defer(
+            () -> {
+              if (shouldDownloadFile(finalPath, metadata)) {
+                return download;
+              }
+              return Completable.complete();
+            }));
   }
 
   /**
diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java
index 413e704..5ae39e9 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java
@@ -70,8 +70,8 @@
   }
 
   @Override
-  protected boolean shouldDownloadFile(Path path, FileArtifactValue metadata) {
-    return metadata.isRemote() && !path.exists();
+  protected boolean canDownloadFile(Path path, FileArtifactValue metadata) {
+    return metadata.isRemote();
   }
 
   @Override
diff --git a/src/test/java/com/google/devtools/build/lib/remote/ActionInputPrefetcherTestBase.java b/src/test/java/com/google/devtools/build/lib/remote/ActionInputPrefetcherTestBase.java
index c570a88..2b0124d 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/ActionInputPrefetcherTestBase.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/ActionInputPrefetcherTestBase.java
@@ -19,8 +19,11 @@
 import static java.nio.charset.StandardCharsets.UTF_8;
 import static org.junit.Assert.assertThrows;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
@@ -157,6 +160,40 @@
   protected abstract AbstractActionInputPrefetcher createPrefetcher(Map<HashCode, byte[]> cas);
 
   @Test
+  public void prefetchFiles_fileExists_doNotDownload() throws IOException, InterruptedException {
+    Map<ActionInput, FileArtifactValue> metadata = new HashMap<>();
+    Map<HashCode, byte[]> cas = new HashMap<>();
+    Artifact a = createRemoteArtifact("file", "hello world", metadata, cas);
+    FileSystemUtils.writeContent(a.getPath(), "hello world".getBytes(UTF_8));
+    MetadataProvider metadataProvider = new StaticMetadataProvider(metadata);
+    AbstractActionInputPrefetcher prefetcher = spy(createPrefetcher(cas));
+
+    wait(prefetcher.prefetchFiles(metadata.keySet(), metadataProvider));
+
+    verify(prefetcher, never()).doDownloadFile(any(), any(), any(), any());
+    assertThat(prefetcher.downloadedFiles()).containsExactly(a.getPath());
+    assertThat(prefetcher.downloadsInProgress()).isEmpty();
+  }
+
+  @Test
+  public void prefetchFiles_fileExistsButContentMismatches_download()
+      throws IOException, InterruptedException {
+    Map<ActionInput, FileArtifactValue> metadata = new HashMap<>();
+    Map<HashCode, byte[]> cas = new HashMap<>();
+    Artifact a = createRemoteArtifact("file", "hello world remote", metadata, cas);
+    FileSystemUtils.writeContent(a.getPath(), "hello world local".getBytes(UTF_8));
+    MetadataProvider metadataProvider = new StaticMetadataProvider(metadata);
+    AbstractActionInputPrefetcher prefetcher = spy(createPrefetcher(cas));
+
+    wait(prefetcher.prefetchFiles(metadata.keySet(), metadataProvider));
+
+    verify(prefetcher).doDownloadFile(any(), eq(a.getExecPath()), any(), any());
+    assertThat(prefetcher.downloadedFiles()).containsExactly(a.getPath());
+    assertThat(prefetcher.downloadsInProgress()).isEmpty();
+    assertThat(FileSystemUtils.readContent(a.getPath(), UTF_8)).isEqualTo("hello world remote");
+  }
+
+  @Test
   public void prefetchFiles_downloadRemoteFiles() throws Exception {
     Map<ActionInput, FileArtifactValue> metadata = new HashMap<>();
     Map<HashCode, byte[]> cas = new HashMap<>();