Allow to configure the additional profile tasks to be included in the JSON profile

Closes #7217.

PiperOrigin-RevId: 238656293
diff --git a/src/main/java/com/google/devtools/build/lib/profiler/Profiler.java b/src/main/java/com/google/devtools/build/lib/profiler/Profiler.java
index f740286..a7b9cb0 100644
--- a/src/main/java/com/google/devtools/build/lib/profiler/Profiler.java
+++ b/src/main/java/com/google/devtools/build/lib/profiler/Profiler.java
@@ -19,6 +19,7 @@
 import com.google.common.base.Preconditions;
 import com.google.common.base.Predicate;
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Iterables;
 import com.google.devtools.build.lib.clock.Clock;
 import com.google.devtools.build.lib.collect.Extrema;
@@ -386,67 +387,8 @@
     }
   }
 
-  /**
-   * Which {@link ProfilerTask}s are profiled.
-   */
-  public enum ProfiledTaskKinds {
-    /**
-     * Do not profile anything.
-     *
-     * <p>Performance is best with this case, but we lose critical path analysis and slowest
-     * operation tracking.
-     */
-    NONE {
-      @Override
-      boolean isProfiling(ProfilerTask type) {
-        return false;
-      }
-    },
-
-    /**
-     * Profile on a few, known-to-be-slow tasks.
-     *
-     * <p>Performance is somewhat decreased in comparison to {@link #NONE}, but we still track the
-     * slowest operations (VFS).
-     */
-    SLOWEST {
-      @Override
-      boolean isProfiling(ProfilerTask type) {
-        return type.collectsSlowestInstances();
-      }
-    },
-
-    /** A set of tasks that's useful for the Json trace output. */
-    ALL_FOR_TRACE {
-      @Override
-      boolean isProfiling(ProfilerTask type) {
-        return !type.isVfs()
-            // CRITICAL_PATH corresponds to writing the file.
-            && type != ProfilerTask.CRITICAL_PATH
-            && type != ProfilerTask.SKYFUNCTION
-            && type != ProfilerTask.ACTION_COMPLETE
-            && !type.isStarlark();
-      }
-    },
-
-    /**
-     * Profile all tasks.
-     *
-     * <p>This is in use when {@code --profile} is specified.
-     */
-    ALL {
-      @Override
-      boolean isProfiling(ProfilerTask type) {
-        return true;
-      }
-    };
-
-    /** Whether the Profiler collects data for the given task type. */
-    abstract boolean isProfiling(ProfilerTask type);
-  }
-
   private Clock clock;
-  private ProfiledTaskKinds profiledTaskKinds;
+  private ImmutableSet<ProfilerTask> profiledTasks;
   private volatile long profileStartTime;
   private volatile boolean recordAllDurations = false;
 
@@ -537,7 +479,7 @@
    * <p>Subsequent calls to beginTask/endTask will be recorded in the provided output stream. Please
    * note that stream performance is extremely important and buffered streams should be utilized.
    *
-   * @param profiledTaskKinds which kinds of {@link ProfilerTask}s to track
+   * @param profiledTasks which of {@link ProfilerTask}s to track
    * @param stream output stream to store profile data. Note: passing unbuffered stream object
    *     reference may result in significant performance penalties
    * @param comment a comment to insert in the profile data
