Use our java test runner in Bazel

RELNOTES[NEW]: A new java test runner that support XML output and test filtering is supported.
          It can be used by specifying --nolegacy_bazel_java_test or by speicifying the test_class
          attribute on a java_test.

--
MOS_MIGRATED_REVID=112028955
diff --git a/examples/java-native/src/test/java/com/example/myproject/BUILD b/examples/java-native/src/test/java/com/example/myproject/BUILD
index 9a91a82..b9d81a1 100644
--- a/examples/java-native/src/test/java/com/example/myproject/BUILD
+++ b/examples/java-native/src/test/java/com/example/myproject/BUILD
@@ -17,6 +17,16 @@
 )
 
 java_test(
+    name = "custom_with_test_class",
+    srcs = glob(["Test*.java"]),
+    test_class = "com.example.myproject.TestCustomGreeting",
+    deps = [
+        "//examples/java-native/src/main/java/com/example/myproject:custom-greeting",
+        "//third_party:junit4",
+    ],
+)
+
+java_test(
     name = "fail",
     srcs = ["Fail.java"],
     deps = ["//third_party:junit4"],
diff --git a/src/main/java/com/google/devtools/build/lib/bazel/rules/java/BazelJavaRuleClasses.java b/src/main/java/com/google/devtools/build/lib/bazel/rules/java/BazelJavaRuleClasses.java
index 0dea311..b065913 100644
--- a/src/main/java/com/google/devtools/build/lib/bazel/rules/java/BazelJavaRuleClasses.java
+++ b/src/main/java/com/google/devtools/build/lib/bazel/rules/java/BazelJavaRuleClasses.java
@@ -25,10 +25,13 @@
 import static com.google.devtools.build.lib.syntax.Type.STRING_LIST;
 
 import com.google.common.collect.ImmutableSet;
+import com.google.devtools.build.lib.Constants;
 import com.google.devtools.build.lib.analysis.BaseRuleClasses;
 import com.google.devtools.build.lib.analysis.RuleDefinition;
 import com.google.devtools.build.lib.analysis.RuleDefinitionEnvironment;
 import com.google.devtools.build.lib.bazel.rules.cpp.BazelCppRuleClasses;
+import com.google.devtools.build.lib.packages.Attribute;
+import com.google.devtools.build.lib.packages.AttributeMap;
 import com.google.devtools.build.lib.packages.ImplicitOutputsFunction;
 import com.google.devtools.build.lib.packages.PredicateWithMessage;
 import com.google.devtools.build.lib.packages.Rule;
@@ -38,6 +41,7 @@
 import com.google.devtools.build.lib.packages.RuleClass.PackageNameConstraint;
 import com.google.devtools.build.lib.packages.TriState;
 import com.google.devtools.build.lib.rules.java.JavaSemantics;
+import com.google.devtools.build.lib.syntax.Type;
 import com.google.devtools.build.lib.util.FileTypeSet;
 
 import java.util.Set;
@@ -50,6 +54,9 @@
   public static final PredicateWithMessage<Rule> JAVA_PACKAGE_NAMES = new PackageNameConstraint(
       PackageNameConstraint.ANY_SEGMENT, "java", "javatests");
 
+  protected static final String JUNIT_TESTRUNNER =
+      Constants.TOOLS_REPOSITORY + "//tools/jdk:TestRunner_deploy.jar";
+
   public static final ImplicitOutputsFunction JAVA_BINARY_IMPLICIT_OUTPUTS =
       fromFunctions(
           JavaSemantics.JAVA_BINARY_CLASS_JAR,
@@ -268,6 +275,7 @@
    * Base class for rule definitions producing Java binaries.
    */
   public static final class BaseJavaBinaryRule implements RuleDefinition {
+
     @Override
     public RuleClass build(Builder builder, final RuleDefinitionEnvironment env) {
       return builder
@@ -298,6 +306,23 @@
           </p>
           <!-- #END_BLAZE_RULE.ATTRIBUTE --> */
           .add(attr("jvm_flags", STRING_LIST))
+          /* <!-- #BLAZE_RULE($base_java_binary).ATTRIBUTE(use_testrunner) -->
+          Use the
+          <code>com.google.testing.junit.runner.GoogleTestRunner</code> class as the
+          main entry point for a Java program.
+          ${SYNOPSIS}
+
+          You can use this to override the default
+          behavior, which is to use <code>BazelTestRunner</code> for
+          <code>java_test</code> rules,
+          and not use it for <code>java_binary</code> rules.  It is unlikely
+          you will want to do this.  One use is for <code>AllTest</code>
+          rules that are invoked by another rule (to set up a database
+          before running the tests, for example).  The <code>AllTest</code>
+          rule must be declared as a <code>java_binary</code>, but should
+          still use the test runner as its main entry point.
+          <!-- #END_BLAZE_RULE.ATTRIBUTE --> */
+          .add(attr("use_testrunner", BOOLEAN).value(false))
           /* <!-- #BLAZE_RULE($base_java_binary).ATTRIBUTE(main_class) -->
           Name of class with <code>main()</code> method to use as entry point.
           ${SYNOPSIS}
@@ -322,6 +347,15 @@
           .add(attr("create_executable", BOOLEAN)
               .nonconfigurable("internal")
               .value(true))
+          .add(attr("$testsupport", LABEL).value(
+              new Attribute.ComputedDefault("use_testrunner") {
+                @Override
+                public Object getDefault(AttributeMap rule) {
+                  return rule.get("use_testrunner", Type.BOOLEAN)
+                    ? env.getLabel(JUNIT_TESTRUNNER)
+                    : null;
+                }
+              }))
           /* <!-- #BLAZE_RULE($base_java_binary).ATTRIBUTE(deploy_manifest_lines) -->
           ${SYNOPSIS}
           A list of lines to add to the <code>META-INF/manifest.mf</code> file generated for the
diff --git a/src/main/java/com/google/devtools/build/lib/bazel/rules/java/BazelJavaSemantics.java b/src/main/java/com/google/devtools/build/lib/bazel/rules/java/BazelJavaSemantics.java
index 182c213..f55c0d9 100644
--- a/src/main/java/com/google/devtools/build/lib/bazel/rules/java/BazelJavaSemantics.java
+++ b/src/main/java/com/google/devtools/build/lib/bazel/rules/java/BazelJavaSemantics.java
@@ -16,11 +16,14 @@
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Iterables;
 import com.google.devtools.build.lib.actions.Artifact;
 import com.google.devtools.build.lib.analysis.RuleConfiguredTarget.Mode;
 import com.google.devtools.build.lib.analysis.RuleConfiguredTargetBuilder;
 import com.google.devtools.build.lib.analysis.RuleContext;
 import com.google.devtools.build.lib.analysis.Runfiles;
+import com.google.devtools.build.lib.analysis.RunfilesProvider;
 import com.google.devtools.build.lib.analysis.TransitiveInfoCollection;
 import com.google.devtools.build.lib.analysis.actions.CustomCommandLine;
 import com.google.devtools.build.lib.analysis.actions.TemplateExpansionAction;
@@ -38,6 +41,7 @@
 import com.google.devtools.build.lib.rules.java.JavaConfiguration;
 import com.google.devtools.build.lib.rules.java.JavaHelper;
 import com.google.devtools.build.lib.rules.java.JavaPrimaryClassProvider;
+import com.google.devtools.build.lib.rules.java.JavaRunfilesProvider;
 import com.google.devtools.build.lib.rules.java.JavaSemantics;
 import com.google.devtools.build.lib.rules.java.JavaTargetAttributes;
 import com.google.devtools.build.lib.rules.java.JavaUtil;
@@ -51,6 +55,7 @@
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.List;
+import java.util.Set;
 
 /**
  * Semantics for Bazel Java rules
@@ -80,14 +85,29 @@
     }
   }
   
-  private String getMainClassInternal(RuleContext ruleContext) {
-    return ruleContext.getRule().isAttrDefined("main_class", Type.STRING)
+  private String getMainClassInternal(RuleContext ruleContext, JavaCommon javaCommon) {
+    String mainClass = ruleContext.getRule().isAttrDefined("main_class", Type.STRING)
         ? ruleContext.attributes().get("main_class", Type.STRING) : "";
+    boolean createExecutable = ruleContext.attributes().get("create_executable", Type.BOOLEAN);
+    boolean useTestrunner = ruleContext.attributes().get("use_testrunner", Type.BOOLEAN)
+        && !useLegacyJavaTest(ruleContext);
+
+    if (createExecutable) {
+      if (useTestrunner) {
+        mainClass = "com.google.testing.junit.runner.BazelTestRunner";
+      } else { /* java_binary or non-Junit java_test */
+        if (mainClass.isEmpty()) {
+          mainClass = javaCommon.determinePrimaryClass(javaCommon.getSrcsArtifacts());
+        }
+      }
+    }
+
+    return mainClass;
   }
 
   private void checkMainClass(RuleContext ruleContext, JavaCommon javaCommon) {
     boolean createExecutable = ruleContext.attributes().get("create_executable", Type.BOOLEAN);
-    String mainClass = getMainClassInternal(ruleContext);
+    String mainClass = getMainClassInternal(ruleContext, javaCommon);
 
     if (!createExecutable && !mainClass.isEmpty()) {
       ruleContext.ruleError("main class must not be specified when executable is not created");
@@ -95,8 +115,7 @@
 
     if (createExecutable && mainClass.isEmpty()) {
       if (javaCommon.getSrcsArtifacts().isEmpty()) {
-        ruleContext.ruleError(
-            "need at least one of 'main_class', 'use_testrunner' or Java source files");
+        ruleContext.ruleError("need at least one of 'main_class' or Java source files");
       }
       mainClass = javaCommon.determinePrimaryClass(javaCommon.getSrcsArtifacts());
       if (mainClass == null) {
@@ -111,7 +130,7 @@
   @Override
   public String getMainClass(RuleContext ruleContext, JavaCommon javaCommon) {
     checkMainClass(ruleContext, javaCommon);
-    return getMainClassInternal(ruleContext);
+    return getMainClassInternal(ruleContext, javaCommon);
   }
 
   @Override
@@ -198,9 +217,30 @@
     }
   }
 
+  private TransitiveInfoCollection getTestSupport(RuleContext ruleContext) {
+    if (!isJavaBinaryOrJavaTest(ruleContext)) {
+      return null;
+    }
+    if (useLegacyJavaTest(ruleContext)) {
+      return null;
+    }
+
+    boolean createExecutable = ruleContext.attributes().get("create_executable", Type.BOOLEAN);
+    if (createExecutable && ruleContext.attributes().get("use_testrunner", Type.BOOLEAN)) {
+      return Iterables.getOnlyElement(ruleContext.getPrerequisites("$testsupport", Mode.TARGET));
+    } else {
+      return null;
+    }
+  }
+
   @Override
   public void addRunfilesForBinary(RuleContext ruleContext, Artifact launcher,
       Runfiles.Builder runfilesBuilder) {
+    TransitiveInfoCollection testSupport = getTestSupport(ruleContext);
+    if (testSupport != null) {
+      runfilesBuilder.addTarget(testSupport, JavaRunfilesProvider.TO_RUNFILES);
+      runfilesBuilder.addTarget(testSupport, RunfilesProvider.DEFAULT_RUNFILES);
+    }
   }
 
   @Override
@@ -210,6 +250,13 @@
   @Override
   public void collectTargetsTreatedAsDeps(
       RuleContext ruleContext, ImmutableList.Builder<TransitiveInfoCollection> builder) {
+    TransitiveInfoCollection testSupport = getTestSupport(ruleContext);
+    if (testSupport != null) {
+      // TODO(bazel-team): The testsupport is used as the test framework
+      // and really only needs to be on the runtime, not compile-time
+      // classpath.
+      builder.add(testSupport);
+    }
   }
 
   @Override
@@ -230,17 +277,87 @@
       NestedSetBuilder<Artifact> filesBuilder,
       RuleConfiguredTargetBuilder ruleBuilder) {
     if (isJavaBinaryOrJavaTest(ruleContext)) {
-      boolean createExec = ruleContext.attributes().get("create_executable", Type.BOOLEAN);
-      ruleBuilder.add(JavaPrimaryClassProvider.class, 
-          new JavaPrimaryClassProvider(createExec ? getMainClassInternal(ruleContext) : null));
+      ruleBuilder.add(
+          JavaPrimaryClassProvider.class,
+          new JavaPrimaryClassProvider(getPrimaryClass(ruleContext, javaCommon)));
     }
   }
 
+  // TODO(dmarting): simplify that logic when we remove the legacy Bazel java_test behavior.
+  private String getPrimaryClassLegacy(RuleContext ruleContext, JavaCommon javaCommon) {
+    boolean createExecutable = ruleContext.attributes().get("create_executable", Type.BOOLEAN);
+    if (!createExecutable) {
+      return null;
+    }
+    return getMainClassInternal(ruleContext, javaCommon);
+  }
+
+  private String getPrimaryClassNew(RuleContext ruleContext, JavaCommon javaCommon) {
+    boolean createExecutable = ruleContext.attributes().get("create_executable", Type.BOOLEAN);
+    Set<Artifact> sourceFiles = ImmutableSet.copyOf(javaCommon.getSrcsArtifacts());
+
+    if (!createExecutable) {
+      return null;
+    }
+
+    boolean useTestrunner = ruleContext.attributes().get("use_testrunner", Type.BOOLEAN);
+
+    String testClass = ruleContext.getRule().isAttrDefined("test_class", Type.STRING)
+        ? ruleContext.attributes().get("test_class", Type.STRING) : "";
+
+    if (useTestrunner) {
+      if (testClass.isEmpty()) {
+        testClass = javaCommon.determinePrimaryClass(sourceFiles);
+        if (testClass == null) {
+          ruleContext.ruleError("cannot determine junit.framework.Test class "
+                    + "(Found no source file '" + ruleContext.getTarget().getName()
+                    + ".java' and package name doesn't include 'java' or 'javatests'. "
+                    + "You might want to rename the rule or add a 'test_class' "
+                    + "attribute.)");
+        }
+      }
+      return testClass;
+    } else {
+      if (!testClass.isEmpty()) {
+        ruleContext.attributeError("test_class", "this attribute is only meaningful to "
+            + "BazelTestRunner, but you are not using it (use_testrunner = 0)");
+      }
+
+      return getMainClassInternal(ruleContext, javaCommon);
+    }
+  }
+  
+  private String getPrimaryClass(RuleContext ruleContext, JavaCommon javaCommon) {
+    return useLegacyJavaTest(ruleContext) ? getPrimaryClassLegacy(ruleContext, javaCommon)
+        : getPrimaryClassNew(ruleContext, javaCommon);
+  }
   
   @Override
   public Iterable<String> getJvmFlags(
       RuleContext ruleContext, JavaCommon javaCommon, List<String> userJvmFlags) {
-    return userJvmFlags;
+    ImmutableList.Builder<String> jvmFlags = ImmutableList.builder();
+    jvmFlags.addAll(userJvmFlags);
+
+    if (!useLegacyJavaTest(ruleContext)) {
+      if (ruleContext.attributes().get("use_testrunner", Type.BOOLEAN)) {
+        String testClass = ruleContext.getRule().isAttrDefined("test_class", Type.STRING)
+            ? ruleContext.attributes().get("test_class", Type.STRING) : "";
+        if (testClass.isEmpty()) {
+          testClass = javaCommon.determinePrimaryClass(javaCommon.getSrcsArtifacts());
+        }
+
+        if (testClass == null) {
+          ruleContext.ruleError("cannot determine test class");
+        } else {
+          // Always run junit tests with -ea (enable assertion)
+          jvmFlags.add("-ea");
+          // "suite" is a misnomer.
+          jvmFlags.add("-Dbazel.test_suite=" +  ShellEscaper.escapeString(testClass));
+        }
+      }
+    }
+
+    return jvmFlags.build();
   }
 
   @Override
@@ -309,22 +426,29 @@
   @Override
   public List<String> getExtraArguments(RuleContext ruleContext, JavaCommon javaCommon) {
     if (ruleContext.getRule().getRuleClass().equals("java_test")) {
-      if (ruleContext.getConfiguration().getTestArguments().isEmpty()
-          && !ruleContext.attributes().isAttributeValueExplicitlySpecified("args")) {
-        ImmutableList.Builder<String> builder = ImmutableList.builder();
-        for (Artifact artifact : javaCommon.getSrcsArtifacts()) {
-          PathFragment path = artifact.getRootRelativePath();
-          String className = JavaUtil.getJavaFullClassname(FileSystemUtils.removeExtension(path));
-          if (className != null) {
-            builder.add(className);
+      if (useLegacyJavaTest(ruleContext)) {
+        if (ruleContext.getConfiguration().getTestArguments().isEmpty()
+            && !ruleContext.attributes().isAttributeValueExplicitlySpecified("args")) {
+          ImmutableList.Builder<String> builder = ImmutableList.builder();
+          for (Artifact artifact : javaCommon.getSrcsArtifacts()) {
+            PathFragment path = artifact.getRootRelativePath();
+            String className = JavaUtil.getJavaFullClassname(FileSystemUtils.removeExtension(path));
+            if (className != null) {
+              builder.add(className);
+            }
           }
+          return builder.build();
         }
-        return builder.build();
       }
     }
     return ImmutableList.<String>of();
   }
 
+  private boolean useLegacyJavaTest(RuleContext ruleContext) {
+    return !ruleContext.attributes().isAttributeValueExplicitlySpecified("test_class")
+        && ruleContext.getFragment(JavaConfiguration.class).useLegacyBazelJavaTest();
+  }
+
   @Override
   public String getJavaBuilderMainClass() {
     return JAVABUILDER_CLASS_NAME;
diff --git a/src/main/java/com/google/devtools/build/lib/bazel/rules/java/BazelJavaTestRule.java b/src/main/java/com/google/devtools/build/lib/bazel/rules/java/BazelJavaTestRule.java
index e9d44be..c3c5b75 100644
--- a/src/main/java/com/google/devtools/build/lib/bazel/rules/java/BazelJavaTestRule.java
+++ b/src/main/java/com/google/devtools/build/lib/bazel/rules/java/BazelJavaTestRule.java
@@ -17,6 +17,7 @@
 import static com.google.devtools.build.lib.packages.Attribute.attr;
 import static com.google.devtools.build.lib.packages.BuildType.LABEL;
 import static com.google.devtools.build.lib.packages.BuildType.TRISTATE;
+import static com.google.devtools.build.lib.syntax.Type.BOOLEAN;
 import static com.google.devtools.build.lib.syntax.Type.STRING;
 
 import com.google.devtools.build.lib.analysis.BaseRuleClasses;
@@ -50,9 +51,45 @@
     return builder
         .requiresConfigurationFragments(JavaConfiguration.class, Jvm.class)
         .setImplicitOutputsFunction(BazelJavaRuleClasses.JAVA_BINARY_IMPLICIT_OUTPUTS)
-        .override(attr("main_class", STRING).value(JUNIT4_RUNNER))
         .override(attr("stamp", TRISTATE).value(TriState.NO))
+        .override(attr("use_testrunner", BOOLEAN).value(true))
         .override(attr(":java_launcher", LABEL).value(JavaSemantics.JAVA_LAUNCHER))
+        // TODO(dmarting): remove once we drop the legacy bazel java_test behavior.
+        .override(attr("main_class", STRING).value(JUNIT4_RUNNER))
+        /* <!-- #BLAZE_RULE(java_test).ATTRIBUTE(test_class) -->
+        The Java class to be loaded by the test runner.<br/>
+        ${SYNOPSIS}
+        <p>
+          By default, if this argument is not defined then the legacy mode is used and the
+          test arguments are used instead. Set the <code>--nolegacy_bazel_java_test</code> flag
+          to not fallback on the first argument.
+        </p>
+        <p>
+          This attribute specifies the name of a Java class to be run by
+          this test. It is rare to need to set this. If this argument is omitted, the Java class
+          whose name corresponds to the <code>name</code> of this
+          <code>java_test</code> rule will be used.
+        </p>
+        <p>
+          For JUnit3, the test class needs to either be a subclass of
+          <code>junit.framework.TestCase</code> or it needs to have a public
+          static <code>suite()</code> method that returns a
+          <code>junit.framework.Test</code> (or a subclass of <code>Test</code>).
+          For JUnit4, the class needs to be annotated with
+          <code>org.junit.runner.RunWith</code>.
+        </p>
+        <p>
+          This attribute allows several <code>java_test</code> rules to
+          share the same <code>Test</code>
+          (<code>TestCase</code>, <code>TestSuite</code>, ...).  Typically
+          additional information is passed to it
+          (e.g. via <code>jvm_flags=['-Dkey=value']</code>) so that its
+          behavior differs in each case, such as running a different
+          subset of the tests.  This attribute also enables the use of
+          Java tests outside the <code>javatests</code> tree.
+        </p>
+        <!-- #END_BLAZE_RULE.ATTRIBUTE --> */
+        .add(attr("test_class", STRING))
         .build();
   }
 
@@ -81,10 +118,9 @@
 ${ATTRIBUTE_DEFINITION}
 
 <p>
-See the section on <a href="#java_binary_args">java_binary()</a> arguments, with the <i>caveat</i>
-that there is no <code>main_class</code> argument. This rule also supports all
-<a href="common-definitions.html#common-attributes-tests">attributes common to all test rules
-(*_test)</a>.
+See the section on <a href="#java_binary_args">java_binary()</a> arguments. This rule also
+supports all <a href="common-definitions.html#common-attributes-tests">attributes common
+to all test rules (*_test)</a>.
 </p>
 
 <h4 id="java_test_examples">Examples</h4>
diff --git a/src/main/java/com/google/devtools/build/lib/rules/java/JavaConfiguration.java b/src/main/java/com/google/devtools/build/lib/rules/java/JavaConfiguration.java
index 246253a..315cf67 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/java/JavaConfiguration.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/java/JavaConfiguration.java
@@ -142,12 +142,16 @@
   private final ImmutableList<Label> translationTargets;
   private final String javaCpu;
   private final JavaOptimizationMode javaOptimizationMode;
-
+  
   private final Label javaToolchain;
 
+  // TODO(dmarting): remove when we have rolled out the new behavior
+  private final boolean legacyBazelJavaTest;
+  
   JavaConfiguration(boolean generateJavaDeps,
       List<String> defaultJvmFlags, JavaOptions javaOptions, Label javaToolchain, String javaCpu,
-      ImmutableList<String> defaultJavaBuilderJvmOpts) throws InvalidConfigurationException {
+      ImmutableList<String> defaultJavaBuilderJvmOpts)
+          throws InvalidConfigurationException {
     this.commandLineJavacFlags =
         ImmutableList.copyOf(JavaHelper.tokenizeJavaOptions(javaOptions.javacOpts));
     this.javaLauncherLabel = javaOptions.javaLauncher;
@@ -169,6 +173,7 @@
     this.javaCpu = javaCpu;
     this.javaToolchain = javaToolchain;
     this.javaOptimizationMode = javaOptions.javaOptimizationMode;
+    this.legacyBazelJavaTest = javaOptions.legacyBazelJavaTest;
 
     ImmutableList.Builder<Label> translationsBuilder = ImmutableList.builder();
     for (String s : javaOptions.translationTargets) {
@@ -336,4 +341,12 @@
   public JavaOptimizationMode getJavaOptimizationMode() {
     return javaOptimizationMode;
   }
+  
+  /**
+   * Returns true if java_test in Bazel should behave in legacy mode that existed before we
+   * open-sourced our test runner.
+   */
+  public boolean useLegacyBazelJavaTest() {
+    return legacyBazelJavaTest;
+  }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/rules/java/JavaOptions.java b/src/main/java/com/google/devtools/build/lib/rules/java/JavaOptions.java
index df524dc..2a091ca 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/java/JavaOptions.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/java/JavaOptions.java
@@ -389,6 +389,12 @@
       help = "Applies desired link-time optimizations to Java binaries and tests.")
   public JavaOptimizationMode javaOptimizationMode;
 
+  @Option(name = "legacy_bazel_java_test",
+      defaultValue = "true",
+      category = "undocumented",
+      help = "Use the legacy mode of Bazel for java_test.")
+  public boolean legacyBazelJavaTest;
+
   @Override
   public FragmentOptions getHost(boolean fallback) {
     JavaOptions host = (JavaOptions) getDefault();
diff --git a/src/test/java/com/google/devtools/build/lib/analysis/mock/BazelAnalysisMock.java b/src/test/java/com/google/devtools/build/lib/analysis/mock/BazelAnalysisMock.java
index 9fbf932..2d7a0ad 100644
--- a/src/test/java/com/google/devtools/build/lib/analysis/mock/BazelAnalysisMock.java
+++ b/src/test/java/com/google/devtools/build/lib/analysis/mock/BazelAnalysisMock.java
@@ -88,7 +88,7 @@
         "filegroup(name='extdir', srcs=glob(['jdk/jre/lib/ext/*']))",
         // "dummy" is needed so that RedirectChaser stops here
         "filegroup(name='java', srcs = ['jdk/jre/bin/java', 'dummy'])",
-        "exports_files(['JavaBuilder_deploy.jar','SingleJar_deploy.jar',",
+        "exports_files(['JavaBuilder_deploy.jar','SingleJar_deploy.jar','TestRunner_deploy.jar',",
         "               'JavaBuilderCanary_deploy.jar', 'ijar', 'GenClass_deploy.jar'])");
 
 
diff --git a/src/test/shell/bazel/BUILD b/src/test/shell/bazel/BUILD
index b6f2dba..8d0be17 100644
--- a/src/test/shell/bazel/BUILD
+++ b/src/test/shell/bazel/BUILD
@@ -52,6 +52,7 @@
         "//src:bazel",
         "//src/java_tools/buildjar:JavaBuilder_deploy.jar",
         "//src/java_tools/buildjar/java/com/google/devtools/build/buildjar/genclass:GenClass_deploy.jar",
+        "//src/java_tools/junitrunner/java/com/google/testing/junit/runner:Runner_deploy.jar",
         "//src/java_tools/singlejar:SingleJar_deploy.jar",
         "//src/main/tools:namespace-sandbox",
         "//src/main/tools:process-wrapper",
diff --git a/src/test/shell/bazel/bazel_example_test.sh b/src/test/shell/bazel/bazel_example_test.sh
index 1596431..d20139b 100755
--- a/src/test/shell/bazel/bazel_example_test.sh
+++ b/src/test/shell/bazel/bazel_example_test.sh
@@ -70,6 +70,13 @@
   assert_test_fails "${java_native_tests}:resource-fail"
 }
 
+function test_java_test_with_junitrunner() {
+  # Test with junitrunner.
+  setup_javatest_support
+  local java_native_tests=//examples/java-native/src/test/java/com/example/myproject
+  assert_test_ok "${java_native_tests}:custom_with_test_class"
+}
+
 function test_java_test_with_workspace_name() {
   local java_pkg=examples/java-native/src/main/java/com/example/myproject
   # Use named workspace and test if we can still execute hello-world
diff --git a/src/test/shell/bazel/test-setup.sh b/src/test/shell/bazel/test-setup.sh
index 4130740..b9cc243 100755
--- a/src/test/shell/bazel/test-setup.sh
+++ b/src/test/shell/bazel/test-setup.sh
@@ -370,6 +370,7 @@
   ln -s "${javabuilder_path}" tools/jdk/JavaBuilder_deploy.jar
   ln -s "${singlejar_path}"  tools/jdk/SingleJar_deploy.jar
   ln -s "${genclass_path}" tools/jdk/GenClass_deploy.jar
+  ln -s "${junitrunner_path}" tools/jdk/TestRunner_deploy.jar
   ln -s "${ijar_path}" tools/jdk/ijar
 
   touch WORKSPACE
diff --git a/src/test/shell/bazel/testenv.sh b/src/test/shell/bazel/testenv.sh
index b2f76ae..c03fc99 100755
--- a/src/test/shell/bazel/testenv.sh
+++ b/src/test/shell/bazel/testenv.sh
@@ -42,6 +42,7 @@
 langtools_path="${TEST_SRCDIR}/third_party/java/jdk/langtools/javac.jar"
 singlejar_path="${TEST_SRCDIR}/src/java_tools/singlejar/SingleJar_deploy.jar"
 genclass_path="${TEST_SRCDIR}/src/java_tools/buildjar/java/com/google/devtools/build/buildjar/genclass/GenClass_deploy.jar"
+junitrunner_path="${TEST_SRCDIR}/src/java_tools/junitrunner/java/com/google/testing/junit/runner/Runner_deploy.jar"
 ijar_path="${TEST_SRCDIR}/third_party/ijar/ijar"
 
 # Sandbox tools