Give multiplex WorkRequests consecutive request_ids instead of just workerid

This is needed for cancellation, lest a cancelled WorkResponse pretend to be a response to a newer WorkRequest.

PiperOrigin-RevId: 341099677
diff --git a/src/main/java/com/google/devtools/build/lib/shell/SubprocessBuilder.java b/src/main/java/com/google/devtools/build/lib/shell/SubprocessBuilder.java
index 36c9cfa..af568e3 100644
--- a/src/main/java/com/google/devtools/build/lib/shell/SubprocessBuilder.java
+++ b/src/main/java/com/google/devtools/build/lib/shell/SubprocessBuilder.java
@@ -53,8 +53,12 @@
 
   static SubprocessFactory defaultFactory = JavaSubprocessFactory.INSTANCE;
 
+  /**
+   * Sets the default factory class for creating subprocesses. Passing {@code null} resets it to the
+   * initial state.
+   */
   public static void setDefaultSubprocessFactory(SubprocessFactory factory) {
-    SubprocessBuilder.defaultFactory = factory;
+    SubprocessBuilder.defaultFactory = factory != null ? factory : JavaSubprocessFactory.INSTANCE;
   }
 
   public SubprocessBuilder() {
diff --git a/src/main/java/com/google/devtools/build/lib/worker/Worker.java b/src/main/java/com/google/devtools/build/lib/worker/Worker.java
index 93c7993..cf720f0 100644
--- a/src/main/java/com/google/devtools/build/lib/worker/Worker.java
+++ b/src/main/java/com/google/devtools/build/lib/worker/Worker.java
@@ -159,7 +159,7 @@
     workerProtocol.putRequest(request);
   }
 