@@ -547,7 +489,7 @@
    * @param execStartTimeNanos execution start time in nanos obtained from {@code clock.nanoTime()}
    */
   public synchronized void start(
-      ProfiledTaskKinds profiledTaskKinds,
+      ImmutableSet<ProfilerTask> profiledTasks,
       OutputStream stream,
       Format format,
       String comment,
@@ -559,7 +501,7 @@
     Preconditions.checkState(!isActive(), "Profiler already active");
     initHistograms();
 
-    this.profiledTaskKinds = profiledTaskKinds;
+    this.profiledTasks = profiledTasks;
     this.clock = clock;
 
     // sanity check for current limitation on the number of supported types due
@@ -661,7 +603,7 @@
   }
 
   public boolean isProfiling(ProfilerTask type) {
-    return profiledTaskKinds.isProfiling(type);
+    return profiledTasks.contains(type);
   }
 
   /**
diff --git a/src/main/java/com/google/devtools/build/lib/runtime/BlazeRuntime.java b/src/main/java/com/google/devtools/build/lib/runtime/BlazeRuntime.java
index fbef84c..9e9d1ae 100644
--- a/src/main/java/com/google/devtools/build/lib/runtime/BlazeRuntime.java
+++ b/src/main/java/com/google/devtools/build/lib/runtime/BlazeRuntime.java
@@ -18,6 +18,7 @@
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Lists;
 import com.google.common.eventbus.SubscriberExceptionContext;
 import com.google.common.eventbus.SubscriberExceptionHandler;
@@ -46,7 +47,6 @@
 import com.google.devtools.build.lib.profiler.ProfilePhase;
 import com.google.devtools.build.lib.profiler.Profiler;
 import com.google.devtools.build.lib.profiler.Profiler.Format;
-import com.google.devtools.build.lib.profiler.Profiler.ProfiledTaskKinds;
 import com.google.devtools.build.lib.profiler.ProfilerTask;
 import com.google.devtools.build.lib.profiler.SilentCloseable;
 import com.google.devtools.build.lib.query2.AbstractBlazeQueryEnvironment;
@@ -279,7 +279,7 @@
       long waitTimeInMs) {
     OutputStream out = null;
     boolean recordFullProfilerData = false;
-    ProfiledTaskKinds profiledTasks = ProfiledTaskKinds.NONE;
+    ImmutableSet.Builder<ProfilerTask> profiledTasksBuilder = ImmutableSet.builder();
     Profiler.Format format = Profiler.Format.BINARY_BAZEL_FORMAT;
     Path profilePath = null;
     try {
@@ -300,20 +300,37 @@
         recordFullProfilerData = false;
         out = profilePath.getOutputStream();
         eventHandler.handle(Event.info("Writing tracer profile to '" + profilePath + "'"));
-        profiledTasks = ProfiledTaskKinds.ALL_FOR_TRACE;
+        for (ProfilerTask profilerTask : ProfilerTask.values()) {
+          if (!profilerTask.isVfs()
+              // CRITICAL_PATH corresponds to writing the file.
+              && profilerTask != ProfilerTask.CRITICAL_PATH
+              && profilerTask != ProfilerTask.SKYFUNCTION
+              && profilerTask != ProfilerTask.ACTION_COMPLETE
+              && !profilerTask.isStarlark()) {
+            profiledTasksBuilder.add(profilerTask);
+          }
+        }
+        profiledTasksBuilder.addAll(options.additionalProfileTasks);
       } else if (options.profilePath != null) {
         profilePath = workspace.getWorkspace().getRelative(options.profilePath);
 
         recordFullProfilerData = options.recordFullProfilerData;
         out = profilePath.getOutputStream();
         eventHandler.handle(Event.info("Writing profile data to '" + profilePath + "'"));
-        profiledTasks = ProfiledTaskKinds.ALL;
+        for (ProfilerTask profilerTask : ProfilerTask.values()) {
+          profiledTasksBuilder.add(profilerTask);
+        }
       } else if (options.alwaysProfileSlowOperations) {
         recordFullProfilerData = false;
         out = null;
-        profiledTasks = ProfiledTaskKinds.SLOWEST;
+        for (ProfilerTask profilerTask : ProfilerTask.values()) {
+          if (profilerTask.collectsSlowestInstances()) {
+            profiledTasksBuilder.add(profilerTask);
+          }
+        }
       }
-      if (profiledTasks != ProfiledTaskKinds.NONE) {
+      ImmutableSet<ProfilerTask> profiledTasks = profiledTasksBuilder.build();
+      if (!profiledTasks.isEmpty()) {
         Profiler profiler = Profiler.instance();
         profiler.start(
             profiledTasks,
diff --git a/src/main/java/com/google/devtools/build/lib/runtime/CommonCommandOptions.java b/src/main/java/com/google/devtools/build/lib/runtime/CommonCommandOptions.java
index 0a8fac9..9452400 100644
--- a/src/main/java/com/google/devtools/build/lib/runtime/CommonCommandOptions.java
+++ b/src/main/java/com/google/devtools/build/lib/runtime/CommonCommandOptions.java
@@ -16,11 +16,13 @@
 import static com.google.common.base.Strings.isNullOrEmpty;
 
 import com.google.devtools.build.lib.profiler.MemoryProfiler.MemoryProfileStableHeapParameters;
+import com.google.devtools.build.lib.profiler.ProfilerTask;
 import com.google.devtools.build.lib.runtime.CommandLineEvent.ToolCommandLineEvent;
 import com.google.devtools.build.lib.util.OptionsUtils;
 import com.google.devtools.build.lib.vfs.PathFragment;
 import com.google.devtools.common.options.Converter;
 import com.google.devtools.common.options.Converters;
+import com.google.devtools.common.options.EnumConverter;
 import com.google.devtools.common.options.Option;
 import com.google.devtools.common.options.OptionDocumentationCategory;
 import com.google.devtools.common.options.OptionEffectTag;
@@ -221,6 +223,16 @@
   public boolean enableCpuUsageProfiling;
 
   @Option(
+      name = "experimental_profile_additional_tasks",
+      converter = ProfilerTaskConverter.class,
+      defaultValue = "none",
+      allowMultiple = true,
+      documentationCategory = OptionDocumentationCategory.LOGGING,
+      effectTags = {OptionEffectTag.AFFECTS_OUTPUTS, OptionEffectTag.BAZEL_MONITORING},
+      help = "Specifies additional profile tasks to be included in the profile.")
+  public List<ProfilerTask> additionalProfileTasks;
+
+  @Option(
       name = "profile",
       defaultValue = "null",
       documentationCategory = OptionDocumentationCategory.LOGGING,
@@ -416,4 +428,11 @@
               + "one."
   )
   public boolean keepStateAfterBuild;
+
+  /** The option converter to check that the user can only specify legal profiler tasks. */
+  public static class ProfilerTaskConverter extends EnumConverter<ProfilerTask> {
+    public ProfilerTaskConverter() {
+      super(ProfilerTask.class, "profiler task");
+    }
+  }
 }
