blob: ad3f3c73f1999dccc3f00ccb5d66c5a502245223 [file] [log] [blame]
// Copyright 2021 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.devtools.build.lib.remote.grpc;
import static com.google.common.truth.Truth.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.devtools.build.lib.remote.grpc.SharedConnectionFactory.SharedConnection;
import com.google.devtools.build.lib.remote.util.RxNoGlobalErrorsRule;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.observers.TestObserver;
import java.io.IOException;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;
/** Tests for {@link SharedConnectionFactory}. */
@RunWith(JUnit4.class)
public class SharedConnectionFactoryTest {
@Rule public final MockitoRule mockito = MockitoJUnit.rule();
@Rule public final RxNoGlobalErrorsRule rxNoGlobalErrorsRule = new RxNoGlobalErrorsRule();
@Mock private Connection connection;
@Mock private ConnectionFactory connectionFactory;
@Before
public void setUp() {
when(connectionFactory.create()).thenAnswer(invocation -> Single.just(connection));
}
@Test
public void create_smoke() {
SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1);
assertThat(factory.numAvailableConnections()).isEqualTo(1);
TestObserver<SharedConnection> observer = factory.create().test();
observer.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete();
verify(connectionFactory, times(1)).create();
assertThat(factory.numAvailableConnections()).isEqualTo(0);
}
@Test
public void create_noConnectionCreationBeforeSubscription() {
SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1);
factory.create();
verify(connectionFactory, times(0)).create();
}
@Test
public void create_exceedingMaxConcurrency_waiting() {
SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1);
TestObserver<SharedConnection> observer1 = factory.create().test();
assertThat(factory.numAvailableConnections()).isEqualTo(0);
observer1.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete();
TestObserver<SharedConnection> observer2 = factory.create().test();
observer2.assertEmpty();
}
@Test
public void create_afterConnectionClosed_shareConnections() throws IOException {
SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1);
TestObserver<SharedConnection> observer1 = factory.create().test();
assertThat(factory.numAvailableConnections()).isEqualTo(0);
observer1.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete();
TestObserver<SharedConnection> observer2 = factory.create().test();
observer1.values().get(0).close();
observer2.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete();
assertThat(factory.numAvailableConnections()).isEqualTo(0);
}
@Test
public void create_belowMaxConcurrency_shareConnections() {
SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 2);
TestObserver<SharedConnection> observer1 = factory.create().test();
assertThat(factory.numAvailableConnections()).isEqualTo(1);
observer1.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete();
TestObserver<SharedConnection> observer2 = factory.create().test();
observer2.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete();
assertThat(factory.numAvailableConnections()).isEqualTo(0);
}
@Test
public void create_concurrentCreate_shareConnections() throws InterruptedException {
int maxConcurrency = 10;
SharedConnectionFactory factory =
new SharedConnectionFactory(connectionFactory, maxConcurrency);
AtomicReference<Throwable> error = new AtomicReference<>(null);
Runnable runnable =
() -> {
try {
TestObserver<SharedConnection> observer = factory.create().test();
observer
.assertNoErrors()
.assertValue(conn -> conn.getUnderlyingConnection() == connection)
.assertComplete();
} catch (Throwable e) {
error.set(e);
}
};
Thread[] threads = new Thread[maxConcurrency];
for (int i = 0; i < threads.length; ++i) {
threads[i] = new Thread(runnable);
}
for (Thread thread : threads) {
thread.start();
}
for (Thread thread : threads) {
thread.join();
}
assertThat(error.get()).isNull();
verify(connectionFactory, times(1)).create();
}
@Test
public void create_afterLastFailed_success() {
AtomicInteger times = new AtomicInteger(0);
ConnectionFactory connectionFactory = mock(ConnectionFactory.class);
when(connectionFactory.create())
.thenAnswer(
invocation -> {
if (times.getAndIncrement() == 0) {
return Single.error(new IllegalStateException("error"));
}
return Single.just(connection);
});
SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1);
Single<SharedConnection> connectionSingle = factory.create();
connectionSingle
.test()
.assertError(IllegalStateException.class)
.assertError(e -> e.getMessage().contains("error"));
assertThat(factory.numAvailableConnections()).isEqualTo(1);
connectionSingle
.test()
.assertValue(conn -> conn.getUnderlyingConnection() == connection)
.assertComplete();
assertThat(times.get()).isEqualTo(2);
assertThat(factory.numAvailableConnections()).isEqualTo(0);
}
@Test
public void create_disposeWhenWaitingForConnectionCreation_doNotCancelCreation()
throws InterruptedException {
AtomicBoolean canceled = new AtomicBoolean(false);
AtomicBoolean finished = new AtomicBoolean(false);
Semaphore disposed = new Semaphore(0);
Semaphore terminated = new Semaphore(0);
ConnectionFactory connectionFactory = mock(ConnectionFactory.class);
when(connectionFactory.create())
.thenAnswer(
invocation ->
Single.create(
emitter ->
new Thread(
() -> {
try {
disposed.acquire();
finished.set(true);
emitter.onSuccess(connection);
} catch (InterruptedException e) {
emitter.onError(e);
}
terminated.release();
})
.start())
.doOnDispose(() -> canceled.set(true)));
SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1);
TestObserver<SharedConnection> observer = factory.create().test();
assertThat(factory.numAvailableConnections()).isEqualTo(0);
observer.assertEmpty().dispose();
disposed.release();
terminated.acquire();
assertThat(canceled.get()).isFalse();
assertThat(finished.get()).isTrue();
assertThat(factory.numAvailableConnections()).isEqualTo(1);
}
@Test
public void create_interrupt_terminate() throws InterruptedException {
AtomicBoolean finished = new AtomicBoolean(false);
AtomicBoolean interrupted = new AtomicBoolean(true);
Semaphore threadTerminatedSemaphore = new Semaphore(0);
Semaphore connectionCreationSemaphore = new Semaphore(0);
ConnectionFactory connectionFactory = mock(ConnectionFactory.class);
when(connectionFactory.create())
.thenAnswer(
invocation ->
Single.create(
emitter ->
new Thread(
() -> {
try {
Thread.sleep(Integer.MAX_VALUE);
finished.set(true);
emitter.onSuccess(connectionFactory);
} catch (InterruptedException e) {
emitter.onError(e);
}
})
.start()));
SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 2);
factory.create().test().assertEmpty();
Thread t =
new Thread(
() -> {
try {
TestObserver<SharedConnection> observer = factory.create().test();
connectionCreationSemaphore.release();
observer.await();
} catch (InterruptedException e) {
interrupted.set(true);
}
threadTerminatedSemaphore.release();
});
t.start();
connectionCreationSemaphore.acquire();
t.interrupt();
threadTerminatedSemaphore.acquire();
assertThat(finished.get()).isFalse();
assertThat(interrupted.get()).isTrue();
}
@Test
public void closeConnection_connectionBecomeAvailable() throws IOException {
SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1);
TestObserver<SharedConnection> observer = factory.create().test();
observer.assertComplete();
SharedConnection conn = observer.values().get(0);
assertThat(factory.numAvailableConnections()).isEqualTo(0);
conn.close();
assertThat(factory.numAvailableConnections()).isEqualTo(1);
verify(connection, times(0)).close();
}
@Test
public void closeFactory_closeUnderlyingConnection() throws IOException {
SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1);
TestObserver<SharedConnection> observer = factory.create().test();
observer.assertComplete();
factory.close();
verify(connection, times(1)).close();
}
@Test
public void closeFactory_noNewConnectionAllowed() throws IOException {
SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1);
factory.close();
TestObserver<SharedConnection> observer = factory.create().test();
observer
.assertError(IllegalStateException.class)
.assertError(e -> e.getMessage().contains("closed"));
}
@Test
public void closeFactory_pendingConnectionCreation_closedError()
throws IOException, InterruptedException {
AtomicBoolean canceled = new AtomicBoolean(false);
AtomicBoolean finished = new AtomicBoolean(false);
Semaphore terminated = new Semaphore(0);
ConnectionFactory connectionFactory = mock(ConnectionFactory.class);
when(connectionFactory.create())
.thenAnswer(
invocation ->
Single.create(
emitter -> {
Thread t =
new Thread(
() -> {
try {
Thread.sleep(Integer.MAX_VALUE);
finished.set(true);
emitter.onSuccess(connection);
} catch (InterruptedException ignored) {
/* no-op */
}
terminated.release();
});
t.start();
emitter.setCancellable(t::interrupt);
})
.doOnDispose(() -> canceled.set(true)));
SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1);
TestObserver<SharedConnection> observer = factory.create().test();
observer.assertEmpty();
assertThat(canceled.get()).isFalse();
factory.close();
terminated.acquire();
observer
.assertError(IllegalStateException.class)
.assertError(e -> e.getMessage().contains("closed"));
assertThat(canceled.get()).isTrue();
assertThat(finished.get()).isFalse();
}
}