// Copyright 2016 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.devtools.build.lib.server;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;

import com.google.common.collect.ImmutableList;
import com.google.devtools.build.lib.clock.JavaClock;
import com.google.devtools.build.lib.runtime.BlazeCommandResult;
import com.google.devtools.build.lib.runtime.CommandDispatcher;
import com.google.devtools.build.lib.runtime.proto.InvocationPolicyOuterClass.InvocationPolicy;
import com.google.devtools.build.lib.server.CommandProtos.CancelRequest;
import com.google.devtools.build.lib.server.CommandProtos.CancelResponse;
import com.google.devtools.build.lib.server.CommandProtos.EnvironmentVariable;
import com.google.devtools.build.lib.server.CommandProtos.RunRequest;
import com.google.devtools.build.lib.server.CommandProtos.RunResponse;
import com.google.devtools.build.lib.server.CommandServerGrpc.CommandServerStub;
import com.google.devtools.build.lib.server.FailureDetails.Command;
import com.google.devtools.build.lib.server.FailureDetails.FailureDetail;
import com.google.devtools.build.lib.server.FailureDetails.GrpcServer;
import com.google.devtools.build.lib.server.FailureDetails.Interrupted;
import com.google.devtools.build.lib.server.FailureDetails.Interrupted.Code;
import com.google.devtools.build.lib.server.GrpcServerImpl.BlockingStreamObserver;
import com.google.devtools.build.lib.testutil.Suite;
import com.google.devtools.build.lib.testutil.TestSpec;
import com.google.devtools.build.lib.testutil.TestUtils;
import com.google.devtools.build.lib.util.Pair;
import com.google.devtools.build.lib.util.io.OutErr;
import com.google.devtools.build.lib.vfs.DigestHashFunction;
import com.google.devtools.build.lib.vfs.FileSystem;
import com.google.devtools.build.lib.vfs.Path;
import com.google.devtools.build.lib.vfs.inmemoryfs.InMemoryFileSystem;
import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.BytesValue;
import com.google.protobuf.StringValue;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Unit tests for the gRPC server. */
@TestSpec(size = Suite.SMALL_TESTS)
@RunWith(JUnit4.class)
public final class GrpcServerTest {

  private static final int SERVER_PID = 42;
  private static final String REQUEST_COOKIE = "request-cookie";

  private final FileSystem fileSystem = new InMemoryFileSystem(DigestHashFunction.SHA256);
  private Server server;
  private ManagedChannel channel;

  private void createServer(CommandDispatcher dispatcher) throws Exception {
    Path serverDirectory = fileSystem.getPath("/bazel_server_directory");
    serverDirectory.createDirectoryAndParents();

    GrpcServerImpl serverImpl =
        new GrpcServerImpl(
            dispatcher,
            ShutdownHooks.createUnregistered(),
            new PidFileWatcher(fileSystem.getPath("/thread-not-running-dont-need"), SERVER_PID),
            new JavaClock(),
            /* port= */ -1,
            REQUEST_COOKIE,
            "response-cookie",
            serverDirectory,
            SERVER_PID,
            1000,
            false,
            false);
    String uniqueName = InProcessServerBuilder.generateName();
    server =
        InProcessServerBuilder.forName(uniqueName)
            .directExecutor()
            .addService(serverImpl)
            .build()
            .start();
    channel = InProcessChannelBuilder.forName(uniqueName).directExecutor().build();
  }

  private static RunRequest createRequest(String... args) {
    return RunRequest.newBuilder()
        .setCookie(REQUEST_COOKIE)
        .setClientDescription("client-description")
        .addAllArg(Arrays.stream(args).map(ByteString::copyFromUtf8).collect(Collectors.toList()))
        .build();
  }

  private static RunRequest createPreemptibleRequest(String... args) {
    return RunRequest.newBuilder()
        .setCookie(REQUEST_COOKIE)
        .setClientDescription("client-description")
        .setPreemptible(true)
        .addAllArg(Arrays.stream(args).map(ByteString::copyFromUtf8).collect(Collectors.toList()))
        .build();
  }