diff --git a/src/test/java/com/google/devtools/build/lib/profiler/AutoProfilerBenchmark.java b/src/test/java/com/google/devtools/build/lib/profiler/AutoProfilerBenchmark.java
index 448a21b..88eb503 100644
--- a/src/test/java/com/google/devtools/build/lib/profiler/AutoProfilerBenchmark.java
+++ b/src/test/java/com/google/devtools/build/lib/profiler/AutoProfilerBenchmark.java
@@ -15,8 +15,8 @@
 
 import com.google.caliper.BeforeExperiment;
 import com.google.caliper.Benchmark;
+import com.google.common.collect.ImmutableSet;
 import com.google.devtools.build.lib.clock.BlazeClock;
-import com.google.devtools.build.lib.profiler.Profiler.ProfiledTaskKinds;
 import com.google.devtools.build.lib.vfs.inmemoryfs.InMemoryFileSystem;
 
 /**
@@ -29,7 +29,7 @@
   void startProfiler() throws Exception {
     Profiler.instance()
         .start(
-            ProfiledTaskKinds.ALL,
+            ImmutableSet.copyOf(ProfilerTask.values()),
             new InMemoryFileSystem().getPath("/out.dat").getOutputStream(),
             Profiler.Format.BINARY_BAZEL_FORMAT,
             "benchmark",
diff --git a/src/test/java/com/google/devtools/build/lib/profiler/ProfilerChartTest.java b/src/test/java/com/google/devtools/build/lib/profiler/ProfilerChartTest.java
index a8fe512..8463222 100644
--- a/src/test/java/com/google/devtools/build/lib/profiler/ProfilerChartTest.java
+++ b/src/test/java/com/google/devtools/build/lib/profiler/ProfilerChartTest.java
@@ -15,8 +15,8 @@
 
 import static com.google.common.truth.Truth.assertThat;
 
+import com.google.common.collect.ImmutableSet;
 import com.google.devtools.build.lib.clock.BlazeClock;
-import com.google.devtools.build.lib.profiler.Profiler.ProfiledTaskKinds;
 import com.google.devtools.build.lib.profiler.analysis.ProfileInfo;
 import com.google.devtools.build.lib.profiler.chart.AggregatingChartCreator;
 import com.google.devtools.build.lib.profiler.chart.Chart;
@@ -252,7 +252,7 @@
     Profiler profiler = Profiler.instance();
     try (OutputStream out = cacheFile.getOutputStream()) {
       profiler.start(
-          ProfiledTaskKinds.ALL,
+          ImmutableSet.copyOf(ProfilerTask.values()),
           out,
           Profiler.Format.BINARY_BAZEL_FORMAT,
           "basic test",
diff --git a/src/test/java/com/google/devtools/build/lib/profiler/ProfilerTest.java b/src/test/java/com/google/devtools/build/lib/profiler/ProfilerTest.java
index 4c18785..eef9405 100644
--- a/src/test/java/com/google/devtools/build/lib/profiler/ProfilerTest.java
+++ b/src/test/java/com/google/devtools/build/lib/profiler/ProfilerTest.java
@@ -20,10 +20,10 @@
 import static org.junit.Assert.fail;
 
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
 import com.google.common.io.ByteStreams;
 import com.google.devtools.build.lib.clock.BlazeClock;
 import com.google.devtools.build.lib.clock.Clock;
-import com.google.devtools.build.lib.profiler.Profiler.ProfiledTaskKinds;
 import com.google.devtools.build.lib.profiler.Profiler.SlowTask;
 import com.google.devtools.build.lib.profiler.analysis.ProfileInfo;
 import com.google.devtools.build.lib.testutil.ManualClock;
@@ -75,11 +75,25 @@
     }
   }
 
-  private ByteArrayOutputStream start(ProfiledTaskKinds kinds, Profiler.Format format)
+  private ImmutableSet<ProfilerTask> getAllProfilerTasks() {
+    return ImmutableSet.copyOf(ProfilerTask.values());
+  }
+
+  private ImmutableSet<ProfilerTask> getSlowestProfilerTasks() {
+    ImmutableSet.Builder<ProfilerTask> profiledTasksBuilder = ImmutableSet.builder();
+    for (ProfilerTask profilerTask : ProfilerTask.values()) {
+      if (profilerTask.collectsSlowestInstances()) {
+        profiledTasksBuilder.add(profilerTask);
+      }
+    }
+    return profiledTasksBuilder.build();
+  }
+
+  private ByteArrayOutputStream start(ImmutableSet<ProfilerTask> tasks, Profiler.Format format)
       throws IOException {
     ByteArrayOutputStream buffer = new ByteArrayOutputStream();
     profiler.start(
-        kinds,
+        tasks,
         buffer,
         format,
         "test",
@@ -90,9 +104,9 @@
     return buffer;
   }
 
-  private void startUnbuffered(ProfiledTaskKinds kinds) throws IOException {
+  private void startUnbuffered(ImmutableSet<ProfilerTask> tasks) throws IOException {
     profiler.start(
-        kinds,
+        tasks,
         null,
         null,
         "test",
@@ -105,7 +119,7 @@
   @Test
   public void testProfilerActivation() throws Exception {
     assertThat(profiler.isActive()).isFalse();
-    start(ProfiledTaskKinds.ALL, BINARY_BAZEL_FORMAT);
+    start(getAllProfilerTasks(), BINARY_BAZEL_FORMAT);
     assertThat(profiler.isActive()).isTrue();
 
     profiler.stop();
@@ -114,7 +128,7 @@
 
   @Test
   public void testTaskDetails() throws Exception {
-    ByteArrayOutputStream buffer = start(ProfiledTaskKinds.ALL, BINARY_BAZEL_FORMAT);
+    ByteArrayOutputStream buffer = start(getAllProfilerTasks(), BINARY_BAZEL_FORMAT);
     try (SilentCloseable c = profiler.profile(ProfilerTask.ACTION, "action task")) {
       profiler.logEvent(ProfilerTask.INFO, "event");
     }
@@ -135,7 +149,7 @@
 
   @Test
   public void testProfiler() throws Exception {
-    ByteArrayOutputStream buffer = start(ProfiledTaskKinds.ALL, BINARY_BAZEL_FORMAT);
+    ByteArrayOutputStream buffer = start(getAllProfilerTasks(), BINARY_BAZEL_FORMAT);
     profiler.logSimpleTask(BlazeClock.instance().nanoTime(),
                            ProfilerTask.PHASE, "profiler start");
     try (SilentCloseable c = profiler.profile(ProfilerTask.ACTION, "complex task")) {
@@ -191,7 +205,7 @@
   public void testProfilerRecordingAllEvents() throws Exception {
     ByteArrayOutputStream buffer = new ByteArrayOutputStream();
     profiler.start(
-        ProfiledTaskKinds.ALL,
+        getAllProfilerTasks(),
         buffer,
         BINARY_BAZEL_FORMAT,
         "basic test",
@@ -221,7 +235,7 @@
     ByteArrayOutputStream buffer = new ByteArrayOutputStream();
 
     profiler.start(
-        ProfiledTaskKinds.SLOWEST,
+        getSlowestProfilerTasks(),
         buffer,
         BINARY_BAZEL_FORMAT,
         "test",
@@ -247,7 +261,7 @@
 
   @Test
   public void testSlowestTasks() throws Exception {
-    startUnbuffered(ProfiledTaskKinds.ALL);
+    startUnbuffered(getAllProfilerTasks());
     profiler.logSimpleTaskDuration(
         Profiler.nanoTimeMaybe(), Duration.ofSeconds(10), ProfilerTask.LOCAL_PARSE, "foo");
     Iterable<SlowTask> slowestTasks = profiler.getSlowestTasks();
@@ -259,7 +273,7 @@
 
   @Test
   public void testGetSlowestTasksCapped() throws Exception {
-    startUnbuffered(ProfiledTaskKinds.SLOWEST);
+    startUnbuffered(getSlowestProfilerTasks());
 
     // Add some fast tasks - these shouldn't show up in the slowest.
     for (int i = 0; i < ProfilerTask.VFS_STAT.slowestInstancesCount; i++) {
@@ -330,7 +344,7 @@
   public void testProfilerRecordsNothing() throws Exception {
     ByteArrayOutputStream buffer = new ByteArrayOutputStream();
     profiler.start(
-        ProfiledTaskKinds.NONE,
+        ImmutableSet.of(),
         buffer,
         BINARY_BAZEL_FORMAT,
         "test",
@@ -352,7 +366,7 @@
 
   @Test
   public void testConcurrentProfiling() throws Exception {
-    ByteArrayOutputStream buffer = start(ProfiledTaskKinds.ALL, BINARY_BAZEL_FORMAT);
+    ByteArrayOutputStream buffer = start(getAllProfilerTasks(), BINARY_BAZEL_FORMAT);
 
     long id = Thread.currentThread().getId();
     Thread thread1 = new Thread() {
@@ -406,7 +420,7 @@
 
   @Test
   public void testPhaseTasks() throws Exception {
-    ByteArrayOutputStream buffer = start(ProfiledTaskKinds.ALL, BINARY_BAZEL_FORMAT);
+    ByteArrayOutputStream buffer = start(getAllProfilerTasks(), BINARY_BAZEL_FORMAT);
     Thread thread1 = new Thread() {
       @Override public void run() {
         for (int i = 0; i < 100; i++) {
@@ -468,7 +482,7 @@
 
   @Test
   public void testCorruptedFile() throws Exception {
-    ByteArrayOutputStream buffer = start(ProfiledTaskKinds.ALL, BINARY_BAZEL_FORMAT);
+    ByteArrayOutputStream buffer = start(getAllProfilerTasks(), BINARY_BAZEL_FORMAT);
     for (int i = 0; i < 100; i++) {
       try (SilentCloseable c = profiler.profile(ProfilerTask.INFO, "outer task " + i)) {
         clock.advanceMillis(1);
@@ -494,7 +508,7 @@
 
   @Test
   public void testUnsupportedProfilerRecord() throws Exception {
-    ByteArrayOutputStream buffer = start(ProfiledTaskKinds.ALL, BINARY_BAZEL_FORMAT);
+    ByteArrayOutputStream buffer = start(getAllProfilerTasks(), BINARY_BAZEL_FORMAT);
     try (SilentCloseable c = profiler.profile(ProfilerTask.INFO, "outer task")) {
       profiler.logEvent(ProfilerTask.PHASE, "inner task");
     }
@@ -550,7 +564,7 @@
       }
     };
     profiler.start(
-        ProfiledTaskKinds.ALL,
+        getAllProfilerTasks(),
         new ByteArrayOutputStream(),
         BINARY_BAZEL_FORMAT,
         "testResilenceToNonDecreasingNanoTimes",
@@ -565,7 +579,7 @@
   /** Checks that the histograms are cleared in the stop call. */
   @Test
   public void testEmptyTaskHistograms() throws Exception {
-    startUnbuffered(ProfiledTaskKinds.ALL);
+    startUnbuffered(getAllProfilerTasks());
     profiler.logSimpleTaskDuration(
         Profiler.nanoTimeMaybe(), Duration.ofSeconds(10), ProfilerTask.INFO, "foo");
     profiler.stop();
@@ -577,7 +591,7 @@
 
   @Test
   public void testTaskHistograms() throws Exception {
-    startUnbuffered(ProfiledTaskKinds.ALL);
+    startUnbuffered(getAllProfilerTasks());
     profiler.logSimpleTaskDuration(
         Profiler.nanoTimeMaybe(), Duration.ofSeconds(10), ProfilerTask.INFO, "foo");
     ImmutableList<StatRecorder> histograms = profiler.getTasksHistograms();
@@ -601,7 +615,7 @@
       }
     };
     profiler.start(
-        ProfiledTaskKinds.ALL,
+        getAllProfilerTasks(),
         failingOutputStream,
         BINARY_BAZEL_FORMAT,
         "basic test",
@@ -628,7 +642,7 @@
       }
     };
     profiler.start(
-        ProfiledTaskKinds.ALL,
+        getAllProfilerTasks(),
         failingOutputStream,
         JSON_TRACE_FILE_FORMAT,
         "basic test",