Properly cancel repository downloads

When cancelling an asynchronous repository download task, an interrupt signal is sent to the download thread. This doesn't mean that the download stops immediately. Avoid restarting a download until the previous download has actually stopped, so that the new download is able to clean old data without crashing (on Windows).

Fixes #21773

Closes #23837.

PiperOrigin-RevId: 686175953
Change-Id: I8d75f905b739d38b6cb430d5b5e84fda9a2d14e3
diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/DownloadManager.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/DownloadManager.java
index 0626f1e..b23a92e 100644
--- a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/DownloadManager.java
+++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/DownloadManager.java
@@ -48,6 +48,7 @@
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
+import java.util.concurrent.Phaser;
 import javax.annotation.Nullable;
 
 /**
@@ -123,9 +124,14 @@
       Path output,
       ExtendedEventHandler eventHandler,
       Map<String, String> clientEnv,
-      String context) {
+      String context,
+      Phaser downloadPhaser) {
     return executorService.submit(
         () -> {
+          if (downloadPhaser.register() != 0) {
+            // Not in download phase, must already have been cancelled.
+            throw new InterruptedException();
+          }
           try (SilentCloseable c = Profiler.instance().profile("fetching: " + context)) {
             return downloadInExecutor(
                 originalUrls,
@@ -138,6 +144,8 @@
                 eventHandler,
                 clientEnv,
                 context);
+          } finally {
+            downloadPhaser.arrive();
           }
         });
   }
diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/starlark/StarlarkBaseExternalContext.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/starlark/StarlarkBaseExternalContext.java
index b8dce6a..25b155e 100644
--- a/src/main/java/com/google/devtools/build/lib/bazel/repository/starlark/StarlarkBaseExternalContext.java
+++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/starlark/StarlarkBaseExternalContext.java
@@ -90,6 +90,7 @@
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
+import java.util.concurrent.Phaser;
 import javax.annotation.Nullable;
 import net.starlark.java.annot.Param;
 import net.starlark.java.annot.ParamType;
@@ -114,9 +115,10 @@
    *
    * <p>The main property of such tasks is that they should under no circumstances keep running
    * after fetching the repository is finished, whether successfully or not. To this end, the {@link
-   * #cancel()} method must stop all such work.
+   * #cancel()} method may be called to interrupt the work and {@link #close()} must be called to
+   * wait for all such work to finish.
    */
-  private interface AsyncTask {
+  private interface AsyncTask extends SilentCloseable {
     /** Returns a user-friendly description of the task. */
     String getDescription();
 
@@ -126,11 +128,21 @@
     /**
      * Cancels the task, if not done yet. Returns false if the task was still in progress.
      *
+     * <p>Note that the task may still be running after this method returns, the task has just got a
+     * signal to interrupt. Call {@link #close()} to wait for the task to finish.
+     *
      * <p>No means of error reporting is provided. Any errors should be reported by other means. The
      * only possible error reported as a consequence of calling this method is one that tells the
      * user that they didn't wait for an async task they should have waited for.
      */
     boolean cancel();
+
+    /**
+     * Waits uninterruptibly until the task is no longer running, even in case it was cancelled but
+     * its underlying thread is still running.
+     */
+    @Override
+    void close();
   }
 
   /** Max. length of command line args added as a profiler description. */
@@ -203,7 +215,12 @@
     // Wait for all (cancelled) async tasks to complete before cleaning up the working directory.
     // This is necessary because downloads may still be in progress and could end up writing to the
     // working directory during deletion, which would cause an error.
+    // Note that just calling executorService.close() doesn't suffice as it considers tasks to be
+    // completed immediately after they are cancelled, without waiting for their underlying thread
+    // to complete.
     executorService.close();
+    asyncTasks.forEach(AsyncTask::close);
+
     if (shouldDeleteWorkingDirectoryOnClose(wasSuccessful)) {
       workingDirectory.deleteTree();
     }
@@ -519,6 +536,7 @@
     private final Optional<Checksum> checksum;
     private final RepositoryFunctionException checksumValidation;
     private final Future<Path> future;
+    private final Phaser downloadPhaser;
     private final Location location;
 
     private PendingDownload(
@@ -528,6 +546,7 @@
         Optional<Checksum> checksum,
         RepositoryFunctionException checksumValidation,
         Future<Path> future,
+        Phaser downloadPhaser,
         Location location) {
       this.executable = executable;
       this.allowFail = allowFail;
@@ -535,6 +554,7 @@
       this.checksum = checksum;
       this.checksumValidation = checksumValidation;
       this.future = future;
+      this.downloadPhaser = downloadPhaser;
       this.location = location;
     }
 
@@ -553,6 +573,18 @@
       return !future.cancel(true);
     }
 
