Fix data race in prefetcher.

Individual file prefetches within a single prefetchFiles() can race against each other, so they must synchronize when writing to the DirectoryContext.

Closes #17678.

PiperOrigin-RevId: 515024483
Change-Id: Ic8097979d06ab143b4d63f5e90f871f8cbf83959
diff --git a/src/main/java/com/google/devtools/build/lib/remote/AbstractActionInputPrefetcher.java b/src/main/java/com/google/devtools/build/lib/remote/AbstractActionInputPrefetcher.java
index 78bbfd7..4a54f13 100644
--- a/src/main/java/com/google/devtools/build/lib/remote/AbstractActionInputPrefetcher.java
+++ b/src/main/java/com/google/devtools/build/lib/remote/AbstractActionInputPrefetcher.java
@@ -60,7 +60,6 @@
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Deque;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
@@ -87,53 +86,124 @@
 
   private final Set<ActionInput> missingActionInputs = Sets.newConcurrentHashSet();
 
-  // Tracks the number of ongoing prefetcher calls temporarily making an output directory writable.
-  // Since concurrent calls may write to the same directory, it's not safe to make it non-writable
-  // until no other ongoing calls are writing to it.
-  private final ConcurrentHashMap<Path, Integer> temporarilyWritableDirectories =
+  private static final Object dummyValue = new Object();
+
+  /**
+   * Tracks output directories temporarily made writable for prefetching. Since concurrent calls may
+   * write to the same directory, it's not safe to make it non-writable until no other ongoing
+   * prefetcher calls are writing to it.
+   */
+  private final ConcurrentHashMap<Path, DirectoryState> temporarilyWritableDirectories =
       new ConcurrentHashMap<>();
 
