Use 2 custom classloaders in the persistent test runner.

RELNOTES: None.
PiperOrigin-RevId: 290900473
diff --git a/src/java_tools/junitrunner/java/com/google/testing/junit/runner/BUILD b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/BUILD
index 30f03b0..c627275 100644
--- a/src/java_tools/junitrunner/java/com/google/testing/junit/runner/BUILD
+++ b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/BUILD
@@ -38,11 +38,13 @@
     name = "persistent_test_runner",
     srcs = [
         "PersistentTestRunner.java",
+        "PersistentTestRunnerClassLoader.java",
         "SuiteTestRunner.java",
     ],
     deps = [
         "//src/main/protobuf:worker_protocol_java_proto",
         "//third_party:guava",
+        "//third_party:jsr305",
     ],
 )
 
diff --git a/src/java_tools/junitrunner/java/com/google/testing/junit/runner/PersistentTestRunner.java b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/PersistentTestRunner.java
index d825667..f465063 100644
--- a/src/java_tools/junitrunner/java/com/google/testing/junit/runner/PersistentTestRunner.java
+++ b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/PersistentTestRunner.java
@@ -16,17 +16,24 @@
 
 import static java.nio.charset.StandardCharsets.UTF_8;
 
+import com.google.common.hash.HashCode;
+import com.google.common.hash.Hasher;
+import com.google.common.hash.Hashing;
+import com.google.common.hash.HashingInputStream;
+import com.google.common.io.ByteStreams;
 import com.google.common.io.Files;
 import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest;
 import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse;
 import java.io.ByteArrayOutputStream;
 import java.io.File;
+import java.io.FileInputStream;
 import java.io.IOException;
+import java.io.InputStream;
 import java.io.PrintStream;
 import java.net.URL;
-import java.net.URLClassLoader;
-import java.util.ArrayList;
 import java.util.List;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
 
 /** A utility class for running Java tests using persistent workers. */
 final class PersistentTestRunner {
@@ -39,17 +46,44 @@
    * Runs new tests in the same process. Communicates with bazel using the worker's protocol. Reads
    * a {@link WorkRequest} sent by bazel and sends back a {@link WorkResponse}.
    *
-   * <p>Before running a test it creates a new classloader loading the test and its dependencies'
-   * classes. A class already loaded in a classloader can not be unloaded. To overcome this issue a
-   * new classloader has to be created at every test run.
+   * <p>Uses three different classloaders:
+   *
+   * <ul>
+   *   <li>A classloader for direct dependencies: loads the test's classes and the first two layers
+   *       of its dependencies.
+   *   <li>A classloader for transitive dependencies: loads the remaining classes from the
+   *       transitive dependencies.
+   *   <li>The system classloader: initialized at JVM startup time; loads the test runner's classes.
+   * </ul>
+   *
+   * <p>The classloaders have a child-parent relationship: direct CL -> transitive CL -> system CL.
+   *
+   * <p>The direct dependencies classloader is the current thread's classloader.
+   *
+   * <p>The transitive dependencies classloader checks if a class was already loaded by the direct
+   * dependencies classloader if it did not succeed in loading the class itself. This is required
+   * for classes loaded by the transitive classloader that reference classes in the direct
+   * dependencies classloader.
+   *
+   * <p>The default class loading logic in {@link ClassLoader} applies in all other cases.
+   *
+   * <p>The direct and transitive classloaders are rebuilt before every test run only if the
+   * combined hash of the jars to be loaded has changed.
    */
   static int runPersistentTestRunner(
       String suitClassName, String workspacePrefix, SuiteTestRunner suiteTestRunner) {
     PrintStream originalStdOut = System.out;
     PrintStream originalStdErr = System.err;
 
-    String testRuntimeClasspathFile = System.getenv("TEST_RUNTIME_CLASSPATH_FILE");
-    String javaRunfilesPath = System.getenv("JAVA_RUNFILES");
+    // TODO(elenairina): Remove this variable after cl/282553936 is released.
+    String legacyTestRuntimeClasspathFile = System.getenv("TEST_RUNTIME_CLASSPATH_FILE");
+
+    String directClasspathFile = System.getenv("TEST_DIRECT_CLASSPATH_FILE");
+    String transitiveClasspathFile = System.getenv("TEST_TRANSITIVE_CLASSPATH_FILE");
+    String absolutePathPrefix = System.getenv("JAVA_RUNFILES") + File.separator + workspacePrefix;
+
+    PersistentTestRunnerClassLoader transitiveDepsClassLoader = null;
+    PersistentTestRunnerClassLoader directDepsClassLoader = null;
 
     // Reading the work requests and solving them in sequence is not a problem because Bazel creates
     // up to --worker_max_instances (defaults to 4) instances per worker key.
@@ -62,8 +96,31 @@
           break;
         }
 
-        URLClassLoader testRunnerClassLoader =
-            recreateClassLoader(testRuntimeClasspathFile, javaRunfilesPath, workspacePrefix);
+        if (legacyTestRuntimeClasspathFile != null) {
+          // Re-use the same classloader variable in the legacy case for simplicity.
+          directDepsClassLoader =
+              maybeRecreateClassLoader(
+                  getFilesWithAbsolutePathPrefixFromFile(
+                      legacyTestRuntimeClasspathFile, absolutePathPrefix),
+                  ClassLoader.getSystemClassLoader(),
+                  null);
+        } else {
+          transitiveDepsClassLoader =
+              maybeRecreateClassLoader(
+                  getFilesWithAbsolutePathPrefixFromFile(
+                      transitiveClasspathFile, absolutePathPrefix),
+                  ClassLoader.getSystemClassLoader(),
+                  transitiveDepsClassLoader);
+
+          directDepsClassLoader =
+              maybeRecreateClassLoader(
+                  getFilesWithAbsolutePathPrefixFromFile(directClasspathFile, absolutePathPrefix),
+                  transitiveDepsClassLoader,
+                  directDepsClassLoader);
+          transitiveDepsClassLoader.setChild(directDepsClassLoader);
+        }
+
+        Thread.currentThread().setContextClassLoader(directDepsClassLoader);
 
         ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
         PrintStream printStream = new PrintStream(outputStream, true);
@@ -74,7 +131,7 @@
         try {
           exitCode =
               suiteTestRunner.runTestsInSuite(
-                  suitClassName, arguments, testRunnerClassLoader, /* resolve= */ true);
+                  suitClassName, arguments, directDepsClassLoader, /* resolve= */ true);
         } finally {
           System.setOut(originalStdOut);
           System.setErr(originalStdErr);
@@ -116,40 +173,57 @@
   }
 
   /**
-   * Returns a classloader containing the jars read from the given runtime classpath file.
+   * Returns a classloader that loads the given jars, only if the combined hash of their content is
+   * different than the one for the given previous classloader.
    *
-   * <p>Sets the classloader of the current thread to the newly created classloader.
-   *
-   * <p>Sets the classloader used to load the test classes and their dependencies.
+   * <p>Returns previousClassLoader if the hashes are the same.
    *
    * <p>Needs to be called before every test run to avoid having stale classes. A class already
    * loaded in a classloader can not be unloaded. To overcome this a new classloader has to be
    * created at every test run.
    */
