Implement a "transitive" mode for memory dumping.

It dumps every object in the Skyframe transitive closure of a SkyValue.

Currently, it's very inefficient for multiple reasons:

* ConcurrentIdentitySet is slow when accessed from many threads. It should be sharded.
* MemoryAccountant cannot be accessed concurrently, so it's protected with a monitor. We should collect the numbers separately for each SkyValue and merge them later.

But at least the implementation of processTransitive() is pretty neat -- mirroring the Skyframe graph with a graph of CompletableFuture instances.

RELNOTES: None.
PiperOrigin-RevId: 623734658
Change-Id: I2eb43ac1ad30a831d1ffb524bc0f4352e3fe7872
diff --git a/src/main/java/com/google/devtools/build/lib/runtime/commands/DumpCommand.java b/src/main/java/com/google/devtools/build/lib/runtime/commands/DumpCommand.java
index 8d608a3..b4627da 100644
--- a/src/main/java/com/google/devtools/build/lib/runtime/commands/DumpCommand.java
+++ b/src/main/java/com/google/devtools/build/lib/runtime/commands/DumpCommand.java
@@ -57,6 +57,7 @@
 import com.google.devtools.build.skyframe.NodeEntry;
 import com.google.devtools.build.skyframe.QueryableGraph.Reason;
 import com.google.devtools.build.skyframe.SkyKey;
+import com.google.devtools.build.skyframe.SkyValue;
 import com.google.devtools.common.options.Converter;
 import com.google.devtools.common.options.EnumConverter;
 import com.google.devtools.common.options.Option;
@@ -77,6 +78,9 @@
 import java.util.Locale;
 import java.util.Map;
 import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutionException;
 import javax.annotation.Nullable;
 
 /** Implementation of the dump command. */
@@ -99,6 +103,8 @@
     SHALLOW,
     /** Dump objects reachable from a single SkyValue */
     DEEP,
+    /** Dump objects in the Skyframe transitive closure of a SkyValue */
+    TRANSITIVE,
   }
 
   /** How to display Skyframe memory use. */
@@ -161,6 +167,7 @@
         switch (word) {
           case "shallow" -> collectionMode = MemoryCollectionMode.SHALLOW;
           case "deep" -> collectionMode = MemoryCollectionMode.DEEP;
+          case "transitive" -> collectionMode = MemoryCollectionMode.TRANSITIVE;
           case "summary" -> displayMode = MemoryDisplayMode.SUMMARY;
           case "count" -> displayMode = MemoryDisplayMode.COUNT;
           case "bytes" -> displayMode = MemoryDisplayMode.BYTES;
@@ -617,7 +624,7 @@
       NodeEntry nodeEntry,
       MemoryMode mode,
       FieldCache fieldCache,
-      MemoryAccountant memoryAccountant,
+      ImmutableList<MemoryAccountant.Measurer> measurers,
       ConcurrentIdentitySet seen)
       throws InterruptedException {
     // Mark all objects accessible from direct dependencies. This will mutate seen, but that's OK.
@@ -636,16 +643,17 @@
 
     // Now traverse the objects reachable from the given SkyValue. Objects reachable from direct
     // dependencies are in "seen" and thus will not be counted.
-    return dumpRamReachable(nodeEntry, mode, fieldCache, memoryAccountant, seen);
+    return dumpRamReachable(nodeEntry, mode, fieldCache, measurers, seen);
   }
 
   private static Stats dumpRamReachable(
       NodeEntry nodeEntry,
       MemoryMode mode,
       FieldCache fieldCache,
-      MemoryAccountant memoryAccountant,
+      ImmutableList<MemoryAccountant.Measurer> measurers,
       ConcurrentIdentitySet seen)
       throws InterruptedException {
+    MemoryAccountant memoryAccountant = new MemoryAccountant(measurers);
     ObjectGraphTraverser traverser =
         new ObjectGraphTraverser(
             fieldCache, mode.reportTransient, seen, true, memoryAccountant, null, mode.needle);
@@ -653,6 +661,82 @@
     return memoryAccountant.getStats();
   }
 
