Use 2 custom classloaders in the persistent test runner. RELNOTES: None. PiperOrigin-RevId: 290900473
diff --git a/src/java_tools/junitrunner/java/com/google/testing/junit/runner/BUILD b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/BUILD index 30f03b0..c627275 100644 --- a/src/java_tools/junitrunner/java/com/google/testing/junit/runner/BUILD +++ b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/BUILD
@@ -38,11 +38,13 @@ name = "persistent_test_runner", srcs = [ "PersistentTestRunner.java", + "PersistentTestRunnerClassLoader.java", "SuiteTestRunner.java", ], deps = [ "//src/main/protobuf:worker_protocol_java_proto", "//third_party:guava", + "//third_party:jsr305", ], )
diff --git a/src/java_tools/junitrunner/java/com/google/testing/junit/runner/PersistentTestRunner.java b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/PersistentTestRunner.java index d825667..f465063 100644 --- a/src/java_tools/junitrunner/java/com/google/testing/junit/runner/PersistentTestRunner.java +++ b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/PersistentTestRunner.java
@@ -16,17 +16,24 @@ import static java.nio.charset.StandardCharsets.UTF_8; +import com.google.common.hash.HashCode; +import com.google.common.hash.Hasher; +import com.google.common.hash.Hashing; +import com.google.common.hash.HashingInputStream; +import com.google.common.io.ByteStreams; import com.google.common.io.Files; import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest; import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse; import java.io.ByteArrayOutputStream; import java.io.File; +import java.io.FileInputStream; import java.io.IOException; +import java.io.InputStream; import java.io.PrintStream; import java.net.URL; -import java.net.URLClassLoader; -import java.util.ArrayList; import java.util.List; +import java.util.stream.Collectors; +import javax.annotation.Nullable; /** A utility class for running Java tests using persistent workers. */ final class PersistentTestRunner { @@ -39,17 +46,44 @@ * Runs new tests in the same process. Communicates with bazel using the worker's protocol. Reads * a {@link WorkRequest} sent by bazel and sends back a {@link WorkResponse}. * - * <p>Before running a test it creates a new classloader loading the test and its dependencies' - * classes. A class already loaded in a classloader can not be unloaded. To overcome this issue a - * new classloader has to be created at every test run. + * <p>Uses three different classloaders: + * + * <ul> + * <li>A classloader for direct dependencies: loads the test's classes and the first two layers + * of its dependencies. + * <li>A classloader for transitive dependencies: loads the remaining classes from the + * transitive dependencies. + * <li>The system classloader: initialized at JVM startup time; loads the test runner's classes. + * </ul> + * + * <p>The classloaders have a child-parent relationship: direct CL -> transitive CL -> system CL. + * + * <p>The direct dependencies classloader is the current thread's classloader. + * + * <p>The transitive dependencies classloader checks if a class was already loaded by the direct + * dependencies classloader if it did not succeed in loading the class itself. This is required + * for classes loaded by the transitive classloader that reference classes in the direct + * dependencies classloader. + * + * <p>The default class loading logic in {@link ClassLoader} applies in all other cases. + * + * <p>The direct and transitive classloaders are rebuilt before every test run only if the + * combined hash of the jars to be loaded has changed. */ static int runPersistentTestRunner( String suitClassName, String workspacePrefix, SuiteTestRunner suiteTestRunner) { PrintStream originalStdOut = System.out; PrintStream originalStdErr = System.err; - String testRuntimeClasspathFile = System.getenv("TEST_RUNTIME_CLASSPATH_FILE"); - String javaRunfilesPath = System.getenv("JAVA_RUNFILES"); + // TODO(elenairina): Remove this variable after cl/282553936 is released. + String legacyTestRuntimeClasspathFile = System.getenv("TEST_RUNTIME_CLASSPATH_FILE"); + + String directClasspathFile = System.getenv("TEST_DIRECT_CLASSPATH_FILE"); + String transitiveClasspathFile = System.getenv("TEST_TRANSITIVE_CLASSPATH_FILE"); + String absolutePathPrefix = System.getenv("JAVA_RUNFILES") + File.separator + workspacePrefix; + + PersistentTestRunnerClassLoader transitiveDepsClassLoader = null; + PersistentTestRunnerClassLoader directDepsClassLoader = null; // Reading the work requests and solving them in sequence is not a problem because Bazel creates // up to --worker_max_instances (defaults to 4) instances per worker key. @@ -62,8 +96,31 @@ break; } - URLClassLoader testRunnerClassLoader = - recreateClassLoader(testRuntimeClasspathFile, javaRunfilesPath, workspacePrefix); + if (legacyTestRuntimeClasspathFile != null) { + // Re-use the same classloader variable in the legacy case for simplicity. + directDepsClassLoader = + maybeRecreateClassLoader( + getFilesWithAbsolutePathPrefixFromFile( + legacyTestRuntimeClasspathFile, absolutePathPrefix), + ClassLoader.getSystemClassLoader(), + null); + } else { + transitiveDepsClassLoader = + maybeRecreateClassLoader( + getFilesWithAbsolutePathPrefixFromFile( + transitiveClasspathFile, absolutePathPrefix), + ClassLoader.getSystemClassLoader(), + transitiveDepsClassLoader); + + directDepsClassLoader = + maybeRecreateClassLoader( + getFilesWithAbsolutePathPrefixFromFile(directClasspathFile, absolutePathPrefix), + transitiveDepsClassLoader, + directDepsClassLoader); + transitiveDepsClassLoader.setChild(directDepsClassLoader); + } + + Thread.currentThread().setContextClassLoader(directDepsClassLoader); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); PrintStream printStream = new PrintStream(outputStream, true); @@ -74,7 +131,7 @@ try { exitCode = suiteTestRunner.runTestsInSuite( - suitClassName, arguments, testRunnerClassLoader, /* resolve= */ true); + suitClassName, arguments, directDepsClassLoader, /* resolve= */ true); } finally { System.setOut(originalStdOut); System.setErr(originalStdErr); @@ -116,40 +173,57 @@ } /** - * Returns a classloader containing the jars read from the given runtime classpath file. + * Returns a classloader that loads the given jars, only if the combined hash of their content is + * different than the one for the given previous classloader. * - * <p>Sets the classloader of the current thread to the newly created classloader. - * - * <p>Sets the classloader used to load the test classes and their dependencies. + * <p>Returns previousClassLoader if the hashes are the same. * * <p>Needs to be called before every test run to avoid having stale classes. A class already * loaded in a classloader can not be unloaded. To overcome this a new classloader has to be * created at every test run. */ - private static URLClassLoader recreateClassLoader( - String runtimeClasspathFilename, String javaRunfilesPath, String workspacePrefix) + private static PersistentTestRunnerClassLoader maybeRecreateClassLoader( + List<File> runtimeJars, + ClassLoader parent, + @Nullable PersistentTestRunnerClassLoader previousClassLoader) throws IOException { - URLClassLoader classLoader = - new URLClassLoader( - createURLsFromRelativePathsInFile( - runtimeClasspathFilename, javaRunfilesPath, workspacePrefix)); - Thread.currentThread().setContextClassLoader(classLoader); - return classLoader; + HashCode combinedHash = getCombinedHashForFiles(runtimeJars); + if (previousClassLoader != null && combinedHash.equals(previousClassLoader.getChecksum())) { + return previousClassLoader; + } + + return new PersistentTestRunnerClassLoader( + convertFileListToURLArray(runtimeJars), parent, combinedHash); } - private static URL[] createURLsFromRelativePathsInFile( - String runtimeClasspathFilename, String javaRunfilesPath, String workspacePrefix) - throws IOException { - List<String> testRuntimeClasspath = Files.readLines(new File(runtimeClasspathFilename), UTF_8); - ArrayList<URL> urlList = new ArrayList<>(); - for (String classPathEntry : testRuntimeClasspath) { - urlList.add( - new File(javaRunfilesPath + File.separator + workspacePrefix + classPathEntry) - .toURI() - .toURL()); + private static HashCode getCombinedHashForFiles(List<File> files) throws IOException { + Hasher hasher = Hashing.sha256().newHasher(); + for (File file : files) { + hasher.putBytes(getFileHash(file).asBytes()); } - URL[] urls = new URL[urlList.size()]; - return urlList.toArray(urls); + return hasher.hash(); + } + + private static HashCode getFileHash(File file) throws IOException { + InputStream inputStream = new FileInputStream(file); + HashingInputStream hashingStream = new HashingInputStream(Hashing.sha256(), inputStream); + ByteStreams.copy(hashingStream, ByteStreams.nullOutputStream()); + return hashingStream.hash(); + } + + private static URL[] convertFileListToURLArray(List<File> jars) throws IOException { + URL[] urls = new URL[jars.size()]; + for (int i = 0; i < jars.size(); i++) { + urls[i] = jars.get(i).toURI().toURL(); + } + return urls; + } + + private static List<File> getFilesWithAbsolutePathPrefixFromFile( + String runtimeClasspathFilename, String absolutePathPrefix) throws IOException { + return Files.readLines(new File(runtimeClasspathFilename), UTF_8).stream() + .map(entry -> new File(absolutePathPrefix + entry)) + .collect(Collectors.toList()); } static boolean isPersistentTestRunner() {
diff --git a/src/java_tools/junitrunner/java/com/google/testing/junit/runner/PersistentTestRunnerClassLoader.java b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/PersistentTestRunnerClassLoader.java new file mode 100644 index 0000000..c5e04bf --- /dev/null +++ b/src/java_tools/junitrunner/java/com/google/testing/junit/runner/PersistentTestRunnerClassLoader.java
@@ -0,0 +1,71 @@ +// Copyright 2020 The Bazel Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.testing.junit.runner; + +import com.google.common.hash.HashCode; +import java.net.URL; +import java.net.URLClassLoader; + +/** + * A custom classloader used by the persistent test runner. + * + * <p>Each classloader stores the combined hash code for the loaded jars. + */ +final class PersistentTestRunnerClassLoader extends URLClassLoader { + + private final HashCode checksum; + private PersistentTestRunnerClassLoader child; + + public PersistentTestRunnerClassLoader(URL[] urls, ClassLoader parent, HashCode checksum) { + super(urls, parent); + this.checksum = checksum; + } + + void setChild(PersistentTestRunnerClassLoader child) { + this.child = child; + } + + HashCode getChecksum() { + return checksum; + } + + /** + * Loads the class with the specified name and resolves it if required. + * + * <p>If the classloader has a child: check if the class was already loaded by the child if the + * current classloader did not succeed in loading the class. + * + * <p>If the classloader doesn't have a child: use the default class loading logic. + */ + @Override + public Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException { + if (child == null) { + return super.loadClass(name, resolve); + } + + synchronized (this.getClassLoadingLock(name)) { + Class<?> result; + try { + result = super.loadClass(name, resolve); + } catch (ClassNotFoundException e) { + result = child.findLoadedClass(name); + } + if (result == null) { + throw new ClassNotFoundException("Could not find " + name); + } + return result; + } + } +}
diff --git a/src/main/java/com/google/devtools/build/lib/rules/java/JavaCommon.java b/src/main/java/com/google/devtools/build/lib/rules/java/JavaCommon.java index 8fdf635..1ddb0bb 100644 --- a/src/main/java/com/google/devtools/build/lib/rules/java/JavaCommon.java +++ b/src/main/java/com/google/devtools/build/lib/rules/java/JavaCommon.java
@@ -940,6 +940,27 @@ } /** + * Returns a list of the current target's runtime jars and the first two levels of its direct + * dependencies. + * + * <p>This method is meant to aid the persistent test runner, which aims at avoiding loading all + * classes on the classpath for each test run. To that extent this method computes a small jars + * set of the most likely to be changed classes when writing code for a test. Their classes should + * be loaded in a separate classloader by the persistent test runner. + */ + public ImmutableSet<Artifact> getDirectRuntimeClasspath() { + ImmutableSet.Builder<Artifact> directDeps = new ImmutableSet.Builder<>(); + directDeps.addAll(javaArtifacts.getRuntimeJars()); + for (TransitiveInfoCollection dep : targetsTreatedAsDeps(ClasspathType.RUNTIME_ONLY)) { + JavaInfo javaInfo = JavaInfo.getJavaInfo(dep); + if (javaInfo != null) { + directDeps.addAll(javaInfo.getDirectRuntimeJars()); + } + } + return directDeps.build(); + } + + /** * Returns true if and only if this target has the neverlink attribute set to 1, or false if the * neverlink attribute does not exist (for example, on *_binary targets) *