Support for cancellation in WorkRequestHandler.

To actually use cancellation, a worker implementation will still have to implement a cancellation callback that actually cancels and add `supports-worker-cancellation = 1` to the execution requirements, and then the build must run with `--experimental_worker_cancellation`.

Cancellation design doc: https://docs.google.com/document/d/1-h4gcBV8Jn6DK9G_e23kZQ159jmX__uckhub1Gv9dzc

RELNOTES: None.
PiperOrigin-RevId: 373749452
diff --git a/site/docs/creating-workers.md b/site/docs/creating-workers.md
index b9140cb..3b2b537 100644
--- a/site/docs/creating-workers.md
+++ b/site/docs/creating-workers.md
@@ -84,6 +84,13 @@
 }
 ```
 
+A `request_id` of 0 indicates a "singleplex" request, i.e. this request cannot
+be processed in parallel with other requests. The server guarantees that a
+given worker receives requests with either only `request_id` 0 or only
+`request_id` greater than zero. Singleplex requests are sent in serial, i.e. the
+server doesn't send another request until it has received a response (except
+for cancel requests, see below).
+
 **Notes**
 
 * Each protocol buffer is preceded by its length in `varint` format (see
@@ -94,6 +101,34 @@
 * Bazel stores requests as protobufs and converts them to JSON using
 [protobuf's JSON format](https://cs.opensource.google/protobuf/protobuf/+/master:java/util/src/main/java/com/google/protobuf/util/JsonFormat.java)
 
+### Cancellation
+
+Workers can optionally allow work requests to be cancelled before they finish.
+This is particularly useful in connection with dynamic execution, where local
+execution can regularly be interrupted by a faster remote execution. To allow
+cancellation, add `supports-worker-cancellation: 1` to the
+`execution-requirements` field (see below) and set the
+`--experimental_worker_cancellation` flag.
+
+A **cancel request** is a `WorkRequest` with the `cancel` field set (and
+similarly a **cancel response** is a `WorkResponse` with the `was_cancelled`
+field set). The only other field that must be in a cancel request or cancel
+response is `request_id`, indicating which
+request to cancel. The `request_id` field will be 0 for singleplex workers
+or the non-0 `request_id` of a previously sent `WorkRequest` for multiplex
+workers. The server may send cancel requests for requests that the worker has
+already responded to, in which case the cancel request must be ignored.
+
+Each non-cancel `WorkRequest` message must be answered exactly once, whether
+or not it was cancelled. Once the server has sent a cancel request, the worker
+may respond with a `WorkResponse` with the `request_id` set
+and the `was_cancelled` field set to true. Sending a regular `WorkResponse`
+is also accepted, but the `output` and `exit_code` fields will be ignored.
+
+Once a response has been sent for a `WorkRequest`, the worker must not touch
+the files in its working directory. The server is free to clean up the files,
+including temporary files.
+
 ## Making the rule that uses the worker
 
 You'll also need to create a rule that generates actions to be performed by the
diff --git a/src/main/java/com/google/devtools/build/lib/worker/WorkRequestHandler.java b/src/main/java/com/google/devtools/build/lib/worker/WorkRequestHandler.java
index faecc14..ebb7b62 100644
--- a/src/main/java/com/google/devtools/build/lib/worker/WorkRequestHandler.java
+++ b/src/main/java/com/google/devtools/build/lib/worker/WorkRequestHandler.java
@@ -13,7 +13,6 @@
 // limitations under the License.
 package com.google.devtools.build.lib.worker;
 
-
 import com.google.common.annotations.VisibleForTesting;
 import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest;
 import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse;
@@ -24,13 +23,12 @@
 import java.io.StringWriter;
 import java.lang.management.ManagementFactory;
 import java.time.Duration;
-import java.util.ArrayDeque;
 import java.util.List;
-import java.util.Map;
 import java.util.Optional;
-import java.util.Queue;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.BiConsumer;
 import java.util.function.BiFunction;
 
 /**
@@ -56,6 +54,10 @@
 
   /** Holds information necessary to properly handle a request, especially for cancellation. */
   static class RequestInfo {
+    /** The thread handling the request. */
+    final Thread thread;
+    /** If true, we have received a cancel request for this request. */
+    private boolean cancelled;
     /**
      * The builder for the response to this request. Since only one response must be sent per
      * request, this builder must be accessed through takeBuilder(), which zeroes this field and
@@ -63,6 +65,20 @@
      */
     private WorkResponse.Builder responseBuilder = WorkResponse.newBuilder();
 
+    RequestInfo(Thread thread) {
+      this.thread = thread;
+    }
+
+    /** Sets whether this request has been cancelled. */
+    void setCancelled() {
+      cancelled = true;
+    }
+
+    /** Returns true if this request has been cancelled. */
+    boolean isCancelled() {
+      return cancelled;
+    }
+
     /**
      * Returns the response builder. If called more than once on the same instance, subsequent calls
      * will return {@code null}.
@@ -72,13 +88,22 @@
       responseBuilder = null;
       return Optional.ofNullable(b);
     }
+
+    /**
+     * Adds {@code s} as output to when the response eventually gets built. Does nothing if the
+     * response has already been taken. There is no guarantee that the response hasn't already been
+     * taken, making this call a no-op. This may be called multiple times. No delimiters are added
+     * between strings from multiple calls.
+     */
+    synchronized void addOutput(String s) {
+      if (responseBuilder != null) {
+        responseBuilder.setOutput(responseBuilder.getOutput() + s);
+      }
+    }
   }
 
   /** Requests that are currently being processed. Visible for testing. */
-  final Map<Integer, RequestInfo> activeRequests = new ConcurrentHashMap<>();
-
-  /** WorkRequests that have been received but could not be processed yet. */
-  private final Queue<WorkRequest> availableRequests = new ArrayDeque<>();
+  final ConcurrentMap<Integer, RequestInfo> activeRequests = new ConcurrentHashMap<>();
 
   /** The function to be called after each {@link WorkRequest} is read. */
   private final BiFunction<List<String>, PrintWriter, Integer> callback;
@@ -88,6 +113,7 @@
 
   final WorkerMessageProcessor messageProcessor;
 
+  private final BiConsumer<Integer, Thread> cancelCallback;
 
   private final CpuTimeBasedGcScheduler gcScheduler;
 
@@ -107,7 +133,7 @@
       BiFunction<List<String>, PrintWriter, Integer> callback,
       PrintStream stderr,
       WorkerMessageProcessor messageProcessor) {
-    this(callback, stderr, messageProcessor, Duration.ZERO);
+    this(callback, stderr, messageProcessor, Duration.ZERO, null);
   }
 
   /**
@@ -131,10 +157,24 @@
       PrintStream stderr,
       WorkerMessageProcessor messageProcessor,
       Duration cpuUsageBeforeGc) {
+    this(callback, stderr, messageProcessor, cpuUsageBeforeGc, null);
+  }
+
+  /**
+   * Creates a {@code WorkRequestHandler} that will call {@code callback} for each WorkRequest
+   * received. Only used for the Builder.
+   */
+  private WorkRequestHandler(
+      BiFunction<List<String>, PrintWriter, Integer> callback,
+      PrintStream stderr,
+      WorkerMessageProcessor messageProcessor,
+      Duration cpuUsageBeforeGc,
+      BiConsumer<Integer, Thread> cancelCallback) {
     this.callback = callback;
     this.stderr = stderr;
     this.messageProcessor = messageProcessor;
     this.gcScheduler = new CpuTimeBasedGcScheduler(cpuUsageBeforeGc);
+    this.cancelCallback = cancelCallback;
   }
 
   /** Builder class for WorkRequestHandler. Required parameters are passed to the constructor. */
@@ -143,6 +183,7 @@
     private final PrintStream stderr;
     private final WorkerMessageProcessor messageProcessor;
     private Duration cpuUsageBeforeGc = Duration.ZERO;
+    private BiConsumer<Integer, Thread> cancelCallback;
 
     /**
      * Creates a {@code WorkRequestHandlerBuilder}.
@@ -173,9 +214,19 @@
       return this;
     }
 
+    /**
+     * Sets a callback will be called when a cancellation message has been received. The callback
+     * will be call with the request ID and the thread executing the request.
+     */
+    public WorkRequestHandlerBuilder setCancelCallback(BiConsumer<Integer, Thread> cancelCallback) {
+      this.cancelCallback = cancelCallback;
+      return this;
+    }
+
     /** Returns a WorkRequestHandler instance with the values in this Builder. */
     public WorkRequestHandler build() {
-      return new WorkRequestHandler(callback, stderr, messageProcessor, cpuUsageBeforeGc);
+      return new WorkRequestHandler(
+          callback, stderr, messageProcessor, cpuUsageBeforeGc, cancelCallback);
     }
   }
 
@@ -191,56 +242,42 @@
       if (request == null) {
         break;
       }
-      availableRequests.add(request);
-      startRequestThreads();
+      if (request.getCancel()) {
+        respondToCancelRequest(request);
+      } else {
+        startResponseThread(request);
+      }
     }
   }
 
-  /**
-   * Starts threads for as many outstanding requests as possible. This is the only method that adds
-   * to {@code activeRequests}.
-   */
-  private synchronized void startRequestThreads() {
-    while (!availableRequests.isEmpty()) {
-      // If there's a singleplex request in process, don't start more processes.
-      if (activeRequests.containsKey(0)) {
-        return;
-      }
-      WorkRequest request = availableRequests.peek();
-      // Don't start new singleplex requests if there are other requests running.
-      if (request.getRequestId() == 0 && !activeRequests.isEmpty()) {
-        return;
-      }
-      availableRequests.remove();
-      Thread t = createResponseThread(request);
-      activeRequests.put(request.getRequestId(), new RequestInfo());
-      t.start();
-    }
-  }
-
-  /** Creates a new {@link Thread} to process a multiplex request. */
-  Thread createResponseThread(WorkRequest request) {
+  /** Starts a thread for the given request. */
+  void startResponseThread(WorkRequest request) {
     Thread currentThread = Thread.currentThread();
     String threadName =
         request.getRequestId() > 0
             ? "multiplex-request-" + request.getRequestId()
             : "singleplex-request";
-    return new Thread(
-        () -> {
-          RequestInfo requestInfo = activeRequests.get(request.getRequestId());
-          try {
-            respondToRequest(request, requestInfo);
-          } catch (IOException e) {
-            e.printStackTrace(stderr);
-            // In case of error, shut down the entire worker.
-            currentThread.interrupt();
-          } finally {
-            activeRequests.remove(request.getRequestId());
-            // A good time to start more requests, especially if we finished a singleplex request
-            startRequestThreads();
-          }
-        },
-        threadName);
+    Thread t =
+        new Thread(
+            () -> {
+              RequestInfo requestInfo = activeRequests.get(request.getRequestId());
+              if (requestInfo == null) {
+                // Already cancelled
+                return;
+              }
+              try {
+                respondToRequest(request, requestInfo);
+              } catch (IOException e) {
+                e.printStackTrace(stderr);
+                // In case of error, shut down the entire worker.
+                currentThread.interrupt();
+              } finally {
+                activeRequests.remove(request.getRequestId());
+              }
+            },
+            threadName);
+    activeRequests.put(request.getRequestId(), new RequestInfo(t));
+    t.start();
   }
 
   /** Handles and responds to the given {@link WorkRequest}. */
@@ -260,7 +297,11 @@
       if (optBuilder.isPresent()) {
         WorkResponse.Builder builder = optBuilder.get();
         builder.setRequestId(request.getRequestId());
-        builder.setOutput(builder.getOutput() + sw.toString()).setExitCode(exitCode);
+        if (requestInfo.isCancelled()) {
+          builder.setWasCancelled(true);
+        } else {
+          builder.setOutput(builder.getOutput() + sw).setExitCode(exitCode);
+        }
         WorkResponse response = builder.build();
         synchronized (this) {
           messageProcessor.writeWorkResponse(response);
@@ -270,6 +311,45 @@
     }
   }
 
+  /**
+   * Handles cancelling an existing request, including sending a response if that is not done by the
+   * time {@code cancelCallback.accept} returns.
+   */
+  void respondToCancelRequest(WorkRequest request) throws IOException {
+    // Theoretically, we could have gotten two singleplex requests, and we can't tell those apart.
+    // However, that's a violation of the protocol, so we don't try to handle it (not least because
+    // handling it would be quite error-prone).
+    RequestInfo ri = activeRequests.remove(request.getRequestId());
+
+    if (ri == null) {
+      return;
+    }
+    if (cancelCallback == null) {
+      ri.setCancelled();
+      // This is either an error on the server side or a version mismatch between the server setup
+      // and the binary. It's better to wait for the regular work to finish instead of breaking the
+      // build, but we should inform the user about the bad setup.
+      ri.addOutput(
+          String.format(
+              "Cancellation request received for worker request %d, but this worker does not"
+                  + " support cancellation.\n",
+              request.getRequestId()));
+    } else {
+      if (ri.thread.isAlive() && !ri.isCancelled()) {
+        ri.setCancelled();
+        cancelCallback.accept(request.getRequestId(), ri.thread);
+        Optional<WorkResponse.Builder> builder = ri.takeBuilder();
+        if (builder.isPresent()) {
+          WorkResponse response =
+              builder.get().setWasCancelled(true).setRequestId(request.getRequestId()).build();
+          synchronized (this) {
+            messageProcessor.writeWorkResponse(response);
+          }
+        }
+      }
+    }
+  }
+
   @Override
   public void close() throws IOException {
     messageProcessor.close();
diff --git a/src/main/protobuf/worker_protocol.proto b/src/main/protobuf/worker_protocol.proto
index a5f9545..2381df1 100644
--- a/src/main/protobuf/worker_protocol.proto
+++ b/src/main/protobuf/worker_protocol.proto
@@ -41,11 +41,12 @@
 
   // Each WorkRequest must have either a unique
   // request_id or request_id = 0. If request_id is 0, this WorkRequest must be
-  // processed alone, otherwise the worker may process multiple WorkRequests in
-  // parallel (multiplexing). As an exception to the above, if the cancel field
-  // is true, the request_id must be the same as a previously sent WorkRequest.
-  // The request_id must be attached unchanged to the corresponding
-  // WorkResponse.
+  // processed alone (singleplex), otherwise the worker may process multiple
+  // WorkRequests in parallel (multiplexing). As an exception to the above, if
+  // the cancel field is true, the request_id must be the same as a previously
+  // sent WorkRequest. The request_id must be attached unchanged to the
+  // corresponding WorkResponse. Only one singleplex request may be sent to a
+  // worker at a time.
   int32 request_id = 3;
 
   // EXPERIMENTAL: When true, this is a cancel request, indicating that a
diff --git a/src/test/java/com/google/devtools/build/lib/worker/ExampleWorker.java b/src/test/java/com/google/devtools/build/lib/worker/ExampleWorker.java
index 6190bbc..96e2e86 100644
--- a/src/test/java/com/google/devtools/build/lib/worker/ExampleWorker.java
+++ b/src/test/java/com/google/devtools/build/lib/worker/ExampleWorker.java
@@ -94,11 +94,10 @@
         if (poisoned && workerOptions.hardPoison) {
           throw new IllegalStateException("I'm a very poisoned worker and will just crash.");
         }
-        if (request.getRequestId() != 0) {
-          Thread t = createResponseThread(request);
-          t.start();
+        if (request.getCancel()) {
+          respondToCancelRequest(request);
         } else {
-          respondToRequest(request, new RequestInfo());
+          startResponseThread(request);
         }
         if (workerOptions.exitAfter > 0 && workUnitCounter > workerOptions.exitAfter) {
           System.exit(0);
diff --git a/src/test/java/com/google/devtools/build/lib/worker/WorkRequestHandlerTest.java b/src/test/java/com/google/devtools/build/lib/worker/WorkRequestHandlerTest.java
index cd3da8d..10af797 100644
--- a/src/test/java/com/google/devtools/build/lib/worker/WorkRequestHandlerTest.java
+++ b/src/test/java/com/google/devtools/build/lib/worker/WorkRequestHandlerTest.java
@@ -17,14 +17,20 @@
 import static com.google.common.truth.Truth.assertThat;
 
 import com.google.devtools.build.lib.worker.WorkRequestHandler.RequestInfo;
+import com.google.devtools.build.lib.worker.WorkRequestHandler.WorkRequestHandlerBuilder;
+import com.google.devtools.build.lib.worker.WorkRequestHandler.WorkerMessageProcessor;
 import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest;
 import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse;
 import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
+import java.io.PipedInputStream;
+import java.io.PipedOutputStream;
 import java.io.PrintStream;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
+import java.util.concurrent.Semaphore;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -51,7 +57,7 @@
 
     List<String> args = Arrays.asList("--sources", "A.java");
     WorkRequest request = WorkRequest.newBuilder().addAllArguments(args).build();
-    handler.respondToRequest(request, new RequestInfo());
+    handler.respondToRequest(request, new RequestInfo(null));
 
     WorkResponse response =
         WorkResponse.parseDelimitedFrom(new ByteArrayInputStream(out.toByteArray()));
@@ -71,7 +77,7 @@
 
     List<String> args = Arrays.asList("--sources", "A.java");
     WorkRequest request = WorkRequest.newBuilder().addAllArguments(args).setRequestId(42).build();
-    handler.respondToRequest(request, new RequestInfo());
+    handler.respondToRequest(request, new RequestInfo(null));
 
     WorkResponse response =
         WorkResponse.parseDelimitedFrom(new ByteArrayInputStream(out.toByteArray()));
@@ -94,7 +100,7 @@
 
     List<String> args = Arrays.asList("--sources", "A.java");
     WorkRequest request = WorkRequest.newBuilder().addAllArguments(args).build();
-    handler.respondToRequest(request, new RequestInfo());
+    handler.respondToRequest(request, new RequestInfo(null));
 
     WorkResponse response =
         WorkResponse.parseDelimitedFrom(new ByteArrayInputStream(out.toByteArray()));
@@ -116,7 +122,7 @@
 
     List<String> args = Arrays.asList("--sources", "A.java");
     WorkRequest request = WorkRequest.newBuilder().addAllArguments(args).build();
-    handler.respondToRequest(request, new RequestInfo());
+    handler.respondToRequest(request, new RequestInfo(null));
 
     WorkResponse response =
         WorkResponse.parseDelimitedFrom(new ByteArrayInputStream(out.toByteArray()));
@@ -124,4 +130,281 @@
     assertThat(response.getExitCode()).isEqualTo(1);
     assertThat(response.getOutput()).startsWith("java.lang.RuntimeException: Exploded!");
   }
+
+  @Test
+  public void testCancelRequest_exactlyOneResponseSent() throws IOException, InterruptedException {
+    boolean[] handlerCalled = new boolean[] {false};
+    boolean[] cancelCalled = new boolean[] {false};
+    PipedOutputStream src = new PipedOutputStream();
+    PipedInputStream dest = new PipedInputStream();
+    Semaphore done = new Semaphore(0);
+    Semaphore finish = new Semaphore(0);
+    List<String> failures = new ArrayList<>();
+
+    WorkRequestHandler handler =
+        new WorkRequestHandlerBuilder(
+                (args, err) -> {
+                  handlerCalled[0] = true;
+                  err.println("Such work! Much progress! Wow!");
+                  return 1;
+                },
+                new PrintStream(new ByteArrayOutputStream()),
+                new LimitedWorkerMessageProcessor(
+                    new ProtoWorkerMessageProcessor(
+                        new PipedInputStream(src), new PipedOutputStream(dest)),
+                    2))
+            .setCancelCallback(
+                (i, t) -> {
+                  cancelCalled[0] = true;
+                })
+            .build();
+
+    runRequestHandlerThread(done, handler, finish, failures);
+    WorkRequest.newBuilder().setRequestId(42).build().writeDelimitedTo(src);
+    WorkRequest.newBuilder().setRequestId(42).setCancel(true).build().writeDelimitedTo(src);
+    WorkResponse response = WorkResponse.parseDelimitedFrom(dest);
+    done.acquire();
+
+    assertThat(handlerCalled[0] || cancelCalled[0]).isTrue();
+    assertThat(response.getRequestId()).isEqualTo(42);
+    if (response.getWasCancelled()) {
+      assertThat(response.getOutput()).isEmpty();
+      assertThat(response.getExitCode()).isEqualTo(0);
+    } else {
+      assertThat(response.getOutput()).isEqualTo("Such work! Much progress! Wow!\n");
+      assertThat(response.getExitCode()).isEqualTo(1);
+    }
+
+    // Checks that nothing more was sent.
+    assertThat(dest.available()).isEqualTo(0);
+    finish.release();
+
+    // Checks that there weren't other unexpected failures.
+    assertThat(failures).isEmpty();
+  }
+
+  @Test
+  public void testCancelRequest_sendsResponseWhenNotAlreadySent()
+      throws IOException, InterruptedException {
+    Semaphore waitForCancel = new Semaphore(0);
+    Semaphore handlerCalled = new Semaphore(0);
+    Semaphore cancelCalled = new Semaphore(0);
+    PipedOutputStream src = new PipedOutputStream();
+    PipedInputStream dest = new PipedInputStream();
+    Semaphore done = new Semaphore(0);
+    Semaphore finish = new Semaphore(0);
+    List<String> failures = new ArrayList<>();
+
+    // We force the regular handling to not finish until after we have read the cancel response,
+    // to avoid flakiness.
+    WorkRequestHandler handler =
+        new WorkRequestHandlerBuilder(
+                (args, err) -> {
+                  // This handler waits until the main thread has sent a cancel request.
+                  handlerCalled.release(2);
+                  try {
+                    waitForCancel.acquire();
+                  } catch (InterruptedException e) {
+                    failures.add("Unexpected interrupt waiting for cancel request");
+                    e.printStackTrace();
+                  }
+                  return 0;
+                },
+                new PrintStream(new ByteArrayOutputStream()),
+                new LimitedWorkerMessageProcessor(
+                    new ProtoWorkerMessageProcessor(
+                        new PipedInputStream(src), new PipedOutputStream(dest)),
+                    2))
+            .setCancelCallback(
+                (i, t) -> {
+                  cancelCalled.release();
+                })
+            .build();
+
+    runRequestHandlerThread(done, handler, finish, failures);
+    WorkRequest.newBuilder().setRequestId(42).build().writeDelimitedTo(src);
+    // Make sure the handler is called before sending the cancel request, or we might process
+    // the cancellation entirely before that.
+    handlerCalled.acquire();
+    WorkRequest.newBuilder().setRequestId(42).setCancel(true).build().writeDelimitedTo(src);
+    WorkResponse response = WorkResponse.parseDelimitedFrom(dest);
+    waitForCancel.release();
+    // Give the other request a chance to process, so we can check that no other response is sent
+    done.acquire();
+
+    assertThat(handlerCalled.availablePermits()).isEqualTo(1); // Released 2, one was acquired
+    assertThat(cancelCalled.availablePermits()).isEqualTo(1);
+    assertThat(response.getRequestId()).isEqualTo(42);
+    assertThat(response.getOutput()).isEmpty();
+    assertThat(response.getWasCancelled()).isTrue();
+
+    // Checks that nothing more was sent.
+    assertThat(dest.available()).isEqualTo(0);
+    src.close();
+    finish.release();
+
+    // Checks that there weren't other unexpected failures.
+    assertThat(failures).isEmpty();
+  }
+
+  @Test
+  public void testCancelRequest_noDoubleCancelResponse() throws IOException, InterruptedException {
+    Semaphore waitForCancel = new Semaphore(0);
+    Semaphore cancelCalled = new Semaphore(0);
+    PipedOutputStream src = new PipedOutputStream();
+    PipedInputStream dest = new PipedInputStream();
+    Semaphore done = new Semaphore(0);
+    Semaphore finish = new Semaphore(0);
+    List<String> failures = new ArrayList<>();
+
+    // We force the regular handling to not finish until after we have read the cancel response,
+    // to avoid flakiness.
+    WorkRequestHandler handler =
+        new WorkRequestHandlerBuilder(
+                (args, err) -> {
+                  try {
+                    waitForCancel.acquire();
+                  } catch (InterruptedException e) {
+                    failures.add("Unexpected interrupt waiting for cancel request");
+                    e.printStackTrace();
+                  }
+                  return 0;
+                },
+                new PrintStream(new ByteArrayOutputStream()),
+                new LimitedWorkerMessageProcessor(
+                    new ProtoWorkerMessageProcessor(
+                        new PipedInputStream(src), new PipedOutputStream(dest)),
+                    3))
+            .setCancelCallback(
+                (i, t) -> {
+                  cancelCalled.release();
+                })
+            .build();
+
+    runRequestHandlerThread(done, handler, finish, failures);
+    WorkRequest.newBuilder().setRequestId(42).build().writeDelimitedTo(src);
+    WorkRequest.newBuilder().setRequestId(42).setCancel(true).build().writeDelimitedTo(src);
+    WorkRequest.newBuilder().setRequestId(42).setCancel(true).build().writeDelimitedTo(src);
+    WorkResponse response = WorkResponse.parseDelimitedFrom(dest);
+    waitForCancel.release();
+    done.acquire();
+
+    assertThat(cancelCalled.availablePermits()).isLessThan(2);
+    assertThat(response.getRequestId()).isEqualTo(42);
+    assertThat(response.getOutput()).isEmpty();
+    assertThat(response.getWasCancelled()).isTrue();
+
+    // Checks that nothing more was sent.
+    assertThat(dest.available()).isEqualTo(0);
+    src.close();
+    finish.release();
+
+    // Checks that there weren't other unexpected failures.
+    assertThat(failures).isEmpty();
+  }
+
+  @Test
+  public void testCancelRequest_sendsNoResponseWhenAlreadySent()
+      throws IOException, InterruptedException {
+    Semaphore handlerCalled = new Semaphore(0);
+    PipedOutputStream src = new PipedOutputStream();
+    PipedInputStream dest = new PipedInputStream();
+    Semaphore done = new Semaphore(0);
+    Semaphore finish = new Semaphore(0);
+    List<String> failures = new ArrayList<>();
+
+    // We force the cancel request to not happen until after we have read the normal response,
+    // to avoid flakiness.
+    WorkRequestHandler handler =
+        new WorkRequestHandlerBuilder(
+                (args, err) -> {
+                  handlerCalled.release();
+                  err.println("Such work! Much progress! Wow!");
+                  return 2;
+                },
+                new PrintStream(new ByteArrayOutputStream()),
+                new LimitedWorkerMessageProcessor(
+                    new ProtoWorkerMessageProcessor(
+                        new PipedInputStream(src), new PipedOutputStream(dest)),
+                    2))
+            .setCancelCallback((i, t) -> {})
+            .build();
+
+    runRequestHandlerThread(done, handler, finish, failures);
+    WorkRequest.newBuilder().setRequestId(42).build().writeDelimitedTo(src);
+    WorkResponse response = WorkResponse.parseDelimitedFrom(dest);
+    WorkRequest.newBuilder().setRequestId(42).setCancel(true).build().writeDelimitedTo(src);
+    done.acquire();
+
+    assertThat(response).isNotNull();
+
+    assertThat(handlerCalled.availablePermits()).isEqualTo(1);
+    assertThat(response.getRequestId()).isEqualTo(42);
+    assertThat(response.getWasCancelled()).isFalse();
+    assertThat(response.getExitCode()).isEqualTo(2);
+    assertThat(response.getOutput()).isEqualTo("Such work! Much progress! Wow!\n");
+
+    // Checks that nothing more was sent.
+    assertThat(dest.available()).isEqualTo(0);
+    src.close();
+    finish.release();
+
+    // Checks that there weren't other unexpected failures.
+    assertThat(failures).isEmpty();
+  }
+
+  private void runRequestHandlerThread(
+      Semaphore done, WorkRequestHandler handler, Semaphore finish, List<String> failures) {
+    // This thread just makes sure the WorkRequestHandler does work asynchronously.
+    new Thread(
+            () -> {
+              try {
+                handler.processRequests();
+                while (!handler.activeRequests.isEmpty()) {
+                  Thread.sleep(1);
+                }
+                done.release();
+                finish.acquire();
+              } catch (IOException | InterruptedException e) {
+                failures.add("Unexpected I/O error talking to worker thread");
+                e.printStackTrace();
+              }
+            })
+        .start();
+  }
+
+  /**
+   * A wrapper around a WorkerMessageProcessor that stops after a given number of requests have been
+   * read. It stops by making readWorkRequest() return null.
+   */
+  private static class LimitedWorkerMessageProcessor implements WorkerMessageProcessor {
+    private final WorkerMessageProcessor delegate;
+    private final int maxMessages;
+    private int messages;
+
+    public LimitedWorkerMessageProcessor(WorkerMessageProcessor delegate, int maxMessages) {
+      this.delegate = delegate;
+      this.maxMessages = maxMessages;
+    }
+
+    @Override
+    public WorkRequest readWorkRequest() throws IOException {
+      System.out.println("Handling request #" + messages);
+      if (++messages > maxMessages) {
+        return null;
+      } else {
+        return delegate.readWorkRequest();
+      }
+    }
+
+    @Override
+    public void writeWorkResponse(WorkResponse workResponse) throws IOException {
+      delegate.writeWorkResponse(workResponse);
+    }
+
+    @Override
+    public void close() throws IOException {
+      delegate.close();
+    }
+  }
 }