  @Test
  public void testSendingSimpleMessage() throws Exception {
    Any commandExtension = Any.pack(EnvironmentVariable.getDefaultInstance()); // Arbitrary message.
    AtomicReference<List<String>> argsReceived = new AtomicReference<>();
    AtomicReference<List<Any>> commandExtensionsReceived = new AtomicReference<>();
    CommandDispatcher dispatcher =
        new CommandDispatcher() {
          @Override
          public BlazeCommandResult exec(
              InvocationPolicy invocationPolicy,
              List<String> args,
              OutErr outErr,
              LockingMode lockingMode,
              String clientDescription,
              long firstContactTimeMillis,
              Optional<List<Pair<String, String>>> startupOptionsTaggedWithBazelRc,
              List<Any> commandExtensions) {
            argsReceived.set(args);
            commandExtensionsReceived.set(commandExtensions);
            return BlazeCommandResult.success();
          }
        };
    createServer(dispatcher);

    CountDownLatch done = new CountDownLatch(1);
    CommandServerStub stub = CommandServerGrpc.newStub(channel);
    List<RunResponse> responses = new ArrayList<>();
    stub.run(
        createRequest("Foo").toBuilder().addCommandExtensions(commandExtension).build(),
        createResponseObserver(responses, done));
    done.await();
    server.shutdown();
    server.awaitTermination();

    assertThat(argsReceived.get()).containsExactly("Foo");
    assertThat(commandExtensionsReceived.get()).containsExactly(commandExtension);

    assertThat(responses).hasSize(2);
    assertThat(responses.get(0).getFinished()).isFalse();
    assertThat(responses.get(0).getCookie()).isNotEmpty();
    assertThat(responses.get(1).getFinished()).isTrue();
    assertThat(responses.get(1).getExitCode()).isEqualTo(0);
    assertThat(responses.get(1).hasFailureDetail()).isFalse();
  }

  @Test
  public void testClosingClientShouldInterrupt() throws Exception {
    CountDownLatch done = new CountDownLatch(1);
    CommandDispatcher dispatcher =
        new CommandDispatcher() {
          @Override
          public BlazeCommandResult exec(
              InvocationPolicy invocationPolicy,
              List<String> args,
              OutErr outErr,
              LockingMode lockingMode,
              String clientDescription,
              long firstContactTimeMillis,
              Optional<List<Pair<String, String>>> startupOptionsTaggedWithBazelRc,
              List<Any> commandExtensions) {
            synchronized (this) {
              assertThrows(InterruptedException.class, this::wait);
            }
            // The only way this can happen is if the current thread is interrupted.
            done.countDown();
            return BlazeCommandResult.failureDetail(
                FailureDetail.newBuilder()
                    .setInterrupted(Interrupted.newBuilder().setCode(Code.INTERRUPTED_UNKNOWN))
                    .build());
          }
        };
    createServer(dispatcher);

    CommandServerStub stub = CommandServerGrpc.newStub(channel);
    stub.run(
        createRequest("Foo"),
        new StreamObserver<RunResponse>() {
          @Override
          public void onNext(RunResponse value) {
            server.shutdownNow();
            done.countDown();
          }

          @Override
          public void onError(Throwable t) {}

          @Override
          public void onCompleted() {}
        });
    server.awaitTermination();
    done.await();
  }

  @Test
  public void testStream() throws Exception {
    CommandDispatcher dispatcher =
        new CommandDispatcher() {
          @Override
          public BlazeCommandResult exec(
              InvocationPolicy invocationPolicy,
              List<String> args,
              OutErr outErr,
              LockingMode lockingMode,
              String clientDescription,
              long firstContactTimeMillis,
              Optional<List<Pair<String, String>>> startupOptionsTaggedWithBazelRc,
              List<Any> commandExtensions) {
            OutputStream out = outErr.getOutputStream();
            try {
              for (int i = 0; i < 10; i++) {
                out.write(new byte[1024]);
              }
            } catch (IOException e) {
              throw new IllegalStateException(e);
            }
            return BlazeCommandResult.withResponseExtensions(
                BlazeCommandResult.success(),
                ImmutableList.of(
                    Any.pack(StringValue.of("foo")),
                    Any.pack(BytesValue.of(ByteString.copyFromUtf8("bar")))));
          }
        };
    createServer(dispatcher);

    CountDownLatch done = new CountDownLatch(1);
    CommandServerStub stub = CommandServerGrpc.newStub(channel);
    List<RunResponse> responses = new ArrayList<>();
    stub.run(createRequest("Foo"), createResponseObserver(responses, done));
    done.await();
    server.shutdown();
    server.awaitTermination();

    assertThat(responses).hasSize(12);
    assertThat(responses.get(0).getFinished()).isFalse();
    assertThat(responses.get(0).getCookie()).isNotEmpty();
    for (int i = 1; i < 11; i++) {
      assertThat(responses.get(i).getFinished()).isFalse();
      assertThat(responses.get(i).getStandardOutput().toByteArray()).isEqualTo(new byte[1024]);
      assertThat(responses.get(i).getCommandExtensionsList()).isEmpty();
    }
    assertThat(responses.get(11).getFinished()).isTrue();
    assertThat(responses.get(11).getExitCode()).isEqualTo(0);
    assertThat(responses.get(11).hasFailureDetail()).isFalse();
    assertThat(responses.get(11).getCommandExtensionsList())
        .containsExactly(
            Any.pack(StringValue.of("foo")),
            Any.pack(BytesValue.of(ByteString.copyFromUtf8("bar"))));
  }

