Support for cancellation in WorkRequestHandler.
To actually use cancellation, a worker implementation will still have to implement a cancellation callback that actually cancels and add `supports-worker-cancellation = 1` to the execution requirements, and then the build must run with `--experimental_worker_cancellation`.
Cancellation design doc: https://docs.google.com/document/d/1-h4gcBV8Jn6DK9G_e23kZQ159jmX__uckhub1Gv9dzc
RELNOTES: None.
PiperOrigin-RevId: 373749452
diff --git a/site/docs/creating-workers.md b/site/docs/creating-workers.md
index b9140cb..3b2b537 100644
--- a/site/docs/creating-workers.md
+++ b/site/docs/creating-workers.md
@@ -84,6 +84,13 @@
}
```
+A `request_id` of 0 indicates a "singleplex" request, i.e. this request cannot
+be processed in parallel with other requests. The server guarantees that a
+given worker receives requests with either only `request_id` 0 or only
+`request_id` greater than zero. Singleplex requests are sent in serial, i.e. the
+server doesn't send another request until it has received a response (except
+for cancel requests, see below).
+
**Notes**
* Each protocol buffer is preceded by its length in `varint` format (see
@@ -94,6 +101,34 @@
* Bazel stores requests as protobufs and converts them to JSON using
[protobuf's JSON format](https://cs.opensource.google/protobuf/protobuf/+/master:java/util/src/main/java/com/google/protobuf/util/JsonFormat.java)
+### Cancellation
+
+Workers can optionally allow work requests to be cancelled before they finish.
+This is particularly useful in connection with dynamic execution, where local
+execution can regularly be interrupted by a faster remote execution. To allow
+cancellation, add `supports-worker-cancellation: 1` to the
+`execution-requirements` field (see below) and set the
+`--experimental_worker_cancellation` flag.
+
+A **cancel request** is a `WorkRequest` with the `cancel` field set (and
+similarly a **cancel response** is a `WorkResponse` with the `was_cancelled`
+field set). The only other field that must be in a cancel request or cancel
+response is `request_id`, indicating which
+request to cancel. The `request_id` field will be 0 for singleplex workers
+or the non-0 `request_id` of a previously sent `WorkRequest` for multiplex
+workers. The server may send cancel requests for requests that the worker has
+already responded to, in which case the cancel request must be ignored.
+
+Each non-cancel `WorkRequest` message must be answered exactly once, whether
+or not it was cancelled. Once the server has sent a cancel request, the worker
+may respond with a `WorkResponse` with the `request_id` set
+and the `was_cancelled` field set to true. Sending a regular `WorkResponse`
+is also accepted, but the `output` and `exit_code` fields will be ignored.
+
+Once a response has been sent for a `WorkRequest`, the worker must not touch
+the files in its working directory. The server is free to clean up the files,
+including temporary files.
+
## Making the rule that uses the worker
You'll also need to create a rule that generates actions to be performed by the
diff --git a/src/main/java/com/google/devtools/build/lib/worker/WorkRequestHandler.java b/src/main/java/com/google/devtools/build/lib/worker/WorkRequestHandler.java
index faecc14..ebb7b62 100644
--- a/src/main/java/com/google/devtools/build/lib/worker/WorkRequestHandler.java
+++ b/src/main/java/com/google/devtools/build/lib/worker/WorkRequestHandler.java
@@ -13,7 +13,6 @@
// limitations under the License.
package com.google.devtools.build.lib.worker;
-
import com.google.common.annotations.VisibleForTesting;
import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest;
import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse;
@@ -24,13 +23,12 @@
import java.io.StringWriter;
import java.lang.management.ManagementFactory;
import java.time.Duration;
-import java.util.ArrayDeque;
import java.util.List;
-import java.util.Map;
import java.util.Optional;
-import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.BiConsumer;
import java.util.function.BiFunction;
/**
@@ -56,6 +54,10 @@
/** Holds information necessary to properly handle a request, especially for cancellation. */
static class RequestInfo {
+ /** The thread handling the request. */
+ final Thread thread;
+ /** If true, we have received a cancel request for this request. */
+ private boolean cancelled;
/**
* The builder for the response to this request. Since only one response must be sent per
* request, this builder must be accessed through takeBuilder(), which zeroes this field and
@@ -63,6 +65,20 @@
*/
private WorkResponse.Builder responseBuilder = WorkResponse.newBuilder();
+ RequestInfo(Thread thread) {
+ this.thread = thread;
+ }
+
+ /** Sets whether this request has been cancelled. */
+ void setCancelled() {
+ cancelled = true;
+ }
+
+ /** Returns true if this request has been cancelled. */
+ boolean isCancelled() {
+ return cancelled;
+ }
+
/**
* Returns the response builder. If called more than once on the same instance, subsequent calls
* will return {@code null}.
@@ -72,13 +88,22 @@
responseBuilder = null;
return Optional.ofNullable(b);
}
+
+ /**
+ * Adds {@code s} as output to when the response eventually gets built. Does nothing if the
+ * response has already been taken. There is no guarantee that the response hasn't already been
+ * taken, making this call a no-op. This may be called multiple times. No delimiters are added
+ * between strings from multiple calls.
+ */
+ synchronized void addOutput(String s) {
+ if (responseBuilder != null) {
+ responseBuilder.setOutput(responseBuilder.getOutput() + s);
+ }
+ }
}
/** Requests that are currently being processed. Visible for testing. */
- final Map<Integer, RequestInfo> activeRequests = new ConcurrentHashMap<>();
-
- /** WorkRequests that have been received but could not be processed yet. */
- private final Queue<WorkRequest> availableRequests = new ArrayDeque<>();
+ final ConcurrentMap<Integer, RequestInfo> activeRequests = new ConcurrentHashMap<>();
/** The function to be called after each {@link WorkRequest} is read. */
private final BiFunction<List<String>, PrintWriter, Integer> callback;
@@ -88,6 +113,7 @@
final WorkerMessageProcessor messageProcessor;
+ private final BiConsumer<Integer, Thread> cancelCallback;
private final CpuTimeBasedGcScheduler gcScheduler;
@@ -107,7 +133,7 @@
BiFunction<List<String>, PrintWriter, Integer> callback,
PrintStream stderr,
WorkerMessageProcessor messageProcessor) {
- this(callback, stderr, messageProcessor, Duration.ZERO);
+ this(callback, stderr, messageProcessor, Duration.ZERO, null);
}
/**
@@ -131,10 +157,24 @@
PrintStream stderr,
WorkerMessageProcessor messageProcessor,
Duration cpuUsageBeforeGc) {
+ this(callback, stderr, messageProcessor, cpuUsageBeforeGc, null);
+ }
+
+ /**
+ * Creates a {@code WorkRequestHandler} that will call {@code callback} for each WorkRequest
+ * received. Only used for the Builder.
+ */
+ private WorkRequestHandler(
+ BiFunction<List<String>, PrintWriter, Integer> callback,
+ PrintStream stderr,
+ WorkerMessageProcessor messageProcessor,
+ Duration cpuUsageBeforeGc,
+ BiConsumer<Integer, Thread> cancelCallback) {
this.callback = callback;
this.stderr = stderr;
this.messageProcessor = messageProcessor;
this.gcScheduler = new CpuTimeBasedGcScheduler(cpuUsageBeforeGc);
+ this.cancelCallback = cancelCallback;
}
/** Builder class for WorkRequestHandler. Required parameters are passed to the constructor. */
@@ -143,6 +183,7 @@
private final PrintStream stderr;
private final WorkerMessageProcessor messageProcessor;
private Duration cpuUsageBeforeGc = Duration.ZERO;
+ private BiConsumer<Integer, Thread> cancelCallback;
/**
* Creates a {@code WorkRequestHandlerBuilder}.
@@ -173,9 +214,19 @@
return this;
}
+ /**
+ * Sets a callback will be called when a cancellation message has been received. The callback
+ * will be call with the request ID and the thread executing the request.
+ */
+ public WorkRequestHandlerBuilder setCancelCallback(BiConsumer<Integer, Thread> cancelCallback) {
+ this.cancelCallback = cancelCallback;
+ return this;
+ }
+
/** Returns a WorkRequestHandler instance with the values in this Builder. */
public WorkRequestHandler build() {
- return new WorkRequestHandler(callback, stderr, messageProcessor, cpuUsageBeforeGc);
+ return new WorkRequestHandler(
+ callback, stderr, messageProcessor, cpuUsageBeforeGc, cancelCallback);
}
}
@@ -191,56 +242,42 @@
if (request == null) {
break;
}
- availableRequests.add(request);
- startRequestThreads();
+ if (request.getCancel()) {
+ respondToCancelRequest(request);
+ } else {
+ startResponseThread(request);
+ }
}
}
- /**
- * Starts threads for as many outstanding requests as possible. This is the only method that adds
- * to {@code activeRequests}.
- */
- private synchronized void startRequestThreads() {
- while (!availableRequests.isEmpty()) {
- // If there's a singleplex request in process, don't start more processes.
- if (activeRequests.containsKey(0)) {
- return;
- }
- WorkRequest request = availableRequests.peek();
- // Don't start new singleplex requests if there are other requests running.
- if (request.getRequestId() == 0 && !activeRequests.isEmpty()) {
- return;
- }
- availableRequests.remove();
- Thread t = createResponseThread(request);
- activeRequests.put(request.getRequestId(), new RequestInfo());
- t.start();
- }
- }
-
- /** Creates a new {@link Thread} to process a multiplex request. */
- Thread createResponseThread(WorkRequest request) {
+ /** Starts a thread for the given request. */
+ void startResponseThread(WorkRequest request) {
Thread currentThread = Thread.currentThread();
String threadName =
request.getRequestId() > 0
? "multiplex-request-" + request.getRequestId()
: "singleplex-request";
- return new Thread(
- () -> {
- RequestInfo requestInfo = activeRequests.get(request.getRequestId());
- try {
- respondToRequest(request, requestInfo);
- } catch (IOException e) {
- e.printStackTrace(stderr);
- // In case of error, shut down the entire worker.
- currentThread.interrupt();
- } finally {
- activeRequests.remove(request.getRequestId());
- // A good time to start more requests, especially if we finished a singleplex request
- startRequestThreads();
- }
- },
- threadName);
+ Thread t =
+ new Thread(
+ () -> {
+ RequestInfo requestInfo = activeRequests.get(request.getRequestId());
+ if (requestInfo == null) {
+ // Already cancelled
+ return;
+ }
+ try {
+ respondToRequest(request, requestInfo);
+ } catch (IOException e) {
+ e.printStackTrace(stderr);
+ // In case of error, shut down the entire worker.
+ currentThread.interrupt();
+ } finally {
+ activeRequests.remove(request.getRequestId());
+ }
+ },
+ threadName);
+ activeRequests.put(request.getRequestId(), new RequestInfo(t));
+ t.start();
}
/** Handles and responds to the given {@link WorkRequest}. */
@@ -260,7 +297,11 @@
if (optBuilder.isPresent()) {
WorkResponse.Builder builder = optBuilder.get();
builder.setRequestId(request.getRequestId());
- builder.setOutput(builder.getOutput() + sw.toString()).setExitCode(exitCode);
+ if (requestInfo.isCancelled()) {
+ builder.setWasCancelled(true);
+ } else {
+ builder.setOutput(builder.getOutput() + sw).setExitCode(exitCode);
+ }
WorkResponse response = builder.build();
synchronized (this) {
messageProcessor.writeWorkResponse(response);
@@ -270,6 +311,45 @@
}
}
+ /**
+ * Handles cancelling an existing request, including sending a response if that is not done by the
+ * time {@code cancelCallback.accept} returns.
+ */
+ void respondToCancelRequest(WorkRequest request) throws IOException {
+ // Theoretically, we could have gotten two singleplex requests, and we can't tell those apart.
+ // However, that's a violation of the protocol, so we don't try to handle it (not least because
+ // handling it would be quite error-prone).
+ RequestInfo ri = activeRequests.remove(request.getRequestId());
+
+ if (ri == null) {
+ return;
+ }
+ if (cancelCallback == null) {
+ ri.setCancelled();
+ // This is either an error on the server side or a version mismatch between the server setup
+ // and the binary. It's better to wait for the regular work to finish instead of breaking the
+ // build, but we should inform the user about the bad setup.
+ ri.addOutput(
+ String.format(
+ "Cancellation request received for worker request %d, but this worker does not"
+ + " support cancellation.\n",
+ request.getRequestId()));
+ } else {
+ if (ri.thread.isAlive() && !ri.isCancelled()) {
+ ri.setCancelled();
+ cancelCallback.accept(request.getRequestId(), ri.thread);
+ Optional<WorkResponse.Builder> builder = ri.takeBuilder();
+ if (builder.isPresent()) {
+ WorkResponse response =
+ builder.get().setWasCancelled(true).setRequestId(request.getRequestId()).build();
+ synchronized (this) {
+ messageProcessor.writeWorkResponse(response);
+ }
+ }
+ }
+ }
+ }
+
@Override
public void close() throws IOException {
messageProcessor.close();
diff --git a/src/main/protobuf/worker_protocol.proto b/src/main/protobuf/worker_protocol.proto
index a5f9545..2381df1 100644
--- a/src/main/protobuf/worker_protocol.proto
+++ b/src/main/protobuf/worker_protocol.proto
@@ -41,11 +41,12 @@
// Each WorkRequest must have either a unique
// request_id or request_id = 0. If request_id is 0, this WorkRequest must be
- // processed alone, otherwise the worker may process multiple WorkRequests in
- // parallel (multiplexing). As an exception to the above, if the cancel field
- // is true, the request_id must be the same as a previously sent WorkRequest.
- // The request_id must be attached unchanged to the corresponding
- // WorkResponse.
+ // processed alone (singleplex), otherwise the worker may process multiple
+ // WorkRequests in parallel (multiplexing). As an exception to the above, if
+ // the cancel field is true, the request_id must be the same as a previously
+ // sent WorkRequest. The request_id must be attached unchanged to the
+ // corresponding WorkResponse. Only one singleplex request may be sent to a
+ // worker at a time.
int32 request_id = 3;
// EXPERIMENTAL: When true, this is a cancel request, indicating that a
diff --git a/src/test/java/com/google/devtools/build/lib/worker/ExampleWorker.java b/src/test/java/com/google/devtools/build/lib/worker/ExampleWorker.java
index 6190bbc..96e2e86 100644
--- a/src/test/java/com/google/devtools/build/lib/worker/ExampleWorker.java
+++ b/src/test/java/com/google/devtools/build/lib/worker/ExampleWorker.java
@@ -94,11 +94,10 @@
if (poisoned && workerOptions.hardPoison) {
throw new IllegalStateException("I'm a very poisoned worker and will just crash.");
}
- if (request.getRequestId() != 0) {
- Thread t = createResponseThread(request);
- t.start();
+ if (request.getCancel()) {
+ respondToCancelRequest(request);
} else {
- respondToRequest(request, new RequestInfo());
+ startResponseThread(request);
}
if (workerOptions.exitAfter > 0 && workUnitCounter > workerOptions.exitAfter) {
System.exit(0);
diff --git a/src/test/java/com/google/devtools/build/lib/worker/WorkRequestHandlerTest.java b/src/test/java/com/google/devtools/build/lib/worker/WorkRequestHandlerTest.java
index cd3da8d..10af797 100644
--- a/src/test/java/com/google/devtools/build/lib/worker/WorkRequestHandlerTest.java
+++ b/src/test/java/com/google/devtools/build/lib/worker/WorkRequestHandlerTest.java
@@ -17,14 +17,20 @@
import static com.google.common.truth.Truth.assertThat;
import com.google.devtools.build.lib.worker.WorkRequestHandler.RequestInfo;
+import com.google.devtools.build.lib.worker.WorkRequestHandler.WorkRequestHandlerBuilder;
+import com.google.devtools.build.lib.worker.WorkRequestHandler.WorkerMessageProcessor;
import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest;
import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
+import java.io.PipedInputStream;
+import java.io.PipedOutputStream;
import java.io.PrintStream;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
+import java.util.concurrent.Semaphore;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -51,7 +57,7 @@
List<String> args = Arrays.asList("--sources", "A.java");
WorkRequest request = WorkRequest.newBuilder().addAllArguments(args).build();
- handler.respondToRequest(request, new RequestInfo());
+ handler.respondToRequest(request, new RequestInfo(null));
WorkResponse response =
WorkResponse.parseDelimitedFrom(new ByteArrayInputStream(out.toByteArray()));
@@ -71,7 +77,7 @@
List<String> args = Arrays.asList("--sources", "A.java");
WorkRequest request = WorkRequest.newBuilder().addAllArguments(args).setRequestId(42).build();
- handler.respondToRequest(request, new RequestInfo());
+ handler.respondToRequest(request, new RequestInfo(null));
WorkResponse response =
WorkResponse.parseDelimitedFrom(new ByteArrayInputStream(out.toByteArray()));
@@ -94,7 +100,7 @@
List<String> args = Arrays.asList("--sources", "A.java");
WorkRequest request = WorkRequest.newBuilder().addAllArguments(args).build();
- handler.respondToRequest(request, new RequestInfo());
+ handler.respondToRequest(request, new RequestInfo(null));
WorkResponse response =
WorkResponse.parseDelimitedFrom(new ByteArrayInputStream(out.toByteArray()));
@@ -116,7 +122,7 @@
List<String> args = Arrays.asList("--sources", "A.java");
WorkRequest request = WorkRequest.newBuilder().addAllArguments(args).build();
- handler.respondToRequest(request, new RequestInfo());
+ handler.respondToRequest(request, new RequestInfo(null));
WorkResponse response =
WorkResponse.parseDelimitedFrom(new ByteArrayInputStream(out.toByteArray()));
@@ -124,4 +130,281 @@
assertThat(response.getExitCode()).isEqualTo(1);
assertThat(response.getOutput()).startsWith("java.lang.RuntimeException: Exploded!");
}
+
+ @Test
+ public void testCancelRequest_exactlyOneResponseSent() throws IOException, InterruptedException {
+ boolean[] handlerCalled = new boolean[] {false};
+ boolean[] cancelCalled = new boolean[] {false};
+ PipedOutputStream src = new PipedOutputStream();
+ PipedInputStream dest = new PipedInputStream();
+ Semaphore done = new Semaphore(0);
+ Semaphore finish = new Semaphore(0);
+ List<String> failures = new ArrayList<>();
+
+ WorkRequestHandler handler =
+ new WorkRequestHandlerBuilder(
+ (args, err) -> {
+ handlerCalled[0] = true;
+ err.println("Such work! Much progress! Wow!");
+ return 1;
+ },
+ new PrintStream(new ByteArrayOutputStream()),
+ new LimitedWorkerMessageProcessor(
+ new ProtoWorkerMessageProcessor(
+ new PipedInputStream(src), new PipedOutputStream(dest)),
+ 2))
+ .setCancelCallback(
+ (i, t) -> {
+ cancelCalled[0] = true;
+ })
+ .build();
+
+ runRequestHandlerThread(done, handler, finish, failures);
+ WorkRequest.newBuilder().setRequestId(42).build().writeDelimitedTo(src);
+ WorkRequest.newBuilder().setRequestId(42).setCancel(true).build().writeDelimitedTo(src);
+ WorkResponse response = WorkResponse.parseDelimitedFrom(dest);
+ done.acquire();
+
+ assertThat(handlerCalled[0] || cancelCalled[0]).isTrue();
+ assertThat(response.getRequestId()).isEqualTo(42);
+ if (response.getWasCancelled()) {
+ assertThat(response.getOutput()).isEmpty();
+ assertThat(response.getExitCode()).isEqualTo(0);
+ } else {
+ assertThat(response.getOutput()).isEqualTo("Such work! Much progress! Wow!\n");
+ assertThat(response.getExitCode()).isEqualTo(1);
+ }
+
+ // Checks that nothing more was sent.
+ assertThat(dest.available()).isEqualTo(0);
+ finish.release();
+
+ // Checks that there weren't other unexpected failures.
+ assertThat(failures).isEmpty();
+ }
+
+ @Test
+ public void testCancelRequest_sendsResponseWhenNotAlreadySent()
+ throws IOException, InterruptedException {
+ Semaphore waitForCancel = new Semaphore(0);
+ Semaphore handlerCalled = new Semaphore(0);
+ Semaphore cancelCalled = new Semaphore(0);
+ PipedOutputStream src = new PipedOutputStream();
+ PipedInputStream dest = new PipedInputStream();
+ Semaphore done = new Semaphore(0);
+ Semaphore finish = new Semaphore(0);
+ List<String> failures = new ArrayList<>();
+
+ // We force the regular handling to not finish until after we have read the cancel response,
+ // to avoid flakiness.
+ WorkRequestHandler handler =
+ new WorkRequestHandlerBuilder(
+ (args, err) -> {
+ // This handler waits until the main thread has sent a cancel request.
+ handlerCalled.release(2);
+ try {
+ waitForCancel.acquire();
+ } catch (InterruptedException e) {
+ failures.add("Unexpected interrupt waiting for cancel request");
+ e.printStackTrace();
+ }
+ return 0;
+ },
+ new PrintStream(new ByteArrayOutputStream()),
+ new LimitedWorkerMessageProcessor(
+ new ProtoWorkerMessageProcessor(
+ new PipedInputStream(src), new PipedOutputStream(dest)),
+ 2))
+ .setCancelCallback(
+ (i, t) -> {
+ cancelCalled.release();
+ })
+ .build();
+
+ runRequestHandlerThread(done, handler, finish, failures);
+ WorkRequest.newBuilder().setRequestId(42).build().writeDelimitedTo(src);
+ // Make sure the handler is called before sending the cancel request, or we might process
+ // the cancellation entirely before that.
+ handlerCalled.acquire();
+ WorkRequest.newBuilder().setRequestId(42).setCancel(true).build().writeDelimitedTo(src);
+ WorkResponse response = WorkResponse.parseDelimitedFrom(dest);
+ waitForCancel.release();
+ // Give the other request a chance to process, so we can check that no other response is sent
+ done.acquire();
+
+ assertThat(handlerCalled.availablePermits()).isEqualTo(1); // Released 2, one was acquired
+ assertThat(cancelCalled.availablePermits()).isEqualTo(1);
+ assertThat(response.getRequestId()).isEqualTo(42);
+ assertThat(response.getOutput()).isEmpty();
+ assertThat(response.getWasCancelled()).isTrue();
+
+ // Checks that nothing more was sent.
+ assertThat(dest.available()).isEqualTo(0);
+ src.close();
+ finish.release();
+
+ // Checks that there weren't other unexpected failures.
+ assertThat(failures).isEmpty();
+ }
+
+ @Test
+ public void testCancelRequest_noDoubleCancelResponse() throws IOException, InterruptedException {
+ Semaphore waitForCancel = new Semaphore(0);
+ Semaphore cancelCalled = new Semaphore(0);
+ PipedOutputStream src = new PipedOutputStream();
+ PipedInputStream dest = new PipedInputStream();
+ Semaphore done = new Semaphore(0);
+ Semaphore finish = new Semaphore(0);
+ List<String> failures = new ArrayList<>();
+
+ // We force the regular handling to not finish until after we have read the cancel response,
+ // to avoid flakiness.
+ WorkRequestHandler handler =
+ new WorkRequestHandlerBuilder(
+ (args, err) -> {
+ try {
+ waitForCancel.acquire();
+ } catch (InterruptedException e) {
+ failures.add("Unexpected interrupt waiting for cancel request");
+ e.printStackTrace();
+ }
+ return 0;
+ },
+ new PrintStream(new ByteArrayOutputStream()),
+ new LimitedWorkerMessageProcessor(
+ new ProtoWorkerMessageProcessor(
+ new PipedInputStream(src), new PipedOutputStream(dest)),
+ 3))
+ .setCancelCallback(
+ (i, t) -> {
+ cancelCalled.release();
+ })
+ .build();
+
+ runRequestHandlerThread(done, handler, finish, failures);
+ WorkRequest.newBuilder().setRequestId(42).build().writeDelimitedTo(src);
+ WorkRequest.newBuilder().setRequestId(42).setCancel(true).build().writeDelimitedTo(src);
+ WorkRequest.newBuilder().setRequestId(42).setCancel(true).build().writeDelimitedTo(src);
+ WorkResponse response = WorkResponse.parseDelimitedFrom(dest);
+ waitForCancel.release();
+ done.acquire();
+
+ assertThat(cancelCalled.availablePermits()).isLessThan(2);
+ assertThat(response.getRequestId()).isEqualTo(42);
+ assertThat(response.getOutput()).isEmpty();
+ assertThat(response.getWasCancelled()).isTrue();
+
+ // Checks that nothing more was sent.
+ assertThat(dest.available()).isEqualTo(0);
+ src.close();
+ finish.release();
+
+ // Checks that there weren't other unexpected failures.
+ assertThat(failures).isEmpty();
+ }
+
+ @Test
+ public void testCancelRequest_sendsNoResponseWhenAlreadySent()
+ throws IOException, InterruptedException {
+ Semaphore handlerCalled = new Semaphore(0);
+ PipedOutputStream src = new PipedOutputStream();
+ PipedInputStream dest = new PipedInputStream();
+ Semaphore done = new Semaphore(0);
+ Semaphore finish = new Semaphore(0);
+ List<String> failures = new ArrayList<>();
+
+ // We force the cancel request to not happen until after we have read the normal response,
+ // to avoid flakiness.
+ WorkRequestHandler handler =
+ new WorkRequestHandlerBuilder(
+ (args, err) -> {
+ handlerCalled.release();
+ err.println("Such work! Much progress! Wow!");
+ return 2;
+ },
+ new PrintStream(new ByteArrayOutputStream()),
+ new LimitedWorkerMessageProcessor(
+ new ProtoWorkerMessageProcessor(
+ new PipedInputStream(src), new PipedOutputStream(dest)),
+ 2))
+ .setCancelCallback((i, t) -> {})
+ .build();
+
+ runRequestHandlerThread(done, handler, finish, failures);
+ WorkRequest.newBuilder().setRequestId(42).build().writeDelimitedTo(src);
+ WorkResponse response = WorkResponse.parseDelimitedFrom(dest);
+ WorkRequest.newBuilder().setRequestId(42).setCancel(true).build().writeDelimitedTo(src);
+ done.acquire();
+
+ assertThat(response).isNotNull();
+
+ assertThat(handlerCalled.availablePermits()).isEqualTo(1);
+ assertThat(response.getRequestId()).isEqualTo(42);
+ assertThat(response.getWasCancelled()).isFalse();
+ assertThat(response.getExitCode()).isEqualTo(2);
+ assertThat(response.getOutput()).isEqualTo("Such work! Much progress! Wow!\n");
+
+ // Checks that nothing more was sent.
+ assertThat(dest.available()).isEqualTo(0);
+ src.close();
+ finish.release();
+
+ // Checks that there weren't other unexpected failures.
+ assertThat(failures).isEmpty();
+ }
+
+ private void runRequestHandlerThread(
+ Semaphore done, WorkRequestHandler handler, Semaphore finish, List<String> failures) {
+ // This thread just makes sure the WorkRequestHandler does work asynchronously.
+ new Thread(
+ () -> {
+ try {
+ handler.processRequests();
+ while (!handler.activeRequests.isEmpty()) {
+ Thread.sleep(1);
+ }
+ done.release();
+ finish.acquire();
+ } catch (IOException | InterruptedException e) {
+ failures.add("Unexpected I/O error talking to worker thread");
+ e.printStackTrace();
+ }
+ })
+ .start();
+ }
+
+ /**
+ * A wrapper around a WorkerMessageProcessor that stops after a given number of requests have been
+ * read. It stops by making readWorkRequest() return null.
+ */
+ private static class LimitedWorkerMessageProcessor implements WorkerMessageProcessor {
+ private final WorkerMessageProcessor delegate;
+ private final int maxMessages;
+ private int messages;
+
+ public LimitedWorkerMessageProcessor(WorkerMessageProcessor delegate, int maxMessages) {
+ this.delegate = delegate;
+ this.maxMessages = maxMessages;
+ }
+
+ @Override
+ public WorkRequest readWorkRequest() throws IOException {
+ System.out.println("Handling request #" + messages);
+ if (++messages > maxMessages) {
+ return null;
+ } else {
+ return delegate.readWorkRequest();
+ }
+ }
+
+ @Override
+ public void writeWorkResponse(WorkResponse workResponse) throws IOException {
+ delegate.writeWorkResponse(workResponse);
+ }
+
+ @Override
+ public void close() throws IOException {
+ delegate.close();
+ }
+ }
}