Give multiplex WorkRequests consecutive request_ids instead of just workerid
This is needed for cancellation, lest a cancelled WorkResponse pretend to be a response to a newer WorkRequest.
PiperOrigin-RevId: 341099677
diff --git a/src/main/java/com/google/devtools/build/lib/shell/SubprocessBuilder.java b/src/main/java/com/google/devtools/build/lib/shell/SubprocessBuilder.java
index 36c9cfa..af568e3 100644
--- a/src/main/java/com/google/devtools/build/lib/shell/SubprocessBuilder.java
+++ b/src/main/java/com/google/devtools/build/lib/shell/SubprocessBuilder.java
@@ -53,8 +53,12 @@
static SubprocessFactory defaultFactory = JavaSubprocessFactory.INSTANCE;
+ /**
+ * Sets the default factory class for creating subprocesses. Passing {@code null} resets it to the
+ * initial state.
+ */
public static void setDefaultSubprocessFactory(SubprocessFactory factory) {
- SubprocessBuilder.defaultFactory = factory;
+ SubprocessBuilder.defaultFactory = factory != null ? factory : JavaSubprocessFactory.INSTANCE;
}
public SubprocessBuilder() {
diff --git a/src/main/java/com/google/devtools/build/lib/worker/Worker.java b/src/main/java/com/google/devtools/build/lib/worker/Worker.java
index 93c7993..cf720f0 100644
--- a/src/main/java/com/google/devtools/build/lib/worker/Worker.java
+++ b/src/main/java/com/google/devtools/build/lib/worker/Worker.java
@@ -159,7 +159,7 @@
workerProtocol.putRequest(request);
}
- WorkResponse getResponse() throws IOException {
+ WorkResponse getResponse(int requestId) throws IOException {
recordingInputStream.startRecording(4096);
return workerProtocol.getResponse();
}
diff --git a/src/main/java/com/google/devtools/build/lib/worker/WorkerMultiplexer.java b/src/main/java/com/google/devtools/build/lib/worker/WorkerMultiplexer.java
index d97d0ab..881e516 100644
--- a/src/main/java/com/google/devtools/build/lib/worker/WorkerMultiplexer.java
+++ b/src/main/java/com/google/devtools/build/lib/worker/WorkerMultiplexer.java
@@ -14,10 +14,12 @@
package com.google.devtools.build.lib.worker;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.flogger.GoogleLogger;
import com.google.devtools.build.lib.shell.Subprocess;
import com.google.devtools.build.lib.shell.SubprocessBuilder;
+import com.google.devtools.build.lib.shell.SubprocessFactory;
import com.google.devtools.build.lib.vfs.Path;
import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest;
import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse;
@@ -49,7 +51,7 @@
/** A semaphore to protect {@code workerProcessResponse} object. */
private Semaphore semWorkerProcessResponse;
/**
- * A map of semaphores corresponding to {@code WorkerProxy} objects. After sending the {@code
+ * A map of semaphores corresponding to {@code WorkRequest}s. After sending the {@code
* WorkRequest}, {@code WorkerProxy} will wait on a semaphore to be released. {@code
* WorkerMultiplexer} is responsible for releasing the corresponding semaphore in order to signal
* {@code WorkerProxy} that the {@code WorkerResponse} has been received.
@@ -79,6 +81,8 @@
*/
private final Path logFile;
+ private SubprocessFactory subprocessFactory = null;
+
WorkerMultiplexer(Path logFile) {
semWorkerProcessResponse = new Semaphore(1);
semResponseChecker = new Semaphore(1);
@@ -96,10 +100,7 @@
*/
public synchronized void createProcess(WorkerKey workerKey, Path workDir) throws IOException {
// The process may have died in the meanwhile (e.g. between builds).
- if (this.process != null && !this.process.isAlive()) {
- this.process = null;
- }
- if (this.process == null) {
+ if (this.process == null || !this.process.isAlive()) {
ImmutableList<String> args = workerKey.getArgs();
File executable = new File(args.get(0));
if (!executable.isAbsolute() && executable.getParent() != null) {
@@ -107,7 +108,10 @@
newArgs.set(0, new File(workDir.getPathFile(), newArgs.get(0)).getAbsolutePath());
args = ImmutableList.copyOf(newArgs);
}
- SubprocessBuilder processBuilder = new SubprocessBuilder();
+ SubprocessBuilder processBuilder =
+ subprocessFactory != null
+ ? new SubprocessBuilder(subprocessFactory)
+ : new SubprocessBuilder();
processBuilder.setArgv(args);
processBuilder.setWorkingDirectory(workDir.getPathFile());
processBuilder.setStderr(logFile.getPathFile());
@@ -178,43 +182,53 @@
* Waits on a semaphore for the {@code WorkResponse} returned from worker process. This method is
* called on the thread of a {@code WorkerProxy}.
*/
- public InputStream getResponse(Integer workerId) throws IOException, InterruptedException {
- semResponseChecker.acquire();
- Semaphore waitForResponse = responseChecker.get(workerId);
- semResponseChecker.release();
+ public InputStream getResponse(Integer requestId) throws IOException, InterruptedException {
+ try {
+ semResponseChecker.acquire();
+ Semaphore waitForResponse = responseChecker.get(requestId);
+ semResponseChecker.release();
- if (waitForResponse == null) {
- // If the multiplexer is interrupted when a {@code WorkerProxy} is trying to send a request,
- // the request is not sent, so there is no need to wait for a response.
- return null;
+ if (waitForResponse == null) {
+ // If the multiplexer is interrupted when a {@code WorkerProxy} is trying to send a request,
+ // the request is not sent, so there is no need to wait for a response.
+ return null;
+ }
+
+ // Wait for the multiplexer to get our response and release this semaphore. The semaphore will
+ // throw {@code InterruptedException} when the multiplexer is terminated.
+ waitForResponse.acquire();
+
+ if (isWorkerStreamClosed) {
+ return null;
+ }
+
+ if (isUnparseable) {
+ recordingStream.readRemaining();
+ throw new IOException(recordingStream.getRecordedDataAsString());
+ }
+
+ semWorkerProcessResponse.acquire();
+ InputStream response = workerProcessResponse.get(requestId);
+ semWorkerProcessResponse.release();
+ return response;
+ } finally {
+ // TODO(b/151767359): Make sure these also get cleared if a worker gets
+ semResponseChecker.acquire();
+ responseChecker.remove(requestId);
+ semResponseChecker.release();
+ semWorkerProcessResponse.acquire();
+ workerProcessResponse.remove(requestId);
+ semWorkerProcessResponse.release();
}
-
- // Wait for the multiplexer to get our response and release this semaphore. The semaphore will
- // throw {@code InterruptedException} when the multiplexer is terminated.
- waitForResponse.acquire();
-
- if (isWorkerStreamClosed) {
- return null;
- }
-
- if (isUnparseable) {
- recordingStream.readRemaining();
- throw new IOException(recordingStream.getRecordedDataAsString());
- }
-
- semWorkerProcessResponse.acquire();
- InputStream response = workerProcessResponse.get(workerId);
- semWorkerProcessResponse.release();
- return response;
}
/**
- * Resets the semaphore map for {@code workerId} before sending a request to the worker process.
+ * Resets the semaphore map for {@code requestId} before sending a request to the worker process.
* This method is called on the thread of a {@code WorkerProxy}.
*/
- public void resetResponseChecker(Integer workerId) throws InterruptedException {
+ public void resetResponseChecker(Integer requestId) throws InterruptedException {
semResponseChecker.acquire();
- responseChecker.put(workerId, new Semaphore(0));
+ responseChecker.put(requestId, new Semaphore(0));
semResponseChecker.release();
}
@@ -234,16 +248,18 @@
return;
}
- Integer workerId = parsedResponse.getRequestId();
+ int requestId = parsedResponse.getRequestId();
ByteArrayOutputStream tempOs = new ByteArrayOutputStream();
parsedResponse.writeDelimitedTo(tempOs);
semWorkerProcessResponse.acquire();
- workerProcessResponse.put(workerId, new ByteArrayInputStream(tempOs.toByteArray()));
+ workerProcessResponse.put(requestId, new ByteArrayInputStream(tempOs.toByteArray()));
semWorkerProcessResponse.release();
+ // TODO(b/151767359): When allowing cancellation, remove responses that have no matching
+ // entry in responseChecker.
semResponseChecker.acquire();
- responseChecker.get(workerId).release();
+ responseChecker.get(requestId).release();
semResponseChecker.release();
}
@@ -274,8 +290,8 @@
private void releaseAllSemaphores() {
try {
semResponseChecker.acquire();
- for (Integer workerId : responseChecker.keySet()) {
- responseChecker.get(workerId).release();
+ for (Integer requestId : responseChecker.keySet()) {
+ responseChecker.get(requestId).release();
}
} catch (InterruptedException e) {
// Do nothing
@@ -283,4 +299,14 @@
semResponseChecker.release();
}
}
+
+ /** For testing only, to verify that maps are cleared after responses are reaped. */
+ @VisibleForTesting
+ boolean noOutstandingRequests() {
+ return responseChecker.isEmpty() && workerProcessResponse.isEmpty();
+ }
+
+ public void setProcessFactory(SubprocessFactory factory) {
+ subprocessFactory = factory;
+ }
}
diff --git a/src/main/java/com/google/devtools/build/lib/worker/WorkerProxy.java b/src/main/java/com/google/devtools/build/lib/worker/WorkerProxy.java
index 8b9b76b..7463926 100644
--- a/src/main/java/com/google/devtools/build/lib/worker/WorkerProxy.java
+++ b/src/main/java/com/google/devtools/build/lib/worker/WorkerProxy.java
@@ -80,7 +80,7 @@
@Override
void putRequest(WorkRequest request) throws IOException {
try {
- workerMultiplexer.resetResponseChecker(workerId);
+ workerMultiplexer.resetResponseChecker(request.getRequestId());
workerMultiplexer.putRequest(request);
} catch (InterruptedException e) {
/**
@@ -96,9 +96,9 @@
/** Wait for WorkResponse from multiplexer. */
@Override
- WorkResponse getResponse() throws IOException {
+ WorkResponse getResponse(int requestId) throws IOException {
try {
- InputStream inputStream = workerMultiplexer.getResponse(workerId);
+ InputStream inputStream = workerMultiplexer.getResponse(requestId);
if (inputStream == null) {
return null;
}
diff --git a/src/main/java/com/google/devtools/build/lib/worker/WorkerSpawnRunner.java b/src/main/java/com/google/devtools/build/lib/worker/WorkerSpawnRunner.java
index b008a51..0b36c8e 100644
--- a/src/main/java/com/google/devtools/build/lib/worker/WorkerSpawnRunner.java
+++ b/src/main/java/com/google/devtools/build/lib/worker/WorkerSpawnRunner.java
@@ -64,6 +64,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.SortedMap;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Pattern;
/**
@@ -93,6 +94,7 @@
private final ResourceManager resourceManager;
private final RunfilesTreeUpdater runfilesTreeUpdater;
private final WorkerOptions workerOptions;
+ private final AtomicInteger requestIdCounter = new AtomicInteger(1);
public WorkerSpawnRunner(
SandboxHelpers helpers,
@@ -291,8 +293,7 @@
Spawn spawn,
SpawnExecutionContext context,
List<String> flagfiles,
- MetadataProvider inputFileCache,
- int workerId)
+ MetadataProvider inputFileCache)
throws IOException {
WorkRequest.Builder requestBuilder = WorkRequest.newBuilder();
for (String flagfile : flagfiles) {
@@ -317,7 +318,7 @@
.setDigest(digest)
.build();
}
- return requestBuilder.setRequestId(workerId).build();
+ return requestBuilder.setRequestId(requestIdCounter.getAndIncrement()).build();
}
/**
@@ -418,8 +419,7 @@
Stopwatch queueStopwatch = Stopwatch.createStarted();
try {
worker = workers.borrowObject(key);
- request =
- createWorkRequest(spawn, context, flagFiles, inputFileCache, worker.getWorkerId());
+ request = createWorkRequest(spawn, context, flagFiles, inputFileCache);
} catch (IOException e) {
String message = "IOException while borrowing a worker from the pool:";
throw createUserExecException(e, message, Code.BORROW_FAILURE);
@@ -464,7 +464,7 @@
}
try {
- response = worker.getResponse();
+ response = worker.getResponse(request.getRequestId());
} catch (IOException e) {
// If protobuf or json reader couldn't parse the response, try to print whatever the
// failing worker wrote to stdout - it's probably a stack trace or some kind of error
diff --git a/src/test/java/com/google/devtools/build/lib/worker/TestUtils.java b/src/test/java/com/google/devtools/build/lib/worker/TestUtils.java
index 4aed125..16b80aa 100644
--- a/src/test/java/com/google/devtools/build/lib/worker/TestUtils.java
+++ b/src/test/java/com/google/devtools/build/lib/worker/TestUtils.java
@@ -19,10 +19,18 @@
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.hash.HashCode;
import com.google.devtools.build.lib.actions.ExecutionRequirements.WorkerProtocolFormat;
+import com.google.devtools.build.lib.shell.Subprocess;
import com.google.devtools.build.lib.vfs.FileSystem;
+import com.google.devtools.build.lib.vfs.Path;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.Closeable;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
/** Utilities that come in handy when unit-testing the worker code. */
-public class TestUtils {
+class TestUtils {
private TestUtils() {}
@@ -39,4 +47,122 @@
/* proxied= */ proxied,
WorkerProtocolFormat.PROTO);
}
+
+ static WorkerKey createWorkerKey(WorkerProtocolFormat protocolFormat, FileSystem fs) {
+ return new WorkerKey(
+ /* args= */ ImmutableList.of("arg1", "arg2", "arg3"),
+ /* env= */ ImmutableMap.of("env1", "foo", "env2", "bar"),
+ /* execRoot= */ fs.getPath("/outputbase/execroot/workspace"),
+ /* mnemonic= */ "dummy",
+ /* workerFilesCombinedHash= */ HashCode.fromInt(0),
+ /* workerFilesWithHashes= */ ImmutableSortedMap.of(),
+ /* mustBeSandboxed= */ true,
+ /* proxied= */ true,
+ protocolFormat);
+ }
+
+ /** A worker that uses a fake subprocess for I/O. */
+ static class TestWorker extends Worker {
+ private final FakeSubprocess fakeSubprocess;
+
+ TestWorker(
+ WorkerKey workerKey,
+ int workerId,
+ final Path workDir,
+ Path logFile,
+ FakeSubprocess fakeSubprocess) {
+ super(workerKey, workerId, workDir, logFile);
+ this.fakeSubprocess = fakeSubprocess;
+ }
+
+ @Override
+ Subprocess createProcess() {
+ return fakeSubprocess;
+ }
+
+ FakeSubprocess getFakeSubprocess() {
+ return fakeSubprocess;
+ }
+ }
+
+ /**
+ * The {@link Worker} object uses a {@link Subprocess} to interact with persistent worker
+ * binaries. Since this test is strictly testing {@link Worker} and not any outside persistent
+ * worker binaries, a {@link FakeSubprocess} instance is used to fake the {@link InputStream} and
+ * {@link OutputStream} that normally write and read from a persistent worker.
+ */
+ static class FakeSubprocess implements Subprocess {
+ private final InputStream inputStream;
+ private final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+ private final ByteArrayInputStream errStream = new ByteArrayInputStream(new byte[0]);
+ private boolean wasDestroyed = false;
+
+ /** Creates a fake Subprocess that writes {@code bytes} to its "stdout". */
+ FakeSubprocess(byte[] bytes) throws IOException {
+ inputStream = new ByteArrayInputStream(bytes);
+ }
+
+ FakeSubprocess(InputStream responseStream) throws IOException {
+ this.inputStream = responseStream;
+ }
+
+ @Override
+ public InputStream getInputStream() {
+ return inputStream;
+ }
+
+ @Override
+ public OutputStream getOutputStream() {
+ return outputStream;
+ }
+
+ @Override
+ public InputStream getErrorStream() {
+ return errStream;
+ }
+
+ @Override
+ public boolean destroy() {
+ for (Closeable stream : new Closeable[] {inputStream, outputStream, errStream}) {
+ try {
+ stream.close();
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ wasDestroyed = true;
+ return true;
+ }
+
+ @Override
+ public int exitValue() {
+ return 0;
+ }
+
+ @Override
+ public boolean finished() {
+ return true;
+ }
+
+ @Override
+ public boolean timedout() {
+ return false;
+ }
+
+ @Override
+ public void waitFor() throws InterruptedException {
+ // Do nothing.
+ }
+
+ @Override
+ public void close() {
+ // Do nothing.
+ }
+
+ @Override
+ public boolean isAlive() {
+ return wasDestroyed;
+ }
+ }
}
diff --git a/src/test/java/com/google/devtools/build/lib/worker/WorkerMultiplexerTest.java b/src/test/java/com/google/devtools/build/lib/worker/WorkerMultiplexerTest.java
new file mode 100644
index 0000000..69784c3
--- /dev/null
+++ b/src/test/java/com/google/devtools/build/lib/worker/WorkerMultiplexerTest.java
@@ -0,0 +1,206 @@
+// Copyright 2020 The Bazel Authors. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.devtools.build.lib.worker;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import com.google.common.util.concurrent.Futures;
+import com.google.devtools.build.lib.clock.BlazeClock;
+import com.google.devtools.build.lib.vfs.DigestHashFunction;
+import com.google.devtools.build.lib.vfs.FileSystem;
+import com.google.devtools.build.lib.vfs.Path;
+import com.google.devtools.build.lib.vfs.inmemoryfs.InMemoryFileSystem;
+import com.google.devtools.build.lib.worker.TestUtils.FakeSubprocess;
+import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest;
+import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.PipedInputStream;
+import java.io.PipedOutputStream;
+import java.lang.Thread.State;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executor;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for WorkerMultiplexer */
+@RunWith(JUnit4.class)
+public class WorkerMultiplexerTest {
+ private FileSystem fileSystem;
+ private Path logPath;
+
+ @Before
+ public void setUp() throws IOException {
+ fileSystem = new InMemoryFileSystem(BlazeClock.instance(), DigestHashFunction.SHA256);
+ logPath = fileSystem.getPath("/tmp/logs4");
+ fileSystem.createDirectoryAndParents(logPath);
+ }
+
+ @Test
+ public void testGetResponse_noOutstandingRequests() throws IOException, InterruptedException {
+ WorkerKey workerKey = TestUtils.createWorkerKey(fileSystem, "test1", true, "fakeBinary");
+ WorkerMultiplexer multiplexer = WorkerMultiplexerManager.getInstance(workerKey, logPath);
+
+ PipedInputStream serverInputStream = new PipedInputStream();
+ OutputStream workerOutputStream = new PipedOutputStream(serverInputStream);
+ multiplexer.setProcessFactory(params -> new FakeSubprocess(serverInputStream));
+
+ WorkRequest request1 = WorkRequest.newBuilder().setRequestId(1).build();
+ WorkerProxy worker = new WorkerProxy(workerKey, 2, logPath, logPath, multiplexer);
+ worker.prepareExecution(null, null, null);
+ worker.putRequest(request1);
+ WorkResponse response1 = WorkResponse.newBuilder().setRequestId(1).build();
+ response1.writeDelimitedTo(workerOutputStream);
+ workerOutputStream.flush();
+ WorkResponse response = worker.getResponse(1);
+ assertThat(response.getRequestId()).isEqualTo(1);
+ // Can't get the same response twice - but the responseChecker is gone, so it just returns null
+ assertThat(multiplexer.getResponse(1)).isNull();
+ assertThat(multiplexer.noOutstandingRequests()).isTrue();
+ }
+
+ @Test
+ public void testGetResponse_basicConcurrency()
+ throws IOException, InterruptedException, ExecutionException {
+ WorkerKey workerKey = TestUtils.createWorkerKey(fileSystem, "test2", true, "fakeBinary");
+ WorkerMultiplexer multiplexer = WorkerMultiplexerManager.getInstance(workerKey, logPath);
+
+ PipedInputStream serverInputStream = new PipedInputStream();
+ OutputStream workerOutputStream = new PipedOutputStream(serverInputStream);
+ multiplexer.setProcessFactory(params -> new FakeSubprocess(serverInputStream));
+
+ WorkerProxy worker1 = new WorkerProxy(workerKey, 1, logPath, logPath, multiplexer);
+ worker1.prepareExecution(null, null, null);
+ WorkRequest request1 = WorkRequest.newBuilder().setRequestId(3).build();
+ worker1.putRequest(request1);
+
+ WorkerProxy worker2 = new WorkerProxy(workerKey, 2, logPath, logPath, multiplexer);
+ worker2.prepareExecution(null, null, null);
+ WorkRequest request2 = WorkRequest.newBuilder().setRequestId(42).build();
+ worker2.putRequest(request2);
+
+ Executor executor = Executors.newFixedThreadPool(2);
+ Future<WorkResponse> response1 = Futures.submit(() -> worker1.getResponse(3), executor);
+ Future<WorkResponse> response2 = Futures.submit(() -> worker2.getResponse(42), executor);
+
+ WorkResponse fakedResponse1 = WorkResponse.newBuilder().setRequestId(3).build();
+ WorkResponse fakedResponse2 = WorkResponse.newBuilder().setRequestId(42).build();
+ // Responses can arrive out of order
+ fakedResponse2.writeDelimitedTo(workerOutputStream);
+ fakedResponse1.writeDelimitedTo(workerOutputStream);
+ workerOutputStream.flush();
+
+ assertThat(response1.get().getRequestId()).isEqualTo(3);
+ assertThat(response2.get().getRequestId()).isEqualTo(42);
+ assertThat(multiplexer.noOutstandingRequests()).isTrue();
+ }
+
+ @Test
+ public void testGetResponse_slowMultiplexer()
+ throws IOException, InterruptedException, ExecutionException {
+ WorkerKey workerKey = TestUtils.createWorkerKey(fileSystem, "test3", true, "fakeBinary");
+ WorkerMultiplexer multiplexer = WorkerMultiplexerManager.getInstance(workerKey, logPath);
+
+ PipedInputStream serverInputStrean = new PipedInputStream();
+ OutputStream workerOutputStream = new PipedOutputStream(serverInputStrean);
+ multiplexer.setProcessFactory(params -> new FakeSubprocess(serverInputStrean));
+
+ WorkerProxy worker1 = new WorkerProxy(workerKey, 1, logPath, logPath, multiplexer);
+ worker1.prepareExecution(null, null, null);
+ WorkRequest request1 = WorkRequest.newBuilder().setRequestId(3).build();
+ worker1.putRequest(request1);
+
+ WorkerProxy worker2 = new WorkerProxy(workerKey, 2, logPath, logPath, multiplexer);
+ worker2.prepareExecution(null, null, null);
+ WorkRequest request2 = WorkRequest.newBuilder().setRequestId(42).build();
+ worker2.putRequest(request2);
+
+ Thread[] proxyThreads = new Thread[2];
+ Executor executor = Executors.newFixedThreadPool(2);
+ Future<WorkResponse> response1 =
+ Futures.submit(
+ () -> {
+ proxyThreads[0] = Thread.currentThread();
+ return worker1.getResponse(3);
+ },
+ executor);
+ Future<WorkResponse> response2 =
+ Futures.submit(
+ () -> {
+ proxyThreads[1] = Thread.currentThread();
+ return worker2.getResponse(42);
+ },
+ executor);
+
+ // Makes sure both workers are waiting for responses before the multiplexer processes anything.
+ while (proxyThreads[0] == null
+ || proxyThreads[0].getState() != State.WAITING
+ || proxyThreads[1] == null
+ || proxyThreads[1].getState() != State.WAITING) {
+ Thread.sleep(1);
+ }
+
+ WorkResponse fakedResponse1 = WorkResponse.newBuilder().setRequestId(3).build();
+ WorkResponse fakedResponse2 = WorkResponse.newBuilder().setRequestId(42).build();
+ // Responses can arrive out of order
+ fakedResponse2.writeDelimitedTo(workerOutputStream);
+ fakedResponse1.writeDelimitedTo(workerOutputStream);
+ workerOutputStream.flush();
+
+ assertThat(response1.get().getRequestId()).isEqualTo(3);
+ assertThat(response2.get().getRequestId()).isEqualTo(42);
+ assertThat(multiplexer.noOutstandingRequests()).isTrue();
+ }
+
+ @Test
+ public void testGetResponse_slowProxy()
+ throws IOException, InterruptedException, ExecutionException {
+ WorkerKey workerKey = TestUtils.createWorkerKey(fileSystem, "test4", true, "fakeBinary");
+ WorkerMultiplexer multiplexer = WorkerMultiplexerManager.getInstance(workerKey, logPath);
+
+ PipedInputStream serverInputStream = new PipedInputStream();
+ OutputStream workerOutputStream = new PipedOutputStream(serverInputStream);
+ multiplexer.setProcessFactory(params -> new FakeSubprocess(serverInputStream));
+
+ WorkerProxy worker1 = new WorkerProxy(workerKey, 1, logPath, logPath, multiplexer);
+ worker1.prepareExecution(null, null, null);
+ WorkRequest request1 = WorkRequest.newBuilder().setRequestId(3).build();
+ worker1.putRequest(request1);
+
+ WorkerProxy worker2 = new WorkerProxy(workerKey, 2, logPath, logPath, multiplexer);
+ worker2.prepareExecution(null, null, null);
+ WorkRequest request2 = WorkRequest.newBuilder().setRequestId(42).build();
+ worker2.putRequest(request2);
+
+ WorkResponse fakedResponse1 = WorkResponse.newBuilder().setRequestId(3).build();
+ WorkResponse fakedResponse2 = WorkResponse.newBuilder().setRequestId(42).build();
+ // Responses can arrive out of order
+ fakedResponse2.writeDelimitedTo(workerOutputStream);
+ fakedResponse1.writeDelimitedTo(workerOutputStream);
+ workerOutputStream.flush();
+
+ Executor executor = Executors.newFixedThreadPool(2);
+ Future<WorkResponse> response1 = Futures.submit(() -> worker1.getResponse(3), executor);
+ Future<WorkResponse> response2 = Futures.submit(() -> worker2.getResponse(42), executor);
+
+ assertThat(response1.get().getRequestId()).isEqualTo(3);
+ assertThat(response2.get().getRequestId()).isEqualTo(42);
+ assertThat(multiplexer.noOutstandingRequests()).isTrue();
+ }
+}
diff --git a/src/test/java/com/google/devtools/build/lib/worker/WorkerSpawnRunnerTest.java b/src/test/java/com/google/devtools/build/lib/worker/WorkerSpawnRunnerTest.java
index f36ae95..224c200 100644
--- a/src/test/java/com/google/devtools/build/lib/worker/WorkerSpawnRunnerTest.java
+++ b/src/test/java/com/google/devtools/build/lib/worker/WorkerSpawnRunnerTest.java
@@ -110,7 +110,7 @@
WorkerKey key = createWorkerKey(fs, "mnem", false);
Path logFile = fs.getPath("/worker.log");
when(worker.getLogFile()).thenReturn(logFile);
- when(worker.getResponse())
+ when(worker.getResponse(1))
.thenReturn(
WorkResponse.newBuilder().setExitCode(0).setOutput("out").setRequestId(1).build());
WorkResponse response =
@@ -149,7 +149,7 @@
WorkerKey key = createWorkerKey(fs, "mnem", false);
Path logFile = fs.getPath("/worker.log");
when(worker.getLogFile()).thenReturn(logFile);
- when(worker.getResponse()).thenThrow(new IOException("Bad protobuf"));
+ when(worker.getResponse(1)).thenThrow(new IOException("Bad protobuf"));
when(worker.getRecordingStreamMessage()).thenReturn(recordedResponse);
String workerLog = "Log from worker\n";
FileSystemUtils.writeIsoLatin1(logFile, workerLog);
diff --git a/src/test/java/com/google/devtools/build/lib/worker/WorkerTest.java b/src/test/java/com/google/devtools/build/lib/worker/WorkerTest.java
index 9b38382..7f6fdf4 100644
--- a/src/test/java/com/google/devtools/build/lib/worker/WorkerTest.java
+++ b/src/test/java/com/google/devtools/build/lib/worker/WorkerTest.java
@@ -21,27 +21,22 @@
import static org.junit.Assert.assertThrows;
import com.google.common.base.Preconditions;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.ImmutableSortedMap;
-import com.google.common.hash.HashCode;
import com.google.devtools.build.lib.actions.ExecutionRequirements.WorkerProtocolFormat;
import com.google.devtools.build.lib.sandbox.SandboxHelpers.SandboxInputs;
import com.google.devtools.build.lib.sandbox.SandboxHelpers.SandboxOutputs;
-import com.google.devtools.build.lib.shell.Subprocess;
import com.google.devtools.build.lib.vfs.DigestHashFunction;
import com.google.devtools.build.lib.vfs.FileSystem;
import com.google.devtools.build.lib.vfs.Path;
import com.google.devtools.build.lib.vfs.inmemoryfs.InMemoryFileSystem;
+import com.google.devtools.build.lib.worker.TestUtils.FakeSubprocess;
+import com.google.devtools.build.lib.worker.TestUtils.TestWorker;
import com.google.devtools.build.lib.worker.WorkerProtocol.Input;
import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest;
import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse;
import com.google.protobuf.ByteString;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
-import java.io.Closeable;
import java.io.IOException;
-import java.io.InputStream;
import java.io.OutputStream;
import org.junit.After;
import org.junit.Test;
@@ -53,26 +48,6 @@
public final class WorkerTest {
final FileSystem fs = new InMemoryFileSystem(DigestHashFunction.SHA256);
- /** A worker that uses a fake subprocess for I/O. */
- private static class TestWorker extends Worker {
- private final FakeSubprocess fakeSubprocess;
-
- public TestWorker(
- WorkerKey workerKey,
- int workerId,
- final Path workDir,
- Path logFile,
- FakeSubprocess fakeSubprocess) {
- super(workerKey, workerId, workDir, logFile);
- this.fakeSubprocess = fakeSubprocess;
- }
-
- @Override
- Subprocess createProcess() {
- return fakeSubprocess;
- }
- }
-
private TestWorker workerForCleanup = null;
@After
@@ -83,95 +58,6 @@
}
}
- private WorkerKey createWorkerKey(WorkerProtocolFormat protocolFormat) {
- return new WorkerKey(
- /* args= */ ImmutableList.of("arg1", "arg2", "arg3"),
- /* env= */ ImmutableMap.of("env1", "foo", "env2", "bar"),
- /* execRoot= */ fs.getPath("/outputbase/execroot/workspace"),
- /* mnemonic= */ "dummy",
- /* workerFilesCombinedHash= */ HashCode.fromInt(0),
- /* workerFilesWithHashes= */ ImmutableSortedMap.of(),
- /* mustBeSandboxed= */ true,
- /* proxied= */ true,
- protocolFormat);
- }
-
- /**
- * The {@link Worker} object uses a {@link Subprocess} to interact with persistent worker
- * binaries. Since this test is strictly testing {@link Worker} and not any outside persistent
- * worker binaries, a {@link FakeSubprocess} instance is used to fake the {@link InputStream} and
- * {@link OutputStream} that normally write and read from a persistent worker.
- */
- private static class FakeSubprocess implements Subprocess {
- private final ByteArrayInputStream inputStream;
- private final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
- private final ByteArrayInputStream errStream = new ByteArrayInputStream(new byte[0]);
- private boolean wasDestroyed = false;
-
- public FakeSubprocess(byte[] bytes) throws IOException {
- inputStream = new ByteArrayInputStream(bytes);
- }
-
- @Override
- public InputStream getInputStream() {
- return inputStream;
- }
-
- @Override
- public OutputStream getOutputStream() {
- return outputStream;
- }
-
- @Override
- public InputStream getErrorStream() {
- return errStream;
- }
-
- @Override
- public boolean destroy() {
- for (Closeable stream : new Closeable[] {inputStream, outputStream, errStream}) {
- try {
- stream.close();
- } catch (IOException e) {
- throw new IllegalStateException(e);
- }
- }
-
- wasDestroyed = true;
- return true;
- }
-
- @Override
- public int exitValue() {
- return 0;
- }
-
- @Override
- public boolean finished() {
- return true;
- }
-
- @Override
- public boolean timedout() {
- return false;
- }
-
- @Override
- public void waitFor() throws InterruptedException {
- // Do nothing.
- }
-
- @Override
- public void close() {
- // Do nothing.
- }
-
- @Override
- public boolean isAlive() {
- return wasDestroyed;
- }
- }
-
private static byte[] serializeResponseToProtoBytes(WorkResponse response) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
response.writeDelimitedTo(baos);
@@ -183,7 +69,7 @@
Preconditions.checkState(
workerForCleanup == null, "createTestWorker can only be called once per test");
- WorkerKey key = createWorkerKey(protocolFormat);
+ WorkerKey key = TestUtils.createWorkerKey(protocolFormat, fs);
FakeSubprocess fakeSubprocess = new FakeSubprocess(outputStreamBytes);
@@ -209,7 +95,7 @@
TestWorker testWorker = createTestWorker(new byte[0], PROTO);
testWorker.putRequest(request);
- OutputStream stdout = testWorker.fakeSubprocess.getOutputStream();
+ OutputStream stdout = testWorker.getFakeSubprocess().getOutputStream();
WorkRequest requestFromStdout =
WorkRequest.parseDelimitedFrom(new ByteArrayInputStream(stdout.toString().getBytes(UTF_8)));
@@ -221,7 +107,7 @@
WorkResponse response = WorkResponse.getDefaultInstance();
TestWorker testWorker = createTestWorker(serializeResponseToProtoBytes(response), PROTO);
- WorkResponse readResponse = testWorker.getResponse();
+ WorkResponse readResponse = testWorker.getResponse(0);
assertThat(readResponse).isEqualTo(response);
}
@@ -231,14 +117,14 @@
TestWorker testWorker = createTestWorker(new byte[0], JSON);
testWorker.putRequest(WorkRequest.getDefaultInstance());
- OutputStream stdout = testWorker.fakeSubprocess.getOutputStream();
+ OutputStream stdout = testWorker.getFakeSubprocess().getOutputStream();
assertThat(stdout.toString()).isEqualTo("{}");
}
@Test
public void testGetResponse_json_success() throws IOException {
TestWorker testWorker = createTestWorker("{}".getBytes(UTF_8), JSON);
- WorkResponse readResponse = testWorker.getResponse();
+ WorkResponse readResponse = testWorker.getResponse(0);
WorkResponse response = WorkResponse.getDefaultInstance();
assertThat(readResponse).isEqualTo(response);
@@ -260,7 +146,7 @@
TestWorker testWorker = createTestWorker(new byte[0], JSON);
testWorker.putRequest(request);
- OutputStream stdout = testWorker.fakeSubprocess.getOutputStream();
+ OutputStream stdout = testWorker.getFakeSubprocess().getOutputStream();
String requestJsonString =
"{\"arguments\":[\"testRequest\"],\"inputs\":"
+ "[{\"path\":\"testPath\",\"digest\":\"dGVzdERpZ2VzdA==\"}],\"requestId\":1}";
@@ -272,7 +158,7 @@
TestWorker testWorker =
createTestWorker(
"{\"exitCode\":1,\"output\":\"test output\",\"requestId\":1}".getBytes(UTF_8), JSON);
- WorkResponse readResponse = testWorker.getResponse();
+ WorkResponse readResponse = testWorker.getResponse(1);
WorkResponse response =
WorkResponse.newBuilder().setExitCode(1).setOutput("test output").setRequestId(1).build();
@@ -282,7 +168,7 @@
private void verifyGetResponseFailure(String responseString, String expectedError)
throws IOException {
TestWorker testWorker = createTestWorker(responseString.getBytes(UTF_8), JSON);
- IOException ex = assertThrows(IOException.class, testWorker::getResponse);
+ IOException ex = assertThrows(IOException.class, () -> testWorker.getResponse(0));
assertThat(ex).hasMessageThat().contains(expectedError);
}