blob: 370d5ffa141c398af86b612249e0085d57f84000 [file] [log] [blame]
// Copyright 2018 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.devtools.build.lib.remote.blobstore.http;
import static com.google.common.truth.Truth.assertThat;
import static com.google.devtools.build.lib.remote.util.Utils.getFromFuture;
import static java.util.Collections.singletonList;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.api.client.util.Preconditions;
import com.google.auth.Credentials;
import com.google.common.base.Charsets;
import com.google.common.collect.ImmutableList;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandler;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.ServerChannel;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.epoll.Epoll;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollServerDomainSocketChannel;
import io.netty.channel.kqueue.KQueue;
import io.netty.channel.kqueue.KQueueEventLoopGroup;
import io.netty.channel.kqueue.KQueueServerDomainSocketChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.unix.DomainSocketAddress;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.HttpVersion;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.net.ConnectException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.IntFunction;
import javax.annotation.Nullable;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
import org.mockito.AdditionalAnswers;
import org.mockito.Mockito;
/**
* Tests for {@link HttpBlobStore}.
*/
@RunWith(Parameterized.class)
public class HttpBlobStoreTest {
private static ServerChannel createServer(
Class<? extends ServerChannel> serverChannelClass,
IntFunction<EventLoopGroup> newEventLoopGroup,
SocketAddress socketAddress,
ChannelHandler handler) {
EventLoopGroup eventLoop = newEventLoopGroup.apply(1);
ServerBootstrap sb =
new ServerBootstrap()
.group(eventLoop)
.channel(serverChannelClass)
.childHandler(
new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
ch.pipeline().addLast(new HttpServerCodec());
ch.pipeline().addLast(new HttpObjectAggregator(1000));
ch.pipeline().addLast(handler);
}
});
try {
return ((ServerChannel) sb.bind(socketAddress).sync().channel());
} catch (Exception e) {
throw new IllegalStateException(e);
}
}
private static DomainSocketAddress newDomainSocketAddress() {
try {
File file = File.createTempFile("bazel", ".sock", new File("/tmp"));
file.delete();
return new DomainSocketAddress(file.getAbsoluteFile());
} catch (Exception e) {
throw new IllegalStateException(e);
}
}
interface TestServer {
ServerChannel start(ChannelInboundHandler handler);
void stop(ServerChannel serverChannel);
}
private static final class InetTestServer implements TestServer {
public ServerChannel start(ChannelInboundHandler handler) {
return createServer(
NioServerSocketChannel.class,
NioEventLoopGroup::new,
new InetSocketAddress("localhost", 0),
handler);
}
public void stop(ServerChannel serverChannel) {
try {
serverChannel.close();
serverChannel.closeFuture().sync();
serverChannel.eventLoop().shutdownGracefully().sync();
} catch (Exception e) {
throw new IllegalStateException(e);
}
}
}
private static final class UnixDomainServer implements TestServer {
// Note: this odd implementation is a workaround because we're unable to shut down and restart
// KQueue backed implementations. See https://github.com/netty/netty/issues/7047.
private final ServerChannel serverChannel;
private ChannelInboundHandler handler = null;
public UnixDomainServer(
Class<? extends ServerChannel> serverChannelClass,
IntFunction<EventLoopGroup> newEventLoopGroup
) {
EventLoopGroup eventLoop = newEventLoopGroup.apply(1);
ServerBootstrap sb =
new ServerBootstrap()
.group(eventLoop)
.channel(serverChannelClass)
.childHandler(
new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
ch.pipeline().addLast(new HttpServerCodec());
ch.pipeline().addLast(new HttpObjectAggregator(1000));
ch.pipeline().addLast(Preconditions.checkNotNull(handler));
}
});
try {
ServerChannel actual = ((ServerChannel) sb.bind(newDomainSocketAddress()).sync().channel());
this.serverChannel = mock(ServerChannel.class, AdditionalAnswers.delegatesTo(actual));
} catch (Exception e) {
throw new IllegalStateException(e);
}
}
public ServerChannel start(ChannelInboundHandler handler) {
reset(this.serverChannel);
this.handler = handler;
return this.serverChannel;
}
public void stop(ServerChannel serverChannel) {
// Note: In the tests, we expect that connecting to a closed server channel results
// in a channel connection error. Netty doesn't seem to handle closing domain socket
// addresses very well-- often connecting to a closed domain socket will result in a
// read timeout instead of a connection timeout.
//
// This is a hack to ensure connection timeouts are "received" by the tests for this
// dummy domain socket server. In particular, this lets the timeoutShouldWork_connect
// test work for both inet and domain sockets.
//
// This is also part of the workaround for https://github.com/netty/netty/issues/7047.
when(this.serverChannel.localAddress()).thenReturn(new DomainSocketAddress(""));
this.handler = null;
}
}
@Parameters
public static Collection createInputValues() {
ArrayList<Object[]> parameters = new ArrayList<Object[]>(
Arrays.asList(new Object[][]{
{ new InetTestServer() }
}));
if (Epoll.isAvailable()) {
parameters.add(new Object[]{
new UnixDomainServer(EpollServerDomainSocketChannel.class, EpollEventLoopGroup::new)
});
}
if (KQueue.isAvailable()) {
parameters.add(new Object[]{
new UnixDomainServer(KQueueServerDomainSocketChannel.class, KQueueEventLoopGroup::new)
});
}
return parameters;
}
private final TestServer testServer;
public HttpBlobStoreTest(TestServer testServer) {
this.testServer = testServer;
}
private HttpBlobStore createHttpBlobStore(
ServerChannel serverChannel, int timeoutSeconds, @Nullable final Credentials creds)
throws Exception {
SocketAddress socketAddress = serverChannel.localAddress();
if (socketAddress instanceof DomainSocketAddress) {
DomainSocketAddress domainSocketAddress = (DomainSocketAddress) socketAddress;
URI uri = new URI("http://localhost");
return HttpBlobStore.create(
domainSocketAddress,
uri,
timeoutSeconds,
/* remoteMaxConnections= */ 0,
ImmutableList.of(),
creds);
} else if (socketAddress instanceof InetSocketAddress) {
InetSocketAddress inetSocketAddress = (InetSocketAddress) socketAddress;
URI uri = new URI("http://localhost:" + inetSocketAddress.getPort());
return HttpBlobStore.create(
uri, timeoutSeconds, /* remoteMaxConnections= */ 0, ImmutableList.of(), creds);
} else {
throw new IllegalStateException(
"unsupported socket address class " + socketAddress.getClass());
}
}
@Test(expected = ConnectException.class, timeout = 30000)
public void connectTimeout() throws Exception {
ServerChannel server = testServer.start(new ChannelInboundHandlerAdapter() {});
testServer.stop(server);
Credentials credentials = newCredentials();
HttpBlobStore blobStore = createHttpBlobStore(server, /* timeoutSeconds= */ 1, credentials);
getFromFuture(blobStore.get("key", new ByteArrayOutputStream()));
fail("Exception expected");
}
@Test(expected = UploadTimeoutException.class, timeout = 30000)
public void uploadTimeout() throws Exception {
ServerChannel server = null;
try {
server =
testServer.start(
new SimpleChannelInboundHandler<FullHttpRequest>() {
@Override
protected void channelRead0(
ChannelHandlerContext channelHandlerContext, FullHttpRequest fullHttpRequest) {
// Don't respond and force a client timeout.
}
});
Credentials credentials = newCredentials();
HttpBlobStore blobStore = createHttpBlobStore(server, /* timeoutSeconds= */ 1, credentials);
byte[] data = "File Contents".getBytes(Charsets.US_ASCII);
ByteArrayInputStream in = new ByteArrayInputStream(data);
blobStore.put("key", data.length, in);
fail("Exception expected");
} finally {
testServer.stop(server);
}
}
@Test(expected = DownloadTimeoutException.class, timeout = 30000)
public void downloadTimeout() throws Exception {
ServerChannel server = null;
try {
server =
testServer.start(
new SimpleChannelInboundHandler<FullHttpRequest>() {
@Override
protected void channelRead0(
ChannelHandlerContext channelHandlerContext, FullHttpRequest fullHttpRequest) {
// Don't respond and force a client timeout.
}
});
Credentials credentials = newCredentials();
HttpBlobStore blobStore = createHttpBlobStore(server, /* timeoutSeconds= */ 1, credentials);
getFromFuture(blobStore.get("key", new ByteArrayOutputStream()));
fail("Exception expected");
} finally {
testServer.stop(server);
}
}
@Test
public void expiredAuthTokensShouldBeRetried_get() throws Exception {
expiredAuthTokensShouldBeRetried_get(
HttpBlobStoreTest.NotAuthorizedHandler.ErrorType.UNAUTHORIZED);
expiredAuthTokensShouldBeRetried_get(
HttpBlobStoreTest.NotAuthorizedHandler.ErrorType.INVALID_TOKEN);
}
private void expiredAuthTokensShouldBeRetried_get(
HttpBlobStoreTest.NotAuthorizedHandler.ErrorType errorType) throws Exception {
ServerChannel server = null;
try {
server = testServer.start(new NotAuthorizedHandler(errorType));
Credentials credentials = newCredentials();
HttpBlobStore blobStore = createHttpBlobStore(server, /* timeoutSeconds= */ 1, credentials);
ByteArrayOutputStream out = Mockito.spy(new ByteArrayOutputStream());
getFromFuture(blobStore.get("key", out));
assertThat(out.toString(Charsets.US_ASCII.name())).isEqualTo("File Contents");
verify(credentials, times(1)).refresh();
verify(credentials, times(2)).getRequestMetadata(any(URI.class));
verify(credentials, times(2)).hasRequestMetadata();
// The caller is responsible to the close the stream.
verify(out, never()).close();
verifyNoMoreInteractions(credentials);
} finally {
testServer.stop(server);
}
}
@Test
public void expiredAuthTokensShouldBeRetried_put() throws Exception {
expiredAuthTokensShouldBeRetried_put(
HttpBlobStoreTest.NotAuthorizedHandler.ErrorType.UNAUTHORIZED);
expiredAuthTokensShouldBeRetried_put(
HttpBlobStoreTest.NotAuthorizedHandler.ErrorType.INVALID_TOKEN);
}
private void expiredAuthTokensShouldBeRetried_put(
HttpBlobStoreTest.NotAuthorizedHandler.ErrorType errorType) throws Exception {
ServerChannel server = null;
try {
server = testServer.start(new NotAuthorizedHandler(errorType));
Credentials credentials = newCredentials();
HttpBlobStore blobStore = createHttpBlobStore(server, /* timeoutSeconds= */ 1, credentials);
byte[] data = "File Contents".getBytes(Charsets.US_ASCII);
ByteArrayInputStream in = new ByteArrayInputStream(data);
blobStore.put("key", data.length, in);
verify(credentials, times(1)).refresh();
verify(credentials, times(2)).getRequestMetadata(any(URI.class));
verify(credentials, times(2)).hasRequestMetadata();
verifyNoMoreInteractions(credentials);
} finally {
testServer.stop(server);
}
}
@Test
public void errorCodesThatShouldNotBeRetried_get() {
errorCodeThatShouldNotBeRetried_get(
HttpBlobStoreTest.NotAuthorizedHandler.ErrorType.INSUFFICIENT_SCOPE);
errorCodeThatShouldNotBeRetried_get(
HttpBlobStoreTest.NotAuthorizedHandler.ErrorType.INVALID_REQUEST);
}
private void errorCodeThatShouldNotBeRetried_get(
HttpBlobStoreTest.NotAuthorizedHandler.ErrorType errorType) {
ServerChannel server = null;
try {
server = testServer.start(new NotAuthorizedHandler(errorType));
Credentials credentials = newCredentials();
HttpBlobStore blobStore = createHttpBlobStore(server, /* timeoutSeconds= */ 1, credentials);
getFromFuture(blobStore.get("key", new ByteArrayOutputStream()));
fail("Exception expected.");
} catch (Exception e) {
assertThat(e).isInstanceOf(HttpException.class);
assertThat(((HttpException) e).response().status())
.isEqualTo(HttpResponseStatus.UNAUTHORIZED);
} finally {
testServer.stop(server);
}
}
@Test
public void errorCodesThatShouldNotBeRetried_put() {
errorCodeThatShouldNotBeRetried_put(
HttpBlobStoreTest.NotAuthorizedHandler.ErrorType.INSUFFICIENT_SCOPE);
errorCodeThatShouldNotBeRetried_put(
HttpBlobStoreTest.NotAuthorizedHandler.ErrorType.INVALID_REQUEST);
}
private void errorCodeThatShouldNotBeRetried_put(
HttpBlobStoreTest.NotAuthorizedHandler.ErrorType errorType) {
ServerChannel server = null;
try {
server = testServer.start(new NotAuthorizedHandler(errorType));
Credentials credentials = newCredentials();
HttpBlobStore blobStore = createHttpBlobStore(server, /* timeoutSeconds= */ 1, credentials);
blobStore.put("key", 1, new ByteArrayInputStream(new byte[]{0}));
fail("Exception expected.");
} catch (Exception e) {
assertThat(e).isInstanceOf(HttpException.class);
assertThat(((HttpException) e).response().status())
.isEqualTo(HttpResponseStatus.UNAUTHORIZED);
} finally {
testServer.stop(server);
}
}
private Credentials newCredentials() throws Exception {
Credentials credentials = mock(Credentials.class);
when(credentials.hasRequestMetadata()).thenReturn(true);
Map<String, List<String>> headers = new HashMap<>();
headers.put("Authorization", singletonList("Bearer invalidToken"));
when(credentials.getRequestMetadata(any(URI.class))).thenReturn(headers);
Mockito.doAnswer(
(mock) -> {
Map<String, List<String>> headers2 = new HashMap<>();
headers2.put("Authorization", singletonList("Bearer validToken"));
when(credentials.getRequestMetadata(any(URI.class))).thenReturn(headers2);
return null;
})
.when(credentials)
.refresh();
return credentials;
}
/**
* {@link ChannelHandler} that on the first request responds with a 401 UNAUTHORIZED status code,
* which the client is expected to retry once with a new authentication token.
*/
@Sharable
static class NotAuthorizedHandler extends SimpleChannelInboundHandler<FullHttpRequest> {
enum ErrorType {
UNAUTHORIZED,
INVALID_TOKEN,
INSUFFICIENT_SCOPE,
INVALID_REQUEST
}
private final ErrorType errorType;
private int messageCount;
NotAuthorizedHandler(ErrorType errorType) {
this.errorType = errorType;
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) {
if (messageCount == 0) {
if (!"Bearer invalidToken".equals(request.headers().get(HttpHeaderNames.AUTHORIZATION))) {
ctx.writeAndFlush(
new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR))
.addListener(ChannelFutureListener.CLOSE);
return;
}
final FullHttpResponse response;
if (errorType == ErrorType.UNAUTHORIZED) {
response =
new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED);
} else {
response =
new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED);
response
.headers()
.set(
HttpHeaderNames.WWW_AUTHENTICATE,
"Bearer realm=\"localhost\","
+ "error=\""
+ errorType.name().toLowerCase()
+ "\","
+ "error_description=\"The access token expired\"");
}
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
messageCount++;
} else if (messageCount == 1) {
if (!"Bearer validToken".equals(request.headers().get(HttpHeaderNames.AUTHORIZATION))) {
ctx.writeAndFlush(
new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR))
.addListener(ChannelFutureListener.CLOSE);
return;
}
ByteBuf content = ctx.alloc().buffer();
content.writeCharSequence("File Contents", Charsets.US_ASCII);
FullHttpResponse response =
new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, content);
HttpUtil.setKeepAlive(response, true);
HttpUtil.setContentLength(response, content.readableBytes());
ctx.writeAndFlush(response);
messageCount++;
} else {
// No third message expected.
ctx.writeAndFlush(
new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR))
.addListener(ChannelFutureListener.CLOSE);
}
}
}
}