-  /** Keeps track of output directories written to by a single prefetcher call. */
+  /** The state of a single temporarily writable directory. */
+  private static final class DirectoryState {
+    /** The number of ongoing prefetcher calls touching this directory. */
+    int numCalls;
+    /** Whether the output permissions must be set on the directory when prefetching completes. */
+    boolean mustSetOutputPermissions;
+  }
+
+  /**
+   * Tracks output directories written to by a single prefetcher call.
+   *
+   * <p>This makes it possible to set the output permissions on directories touched by the
+   * prefetcher call all at once, so that files prefetched within the same call don't repeatedly set
+   * output permissions on the same directory.
+   */
   private final class DirectoryContext {
-    private final HashSet<Path> dirs = new HashSet<>();
+    private final ConcurrentHashMap<Path, Object> dirs = new ConcurrentHashMap<>();
 
     /**
-     * Adds to the set of directories written to by the prefetcher call associated with this
-     * context.
+     * Makes a directory temporarily writable for the remainder of the prefetcher call associated
+     * with this context.
+     *
+     * @param isDefinitelyTreeDir Whether this directory definitely belongs to a tree artifact.
+     *     Otherwise, whether it belongs to a tree artifact is inferred from its permissions.
      */
-    void add(Path dir) {
-      if (dirs.add(dir)) {
-        temporarilyWritableDirectories.compute(dir, (unused, count) -> count != null ? ++count : 1);
+    void createOrSetWritable(Path dir, boolean isDefinitelyTreeDir) throws IOException {
+      AtomicReference<IOException> caughtException = new AtomicReference<>();
+
+      dirs.compute(
+          dir,
+          (outerUnused, previousValue) -> {
+            if (previousValue != null) {
+              return previousValue;
+            }
+
+            temporarilyWritableDirectories.compute(
+                dir,
+                (innerUnused, state) -> {
+                  if (state == null) {
+                    state = new DirectoryState();
+                    state.numCalls = 0;
+
+                    try {
+                      if (isDefinitelyTreeDir) {
+                        state.mustSetOutputPermissions = true;
+                        var ignored = dir.createWritableDirectory();
+                      } else {
+                        // If the directory is writable, it's a package and should be kept writable.
+                        // Otherwise, it must belong to a tree artifact, since the directory for a
+                        // tree is created in a non-writable state before prefetching begins, and
+                        // this is the first time the prefetcher is seeing it.
+                        state.mustSetOutputPermissions = !dir.isWritable();
+                        if (state.mustSetOutputPermissions) {
+                          dir.setWritable(true);
+                        }
+                      }
+                    } catch (IOException e) {
+                      caughtException.set(e);
+                      return null;
+                    }
+                  }
+
+                  ++state.numCalls;
+
+                  return state;
+                });
+
+            if (caughtException.get() != null) {
+              return null;
+            }
+
+            return dummyValue;
+          });
+
+      if (caughtException.get() != null) {
+        throw caughtException.get();
       }
     }
 
     /**
      * Signals that the prefetcher call associated with this context has finished.
      *
-     * <p>The output permissions will be set on any directories written to by this call that are not
-     * being written to by other concurrent calls.
+     * <p>The output permissions will be set on any directories temporarily made writable by this
+     * call, if this is the last remaining call temporarily making them writable.
      */
     void close() throws IOException {
       AtomicReference<IOException> caughtException = new AtomicReference<>();
 
-      for (Path dir : dirs) {
+      for (Path dir : dirs.keySet()) {
         temporarilyWritableDirectories.compute(
             dir,
-            (unused, count) -> {
-              checkState(count != null);
-              if (--count == 0) {
-                try {
-                  dir.chmod(outputPermissions.getPermissionsMode());
-                } catch (IOException e) {
-                  // Store caught exceptions, but keep cleaning up the map.
-                  if (caughtException.get() == null) {
-                    caughtException.set(e);
-                  } else {
-                    caughtException.get().addSuppressed(e);
+            (unused, state) -> {
+              checkState(state != null);
+              if (--state.numCalls == 0) {
+                if (state.mustSetOutputPermissions) {
+                  try {
+                    dir.chmod(outputPermissions.getPermissionsMode());
+                  } catch (IOException e) {
+                    // Store caught exceptions, but keep cleaning up the map.
+                    if (caughtException.get() == null) {
+                      caughtException.set(e);
+                    } else {
+                      caughtException.get().addSuppressed(e);
+                    }
                   }
                 }
               }
-              return count > 0 ? count : null;
+              return state.numCalls > 0 ? state : null;
             });
       }
       dirs.clear();
@@ -520,24 +590,19 @@
       }
       while (!dirs.isEmpty()) {
         Path dir = dirs.pop();
-        dirCtx.add(dir);
         // Create directory or make existing directory writable.
-        var unused = dir.createWritableDirectory();
+        // We know with certainty that the directory belongs to a tree artifact.
+        dirCtx.createOrSetWritable(dir, /* isDefinitelyTreeDir= */ true);
       }
     } else {
-      // If the parent directory is not writable, temporarily make it so.
-      // This is needed when fetching a non-tree artifact nested inside a tree artifact, or a tree
-      // artifact inside a fileset (see b/254844173 for the latter).
-      // TODO(tjgq): Fix the TOCTTOU race between isWritable and setWritable. This requires keeping
-      // track of the original directory permissions. Note that nested artifacts are relatively rare
-      // and will eventually be disallowed (see issue #16729).
-      if (!parentDir.isWritable()) {
-        dirCtx.add(parentDir);
-        parentDir.setWritable(true);
-      }
+      // Temporarily make the parent directory writable if needed.
+      // We don't know with certainty that the directory does not belong to a tree artifact; it
+      // could if the fetched file is a non-tree artifact nested inside a tree artifact, or a
+      // tree artifact inside a fileset (see b/254844173 for the latter).
+      dirCtx.createOrSetWritable(parentDir, /* isDefinitelyTreeDir= */ false);
     }
 
-    // Set output permissions on files (tree subdirectories are handled in stopPrefetching),
+    // Set output permissions on files (tree subdirectories are handled in DirectoryContext#close),
     // matching the behavior of SkyframeActionExecutor#checkOutputs for artifacts produced by local
     // actions.
     tmpPath.chmod(outputPermissions.getPermissionsMode());
diff --git a/src/main/java/com/google/devtools/build/lib/testing/vfs/SpiedFileSystem.java b/src/main/java/com/google/devtools/build/lib/testing/vfs/SpiedFileSystem.java
index 3a7d303..31f13d0 100644
--- a/src/main/java/com/google/devtools/build/lib/testing/vfs/SpiedFileSystem.java
+++ b/src/main/java/com/google/devtools/build/lib/testing/vfs/SpiedFileSystem.java
@@ -50,4 +50,9 @@
   public OutputStream getOutputStream(PathFragment path, boolean append) throws IOException {
     return super.getOutputStream(path, append);
   }
+
+  @Override
+  public boolean createWritableDirectory(PathFragment path) throws IOException {
+    return super.createWritableDirectory(path);
+  }
 }
diff --git a/src/test/java/com/google/devtools/build/lib/remote/ActionInputPrefetcherTestBase.java b/src/test/java/com/google/devtools/build/lib/remote/ActionInputPrefetcherTestBase.java
index c4c9950..3db89fa 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/ActionInputPrefetcherTestBase.java
+++ b/src/test/java/com/google/devtools/build/lib/remote/ActionInputPrefetcherTestBase.java
@@ -25,6 +25,7 @@
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 
 import com.google.common.collect.ImmutableList;
@@ -44,19 +45,17 @@
 import com.google.devtools.build.lib.actions.FileArtifactValue.RemoteFileArtifactValue;
 import com.google.devtools.build.lib.actions.MetadataProvider;
 import com.google.devtools.build.lib.actions.util.ActionsTestUtil;
-import com.google.devtools.build.lib.clock.JavaClock;
 import com.google.devtools.build.lib.remote.util.StaticMetadataProvider;
 import com.google.devtools.build.lib.remote.util.TempPathGenerator;
 import com.google.devtools.build.lib.skyframe.TreeArtifactValue;
+import com.google.devtools.build.lib.testing.vfs.SpiedFileSystem;
 import com.google.devtools.build.lib.util.Pair;
 import com.google.devtools.build.lib.vfs.DigestHashFunction;
 import com.google.devtools.build.lib.vfs.Dirent;
-import com.google.devtools.build.lib.vfs.FileSystem;
 import com.google.devtools.build.lib.vfs.FileSystemUtils;
 import com.google.devtools.build.lib.vfs.Path;
 import com.google.devtools.build.lib.vfs.PathFragment;
 import com.google.devtools.build.lib.vfs.Symlinks;
-import com.google.devtools.build.lib.vfs.inmemoryfs.InMemoryFileSystem;
 import java.io.IOException;
 import java.util.HashMap;
 import java.util.Map;
@@ -72,14 +71,14 @@
 public abstract class ActionInputPrefetcherTestBase {
   protected static final DigestHashFunction HASH_FUNCTION = DigestHashFunction.SHA256;
 
-  protected FileSystem fs;
+  protected SpiedFileSystem fs;
   protected Path execRoot;
   protected ArtifactRoot artifactRoot;
   protected TempPathGenerator tempPathGenerator;
 
   @Before
   public void setUp() throws IOException {
-    fs = new InMemoryFileSystem(new JavaClock(), HASH_FUNCTION);
+    fs = SpiedFileSystem.createInMemorySpy();
     execRoot = fs.getPath("/exec");
     execRoot.createDirectoryAndParents();
     artifactRoot = ArtifactRoot.asDerivedRoot(execRoot, RootType.Output, "root");
@@ -426,6 +425,31 @@
   }
 
   @Test
+  public void prefetchFiles_treeFiles_minimizeFilesystemOperations() throws Exception {
+    Map<ActionInput, FileArtifactValue> metadata = new HashMap<>();
+    Map<HashCode, byte[]> cas = new HashMap<>();
+    Pair<SpecialArtifact, ImmutableList<TreeFileArtifact>> treeAndChildren =
+        createRemoteTreeArtifact(
+            "dir",
+            /* localContentMap= */ ImmutableMap.of("subdir/file1", "content1"),
+            /* remoteContentMap= */ ImmutableMap.of("subdir/file2", "content2"),
+            metadata,
+            cas);
+    SpecialArtifact tree = treeAndChildren.getFirst();
+    ImmutableList<TreeFileArtifact> children = treeAndChildren.getSecond();
+    Artifact firstChild = children.get(0);
+    Artifact secondChild = children.get(1);
+
+    MetadataProvider metadataProvider = new StaticMetadataProvider(metadata);
+    AbstractActionInputPrefetcher prefetcher = createPrefetcher(cas);
+
+    wait(prefetcher.prefetchFiles(ImmutableList.of(firstChild, secondChild), metadataProvider));
+
+    verify(fs, times(1)).createWritableDirectory(tree.getPath().asFragment());
+    verify(fs, times(1)).createWritableDirectory(tree.getPath().getChild("subdir").asFragment());
+  }
+
+  @Test
   public void prefetchFiles_multipleThreads_downloadIsCancelled() throws Exception {
     // Test shared downloads are cancelled if all threads/callers are interrupted
 
diff --git a/src/test/java/com/google/devtools/build/lib/remote/BUILD b/src/test/java/com/google/devtools/build/lib/remote/BUILD
index 0b4b0ca..aa4613d 100644
--- a/src/test/java/com/google/devtools/build/lib/remote/BUILD
+++ b/src/test/java/com/google/devtools/build/lib/remote/BUILD
@@ -84,6 +84,7 @@
         "//src/main/java/com/google/devtools/build/lib/remote/util",
         "//src/main/java/com/google/devtools/build/lib/runtime/commands",
         "//src/main/java/com/google/devtools/build/lib/skyframe:tree_artifact_value",
+        "//src/main/java/com/google/devtools/build/lib/testing/vfs:spied_filesystem",
         "//src/main/java/com/google/devtools/build/lib/util",
         "//src/main/java/com/google/devtools/build/lib/util:abrupt_exit_exception",
         "//src/main/java/com/google/devtools/build/lib/util:exit_code",