Use ConcurrentMap instead of Map+Semaphore.

Simplifies things immensely, since there are now fewer places that can be interrupted.

RELNOTES: n/a
PiperOrigin-RevId: 342609938
diff --git a/src/main/java/com/google/devtools/build/lib/worker/Worker.java b/src/main/java/com/google/devtools/build/lib/worker/Worker.java
index 48d0f0b..0b035a3 100644
--- a/src/main/java/com/google/devtools/build/lib/worker/Worker.java
+++ b/src/main/java/com/google/devtools/build/lib/worker/Worker.java
@@ -14,7 +14,7 @@
 package com.google.devtools.build.lib.worker;
 
 import com.google.common.hash.HashCode;
-import com.google.devtools.build.lib.events.ExtendedEventHandler;
+import com.google.devtools.build.lib.events.EventHandler;
 import com.google.devtools.build.lib.sandbox.SandboxHelpers.SandboxInputs;
 import com.google.devtools.build.lib.sandbox.SandboxHelpers.SandboxOutputs;
 import com.google.devtools.build.lib.vfs.Path;
@@ -70,7 +70,7 @@
    * Sets the reporter this {@code Worker} should report anomalous events to, or clears it. We
    * expect the reporter to be cleared at end of build.
    */
-  void setReporter(ExtendedEventHandler reporter) {}
+  void setReporter(EventHandler reporter) {}
 
   /**
    * Performs the necessary steps to prepare for execution. Once this is done, the worker should be
@@ -86,9 +86,8 @@
    * @param request The request to send.
    * @throws IOException If there was a problem doing I/O, or this thread was interrupted at a time
    *     where some or all of the expected I/O has been done.
-   * @throws InterruptedException If this thread was interrupted before doing any I/O.
    */
-  abstract void putRequest(WorkRequest request) throws IOException, InterruptedException;
+  abstract void putRequest(WorkRequest request) throws IOException;
 
   /**
    * Waits to receive a response from the worker. This method should return as soon as a response
diff --git a/src/main/java/com/google/devtools/build/lib/worker/WorkerMultiplexer.java b/src/main/java/com/google/devtools/build/lib/worker/WorkerMultiplexer.java
index 5419457..858dbbe 100644
--- a/src/main/java/com/google/devtools/build/lib/worker/WorkerMultiplexer.java
+++ b/src/main/java/com/google/devtools/build/lib/worker/WorkerMultiplexer.java
@@ -18,7 +18,7 @@
 import com.google.common.collect.ImmutableList;
 import com.google.common.flogger.GoogleLogger;
 import com.google.devtools.build.lib.events.Event;
-import com.google.devtools.build.lib.events.ExtendedEventHandler;
+import com.google.devtools.build.lib.events.EventHandler;
 import com.google.devtools.build.lib.shell.Subprocess;
 import com.google.devtools.build.lib.shell.SubprocessBuilder;
 import com.google.devtools.build.lib.shell.SubprocessFactory;
@@ -29,10 +29,10 @@
 import java.io.IOException;
 import java.io.InterruptedIOException;
 import java.util.ArrayList;
-import java.util.HashMap;
 import java.util.List;
-import java.util.Map;
 import java.util.Optional;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.Semaphore;
 
 /**
@@ -46,20 +46,16 @@
   private static final GoogleLogger logger = GoogleLogger.forEnclosingClass();
   /**
    * A map of {@code WorkResponse}s received from the worker process. They are stored in this map
-   * until the corresponding {@code WorkerProxy} picks them up.
+   * keyed by the request id until the corresponding {@code WorkerProxy} picks them up.
    */
-  private final Map<Integer, WorkResponse> workerProcessResponse;
-  /** A semaphore to protect {@code workerProcessResponse} object. */
-  private final Semaphore semWorkerProcessResponse;
+  private final ConcurrentMap<Integer, WorkResponse> workerProcessResponse;
   /**
    * A map of semaphores corresponding to {@code WorkRequest}s. After sending the {@code
    * WorkRequest}, {@code WorkerProxy} will wait on a semaphore to be released. {@code
    * WorkerMultiplexer} is responsible for releasing the corresponding semaphore in order to signal
    * {@code WorkerProxy} that the {@code WorkerResponse} has been received.
    */