  @Test
  public void badCookie() throws Exception {
    runBadCommandTest(
        RunRequest.newBuilder().setCookie("bad-cookie").setClientDescription("client-description"),
        FailureDetail.newBuilder()
            .setMessage("Invalid RunRequest: bad cookie")
            .setGrpcServer(GrpcServer.newBuilder().setCode(GrpcServer.Code.BAD_COOKIE))
            .build());
  }

  @Test
  public void emptyClientDescription() throws Exception {
    runBadCommandTest(
        RunRequest.newBuilder().setCookie(REQUEST_COOKIE).setClientDescription(""),
        FailureDetail.newBuilder()
            .setMessage("Invalid RunRequest: no client description")
            .setGrpcServer(GrpcServer.newBuilder().setCode(GrpcServer.Code.NO_CLIENT_DESCRIPTION))
            .build());
  }

  private void runBadCommandTest(RunRequest.Builder runRequestBuilder, FailureDetail failureDetail)
      throws Exception {
    createServer(throwingDispatcher());
    CountDownLatch done = new CountDownLatch(1);
    CommandServerStub stub = CommandServerGrpc.newStub(channel);
    List<RunResponse> responses = new ArrayList<>();

    stub.run(
        runRequestBuilder.addArg(ByteString.copyFromUtf8("Foo")).build(),
        createResponseObserver(responses, done));
    done.await();
    server.shutdown();
    server.awaitTermination();

    assertThat(responses).hasSize(1);
    assertThat(responses.get(0).getFinished()).isTrue();
    assertThat(responses.get(0).getExitCode()).isEqualTo(36);
    assertThat(responses.get(0).hasFailureDetail()).isTrue();
    assertThat(responses.get(0).getFailureDetail()).isEqualTo(failureDetail);
  }

  @Test
  public void unparseableInvocationPolicy() throws Exception {
    createServer(throwingDispatcher());
    CountDownLatch done = new CountDownLatch(1);
    CommandServerStub stub = CommandServerGrpc.newStub(channel);
    List<RunResponse> responses = new ArrayList<>();

    stub.run(
        RunRequest.newBuilder()
            .setCookie(REQUEST_COOKIE)
            .setClientDescription("client-description")
            .setInvocationPolicy("invalid-invocation-policy")
            .addArg(ByteString.copyFromUtf8("Foo"))
            .build(),
        createResponseObserver(responses, done));
    done.await();
    server.shutdown();
    server.awaitTermination();

    assertThat(responses).hasSize(3);
    assertThat(responses.get(2).getFinished()).isTrue();
    assertThat(responses.get(2).getExitCode()).isEqualTo(2);
    assertThat(responses.get(2).hasFailureDetail()).isTrue();
    assertThat(responses.get(2).getFailureDetail())
        .isEqualTo(
            FailureDetail.newBuilder()
                .setMessage(
                    "Invocation policy parsing failed: Malformed value of --invocation_policy: "
                        + "invalid-invocation-policy")
                .setCommand(
                    Command.newBuilder().setCode(Command.Code.INVOCATION_POLICY_PARSE_FAILURE))
                .build());
  }

  @Test
  public void testInterruptStream() throws Exception {
    CountDownLatch done = new CountDownLatch(1);
    CommandDispatcher dispatcher =
        new CommandDispatcher() {
          @Override
          public BlazeCommandResult exec(
              InvocationPolicy invocationPolicy,
              List<String> args,
              OutErr outErr,
              LockingMode lockingMode,
              String clientDescription,
              long firstContactTimeMillis,
              Optional<List<Pair<String, String>>> startupOptionsTaggedWithBazelRc,
              List<Any> commandExtensions) {
            OutputStream out = outErr.getOutputStream();
            try {
              while (true) {
                if (Thread.interrupted()) {
                  return BlazeCommandResult.failureDetail(
                      FailureDetail.newBuilder()
                          .setInterrupted(
                              Interrupted.newBuilder().setCode(Code.INTERRUPTED_UNKNOWN))
                          .build());
                }
                out.write(new byte[1024]);
              }
            } catch (IOException e) {
              throw new IllegalStateException(e);
            }
          }
        };
    createServer(dispatcher);

    CommandServerStub stub = CommandServerGrpc.newStub(channel);
    List<RunResponse> responses = new ArrayList<>();
    stub.run(
        createRequest("Foo"),
        new StreamObserver<RunResponse>() {
          @Override
          public void onNext(RunResponse value) {
            responses.add(value);
            if (responses.size() == 10) {
              server.shutdownNow();
            }
          }

          @Override
          public void onError(Throwable t) {
            done.countDown();
          }

          @Override
          public void onCompleted() {
            done.countDown();
          }
        });
    server.awaitTermination();
    done.await();
  }

