blob: d7544369ed2f2137a3c71291cdb43de7c38ce555 [file] [log] [blame]
// Copyright 2010 The Bazel Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.testing.junit.runner.sharding;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.junit.runner.Description;
import org.junit.runner.manipulation.Filter;
/**
* Implements the round-robin sharding strategy.
*
* <p>This is done by equally dividing up the tests across all the shards
* Each test is numbered and the test number is modded with the number of
* shards and checked against the shard number to see whether it should run
* on a particular shard.
*
* <p>Equals and hashCode implementations are not necessary for correct
* sharding, but are done so that this filter can be compared in tests.
*/
public final class RoundRobinShardingFilter extends Filter {
// VisibleForTesting
final Map<Description, Integer> testToShardMap;
// VisibleForTesting
final int shardIndex;
// VisibleForTesting
final int totalShards;
public RoundRobinShardingFilter(Collection<Description> testDescriptions,
int shardIndex, int totalShards) {
if (shardIndex < 0 || totalShards <= shardIndex) {
throw new IllegalArgumentException();
}
this.testToShardMap = buildTestToShardMap(testDescriptions);
this.shardIndex = shardIndex;
this.totalShards = totalShards;
}
/**
* Given a list of test case descriptions, returns a mapping from each
* to its index in the list.
*/
private static Map<Description, Integer> buildTestToShardMap(
Collection<Description> testDescriptions) {
Map<Description, Integer> map = new HashMap<>();
// Sorting this list is incredibly important to correctness. Otherwise,
// "shuffled" suites would break the sharding protocol.
List<Description> sortedDescriptions = new ArrayList<>(testDescriptions);
Collections.sort(sortedDescriptions, new DescriptionComparator());
// If we get two descriptions that are equal, the shard number for the second
// one will overwrite the shard number for the first. Thus they'll run on the
// same shard.
int index = 0;
for (Description description : sortedDescriptions) {
if (!description.isTest()) {
throw new IllegalArgumentException("Test suite should not be included in the set of tests "
+ "to shard: " + description.getDisplayName());
}
map.put(description, index);
index++;
}
return Collections.unmodifiableMap(map);
}
@Override
public boolean shouldRun(Description description) {
if (description.isSuite()) {
return true;
}
Integer testNumber = testToShardMap.get(description);
if (testNumber == null) {
throw new IllegalArgumentException("This filter keeps a mapping from each test "
+ "description to a shard, and the given description was not passed in when "
+ "filter was constructed: " + description);
}
return (testNumber % totalShards) == shardIndex;
}
@Override
public String describe() {
return "round robin sharding filter";
}
// VisibleForTesting
static class DescriptionComparator implements Comparator<Description> {
@Override
public int compare(Description d1, Description d2) {
return d1.getDisplayName().compareTo(d2.getDisplayName());
}
}
}