Add json worker protocol support to getRequest and getResponse.
Add additional unit tests for cases when Worker.java reads and writes json responses and requests.
RELNOTES: None
PiperOrigin-RevId: 325448100
diff --git a/src/main/java/com/google/devtools/build/lib/worker/BUILD b/src/main/java/com/google/devtools/build/lib/worker/BUILD
index 7ea5833..1c38226 100644
--- a/src/main/java/com/google/devtools/build/lib/worker/BUILD
+++ b/src/main/java/com/google/devtools/build/lib/worker/BUILD
@@ -38,8 +38,10 @@
"//third_party:apache_commons_pool2",
"//third_party:auto_value",
"//third_party:flogger",
+ "//third_party:gson",
"//third_party:guava",
"//third_party:jsr305",
"//third_party/protobuf:protobuf_java",
+ "//third_party/protobuf:protobuf_java_util",
],
)
diff --git a/src/main/java/com/google/devtools/build/lib/worker/Worker.java b/src/main/java/com/google/devtools/build/lib/worker/Worker.java
index 85b0eda..5299fc3 100644
--- a/src/main/java/com/google/devtools/build/lib/worker/Worker.java
+++ b/src/main/java/com/google/devtools/build/lib/worker/Worker.java
@@ -13,8 +13,13 @@
// limitations under the License.
package com.google.devtools.build.lib.worker;
+import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.base.Preconditions.checkState;
+import static java.nio.charset.StandardCharsets.UTF_8;
+
import com.google.common.collect.ImmutableList;
import com.google.common.hash.HashCode;
+import com.google.devtools.build.lib.actions.ExecutionRequirements.WorkerProtocolFormat;
import com.google.devtools.build.lib.sandbox.SandboxHelpers.SandboxInputs;
import com.google.devtools.build.lib.sandbox.SandboxHelpers.SandboxOutputs;
import com.google.devtools.build.lib.shell.Subprocess;
@@ -23,13 +28,21 @@
import com.google.devtools.build.lib.vfs.PathFragment;
import com.google.devtools.build.lib.worker.WorkerProtocol.WorkRequest;
import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse;
+import com.google.gson.stream.JsonReader;
+import com.google.protobuf.util.JsonFormat;
+import com.google.protobuf.util.JsonFormat.Printer;
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
import java.io.File;
import java.io.IOException;
+import java.io.InputStreamReader;
+import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.SortedMap;
+import javax.annotation.Nullable;
/**
* Interface to a worker process running as a child process.
@@ -51,8 +64,14 @@
protected final Path workDir;
/** The path of the log file. */
protected final Path logFile;
- /** Stream for reading the WorkResponse. */
- protected RecordingInputStream recordingStream;
+ /** Stream for reading the protobuf WorkResponse. */
+ @Nullable protected RecordingInputStream protoRecordingStream;
+ /** Reader for reading the JSON WorkResponse. */
+ @Nullable protected JsonReader jsonReader;
+ /** Printer for writing the JSON WorkRequest bytes */
+ @Nullable protected Printer jsonPrinter;
+ /** BufferedWriter for the JSON WorkRequest bytes */
+ @Nullable protected BufferedWriter jsonWriter;
private Subprocess process;
private Thread shutdownHook;
@@ -100,6 +119,11 @@
Runtime.getRuntime().removeShutdownHook(shutdownHook);
}
if (process != null) {
+ if (workerKey.getProtocolFormat() == WorkerProtocolFormat.JSON) {
+ jsonReader.close();
+ jsonWriter.close();
+ }
+
wasDestroyed = true;
process.destroyAndWait();
}
@@ -139,22 +163,95 @@
: Optional.empty();
}
+ // TODO(karlgray): Create wrapper class that handles writing and reading JSON worker protocol to
+ // and from stream.
void putRequest(WorkRequest request) throws IOException {
- request.writeDelimitedTo(process.getOutputStream());
- process.getOutputStream().flush();
+ switch (workerKey.getProtocolFormat()) {
+ case JSON:
+ checkNotNull(jsonWriter, "Did prepareExecution get called before putRequest?");
+ checkNotNull(jsonPrinter, "Did prepareExecution get called before putRequest?");
+
+ jsonPrinter.appendTo(request, jsonWriter);
+ jsonWriter.flush();
+ break;
+
+ case PROTO:
+ request.writeDelimitedTo(process.getOutputStream());
+ process.getOutputStream().flush();
+ break;
+ }
}
WorkResponse getResponse() throws IOException {
- recordingStream = new RecordingInputStream(process.getInputStream());
- recordingStream.startRecording(4096);
- // response can be null when the worker has already closed stdout at this point and thus
- // the InputStream is at EOF.
- return WorkResponse.parseDelimitedFrom(recordingStream);
+ switch (workerKey.getProtocolFormat()) {
+ case JSON:
+ checkNotNull(jsonReader, "Did prepareExecution get called before putRequest?");
+
+ return readResponse(jsonReader);
+
+ case PROTO:
+ protoRecordingStream = new RecordingInputStream(process.getInputStream());
+ protoRecordingStream.startRecording(4096);
+ // response can be null when the worker has already closed
+ // stdout at this point and thus the InputStream is at EOF.
+ return WorkResponse.parseDelimitedFrom(protoRecordingStream);
+ }
+
+ throw new IllegalStateException(
+ "Invalid protocol format; protocol formats are currently proto or json");
+ }
+
+ private static WorkResponse readResponse(JsonReader reader) throws IOException {
+ Integer exitCode = null;
+ String output = null;
+ Integer requestId = null;
+
+ reader.beginObject();
+ while (reader.hasNext()) {
+ String name = reader.nextName();
+ switch (name) {
+ case "exitCode":
+ if (exitCode != null) {
+ throw new IOException("Work response cannot have more than one exit code");
+ }
+ exitCode = reader.nextInt();
+ break;
+ case "output":
+ if (output != null) {
+ throw new IOException("Work response cannot have more than one output");
+ }
+ output = reader.nextString();
+ break;
+ case "requestId":
+ if (requestId != null) {
+ throw new IOException("Work response cannot have more than one requestId");
+ }
+ requestId = reader.nextInt();
+ break;
+ default:
+ throw new IOException(name + " is an incorrect field in work response");
+ }
+ }
+ reader.endObject();
+
+ WorkResponse.Builder responseBuilder = WorkResponse.newBuilder();
+
+ if (exitCode != null) {
+ responseBuilder.setExitCode(exitCode);
+ }
+ if (output != null) {
+ responseBuilder.setOutput(output);
+ }
+ if (requestId != null) {
+ responseBuilder.setRequestId(requestId);
+ }
+
+ return responseBuilder.build();
}
String getRecordingStreamMessage() {
- recordingStream.readRemaining();
- return recordingStream.getRecordedDataAsString();
+ protoRecordingStream.readRemaining();
+ return protoRecordingStream.getRecordedDataAsString();
}
public void prepareExecution(
@@ -162,6 +259,19 @@
throws IOException {
if (process == null) {
process = createProcess();
+
+ if (workerKey.getProtocolFormat() == WorkerProtocolFormat.JSON) {
+ checkState(jsonReader == null, "JSON streams inconsistent with process status");
+ checkState(jsonPrinter == null, "JSON streams inconsistent with process status");
+ checkState(jsonWriter == null, "JSON streams inconsistent with process status");
+
+ jsonReader =
+ new JsonReader(
+ new BufferedReader(new InputStreamReader(process.getInputStream(), UTF_8)));
+ jsonReader.setLenient(true);
+ jsonPrinter = JsonFormat.printer().omittingInsignificantWhitespace();
+ jsonWriter = new BufferedWriter(new OutputStreamWriter(process.getOutputStream(), UTF_8));
+ }
}
}
diff --git a/src/test/java/com/google/devtools/build/lib/BUILD b/src/test/java/com/google/devtools/build/lib/BUILD
index ace04f8..0b5c1b2 100644
--- a/src/test/java/com/google/devtools/build/lib/BUILD
+++ b/src/test/java/com/google/devtools/build/lib/BUILD
@@ -555,11 +555,13 @@
"//src/test/java/com/google/devtools/build/lib/testutil:JunitUtils",
"//src/test/java/com/google/devtools/build/lib/testutil:TestUtils",
"//src/test/java/com/google/devtools/build/lib/vfs/util",
+ "//third_party:gson",
"//third_party:guava",
"//third_party:guava-testlib",
"//third_party:junit4",
"//third_party:truth",
"//third_party/protobuf:protobuf_java",
+ "//third_party/protobuf:protobuf_java_util",
],
)
diff --git a/src/test/java/com/google/devtools/build/lib/worker/WorkerTest.java b/src/test/java/com/google/devtools/build/lib/worker/WorkerTest.java
index 72cce57..716d185 100644
--- a/src/test/java/com/google/devtools/build/lib/worker/WorkerTest.java
+++ b/src/test/java/com/google/devtools/build/lib/worker/WorkerTest.java
@@ -15,7 +15,10 @@
package com.google.devtools.build.lib.worker;
import static com.google.common.truth.Truth.assertThat;
+import static com.google.devtools.build.lib.actions.ExecutionRequirements.WorkerProtocolFormat.JSON;
+import static com.google.devtools.build.lib.actions.ExecutionRequirements.WorkerProtocolFormat.PROTO;
import static java.nio.charset.StandardCharsets.UTF_8;
+import static org.junit.Assert.assertThrows;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@@ -49,23 +52,48 @@
public final class WorkerTest {
final FileSystem fs = new InMemoryFileSystem();
- Path workerBaseDir = fs.getPath("/outputbase/bazel-workers");
- WorkerKey key =
- new WorkerKey(
- /* args= */ ImmutableList.of("arg1", "arg2", "arg3"),
- /* env= */ ImmutableMap.of("env1", "foo", "env2", "bar"),
- /* execRoot= */ fs.getPath("/outputbase/execroot/workspace"),
- /* mnemonic= */ "dummy",
- /* workerFilesCombinedHash= */ HashCode.fromInt(0),
- /* workerFilesWithHashes= */ ImmutableSortedMap.of(),
- /* mustBeSandboxed= */ true,
- /* proxied= */ true,
- WorkerProtocolFormat.PROTO);
+ /** A worker that uses a fake subprocess for I/O. */
+ private static class TestWorker extends Worker {
+ private final FakeSubprocess fakeSubprocess;
- int workerId = 1;
- String workTypeName = WorkerKey.makeWorkerTypeName(key.getProxied());
- Path logFile =
- workerBaseDir.getRelative(workTypeName + "-" + workerId + "-" + key.getMnemonic() + ".log");
+ public TestWorker(
+ WorkerKey workerKey,
+ int workerId,
+ final Path workDir,
+ Path logFile,
+ FakeSubprocess fakeSubprocess) {
+ super(workerKey, workerId, workDir, logFile);
+ this.fakeSubprocess = fakeSubprocess;
+ }
+
+ @Override
+ Subprocess createProcess() {
+ return fakeSubprocess;
+ }
+ }
+
+ private TestWorker workerForCleanup = null;
+
+ @After
+ public void destroyWorker() throws IOException {
+ if (workerForCleanup != null) {
+ workerForCleanup.destroy();
+ workerForCleanup = null;
+ }
+ }
+
+ private WorkerKey createWorkerKey(WorkerProtocolFormat protocolFormat) {
+ return new WorkerKey(
+ /* args= */ ImmutableList.of("arg1", "arg2", "arg3"),
+ /* env= */ ImmutableMap.of("env1", "foo", "env2", "bar"),
+ /* execRoot= */ fs.getPath("/outputbase/execroot/workspace"),
+ /* mnemonic= */ "dummy",
+ /* workerFilesCombinedHash= */ HashCode.fromInt(0),
+ /* workerFilesWithHashes= */ ImmutableSortedMap.of(),
+ /* mustBeSandboxed= */ true,
+ /* proxied= */ true,
+ protocolFormat);
+ }
/**
* The {@link Worker} object uses a {@link Subprocess} to interact with persistent worker
@@ -143,47 +171,25 @@
}
}
- /**
- * To use the {@link FakeSubprocess}, {@link TestWorker} is created to override createProcess()
- * and create an additional getter method to ensure the proper bytes are being written to stream.
- */
- private static class TestWorker extends Worker {
- private final FakeSubprocess fakeSubprocess;
-
- public TestWorker(
- WorkerKey workerKey,
- int workerId,
- final Path workDir,
- Path logFile,
- FakeSubprocess fakeSubprocess) {
- super(workerKey, workerId, workDir, logFile);
- this.fakeSubprocess = fakeSubprocess;
- }
-
- @Override
- Subprocess createProcess() {
- return fakeSubprocess;
- }
+ private static byte[] serializeResponseToProtoBytes(WorkResponse response) throws IOException {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ response.writeDelimitedTo(baos);
+ return baos.toByteArray();
}
- private TestWorker workerForCleanup = null;
-
- @After
- public void destroyWorker() throws IOException {
- if (workerForCleanup != null) {
- workerForCleanup.destroy();
- workerForCleanup = null;
- }
- }
-
- private TestWorker createTestWorker(WorkResponse response) throws IOException {
+ private TestWorker createTestWorker(byte[] outputStreamBytes, WorkerProtocolFormat protocolFormat)
+ throws IOException {
Preconditions.checkState(
workerForCleanup == null, "createTestWorker can only be called once per test");
- ByteArrayOutputStream baos = new ByteArrayOutputStream();
- response.writeDelimitedTo(baos);
+ WorkerKey key = createWorkerKey(protocolFormat);
- FakeSubprocess fakeSubprocess = new FakeSubprocess(baos.toByteArray());
+ FakeSubprocess fakeSubprocess = new FakeSubprocess(outputStreamBytes);
+
+ Path workerBaseDir = fs.getPath("/outputbase/bazel-workers");
+ int workerId = 1;
+ Path logFile = workerBaseDir.getRelative("test-log-file.log");
+
TestWorker worker = new TestWorker(key, workerId, key.getExecRoot(), logFile, fakeSubprocess);
SandboxInputs sandboxInputs = null;
@@ -197,6 +203,48 @@
@Test
public void testPutRequest_success() throws IOException {
+ WorkRequest request = WorkRequest.getDefaultInstance();
+
+ TestWorker testWorker = createTestWorker(new byte[0], PROTO);
+ testWorker.putRequest(request);
+
+ OutputStream stdout = testWorker.fakeSubprocess.getOutputStream();
+ WorkRequest requestFromStdout =
+ WorkRequest.parseDelimitedFrom(new ByteArrayInputStream(stdout.toString().getBytes(UTF_8)));
+
+ assertThat(requestFromStdout).isEqualTo(request);
+ }
+
+ @Test
+ public void testGetResponse_success() throws IOException {
+ WorkResponse response = WorkResponse.getDefaultInstance();
+
+ TestWorker testWorker = createTestWorker(serializeResponseToProtoBytes(response), PROTO);
+ WorkResponse readResponse = testWorker.getResponse();
+
+ assertThat(readResponse).isEqualTo(response);
+ }
+
+ @Test
+ public void testPutRequest_json_success() throws IOException {
+ TestWorker testWorker = createTestWorker(new byte[0], JSON);
+ testWorker.putRequest(WorkRequest.getDefaultInstance());
+
+ OutputStream stdout = testWorker.fakeSubprocess.getOutputStream();
+ assertThat(stdout.toString()).isEqualTo("{}");
+ }
+
+ @Test
+ public void testGetResponse_json_success() throws IOException {
+ TestWorker testWorker = createTestWorker("{}".getBytes(UTF_8), JSON);
+ WorkResponse readResponse = testWorker.getResponse();
+ WorkResponse response = WorkResponse.getDefaultInstance();
+
+ assertThat(readResponse).isEqualTo(response);
+ }
+
+ @Test
+ public void testPutRequest_json_populatedFields_success() throws IOException {
WorkRequest request =
WorkRequest.newBuilder()
.addArguments("testRequest")
@@ -208,27 +256,56 @@
.setRequestId(1)
.build();
- WorkResponse response =
- WorkResponse.newBuilder().setExitCode(1).setOutput("test output").setRequestId(1).build();
-
- TestWorker testWorker = createTestWorker(response);
+ TestWorker testWorker = createTestWorker(new byte[0], JSON);
testWorker.putRequest(request);
- OutputStream output = testWorker.fakeSubprocess.getOutputStream();
- WorkRequest requestFromOutput =
- WorkRequest.parseDelimitedFrom(new ByteArrayInputStream(output.toString().getBytes(UTF_8)));
-
- assertThat(request).isEqualTo(requestFromOutput);
+ OutputStream stdout = testWorker.fakeSubprocess.getOutputStream();
+ String requestJsonString =
+ "{\"arguments\":[\"testRequest\"],\"inputs\":"
+ + "[{\"path\":\"testPath\",\"digest\":\"dGVzdERpZ2VzdA==\"}],\"requestId\":1}";
+ assertThat(stdout.toString()).isEqualTo(requestJsonString);
}
@Test
- public void testGetResponse_success() throws IOException {
+ public void testGetResponse_json_populatedFields_success() throws IOException {
+ TestWorker testWorker =
+ createTestWorker(
+ "{\"exitCode\":1,\"output\":\"test output\",\"requestId\":1}".getBytes(UTF_8), JSON);
+ WorkResponse readResponse = testWorker.getResponse();
WorkResponse response =
WorkResponse.newBuilder().setExitCode(1).setOutput("test output").setRequestId(1).build();
- TestWorker testWorker = createTestWorker(response);
- WorkResponse readResponse = testWorker.getResponse();
+ assertThat(readResponse).isEqualTo(response);
+ }
- assertThat(response).isEqualTo(readResponse);
+ private void verifyGetResponseFailure(String responseString, String expectedError)
+ throws IOException {
+ TestWorker testWorker = createTestWorker(responseString.getBytes(UTF_8), JSON);
+ IOException ex = assertThrows(IOException.class, testWorker::getResponse);
+ assertThat(ex).hasMessageThat().isEqualTo(expectedError);
+ }
+
+ @Test
+ public void testGetResponse_json_multipleExitCode_fails() throws IOException {
+ verifyGetResponseFailure(
+ "{\"exitCode\":1,\"exitCode\":1}", "Work response cannot have more than one exit code");
+ }
+
+ @Test
+ public void testGetResponse_json_multipleOutput_fails() throws IOException {
+ verifyGetResponseFailure(
+ "{\"output\":\"\",\"output\":\"\"}", "Work response cannot have more than one output");
+ }
+
+ @Test
+ public void testGetResponse_json_multipleRequestId_fails() throws IOException {
+ verifyGetResponseFailure(
+ "{\"requestId\":0,\"requestId\":0}", "Work response cannot have more than one requestId");
+ }
+
+ @Test
+ public void testGetResponse_json_incorrectFields_fails() throws IOException {
+ verifyGetResponseFailure(
+ "{\"testField\":0}", "testField is an incorrect field in work response");
}
}