  @Test
  public void testCancel() throws Exception {
    CommandDispatcher dispatcher =
        new CommandDispatcher() {
          @Override
          public BlazeCommandResult exec(
              InvocationPolicy invocationPolicy,
              List<String> args,
              OutErr outErr,
              LockingMode lockingMode,
              String clientDescription,
              long firstContactTimeMillis,
              Optional<List<Pair<String, String>>> startupOptionsTaggedWithBazelRc,
              List<Any> commandExtensions)
              throws InterruptedException {
            synchronized (this) {
              this.wait();
            }
            // Interruption expected before this is reached.
            throw new IllegalStateException();
          }
        };
    createServer(dispatcher);

    AtomicReference<String> commandId = new AtomicReference<>();
    CountDownLatch gotCommandId = new CountDownLatch(1);
    AtomicReference<RunResponse> secondResponse = new AtomicReference<>();
    CountDownLatch gotSecondResponse = new CountDownLatch(1);
    CommandServerStub stub = CommandServerGrpc.newStub(channel);
    stub.run(
        createRequest("Foo"),
        new StreamObserver<RunResponse>() {
          @Override
          public void onNext(RunResponse value) {
            String previousCommandId = commandId.getAndSet(value.getCommandId());
            if (previousCommandId == null) {
              gotCommandId.countDown();
            } else {
              secondResponse.set(value);
              gotSecondResponse.countDown();
            }
          }

          @Override
          public void onError(Throwable t) {}

          @Override
          public void onCompleted() {}
        });
    // Wait until we've got the command id.
    gotCommandId.await();

    CountDownLatch cancelRequestComplete = new CountDownLatch(1);
    CancelRequest cancelRequest =
        CancelRequest.newBuilder().setCookie(REQUEST_COOKIE).setCommandId(commandId.get()).build();
    stub.cancel(
        cancelRequest,
        new StreamObserver<CancelResponse>() {
          @Override
          public void onNext(CancelResponse value) {}

          @Override
          public void onError(Throwable t) {}

          @Override
          public void onCompleted() {
            cancelRequestComplete.countDown();
          }
        });
    cancelRequestComplete.await();
    gotSecondResponse.await();
    server.shutdown();
    server.awaitTermination();

    assertThat(secondResponse.get().getFinished()).isTrue();
    assertThat(secondResponse.get().getExitCode()).isEqualTo(8);
    assertThat(secondResponse.get().hasFailureDetail()).isTrue();
    assertThat(secondResponse.get().getFailureDetail().hasInterrupted()).isTrue();
    assertThat(secondResponse.get().getFailureDetail().getInterrupted().getCode())
        .isEqualTo(Code.INTERRUPTED);
  }

