remote: fix race on download error. Fixes #5047

For downloading output files / directories we trigger all
downloads concurrently and asynchronously in the background
and after that wait for all downloads to finish. However, if
a download failed we did not wait for the remaining downloads
to finish but immediately started deleting partial downloads
and continued with local execution of the action.

That leads to two interesting bugs:
* The cleanup procedure races with the downloads that are still
in progress. As it tries to delete files and directories, new
files and directories are created and that will often
lead to "Directory not empty" errors as seen in #5047.
* The clean up procedure does not detect the race, succeeds and
subsequent local execution fails because not all files have
been deleted.

The solution is to always wait for all downloads to complete
before entering the cleanup routine. Ideally we would also
cancel all outstanding downloads, however, that's not as
straightfoward as it seems. That is, the j.u.c.Future API does
not provide a way to cancel a computation and also wait for
that computation actually having determinated. So we'd need
to introduce a separate mechanism to cancel downloads.

RELNOTES: None
PiperOrigin-RevId: 205980446
diff --git a/src/main/java/com/google/devtools/build/lib/remote/AbstractRemoteActionCache.java b/src/main/java/com/google/devtools/build/lib/remote/AbstractRemoteActionCache.java
index ef90223..66f29c5 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/AbstractRemoteActionCache.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/AbstractRemoteActionCache.java
@@ -13,8 +13,7 @@
 // limitations under the License.
 package com.google.devtools.build.lib.remote;
 
-import static com.google.devtools.build.lib.remote.util.Utils.getFromFuture;
-
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 import com.google.common.util.concurrent.FutureCallback;
 import com.google.common.util.concurrent.Futures;
@@ -27,6 +26,7 @@
 import com.google.devtools.build.lib.concurrent.ThreadSafety;
 import com.google.devtools.build.lib.remote.TreeNodeRepository.TreeNode;
 import com.google.devtools.build.lib.remote.util.DigestUtil;
+import com.google.devtools.build.lib.remote.util.Utils;
 import com.google.devtools.build.lib.util.io.FileOutErr;
 import com.google.devtools.build.lib.vfs.Dirent;
 import com.google.devtools.build.lib.vfs.FileStatus;
