Use the fact that we know the shard count of a test ahead of time to simplify our data structures around sharding a little bit.

PiperOrigin-RevId: 381007451
diff --git a/src/main/java/com/google/devtools/build/lib/analysis/test/TestProvider.java b/src/main/java/com/google/devtools/build/lib/analysis/test/TestProvider.java
index d14c16e..3275913 100644
--- a/src/main/java/com/google/devtools/build/lib/analysis/test/TestProvider.java
+++ b/src/main/java/com/google/devtools/build/lib/analysis/test/TestProvider.java
@@ -52,6 +52,7 @@
   }
 
   /** A value class describing the properties of a test. */
+  // Non-final only for mocking.
   public static class TestParams {
     private final int runs;
     private final int shards;
diff --git a/src/main/java/com/google/devtools/build/lib/runtime/AggregatingTestListener.java b/src/main/java/com/google/devtools/build/lib/runtime/AggregatingTestListener.java
index ddfa43d..988159a 100644
--- a/src/main/java/com/google/devtools/build/lib/runtime/AggregatingTestListener.java
+++ b/src/main/java/com/google/devtools/build/lib/runtime/AggregatingTestListener.java
@@ -236,9 +236,8 @@
                 .setConfigurationKey(testTarget.getConfigurationKey())
                 .build();
         TestResultAggregator aggregator = aggregators.get(actualKey);
-        TestSummary.Builder summaryBuilder = TestSummary.newBuilder();
+        TestSummary.Builder summaryBuilder = TestSummary.newBuilder(testTarget);
         summaryBuilder.mergeFrom(aggregator.aggregateAndReportSummary(skipTargetsOnFailure));
-        summaryBuilder.setTarget(testTarget);
         summary = summaryBuilder.build();
       } else {
         TestResultAggregator aggregator = aggregators.get(asKey(testTarget));
@@ -255,7 +254,7 @@
         // just use NO_STATUS for all tests with failed validations for simplicity here (absent -k).
         // Events published on BEP are not affected by this, but validation failures are published
         // as separate events and are additionally accounted in TargetSummary BEP messages.
-        TestSummary.Builder summaryBuilder = TestSummary.newBuilder();
+        TestSummary.Builder summaryBuilder = TestSummary.newBuilder(summary.getTarget());
         summaryBuilder.mergeFrom(summary);
         summaryBuilder.setStatus(
             skipTargetsOnFailure
diff --git a/src/main/java/com/google/devtools/build/lib/runtime/TestResultAggregator.java b/src/main/java/com/google/devtools/build/lib/runtime/TestResultAggregator.java
index b08aa3e..1a120a3 100644
--- a/src/main/java/com/google/devtools/build/lib/runtime/TestResultAggregator.java
+++ b/src/main/java/com/google/devtools/build/lib/runtime/TestResultAggregator.java
@@ -65,8 +65,7 @@
       boolean skippedThisTest) {
     this.policy = policy;
     this.summary =
-        TestSummary.newBuilder()
-            .setTarget(target)
+        TestSummary.newBuilder(target)
             .setConfiguration(configuration)
             .setStatus(BlazeTestStatus.NO_STATUS)
             .setSkipped(skippedThisTest);
diff --git a/src/main/java/com/google/devtools/build/lib/runtime/TestSummary.java b/src/main/java/com/google/devtools/build/lib/runtime/TestSummary.java
index 77818ed..12c393a 100644
--- a/src/main/java/com/google/devtools/build/lib/runtime/TestSummary.java
+++ b/src/main/java/com/google/devtools/build/lib/runtime/TestSummary.java
@@ -13,13 +13,13 @@
 // limitations under the License.
 package com.google.devtools.build.lib.runtime;
 
+import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.collect.ImmutableList.toImmutableList;
+
 import com.google.common.base.MoreObjects;
 import com.google.common.base.Preconditions;
-import com.google.common.collect.ArrayListMultimap;
 import com.google.common.collect.ComparisonChain;
 import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Multimap;
-import com.google.common.collect.MultimapBuilder;
 import com.google.devtools.build.lib.analysis.AliasProvider;
 import com.google.devtools.build.lib.analysis.ConfiguredTarget;
 import com.google.devtools.build.lib.analysis.config.BuildConfiguration;
@@ -49,6 +49,7 @@
 import java.util.List;
 import java.util.Map;
 import java.util.TreeMap;
+import java.util.stream.Stream;
 import javax.annotation.Nullable;
 
 /**
@@ -66,19 +67,19 @@
     private TestSummary summary;
     private boolean built;
 
-    private Builder() {
-      summary = new TestSummary();
+    private Builder(ConfiguredTarget target) {
+      summary = new TestSummary(target);
       built = false;
     }
 
     void mergeFrom(TestSummary existingSummary) {
       // Yuck, manually fill in fields.
-      summary.shardRunStatuses =
-          MultimapBuilder.hashKeys().arrayListValues().build(existingSummary.shardRunStatuses);
+      for (int i = 0; i < existingSummary.shardRunStatuses.size(); i++) {
+        summary.shardRunStatuses.get(i).addAll(existingSummary.shardRunStatuses.get(i));
+      }
       summary.firstStartTimeMillis = existingSummary.firstStartTimeMillis;
       summary.lastStopTimeMillis = existingSummary.lastStopTimeMillis;
       summary.totalRunDurationMillis = existingSummary.totalRunDurationMillis;
-      setTarget(existingSummary.target);
       setConfiguration(existingSummary.configuration);
       setStatus(existingSummary.status);
       addCoverageFiles(existingSummary.coverageFiles);
@@ -107,7 +108,7 @@
       if (built) {
         built = false;
         TestSummary lastSummary = summary;
-        summary = new TestSummary();
+        summary = new TestSummary(lastSummary.target);
         mergeFrom(lastSummary);
       }
     }
@@ -116,19 +117,13 @@
     // However, since it can alter the summary member, inlining it in an
     // assignment to a property of summary was unsafe.
     private void checkMutation(Object value) {
-      Preconditions.checkNotNull(value);
+      checkNotNull(value);
       checkMutation();
     }
 
-    public Builder setTarget(ConfiguredTarget target) {
-      checkMutation(target);
-      summary.target = target;
-      return this;
-    }
-
     public Builder setConfiguration(BuildConfiguration configuration) {
       checkMutation(configuration);
-      summary.configuration = Preconditions.checkNotNull(configuration, summary);
+      summary.configuration = checkNotNull(configuration, summary);
       return this;
     }
 
@@ -314,10 +309,10 @@
      *
      * @return an immutable view of the statuses associated with the shard, with the new element.
      */
-    public List<BlazeTestStatus> addShardStatus(int shardNumber, BlazeTestStatus status) {
-      Preconditions.checkState(summary.shardRunStatuses.put(shardNumber, status),
-          "shardRunStatuses must allow duplicate statuses");
-      return ImmutableList.copyOf(summary.shardRunStatuses.get(shardNumber));
+    public ImmutableList<BlazeTestStatus> addShardStatus(int shardNumber, BlazeTestStatus status) {
+      List<BlazeTestStatus> statuses = summary.shardRunStatuses.get(shardNumber);
+      statuses.add(status);
+      return ImmutableList.copyOf(statuses);
     }
 
     /**
@@ -341,8 +336,8 @@
      * incompletely-built TestSummary. Used to pass Builders around directly.
      */
     TestSummary peek() {
-      Preconditions.checkNotNull(summary.target, "Target cannot be null");
-      Preconditions.checkNotNull(summary.status, "Status cannot be null");
+      checkNotNull(summary.target, "Target cannot be null");
+      checkNotNull(summary.status, "Status cannot be null");
       return summary;
     }
 
@@ -358,12 +353,13 @@
     }
   }
 
-  private ConfiguredTarget target;
+  private final ConfiguredTarget target;
+  // Currently only populated if --runs_per_test_detects_flakes is enabled.
+  private final ImmutableList<ArrayList<BlazeTestStatus>> shardRunStatuses;
+
   private BuildConfiguration configuration;
   private BlazeTestStatus status;
   private boolean skipped;
-  // Currently only populated if --runs_per_test_detects_flakes is enabled.
-  private Multimap<Integer, BlazeTestStatus> shardRunStatuses = ArrayListMultimap.create();
   private int numCached;
   private int numLocalActionCached;
   private boolean actionRan;
@@ -384,14 +380,23 @@
   @Nullable private DetailedExitCode systemFailure;
 
   // Don't allow public instantiation; go through the Builder.
-  private TestSummary() {
+  private TestSummary(ConfiguredTarget target) {
+    this.target = target;
+    TestParams testParams = getTestParams();
+    shardRunStatuses =
+        createAndInitialize(
+            testParams.runsDetectsFlakes() ? Math.max(testParams.getShards(), 1) : 0);
   }
 
-  /**
-   * Creates a new Builder allowing construction of a new TestSummary object.
-   */
-  public static Builder newBuilder() {
-    return new Builder();
+  private static ImmutableList<ArrayList<BlazeTestStatus>> createAndInitialize(int sz) {
+    return Stream.generate(() -> new ArrayList<BlazeTestStatus>(1))
+        .limit(sz)
+        .collect(toImmutableList());
+  }
+
+  /** Creates a new Builder allowing construction of a new TestSummary object. */
+  public static Builder newBuilder(ConfiguredTarget target) {
+    return new Builder(target);
   }
 
   public Label getLabel() {
@@ -599,7 +604,7 @@
   @Override
   public BuildEventStreamProtos.BuildEvent asStreamProto(BuildEventContext converters) {
     PathConverter pathConverter = converters.pathConverter();
-    TestParams testParams = target.getProvider(TestProvider.class).getTestParams();
+    TestParams testParams = getTestParams();
     BuildEventStreamProtos.TestSummary.Builder summaryBuilder =
         BuildEventStreamProtos.TestSummary.newBuilder()
             .setOverallStatus(BuildEventStreamerUtils.bepStatus(status))
@@ -627,4 +632,8 @@
     }
     return GenericBuildEvent.protoChaining(this).setTestSummary(summaryBuilder.build()).build();
   }
+
+  private TestParams getTestParams() {
+    return checkNotNull(target.getProvider(TestProvider.class).getTestParams(), target);
+  }
 }
diff --git a/src/test/java/com/google/devtools/build/lib/runtime/TestResultAggregatorTest.java b/src/test/java/com/google/devtools/build/lib/runtime/TestResultAggregatorTest.java
index 8a72814..7211f79 100644
--- a/src/test/java/com/google/devtools/build/lib/runtime/TestResultAggregatorTest.java
+++ b/src/test/java/com/google/devtools/build/lib/runtime/TestResultAggregatorTest.java
@@ -48,6 +48,7 @@
   public void configureMockParams() {
     when(mockParams.runsDetectsFlakes()).thenReturn(false);
     when(mockParams.getTimeout()).thenReturn(TestTimeout.LONG);
+    when(mockParams.getShards()).thenReturn(1);
   }
 
   @Test
@@ -113,8 +114,8 @@
 
   @Test
   public void cancelConcurrentTests_cancellationAfterPassIgnored() {
-    TestResultAggregator underTest = createAggregatorWithTestRuns(2);
     when(mockParams.runsDetectsFlakes()).thenReturn(true);
+    TestResultAggregator underTest = createAggregatorWithTestRuns(2);
 
     underTest.testEvent(
         testResult(
diff --git a/src/test/java/com/google/devtools/build/lib/runtime/TestSummaryTest.java b/src/test/java/com/google/devtools/build/lib/runtime/TestSummaryTest.java
index e4aae44..7df867f 100644
--- a/src/test/java/com/google/devtools/build/lib/runtime/TestSummaryTest.java
+++ b/src/test/java/com/google/devtools/build/lib/runtime/TestSummaryTest.java
@@ -84,8 +84,7 @@
   private TestSummary.Builder getTemplateBuilder() {
     BuildConfiguration configuration = Mockito.mock(BuildConfiguration.class);
     when(configuration.checksum()).thenReturn("abcdef");
-    return TestSummary.newBuilder()
-        .setTarget(stubTarget)
+    return TestSummary.newBuilder(stubTarget)
         .setConfiguration(configuration)
         .setStatus(BlazeTestStatus.PASSED)
         .setNumCached(NOT_CACHED)
@@ -596,6 +595,9 @@
     ConfiguredTarget target = Mockito.mock(ConfiguredTarget.class);
     when(target.getLabel()).thenReturn(Label.create(path, targetName));
     when(target.getConfigurationChecksum()).thenReturn("abcdef");
+    TestParams mockParams = Mockito.mock(TestParams.class);
+    when(mockParams.getShards()).thenReturn(1);
+    when(target.getProvider(TestProvider.class)).thenReturn(new TestProvider(mockParams));
     return target;
   }
 
@@ -625,8 +627,7 @@
   private static TestSummary createTestSummary(ConfiguredTarget target, BlazeTestStatus status,
                                                int numCached) {
     ImmutableList<TestCase> emptyList = ImmutableList.of();
-    TestSummary summary = TestSummary.newBuilder()
-        .setTarget(target)
+    return TestSummary.newBuilder(target)
         .setStatus(status)
         .setNumCached(numCached)
         .setActionRan(true)
@@ -635,7 +636,6 @@
         .addFailedTestCases(emptyList, FailedTestCasesStatus.FULL)
         .addTestTimes(SMALL_TIMING)
         .build();
-    return summary;
   }
 
   private TestSummary createTestSummary(BlazeTestStatus status, int numCached) {