-  private static URLClassLoader recreateClassLoader(
-      String runtimeClasspathFilename, String javaRunfilesPath, String workspacePrefix)
+  private static PersistentTestRunnerClassLoader maybeRecreateClassLoader(
+      List<File> runtimeJars,
+      ClassLoader parent,
+      @Nullable PersistentTestRunnerClassLoader previousClassLoader)
       throws IOException {
-    URLClassLoader classLoader =
-        new URLClassLoader(
-            createURLsFromRelativePathsInFile(
-                runtimeClasspathFilename, javaRunfilesPath, workspacePrefix));
-    Thread.currentThread().setContextClassLoader(classLoader);
-    return classLoader;
+    HashCode combinedHash = getCombinedHashForFiles(runtimeJars);
+    if (previousClassLoader != null && combinedHash.equals(previousClassLoader.getChecksum())) {
+      return previousClassLoader;
+    }
+
+    return new PersistentTestRunnerClassLoader(
+        convertFileListToURLArray(runtimeJars), parent, combinedHash);
   }
 
-  private static URL[] createURLsFromRelativePathsInFile(
-      String runtimeClasspathFilename, String javaRunfilesPath, String workspacePrefix)
-      throws IOException {
-    List<String> testRuntimeClasspath = Files.readLines(new File(runtimeClasspathFilename), UTF_8);
-    ArrayList<URL> urlList = new ArrayList<>();
-    for (String classPathEntry : testRuntimeClasspath) {
-      urlList.add(
-          new File(javaRunfilesPath + File.separator + workspacePrefix + classPathEntry)
-              .toURI()
-              .toURL());
+  private static HashCode getCombinedHashForFiles(List<File> files) throws IOException {
+    Hasher hasher = Hashing.sha256().newHasher();
+    for (File file : files) {
+      hasher.putBytes(getFileHash(file).asBytes());
     }
-    URL[] urls = new URL[urlList.size()];
-    return urlList.toArray(urls);
+    return hasher.hash();
+  }
+
+  private static HashCode getFileHash(File file) throws IOException {
+    InputStream inputStream = new FileInputStream(file);
+    HashingInputStream hashingStream = new HashingInputStream(Hashing.sha256(), inputStream);
+    ByteStreams.copy(hashingStream, ByteStreams.nullOutputStream());
+    return hashingStream.hash();
+  }
+
+  private static URL[] convertFileListToURLArray(List<File> jars) throws IOException {
+    URL[] urls = new URL[jars.size()];
+    for (int i = 0; i < jars.size(); i++) {
+      urls[i] = jars.get(i).toURI().toURL();
+    }
+    return urls;
+  }
+
+  private static List<File> getFilesWithAbsolutePathPrefixFromFile(
+      String runtimeClasspathFilename, String absolutePathPrefix) throws IOException {
+    return Files.readLines(new File(runtimeClasspathFilename), UTF_8).stream()
+        .map(entry -> new File(absolutePathPrefix + entry))
+        .collect(Collectors.toList());
   }
 
   static boolean isPersistentTestRunner() {
diff --git a/src/java_tools/junitrunner/java/com/google/testing/junit/runner/PersistentTestRunnerClassLoader.java b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/PersistentTestRunnerClassLoader.java
new file mode 100644
index 0000000..c5e04bf
--- /dev/null
+++ b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/PersistentTestRunnerClassLoader.java
@@ -0,0 +1,71 @@
+// Copyright 2020 The Bazel Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.testing.junit.runner;
+
+import com.google.common.hash.HashCode;
+import java.net.URL;
+import java.net.URLClassLoader;
+
+/**
+ * A custom classloader used by the persistent test runner.
+ *
+ * <p>Each classloader stores the combined hash code for the loaded jars.
+ */
+final class PersistentTestRunnerClassLoader extends URLClassLoader {
+
+  private final HashCode checksum;
+  private PersistentTestRunnerClassLoader child;
+
+  public PersistentTestRunnerClassLoader(URL[] urls, ClassLoader parent, HashCode checksum) {
+    super(urls, parent);
+    this.checksum = checksum;
+  }
+
+  void setChild(PersistentTestRunnerClassLoader child) {
+    this.child = child;
+  }
+
+  HashCode getChecksum() {
+    return checksum;
+  }
+
+  /**
+   * Loads the class with the specified name and resolves it if required.
+   *
+   * <p>If the classloader has a child: check if the class was already loaded by the child if the
+   * current classloader did not succeed in loading the class.
+   *
+   * <p>If the classloader doesn't have a child: use the default class loading logic.
+   */
+  @Override
+  public Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException {
+    if (child == null) {
+      return super.loadClass(name, resolve);
+    }
+
+    synchronized (this.getClassLoadingLock(name)) {
+      Class<?> result;
+      try {
+        result = super.loadClass(name, resolve);
+      } catch (ClassNotFoundException e) {
+        result = child.findLoadedClass(name);
+      }
+      if (result == null) {
+        throw new ClassNotFoundException("Could not find " + name);
+      }
+      return result;
+    }
+  }
+}
diff --git a/src/main/java/com/google/devtools/build/lib/rules/java/JavaCommon.java b/src/main/java/com/google/devtools/build/lib/rules/java/JavaCommon.java
index 8fdf635..1ddb0bb 100644
--- a/src/main/java/com/google/devtools/build/lib/rules/java/JavaCommon.java
+++ b/src/main/java/com/google/devtools/build/lib/rules/java/JavaCommon.java
@@ -940,6 +940,27 @@
   }
 
   /**
+   * Returns a list of the current target's runtime jars and the first two levels of its direct
+   * dependencies.
+   *
+   * <p>This method is meant to aid the persistent test runner, which aims at avoiding loading all
+   * classes on the classpath for each test run. To that extent this method computes a small jars
+   * set of the most likely to be changed classes when writing code for a test. Their classes should
+   * be loaded in a separate classloader by the persistent test runner.
+   */
+  public ImmutableSet<Artifact> getDirectRuntimeClasspath() {
+    ImmutableSet.Builder<Artifact> directDeps = new ImmutableSet.Builder<>();
+    directDeps.addAll(javaArtifacts.getRuntimeJars());
+    for (TransitiveInfoCollection dep : targetsTreatedAsDeps(ClasspathType.RUNTIME_ONLY)) {
+      JavaInfo javaInfo = JavaInfo.getJavaInfo(dep);
+      if (javaInfo != null) {
+        directDeps.addAll(javaInfo.getDirectRuntimeJars());
+      }
+    }
+    return directDeps.build();
+  }
+
+  /**
    * Returns true if and only if this target has the neverlink attribute set to 1, or false if the
    * neverlink attribute does not exist (for example, on *_binary targets)
    *