  /**
   * Ensure that if a command is marked as preemptible, running a second command interrupts the
   * first command.
   */
  @Test
  public void testPreeempt() throws Exception {
    String firstCommandArg = "Foo";
    String secondCommandArg = "Bar";

    CommandDispatcher dispatcher =
        new CommandDispatcher() {
          @Override
          public BlazeCommandResult exec(
              InvocationPolicy invocationPolicy,
              List<String> args,
              OutErr outErr,
              LockingMode lockingMode,
              String clientDescription,
              long firstContactTimeMillis,
              Optional<List<Pair<String, String>>> startupOptionsTaggedWithBazelRc,
              List<Any> commandExtensions) {
            if (args.contains(firstCommandArg)) {
              while (true) {
                try {
                  Thread.sleep(TestUtils.WAIT_TIMEOUT_MILLISECONDS);
                } catch (InterruptedException e) {
                  return BlazeCommandResult.failureDetail(
                      FailureDetail.newBuilder()
                          .setInterrupted(Interrupted.newBuilder().setCode(Code.INTERRUPTED))
                          .build());
                }
              }
            } else {
              return BlazeCommandResult.success();
            }
          }
        };
    createServer(dispatcher);

    CountDownLatch gotFoo = new CountDownLatch(1);
    AtomicReference<RunResponse> lastFooResponse = new AtomicReference<>();
    AtomicReference<RunResponse> lastBarResponse = new AtomicReference<>();

    CommandServerStub stub = CommandServerGrpc.newStub(channel);
    stub.run(
        createPreemptibleRequest(firstCommandArg),
        new StreamObserver<RunResponse>() {
          @Override
          public void onNext(RunResponse value) {
            gotFoo.countDown();
            lastFooResponse.set(value);
          }

          @Override
          public void onError(Throwable t) {}

          @Override
          public void onCompleted() {}
        });

    // Wait for the first command to startup
    gotFoo.await();

    CountDownLatch gotBar = new CountDownLatch(1);
    stub.run(
        createRequest(secondCommandArg),
        new StreamObserver<RunResponse>() {
          @Override
          public void onNext(RunResponse value) {
            gotBar.countDown();
            lastBarResponse.set(value);
          }

          @Override
          public void onError(Throwable t) {}

          @Override
          public void onCompleted() {}
        });

    gotBar.await();
    server.shutdown();
    server.awaitTermination();

    assertThat(lastBarResponse.get().getFinished()).isTrue();
    assertThat(lastBarResponse.get().getExitCode()).isEqualTo(0);
    assertThat(lastFooResponse.get().getFinished()).isTrue();
    assertThat(lastFooResponse.get().getExitCode()).isEqualTo(8);
    assertThat(lastFooResponse.get().hasFailureDetail()).isTrue();
    assertThat(lastFooResponse.get().getFailureDetail().hasInterrupted()).isTrue();
    assertThat(lastFooResponse.get().getFailureDetail().getInterrupted().getCode())
        .isEqualTo(Code.INTERRUPTED);
  }

  /**
   * Ensure that if a command is marked as preemptible, running a second preemptible command
   * interupts the first command.
   */
  @Test
  public void testMultiPreeempt() throws Exception {
    String firstCommandArg = "Foo";
    String secondCommandArg = "Bar";

    CommandDispatcher dispatcher =
        new CommandDispatcher() {
          @Override
          public BlazeCommandResult exec(
              InvocationPolicy invocationPolicy,
              List<String> args,
              OutErr outErr,
              LockingMode lockingMode,
              String clientDescription,
              long firstContactTimeMillis,
              Optional<List<Pair<String, String>>> startupOptionsTaggedWithBazelRc,
              List<Any> commandExtensions)
              throws InterruptedException {
            if (args.contains(firstCommandArg)) {
              while (true) {
                try {
                  Thread.sleep(TestUtils.WAIT_TIMEOUT_MILLISECONDS);
                } catch (InterruptedException e) {
                  return BlazeCommandResult.failureDetail(
                      FailureDetail.newBuilder()
                          .setInterrupted(Interrupted.newBuilder().setCode(Code.INTERRUPTED))
                          .build());
                }
              }
            } else {
              return BlazeCommandResult.success();
            }
          }
        };
    createServer(dispatcher);

    CountDownLatch gotFoo = new CountDownLatch(1);
    AtomicReference<RunResponse> lastFooResponse = new AtomicReference<>();
    AtomicReference<RunResponse> lastBarResponse = new AtomicReference<>();

    CommandServerStub stub = CommandServerGrpc.newStub(channel);
    stub.run(
        createPreemptibleRequest(firstCommandArg),
        new StreamObserver<RunResponse>() {
          @Override
          public void onNext(RunResponse value) {
            gotFoo.countDown();
            lastFooResponse.set(value);
          }

          @Override
          public void onError(Throwable t) {}

          @Override
          public void onCompleted() {}
        });

    // Wait for the first command to startup
    gotFoo.await();

    CountDownLatch gotBar = new CountDownLatch(1);
    stub.run(
        createPreemptibleRequest(secondCommandArg),
        new StreamObserver<RunResponse>() {
          @Override
          public void onNext(RunResponse value) {
            gotBar.countDown();
            lastBarResponse.set(value);
          }

          @Override
          public void onError(Throwable t) {}

          @Override
          public void onCompleted() {}
        });

    gotBar.await();
    server.shutdown();
    server.awaitTermination();

    assertThat(lastBarResponse.get().getFinished()).isTrue();
    assertThat(lastBarResponse.get().getExitCode()).isEqualTo(0);
    assertThat(lastFooResponse.get().getFinished()).isTrue();
    assertThat(lastFooResponse.get().getExitCode()).isEqualTo(8);
    assertThat(lastFooResponse.get().hasFailureDetail()).isTrue();
    assertThat(lastFooResponse.get().getFailureDetail().hasInterrupted()).isTrue();
    assertThat(lastFooResponse.get().getFailureDetail().getInterrupted().getCode())
        .isEqualTo(Code.INTERRUPTED);
  }

