Add --task_tree and --task_tree_threshold options.

Options allow displaying some tasks in more detail, e.g. for inspecting what
exactly a Skylark user-defined function calls and how long that takes.

--
MOS_MIGRATED_REVID=104505599
diff --git a/src/main/java/com/google/devtools/build/lib/profiler/ProfileInfo.java b/src/main/java/com/google/devtools/build/lib/profiler/ProfileInfo.java
index f7ec6db..b93b8d9 100644
--- a/src/main/java/com/google/devtools/build/lib/profiler/ProfileInfo.java
+++ b/src/main/java/com/google/devtools/build/lib/profiler/ProfileInfo.java
@@ -18,10 +18,14 @@
 
 import com.google.common.base.Joiner;
 import com.google.common.base.Preconditions;
+import com.google.common.base.Predicate;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
 import com.google.common.collect.ListMultimap;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import com.google.common.collect.MultimapBuilder.ListMultimapBuilder;
+import com.google.common.collect.Ordering;
 import com.google.common.collect.Sets;
 import com.google.devtools.build.lib.util.VarInt;
 import com.google.devtools.build.lib.vfs.Path;
@@ -29,6 +33,7 @@
 import java.io.BufferedInputStream;
 import java.io.DataInputStream;
 import java.io.IOException;
+import java.io.PrintStream;
 import java.io.UnsupportedEncodingException;
 import java.nio.ByteBuffer;
 import java.util.ArrayDeque;
@@ -42,6 +47,8 @@
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.TimeUnit;
+import java.util.regex.Pattern;
 import java.util.zip.Inflater;
 import java.util.zip.InflaterInputStream;
 
@@ -153,6 +160,14 @@
     }
   }
 
