Improve the implementation of the persistent test runner.

PiperOrigin-RevId: 295556822
diff --git a/src/java_tools/junitrunner/java/com/google/testing/junit/runner/BUILD b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/BUILD
index c627275..f16132b 100644
--- a/src/java_tools/junitrunner/java/com/google/testing/junit/runner/BUILD
+++ b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/BUILD
@@ -38,7 +38,6 @@
     name = "persistent_test_runner",
     srcs = [
         "PersistentTestRunner.java",
-        "PersistentTestRunnerClassLoader.java",
         "SuiteTestRunner.java",
     ],
     deps = [
diff --git a/src/java_tools/junitrunner/java/com/google/testing/junit/runner/PersistentTestRunner.java b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/PersistentTestRunner.java
index f465063..0ad6d81 100644
--- a/src/java_tools/junitrunner/java/com/google/testing/junit/runner/PersistentTestRunner.java
+++ b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/PersistentTestRunner.java
@@ -15,6 +15,7 @@
 package com.google.testing.junit.runner;
 
 import static java.nio.charset.StandardCharsets.UTF_8;
+import static java.util.stream.Collectors.toList;
 
 import com.google.common.hash.HashCode;
 import com.google.common.hash.Hasher;
@@ -31,9 +32,11 @@
 import java.io.InputStream;
 import java.io.PrintStream;
 import java.net.URL;
+import java.net.URLClassLoader;
 import java.util.List;
+import java.util.Objects;
+import java.util.Set;
 import java.util.stream.Collectors;
-import javax.annotation.Nullable;
 
 /** A utility class for running Java tests using persistent workers. */
 final class PersistentTestRunner {
@@ -49,8 +52,9 @@
    * <p>Uses three different classloaders:
    *
    * <ul>
-   *   <li>A classloader for direct dependencies: loads the test's classes and the first two layers
-   *       of its dependencies.
+   *   <li>A classloader for direct dependencies: loads the classes of the test target and its
+   *       direct dependencies, *excluding* those dependencies that are also among the rest of the
+   *       transitive dependencies
    *   <li>A classloader for transitive dependencies: loads the remaining classes from the
    *       transitive dependencies.
    *   <li>The system classloader: initialized at JVM startup time; loads the test runner's classes.
@@ -60,30 +64,44 @@
    *
    * <p>The direct dependencies classloader is the current thread's classloader.
    *
-   * <p>The transitive dependencies classloader checks if a class was already loaded by the direct
-   * dependencies classloader if it did not succeed in loading the class itself. This is required
-   * for classes loaded by the transitive classloader that reference classes in the direct
-   * dependencies classloader.
+   * <p>For example, given the following dependency graph:
    *
-   * <p>The default class loading logic in {@link ClassLoader} applies in all other cases.
+   * <p>TestTarget / | \ * * * a b c | * | * / * d e | * f
+   *
+   * <p>the classloaders load the following: - the direct classloader loads the classes of
+   * TestTarget, a and c, which are direct deps - the transitive classloader loads the classes of b,
+   * d, e and f, which are transitive deps Note that b is loaded by the transitive classloader even
+   * if it is a direct dependency, because b is also a dependency of d, which is a 2nd level
+   * dependency of TestTarget.
+   *
+   * <p>Excluding the direct dependencies that are also present lower in the dependency tree from
+   * direct dependencies classloader presents two advantages: 1. reduces the number of classes
+   * loaded by directDepsClassLoader, making it faster to create/load 2. avoids additional custom
+   * logic for loading classes, since we are now sure that the transitive dependencies classLoader
+   * doesn't reference anything from its child classloader
    *
    * <p>The direct and transitive classloaders are rebuilt before every test run only if the
-   * combined hash of the jars to be loaded has changed.
+   * combined hash of the jars to be loaded has changed. If the transitive classloader has to be
+   * rebuilt, the direct classloader will also be rebuilt to preserve correctness.
    */
   static int runPersistentTestRunner(
       String suitClassName, String workspacePrefix, SuiteTestRunner suiteTestRunner) {
     PrintStream originalStdOut = System.out;
     PrintStream originalStdErr = System.err;
 
-    // TODO(elenairina): Remove this variable after cl/282553936 is released.
-    String legacyTestRuntimeClasspathFile = System.getenv("TEST_RUNTIME_CLASSPATH_FILE");
-
     String directClasspathFile = System.getenv("TEST_DIRECT_CLASSPATH_FILE");
     String transitiveClasspathFile = System.getenv("TEST_TRANSITIVE_CLASSPATH_FILE");
     String absolutePathPrefix = System.getenv("JAVA_RUNFILES") + File.separator + workspacePrefix;
 
-    PersistentTestRunnerClassLoader transitiveDepsClassLoader = null;
-    PersistentTestRunnerClassLoader directDepsClassLoader = null;
+    // Loads the classes of the test target and its direct dependencies *excluding* those
+    // dependencies that are also among the rest of the transitive dependencies.
+    URLClassLoader directDepsClassLoader = null;
+    // Loads all the classes in the transitive dependencies, excluding those loaded
+    // by directDepsClassLoader
+    URLClassLoader transitiveDepsClassLoader = null;
+
+    HashCode previousTransitiveCombinedHash = null;
+    HashCode previousDirectCombinedHash = null;
 
     // Reading the work requests and solving them in sequence is not a problem because Bazel creates
     // up to --worker_max_instances (defaults to 4) instances per worker key.
@@ -96,28 +114,39 @@
           break;
         }
 
-        if (legacyTestRuntimeClasspathFile != null) {
-          // Re-use the same classloader variable in the legacy case for simplicity.
-          directDepsClassLoader =
-              maybeRecreateClassLoader(
-                  getFilesWithAbsolutePathPrefixFromFile(
-                      legacyTestRuntimeClasspathFile, absolutePathPrefix),
-                  ClassLoader.getSystemClassLoader(),
-                  null);
-        } else {
-          transitiveDepsClassLoader =
-              maybeRecreateClassLoader(
-                  getFilesWithAbsolutePathPrefixFromFile(
-                      transitiveClasspathFile, absolutePathPrefix),
-                  ClassLoader.getSystemClassLoader(),
-                  transitiveDepsClassLoader);
+        Set<File> transitiveDeps =
+            getFilesWithAbsolutePathPrefixFromFile(transitiveClasspathFile, absolutePathPrefix);
+        Set<File> directDeps =
+            getFilesWithAbsolutePathPrefixFromFile(directClasspathFile, absolutePathPrefix);
+        // Filter out duplicated dependencies from the directDeps, to ensure that the
+        // transitiveDepsClassLoader doesn't reference anything from directDepsClassLoader.
+        directDeps = filterOutDupedDeps(directDeps, transitiveDeps);
 
+        HashCode transitiveCombinedHash = getCombinedHashForFiles(transitiveDeps);
+        HashCode directCombinedHash = getCombinedHashForFiles(directDeps);
+
+        boolean recreateTransitiveClassloader =
+            transitiveDepsClassLoader == null
+                || !transitiveCombinedHash.equals(previousTransitiveCombinedHash);
+
+        // if the parent needs to be re-created, the child needs to be recreated also
+        boolean recreateDirectClassloader =
+            recreateTransitiveClassloader
+                || directDepsClassLoader == null
+                || !directCombinedHash.equals(previousDirectCombinedHash);
+
+        previousTransitiveCombinedHash = transitiveCombinedHash;
+        previousDirectCombinedHash = directCombinedHash;
+
+        if (recreateTransitiveClassloader) {
+          transitiveDepsClassLoader =
+              new URLClassLoader(
+                  convertFileListToURLArray(transitiveDeps), ClassLoader.getSystemClassLoader());
+        }
+
+        if (recreateDirectClassloader) {
           directDepsClassLoader =
-              maybeRecreateClassLoader(
-                  getFilesWithAbsolutePathPrefixFromFile(directClasspathFile, absolutePathPrefix),
-                  transitiveDepsClassLoader,
-                  directDepsClassLoader);
-          transitiveDepsClassLoader.setChild(directDepsClassLoader);
+              new URLClassLoader(convertFileListToURLArray(directDeps), transitiveDepsClassLoader);
         }
 
         Thread.currentThread().setContextClassLoader(directDepsClassLoader);
@@ -172,58 +201,57 @@
     }
   }
 
-  /**
-   * Returns a classloader that loads the given jars, only if the combined hash of their content is
-   * different than the one for the given previous classloader.
-   *
-   * <p>Returns previousClassLoader if the hashes are the same.
-   *
-   * <p>Needs to be called before every test run to avoid having stale classes. A class already
-   * loaded in a classloader can not be unloaded. To overcome this a new classloader has to be
-   * created at every test run.
-   */
-  private static PersistentTestRunnerClassLoader maybeRecreateClassLoader(
-      List<File> runtimeJars,
-      ClassLoader parent,
-      @Nullable PersistentTestRunnerClassLoader previousClassLoader)
-      throws IOException {
-    HashCode combinedHash = getCombinedHashForFiles(runtimeJars);
-    if (previousClassLoader != null && combinedHash.equals(previousClassLoader.getChecksum())) {
-      return previousClassLoader;
-    }
+  private static HashCode getCombinedHashForFiles(Set<File> files) {
+    List<HashCode> hashesAsBytes =
+        files.stream()
+            .parallel()
+            .map(file -> getFileHash(file))
+            .filter(Objects::nonNull)
+            .collect(toList());
 
-    return new PersistentTestRunnerClassLoader(
-        convertFileListToURLArray(runtimeJars), parent, combinedHash);
-  }
-
-  private static HashCode getCombinedHashForFiles(List<File> files) throws IOException {
+    // Update the hasher separately because Hasher.putBytes() is not safe for parallel operations
     Hasher hasher = Hashing.sha256().newHasher();
-    for (File file : files) {
-      hasher.putBytes(getFileHash(file).asBytes());
+    for (HashCode hash : hashesAsBytes) {
+      hasher.putBytes(hash.asBytes());
     }
     return hasher.hash();
   }
 
-  private static HashCode getFileHash(File file) throws IOException {
-    InputStream inputStream = new FileInputStream(file);
-    HashingInputStream hashingStream = new HashingInputStream(Hashing.sha256(), inputStream);
-    ByteStreams.copy(hashingStream, ByteStreams.nullOutputStream());
-    return hashingStream.hash();
+  private static HashCode getFileHash(File file) {
+    try {
+      InputStream inputStream;
+      inputStream = new FileInputStream(file);
+      HashingInputStream hashingStream = new HashingInputStream(Hashing.sha256(), inputStream);
+      ByteStreams.copy(hashingStream, ByteStreams.nullOutputStream());
+      return hashingStream.hash();
+    } catch (IOException e) {
+      // Throwing RuntimeException to fail the whole build, and still benefit from using parallel
+      // streams in getCombinedHashForFiles().
+      throw new RuntimeException(e);
+    }
   }
 
-  private static URL[] convertFileListToURLArray(List<File> jars) throws IOException {
+  private static URL[] convertFileListToURLArray(Set<File> jars) throws IOException {
     URL[] urls = new URL[jars.size()];
-    for (int i = 0; i < jars.size(); i++) {
-      urls[i] = jars.get(i).toURI().toURL();
+    int it = 0;
+    for (File jar : jars) {
+      urls[it++] = jar.toURI().toURL();
     }
     return urls;
   }
 
-  private static List<File> getFilesWithAbsolutePathPrefixFromFile(
+  private static Set<File> getFilesWithAbsolutePathPrefixFromFile(
       String runtimeClasspathFilename, String absolutePathPrefix) throws IOException {
     return Files.readLines(new File(runtimeClasspathFilename), UTF_8).stream()
         .map(entry -> new File(absolutePathPrefix + entry))
-        .collect(Collectors.toList());
+        .collect(Collectors.toSet());
+  }
+
+  private static Set<File> filterOutDupedDeps(Set<File> smallSet, Set<File> largeSet) {
+    return smallSet.stream()
+        .parallel()
+        .filter(f -> !largeSet.contains(f))
+        .collect(Collectors.toSet());
   }
 
   static boolean isPersistentTestRunner() {
diff --git a/src/java_tools/junitrunner/java/com/google/testing/junit/runner/PersistentTestRunnerClassLoader.java b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/PersistentTestRunnerClassLoader.java
deleted file mode 100644
index c5e04bf..0000000
--- a/src/java_tools/junitrunner/java/com/google/testing/junit/runner/PersistentTestRunnerClassLoader.java
+++ /dev/null
@@ -1,71 +0,0 @@
-// Copyright 2020 The Bazel Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//    http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package com.google.testing.junit.runner;
-
-import com.google.common.hash.HashCode;
-import java.net.URL;
-import java.net.URLClassLoader;
-
-/**
- * A custom classloader used by the persistent test runner.
- *
- * <p>Each classloader stores the combined hash code for the loaded jars.
- */
-final class PersistentTestRunnerClassLoader extends URLClassLoader {
-
-  private final HashCode checksum;
-  private PersistentTestRunnerClassLoader child;
-
-  public PersistentTestRunnerClassLoader(URL[] urls, ClassLoader parent, HashCode checksum) {
-    super(urls, parent);
-    this.checksum = checksum;
-  }
-
-  void setChild(PersistentTestRunnerClassLoader child) {
-    this.child = child;
-  }
-
-  HashCode getChecksum() {
-    return checksum;
-  }
-
-  /**
-   * Loads the class with the specified name and resolves it if required.
-   *
-   * <p>If the classloader has a child: check if the class was already loaded by the child if the
-   * current classloader did not succeed in loading the class.
-   *
-   * <p>If the classloader doesn't have a child: use the default class loading logic.
-   */
-  @Override
-  public Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException {
-    if (child == null) {
-      return super.loadClass(name, resolve);
-    }
-
-    synchronized (this.getClassLoadingLock(name)) {
-      Class<?> result;
-      try {
-        result = super.loadClass(name, resolve);
-      } catch (ClassNotFoundException e) {
-        result = child.findLoadedClass(name);
-      }
-      if (result == null) {
-        throw new ClassNotFoundException("Could not find " + name);
-      }
-      return result;
-    }
-  }
-}
diff --git a/src/main/java/com/google/devtools/build/lib/bazel/rules/java/BazelJavaSemantics.java b/src/main/java/com/google/devtools/build/lib/bazel/rules/java/BazelJavaSemantics.java
index b2a0400..1e7cf52 100644
--- a/src/main/java/com/google/devtools/build/lib/bazel/rules/java/BazelJavaSemantics.java
+++ b/src/main/java/com/google/devtools/build/lib/bazel/rules/java/BazelJavaSemantics.java
@@ -307,7 +307,7 @@
     }
     NestedSet<Artifact> classpath = classpathBuilder.build();
 
-    if (JavaSemantics.isPersistentTestRunner(ruleContext)) {
+    if (JavaSemantics.isTestTargetAndPersistentTestRunner(ruleContext)) {
       // Create an artifact that stores the test's runtime classpath (excluding the test support
       // classpath). The file is read by the test runner. The jars inside the file are loaded
       // dynamically for every test run into a custom classloader.
@@ -491,7 +491,7 @@
       // targets may break, we are keeping it behind this flag.
       return;
     }
-    if (!JavaSemantics.isPersistentTestRunner(ruleContext)) {
+    if (!JavaSemantics.isTestTargetAndPersistentTestRunner(ruleContext)) {
       // Only add the test support to the dependencies when running in regular mode.
       // In persistent test runner mode don't pollute the classpath of the test with
       // the test support classes.
diff --git a/src/main/java/com/google/devtools/build/lib/rules/java/JavaBinary.java b/src/main/java/com/google/devtools/build/lib/rules/java/JavaBinary.java
index 4b015ae..0be02b5 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/java/JavaBinary.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/java/JavaBinary.java
@@ -351,7 +351,7 @@
           RunfilesSupport.withExecutable(
               ruleContext, defaultRunfiles, executableForRunfiles, extraArgs);
       extraFilesToRunBuilder.add(runfilesSupport.getRunfilesMiddleman());
-      if (JavaSemantics.isPersistentTestRunner(ruleContext)) {
+      if (JavaSemantics.isTestTargetAndPersistentTestRunner(ruleContext)) {
         persistentTestRunnerRunfiles = JavaSemantics.getTestSupportRunfiles(ruleContext);
       }
     }
@@ -465,6 +465,8 @@
             .addProvider(
                 JavaSourceInfoProvider.class,
                 JavaSourceInfoProvider.fromJavaTargetAttributes(attributes, semantics))
+            .maybeTransitiveOnlyRuntimeJarsToJavaInfo(
+                common.getDependencies(), JavaSemantics.isPersistentTestRunner(ruleContext))
             .build();
 
     return builder
diff --git a/src/main/java/com/google/devtools/build/lib/rules/java/JavaCommon.java b/src/main/java/com/google/devtools/build/lib/rules/java/JavaCommon.java
index 9ebd86a..b585f4d 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/java/JavaCommon.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/java/JavaCommon.java
@@ -55,6 +55,7 @@
 import java.util.Collection;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Objects;
 import java.util.Set;
 import javax.annotation.Nullable;
 
@@ -961,6 +962,23 @@
   }
 
   /**
+   * Return the runtime jars of the transitive closure of the target, excluding the first level of
+   * dependencies and the current target itself.
+   *
+   * <p>This particular set of jars is used by the persistent test runner, to create a classloader
+   * for the transitive dependencies. The target itself and its direct dependencies are loaded into
+   * a different classloader.
+   */
+  public NestedSet<Artifact> getRuntimeClasspathExcludingDirect() {
+    NestedSetBuilder<Artifact> classpath = new NestedSetBuilder<>(Order.STABLE_ORDER);
+    targetsTreatedAsDeps(ClasspathType.RUNTIME_ONLY).stream()
+        .map(JavaInfo::getJavaInfo)
+        .filter(Objects::nonNull)
+        .forEach(j -> classpath.addTransitive(j.getTransitiveOnlyRuntimeJars()));
+    return classpath.build();
+  }
+
+  /**
    * Returns true if and only if this target has the neverlink attribute set to 1, or false if the
    * neverlink attribute does not exist (for example, on *_binary targets)
    *
diff --git a/src/main/java/com/google/devtools/build/lib/rules/java/JavaImport.java b/src/main/java/com/google/devtools/build/lib/rules/java/JavaImport.java
index a72006d..0d8ad19 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/java/JavaImport.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/java/JavaImport.java
@@ -145,6 +145,8 @@
             .addProvider(JavaRuleOutputJarsProvider.class, ruleOutputJarsProvider)
             .addProvider(JavaSourceJarsProvider.class, sourceJarsProvider)
             .addProvider(JavaSourceInfoProvider.class, javaSourceInfoProvider)
+            .maybeTransitiveOnlyRuntimeJarsToJavaInfo(
+                common.getDependencies(), JavaSemantics.isPersistentTestRunner(ruleContext))
             .setRuntimeJars(javaArtifacts.getRuntimeJars())
             .setJavaConstraints(JavaCommon.getConstraints(ruleContext))
             .setNeverlink(neverLink)
diff --git a/src/main/java/com/google/devtools/build/lib/rules/java/JavaInfo.java b/src/main/java/com/google/devtools/build/lib/rules/java/JavaInfo.java
index 5de4841..36d0a0f 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/java/JavaInfo.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/java/JavaInfo.java
@@ -94,6 +94,16 @@
    */
   private final ImmutableList<Artifact> directRuntimeJars;
 
+  /**
+   * A set of runtime jars corresponding to the transitive dependencies of a certain target,
+   * excluding the runtime jars for the target itself and its direct dependencies.
+   *
+   * <p>This set is required only when the persistent test runner is enabled. It is used to create a
+   * custom classloader for loading the jars in the transitive dependencies. The persistent test
+   * runner creates a separate classloader for the target itself and its direct dependencies.
+   */
+  private final NestedSet<Artifact> transitiveOnlyRuntimeJars;
+
   /** Java constraints (e.g. "android") that are present on the target. */
   private final ImmutableList<String> javaConstraints;
 
@@ -246,11 +256,13 @@
   JavaInfo(
       TransitiveInfoProviderMap providers,
       ImmutableList<Artifact> directRuntimeJars,
+      NestedSet<Artifact> transitiveOnlyRuntimeJars,
       boolean neverlink,
       ImmutableList<String> javaConstraints,
       Location location) {
     super(PROVIDER, location);
     this.directRuntimeJars = directRuntimeJars;
+    this.transitiveOnlyRuntimeJars = transitiveOnlyRuntimeJars;
     this.providers = providers;
     this.neverlink = neverlink;
     this.javaConstraints = javaConstraints;
@@ -321,6 +333,11 @@
     return directRuntimeJars;
   }
 
+  // Do not expose to Starlark.
+  public NestedSet<Artifact> getTransitiveOnlyRuntimeJars() {
+    return transitiveOnlyRuntimeJars;
+  }
+
   @Override
   public Depset /*<Artifact>*/ getTransitiveDeps() {
     return Depset.of(
@@ -447,6 +464,8 @@
     TransitiveInfoProviderMapBuilder providerMap;
     private ImmutableList<Artifact> runtimeJars;
     private ImmutableList<String> javaConstraints;
+    private final NestedSetBuilder<Artifact> transitiveOnlyRuntimeJars =
+        new NestedSetBuilder<>(Order.STABLE_ORDER);
     private boolean neverlink;
     private Location location = Location.BUILTIN;
 
@@ -463,6 +482,7 @@
     public static Builder copyOf(JavaInfo javaInfo) {
       return new Builder(new TransitiveInfoProviderMapBuilder().addAll(javaInfo.getProviders()))
           .setRuntimeJars(javaInfo.getDirectRuntimeJars())
+          .addTransitiveOnlyRuntimeJars(javaInfo.getTransitiveOnlyRuntimeJars())
           .setNeverlink(javaInfo.isNeverlink())
           .setJavaConstraints(javaInfo.getJavaConstraints())
           .setLocation(javaInfo.getCreationLoc());
@@ -478,6 +498,25 @@
       return this;
     }
 
+    public Builder maybeTransitiveOnlyRuntimeJarsToJavaInfo(
+        List<? extends TransitiveInfoCollection> deps, boolean shouldAdd) {
+      if (shouldAdd) {
+        deps.stream()
+            .map(JavaInfo::getJavaInfo)
+            .filter(Objects::nonNull)
+            .map(j -> j.getProvider(JavaCompilationArgsProvider.class))
+            .filter(Objects::nonNull)
+            .map(JavaCompilationArgsProvider::getRuntimeJars)
+            .forEach(this::addTransitiveOnlyRuntimeJars);
+      }
+      return this;
+    }
+
+    private Builder addTransitiveOnlyRuntimeJars(NestedSet<Artifact> runtimeJars) {
+      this.transitiveOnlyRuntimeJars.addTransitive(runtimeJars);
+      return this;
+    }
+
     public Builder setJavaConstraints(ImmutableList<String> javaConstraints) {
       this.javaConstraints = javaConstraints;
       return this;
@@ -524,7 +563,13 @@
                 providerMap.getProvider(JavaCompilationArgsProvider.class));
         addProvider(JavaStrictCompilationArgsProvider.class, javaStrictCompilationArgsProvider);
       }
-      return new JavaInfo(providerMap.build(), runtimeJars, neverlink, javaConstraints, location);
+      return new JavaInfo(
+          providerMap.build(),
+          runtimeJars,
+          transitiveOnlyRuntimeJars.build(),
+          neverlink,
+          javaConstraints,
+          location);
     }
   }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/rules/java/JavaLibrary.java b/src/main/java/com/google/devtools/build/lib/rules/java/JavaLibrary.java
index dc57bf8..98c9875 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/java/JavaLibrary.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/java/JavaLibrary.java
@@ -188,6 +188,8 @@
             .addProvider(JavaRuleOutputJarsProvider.class, ruleOutputJarsProvider)
             // TODO(bazel-team): this should only happen for java_plugin
             .addProvider(JavaPluginInfoProvider.class, pluginInfoProvider)
+            .maybeTransitiveOnlyRuntimeJarsToJavaInfo(
+                common.getDependencies(), JavaSemantics.isPersistentTestRunner(ruleContext))
             .setRuntimeJars(javaArtifacts.getRuntimeJars())
             .setJavaConstraints(JavaCommon.getConstraints(ruleContext))
             .setNeverlink(neverLink)
diff --git a/src/main/java/com/google/devtools/build/lib/rules/java/JavaSemantics.java b/src/main/java/com/google/devtools/build/lib/rules/java/JavaSemantics.java
index 2d5b3ec..0466a13 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/java/JavaSemantics.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/java/JavaSemantics.java
@@ -321,11 +321,15 @@
    */
   boolean isJavaExecutableSubstitution();
 
-  static boolean isPersistentTestRunner(RuleContext ruleContext) {
+  static boolean isTestTargetAndPersistentTestRunner(RuleContext ruleContext) {
     return ruleContext.isTestTarget()
         && ruleContext.getFragment(TestConfiguration.class).isPersistentTestRunner();
   }
 
+  static boolean isPersistentTestRunner(RuleContext ruleContext) {
+    return ruleContext.getFragment(TestConfiguration.class).isPersistentTestRunner();
+  }
+
   static Runfiles getTestSupportRunfiles(RuleContext ruleContext) {
     TransitiveInfoCollection testSupport = getTestSupport(ruleContext);
     if (testSupport == null) {