  /**
   * Ensure that when a command is not marked as preemptible, running a second command does not
   * interrupt the first command.
   */
  @Test
  public void testNoPreeempt() throws Exception {
    String firstCommandArg = "Foo";
    String secondCommandArg = "Bar";

    CountDownLatch fooBlocked = new CountDownLatch(1);
    CountDownLatch fooProceed = new CountDownLatch(1);
    CountDownLatch barBlocked = new CountDownLatch(1);
    CountDownLatch barProceed = new CountDownLatch(1);

    CommandDispatcher dispatcher =
        new CommandDispatcher() {
          @Override
          public BlazeCommandResult exec(
              InvocationPolicy invocationPolicy,
              List<String> args,
              OutErr outErr,
              LockingMode lockingMode,
              String clientDescription,
              long firstContactTimeMillis,
              Optional<List<Pair<String, String>>> startupOptionsTaggedWithBazelRc,
              List<Any> commandExtensions)
              throws InterruptedException {
            if (args.contains(firstCommandArg)) {
              fooBlocked.countDown();
              fooProceed.await();
            } else {
              barBlocked.countDown();
              barProceed.await();
            }
            return BlazeCommandResult.success();
          }
        };
    createServer(dispatcher);

    AtomicReference<RunResponse> lastFooResponse = new AtomicReference<>();
    AtomicReference<RunResponse> lastBarResponse = new AtomicReference<>();

    CommandServerStub stub = CommandServerGrpc.newStub(channel);
    stub.run(
        createRequest(firstCommandArg),
        new StreamObserver<RunResponse>() {
          @Override
          public void onNext(RunResponse value) {
            lastFooResponse.set(value);
          }

          @Override
          public void onError(Throwable t) {}

          @Override
          public void onCompleted() {}
        });
    fooBlocked.await();

    stub.run(
        createRequest(secondCommandArg),
        new StreamObserver<RunResponse>() {
          @Override
          public void onNext(RunResponse value) {
            lastBarResponse.set(value);
          }

          @Override
          public void onError(Throwable t) {}

          @Override
          public void onCompleted() {}
        });
    barBlocked.await();

    // At this point both commands should be blocked on proceed latch, carry on...
    fooProceed.countDown();
    barProceed.countDown();

    server.shutdown();
    server.awaitTermination();

    assertThat(lastFooResponse.get().getFinished()).isTrue();
    assertThat(lastFooResponse.get().getExitCode()).isEqualTo(0);
    assertThat(lastBarResponse.get().getFinished()).isTrue();
    assertThat(lastBarResponse.get().getExitCode()).isEqualTo(0);
  }

