blob: 13b729bccfcecfabf0251e2cf19f7c52d710ffd6 [file] [log] [blame]
// Copyright 2019 The Bazel Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.testing.junit.runner;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.stream.Collectors.toList;
import com.google.common.hash.HashCode;
import com.google.common.hash.Hasher;
import com.google.common.hash.Hashing;
import com.google.common.hash.HashingInputStream;
import com.google.common.io.ByteStreams;
import com.google.common.io.Files;
import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest;
import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintStream;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/** A utility class for running Java tests using persistent workers. */
final class PersistentTestRunner {
private PersistentTestRunner() {
// utility class; should not be instantiated
}
/**
* Runs new tests in the same process. Communicates with bazel using the worker's protocol. Reads
* a {@link WorkRequest} sent by bazel and sends back a {@link WorkResponse}.
*
* <p>Uses three different classloaders:
*
* <ul>
* <li>A classloader for direct dependencies: loads the classes of the test target and its
* direct dependencies, *excluding* those dependencies that are also among the rest of the
* transitive dependencies
* <li>A classloader for transitive dependencies: loads the remaining classes from the
* transitive dependencies.
* <li>The system classloader: initialized at JVM startup time; loads the test runner's classes.
* </ul>
*
* <p>The classloaders have a child-parent relationship: direct CL -> transitive CL -> system CL.
*
* <p>The direct dependencies classloader is the current thread's classloader.
*
* <p>For example, given the following dependency graph:
*
* TestTarget
* / | \
* * * *
* a b c
* | * |
* * / *
* d e
* |
* *
* f
*
* <p>the classloaders load the following: - the direct classloader loads the classes of
* TestTarget, a and c, which are direct deps - the transitive classloader loads the classes of b,
* d, e and f, which are transitive deps Note that b is loaded by the transitive classloader even
* if it is a direct dependency, because b is also a dependency of d, which is a 2nd level
* dependency of TestTarget.
*
* <p>Excluding the direct dependencies that are also present lower in the dependency tree from
* direct dependencies classloader presents two advantages: 1. reduces the number of classes
* loaded by directDepsClassLoader, making it faster to create/load 2. avoids additional custom
* logic for loading classes, since we are now sure that the transitive dependencies classLoader
* doesn't reference anything from its child classloader
*
* <p>The direct and transitive classloaders are rebuilt before every test run only if the
* combined hash of the jars to be loaded has changed. If the transitive classloader has to be
* rebuilt, the direct classloader will also be rebuilt to preserve correctness.
*/
static int runPersistentTestRunner(
String suitClassName, String workspacePrefix, SuiteTestRunner suiteTestRunner) {
PrintStream originalStdOut = System.out;
PrintStream originalStdErr = System.err;
String directClasspathFile = System.getenv("TEST_DIRECT_CLASSPATH_FILE");
String transitiveClasspathFile = System.getenv("TEST_TRANSITIVE_CLASSPATH_FILE");
String absolutePathPrefix = System.getenv("JAVA_RUNFILES") + File.separator + workspacePrefix;
// Loads the classes of the test target and its direct dependencies *excluding* those
// dependencies that are also among the rest of the transitive dependencies.
URLClassLoader directDepsClassLoader = null;
// Loads all the classes in the transitive dependencies, excluding those loaded
// by directDepsClassLoader
URLClassLoader transitiveDepsClassLoader = null;
HashCode previousTransitiveCombinedHash = null;
HashCode previousDirectCombinedHash = null;
// Reading the work requests and solving them in sequence is not a problem because Bazel creates
// up to --worker_max_instances (defaults to 4) instances per worker key.
while (true) {
try {
WorkRequest request = WorkRequest.parseDelimitedFrom(System.in);
if (request == null) {
// null is only returned when the stream reaches EOF
break;
}
Set<File> transitiveDeps =
getFilesWithAbsolutePathPrefixFromFile(transitiveClasspathFile, absolutePathPrefix);
Set<File> directDeps =
getFilesWithAbsolutePathPrefixFromFile(directClasspathFile, absolutePathPrefix);
// Filter out duplicated dependencies from the directDeps, to ensure that the
// transitiveDepsClassLoader doesn't reference anything from directDepsClassLoader.
directDeps = filterOutDupedDeps(directDeps, transitiveDeps);
HashCode transitiveCombinedHash = getCombinedHashForFiles(transitiveDeps);
HashCode directCombinedHash = getCombinedHashForFiles(directDeps);
boolean recreateTransitiveClassloader =
transitiveDepsClassLoader == null
|| !transitiveCombinedHash.equals(previousTransitiveCombinedHash);
// if the parent needs to be re-created, the child needs to be recreated also
boolean recreateDirectClassloader =
recreateTransitiveClassloader
|| directDepsClassLoader == null
|| !directCombinedHash.equals(previousDirectCombinedHash);
previousTransitiveCombinedHash = transitiveCombinedHash;
previousDirectCombinedHash = directCombinedHash;
if (recreateTransitiveClassloader) {
transitiveDepsClassLoader =
new URLClassLoader(
convertFileListToURLArray(transitiveDeps), ClassLoader.getSystemClassLoader());
}
if (recreateDirectClassloader) {
directDepsClassLoader =
new URLClassLoader(convertFileListToURLArray(directDeps), transitiveDepsClassLoader);
}
Thread.currentThread().setContextClassLoader(directDepsClassLoader);
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
PrintStream printStream = new PrintStream(outputStream, true);
System.setOut(printStream);
System.setErr(printStream);
String[] arguments = request.getArgumentsList().toArray(new String[0]);
int exitCode = -1;
try {
exitCode =
suiteTestRunner.runTestsInSuite(
suitClassName, arguments, directDepsClassLoader, /* resolve= */ true);
} finally {
System.setOut(originalStdOut);
System.setErr(originalStdErr);
}
WorkResponse response =
WorkResponse.newBuilder()
.setOutput(outputStream.toString())
.setExitCode(exitCode)
.build();
response.writeDelimitedTo(System.out);
System.out.flush();
} catch (IOException e) {
e.printStackTrace();
return 1;
}
}
return 0;
}
/**
* Returns a {@link Class} with the given name. Loads the class for the given classloader, or from
* the system classloader if none is specified.
*/
static Class<?> getTestClass(String name, ClassLoader classLoader, boolean resolve) {
if (name == null) {
return null;
}
try {
if (classLoader == null) {
return Class.forName(name);
}
return Class.forName(name, resolve, classLoader);
} catch (ClassNotFoundException e) {
return null;
}
}
private static HashCode getCombinedHashForFiles(Set<File> files) {
List<HashCode> hashesAsBytes =
files.stream()
.parallel()
.map(file -> getFileHash(file))
.filter(Objects::nonNull)
.collect(toList());
// Update the hasher separately because Hasher.putBytes() is not safe for parallel operations
Hasher hasher = Hashing.sha256().newHasher();
for (HashCode hash : hashesAsBytes) {
hasher.putBytes(hash.asBytes());
}
return hasher.hash();
}
private static HashCode getFileHash(File file) {
try {
InputStream inputStream;
inputStream = new FileInputStream(file);
HashingInputStream hashingStream = new HashingInputStream(Hashing.sha256(), inputStream);
ByteStreams.copy(hashingStream, ByteStreams.nullOutputStream());
return hashingStream.hash();
} catch (IOException e) {
// Throwing RuntimeException to fail the whole build, and still benefit from using parallel
// streams in getCombinedHashForFiles().
throw new RuntimeException(e);
}
}
private static URL[] convertFileListToURLArray(Set<File> jars) throws IOException {
URL[] urls = new URL[jars.size()];
int it = 0;
for (File jar : jars) {
urls[it++] = jar.toURI().toURL();
}
return urls;
}
private static Set<File> getFilesWithAbsolutePathPrefixFromFile(
String runtimeClasspathFilename, String absolutePathPrefix) throws IOException {
return Files.readLines(new File(runtimeClasspathFilename), UTF_8).stream()
.map(entry -> new File(absolutePathPrefix + entry))
.collect(Collectors.toSet());
}
private static Set<File> filterOutDupedDeps(Set<File> smallSet, Set<File> largeSet) {
return smallSet.stream()
.parallel()
.filter(f -> !largeSet.contains(f))
.collect(Collectors.toSet());
}
static boolean isPersistentTestRunner() {
return Boolean.parseBoolean(System.getenv("PERSISTENT_TEST_RUNNER"));
}
}