+    @Override
+    public void close() {
+      if (downloadPhaser.register() != 0) {
+        // Not in the download phase, either the download completed normally or
+        // it has completed after a cancellation.
+        return;
+      }
+      try (SilentCloseable c = Profiler.instance().profile("Cancelling download " + outputPath)) {
+        downloadPhaser.arriveAndAwaitAdvance();
+      }
+    }
+
     @StarlarkMethod(
         name = "wait",
         doc =
@@ -590,6 +622,8 @@
           Starlark.errorf(
               "Could not create output path %s: %s", pendingDownload.outputPath, e.getMessage()),
           Transience.PERSISTENT);
+    } finally {
+      pendingDownload.close();
     }
     if (pendingDownload.checksumValidation != null) {
       throw pendingDownload.checksumValidation;
@@ -758,6 +792,7 @@
       checkInOutputDirectory("write", outputPath);
       makeDirectories(outputPath.getPath());
     } catch (IOException e) {
+      Phaser downloadPhaser = new Phaser();
       download =
           new PendingDownload(
               executable,
@@ -766,9 +801,11 @@
               checksum,
               checksumValidation,
               Futures.immediateFailedFuture(e),
+              downloadPhaser,
               thread.getCallerLocation());
     }
     if (download == null) {
+      Phaser downloadPhaser = new Phaser();
       Future<Path> downloadFuture =
           downloadManager.startDownload(
               executorService,
@@ -781,7 +818,8 @@
               outputPath.getPath(),
               env.getListener(),
               envVariables,
-              identifyingStringForLogging);
+              identifyingStringForLogging,
+              downloadPhaser);
       download =
           new PendingDownload(
               executable,
@@ -790,6 +828,7 @@
               checksum,
               checksumValidation,
               downloadFuture,
+              downloadPhaser,
               thread.getCallerLocation());
       registerAsyncTask(download);
     }
@@ -996,6 +1035,7 @@
       downloadDirectory =
           workingDirectory.getFileSystem().getPath(tempDirectory.toFile().getAbsolutePath());
 
+      Phaser downloadPhaser = new Phaser();
       Future<Path> pendingDownload =
           downloadManager.startDownload(
               executorService,
@@ -1008,7 +1048,8 @@
               downloadDirectory,
               env.getListener(),
               envVariables,
-              identifyingStringForLogging);
+              identifyingStringForLogging,
+              downloadPhaser);
       // Ensure that the download is cancelled if the repo rule is restarted as it runs in its own
       // executor.
       PendingDownload pendingTask =
@@ -1019,6 +1060,7 @@
               checksum,
               checksumValidation,
               pendingDownload,
+              downloadPhaser,
               thread.getCallerLocation());
       registerAsyncTask(pendingTask);
       downloadedPath = downloadManager.finalizeDownload(pendingDownload);
diff --git a/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloaderTest.java b/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloaderTest.java
index 9b900cc..341dca9 100644
--- a/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloaderTest.java
+++ b/src/test/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloaderTest.java
@@ -53,6 +53,7 @@
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
+import java.util.concurrent.Phaser;
 import java.util.concurrent.atomic.AtomicInteger;
 import org.junit.After;
 import org.junit.Ignore;
@@ -785,6 +786,7 @@
       Map<String, String> clientEnv,
       String context)
       throws IOException, InterruptedException {
+    Phaser downloadPhaser = new Phaser();
     try (ExecutorService executorService = Executors.newVirtualThreadPerTaskExecutor()) {
       Future<Path> future =
           downloadManager.startDownload(
@@ -798,8 +800,12 @@
               output,
               eventHandler,
               clientEnv,
-              context);
-      return downloadManager.finalizeDownload(future);
+              context,
+              downloadPhaser);
+      Path downloadedPath = downloadManager.finalizeDownload(future);
+      // Should not be in the download phase.
+      assertThat(downloadPhaser.getPhase()).isNotEqualTo(0);
+      return downloadedPath;
     }
   }
 }