| // Copyright 2018 The Bazel Authors. All rights reserved. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| package com.google.devtools.build.lib.remote.blobstore.http; |
| |
| import static com.google.common.truth.Truth.assertThat; |
| import static java.util.Collections.singletonList; |
| import static org.junit.Assert.fail; |
| import static org.mockito.Matchers.any; |
| import static org.mockito.Mockito.mock; |
| import static org.mockito.Mockito.never; |
| import static org.mockito.Mockito.times; |
| import static org.mockito.Mockito.verify; |
| import static org.mockito.Mockito.verifyNoMoreInteractions; |
| import static org.mockito.Mockito.when; |
| |
| import com.google.auth.Credentials; |
| import com.google.common.base.Charsets; |
| import com.google.devtools.build.lib.remote.blobstore.http.HttpBlobStoreTest.NotAuthorizedHandler.ErrorType; |
| import io.netty.bootstrap.ServerBootstrap; |
| import io.netty.buffer.ByteBuf; |
| import io.netty.channel.ChannelFutureListener; |
| import io.netty.channel.ChannelHandler; |
| import io.netty.channel.ChannelHandler.Sharable; |
| import io.netty.channel.ChannelHandlerAdapter; |
| import io.netty.channel.ChannelHandlerContext; |
| import io.netty.channel.ChannelInitializer; |
| import io.netty.channel.EventLoopGroup; |
| import io.netty.channel.SimpleChannelInboundHandler; |
| import io.netty.channel.nio.NioEventLoopGroup; |
| import io.netty.channel.socket.ServerSocketChannel; |
| import io.netty.channel.socket.nio.NioServerSocketChannel; |
| import io.netty.channel.socket.nio.NioSocketChannel; |
| import io.netty.handler.codec.http.DefaultFullHttpResponse; |
| import io.netty.handler.codec.http.FullHttpRequest; |
| import io.netty.handler.codec.http.FullHttpResponse; |
| import io.netty.handler.codec.http.HttpHeaderNames; |
| import io.netty.handler.codec.http.HttpObjectAggregator; |
| import io.netty.handler.codec.http.HttpResponseStatus; |
| import io.netty.handler.codec.http.HttpServerCodec; |
| import io.netty.handler.codec.http.HttpUtil; |
| import io.netty.handler.codec.http.HttpVersion; |
| import io.netty.handler.timeout.ReadTimeoutException; |
| import java.io.ByteArrayInputStream; |
| import java.io.ByteArrayOutputStream; |
| import java.net.ConnectException; |
| import java.net.URI; |
| import java.util.HashMap; |
| import java.util.List; |
| import java.util.Map; |
| import org.junit.Test; |
| import org.junit.runner.RunWith; |
| import org.junit.runners.JUnit4; |
| import org.mockito.Mockito; |
| |
| /** Tests for {@link HttpBlobStore}. */ |
| @RunWith(JUnit4.class) |
| public class HttpBlobStoreTest { |
| |
| private ServerSocketChannel startServer(ChannelHandler handler) throws Exception { |
| EventLoopGroup eventLoop = new NioEventLoopGroup(1); |
| ServerBootstrap sb = |
| new ServerBootstrap() |
| .group(eventLoop) |
| .channel(NioServerSocketChannel.class) |
| .childHandler( |
| new ChannelInitializer<NioSocketChannel>() { |
| @Override |
| protected void initChannel(NioSocketChannel ch) { |
| ch.pipeline().addLast(new HttpServerCodec()); |
| ch.pipeline().addLast(new HttpObjectAggregator(1000)); |
| ch.pipeline().addLast(handler); |
| } |
| }); |
| return ((ServerSocketChannel) sb.bind("localhost", 0).sync().channel()); |
| } |
| |
| @Test(expected = ConnectException.class, timeout = 30000) |
| public void timeoutShouldWork_connect() throws Exception { |
| ServerSocketChannel server = startServer(new ChannelHandlerAdapter() {}); |
| int serverPort = server.localAddress().getPort(); |
| closeServerChannel(server); |
| |
| Credentials credentials = newCredentials(); |
| HttpBlobStore blobStore = |
| new HttpBlobStore(new URI("http://localhost:" + serverPort), 5, credentials); |
| blobStore.get("key", new ByteArrayOutputStream()); |
| fail("Exception expected"); |
| } |
| |
| @Test(expected = ReadTimeoutException.class, timeout = 30000) |
| public void timeoutShouldWork_read() throws Exception { |
| ServerSocketChannel server = null; |
| try { |
| server = |
| startServer( |
| new SimpleChannelInboundHandler<FullHttpRequest>() { |
| @Override |
| protected void channelRead0( |
| ChannelHandlerContext channelHandlerContext, FullHttpRequest fullHttpRequest) { |
| // Don't respond and force a client timeout. |
| } |
| }); |
| int serverPort = server.localAddress().getPort(); |
| |
| Credentials credentials = newCredentials(); |
| HttpBlobStore blobStore = |
| new HttpBlobStore(new URI("http://localhost:" + serverPort), 5, credentials); |
| blobStore.get("key", new ByteArrayOutputStream()); |
| fail("Exception expected"); |
| } finally { |
| closeServerChannel(server); |
| } |
| } |
| |
| @Test |
| public void expiredAuthTokensShouldBeRetried_get() throws Exception { |
| expiredAuthTokensShouldBeRetried_get(ErrorType.UNAUTHORIZED); |
| expiredAuthTokensShouldBeRetried_get(ErrorType.INVALID_TOKEN); |
| } |
| |
| private void expiredAuthTokensShouldBeRetried_get(ErrorType errorType) throws Exception { |
| ServerSocketChannel server = null; |
| try { |
| server = startServer(new NotAuthorizedHandler(errorType)); |
| int serverPort = server.localAddress().getPort(); |
| |
| Credentials credentials = newCredentials(); |
| HttpBlobStore blobStore = |
| new HttpBlobStore(new URI("http://localhost:" + serverPort), 30, credentials); |
| ByteArrayOutputStream out = Mockito.spy(new ByteArrayOutputStream()); |
| blobStore.get("key", out); |
| assertThat(out.toString(Charsets.US_ASCII.name())).isEqualTo("File Contents"); |
| verify(credentials, times(1)).refresh(); |
| verify(credentials, times(2)).getRequestMetadata(any(URI.class)); |
| verify(credentials, times(2)).hasRequestMetadata(); |
| // The caller is responsible to the close the stream. |
| verify(out, never()).close(); |
| verifyNoMoreInteractions(credentials); |
| } finally { |
| closeServerChannel(server); |
| } |
| } |
| |
| @Test |
| public void expiredAuthTokensShouldBeRetried_put() throws Exception { |
| expiredAuthTokensShouldBeRetried_put(ErrorType.UNAUTHORIZED); |
| expiredAuthTokensShouldBeRetried_put(ErrorType.INVALID_TOKEN); |
| } |
| |
| private void expiredAuthTokensShouldBeRetried_put(ErrorType errorType) throws Exception { |
| ServerSocketChannel server = null; |
| try { |
| server = startServer(new NotAuthorizedHandler(errorType)); |
| int serverPort = server.localAddress().getPort(); |
| |
| Credentials credentials = newCredentials(); |
| HttpBlobStore blobStore = |
| new HttpBlobStore(new URI("http://localhost:" + serverPort), 30, credentials); |
| byte[] data = "File Contents".getBytes(Charsets.US_ASCII); |
| ByteArrayInputStream in = new ByteArrayInputStream(data); |
| blobStore.put("key", data.length, in); |
| verify(credentials, times(1)).refresh(); |
| verify(credentials, times(2)).getRequestMetadata(any(URI.class)); |
| verify(credentials, times(2)).hasRequestMetadata(); |
| verifyNoMoreInteractions(credentials); |
| } finally { |
| closeServerChannel(server); |
| } |
| } |
| |
| @Test |
| public void errorCodesThatShouldNotBeRetried_get() throws InterruptedException { |
| errorCodeThatShouldNotBeRetried_get(ErrorType.INSUFFICIENT_SCOPE); |
| errorCodeThatShouldNotBeRetried_get(ErrorType.INVALID_REQUEST); |
| } |
| |
| private void errorCodeThatShouldNotBeRetried_get(ErrorType errorType) |
| throws InterruptedException { |
| ServerSocketChannel server = null; |
| try { |
| server = startServer(new NotAuthorizedHandler(errorType)); |
| int serverPort = server.localAddress().getPort(); |
| |
| Credentials credentials = newCredentials(); |
| HttpBlobStore blobStore = |
| new HttpBlobStore(new URI("http://localhost:" + serverPort), 30, credentials); |
| blobStore.get("key", new ByteArrayOutputStream()); |
| fail("Exception expected."); |
| } catch (Exception e) { |
| assertThat(e).isInstanceOf(HttpException.class); |
| assertThat(((HttpException) e).response().status()) |
| .isEqualTo(HttpResponseStatus.UNAUTHORIZED); |
| } finally { |
| closeServerChannel(server); |
| } |
| } |
| |
| @Test |
| public void errorCodesThatShouldNotBeRetried_put() throws InterruptedException { |
| errorCodeThatShouldNotBeRetried_put(ErrorType.INSUFFICIENT_SCOPE); |
| errorCodeThatShouldNotBeRetried_put(ErrorType.INVALID_REQUEST); |
| } |
| |
| private void errorCodeThatShouldNotBeRetried_put(ErrorType errorType) |
| throws InterruptedException { |
| ServerSocketChannel server = null; |
| try { |
| server = startServer(new NotAuthorizedHandler(errorType)); |
| int serverPort = server.localAddress().getPort(); |
| |
| Credentials credentials = newCredentials(); |
| HttpBlobStore blobStore = |
| new HttpBlobStore(new URI("http://localhost:" + serverPort), 30, credentials); |
| blobStore.put("key", 1, new ByteArrayInputStream(new byte[] {0})); |
| fail("Exception expected."); |
| } catch (Exception e) { |
| assertThat(e).isInstanceOf(HttpException.class); |
| assertThat(((HttpException) e).response().status()) |
| .isEqualTo(HttpResponseStatus.UNAUTHORIZED); |
| } finally { |
| closeServerChannel(server); |
| } |
| } |
| |
| private Credentials newCredentials() throws Exception { |
| Credentials credentials = mock(Credentials.class); |
| when(credentials.hasRequestMetadata()).thenReturn(true); |
| Map<String, List<String>> headers = new HashMap<>(); |
| headers.put("Authorization", singletonList("Bearer invalidToken")); |
| when(credentials.getRequestMetadata(any(URI.class))).thenReturn(headers); |
| Mockito.doAnswer( |
| (mock) -> { |
| Map<String, List<String>> headers2 = new HashMap<>(); |
| headers2.put("Authorization", singletonList("Bearer validToken")); |
| when(credentials.getRequestMetadata(any(URI.class))).thenReturn(headers2); |
| return null; |
| }) |
| .when(credentials) |
| .refresh(); |
| return credentials; |
| } |
| |
| /** |
| * {@link ChannelHandler} that on the first request responds with a 401 UNAUTHORIZED status code, |
| * which the client is expected to retry once with a new authentication token. |
| */ |
| @Sharable |
| static class NotAuthorizedHandler extends SimpleChannelInboundHandler<FullHttpRequest> { |
| |
| enum ErrorType { |
| UNAUTHORIZED, |
| INVALID_TOKEN, |
| INSUFFICIENT_SCOPE, |
| INVALID_REQUEST |
| } |
| |
| private final ErrorType errorType; |
| private int messageCount; |
| |
| NotAuthorizedHandler(ErrorType errorType) { |
| this.errorType = errorType; |
| } |
| |
| @Override |
| protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) { |
| if (messageCount == 0) { |
| if (!"Bearer invalidToken".equals(request.headers().get(HttpHeaderNames.AUTHORIZATION))) { |
| ctx.writeAndFlush( |
| new DefaultFullHttpResponse( |
| HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR)) |
| .addListener(ChannelFutureListener.CLOSE); |
| return; |
| } |
| final FullHttpResponse response; |
| if (errorType == ErrorType.UNAUTHORIZED) { |
| response = |
| new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED); |
| } else { |
| response = |
| new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED); |
| response |
| .headers() |
| .set( |
| HttpHeaderNames.WWW_AUTHENTICATE, |
| "Bearer realm=\"localhost\"," |
| + "error=\"" |
| + errorType.name().toLowerCase() |
| + "\"," |
| + "error_description=\"The access token expired\""); |
| } |
| ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); |
| messageCount++; |
| } else if (messageCount == 1) { |
| if (!"Bearer validToken".equals(request.headers().get(HttpHeaderNames.AUTHORIZATION))) { |
| ctx.writeAndFlush( |
| new DefaultFullHttpResponse( |
| HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR)) |
| .addListener(ChannelFutureListener.CLOSE); |
| return; |
| } |
| ByteBuf content = ctx.alloc().buffer(); |
| content.writeCharSequence("File Contents", Charsets.US_ASCII); |
| FullHttpResponse response = |
| new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, content); |
| HttpUtil.setKeepAlive(response, true); |
| HttpUtil.setContentLength(response, content.readableBytes()); |
| ctx.writeAndFlush(response); |
| messageCount++; |
| } else { |
| // No third message expected. |
| ctx.writeAndFlush( |
| new DefaultFullHttpResponse( |
| HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR)) |
| .addListener(ChannelFutureListener.CLOSE); |
| } |
| } |
| } |
| |
| private void closeServerChannel(ServerSocketChannel server) throws InterruptedException { |
| if (server != null) { |
| server.close(); |
| server.closeFuture().sync(); |
| server.eventLoop().shutdownGracefully().sync(); |
| } |
| } |
| } |