@@ -170,68 +170,84 @@
   // TODO(olaola): will need to amend to include the TreeNodeRepository for updating.
   public void download(ActionResult result, Path execRoot, FileOutErr outErr)
       throws ExecException, IOException, InterruptedException {
-    try {
-      Context ctx = Context.current();
-      List<FuturePathBooleanTuple> fileDownloads =
-          Collections.synchronizedList(
-              new ArrayList<>(result.getOutputFilesCount() + result.getOutputDirectoriesCount()));
-      for (OutputFile file : result.getOutputFilesList()) {
-        Path path = execRoot.getRelative(file.getPath());
-        ListenableFuture<Void> download =
-            retrier.executeAsync(
-                () -> ctx.call(() -> downloadFile(path, file.getDigest(), file.getContent())));
-        fileDownloads.add(new FuturePathBooleanTuple(download, path, file.getIsExecutable()));
-      }
+    Context ctx = Context.current();
+    List<FuturePathBooleanTuple> fileDownloads =
+        Collections.synchronizedList(
+            new ArrayList<>(result.getOutputFilesCount() + result.getOutputDirectoriesCount()));
+    for (OutputFile file : result.getOutputFilesList()) {
+      Path path = execRoot.getRelative(file.getPath());
+      ListenableFuture<Void> download =
+          retrier.executeAsync(
+              () -> ctx.call(() -> downloadFile(path, file.getDigest(), file.getContent())));
+      fileDownloads.add(new FuturePathBooleanTuple(download, path, file.getIsExecutable()));
+    }
 
-      List<ListenableFuture<Void>> dirDownloads =
-          new ArrayList<>(result.getOutputDirectoriesCount());
-      for (OutputDirectory dir : result.getOutputDirectoriesList()) {
-        SettableFuture<Void> dirDownload = SettableFuture.create();
-        ListenableFuture<byte[]> protoDownload =
-            retrier.executeAsync(() -> ctx.call(() -> downloadBlob(dir.getTreeDigest())));
-        Futures.addCallback(
-            protoDownload,
-            new FutureCallback<byte[]>() {
-              @Override
-              public void onSuccess(byte[] b) {
-                try {
-                  Tree tree = Tree.parseFrom(b);
-                  Map<Digest, Directory> childrenMap = new HashMap<>();
-                  for (Directory child : tree.getChildrenList()) {
-                    childrenMap.put(digestUtil.compute(child), child);
-                  }
-                  Path path = execRoot.getRelative(dir.getPath());
-                  fileDownloads.addAll(downloadDirectory(path, tree.getRoot(), childrenMap, ctx));
-                  dirDownload.set(null);
-                } catch (IOException e) {
-                  dirDownload.setException(e);
+    List<ListenableFuture<Void>> dirDownloads = new ArrayList<>(result.getOutputDirectoriesCount());
+    for (OutputDirectory dir : result.getOutputDirectoriesList()) {
+      SettableFuture<Void> dirDownload = SettableFuture.create();
+      ListenableFuture<byte[]> protoDownload =
+          retrier.executeAsync(() -> ctx.call(() -> downloadBlob(dir.getTreeDigest())));
+      Futures.addCallback(
+          protoDownload,
+          new FutureCallback<byte[]>() {
+            @Override
+            public void onSuccess(byte[] b) {
+              try {
+                Tree tree = Tree.parseFrom(b);
+                Map<Digest, Directory> childrenMap = new HashMap<>();
+                for (Directory child : tree.getChildrenList()) {
+                  childrenMap.put(digestUtil.compute(child), child);
                 }
+                Path path = execRoot.getRelative(dir.getPath());
+                fileDownloads.addAll(downloadDirectory(path, tree.getRoot(), childrenMap, ctx));
+                dirDownload.set(null);
+              } catch (IOException e) {
+                dirDownload.setException(e);
               }
+            }
 
-              @Override
-              public void onFailure(Throwable t) {
-                dirDownload.setException(t);
-              }
-            },
-            MoreExecutors.directExecutor());
-        dirDownloads.add(dirDownload);
-      }
+            @Override
+            public void onFailure(Throwable t) {
+              dirDownload.setException(t);
+            }
+          },
+          MoreExecutors.directExecutor());
+      dirDownloads.add(dirDownload);
+    }
 
+    // Subsequently we need to wait for *every* download to finish, even if we already know that
+    // one failed. That's so that when exiting this method we can be sure that all downloads have
+    // finished and don't race with the cleanup routine.
+    // TODO(buchgr): Look into cancellation.
+
+    IOException downloadException = null;
+    try {
       fileDownloads.addAll(downloadOutErr(result, outErr, ctx));
-
-      for (ListenableFuture<Void> dirDownload : dirDownloads) {
-        // Block on all directory download futures, so that we can be sure that we have discovered
-        // all file downloads and can subsequently safely iterate over the list of file downloads.
+    } catch (IOException e) {
+      downloadException = e;
+    }
+    for (ListenableFuture<Void> dirDownload : dirDownloads) {
+      // Block on all directory download futures, so that we can be sure that we have discovered
+      // all file downloads and can subsequently safely iterate over the list of file downloads.
+      try {
         getFromFuture(dirDownload);
+      } catch (IOException e) {
+        downloadException = downloadException == null ? e : downloadException;
       }
+    }
 
-      for (FuturePathBooleanTuple download : fileDownloads) {
+    for (FuturePathBooleanTuple download : fileDownloads) {
+      try {
         getFromFuture(download.getFuture());
         if (download.getPath() != null) {
           download.getPath().setExecutable(download.isExecutable());
         }
+      } catch (IOException e) {
+        downloadException = downloadException == null ? e : downloadException;
       }
-    } catch (IOException downloadException) {
+    }
+
+    if (downloadException != null) {
       try {
         // Delete any (partially) downloaded output files, since any subsequent local execution
         // of this action may expect none of the output files to exist.
@@ -261,6 +277,11 @@
     }
   }
 
+  @VisibleForTesting
+  protected <T> T getFromFuture(ListenableFuture<T> f) throws IOException, InterruptedException {
+    return Utils.getFromFuture(f);
+  }
+
   /** Tuple of {@code ListenableFuture, Path, boolean}. */
   private static class FuturePathBooleanTuple {
     private final ListenableFuture<?> future;
diff --git a/src/main/java/com/google/devtools/build/lib/remote/SimpleBlobStoreActionCache.java b/src/main/java/com/google/devtools/build/lib/remote/SimpleBlobStoreActionCache.java
index b1b0d60..8e6269a 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/SimpleBlobStoreActionCache.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/SimpleBlobStoreActionCache.java
@@ -14,8 +14,6 @@
 
 package com.google.devtools.build.lib.remote;
 
-import static com.google.devtools.build.lib.remote.util.Utils.getFromFuture;
-
 import com.google.common.util.concurrent.FutureCallback;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
diff --git a/src/test/java/com/google/devtools/build/lib/remote/AbstractRemoteActionCacheTests.java b/src/test/java/com/google/devtools/build/lib/remote/AbstractRemoteActionCacheTests.java
index fe4881c..7afe9b8 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/AbstractRemoteActionCacheTests.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/AbstractRemoteActionCacheTests.java
@@ -15,19 +15,46 @@
 
 import static com.google.common.truth.Truth.assertThat;
 import static com.google.devtools.build.lib.testutil.MoreAsserts.assertThrows;
+import static org.junit.Assert.fail;
 
+import com.google.common.base.Throwables;
 import com.google.common.collect.ImmutableList;
+import com.google.common.util.concurrent.FutureCallback;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.ListeningScheduledExecutorService;
+import com.google.common.util.concurrent.MoreExecutors;
+import com.google.common.util.concurrent.SettableFuture;
 import com.google.devtools.build.lib.actions.ExecException;
 import com.google.devtools.build.lib.clock.JavaClock;
 import com.google.devtools.build.lib.remote.AbstractRemoteActionCache.UploadManifest;
+import com.google.devtools.build.lib.remote.TreeNodeRepository.TreeNode;
 import com.google.devtools.build.lib.remote.util.DigestUtil;
+import com.google.devtools.build.lib.remote.util.DigestUtil.ActionKey;
+import com.google.devtools.build.lib.remote.util.Utils;
+import com.google.devtools.build.lib.util.io.FileOutErr;
 import com.google.devtools.build.lib.vfs.DigestHashFunction;
 import com.google.devtools.build.lib.vfs.FileSystem;
 import com.google.devtools.build.lib.vfs.FileSystemUtils;
 import com.google.devtools.build.lib.vfs.Path;
 import com.google.devtools.build.lib.vfs.inmemoryfs.InMemoryFileSystem;
 import com.google.devtools.remoteexecution.v1test.ActionResult;
+import com.google.devtools.remoteexecution.v1test.Command;
+import com.google.devtools.remoteexecution.v1test.Digest;
+import com.google.devtools.remoteexecution.v1test.OutputFile;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicInteger;
+import javax.annotation.Nullable;
+import org.junit.AfterClass;
 import org.junit.Before;
+import org.junit.BeforeClass;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -40,6 +67,13 @@
   private Path execRoot;
   private final DigestUtil digestUtil = new DigestUtil(DigestHashFunction.SHA256);
 
+  private static ListeningScheduledExecutorService retryService;
+
+  @BeforeClass
+  public static void beforeEverything() {
+    retryService = MoreExecutors.listeningDecorator(Executors.newScheduledThreadPool(1));
+  }
+
   @Before
   public void setUp() throws Exception {
     fs = new InMemoryFileSystem(new JavaClock(), DigestHashFunction.SHA256);
@@ -47,6 +81,11 @@
     execRoot.createDirectory();
   }
 
+  @AfterClass
+  public static void afterEverything() {
+    retryService.shutdownNow();
+  }
+
   @Test
   public void uploadSymlinkAsFile() throws Exception {
     ActionResult.Builder result = ActionResult.newBuilder();
@@ -92,4 +131,118 @@
         .hasMessageThat()
         .contains("Only regular files and directories may be uploaded to a remote cache.");
   }
+
+  @Test
+  public void onErrorWaitForRemainingDownloadsToComplete() throws Exception {
+    // If one or more downloads of output files / directories fail then the code should
+    // wait for all downloads to have been completed before it tries to clean up partially
+    // downloaded files.
+
+    Path stdout = fs.getPath("/execroot/stdout");
+    Path stderr = fs.getPath("/execroot/stderr");
+
+    Map<Digest, ListenableFuture<byte[]>> downloadResults = new HashMap<>();
+    Path file1 = fs.getPath("/execroot/file1");
+    Digest digest1 = digestUtil.compute("file1".getBytes("UTF-8"));
+    downloadResults.put(digest1, Futures.immediateFuture("file1".getBytes("UTF-8")));
+    Path file2 = fs.getPath("/execroot/file2");
+    Digest digest2 = digestUtil.compute("file2".getBytes("UTF-8"));
+    downloadResults.put(digest2, Futures.immediateFailedFuture(new IOException("download failed")));
+    Path file3 = fs.getPath("/execroot/file3");
+    Digest digest3 = digestUtil.compute("file3".getBytes("UTF-8"));
+    downloadResults.put(digest3, Futures.immediateFuture("file3".getBytes("UTF-8")));
+
+    RemoteOptions options = new RemoteOptions();
+    RemoteRetrier retrier = new RemoteRetrier(options, (e) -> false, retryService,
+        Retrier.ALLOW_ALL_CALLS);
+    List<ListenableFuture<?>> blockingDownloads = new ArrayList<>();
+    AtomicInteger numSuccess = new AtomicInteger();
+    AtomicInteger numFailures = new AtomicInteger();
+    AbstractRemoteActionCache cache = new DefaultRemoteActionCache(options, digestUtil, retrier) {
+      @Override
+      public ListenableFuture<Void> downloadBlob(Digest digest, OutputStream out) {
+        SettableFuture<Void> result = SettableFuture.create();
+        Futures.addCallback(downloadResults.get(digest), new FutureCallback<byte[]>() {
+          @Override
+          public void onSuccess(byte[] bytes) {
+            numSuccess.incrementAndGet();
+            try {
+              out.write(bytes);
+              out.close();
+              result.set(null);
+            } catch (IOException e) {
+              result.setException(e);
+            }
+          }
+
+          @Override
+          public void onFailure(Throwable throwable) {
+            numFailures.incrementAndGet();
+            result.setException(throwable);
+          }
+        }, MoreExecutors.directExecutor());
+        return result;
+      }
+
+      @Override
+      protected <T> T getFromFuture(ListenableFuture<T> f)
+          throws IOException, InterruptedException {
+        blockingDownloads.add(f);
+        return Utils.getFromFuture(f);
+      }
+    };
+
+    ActionResult result = ActionResult.newBuilder()
+        .setExitCode(0)
+        .addOutputFiles(OutputFile.newBuilder().setPath(file1.getPathString()).setDigest(digest1))
+        .addOutputFiles(OutputFile.newBuilder().setPath(file2.getPathString()).setDigest(digest2))
+        .addOutputFiles(OutputFile.newBuilder().setPath(file3.getPathString()).setDigest(digest3))
+        .build();
+    try {
+      cache.download(result, execRoot, new FileOutErr(stdout, stderr));
+      fail("Expected IOException");
+    } catch (IOException e) {
+      assertThat(numSuccess.get()).isEqualTo(2);
+      assertThat(numFailures.get()).isEqualTo(1);
+      assertThat(blockingDownloads).hasSize(3);
+      assertThat(Throwables.getRootCause(e)).hasMessageThat().isEqualTo("download failed");
+    }
+  }
+
+  private static class DefaultRemoteActionCache extends AbstractRemoteActionCache {
+
+    public DefaultRemoteActionCache(RemoteOptions options,
+        DigestUtil digestUtil, Retrier retrier) {
+      super(options, digestUtil, retrier);
+    }
+
+    @Override
+    public void ensureInputsPresent(TreeNodeRepository repository, Path execRoot, TreeNode root,
+        Command command) throws IOException, InterruptedException {
+      throw new UnsupportedOperationException();
+    }
+
+    @Nullable
+    @Override
+    ActionResult getCachedActionResult(ActionKey actionKey)
+        throws IOException, InterruptedException {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override
+    void upload(ActionKey actionKey, Path execRoot, Collection<Path> files, FileOutErr outErr,
+        boolean uploadAction) throws ExecException, IOException, InterruptedException {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override
+    protected ListenableFuture<Void> downloadBlob(Digest digest, OutputStream out) {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override
+    public void close() {
+      throw new UnsupportedOperationException();
+    }
+  }
 }