Open sourcing junitrunner/java/com/google/testing/junit/runner/sharding/weighted.

--
MOS_MIGRATED_REVID=134046554
diff --git a/src/BUILD b/src/BUILD
index e78e52a..af59d9b 100644
--- a/src/BUILD
+++ b/src/BUILD
@@ -285,6 +285,7 @@
         "//src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding:srcs",
         "//src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/api:srcs",
         "//src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/testing:srcs",
+        "//src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/weighted:srcs",
         "//src/java_tools/junitrunner/java/com/google/testing/junit/runner/util:srcs",
         "//src/java_tools/singlejar:srcs",
         "//src/main/cpp:srcs",
diff --git a/src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/weighted/BUILD b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/weighted/BUILD
new file mode 100644
index 0000000..a35ed23
--- /dev/null
+++ b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/weighted/BUILD
@@ -0,0 +1,24 @@
+DEFAULT_VISIBILITY = [
+    "//java/com/google/testing/junit/runner:__subpackages__",
+    "//javatests/com/google/testing/junit/runner:__subpackages__",
+    "//third_party/bazel/src/java_tools/junitrunner/java/com/google/testing/junit/runner:__subpackages__",
+    "//third_party/bazel/src/java_tools/junitrunner/javatests/com/google/testing/junit/runner:__subpackages__",
+]
+
+package(default_visibility = ["//src:__subpackages__"])
+
+# TODO(bazel-team): This should be testonly = 1.
+java_library(
+    name = "weighted",
+    srcs = glob(["*.java"]),
+    deps = [
+        "//java/com/google/testing/util",
+        "//src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/api",
+        "//third_party:junit4",
+    ],
+)
+
+filegroup(
+    name = "srcs",
+    srcs = glob(["**"]),
+)
diff --git a/src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/weighted/BinStackingShardingFilterFactory.java b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/weighted/BinStackingShardingFilterFactory.java
new file mode 100644
index 0000000..5f59792
--- /dev/null
+++ b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/weighted/BinStackingShardingFilterFactory.java
@@ -0,0 +1,82 @@
+// Copyright 2015 The Bazel Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.testing.junit.runner.sharding.weighted;
+
+import com.google.testing.junit.runner.sharding.api.ShardingFilterFactory;
+import com.google.testing.junit.runner.sharding.api.WeightStrategy;
+import com.google.testing.util.RuntimeCost;
+import java.util.Collection;
+import org.junit.Ignore;
+import org.junit.runner.Description;
+import org.junit.runner.manipulation.Filter;
+
+/**
+ * A factory that creates a {@link WeightedShardingFilter} that extracts the weight for a test from
+ * the {@link RuntimeCost} annotations present in descriptions of tests.
+ */
+public final class BinStackingShardingFilterFactory implements ShardingFilterFactory {
+  static final String DEFAULT_TEST_WEIGHT_PROPERTY = "test.sharding.default_weight";
+  static final int DEFAULT_TEST_WEIGHT = 1;
+
+  private final int defaultTestWeight;
+
+  public BinStackingShardingFilterFactory() {
+    this(getDefaultTestWeight());
+  }
+
+  // VisibleForTesting
+  BinStackingShardingFilterFactory(int defaultTestWeight) {
+    this.defaultTestWeight = defaultTestWeight;
+  }
+
+  static int getDefaultTestWeight() {
+    String property = System.getProperty(DEFAULT_TEST_WEIGHT_PROPERTY);
+    if (property != null) {
+      return Integer.parseInt(property);
+    }
+    return DEFAULT_TEST_WEIGHT;
+  }
+
+  @Override
+  public Filter createFilter(
+      Collection<Description> testDescriptions, int shardIndex, int totalShards) {
+    return new WeightedShardingFilter(
+        testDescriptions,
+        shardIndex,
+        totalShards,
+        new RuntimeCostWeightStrategy(defaultTestWeight));
+  }
+
+  static class RuntimeCostWeightStrategy implements WeightStrategy {
+
+    private final int defaultTestWeight;
+
+    RuntimeCostWeightStrategy(int defaultTestWeight) {
+      this.defaultTestWeight = defaultTestWeight;
+    }
+
+    @Override
+    public int getDescriptionWeight(Description description) {
+      RuntimeCost runtimeCost = description.getAnnotation(RuntimeCost.class);
+      Ignore ignore = description.getAnnotation(Ignore.class);
+
+      if (runtimeCost == null || ignore != null) {
+        return defaultTestWeight;
+      } else {
+        return runtimeCost.value();
+      }
+    }
+  }
+}
diff --git a/src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/weighted/WeightedShardingFilter.java b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/weighted/WeightedShardingFilter.java
new file mode 100644
index 0000000..bb3a8aa
--- /dev/null
+++ b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/sharding/weighted/WeightedShardingFilter.java
@@ -0,0 +1,154 @@
+// Copyright 2015 The Bazel Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.testing.junit.runner.sharding.weighted;
+
+import com.google.testing.junit.runner.sharding.api.WeightStrategy;
+import com.google.testing.util.RuntimeCost;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import org.junit.runner.Description;
+import org.junit.runner.manipulation.Filter;
+
+/**
+ * A sharding function that attempts to evenly use time on all available
+ * shards while considering the test's weight.
+ *
+ * <p>When all tests have the same weight the sharding function behaves
+ * similarly to round robin.
+ */
+public final class WeightedShardingFilter extends Filter {
+  private final Map<Description, Integer> testToShardMap;
+  private final int shardIndex;
+
+  public WeightedShardingFilter(Collection<Description> descriptions, int shardIndex,
+      int totalShards, WeightStrategy weightStrategy) {
+    if (shardIndex < 0 || totalShards <= shardIndex) {
+      throw new IllegalArgumentException();
+    }
+    this.shardIndex = shardIndex;
+    this.testToShardMap = buildTestToShardMap(descriptions, totalShards, weightStrategy);
+  }
+
+  @Override
+  public String describe() {
+    return "bin stacking filter";
+  }
+
+  @Override
+  public boolean shouldRun(Description description) {
+    if (description.isSuite()) {
+      return true;
+    }
+    Integer shardForTest = testToShardMap.get(description);
+    if (shardForTest == null) {
+      throw new IllegalArgumentException("This filter keeps a mapping from each test "
+          + "description to a shard, and the given description was not passed in when "
+          + "filter was constructed: " + description);
+    }
+    return shardForTest == shardIndex;
+  }
+
+  private static Map<Description, Integer> buildTestToShardMap(
+      Collection<Description> descriptions, int numShards, WeightStrategy weightStrategy) {
+    Map<Description, Integer> map = new HashMap<>();
+
+    // Sorting this list is incredibly important to correctness. Otherwise,
+    // "shuffled" suites would break the sharding protocol.
+    List<Description> sortedDescriptions = new ArrayList<>(descriptions);
+    Collections.sort(sortedDescriptions, new WeightClassAndTestNameComparator(weightStrategy));
+
+    PriorityQueue<Shard> queue = new PriorityQueue<>(numShards);
+    for (int i = 0; i < numShards; i++) {
+      queue.offer(new Shard(i));
+    }
+
+    // If we get two descriptions that are equal, the shard number for the second
+    // one will overwrite the shard number for the first.  Thus they'll run on the
+    // same shard.
+    for (Description description : sortedDescriptions) {
+      if (!description.isTest()) {
+        throw new IllegalArgumentException("Test suite should not be included in the set of tests "
+            + "to shard: " + description.getDisplayName());
+      }
+
+      Shard shard = queue.remove();
+      shard.addWeight(weightStrategy.getDescriptionWeight(description));
+      queue.offer(shard);
+      map.put(description, shard.getIndex());
+    }
+    return Collections.unmodifiableMap(map);
+  }
+
+  /**
+   * A comparator that sorts by weight in descending order, then by test case name.
+   */
+  private static class WeightClassAndTestNameComparator implements Comparator<Description> {
+
+    private final WeightStrategy weightStrategy;
+
+    WeightClassAndTestNameComparator(WeightStrategy weightStrategy) {
+      this.weightStrategy = weightStrategy;
+    }
+
+    @Override
+    public int compare(Description d1, Description d2) {
+      int weight1 = weightStrategy.getDescriptionWeight(d1);
+      int weight2 = weightStrategy.getDescriptionWeight(d2);
+      if (weight1 != weight2) {
+        // We consider the reverse order when comparing weights.
+        return -1 * compareInts(weight1, weight2);
+      }
+      return d1.getDisplayName().compareTo(d2.getDisplayName());
+    }
+  }
+
+  /**
+   * A bean representing the sum of {@link RuntimeCost}s assigned to a shard.
+   */
+  private static class Shard implements Comparable<Shard> {
+    private final int index;
+    private int weight = 0;
+
+    Shard(int index) {
+      this.index = index;
+    }
+
+    void addWeight(int weight) {
+      this.weight += weight;
+    }
+
+    int getIndex() {
+      return index;
+    }
+
+    @Override
+    public int compareTo(Shard other) {
+      if (weight != other.weight) {
+        return compareInts(weight, other.weight);
+      }
+      return compareInts(index, other.index);
+    }
+  }
+
+  private static int compareInts(int value1, int value2) {
+    return value1 - value2;
+  }
+}