blob: 6e1555ca00bce16dbc5c54690329df9c943b21d6 [file] [log] [blame]
// Copyright 2016 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.devtools.build.lib.server;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.net.InetAddresses;
import com.google.devtools.build.lib.runtime.BlazeCommandDispatcher.LockingMode;
import com.google.devtools.build.lib.runtime.CommandExecutor;
import com.google.devtools.build.lib.server.CommandProtos.CancelRequest;
import com.google.devtools.build.lib.server.CommandProtos.CancelResponse;
import com.google.devtools.build.lib.server.CommandProtos.PingRequest;
import com.google.devtools.build.lib.server.CommandProtos.PingResponse;
import com.google.devtools.build.lib.server.CommandProtos.RunRequest;
import com.google.devtools.build.lib.server.CommandProtos.RunResponse;
import com.google.devtools.build.lib.util.BlazeClock;
import com.google.devtools.build.lib.util.Clock;
import com.google.devtools.build.lib.util.ExitCode;
import com.google.devtools.build.lib.util.Preconditions;
import com.google.devtools.build.lib.util.ThreadUtils;
import com.google.devtools.build.lib.util.io.OutErr;
import com.google.devtools.build.lib.vfs.FileSystemUtils;
import com.google.devtools.build.lib.vfs.Path;
import com.google.protobuf.ByteString;
import io.grpc.Server;
import io.grpc.netty.NettyServerBuilder;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.net.InetSocketAddress;
import java.nio.charset.Charset;
import java.security.SecureRandom;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.logging.Logger;
import javax.annotation.concurrent.GuardedBy;
/**
* gRPC server class.
*
* <p>Only this class should depend on gRPC so that we only need to exclude this during
* bootstrapping.
*/
public class GrpcServerImpl implements RPCServer {
// UTF-8 won't do because we want to be able to pass arbitrary binary strings.
// Not that the internals of Bazel handle that correctly, but why not make at least this little
// part correct?
private static final Charset CHARSET = Charset.forName("ISO-8859-1");
private static final long NANOSECONDS_IN_MS = TimeUnit.MILLISECONDS.toNanos(1);
private static final Logger LOG = Logger.getLogger(RPCServer.class.getName());
private class RunningCommand implements AutoCloseable {
private final Thread thread;
private final String id;
private RunningCommand() {
thread = Thread.currentThread();
id = UUID.randomUUID().toString();
synchronized (runningCommands) {
runningCommands.put(id, this);
runningCommands.notify();
}
}
@Override
public void close() {
synchronized (runningCommands) {
runningCommands.remove(id);
runningCommands.notify();
}
}
}
/**
* Factory class. Instantiated by reflection.
*/
public static class Factory implements RPCServer.Factory {
@Override
public RPCServer create(CommandExecutor commandExecutor, Clock clock, int port,
Path serverDirectory, int maxIdleSeconds) throws IOException {
return new GrpcServerImpl(commandExecutor, clock, port, serverDirectory, maxIdleSeconds);
}
}
private enum StreamType {
STDOUT,
STDERR,
}
// TODO(lberki): Maybe we should implement line buffering?
private class RpcOutputStream extends OutputStream {
private final StreamObserver<RunResponse> observer;
private final String commandId;
private final StreamType type;
private RpcOutputStream(
StreamObserver<RunResponse> observer, String commandId, StreamType type) {
this.observer = observer;
this.commandId = commandId;
this.type = type;
}
@Override
public synchronized void write(byte[] b, int off, int inlen) {
ByteString input = ByteString.copyFrom(b, off, inlen);
RunResponse.Builder response = RunResponse
.newBuilder()
.setCookie(responseCookie)
.setCommandId(commandId);
switch (type) {
case STDOUT: response.setStandardOutput(input); break;
case STDERR: response.setStandardError(input); break;
default: throw new IllegalStateException();
}
observer.onNext(response.build());
}
@Override
public void write(int byteAsInt) throws IOException {
byte b = (byte) byteAsInt; // make sure we work with bytes in comparisons
write(new byte[] {b}, 0, 1);
}
}
// These paths are all relative to the server directory
private static final String PORT_FILE = "command_port";
private static final String REQUEST_COOKIE_FILE = "request_cookie";
private static final String RESPONSE_COOKIE_FILE = "response_cookie";
private static final AtomicBoolean runShutdownHooks = new AtomicBoolean(true);
@GuardedBy("runningCommands")
private final Map<String, RunningCommand> runningCommands = new HashMap<>();
private final CommandExecutor commandExecutor;
private final Clock clock;
private final Path serverDirectory;
private final String requestCookie;
private final String responseCookie;
private final AtomicLong interruptCounter = new AtomicLong(0);
private final int maxIdleSeconds;
private Server server;
private final int port;
boolean serving;
public GrpcServerImpl(CommandExecutor commandExecutor, Clock clock, int port,
Path serverDirectory, int maxIdleSeconds) throws IOException {
// server.pid was written in the C++ launcher after fork() but before exec() .
// The client only accesses the pid file after connecting to the socket
// which ensures that it gets the correct pid value.
Path pidFile = serverDirectory.getRelative("server.pid.txt");
Path pidSymlink = serverDirectory.getRelative("server.pid");
deleteAtExit(pidFile, /*deleteParent=*/ false);
deleteAtExit(pidSymlink, /*deleteParent=*/ false);
this.commandExecutor = commandExecutor;
this.clock = clock;
this.serverDirectory = serverDirectory;
this.port = port;
this.maxIdleSeconds = maxIdleSeconds;
this.serving = false;
SecureRandom random = new SecureRandom();
requestCookie = generateCookie(random, 16);
responseCookie = generateCookie(random, 16);
}
private static String generateCookie(SecureRandom random, int byteCount) {
byte[] bytes = new byte[byteCount];
random.nextBytes(bytes);
StringBuilder result = new StringBuilder();
for (byte b : bytes) {
result.append(Integer.toHexString(((int) b) + 128));
}
return result.toString();
}
private void startSlowInterruptWatcher(final ImmutableSet<String> commandIds) {
if (commandIds.isEmpty()) {
return;
}
Runnable interruptWatcher = new Runnable() {
@Override
public void run() {
try {
boolean ok;
Thread.sleep(10 * 1000);
synchronized (runningCommands) {
ok = Collections.disjoint(commandIds, runningCommands.keySet());
}
if (!ok) {
// At least one command was not interrupted. Interrupt took too long.
ThreadUtils.warnAboutSlowInterrupt();
}
} catch (InterruptedException e) {
// Ignore.
}
}
};
Thread interruptWatcherThread =
new Thread(interruptWatcher, "interrupt-watcher-" + interruptCounter.incrementAndGet());
interruptWatcherThread.setDaemon(true);
interruptWatcherThread.start();
}
private void timeoutThread() {
synchronized (runningCommands) {
boolean idle = runningCommands.isEmpty();
boolean wasIdle = false;
long shutdownTime = -1;
while (true) {
if (!wasIdle && idle) {
shutdownTime = BlazeClock.nanoTime()
+ ((long) maxIdleSeconds) * 1000L * NANOSECONDS_IN_MS;
}
try {
if (idle) {
Verify.verify(shutdownTime > 0);
long waitTime = shutdownTime - BlazeClock.nanoTime();
if (waitTime > 0) {
// Round upwards so that we don't busy-wait in the last millisecond
runningCommands.wait((waitTime + NANOSECONDS_IN_MS - 1) / NANOSECONDS_IN_MS);
}
} else {
runningCommands.wait();
}
} catch (InterruptedException e) {
// Dealt with by checking the current time below.
}
wasIdle = idle;
idle = runningCommands.isEmpty();
if (wasIdle && idle && BlazeClock.nanoTime() >= shutdownTime) {
break;
}
}
}
server.shutdown();
}
@Override
public void interrupt() {
synchronized (runningCommands) {
for (RunningCommand command : runningCommands.values()) {
command.thread.interrupt();
}
startSlowInterruptWatcher(ImmutableSet.copyOf(runningCommands.keySet()));
}
}
@Override
public void serve() throws IOException {
Preconditions.checkState(!serving);
// For reasons only Apple knows, you cannot bind to IPv4-localhost when you run in a sandbox
// that only allows loopback traffic, but binding to IPv6-localhost works fine. This would
// however break on systems that don't support IPv6. So what we'll do is to try to bind to IPv6
// and if that fails, try again with IPv4.
InetSocketAddress address = new InetSocketAddress("[::1]", port);
try {
server = NettyServerBuilder.forAddress(address).addService(commandServer).build().start();
} catch (IOException e) {
address = new InetSocketAddress("127.0.0.1", port);
server = NettyServerBuilder.forAddress(address).addService(commandServer).build().start();
}
if (maxIdleSeconds > 0) {
Thread timeoutThread =
new Thread(
new Runnable() {
@Override
public void run() {
timeoutThread();
}
});
timeoutThread.setDaemon(true);
timeoutThread.start();
}
serving = true;
writeServerFile(
PORT_FILE, InetAddresses.toUriString(address.getAddress()) + ":" + server.getPort());
writeServerFile(REQUEST_COOKIE_FILE, requestCookie);
writeServerFile(RESPONSE_COOKIE_FILE, responseCookie);
try {
server.awaitTermination();
} catch (InterruptedException e) {
// TODO(lberki): Handle SIGINT in a reasonable way
throw new IllegalStateException(e);
}
}
private void writeServerFile(String name, String contents) throws IOException {
Path file = serverDirectory.getChild(name);
FileSystemUtils.writeContentAsLatin1(file, contents);
deleteAtExit(file, false);
}
protected void disableShutdownHooks() {
runShutdownHooks.set(false);
}
/**
* Schedule the specified file for (attempted) deletion at JVM exit.
*/
protected static void deleteAtExit(final Path path, final boolean deleteParent) {
Runtime.getRuntime().addShutdownHook(new Thread() {
@Override
public void run() {
if (!runShutdownHooks.get()) {
return;
}
try {
path.delete();
if (deleteParent) {
path.getParentDirectory().delete();
}
} catch (IOException e) {
printStack(e);
}
}
});
}
static void printStack(IOException e) {
/*
* Hopefully this never happens. It's not very nice to just write this
* to the user's console, but I'm not sure what better choice we have.
*/
StringWriter err = new StringWriter();
PrintWriter printErr = new PrintWriter(err);
printErr.println("=======[BLAZE SERVER: ENCOUNTERED IO EXCEPTION]=======");
e.printStackTrace(printErr);
printErr.println("=====================================================");
LOG.severe(err.toString());
}
private final CommandServerGrpc.CommandServerImplBase commandServer =
new CommandServerGrpc.CommandServerImplBase() {
@Override
public void run(RunRequest request, StreamObserver<RunResponse> observer) {
if (!request.getCookie().equals(requestCookie)
|| request.getClientDescription().isEmpty()) {
observer.onNext(
RunResponse.newBuilder()
.setExitCode(ExitCode.LOCAL_ENVIRONMENTAL_ERROR.getNumericExitCode())
.build());
observer.onCompleted();
return;
}
ImmutableList.Builder<String> args = ImmutableList.builder();
for (ByteString requestArg : request.getArgList()) {
args.add(requestArg.toString(CHARSET));
}
String commandId;
int exitCode;
try (RunningCommand command = new RunningCommand()) {
commandId = command.id;
OutErr rpcOutErr =
OutErr.create(
new RpcOutputStream(observer, command.id, StreamType.STDOUT),
new RpcOutputStream(observer, command.id, StreamType.STDERR));
exitCode =
commandExecutor.exec(
args.build(),
rpcOutErr,
request.getBlockForLock() ? LockingMode.WAIT : LockingMode.ERROR_OUT,
request.getClientDescription(),
clock.currentTimeMillis());
} catch (InterruptedException e) {
exitCode = ExitCode.INTERRUPTED.getNumericExitCode();
commandId = ""; // The default value, the client will ignore it
}
// There is a chance that a cancel request comes in after commandExecutor#exec() has
// finished and no one calls Thread.interrupted() to receive the interrupt. So we just
// reset the interruption state here to make these cancel requests not have any effect
// outside of command execution (after the try block above, the cancel request won't find
// the thread to interrupt)
Thread.interrupted();
RunResponse response =
RunResponse.newBuilder()
.setCookie(responseCookie)
.setCommandId(commandId)
.setFinished(true)
.setExitCode(exitCode)
.build();
observer.onNext(response);
observer.onCompleted();
switch (commandExecutor.shutdown()) {
case NONE:
break;
case CLEAN:
server.shutdownNow();
break;
case EXPUNGE:
disableShutdownHooks();
server.shutdownNow();
break;
}
}
@Override
public void ping(PingRequest pingRequest, StreamObserver<PingResponse> streamObserver) {
Preconditions.checkState(serving);
try (RunningCommand command = new RunningCommand()) {
PingResponse.Builder response = PingResponse.newBuilder();
if (pingRequest.getCookie().equals(requestCookie)) {
response.setCookie(responseCookie);
}
streamObserver.onNext(response.build());
streamObserver.onCompleted();
}
}
@Override
public void cancel(CancelRequest request, StreamObserver<CancelResponse> streamObserver) {
if (!request.getCookie().equals(requestCookie)) {
streamObserver.onCompleted();
return;
}
try (RunningCommand cancelCommand = new RunningCommand()) {
synchronized (runningCommands) {
RunningCommand pendingCommand = runningCommands.get(request.getCommandId());
if (pendingCommand != null) {
pendingCommand.thread.interrupt();
}
startSlowInterruptWatcher(ImmutableSet.of(request.getCommandId()));
}
streamObserver.onNext(CancelResponse.newBuilder().setCookie(responseCookie).build());
streamObserver.onCompleted();
}
}
};
}