Print an aggregated summary for sharded tests (#1870)

With a high number of shards it becomes very hard to see which tests are
actually failing. This change introduces the `--print_shard_summary`
flag. If set, there will be one Buildkite annotation per failing sharded
platform that contains a test summary of all failing tests.

Example: https://buildkite.com/bazel/bazel-bazel-macos-ninja/builds/420

Related to
https://github.com/bazelbuild/continuous-integration/issues/1708
diff --git a/buildkite/bazelci.py b/buildkite/bazelci.py
index c98a48d..574d233 100755
--- a/buildkite/bazelci.py
+++ b/buildkite/bazelci.py
@@ -17,6 +17,8 @@
 import argparse
 import base64
 import codecs
+import collections
+import concurrent.futures
 import copy
 import datetime
 from glob import glob
@@ -38,6 +40,7 @@
 import tempfile
 import threading
 import time
+from typing import Sequence
 import urllib.error
 import urllib.request
 import yaml
@@ -81,6 +84,12 @@
     "bazel": "gs://bazel-kzips/",
 }[BUILDKITE_ORG]
 
+# We don't collect logs in the trusted org
+LOG_BUCKET = {
+    "bazel-testing": "https://storage.googleapis.com/bazel-testing-buildkite-artifacts",
+    "bazel": "https://storage.googleapis.com/bazel-untrusted-buildkite-artifacts",
+}[BUILDKITE_ORG]
+
 # Projects can opt out of receiving GitHub issues from --notify by adding `"do_not_notify": True` to their respective downstream entry.
 DOWNSTREAM_PROJECTS_PRODUCTION = {
     "Android Studio Plugin": {
@@ -324,7 +333,6 @@
     "rules_nodejs": {
         "git_repository": "https://github.com/bazelbuild/rules_nodejs.git",
         "pipeline_slug": "rules-nodejs-nodejs",
-        "disabled_reason": "https://github.com/bazelbuild/rules_nodejs/issues/3713"
     },
     "rules_perl": {
         "git_repository": "https://github.com/bazelbuild/rules_perl.git",
@@ -613,6 +621,9 @@
     re.compile(r"^bk-(trusted-)?macstudio-\d+$"),
 ]
 
+_TEST_BEP_FILE = "test_bep.json"
+_SHARD_RE = re.compile(r"(.+) \(shard (\d+)\)")
+
 
 class BuildkiteException(Exception):
     """
@@ -1224,9 +1235,9 @@
         test_env_vars.append("BAZELISK_USER_AGENT")
 
         # Avoid "Network is unreachable" errors in IPv6-only environments
-        for e in ('JAVA_TOOL_OPTIONS', 'SSL_CERT_FILE'):
-          if os.getenv(e):
-              test_env_vars.append(e)
+        for e in ("JAVA_TOOL_OPTIONS", "SSL_CERT_FILE"):
+            if os.getenv(e):
+                test_env_vars.append(e)
 
         # We use one binary for all Linux platforms (because we also just release one binary for all
         # Linux versions and we have to ensure that it works on all of them).
@@ -1388,13 +1399,11 @@
                     project=project,
                 )
 
-            test_bep_file = os.path.join(tmpdir, "test_bep.json")
-            upload_thread = threading.Thread(
-                target=upload_test_logs_from_bep,
-                args=(test_bep_file, tmpdir, monitor_flaky_tests),
-            )
-            try:
-                upload_thread.start()
+            test_bep_file = os.path.join(tmpdir, _TEST_BEP_FILE)
+            with concurrent.futures.ThreadPoolExecutor() as executor:
+                future = executor.submit(
+                    upload_test_logs_from_bep, test_bep_file, tmpdir, monitor_flaky_tests
+                )
                 try:
                     execute_bazel_test(
                         bazel_version,
@@ -1410,8 +1419,9 @@
                         upload_json_profile(json_profile_out_test, tmpdir)
                     if capture_corrupted_outputs_dir_test:
                         upload_corrupted_outputs(capture_corrupted_outputs_dir_test, tmpdir)
-            finally:
-                upload_thread.join()
+
+                _ = future.result()
+                # TODO: print results
 
         if coverage_targets:
             (
@@ -2594,12 +2604,14 @@
 def upload_test_logs_from_bep(bep_file, tmpdir, monitor_flaky_tests):
     if local_run_only():
         return
+
     bazelci_agent_binary = download_bazelci_agent(tmpdir)
     execute_command(
         [
             bazelci_agent_binary,
             "artifact",
             "upload",
+            "--debug",  # Force BEP upload for non-flaky failures
             "--delay=5",
             "--mode=buildkite",
             "--build_event_json_file={}".format(bep_file),
@@ -2773,6 +2785,7 @@
     monitor_flaky_tests,
     use_but,
     notify,
+    print_shard_summary,
 ):
     task_configs = configs.get("tasks", None)
     if not task_configs:
@@ -2824,6 +2837,7 @@
 
     config_hashes = set()
     skipped_downstream_tasks = []
+    has_sharded_task = False
     for task, task_config in task_configs.items():
         platform = get_platform_for_task(task, task_config)
         task_name = task_config.get("name")
@@ -2863,6 +2877,9 @@
         except ValueError:
             raise BuildkiteException("Task {} has invalid shard value '{}'".format(task, shards))
 
+        if shards > 1:
+            has_sharded_task = True
+
         step = runner_step(
             platform=platform,
             task=task,
@@ -2900,6 +2917,7 @@
     all_downstream_pipeline_slugs = []
     for _, config in DOWNSTREAM_PROJECTS.items():
         all_downstream_pipeline_slugs.append(config["pipeline_slug"])
+
     # We update last green commit in the following cases:
     #   1. This job runs on master, stable or main branch (could be a custom build launched manually)
     #   2. We intend to run the same job in downstream with Bazel@HEAD (eg. google-bazel-presubmit)
@@ -2908,17 +2926,23 @@
     #      - uses a custom built Bazel binary (in Bazel Downstream Projects pipeline)
     #      - testing incompatible flags
     #      - running `bazelisk --migrate` in a non-downstream pipeline
-    if (
+    should_update_last_green = (
         current_branch_is_main_branch()
         and pipeline_slug in all_downstream_pipeline_slugs
         and not (is_pull_request() or use_but or use_bazelisk_migrate())
-    ):
+    )
+
+    actually_print_shard_summary = has_sharded_task and print_shard_summary
+
+    if should_update_last_green or actually_print_shard_summary:
+        pipeline_steps.append({"wait": None, "continue_on_failure": True})
+
+    if should_update_last_green:
         # We need to call "Try Update Last Green Commit" even if there are failures,
         # since we don't want a failing Buildifier step to block the update of
         # the last green commit for this project.
         # try_update_last_green_commit() ensures that we don't update the commit
         # if any build or test steps fail.
-        pipeline_steps.append({"wait": None, "continue_on_failure": True})
         pipeline_steps.append(
             create_step(
                 label="Try Update Last Green Commit",
@@ -2943,6 +2967,18 @@
         number = os.getenv("BUILDKITE_BUILD_NUMBER")
         pipeline_steps += get_steps_for_aggregating_migration_results(number, notify)
 
+    if actually_print_shard_summary:
+        pipeline_steps.append(
+            create_step(
+                label="Print Test Summary for Shards",
+                commands=[
+                    fetch_bazelcipy_command(),
+                    PLATFORMS[DEFAULT_PLATFORM]["python"] + " bazelci.py print_shard_summary",
+                ],
+                platform=DEFAULT_PLATFORM,
+            )
+        )
+
     print_pipeline_steps(pipeline_steps, handle_emergencies=not is_downstream_pipeline())
 
 
@@ -3694,6 +3730,317 @@
     return sha256.hexdigest()
 
 
+def print_shard_summary():
+    tmpdir = tempfile.mkdtemp()
+    try:
+        print_collapsed_group("Fetching test artifacts...")
+        all_test_artifacts = get_artifacts_for_failing_tests()
+        print_collapsed_group("Dwonloading & parsing BEP files...")
+        for base_task, current_test_artifacts in all_test_artifacts.items():
+            failures = []
+            for test_artifact in current_test_artifacts:
+                local_bep_path = test_artifact.download_bep(tmpdir)
+                if not local_bep_path:
+                    # TODO: propagate errors
+                    continue
+
+                for test_execution in parse_bep(local_bep_path):
+                    if test_execution.overall_status == "PASSED":
+                        continue
+
+                    failures.append(test_execution.Format(test_artifact.job_id))
+
+            if failures:
+                message = "\n".join(failures)
+                execute_command(
+                    [
+                        "buildkite-agent",
+                        "annotate",
+                        "--style=error",
+                        f"**{base_task} Failures**\n\n{message}",
+                        "--context",
+                        f"{base_task}",
+                    ]
+                )
+    finally:
+        shutil.rmtree(tmpdir)
+
+
+def get_log_path_for_label(label, shard, total_shards, attempt, total_attempts):
+    parts = [label.lstrip("/").replace(":", "/")]
+    if total_shards > 1:
+        parts.append(f"shard_{shard}_of_{total_shards}")
+    if total_attempts > 1:
+        parts.append(f"test_attempts/attempt_{attempt}.log")
+    else:
+        parts.append("test.log")
+
+    return "/".join(parts)
+
+
+def get_artifacts_for_failing_tests():
+    org_slug = os.getenv("BUILDKITE_ORGANIZATION_SLUG")
+    pipeline_slug = os.getenv("BUILDKITE_PIPELINE_SLUG")
+    build_number = os.getenv("BUILDKITE_BUILD_NUMBER")
+
+    client = BuildkiteClient(org=org_slug, pipeline=pipeline_slug)
+    build_info = client.get_build_info(build_number)
+
+    paths = collections.defaultdict(list)
+    for job in build_info["jobs"]:
+        if job.get("state") in (None, "passed"):
+            continue
+
+        # This is a bit hacky, but saves us one API request per job (to check for BUILDKITE_PARALLEL_JOB)
+        match = _SHARD_RE.search(job.get("name", ""))
+        if not match:
+            continue
+
+        relative_bep_path, relative_log_paths = get_test_file_paths(job["id"])
+        # TODO: show build failures in the annotation, too?
+        if not relative_bep_path:
+            continue
+
+        base_task = match.group(1)
+        ta = TestArtifacts(
+            job_id=job["id"],
+            relative_bep_path=relative_bep_path,
+            relative_log_paths=relative_log_paths,
+        )
+        paths[base_task].append(ta)
+
+    return paths
+
+
+class TestArtifacts:
+    def __init__(self, job_id, relative_bep_path, relative_log_paths) -> None:
+        self.job_id = job_id
+        self.relative_bep_path = relative_bep_path
+        self.relative_log_paths = relative_log_paths
+
+    def download_bep(self, dest_dir: str) -> str:
+        job_dir = os.path.join(dest_dir, self.job_id)
+        os.makedirs(job_dir)
+
+        try:
+            execute_command(
+                [
+                    "buildkite-agent",
+                    "artifact",
+                    "download",
+                    f"*/{_TEST_BEP_FILE}",
+                    job_dir,
+                    "--step",
+                    self.job_id,
+                ]
+            )
+        except:
+            # TODO: handle exception
+            return None
+
+        return os.path.join(job_dir, self.relative_bep_path)
+
+
+def get_test_file_paths(job_id):
+    bep_path = None
+    log_paths = []
+
+    output = execute_command_and_get_output(
+        [
+            "buildkite-agent",
+            "artifact",
+            "search",
+            "*",
+            "--step",
+            job_id,
+        ],
+        fail_if_nonzero=False,
+    ).strip()
+
+    if not output or "no matches found" in output:
+        return None, []
+
+    for line in output.split("\n"):
+        parts = line.split(" ")
+        # Expected format:
+        # JOB_ID FILE_PATH TIMESTAMP
+        if len(parts) != 3:
+            continue
+
+        path = parts[1]
+        if path.endswith(_TEST_BEP_FILE):
+            bep_path = path
+        elif path.endswith(".log"):
+            log_paths.append(path)
+
+    return bep_path, log_paths
+
+
+def format_millis(millis):
+    def fmt(ms):
+        return "{:.1f}s".format(ms / 1000)
+
+    if len(millis) == 1:
+        return fmt(millis[0])
+
+    total = sum(millis)
+    return f"{fmt(total)} ({' + '.join(fmt(ms) for ms in millis)})"
+
+
+def format_test_status(status):
+    cls = {"PASSED": "green", "FLAKY": "purple"}.get(status, "red")
+    return f"<span class='{cls}'>{status}</span>"
+
+
+# TODO here and below: use @dataclasses.dataclass(frozen=True) once Python has been updated on Docker machines
+class TestAttempt:
+    def __init__(self, number, status, millis) -> None:
+        self.number = number
+        self.status = status
+        self.millis = millis
+
+
+class TestShard:
+    def __init__(self, number, attempts) -> None:
+        self.number = number
+        self.attempts = attempts
+
+    def _get_detailed_overall_status(self):
+        counter = collections.Counter([a.status for a in self.attempts])
+        passed = counter["PASSED"]
+        no_attempts = len(self.attempts)
+        if passed == no_attempts:
+            return "PASSED", no_attempts, no_attempts
+        elif passed and passed < no_attempts:
+            return "FLAKY", no_attempts - passed, no_attempts
+        elif counter["FAILED"]:
+            return "FAILED", counter["FAILED"], no_attempts
+
+        [(status, count)] = counter.most_common(1)
+        return status, count, no_attempts
+
+    def get_details(self):
+        overall, bad_runs, total_runs = self._get_detailed_overall_status()
+        qualifier = "" if not bad_runs else f"{bad_runs} out of "
+        return overall, (
+            f"in {qualifier}{total_runs} runs over {format_millis(self.attempt_millis)}"
+        )
+
+    @property
+    def overall_status(self):
+        return self._get_detailed_overall_status()[0]
+
+    @property
+    def attempt_millis(self):
+        return [a.millis for a in self.attempts]
+
+
+class TestExecution:
+    def __init__(self, label, shards) -> None:
+        self.label = label
+        self.shards = shards
+
+    @property
+    def overall_status(self):
+        status_set = set(s.overall_status for s in self.shards)
+        if len(status_set) > 1:
+            for status in (
+                "FAILED",
+                "TIMEOUT",
+                "NO_STATUS",
+                "INCOMPLETE",
+                "REMOTE_FAILURE",
+                "FAILED_TO_BUILD",
+                "PASSED",
+            ):
+                if status in status_set:
+                    return status
+
+        return next(iter(status_set))
+
+    @property
+    def critical_path(self):
+        max_millis = 0
+        path = None
+
+        for s in self.shards:
+            duration_millis = sum(s.attempt_millis)
+            if duration_millis > max_millis:
+                max_millis = duration_millis
+                path = s.attempt_millis
+
+        return format_millis(path)
+
+    def Format(self, job_id: str) -> str:
+        def get_log_url_for_shard(s):
+            local_log_path = get_log_path_for_label(
+                self.label,
+                s.number,
+                len(self.shards),
+                1,
+                len(s.attempts),
+            )
+            # TODO: check in relative_log_paths if log really exists?
+            return os.path.join(LOG_BUCKET, job_id, local_log_path)
+
+        def format_shard(s):
+            overall, statistics = shard.get_details()
+            return (
+                f"{format_test_status(overall)} {statistics}: [log]({get_log_url_for_shard(shard)})"
+            )
+
+        failing_shards = [s for s in self.shards if s.overall_status != "PASSED"]
+        if len(failing_shards) == 1:
+            [shard] = failing_shards
+            # TODO: show log links for failing attempts > 1?
+            return f"- {self.label} {format_shard(shard)}"
+
+        shard_info = "".join(
+            f"  - Shard {s.number}/{len(self.shards)}: {format_shard(s)}" for s in failing_shards
+        )
+        return f"- {self.label}\n{shard_info}"
+
+
+def parse_bep(path):
+    data = collections.defaultdict(dict)
+    for test, shard, attempt, status, millis in get_test_results_from_bep(path):
+        ta = TestAttempt(number=attempt, status=status, millis=millis)
+        if shard not in data[test]:
+            data[test][shard] = []
+
+        data[test][shard].append(ta)
+
+    tests = []
+    for test, attempts_per_shard in data.items():
+        shards = [
+            TestShard(number=shard, attempts=attempts_per_shard[shard])
+            for shard in sorted(attempts_per_shard.keys())
+        ]
+        tests.append(TestExecution(label=test, shards=shards))
+
+    return tests
+
+
+def get_test_results_from_bep(path):
+    with open(path, "rt") as f:
+        for line in f:
+            if "testResult" not in line:
+                continue
+
+            data = json.loads(line)
+            meta = data.get("id").get("testResult")
+            if not meta:
+                continue
+
+            yield (
+                meta["label"],
+                meta["shard"],
+                meta["attempt"],
+                data["testResult"]["status"],
+                int(data["testResult"]["testAttemptDurationMillis"]),
+            )
+
+
 def upload_bazel_binaries():
     """
     Uploads all Bazel binaries to a deterministic URL based on the current Git commit.
@@ -3895,6 +4242,7 @@
     project_pipeline.add_argument("--monitor_flaky_tests", type=bool, nargs="?", const=True)
     project_pipeline.add_argument("--use_but", type=bool, nargs="?", const=True)
     project_pipeline.add_argument("--notify", type=bool, nargs="?", const=True)
+    project_pipeline.add_argument("--print_shard_summary", type=bool, nargs="?", const=True)
 
     runner = subparsers.add_parser("runner")
     runner.add_argument("--task", action="store", type=str, default="")
@@ -3922,6 +4270,7 @@
     subparsers.add_parser("publish_binaries")
     subparsers.add_parser("try_update_last_green_commit")
     subparsers.add_parser("try_update_last_green_downstream_commit")
+    subparsers.add_parser("print_shard_summary")
 
     args = parser.parse_args(argv)
 
@@ -3968,6 +4317,7 @@
                 monitor_flaky_tests=args.monitor_flaky_tests,
                 use_but=args.use_but,
                 notify=args.notify,
+                print_shard_summary=args.print_shard_summary,
             )
         elif args.subparsers_name == "runner":
             # Fetch the repo in case we need to use file_config.
@@ -4016,6 +4366,8 @@
         elif args.subparsers_name == "try_update_last_green_downstream_commit":
             # Update the last green commit of the downstream pipeline
             try_update_last_green_downstream_commit()
+        elif args.subparsers_name == "print_shard_summary":
+            print_shard_summary()
         else:
             parser.print_help()
             return 2