  @Test
  public void testFlowControl() throws Exception {
    // This test attempts to verify that FlowControl successfully blocks after some number of onNext
    // calls (however long it takes to fill up gRPCs internal buffers). In order to trigger this
    // behavior, we intentionally block the client after a few successful calls, then wait a bit,
    // and then check that the server has stopped prematurely. Unfortunately, we cannot
    // deterministically verify that the onNext call is blocking. A faulty implementation of
    // FlowControl could pass this test if the sleep is too short. However, a correct implementation
    // should never fail this test.
    // This test could start failing if gRPCs internal buffer size is increased. If it fails after
    // an upgrade of gRPC, you might want to check that.
    CountDownLatch serverDone = new CountDownLatch(1);
    CountDownLatch clientBlocks = new CountDownLatch(1);
    CountDownLatch clientUnblocks = new CountDownLatch(1);
    CountDownLatch clientDone = new CountDownLatch(1);
    AtomicInteger sentCount = new AtomicInteger();
    AtomicInteger receiveCount = new AtomicInteger();
    CommandServerGrpc.CommandServerImplBase serverImpl =
        new CommandServerGrpc.CommandServerImplBase() {
          @Override
          public void run(RunRequest request, StreamObserver<RunResponse> observer) {
            ServerCallStreamObserver<RunResponse> serverCallStreamObserver =
                (ServerCallStreamObserver<RunResponse>) observer;
            BlockingStreamObserver<RunResponse> blockingStreamObserver =
                new BlockingStreamObserver<>(serverCallStreamObserver);
            Thread t =
                new Thread(
                    () -> {
                      RunResponse response =
                          RunResponse.newBuilder()
                              .setStandardOutput(ByteString.copyFrom(new byte[1024]))
                              .build();
                      for (int i = 0; i < 100; i++) {
                        blockingStreamObserver.onNext(response);
                        sentCount.incrementAndGet();
                      }
                      blockingStreamObserver.onCompleted();
                      serverDone.countDown();
                    });
            t.start();
          }
        };

    String uniqueName = InProcessServerBuilder.generateName();
    // Do not use .directExecutor here, as it makes both client and server run in the same thread.
    server =
        InProcessServerBuilder.forName(uniqueName)
            .addService(serverImpl)
            .executor(Executors.newFixedThreadPool(4))
            .build()
            .start();
    channel =
        InProcessChannelBuilder.forName(uniqueName)
            .executor(Executors.newFixedThreadPool(4))
            .build();

    CommandServerStub stub = CommandServerGrpc.newStub(channel);
    stub.run(
        RunRequest.getDefaultInstance(),
        new StreamObserver<RunResponse>() {
          @Override
          public void onNext(RunResponse value) {
            if (sentCount.get() >= 3) {
              clientBlocks.countDown();
              try {
                clientUnblocks.await();
              } catch (InterruptedException e) {
                throw new IllegalStateException(e);
              }
            }
            receiveCount.incrementAndGet();
          }

          @Override
          public void onError(Throwable t) {
            throw new IllegalStateException(t);
          }

          @Override
          public void onCompleted() {
            clientDone.countDown();
          }
        });
    clientBlocks.await();
    // Wait a bit for the server to (hopefully) block. If the server does not block, then this may
    // be flaky.
    Thread.sleep(10);
    assertThat(sentCount.get()).isLessThan(5);
    clientUnblocks.countDown();
    serverDone.await();
    clientDone.await();
    server.shutdown();
    server.awaitTermination();
  }

  @Test
  public void testFlowControlClientCancel() throws Exception {
    // This test attempts to verify that FlowControl unblocks if the client prematurely closes the
    // connection. In that case, FlowControl should observe the onCancel event and interrupt the
    // calling thread. I have observed this test failing with an intentionally introduced bug in
    // FlowControl.
    CountDownLatch serverDone = new CountDownLatch(1);
    CountDownLatch clientDone = new CountDownLatch(1);
    AtomicInteger sentCount = new AtomicInteger();
    AtomicInteger receiveCount = new AtomicInteger();
    CommandServerGrpc.CommandServerImplBase serverImpl =
        new CommandServerGrpc.CommandServerImplBase() {
          @Override
          public void run(RunRequest request, StreamObserver<RunResponse> observer) {
            ServerCallStreamObserver<RunResponse> serverCallStreamObserver =
                (ServerCallStreamObserver<RunResponse>) observer;
            BlockingStreamObserver<RunResponse> blockingStreamObserver =
                new BlockingStreamObserver<>(serverCallStreamObserver);
            Thread t =
                new Thread(
                    () -> {
                      RunResponse response =
                          RunResponse.newBuilder()
                              .setStandardOutput(ByteString.copyFrom(new byte[1024]))
                              .build();
                      for (int i = 0; i < 100; i++) {
                        blockingStreamObserver.onNext(response);
                        sentCount.incrementAndGet();
                      }
                      // FlowControl should have interrupted the current thread after learning of
                      // the server
                      // cancel.
                      assertThat(Thread.currentThread().isInterrupted()).isTrue();
                      blockingStreamObserver.onCompleted();
                      serverDone.countDown();
                    });
            t.start();
          }
        };

    String uniqueName = InProcessServerBuilder.generateName();
    // Do not use .directExecutor here, as it makes both client and server run in the same thread.
    server =
        InProcessServerBuilder.forName(uniqueName)
            .addService(serverImpl)
            .executor(Executors.newFixedThreadPool(4))
            .build()
            .start();
    channel =
        InProcessChannelBuilder.forName(uniqueName)
            .executor(Executors.newFixedThreadPool(4))
            .build();

    CommandServerStub stub = CommandServerGrpc.newStub(channel);
    stub.run(
        RunRequest.getDefaultInstance(),
        new StreamObserver<RunResponse>() {
          @Override
          public void onNext(RunResponse value) {
            if (receiveCount.get() > 3) {
              channel.shutdownNow();
            }
            receiveCount.incrementAndGet();
          }

          @Override
          public void onError(Throwable t) {
            clientDone.countDown();
          }

          @Override
          public void onCompleted() {
            clientDone.countDown();
          }
        });
    serverDone.await();
    clientDone.await();
    server.shutdown();
    server.awaitTermination();
  }

