blob: c677874b4ca04e8318d7e2c8789ca4cda17a93de [file] [log] [blame]
// Copyright 2018 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.devtools.build.lib.remote.blobstore.http;
import static com.google.common.truth.Truth.assertThat;
import static com.google.devtools.build.lib.remote.util.Utils.getFromFuture;
import static java.util.Collections.singletonList;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.auth.Credentials;
import com.google.common.base.Charsets;
import com.google.devtools.build.lib.remote.blobstore.http.HttpBlobStoreTest.NotAuthorizedHandler.ErrorType;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.ServerSocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.timeout.ReadTimeoutException;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.net.ConnectException;
import java.net.URI;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mockito;
/** Tests for {@link HttpBlobStore}. */
@RunWith(JUnit4.class)
public class HttpBlobStoreTest {
private ServerSocketChannel startServer(ChannelHandler handler) throws Exception {
EventLoopGroup eventLoop = new NioEventLoopGroup(1);
ServerBootstrap sb =
new ServerBootstrap()
.group(eventLoop)
.channel(NioServerSocketChannel.class)
.childHandler(
new ChannelInitializer<NioSocketChannel>() {
@Override
protected void initChannel(NioSocketChannel ch) {
ch.pipeline().addLast(new HttpServerCodec());
ch.pipeline().addLast(new HttpObjectAggregator(1000));
ch.pipeline().addLast(handler);
}
});
return ((ServerSocketChannel) sb.bind("localhost", 0).sync().channel());
}
@Test(expected = ConnectException.class, timeout = 30000)
public void timeoutShouldWork_connect() throws Exception {
ServerSocketChannel server = startServer(new ChannelHandlerAdapter() {});
int serverPort = server.localAddress().getPort();
closeServerChannel(server);
Credentials credentials = newCredentials();
HttpBlobStore blobStore =
new HttpBlobStore(new URI("http://localhost:" + serverPort), 5, 0, credentials);
getFromFuture(blobStore.get("key", new ByteArrayOutputStream()));
fail("Exception expected");
}
@Test(expected = ReadTimeoutException.class, timeout = 30000)
public void timeoutShouldWork_read() throws Exception {
ServerSocketChannel server = null;
try {
server =
startServer(
new SimpleChannelInboundHandler<FullHttpRequest>() {
@Override
protected void channelRead0(
ChannelHandlerContext channelHandlerContext, FullHttpRequest fullHttpRequest) {
// Don't respond and force a client timeout.
}
});
int serverPort = server.localAddress().getPort();
Credentials credentials = newCredentials();
HttpBlobStore blobStore =
new HttpBlobStore(new URI("http://localhost:" + serverPort), 5, 0, credentials);
getFromFuture(blobStore.get("key", new ByteArrayOutputStream()));
fail("Exception expected");
} finally {
closeServerChannel(server);
}
}
@Test
public void expiredAuthTokensShouldBeRetried_get() throws Exception {
expiredAuthTokensShouldBeRetried_get(ErrorType.UNAUTHORIZED);
expiredAuthTokensShouldBeRetried_get(ErrorType.INVALID_TOKEN);
}
private void expiredAuthTokensShouldBeRetried_get(ErrorType errorType) throws Exception {
ServerSocketChannel server = null;
try {
server = startServer(new NotAuthorizedHandler(errorType));
int serverPort = server.localAddress().getPort();
Credentials credentials = newCredentials();
HttpBlobStore blobStore =
new HttpBlobStore(new URI("http://localhost:" + serverPort), 30, 0, credentials);
ByteArrayOutputStream out = Mockito.spy(new ByteArrayOutputStream());
getFromFuture(blobStore.get("key", out));
assertThat(out.toString(Charsets.US_ASCII.name())).isEqualTo("File Contents");
verify(credentials, times(1)).refresh();
verify(credentials, times(2)).getRequestMetadata(any(URI.class));
verify(credentials, times(2)).hasRequestMetadata();
// The caller is responsible to the close the stream.
verify(out, never()).close();
verifyNoMoreInteractions(credentials);
} finally {
closeServerChannel(server);
}
}
@Test
public void expiredAuthTokensShouldBeRetried_put() throws Exception {
expiredAuthTokensShouldBeRetried_put(ErrorType.UNAUTHORIZED);
expiredAuthTokensShouldBeRetried_put(ErrorType.INVALID_TOKEN);
}
private void expiredAuthTokensShouldBeRetried_put(ErrorType errorType) throws Exception {
ServerSocketChannel server = null;
try {
server = startServer(new NotAuthorizedHandler(errorType));
int serverPort = server.localAddress().getPort();
Credentials credentials = newCredentials();
HttpBlobStore blobStore =
new HttpBlobStore(new URI("http://localhost:" + serverPort), 30, 0, credentials);
byte[] data = "File Contents".getBytes(Charsets.US_ASCII);
ByteArrayInputStream in = new ByteArrayInputStream(data);
blobStore.put("key", data.length, in);
verify(credentials, times(1)).refresh();
verify(credentials, times(2)).getRequestMetadata(any(URI.class));
verify(credentials, times(2)).hasRequestMetadata();
verifyNoMoreInteractions(credentials);
} finally {
closeServerChannel(server);
}
}
@Test
public void errorCodesThatShouldNotBeRetried_get() throws InterruptedException {
errorCodeThatShouldNotBeRetried_get(ErrorType.INSUFFICIENT_SCOPE);
errorCodeThatShouldNotBeRetried_get(ErrorType.INVALID_REQUEST);
}
private void errorCodeThatShouldNotBeRetried_get(ErrorType errorType)
throws InterruptedException {
ServerSocketChannel server = null;
try {
server = startServer(new NotAuthorizedHandler(errorType));
int serverPort = server.localAddress().getPort();
Credentials credentials = newCredentials();
HttpBlobStore blobStore =
new HttpBlobStore(new URI("http://localhost:" + serverPort), 30, 0, credentials);
getFromFuture(blobStore.get("key", new ByteArrayOutputStream()));
fail("Exception expected.");
} catch (Exception e) {
assertThat(e).isInstanceOf(HttpException.class);
assertThat(((HttpException) e).response().status())
.isEqualTo(HttpResponseStatus.UNAUTHORIZED);
} finally {
closeServerChannel(server);
}
}
@Test
public void errorCodesThatShouldNotBeRetried_put() throws InterruptedException {
errorCodeThatShouldNotBeRetried_put(ErrorType.INSUFFICIENT_SCOPE);
errorCodeThatShouldNotBeRetried_put(ErrorType.INVALID_REQUEST);
}
private void errorCodeThatShouldNotBeRetried_put(ErrorType errorType)
throws InterruptedException {
ServerSocketChannel server = null;
try {
server = startServer(new NotAuthorizedHandler(errorType));
int serverPort = server.localAddress().getPort();
Credentials credentials = newCredentials();
HttpBlobStore blobStore =
new HttpBlobStore(new URI("http://localhost:" + serverPort), 30, 0, credentials);
blobStore.put("key", 1, new ByteArrayInputStream(new byte[] {0}));
fail("Exception expected.");
} catch (Exception e) {
assertThat(e).isInstanceOf(HttpException.class);
assertThat(((HttpException) e).response().status())
.isEqualTo(HttpResponseStatus.UNAUTHORIZED);
} finally {
closeServerChannel(server);
}
}
private Credentials newCredentials() throws Exception {
Credentials credentials = mock(Credentials.class);
when(credentials.hasRequestMetadata()).thenReturn(true);
Map<String, List<String>> headers = new HashMap<>();
headers.put("Authorization", singletonList("Bearer invalidToken"));
when(credentials.getRequestMetadata(any(URI.class))).thenReturn(headers);
Mockito.doAnswer(
(mock) -> {
Map<String, List<String>> headers2 = new HashMap<>();
headers2.put("Authorization", singletonList("Bearer validToken"));
when(credentials.getRequestMetadata(any(URI.class))).thenReturn(headers2);
return null;
})
.when(credentials)
.refresh();
return credentials;
}
/**
* {@link ChannelHandler} that on the first request responds with a 401 UNAUTHORIZED status code,
* which the client is expected to retry once with a new authentication token.
*/
@Sharable
static class NotAuthorizedHandler extends SimpleChannelInboundHandler<FullHttpRequest> {
enum ErrorType {
UNAUTHORIZED,
INVALID_TOKEN,
INSUFFICIENT_SCOPE,
INVALID_REQUEST
}
private final ErrorType errorType;
private int messageCount;
NotAuthorizedHandler(ErrorType errorType) {
this.errorType = errorType;
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) {
if (messageCount == 0) {
if (!"Bearer invalidToken".equals(request.headers().get(HttpHeaderNames.AUTHORIZATION))) {
ctx.writeAndFlush(
new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR))
.addListener(ChannelFutureListener.CLOSE);
return;
}
final FullHttpResponse response;
if (errorType == ErrorType.UNAUTHORIZED) {
response =
new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED);
} else {
response =
new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED);
response
.headers()
.set(
HttpHeaderNames.WWW_AUTHENTICATE,
"Bearer realm=\"localhost\","
+ "error=\""
+ errorType.name().toLowerCase()
+ "\","
+ "error_description=\"The access token expired\"");
}
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
messageCount++;
} else if (messageCount == 1) {
if (!"Bearer validToken".equals(request.headers().get(HttpHeaderNames.AUTHORIZATION))) {
ctx.writeAndFlush(
new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR))
.addListener(ChannelFutureListener.CLOSE);
return;
}
ByteBuf content = ctx.alloc().buffer();
content.writeCharSequence("File Contents", Charsets.US_ASCII);
FullHttpResponse response =
new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, content);
HttpUtil.setKeepAlive(response, true);
HttpUtil.setContentLength(response, content.readableBytes());
ctx.writeAndFlush(response);
messageCount++;
} else {
// No third message expected.
ctx.writeAndFlush(
new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR))
.addListener(ChannelFutureListener.CLOSE);
}
}
}
private void closeServerChannel(ServerSocketChannel server) throws InterruptedException {
if (server != null) {
server.close();
server.closeFuture().sync();
server.eventLoop().shutdownGracefully().sync();
}
}
}