-  private final Map<Integer, Semaphore> responseChecker;
-  /** A semaphore to protect responseChecker object. */
-  private final Semaphore semResponseChecker;
+  private final ConcurrentMap<Integer, Semaphore> responseChecker;
   /** The worker process that this WorkerMultiplexer should be talking to. */
   private Subprocess process;
   /**
@@ -93,28 +89,26 @@
    * The active Reporter object, non-null if {@code --worker_verbose} is set. This must be cleared
    * at the end of a command execution.
    */
-  public ExtendedEventHandler reporter;
+  public EventHandler reporter;
 
   WorkerMultiplexer(Path logFile, WorkerKey workerKey) {
     this.logFile = logFile;
     this.workerKey = workerKey;
-    semWorkerProcessResponse = new Semaphore(1);
-    semResponseChecker = new Semaphore(1);
-    responseChecker = new HashMap<>();
-    workerProcessResponse = new HashMap<>();
+    responseChecker = new ConcurrentHashMap<>();
+    workerProcessResponse = new ConcurrentHashMap<>();
     isWorkerStreamCorrupted = false;
     isWorkerStreamClosed = false;
     wasDestroyed = false;
   }
 
   /** Sets or clears the reporter for outputting verbose info. */
-  void setReporter(ExtendedEventHandler reporter) {
+  void setReporter(EventHandler reporter) {
     this.reporter = reporter;
   }
 
   /** Reports a string to the user if reporting is enabled. */
   private void report(String s) {
-    ExtendedEventHandler r = this.reporter; // Protect against race condition with setReporter().
+    EventHandler r = this.reporter; // Protect against race condition with setReporter().
     if (r != null && s != null) {
       r.handle(Event.info(s));
     }
@@ -198,6 +192,7 @@
    * WorkerProxy}, and so is subject to interrupts by dynamic execution.
    */
   public synchronized void putRequest(WorkRequest request) throws IOException {
+    responseChecker.put(request.getRequestId(), new Semaphore(0));
     try {
       request.writeDelimitedTo(process.getOutputStream());
       process.getOutputStream().flush();
@@ -209,6 +204,7 @@
       if (e instanceof InterruptedIOException) {
         Thread.currentThread().interrupt();
       }
+      responseChecker.remove(request.getRequestId());
       throw e;
     }
   }
@@ -220,11 +216,10 @@
    */
   public WorkResponse getResponse(Integer requestId) throws InterruptedException {
     try {
-      semResponseChecker.acquire();
       Semaphore waitForResponse = responseChecker.get(requestId);
-      semResponseChecker.release();
 
       if (waitForResponse == null) {
+        report("Null response semaphore for " + requestId);
         // If the multiplexer is interrupted when a {@code WorkerProxy} is trying to send a request,
         // the request is not sent, so there is no need to wait for a response.
         return null;
@@ -233,38 +228,18 @@
       // Wait for the multiplexer to get our response and release this semaphore. The semaphore will
       // throw {@code InterruptedException} when the multiplexer is terminated.
       waitForResponse.acquire();
+      report("Acquired response semaphore for " + requestId);
 
-      if (isWorkerStreamClosed || isWorkerStreamCorrupted) {
-        return null;
-      }
-
-      semWorkerProcessResponse.acquire();
-      WorkResponse response = workerProcessResponse.get(requestId);
-      semWorkerProcessResponse.release();
-      return response;
+      WorkResponse workResponse = workerProcessResponse.get(requestId);
+      report("Response for " + requestId + " is " + workResponse);
+      return workResponse;
     } finally {
-      semResponseChecker.acquire();
       responseChecker.remove(requestId);
-      semResponseChecker.release();
-      semWorkerProcessResponse.acquire();
       workerProcessResponse.remove(requestId);
-      semWorkerProcessResponse.release();
     }
   }
 
   /**
-   * Puts an entry for {@code requestId} into the semaphore map before sending a request to the
-   * worker process. This method is called on the thread of a {@code WorkerProxy}, and so is subject
-   * to interrupts by dynamic execution.
-   */
-  void putResponseChecker(Integer requestId) throws InterruptedException {
-    // This is separate from putRequest to avoid waiting for a semaphore in a synchronized method.
-    semResponseChecker.acquire();
-    responseChecker.put(requestId, new Semaphore(0));
-    semResponseChecker.release();
-  }
-
-  /**
    * Waits to read a {@code WorkResponse} from worker process, put that {@code WorkResponse} in
    * {@code workerProcessResponse} and release the semaphore for the {@code WorkerProxy}.
    *
@@ -295,26 +270,17 @@
 
     int requestId = parsedResponse.getRequestId();
 
-    semWorkerProcessResponse.acquire();
     workerProcessResponse.put(requestId, parsedResponse);
-    semWorkerProcessResponse.release();
 
     // TODO(b/151767359): When allowing cancellation, just remove responses that have no matching
     // entry in responseChecker.
-    semResponseChecker.acquire();
     Semaphore semaphore = responseChecker.get(requestId);
     if (semaphore != null) {
       // This wakes up the WorkerProxy that should receive this response.
       semaphore.release();
-      semResponseChecker.release();
     } else {
       report(String.format("Multiplexer for %s found no semaphore", workerKey.getMnemonic()));
-      semResponseChecker.release();
-      logger.atWarning().log("Received response for unknown request %d.", requestId);
-      semWorkerProcessResponse.acquire();
-      // Prevent memory leak of useless responses.
       workerProcessResponse.remove(requestId);
-      semWorkerProcessResponse.release();
     }
   }
 
@@ -369,23 +335,11 @@
    * down the multiplexer.
    */
   private void releaseAllSemaphores() {
-    try {
-      semResponseChecker.acquire();
-      for (Semaphore semaphore : responseChecker.values()) {
-        semaphore.release();
-      }
-      responseChecker.clear();
-      semResponseChecker.release();
-    } catch (InterruptedException e) {
-      // Do nothing - we only get interrupted during shutdown
+    for (Semaphore semaphore : responseChecker.values()) {
+      semaphore.release();
     }
-    try {
-      semWorkerProcessResponse.acquire();
-      workerProcessResponse.clear();
-      semWorkerProcessResponse.release();
-    } catch (InterruptedException e) {
-      // Do nothing - we only get interrupted during shutdown
-    }
+    responseChecker.clear();
+    workerProcessResponse.clear();
   }
 
   String getRecordingStreamMessage() {
diff --git a/src/main/java/com/google/devtools/build/lib/worker/WorkerProxy.java b/src/main/java/com/google/devtools/build/lib/worker/WorkerProxy.java
index 40b2151..6b3258c 100644
--- a/src/main/java/com/google/devtools/build/lib/worker/WorkerProxy.java
+++ b/src/main/java/com/google/devtools/build/lib/worker/WorkerProxy.java
@@ -16,7 +16,7 @@
 
 import com.google.common.flogger.GoogleLogger;
 import com.google.devtools.build.lib.actions.UserExecException;
-import com.google.devtools.build.lib.events.ExtendedEventHandler;
+import com.google.devtools.build.lib.events.EventHandler;
 import com.google.devtools.build.lib.sandbox.SandboxHelpers.SandboxInputs;
 import com.google.devtools.build.lib.sandbox.SandboxHelpers.SandboxOutputs;
 import com.google.devtools.build.lib.vfs.Path;
@@ -56,7 +56,7 @@
   }
 
   @Override
-  void setReporter(ExtendedEventHandler reporter) {
+  void setReporter(EventHandler reporter) {
     workerMultiplexer.setReporter(reporter);
   }
 
@@ -83,8 +83,7 @@
 
   /** Send the WorkRequest to multiplexer. */
   @Override
-  void putRequest(WorkRequest request) throws IOException, InterruptedException {
-    workerMultiplexer.putResponseChecker(request.getRequestId());
+  void putRequest(WorkRequest request) throws IOException {
     workerMultiplexer.putRequest(request);
   }
 
diff --git a/src/test/java/com/google/devtools/build/lib/worker/WorkerMultiplexerTest.java b/src/test/java/com/google/devtools/build/lib/worker/WorkerMultiplexerTest.java
index c216292..9075e2a 100644
--- a/src/test/java/com/google/devtools/build/lib/worker/WorkerMultiplexerTest.java
+++ b/src/test/java/com/google/devtools/build/lib/worker/WorkerMultiplexerTest.java
@@ -62,8 +62,7 @@
     multiplexer.setProcessFactory(params -> new FakeSubprocess(serverInputStream));
 
     WorkRequest request1 = WorkRequest.newBuilder().setRequestId(1).build();
-    WorkerProxy worker =
-        new WorkerProxy(workerKey, 2, logPath, logPath, multiplexer /* workerVerbose */);
+    WorkerProxy worker = new WorkerProxy(workerKey, 2, logPath, logPath, multiplexer);
     worker.prepareExecution(null, null, null);
     worker.putRequest(request1);
     WorkResponse response1 = WorkResponse.newBuilder().setRequestId(1).build();
@@ -86,14 +85,12 @@
     OutputStream workerOutputStream = new PipedOutputStream(serverInputStream);
     multiplexer.setProcessFactory(params -> new FakeSubprocess(serverInputStream));
 
-    WorkerProxy worker1 =
-        new WorkerProxy(workerKey, 1, logPath, logPath, multiplexer /* workerVerbose */);
+    WorkerProxy worker1 = new WorkerProxy(workerKey, 1, logPath, logPath, multiplexer);
     worker1.prepareExecution(null, null, null);
     WorkRequest request1 = WorkRequest.newBuilder().setRequestId(3).build();
     worker1.putRequest(request1);
 
-    WorkerProxy worker2 =
-        new WorkerProxy(workerKey, 2, logPath, logPath, multiplexer /* workerVerbose */);
+    WorkerProxy worker2 = new WorkerProxy(workerKey, 2, logPath, logPath, multiplexer);
     worker2.prepareExecution(null, null, null);
     WorkRequest request2 = WorkRequest.newBuilder().setRequestId(42).build();
     worker2.putRequest(request2);
@@ -124,14 +121,12 @@
     OutputStream workerOutputStream = new PipedOutputStream(serverInputStrean);
     multiplexer.setProcessFactory(params -> new FakeSubprocess(serverInputStrean));
 
-    WorkerProxy worker1 =
-        new WorkerProxy(workerKey, 1, logPath, logPath, multiplexer /* workerVerbose */);
+    WorkerProxy worker1 = new WorkerProxy(workerKey, 1, logPath, logPath, multiplexer);
     worker1.prepareExecution(null, null, null);
     WorkRequest request1 = WorkRequest.newBuilder().setRequestId(3).build();
     worker1.putRequest(request1);
 
-    WorkerProxy worker2 =
-        new WorkerProxy(workerKey, 2, logPath, logPath, multiplexer /* workerVerbose */);
+    WorkerProxy worker2 = new WorkerProxy(workerKey, 2, logPath, logPath, multiplexer);
     worker2.prepareExecution(null, null, null);
     WorkRequest request2 = WorkRequest.newBuilder().setRequestId(42).build();
     worker2.putRequest(request2);
@@ -183,14 +178,12 @@
     OutputStream workerOutputStream = new PipedOutputStream(serverInputStream);
     multiplexer.setProcessFactory(params -> new FakeSubprocess(serverInputStream));
 
-    WorkerProxy worker1 =
-        new WorkerProxy(workerKey, 1, logPath, logPath, multiplexer /* workerVerbose */);
+    WorkerProxy worker1 = new WorkerProxy(workerKey, 1, logPath, logPath, multiplexer);
     worker1.prepareExecution(null, null, null);
     WorkRequest request1 = WorkRequest.newBuilder().setRequestId(3).build();
     worker1.putRequest(request1);
 
-    WorkerProxy worker2 =
-        new WorkerProxy(workerKey, 2, logPath, logPath, multiplexer /* workerVerbose */);
+    WorkerProxy worker2 = new WorkerProxy(workerKey, 2, logPath, logPath, multiplexer);
     worker2.prepareExecution(null, null, null);
     WorkRequest request2 = WorkRequest.newBuilder().setRequestId(42).build();
     worker2.putRequest(request2);