+  public static final Ordering<Task> TASK_DURATION_ORDERING =
+      new Ordering<Task>() {
+        @Override
+        public int compare(Task o1, Task o2) {
+          return Long.compare(o1.durationNanos, o2.durationNanos);
+        }
+      };
+
   /**
    * Container for the profile record information.
    *
@@ -164,7 +179,7 @@
     public final int id;
     public final int parentId;
     public final long startTime;
-    public final long duration;
+    public final long durationNanos;
     public final ProfilerTask type;
     final CompactStatistics stats;
     // Contains statistic for a task and all subtasks. Populated only for root tasks.
@@ -176,13 +191,20 @@
     // Reference to the related task (e.g. ACTION_GRAPH->ACTION task relation).
     private Task relatedTask;
 
-    Task(long threadId, int id, int parentId, long startTime, long duration,
-         ProfilerTask type, int descIndex, CompactStatistics stats) {
+    Task(
+        long threadId,
+        int id,
+        int parentId,
+        long startTime,
+        long durationNanos,
+        ProfilerTask type,
+        int descIndex,
+        CompactStatistics stats) {
       this.threadId = threadId;
       this.id = id;
       this.parentId = parentId;
       this.startTime = startTime;
-      this.duration = duration;
+      this.durationNanos = durationNanos;
       this.type = type;
       this.descIndex = descIndex;
       this.stats = stats;
@@ -258,6 +280,77 @@
     }
 
     /**
+     * Produce a nicely indented tree of the task and its subtasks with execution time.
+     *
+     * <p>Execution times are in milliseconds.
+     *
+     * <p>Example:
+     *
+     * <pre>
+     * 636779 SKYLARK_USER_FN (259.593 ms) /path/file.bzl:42#function [
+     *   636810 SKYLARK_USER_FN (257.768 ms) /path/file.bzl:133#_other_function [
+     *     636974 SKYLARK_BUILTIN_FN (254.596 ms) some.package.PackageFactory$9#genrule []
+     *   2 subtree(s) omitted]
+     * ]
+     * </pre>
+     *
+     * @param durationThresholdMillis Tasks with a shorter duration than this threshold will be
+     *  skipped
+     * @return whether this task took longer than the threshold and was thus printed
+     */
+    public boolean printTaskTree(PrintStream out, long durationThresholdMillis) {
+      return printTaskTree(out, "", TimeUnit.MILLISECONDS.toNanos(durationThresholdMillis));
+    }
+
+    /**
+     * @see #printTaskTree(PrintStream, long)
+     */
+    private boolean printTaskTree(
+        PrintStream out, String indent, final long durationThresholdNanos) {
+      if (durationNanos < durationThresholdNanos) {
+        return false;
+      }
+      out.printf("%s%6d %s", indent, id, type);
+      out.printf(" (%5.3f ms) ", durationNanos / 1000000.0);
+      out.print(getDescription());
+
+      out.print(" [");
+      ImmutableList<Task> sortedSubTasks =
+          TASK_DURATION_ORDERING
+              .reverse()
+              .immutableSortedCopy(
+                  Iterables.filter(
+                      Arrays.asList(subtasks),
+                      new Predicate<Task>() {
+                        @Override
+                        public boolean apply(Task task) {
+                          return task.durationNanos >= durationThresholdNanos;
+                        }
+                      }));
+      String sep = "";
+      for (Task task : sortedSubTasks) {
+        out.print(sep);
+        out.println();
+        task.printTaskTree(out, indent + "  ", durationThresholdNanos);
+        sep = ",";
+      }
+      if (!sortedSubTasks.isEmpty()) {
+        out.println();
+        out.print(indent);
+      }
+      int skipped = subtasks.length - sortedSubTasks.size();
+      if (skipped > 0) {
+        out.printf("%d subtree(s) omitted", skipped);
+      }
+      out.print("]");
+
+      if (indent.equals("")) {
+        out.println();
+      }
+      return true;
+    }
+
+    /**
      * Tasks records by default sorted by their id. Since id was obtained using
      * AtomicInteger, this comparison will correctly sort tasks in time-ascending
      * order regardless of their origin thread.
@@ -513,7 +606,7 @@
       totalTime += attr.totalTime;
       if (task.type == type) {
         count++;
-        totalTime += (task.duration - task.getInheritedDuration());
+        totalTime += (task.durationNanos - task.getInheritedDuration());
       }
     }
     return new AggregateAttr(count, totalTime);
@@ -569,7 +662,7 @@
       duration = phaseTask.relatedTask.startTime - phaseTask.startTime;
     } else {
       Task lastTask = rootTasksById.get(rootTasksById.size() - 1);
-      duration = lastTask.startTime + lastTask.duration - phaseTask.startTime;
+      duration = lastTask.startTime + lastTask.durationNanos - phaseTask.startTime;
     }
     Preconditions.checkState(duration >= 0);
     return duration;
@@ -705,7 +798,7 @@
           }
         }
         if (actionTask.type == ProfilerTask.ACTION) {
-          long duration = actionTask.duration;
+          long duration = actionTask.durationNanos;
           if (ignoredTasks.contains(actionTask)) {
             duration = 0L;
           } else {
@@ -734,7 +827,7 @@
       if (task.type == CRITICAL_PATH) {
         CriticalPathEntry entry = null;
         for (Task shared : task.subtasks) {
-          entry = new CriticalPathEntry(shared, shared.duration, entry);
+          entry = new CriticalPathEntry(shared, shared.durationNanos, entry);
         }
         return entry;
       }
@@ -825,7 +918,7 @@
     Task related = parallelBuilderCompletionQueueTasks.get(actionTask);
     if (related != null) {
       Preconditions.checkState(related.type == ProfilerTask.ACTION_BUILDER);
-      long time = related.startTime - (actionTask.startTime + actionTask.duration);
+      long time = related.startTime - (actionTask.startTime + actionTask.durationNanos);
       Preconditions.checkState(time >= 0);
       return time;
     } else {
@@ -834,6 +927,23 @@
   }
 
   /**
+   * Searches for the task by its description. Linear in the number of tasks.
+   * @param description a regular expression pattern which will be matched against the task
+   * description
+   * @return an Iterable of Tasks matching the description
+   */
+  public Iterable<Task> findTasksByDescription(final Pattern description) {
+    return Iterables.filter(
+        allTasksById,
+        new Predicate<Task>() {
+          @Override
+          public boolean apply(Task task) {
+            return description.matcher(task.getDescription()).find();
+          }
+        });
+  }
+
+  /**
    * Returns an empty array used to store task statistics. Array index
    * corresponds to the ProfilerTask ordinal() value associated with the
    * given statistic. Absent statistics are stored as null.
diff --git a/src/main/java/com/google/devtools/build/lib/profiler/chart/AggregatingChartCreator.java b/src/main/java/com/google/devtools/build/lib/profiler/chart/AggregatingChartCreator.java
index a3f29c7..98c58cd 100644
--- a/src/main/java/com/google/devtools/build/lib/profiler/chart/AggregatingChartCreator.java
+++ b/src/main/java/com/google/devtools/build/lib/profiler/chart/AggregatingChartCreator.java
@@ -128,7 +128,7 @@
    */
   private void createBar(Chart chart, Task task, ChartBarType type) {
     String label = task.type.description + ": " + task.getDescription();
-    chart.addBar(task.threadId, task.startTime, task.startTime + task.duration, type, label);
+    chart.addBar(task.threadId, task.startTime, task.startTime + task.durationNanos, type, label);
   }
 
   /**
diff --git a/src/main/java/com/google/devtools/build/lib/profiler/chart/DetailedChartCreator.java b/src/main/java/com/google/devtools/build/lib/profiler/chart/DetailedChartCreator.java
index 4ee2f19..2477d3a 100644
--- a/src/main/java/com/google/devtools/build/lib/profiler/chart/DetailedChartCreator.java
+++ b/src/main/java/com/google/devtools/build/lib/profiler/chart/DetailedChartCreator.java
@@ -53,7 +53,7 @@
     for (Task task : info.allTasksById) {
       String label = task.type.description + ": " + task.getDescription();
       ChartBarType type = chart.lookUpType(task.type.description);
-      long stop = task.startTime + task.duration;
+      long stop = task.startTime + task.durationNanos;
       CriticalPathEntry entry = null;
 
       // for top level tasks, check if they are on the critical path
diff --git a/src/main/java/com/google/devtools/build/lib/profiler/output/SkylarkHtml.java b/src/main/java/com/google/devtools/build/lib/profiler/output/SkylarkHtml.java
index 0e4dc6e..b2768bd 100644
--- a/src/main/java/com/google/devtools/build/lib/profiler/output/SkylarkHtml.java
+++ b/src/main/java/com/google/devtools/build/lib/profiler/output/SkylarkHtml.java
@@ -142,7 +142,7 @@
       lnPrintf("'%s': google.visualization.arrayToDataTable(", function);
       lnPrint("[['duration']");
       for (Task task : tasks.get(function)) {
-        printf(",[%f]", task.duration / 1000000.);
+        printf(",[%f]", task.durationNanos / 1000000.);
       }
       lnPrint("], false),");
     }
diff --git a/src/main/java/com/google/devtools/build/lib/profiler/statistics/PhaseStatistics.java b/src/main/java/com/google/devtools/build/lib/profiler/statistics/PhaseStatistics.java
index 7ef69c7..3f6e326 100644
--- a/src/main/java/com/google/devtools/build/lib/profiler/statistics/PhaseStatistics.java
+++ b/src/main/java/com/google/devtools/build/lib/profiler/statistics/PhaseStatistics.java
@@ -54,7 +54,7 @@
       for (Task task : taskList) {
         // Tasks on the phaseTask thread already accounted for in the phaseDuration.
         if (task.threadId != phaseTask.threadId) {
-          duration += task.duration;
+          duration += task.durationNanos;
         }
       }
       totalDurationNanos = duration;
diff --git a/src/main/java/com/google/devtools/build/lib/profiler/statistics/PhaseVfsStatistics.java b/src/main/java/com/google/devtools/build/lib/profiler/statistics/PhaseVfsStatistics.java
index 27856fc..4c52c47 100644
--- a/src/main/java/com/google/devtools/build/lib/profiler/statistics/PhaseVfsStatistics.java
+++ b/src/main/java/com/google/devtools/build/lib/profiler/statistics/PhaseVfsStatistics.java
@@ -111,7 +111,7 @@
         stat = new Stat();
       }
 
-      stat.duration += task.duration;
+      stat.duration += task.durationNanos;
       stat.count++;
       statsForType.put(path, stat);
     }
diff --git a/src/main/java/com/google/devtools/build/lib/profiler/statistics/TasksStatistics.java b/src/main/java/com/google/devtools/build/lib/profiler/statistics/TasksStatistics.java
index 1a422db..72afab4 100644
--- a/src/main/java/com/google/devtools/build/lib/profiler/statistics/TasksStatistics.java
+++ b/src/main/java/com/google/devtools/build/lib/profiler/statistics/TasksStatistics.java
@@ -13,10 +13,9 @@
 // limitations under the License.
 package com.google.devtools.build.lib.profiler.statistics;
 
+import com.google.devtools.build.lib.profiler.ProfileInfo;
 import com.google.devtools.build.lib.profiler.ProfileInfo.Task;
 
-import java.util.Collections;
-import java.util.Comparator;
 import java.util.List;
 
 /**
@@ -91,28 +90,19 @@
   }
 
   /**
-   * @param name
-   * @param tasks
    * @return The set of statistics grouped in this class, computed from a list of {@link Task}s.
    */
   public static TasksStatistics create(String name, List<Task> tasks) {
-    Collections.sort(
-        tasks,
-        new Comparator<Task>() {
-          @Override
-          public int compare(Task o1, Task o2) {
-            return Long.compare(o1.duration, o2.duration);
-          }
-        });
+    tasks = ProfileInfo.TASK_DURATION_ORDERING.immutableSortedCopy(tasks);
     int count = tasks.size();
-    long min = tasks.get(0).duration;
-    long max = tasks.get(count - 1).duration;
+    long min = tasks.get(0).durationNanos;
+    long max = tasks.get(count - 1).durationNanos;
 
     int midIndex = count / 2;
     double median =
         tasks.size() % 2 == 0
-            ? (tasks.get(midIndex).duration + tasks.get(midIndex - 1).duration) / 2.0
-            : tasks.get(midIndex).duration;
+            ? (tasks.get(midIndex).durationNanos + tasks.get(midIndex - 1).durationNanos) / 2.0
+            : tasks.get(midIndex).durationNanos;
 
     // Compute standard deviation with a shift to avoid catastrophic cancellation
     // and also do it in milliseconds, as in nanoseconds it overflows
@@ -122,9 +112,9 @@
     final long shift = min;
 
     for (Task task : tasks) {
-      sum += task.duration;
-      self += task.duration - task.getInheritedDuration();
-      double taskDurationShiftMillis = toMilliSeconds(task.duration - shift);
+      sum += task.durationNanos;
+      self += task.durationNanos - task.getInheritedDuration();
+      double taskDurationShiftMillis = toMilliSeconds(task.durationNanos - shift);
       sumOfSquaredShiftedMillis += taskDurationShiftMillis * taskDurationShiftMillis;
     }
     double sumShiftedMillis = toMilliSeconds(sum - count * shift);
diff --git a/src/main/java/com/google/devtools/build/lib/runtime/commands/ProfileCommand.java b/src/main/java/com/google/devtools/build/lib/runtime/commands/ProfileCommand.java
index 0d728ca..1a21b9c 100644
--- a/src/main/java/com/google/devtools/build/lib/runtime/commands/ProfileCommand.java
+++ b/src/main/java/com/google/devtools/build/lib/runtime/commands/ProfileCommand.java
@@ -13,10 +13,13 @@
 // limitations under the License.
 package com.google.devtools.build.lib.runtime.commands;
 
+import com.google.common.base.Joiner;
+import com.google.common.collect.Iterables;
 import com.google.devtools.build.lib.events.Event;
 import com.google.devtools.build.lib.events.EventHandler;
 import com.google.devtools.build.lib.profiler.ProfileInfo;
 import com.google.devtools.build.lib.profiler.ProfileInfo.InfoListener;
+import com.google.devtools.build.lib.profiler.ProfileInfo.Task;
 import com.google.devtools.build.lib.profiler.ProfilePhase;
 import com.google.devtools.build.lib.profiler.ProfilerTask;
 import com.google.devtools.build.lib.profiler.output.HtmlCreator;
@@ -41,6 +44,7 @@
 import java.io.IOException;
 import java.io.PrintStream;
 import java.util.EnumMap;
+import java.util.regex.Pattern;
 
 /**
  * Command line wrapper for analyzing Blaze build profiles.
@@ -92,6 +96,24 @@
     )
     public boolean htmlDetails;
 
+    @Option(
+      name = "task_tree",
+      defaultValue = "null",
+      converter = Converters.RegexPatternConverter.class,
+      help =
+          "Print the tree of profiler tasks from all tasks matching the given regular expression."
+    )
+    public Pattern taskTree;
+
+    @Option(
+      name = "task_tree_threshold",
+      defaultValue = "50",
+      help =
+          "When printing a task tree, will skip tasks with a duration that is less than the"
+              + " given threshold in milliseconds."
+    )
+    public long taskTreeThreshold;
+
     @Option(name = "vfs_stats",
         defaultValue = "false",
         help = "If present, include VFS path statistics.")
@@ -132,8 +154,7 @@
       opts.vfsStatsLimit = 0;
     }
 
-    PrintStream out = new PrintStream(env.getReporter().getOutErr().getOutputStream());
-    try {
+    try (PrintStream out = new PrintStream(env.getReporter().getOutErr().getOutputStream())) {
       env.getReporter().handle(Event.warn(
           null, "This information is intended for consumption by Blaze developers"
               + " only, and may change at any time.  Script against it at your own risk"));
@@ -145,6 +166,11 @@
               profileFile, getInfoListener(env));
           ProfileInfo.aggregateProfile(info, getInfoListener(env));
 
+          if (opts.taskTree != null) {
+            printTaskTree(out, name, info, opts.taskTree, opts.taskTreeThreshold);
+            continue;
+          }
+
           PhaseSummaryStatistics phaseSummaryStatistics = new PhaseSummaryStatistics(info);
           EnumMap<ProfilePhase, PhaseStatistics> phaseStatistics =
               new EnumMap<>(ProfilePhase.class);
@@ -185,12 +211,38 @@
               null, "Failed to process file " + name + ": " + e.getMessage()));
         }
       }
-    } finally {
-      out.flush();
     }
     return ExitCode.SUCCESS;
   }
 
+  /**
+   * Prints trees rooted at tasks with a description matching a pattern.
+   * @see Task#printTaskTree(PrintStream, long)
+   */
+  private void printTaskTree(
+      PrintStream out,
+      String fileName,
+      ProfileInfo info,
+      Pattern taskPattern,
+      long taskDurationThreshold) {
+    Iterable<Task> tasks = info.findTasksByDescription(taskPattern);
+    if (Iterables.isEmpty(tasks)) {
+      out.printf("No tasks matching %s found in profile file %s.", taskPattern, fileName);
+      out.println();
+    } else {
+      int skipped = 0;
+      for (Task task : tasks) {
+        if (!task.printTaskTree(out, taskDurationThreshold)) {
+          skipped++;
+        }
+      }
+      if (skipped > 0) {
+        out.printf("Skipped %d matching task(s) below the duration threshold.", skipped);
+      }
+      out.println();
+    }
+  }
+
   private void dumpProfile(
       CommandEnvironment env, ProfileInfo info, PrintStream out, String dumpMode) {
     if (!dumpMode.contains("unsorted")) {
@@ -212,10 +264,22 @@
   }
 
   private void dumpTask(ProfileInfo.Task task, PrintStream out, int indent) {
-    StringBuilder builder = new StringBuilder(String.format(
-        "\n%s %s\nThread: %-6d  Id: %-6d  Parent: %d\nStart time: %-12s   Duration: %s",
-        task.type, task.getDescription(), task.threadId, task.id, task.parentId,
-        TimeUtilities.prettyTime(task.startTime), TimeUtilities.prettyTime(task.duration)));
+    StringBuilder builder =
+        new StringBuilder(
+            String.format(
+                Joiner.on('\n')
+                    .join(
+                        "",
+                        "%s %s",
+                        "Thread: %-6d  Id: %-6d  Parent: %d",
+                        "Start time: %-12s   Duration: %s"),
+                task.type,
+                task.getDescription(),
+                task.threadId,
+                task.id,
+                task.parentId,
+                TimeUtilities.prettyTime(task.startTime),
+                TimeUtilities.prettyTime(task.durationNanos)));
     if (task.hasStats()) {
       builder.append("\n");
       ProfileInfo.AggregateAttr[] stats = task.getStatAttrArray();
@@ -245,9 +309,15 @@
       }
     }
     out.println(
-        task.threadId + "|" + task.id + "|" + task.parentId + "|"
-        + task.startTime + "|" + task.duration + "|"
-        + aggregateString.toString().trim() + "|"
-        + task.type + "|" + task.getDescription());
+        Joiner.on('|')
+            .join(
+                task.threadId,
+                task.id,
+                task.parentId,
+                task.startTime,
+                task.durationNanos,
+                aggregateString.toString().trim(),
+                task.type,
+                task.getDescription()));
   }
 }