-  WorkResponse getResponse() throws IOException {
+  WorkResponse getResponse(int requestId) throws IOException {
     recordingInputStream.startRecording(4096);
     return workerProtocol.getResponse();
   }
diff --git a/src/main/java/com/google/devtools/build/lib/worker/WorkerMultiplexer.java b/src/main/java/com/google/devtools/build/lib/worker/WorkerMultiplexer.java
index d97d0ab..881e516 100644
--- a/src/main/java/com/google/devtools/build/lib/worker/WorkerMultiplexer.java
+++ b/src/main/java/com/google/devtools/build/lib/worker/WorkerMultiplexer.java
@@ -14,10 +14,12 @@
 
 package com.google.devtools.build.lib.worker;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.ImmutableList;
 import com.google.common.flogger.GoogleLogger;
 import com.google.devtools.build.lib.shell.Subprocess;
 import com.google.devtools.build.lib.shell.SubprocessBuilder;
+import com.google.devtools.build.lib.shell.SubprocessFactory;
 import com.google.devtools.build.lib.vfs.Path;
 import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest;
 import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse;
@@ -49,7 +51,7 @@
   /** A semaphore to protect {@code workerProcessResponse} object. */
   private Semaphore semWorkerProcessResponse;
   /**
-   * A map of semaphores corresponding to {@code WorkerProxy} objects. After sending the {@code
+   * A map of semaphores corresponding to {@code WorkRequest}s. After sending the {@code
    * WorkRequest}, {@code WorkerProxy} will wait on a semaphore to be released. {@code
    * WorkerMultiplexer} is responsible for releasing the corresponding semaphore in order to signal
    * {@code WorkerProxy} that the {@code WorkerResponse} has been received.
@@ -79,6 +81,8 @@
    */
   private final Path logFile;
 
+  private SubprocessFactory subprocessFactory = null;
+
   WorkerMultiplexer(Path logFile) {
     semWorkerProcessResponse = new Semaphore(1);
     semResponseChecker = new Semaphore(1);
@@ -96,10 +100,7 @@
    */
   public synchronized void createProcess(WorkerKey workerKey, Path workDir) throws IOException {
     // The process may have died in the meanwhile (e.g. between builds).
-    if (this.process != null && !this.process.isAlive()) {
-      this.process = null;
-    }
-    if (this.process == null) {
+    if (this.process == null || !this.process.isAlive()) {
       ImmutableList<String> args = workerKey.getArgs();
       File executable = new File(args.get(0));
       if (!executable.isAbsolute() && executable.getParent() != null) {
@@ -107,7 +108,10 @@
         newArgs.set(0, new File(workDir.getPathFile(), newArgs.get(0)).getAbsolutePath());
         args = ImmutableList.copyOf(newArgs);
       }
-      SubprocessBuilder processBuilder = new SubprocessBuilder();
+      SubprocessBuilder processBuilder =
+          subprocessFactory != null
+              ? new SubprocessBuilder(subprocessFactory)
+              : new SubprocessBuilder();
       processBuilder.setArgv(args);
       processBuilder.setWorkingDirectory(workDir.getPathFile());
       processBuilder.setStderr(logFile.getPathFile());
@@ -178,43 +182,53 @@
    * Waits on a semaphore for the {@code WorkResponse} returned from worker process. This method is
    * called on the thread of a {@code WorkerProxy}.
    */
-  public InputStream getResponse(Integer workerId) throws IOException, InterruptedException {
-    semResponseChecker.acquire();
-    Semaphore waitForResponse = responseChecker.get(workerId);
-    semResponseChecker.release();
+  public InputStream getResponse(Integer requestId) throws IOException, InterruptedException {
+    try {
+      semResponseChecker.acquire();
+      Semaphore waitForResponse = responseChecker.get(requestId);
+      semResponseChecker.release();
 
-    if (waitForResponse == null) {
-      // If the multiplexer is interrupted when a {@code WorkerProxy} is trying to send a request,
-      // the request is not sent, so there is no need to wait for a response.
-      return null;
+      if (waitForResponse == null) {
+        // If the multiplexer is interrupted when a {@code WorkerProxy} is trying to send a request,
+        // the request is not sent, so there is no need to wait for a response.
+        return null;
+      }
+
+      // Wait for the multiplexer to get our response and release this semaphore. The semaphore will
+      // throw {@code InterruptedException} when the multiplexer is terminated.
+      waitForResponse.acquire();
+
+      if (isWorkerStreamClosed) {
+        return null;
+      }
+
+      if (isUnparseable) {
+        recordingStream.readRemaining();
+        throw new IOException(recordingStream.getRecordedDataAsString());
+      }
+
+      semWorkerProcessResponse.acquire();
+      InputStream response = workerProcessResponse.get(requestId);
+      semWorkerProcessResponse.release();
+      return response;
+    } finally {
+      // TODO(b/151767359): Make sure these also get cleared if a worker gets
+      semResponseChecker.acquire();
+      responseChecker.remove(requestId);
+      semResponseChecker.release();
+      semWorkerProcessResponse.acquire();
+      workerProcessResponse.remove(requestId);
+      semWorkerProcessResponse.release();
     }
-
-    // Wait for the multiplexer to get our response and release this semaphore. The semaphore will
-    // throw {@code InterruptedException} when the multiplexer is terminated.
-    waitForResponse.acquire();
-
-    if (isWorkerStreamClosed) {
-      return null;
-    }
-
-    if (isUnparseable) {
-      recordingStream.readRemaining();
-      throw new IOException(recordingStream.getRecordedDataAsString());
-    }
-
-    semWorkerProcessResponse.acquire();
-    InputStream response = workerProcessResponse.get(workerId);
-    semWorkerProcessResponse.release();
-    return response;
   }
 
   /**
-   * Resets the semaphore map for {@code workerId} before sending a request to the worker process.
+   * Resets the semaphore map for {@code requestId} before sending a request to the worker process.
    * This method is called on the thread of a {@code WorkerProxy}.
    */
-  public void resetResponseChecker(Integer workerId) throws InterruptedException {
+  public void resetResponseChecker(Integer requestId) throws InterruptedException {
     semResponseChecker.acquire();
-    responseChecker.put(workerId, new Semaphore(0));
+    responseChecker.put(requestId, new Semaphore(0));
     semResponseChecker.release();
   }
 
@@ -234,16 +248,18 @@
       return;
     }
 
-    Integer workerId = parsedResponse.getRequestId();
+    int requestId = parsedResponse.getRequestId();
     ByteArrayOutputStream tempOs = new ByteArrayOutputStream();
     parsedResponse.writeDelimitedTo(tempOs);
 
     semWorkerProcessResponse.acquire();
-    workerProcessResponse.put(workerId, new ByteArrayInputStream(tempOs.toByteArray()));
+    workerProcessResponse.put(requestId, new ByteArrayInputStream(tempOs.toByteArray()));
     semWorkerProcessResponse.release();
 
+    // TODO(b/151767359): When allowing cancellation, remove responses that have no matching
+    // entry in responseChecker.
     semResponseChecker.acquire();
-    responseChecker.get(workerId).release();
+    responseChecker.get(requestId).release();
     semResponseChecker.release();
   }
 
@@ -274,8 +290,8 @@
   private void releaseAllSemaphores() {
     try {
       semResponseChecker.acquire();
-      for (Integer workerId : responseChecker.keySet()) {
-        responseChecker.get(workerId).release();
+      for (Integer requestId : responseChecker.keySet()) {
+        responseChecker.get(requestId).release();
       }
     } catch (InterruptedException e) {
       // Do nothing
@@ -283,4 +299,14 @@
       semResponseChecker.release();
     }
   }
+
+  /** For testing only, to verify that maps are cleared after responses are reaped. */
+  @VisibleForTesting
+  boolean noOutstandingRequests() {
+    return responseChecker.isEmpty() && workerProcessResponse.isEmpty();
+  }
+
+  public void setProcessFactory(SubprocessFactory factory) {
+    subprocessFactory = factory;
+  }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/worker/WorkerProxy.java b/src/main/java/com/google/devtools/build/lib/worker/WorkerProxy.java
index 8b9b76b..7463926 100644
--- a/src/main/java/com/google/devtools/build/lib/worker/WorkerProxy.java
+++ b/src/main/java/com/google/devtools/build/lib/worker/WorkerProxy.java
@@ -80,7 +80,7 @@
   @Override
   void putRequest(WorkRequest request) throws IOException {
     try {
-      workerMultiplexer.resetResponseChecker(workerId);
+      workerMultiplexer.resetResponseChecker(request.getRequestId());
       workerMultiplexer.putRequest(request);
     } catch (InterruptedException e) {
       /**
@@ -96,9 +96,9 @@
 
   /** Wait for WorkResponse from multiplexer. */
   @Override
-  WorkResponse getResponse() throws IOException {
+  WorkResponse getResponse(int requestId) throws IOException {
     try {
-      InputStream inputStream = workerMultiplexer.getResponse(workerId);
+      InputStream inputStream = workerMultiplexer.getResponse(requestId);
       if (inputStream == null) {
         return null;
       }
diff --git a/src/main/java/com/google/devtools/build/lib/worker/WorkerSpawnRunner.java b/src/main/java/com/google/devtools/build/lib/worker/WorkerSpawnRunner.java
index b008a51..0b36c8e 100644
--- a/src/main/java/com/google/devtools/build/lib/worker/WorkerSpawnRunner.java
+++ b/src/main/java/com/google/devtools/build/lib/worker/WorkerSpawnRunner.java
@@ -64,6 +64,7 @@
 import java.util.ArrayList;
 import java.util.List;
 import java.util.SortedMap;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.regex.Pattern;
 
 /**
@@ -93,6 +94,7 @@
   private final ResourceManager resourceManager;
   private final RunfilesTreeUpdater runfilesTreeUpdater;
   private final WorkerOptions workerOptions;
+  private final AtomicInteger requestIdCounter = new AtomicInteger(1);
 
   public WorkerSpawnRunner(
       SandboxHelpers helpers,
@@ -291,8 +293,7 @@
       Spawn spawn,
       SpawnExecutionContext context,
       List<String> flagfiles,
-      MetadataProvider inputFileCache,
-      int workerId)
+      MetadataProvider inputFileCache)
       throws IOException {
     WorkRequest.Builder requestBuilder = WorkRequest.newBuilder();
     for (String flagfile : flagfiles) {
@@ -317,7 +318,7 @@
           .setDigest(digest)
           .build();
     }
-    return requestBuilder.setRequestId(workerId).build();
+    return requestBuilder.setRequestId(requestIdCounter.getAndIncrement()).build();
   }
 
   /**
@@ -418,8 +419,7 @@
       Stopwatch queueStopwatch = Stopwatch.createStarted();
       try {
         worker = workers.borrowObject(key);
-        request =
-            createWorkRequest(spawn, context, flagFiles, inputFileCache, worker.getWorkerId());
+        request = createWorkRequest(spawn, context, flagFiles, inputFileCache);
       } catch (IOException e) {
         String message = "IOException while borrowing a worker from the pool:";
         throw createUserExecException(e, message, Code.BORROW_FAILURE);
@@ -464,7 +464,7 @@
         }
 
         try {
-          response = worker.getResponse();
+          response = worker.getResponse(request.getRequestId());
         } catch (IOException e) {
           // If protobuf or json reader couldn't parse the response, try to print whatever the
           // failing worker wrote to stdout - it's probably a stack trace or some kind of error
diff --git a/src/test/java/com/google/devtools/build/lib/worker/TestUtils.java b/src/test/java/com/google/devtools/build/lib/worker/TestUtils.java
index 4aed125..16b80aa 100644
--- a/src/test/java/com/google/devtools/build/lib/worker/TestUtils.java
+++ b/src/test/java/com/google/devtools/build/lib/worker/TestUtils.java
@@ -19,10 +19,18 @@
 import com.google.common.collect.ImmutableSortedMap;
 import com.google.common.hash.HashCode;
 import com.google.devtools.build.lib.actions.ExecutionRequirements.WorkerProtocolFormat;
+import com.google.devtools.build.lib.shell.Subprocess;
 import com.google.devtools.build.lib.vfs.FileSystem;
+import com.google.devtools.build.lib.vfs.Path;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.Closeable;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
 
 /** Utilities that come in handy when unit-testing the worker code. */
-public class TestUtils {
+class TestUtils {
 
   private TestUtils() {}
 
@@ -39,4 +47,122 @@
         /* proxied= */ proxied,
         WorkerProtocolFormat.PROTO);
   }
+
+  static WorkerKey createWorkerKey(WorkerProtocolFormat protocolFormat, FileSystem fs) {
+    return new WorkerKey(
+        /* args= */ ImmutableList.of("arg1", "arg2", "arg3"),
+        /* env= */ ImmutableMap.of("env1", "foo", "env2", "bar"),
+        /* execRoot= */ fs.getPath("/outputbase/execroot/workspace"),
+        /* mnemonic= */ "dummy",
+        /* workerFilesCombinedHash= */ HashCode.fromInt(0),
+        /* workerFilesWithHashes= */ ImmutableSortedMap.of(),
+        /* mustBeSandboxed= */ true,
+        /* proxied= */ true,
+        protocolFormat);
+  }
+
+  /** A worker that uses a fake subprocess for I/O. */
+  static class TestWorker extends Worker {
+    private final FakeSubprocess fakeSubprocess;
+
+    TestWorker(
+        WorkerKey workerKey,
+        int workerId,
+        final Path workDir,
+        Path logFile,
+        FakeSubprocess fakeSubprocess) {
+      super(workerKey, workerId, workDir, logFile);
+      this.fakeSubprocess = fakeSubprocess;
+    }
+
+    @Override
+    Subprocess createProcess() {
+      return fakeSubprocess;
+    }
+
+    FakeSubprocess getFakeSubprocess() {
+      return fakeSubprocess;
+    }
+  }
+
+  /**
+   * The {@link Worker} object uses a {@link Subprocess} to interact with persistent worker
+   * binaries. Since this test is strictly testing {@link Worker} and not any outside persistent
+   * worker binaries, a {@link FakeSubprocess} instance is used to fake the {@link InputStream} and
+   * {@link OutputStream} that normally write and read from a persistent worker.
+   */
+  static class FakeSubprocess implements Subprocess {
+    private final InputStream inputStream;
+    private final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+    private final ByteArrayInputStream errStream = new ByteArrayInputStream(new byte[0]);
+    private boolean wasDestroyed = false;
+
+    /** Creates a fake Subprocess that writes {@code bytes} to its "stdout". */
+    FakeSubprocess(byte[] bytes) throws IOException {
+      inputStream = new ByteArrayInputStream(bytes);
+    }
+
+    FakeSubprocess(InputStream responseStream) throws IOException {
+      this.inputStream = responseStream;
+    }
+
+    @Override
+    public InputStream getInputStream() {
+      return inputStream;
+    }
+
+    @Override
+    public OutputStream getOutputStream() {
+      return outputStream;
+    }
+
+    @Override
+    public InputStream getErrorStream() {
+      return errStream;
+    }
+
+    @Override
+    public boolean destroy() {
+      for (Closeable stream : new Closeable[] {inputStream, outputStream, errStream}) {
+        try {
+          stream.close();
+        } catch (IOException e) {
+          throw new IllegalStateException(e);
+        }
+      }
+
+      wasDestroyed = true;
+      return true;
+    }
+
+    @Override
+    public int exitValue() {
+      return 0;
+    }
+
+    @Override
+    public boolean finished() {
+      return true;
+    }
+
+    @Override
+    public boolean timedout() {
+      return false;
+    }
+
+    @Override
+    public void waitFor() throws InterruptedException {
+      // Do nothing.
+    }
+
+    @Override
+    public void close() {
+      // Do nothing.
+    }
+
+    @Override
+    public boolean isAlive() {
+      return wasDestroyed;
+    }
+  }
 }
diff --git a/src/test/java/com/google/devtools/build/lib/worker/WorkerMultiplexerTest.java b/src/test/java/com/google/devtools/build/lib/worker/WorkerMultiplexerTest.java
new file mode 100644
index 0000000..69784c3
--- /dev/null
+++ b/src/test/java/com/google/devtools/build/lib/worker/WorkerMultiplexerTest.java
@@ -0,0 +1,206 @@
+// Copyright 2020 The Bazel Authors. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.devtools.build.lib.worker;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import com.google.common.util.concurrent.Futures;
+import com.google.devtools.build.lib.clock.BlazeClock;
+import com.google.devtools.build.lib.vfs.DigestHashFunction;
+import com.google.devtools.build.lib.vfs.FileSystem;
+import com.google.devtools.build.lib.vfs.Path;
+import com.google.devtools.build.lib.vfs.inmemoryfs.InMemoryFileSystem;
+import com.google.devtools.build.lib.worker.TestUtils.FakeSubprocess;
+import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest;
+import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.PipedInputStream;
+import java.io.PipedOutputStream;
+import java.lang.Thread.State;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executor;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for WorkerMultiplexer */
+@RunWith(JUnit4.class)
+public class WorkerMultiplexerTest {
+  private FileSystem fileSystem;
+  private Path logPath;
+
+  @Before
+  public void setUp() throws IOException {
+    fileSystem = new InMemoryFileSystem(BlazeClock.instance(), DigestHashFunction.SHA256);
+    logPath = fileSystem.getPath("/tmp/logs4");
+    fileSystem.createDirectoryAndParents(logPath);
+  }
+
+  @Test
+  public void testGetResponse_noOutstandingRequests() throws IOException, InterruptedException {
+    WorkerKey workerKey = TestUtils.createWorkerKey(fileSystem, "test1", true, "fakeBinary");
+    WorkerMultiplexer multiplexer = WorkerMultiplexerManager.getInstance(workerKey, logPath);
+
+    PipedInputStream serverInputStream = new PipedInputStream();
+    OutputStream workerOutputStream = new PipedOutputStream(serverInputStream);
+    multiplexer.setProcessFactory(params -> new FakeSubprocess(serverInputStream));
+
+    WorkRequest request1 = WorkRequest.newBuilder().setRequestId(1).build();
+    WorkerProxy worker = new WorkerProxy(workerKey, 2, logPath, logPath, multiplexer);
+    worker.prepareExecution(null, null, null);
+    worker.putRequest(request1);
+    WorkResponse response1 = WorkResponse.newBuilder().setRequestId(1).build();
+    response1.writeDelimitedTo(workerOutputStream);
+    workerOutputStream.flush();
+    WorkResponse response = worker.getResponse(1);
+    assertThat(response.getRequestId()).isEqualTo(1);
+    // Can't get the same response twice - but the responseChecker is gone, so it just returns null
+    assertThat(multiplexer.getResponse(1)).isNull();
+    assertThat(multiplexer.noOutstandingRequests()).isTrue();
+  }
+
+  @Test
+  public void testGetResponse_basicConcurrency()
+      throws IOException, InterruptedException, ExecutionException {
+    WorkerKey workerKey = TestUtils.createWorkerKey(fileSystem, "test2", true, "fakeBinary");
+    WorkerMultiplexer multiplexer = WorkerMultiplexerManager.getInstance(workerKey, logPath);
+
+    PipedInputStream serverInputStream = new PipedInputStream();
+    OutputStream workerOutputStream = new PipedOutputStream(serverInputStream);
+    multiplexer.setProcessFactory(params -> new FakeSubprocess(serverInputStream));
+
+    WorkerProxy worker1 = new WorkerProxy(workerKey, 1, logPath, logPath, multiplexer);
+    worker1.prepareExecution(null, null, null);
+    WorkRequest request1 = WorkRequest.newBuilder().setRequestId(3).build();
+    worker1.putRequest(request1);
+
+    WorkerProxy worker2 = new WorkerProxy(workerKey, 2, logPath, logPath, multiplexer);
+    worker2.prepareExecution(null, null, null);
+    WorkRequest request2 = WorkRequest.newBuilder().setRequestId(42).build();
+    worker2.putRequest(request2);
+
+    Executor executor = Executors.newFixedThreadPool(2);
+    Future<WorkResponse> response1 = Futures.submit(() -> worker1.getResponse(3), executor);
+    Future<WorkResponse> response2 = Futures.submit(() -> worker2.getResponse(42), executor);
+
+    WorkResponse fakedResponse1 = WorkResponse.newBuilder().setRequestId(3).build();
+    WorkResponse fakedResponse2 = WorkResponse.newBuilder().setRequestId(42).build();
+    // Responses can arrive out of order
+    fakedResponse2.writeDelimitedTo(workerOutputStream);
+    fakedResponse1.writeDelimitedTo(workerOutputStream);
+    workerOutputStream.flush();
+
+    assertThat(response1.get().getRequestId()).isEqualTo(3);
+    assertThat(response2.get().getRequestId()).isEqualTo(42);
+    assertThat(multiplexer.noOutstandingRequests()).isTrue();
+  }
+
+  @Test
+  public void testGetResponse_slowMultiplexer()
+      throws IOException, InterruptedException, ExecutionException {
+    WorkerKey workerKey = TestUtils.createWorkerKey(fileSystem, "test3", true, "fakeBinary");
+    WorkerMultiplexer multiplexer = WorkerMultiplexerManager.getInstance(workerKey, logPath);
+
+    PipedInputStream serverInputStrean = new PipedInputStream();
+    OutputStream workerOutputStream = new PipedOutputStream(serverInputStrean);
+    multiplexer.setProcessFactory(params -> new FakeSubprocess(serverInputStrean));
+
+    WorkerProxy worker1 = new WorkerProxy(workerKey, 1, logPath, logPath, multiplexer);
+    worker1.prepareExecution(null, null, null);
+    WorkRequest request1 = WorkRequest.newBuilder().setRequestId(3).build();
+    worker1.putRequest(request1);
+
+    WorkerProxy worker2 = new WorkerProxy(workerKey, 2, logPath, logPath, multiplexer);
+    worker2.prepareExecution(null, null, null);
+    WorkRequest request2 = WorkRequest.newBuilder().setRequestId(42).build();
+    worker2.putRequest(request2);
+
+    Thread[] proxyThreads = new Thread[2];
+    Executor executor = Executors.newFixedThreadPool(2);
+    Future<WorkResponse> response1 =
+        Futures.submit(
+            () -> {
+              proxyThreads[0] = Thread.currentThread();
+              return worker1.getResponse(3);
+            },
+            executor);
+    Future<WorkResponse> response2 =
+        Futures.submit(
+            () -> {
+              proxyThreads[1] = Thread.currentThread();
+              return worker2.getResponse(42);
+            },
+            executor);
+
+    // Makes sure both workers are waiting for responses before the multiplexer processes anything.
+    while (proxyThreads[0] == null
+        || proxyThreads[0].getState() != State.WAITING
+        || proxyThreads[1] == null
+        || proxyThreads[1].getState() != State.WAITING) {
+      Thread.sleep(1);
+    }
+
+    WorkResponse fakedResponse1 = WorkResponse.newBuilder().setRequestId(3).build();
+    WorkResponse fakedResponse2 = WorkResponse.newBuilder().setRequestId(42).build();
+    // Responses can arrive out of order
+    fakedResponse2.writeDelimitedTo(workerOutputStream);
+    fakedResponse1.writeDelimitedTo(workerOutputStream);
+    workerOutputStream.flush();
+
+    assertThat(response1.get().getRequestId()).isEqualTo(3);
+    assertThat(response2.get().getRequestId()).isEqualTo(42);
+    assertThat(multiplexer.noOutstandingRequests()).isTrue();
+  }
+
+  @Test
+  public void testGetResponse_slowProxy()
+      throws IOException, InterruptedException, ExecutionException {
+    WorkerKey workerKey = TestUtils.createWorkerKey(fileSystem, "test4", true, "fakeBinary");
+    WorkerMultiplexer multiplexer = WorkerMultiplexerManager.getInstance(workerKey, logPath);
+
+    PipedInputStream serverInputStream = new PipedInputStream();
+    OutputStream workerOutputStream = new PipedOutputStream(serverInputStream);
+    multiplexer.setProcessFactory(params -> new FakeSubprocess(serverInputStream));
+
+    WorkerProxy worker1 = new WorkerProxy(workerKey, 1, logPath, logPath, multiplexer);
+    worker1.prepareExecution(null, null, null);
+    WorkRequest request1 = WorkRequest.newBuilder().setRequestId(3).build();
+    worker1.putRequest(request1);
+
+    WorkerProxy worker2 = new WorkerProxy(workerKey, 2, logPath, logPath, multiplexer);
+    worker2.prepareExecution(null, null, null);
+    WorkRequest request2 = WorkRequest.newBuilder().setRequestId(42).build();
+    worker2.putRequest(request2);
+
+    WorkResponse fakedResponse1 = WorkResponse.newBuilder().setRequestId(3).build();
+    WorkResponse fakedResponse2 = WorkResponse.newBuilder().setRequestId(42).build();
+    // Responses can arrive out of order
+    fakedResponse2.writeDelimitedTo(workerOutputStream);
+    fakedResponse1.writeDelimitedTo(workerOutputStream);
+    workerOutputStream.flush();
+
+    Executor executor = Executors.newFixedThreadPool(2);
+    Future<WorkResponse> response1 = Futures.submit(() -> worker1.getResponse(3), executor);
+    Future<WorkResponse> response2 = Futures.submit(() -> worker2.getResponse(42), executor);
+
+    assertThat(response1.get().getRequestId()).isEqualTo(3);
+    assertThat(response2.get().getRequestId()).isEqualTo(42);
+    assertThat(multiplexer.noOutstandingRequests()).isTrue();
+  }
+}
diff --git a/src/test/java/com/google/devtools/build/lib/worker/WorkerSpawnRunnerTest.java b/src/test/java/com/google/devtools/build/lib/worker/WorkerSpawnRunnerTest.java
index f36ae95..224c200 100644
--- a/src/test/java/com/google/devtools/build/lib/worker/WorkerSpawnRunnerTest.java
+++ b/src/test/java/com/google/devtools/build/lib/worker/WorkerSpawnRunnerTest.java
@@ -110,7 +110,7 @@
     WorkerKey key = createWorkerKey(fs, "mnem", false);
     Path logFile = fs.getPath("/worker.log");
     when(worker.getLogFile()).thenReturn(logFile);
-    when(worker.getResponse())
+    when(worker.getResponse(1))
         .thenReturn(
             WorkResponse.newBuilder().setExitCode(0).setOutput("out").setRequestId(1).build());
     WorkResponse response =
@@ -149,7 +149,7 @@
     WorkerKey key = createWorkerKey(fs, "mnem", false);
     Path logFile = fs.getPath("/worker.log");
     when(worker.getLogFile()).thenReturn(logFile);
-    when(worker.getResponse()).thenThrow(new IOException("Bad protobuf"));
+    when(worker.getResponse(1)).thenThrow(new IOException("Bad protobuf"));
     when(worker.getRecordingStreamMessage()).thenReturn(recordedResponse);
     String workerLog = "Log from worker\n";
     FileSystemUtils.writeIsoLatin1(logFile, workerLog);
diff --git a/src/test/java/com/google/devtools/build/lib/worker/WorkerTest.java b/src/test/java/com/google/devtools/build/lib/worker/WorkerTest.java
index 9b38382..7f6fdf4 100644
--- a/src/test/java/com/google/devtools/build/lib/worker/WorkerTest.java
+++ b/src/test/java/com/google/devtools/build/lib/worker/WorkerTest.java
@@ -21,27 +21,22 @@
 import static org.junit.Assert.assertThrows;
 
 import com.google.common.base.Preconditions;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.ImmutableSortedMap;
-import com.google.common.hash.HashCode;
 import com.google.devtools.build.lib.actions.ExecutionRequirements.WorkerProtocolFormat;
 import com.google.devtools.build.lib.sandbox.SandboxHelpers.SandboxInputs;
 import com.google.devtools.build.lib.sandbox.SandboxHelpers.SandboxOutputs;
-import com.google.devtools.build.lib.shell.Subprocess;
 import com.google.devtools.build.lib.vfs.DigestHashFunction;
 import com.google.devtools.build.lib.vfs.FileSystem;
 import com.google.devtools.build.lib.vfs.Path;
 import com.google.devtools.build.lib.vfs.inmemoryfs.InMemoryFileSystem;
+import com.google.devtools.build.lib.worker.TestUtils.FakeSubprocess;
+import com.google.devtools.build.lib.worker.TestUtils.TestWorker;
 import com.google.devtools.build.lib.worker.WorkerProtocol.Input;
 import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest;
 import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse;
 import com.google.protobuf.ByteString;
 import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
-import java.io.Closeable;
 import java.io.IOException;
-import java.io.InputStream;
 import java.io.OutputStream;
 import org.junit.After;
 import org.junit.Test;
@@ -53,26 +48,6 @@
 public final class WorkerTest {
   final FileSystem fs = new InMemoryFileSystem(DigestHashFunction.SHA256);
 
-  /** A worker that uses a fake subprocess for I/O. */
-  private static class TestWorker extends Worker {
-    private final FakeSubprocess fakeSubprocess;
-
-    public TestWorker(
-        WorkerKey workerKey,
-        int workerId,
-        final Path workDir,
-        Path logFile,
-        FakeSubprocess fakeSubprocess) {
-      super(workerKey, workerId, workDir, logFile);
-      this.fakeSubprocess = fakeSubprocess;
-    }
-
-    @Override
-    Subprocess createProcess() {
-      return fakeSubprocess;
-    }
-  }
-
   private TestWorker workerForCleanup = null;
 
   @After
@@ -83,95 +58,6 @@
     }
   }
 
-  private WorkerKey createWorkerKey(WorkerProtocolFormat protocolFormat) {
-    return new WorkerKey(
-        /* args= */ ImmutableList.of("arg1", "arg2", "arg3"),
-        /* env= */ ImmutableMap.of("env1", "foo", "env2", "bar"),
-        /* execRoot= */ fs.getPath("/outputbase/execroot/workspace"),
-        /* mnemonic= */ "dummy",
-        /* workerFilesCombinedHash= */ HashCode.fromInt(0),
-        /* workerFilesWithHashes= */ ImmutableSortedMap.of(),
-        /* mustBeSandboxed= */ true,
-        /* proxied= */ true,
-        protocolFormat);
-  }
-
-  /**
-   * The {@link Worker} object uses a {@link Subprocess} to interact with persistent worker
-   * binaries. Since this test is strictly testing {@link Worker} and not any outside persistent
-   * worker binaries, a {@link FakeSubprocess} instance is used to fake the {@link InputStream} and
-   * {@link OutputStream} that normally write and read from a persistent worker.
-   */
-  private static class FakeSubprocess implements Subprocess {
-    private final ByteArrayInputStream inputStream;
-    private final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
-    private final ByteArrayInputStream errStream = new ByteArrayInputStream(new byte[0]);
-    private boolean wasDestroyed = false;
-
-    public FakeSubprocess(byte[] bytes) throws IOException {
-      inputStream = new ByteArrayInputStream(bytes);
-    }
-
-    @Override
-    public InputStream getInputStream() {
-      return inputStream;
-    }
-
-    @Override
-    public OutputStream getOutputStream() {
-      return outputStream;
-    }
-
-    @Override
-    public InputStream getErrorStream() {
-      return errStream;
-    }
-
-    @Override
-    public boolean destroy() {
-      for (Closeable stream : new Closeable[] {inputStream, outputStream, errStream}) {
-        try {
-          stream.close();
-        } catch (IOException e) {
-          throw new IllegalStateException(e);
-        }
-      }
-
-      wasDestroyed = true;
-      return true;
-    }
-
-    @Override
-    public int exitValue() {
-      return 0;
-    }
-
-    @Override
-    public boolean finished() {
-      return true;
-    }
-
-    @Override
-    public boolean timedout() {
-      return false;
-    }
-
-    @Override
-    public void waitFor() throws InterruptedException {
-      // Do nothing.
-    }
-
-    @Override
-    public void close() {
-      // Do nothing.
-    }
-
-    @Override
-    public boolean isAlive() {
-      return wasDestroyed;
-    }
-  }
-
   private static byte[] serializeResponseToProtoBytes(WorkResponse response) throws IOException {
     ByteArrayOutputStream baos = new ByteArrayOutputStream();
     response.writeDelimitedTo(baos);
@@ -183,7 +69,7 @@
     Preconditions.checkState(
         workerForCleanup == null, "createTestWorker can only be called once per test");
 
-    WorkerKey key = createWorkerKey(protocolFormat);
+    WorkerKey key = TestUtils.createWorkerKey(protocolFormat, fs);
 
     FakeSubprocess fakeSubprocess = new FakeSubprocess(outputStreamBytes);
 
@@ -209,7 +95,7 @@
     TestWorker testWorker = createTestWorker(new byte[0], PROTO);
     testWorker.putRequest(request);
 
-    OutputStream stdout = testWorker.fakeSubprocess.getOutputStream();
+    OutputStream stdout = testWorker.getFakeSubprocess().getOutputStream();
     WorkRequest requestFromStdout =
         WorkRequest.parseDelimitedFrom(new ByteArrayInputStream(stdout.toString().getBytes(UTF_8)));
 
@@ -221,7 +107,7 @@
     WorkResponse response = WorkResponse.getDefaultInstance();
 
     TestWorker testWorker = createTestWorker(serializeResponseToProtoBytes(response), PROTO);
-    WorkResponse readResponse = testWorker.getResponse();
+    WorkResponse readResponse = testWorker.getResponse(0);
 
     assertThat(readResponse).isEqualTo(response);
   }
@@ -231,14 +117,14 @@
     TestWorker testWorker = createTestWorker(new byte[0], JSON);
     testWorker.putRequest(WorkRequest.getDefaultInstance());
 
-    OutputStream stdout = testWorker.fakeSubprocess.getOutputStream();
+    OutputStream stdout = testWorker.getFakeSubprocess().getOutputStream();
     assertThat(stdout.toString()).isEqualTo("{}");
   }
 
   @Test
   public void testGetResponse_json_success() throws IOException {
     TestWorker testWorker = createTestWorker("{}".getBytes(UTF_8), JSON);
-    WorkResponse readResponse = testWorker.getResponse();
+    WorkResponse readResponse = testWorker.getResponse(0);
     WorkResponse response = WorkResponse.getDefaultInstance();
 
     assertThat(readResponse).isEqualTo(response);
@@ -260,7 +146,7 @@
     TestWorker testWorker = createTestWorker(new byte[0], JSON);
     testWorker.putRequest(request);
 
-    OutputStream stdout = testWorker.fakeSubprocess.getOutputStream();
+    OutputStream stdout = testWorker.getFakeSubprocess().getOutputStream();
     String requestJsonString =
         "{\"arguments\":[\"testRequest\"],\"inputs\":"
             + "[{\"path\":\"testPath\",\"digest\":\"dGVzdERpZ2VzdA==\"}],\"requestId\":1}";
@@ -272,7 +158,7 @@
     TestWorker testWorker =
         createTestWorker(
             "{\"exitCode\":1,\"output\":\"test output\",\"requestId\":1}".getBytes(UTF_8), JSON);
-    WorkResponse readResponse = testWorker.getResponse();
+    WorkResponse readResponse = testWorker.getResponse(1);
     WorkResponse response =
         WorkResponse.newBuilder().setExitCode(1).setOutput("test output").setRequestId(1).build();
 
@@ -282,7 +168,7 @@
   private void verifyGetResponseFailure(String responseString, String expectedError)
       throws IOException {
     TestWorker testWorker = createTestWorker(responseString.getBytes(UTF_8), JSON);
-    IOException ex = assertThrows(IOException.class, testWorker::getResponse);
+    IOException ex = assertThrows(IOException.class, () -> testWorker.getResponse(0));
     assertThat(ex).hasMessageThat().contains(expectedError);
   }