+  private static CompletableFuture<Void> processTransitive(
+      InMemoryGraph graph,
+      SkyKey skyKey,
+      MemoryMode mode,
+      FieldCache fieldCache,
+      MemoryAccountant memoryAccountant,
+      ConcurrentIdentitySet seen,
+      Map<SkyKey, CompletableFuture<Void>> futureMap) {
+    NodeEntry entry = graph.get(null, Reason.OTHER, skyKey);
+    SkyValue value;
+    ImmutableList<SkyKey> directDeps;
+
+    try {
+      value = entry.getValue();
+      directDeps = ImmutableList.copyOf(entry.getDirectDeps());
+    } catch (InterruptedException e) {
+      // This is ugly but will do for now
+      throw new IllegalStateException();
+    }
+
+    return CompletableFuture.supplyAsync(
+            () -> {
+              // First we create list of futures this node depends on:
+              List<CompletableFuture<Void>> futures = new ArrayList<>();
+
+              // We iterate over every direct dep,
+              for (SkyKey dep : directDeps) {
+                futures.add(
+                    // and if not already processed, we create a future for it
+                    futureMap.computeIfAbsent(
+                        dep,
+                        k ->
+                            processTransitive(
+                                graph, dep, mode, fieldCache, memoryAccountant, seen, futureMap)));
+              }
+
+              return ImmutableList.copyOf(futures);
+            })
+        .thenCompose(
+            // Then we merge the futures representing the direct deps into one that fires when all
+            // of them are done,
+            futures -> CompletableFuture.allOf(futures.toArray(new CompletableFuture<?>[] {})))
+        .thenAcceptAsync(
+            // and once that's the case, we iterate over the object graph of that one.
+            done -> {
+              ObjectGraphTraverser traverser =
+                  new ObjectGraphTraverser(
+                      fieldCache,
+                      mode.reportTransient,
+                      seen,
+                      true,
+                      memoryAccountant,
+                      null,
+                      mode.needle);
+              traverser.traverse(value);
+            });
+  }
+
+  private static Stats dumpRamTransitive(
+      InMemoryGraph graph,
+      SkyKey skyKey,
+      MemoryMode mode,
+      FieldCache fieldCache,
+      ImmutableList<MemoryAccountant.Measurer> measurers,
+      ConcurrentIdentitySet seen) {
+    MemoryAccountant memoryAccountant = new MemoryAccountant(measurers);
+    try {
+      processTransitive(
+              graph, skyKey, mode, fieldCache, memoryAccountant, seen, new ConcurrentHashMap<>())
+          .get();
+    } catch (InterruptedException | ExecutionException e) {
+      throw new IllegalStateException(e);
+    }
+    return memoryAccountant.getStats();
+  }
+
   private static Optional<BlazeCommandResult> dumpSkyframeMemory(
       CommandEnvironment env, DumpOptions dumpOptions, PrintStream out)
       throws InterruptedException {
@@ -676,17 +760,17 @@
     CollectionObjectTraverser collectionObjectTraverser = new CollectionObjectTraverser();
     FieldCache fieldCache =
         new FieldCache(ImmutableList.of(buildObjectTraverser, collectionObjectTraverser));
-    MemoryAccountant memoryAccountant =
-        new MemoryAccountant(ImmutableList.of(collectionObjectTraverser));
+    ImmutableList<MemoryAccountant.Measurer> measurers =
+        ImmutableList.of(collectionObjectTraverser);
 
     ConcurrentIdentitySet seen = getBuiltinsSet(env, fieldCache);
     Stats stats =
         switch (dumpOptions.memory.collectionMode) {
-          case DEEP ->
-              dumpRamReachable(nodeEntry, dumpOptions.memory, fieldCache, memoryAccountant, seen);
+          case DEEP -> dumpRamReachable(nodeEntry, dumpOptions.memory, fieldCache, measurers, seen);
           case SHALLOW ->
-              dumpRamShallow(
-                  graph, nodeEntry, dumpOptions.memory, fieldCache, memoryAccountant, seen);
+              dumpRamShallow(graph, nodeEntry, dumpOptions.memory, fieldCache, measurers, seen);
+          case TRANSITIVE ->
+              dumpRamTransitive(graph, skyKey, dumpOptions.memory, fieldCache, measurers, seen);
         };
 
     switch (dumpOptions.memory.displayMode) {
diff --git a/src/main/java/com/google/devtools/build/lib/util/MemoryAccountant.java b/src/main/java/com/google/devtools/build/lib/util/MemoryAccountant.java
index 6765c18..c2315a7 100644
--- a/src/main/java/com/google/devtools/build/lib/util/MemoryAccountant.java
+++ b/src/main/java/com/google/devtools/build/lib/util/MemoryAccountant.java
@@ -80,7 +80,7 @@
   }
 
   @Override
-  public void objectFound(Object o, String context) {
+  public synchronized void objectFound(Object o, String context) {
     if (context == null) {
       if (o.getClass().isArray()) {
         context = "[] " + o.getClass().getComponentType().getName();
diff --git a/src/test/shell/integration/dump_test.sh b/src/test/shell/integration/dump_test.sh
index 3895f86..556f06a 100755
--- a/src/test/shell/integration/dump_test.sh
+++ b/src/test/shell/integration/dump_test.sh
@@ -107,4 +107,20 @@
   expect_not_log "Needle reached by path:"
 }
 
+function test_memory_transitive() {
+  mkdir -p a
+  cat > a/BUILD <<'EOF'
+sh_library(name="a", srcs=["a.sh"], deps=["//b"])
+EOF
+
+  mkdir -p b
+  cat > b/BUILD <<'EOF'
+sh_library(name="b", srcs=["b.sh"], visibility=["//visibility:public"])
+EOF
+
+  bazel build --nobuild //a >& $TEST_log || fail "build failed"
+  bazel dump --memory=transitive,count:configured_target://a >& $TEST_log || fail "dump failed"
+  expect_log "InputFileConfiguredTarget: 2"
+}
+
 run_suite "Tests for 'bazel dump'"