  @Test
  public void testInterruptFlowControl() throws Exception {
    // This test attempts to verify that FlowControl does not hang if the current thread is
    // interrupted. The initial implementation of FlowControl (which was never submitted) would go
    // into an infinite loop holding the lock on FlowControl. This would prevent any other thread
    // from obtaining the lock on FlowControl, and hang the entire process. I have confirmed that
    // this test fails with the original faulty implementation of FlowControl.
    CountDownLatch serverDone = new CountDownLatch(1);
    CountDownLatch clientDone = new CountDownLatch(1);
    AtomicInteger sentCount = new AtomicInteger();
    AtomicInteger receiveCount = new AtomicInteger();
    CommandServerGrpc.CommandServerImplBase serverImpl =
        new CommandServerGrpc.CommandServerImplBase() {
          @Override
          public void run(RunRequest request, StreamObserver<RunResponse> observer) {
            ServerCallStreamObserver<RunResponse> serverCallStreamObserver =
                (ServerCallStreamObserver<RunResponse>) observer;
            BlockingStreamObserver<RunResponse> blockingStreamObserver =
                new BlockingStreamObserver<>(serverCallStreamObserver);
            Thread t =
                new Thread(
                    () -> {
                      RunResponse response =
                          RunResponse.newBuilder()
                              .setStandardOutput(ByteString.copyFrom(new byte[1024]))
                              .build();
                      // We want to trigger isReady() -> false, and we use sentCount to control
                      // whether to
                      // sleep on the client side. Therefore, we only set sentCount after isReady()
                      // changes.
                      int sent = 0;
                      while (serverCallStreamObserver.isReady()) {
                        blockingStreamObserver.onNext(response);
                        sent++;
                      }
                      sentCount.set(sent);
                      // If the current thread is interrupted, the subsequent onNext calls should
                      // not
                      // hang, but complete eventually (they may block on flow control).
                      Thread.currentThread().interrupt();
                      for (int i = 0; i < 10; i++) {
                        blockingStreamObserver.onNext(response);
                        sentCount.incrementAndGet();
                      }
                      blockingStreamObserver.onCompleted();
                      serverDone.countDown();
                    });
            t.start();
          }
        };

    String uniqueName = InProcessServerBuilder.generateName();
    // Do not use .directExecutor here, as it makes both client and server run in the same thread.
    server =
        InProcessServerBuilder.forName(uniqueName)
            .addService(serverImpl)
            .executor(Executors.newFixedThreadPool(4))
            .build()
            .start();
    channel =
        InProcessChannelBuilder.forName(uniqueName)
            .executor(Executors.newFixedThreadPool(4))
            .build();

    CommandServerStub stub = CommandServerGrpc.newStub(channel);
    stub.run(
        RunRequest.getDefaultInstance(),
        new StreamObserver<RunResponse>() {
          @Override
          public void onNext(RunResponse value) {
            if (sentCount.get() == 0) {
              try {
                Thread.sleep(1);
              } catch (InterruptedException e) {
                throw new IllegalStateException(e);
              }
            }
            receiveCount.incrementAndGet();
          }

          @Override
          public void onError(Throwable t) {
            throw new IllegalStateException(t);
          }

          @Override
          public void onCompleted() {
            clientDone.countDown();
          }
        });
    serverDone.await();
    clientDone.await();
    assertThat(sentCount.get()).isEqualTo(receiveCount.get());
    server.shutdown();
    server.awaitTermination();
  }

  private static StreamObserver<RunResponse> createResponseObserver(
      List<RunResponse> responses, CountDownLatch done) {
    return new StreamObserver<RunResponse>() {
      @Override
      public void onNext(RunResponse value) {
        responses.add(value);
      }

      @Override
      public void onError(Throwable t) {
        done.countDown();
      }

      @Override
      public void onCompleted() {
        done.countDown();
      }
    };
  }

  private static CommandDispatcher throwingDispatcher() {
    return (invocationPolicy,
        args,
        outErr,
        lockingMode,
        clientDescription,
        firstContactTimeMillis,
        startupOptionsTaggedWithBazelRc,
        commandExtensions) -> {
      throw new IllegalStateException("Command exec not expected");
    };
  }
}
