Extend the repository_context.download{,_and_extract} with authentication
Extend, as described in the authentication API design[1], the functions
download{,_and_extract} of the repository context by an additional
parameter 'auth'. If given, the information provided there will be used
to authenticate upon download.
[1] https://github.com/bazelbuild/proposals/blob/master/designs/2019-05-27-auth.md
Change-Id: Ifa119987accbcc56219e9a0c965ed3e94dc9d38e
PiperOrigin-RevId: 253760773
diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnector.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnector.java
index e1050bb..8259399 100644
--- a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnector.java
+++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnector.java
@@ -90,6 +90,7 @@
URLConnection connect(
URL originalUrl, ImmutableMap<String, String> requestHeaders)
throws IOException {
+
if (Thread.interrupted()) {
throw new InterruptedIOException();
}
@@ -116,7 +117,7 @@
// appears to be compressed.
continue;
}
- connection.setRequestProperty(entry.getKey(), entry.getValue());
+ connection.addRequestProperty(entry.getKey(), entry.getValue());
}
connection.setConnectTimeout(connectTimeout);
// The read timeout is always large because it stays in effect after this method.
diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexer.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexer.java
index 8798a33..36328ae 100644
--- a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexer.java
+++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexer.java
@@ -30,6 +30,8 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.InterruptedIOException;
+import java.net.URI;
+import java.net.URISyntaxException;
import java.net.URL;
import java.net.URLConnection;
import java.util.ArrayList;
@@ -37,6 +39,7 @@
import java.util.Deque;
import java.util.LinkedList;
import java.util.List;
+import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
@@ -92,6 +95,10 @@
this.sleeper = sleeper;
}
+ public HttpStream connect(List<URL> urls, String sha256) throws IOException {
+ return connect(urls, sha256, ImmutableMap.<URI, Map<String, String>>of());
+ }
+
/**
* Establishes reliable HTTP connection to a good mirror URL.
*
@@ -115,7 +122,8 @@
* @throws InterruptedIOException if current thread is being cast into oblivion
* @throws IllegalArgumentException if {@code urls} is empty or has an unsupported protocol
*/
- public HttpStream connect(List<URL> urls, String sha256) throws IOException {
+ public HttpStream connect(
+ List<URL> urls, String sha256, Map<URI, Map<String, String>> authHeaders) throws IOException {
Preconditions.checkNotNull(sha256);
HttpUtils.checkUrlsArgument(urls);
if (Thread.interrupted()) {
@@ -123,7 +131,7 @@
}
// If there's only one URL then there's no need for us to run all our fancy thread stuff.
if (urls.size() == 1) {
- return establishConnection(urls.get(0), sha256);
+ return establishConnection(urls.get(0), sha256, authHeaders);
}
MutexConditionSharedMemory context = new MutexConditionSharedMemory();
// The parent thread always holds the lock except when released by wait().
@@ -132,7 +140,7 @@
long now = clock.currentTimeMillis();
long startAtTime = now;
for (URL url : urls) {
- context.jobs.add(new WorkItem(url, sha256, startAtTime));
+ context.jobs.add(new WorkItem(url, sha256, startAtTime, authHeaders));
startAtTime += FAILOVER_DELAY_MS;
}
// Create the worker thread pool.
@@ -204,11 +212,13 @@
final URL url;
final String sha256;
final long startAtTime;
+ final Map<URI, Map<String, String>> authHeaders;
- WorkItem(URL url, String sha256, long startAtTime) {
+ WorkItem(URL url, String sha256, long startAtTime, Map<URI, Map<String, String>> authHeaders) {
this.url = url;
this.sha256 = sha256;
this.startAtTime = startAtTime;
+ this.authHeaders = authHeaders;
}
}
@@ -253,7 +263,7 @@
// Now we're actually going to attempt to connect to the remote server.
HttpStream result;
try {
- result = establishConnection(work.url, work.sha256);
+ result = establishConnection(work.url, work.sha256, work.authHeaders);
} catch (InterruptedIOException e) {
// The parent thread got its result from another thread and killed this one.
synchronized (context) {
@@ -296,21 +306,38 @@
}
}
- private HttpStream establishConnection(final URL url, String sha256) throws IOException {
- final URLConnection connection = connector.connect(url, REQUEST_HEADERS);
+ private HttpStream establishConnection(
+ final URL url, String sha256, Map<URI, Map<String, String>> additionalHeaders)
+ throws IOException {
+ ImmutableMap<String, String> headers = REQUEST_HEADERS;
+ try {
+ if (additionalHeaders.containsKey(url.toURI())) {
+ headers =
+ ImmutableMap.<String, String>builder()
+ .putAll(headers)
+ .putAll(additionalHeaders.get(url.toURI()))
+ .build();
+ }
+ } catch (URISyntaxException e) {
+ // If we can't convert the URL to a URI (because it is syntactically malformed), still try to
+ // do the connection, not adding authentication information as we cannot look it up.
+ }
+ final URLConnection connection = connector.connect(url, headers);
+ final Map<String, String> allHeaders = headers;
return httpStreamFactory.create(
- connection, url, sha256,
+ connection,
+ url,
+ sha256,
new Reconnector() {
@Override
- public URLConnection connect(
- Throwable cause, ImmutableMap<String, String> extraHeaders)
- throws IOException {
+ public URLConnection connect(Throwable cause, ImmutableMap<String, String> extraHeaders)
+ throws IOException {
eventHandler.handle(
Event.progress(String.format("Lost connection for %s due to %s", url, cause)));
return connector.connect(
connection.getURL(),
new ImmutableMap.Builder<String, String>()
- .putAll(REQUEST_HEADERS)
+ .putAll(allHeaders)
.putAll(extraHeaders)
.build());
}
diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloader.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloader.java
index 4f6a881..d9fad055 100644
--- a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloader.java
+++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloader.java
@@ -35,6 +35,7 @@
import java.io.IOException;
import java.io.InterruptedIOException;
import java.io.OutputStream;
+import java.net.URI;
import java.net.URL;
import java.util.List;
import java.util.Locale;
@@ -89,6 +90,7 @@
*/
public Path download(
List<URL> urls,
+ Map<URI, Map<String, String>> authHeaders,
String sha256,
String canonicalId,
Optional<String> type,
@@ -199,7 +201,7 @@
// Connect to the best mirror and download the file, while reporting progress to the CLI.
semaphore.acquire();
boolean success = false;
- try (HttpStream payload = multiplexer.connect(urls, sha256);
+ try (HttpStream payload = multiplexer.connect(urls, sha256, authHeaders);
OutputStream out = destination.getOutputStream()) {
ByteStreams.copy(payload, out);
success = true;
diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/skylark/SkylarkRepositoryContext.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/skylark/SkylarkRepositoryContext.java
index 729cdbd..25b521a 100644
--- a/src/main/java/com/google/devtools/build/lib/bazel/repository/skylark/SkylarkRepositoryContext.java
+++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/skylark/SkylarkRepositoryContext.java
@@ -61,9 +61,12 @@
import java.io.IOException;
import java.io.OutputStream;
import java.net.MalformedURLException;
+import java.net.URI;
+import java.net.URISyntaxException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
+import java.util.Base64;
import java.util.List;
import java.util.Map;
@@ -460,8 +463,11 @@
Boolean executable,
Boolean allowFail,
String canonicalId,
+ SkylarkDict<String, SkylarkDict<Object, Object>> auth,
Location location)
throws RepositoryFunctionException, EvalException, InterruptedException {
+ Map<URI, Map<String, String>> authHeaders = getAuthHeaders(auth);
+
List<URL> urls = getUrls(url, /* ensureNonEmpty= */ !allowFail);
RepositoryFunctionException sha256Validation = validateSha256(sha256, location);
if (sha256Validation != null) {
@@ -480,6 +486,7 @@
downloadedPath =
httpDownloader.download(
urls,
+ authHeaders,
sha256,
canonicalId,
Optional.<String>absent(),
@@ -561,8 +568,11 @@
String stripPrefix,
Boolean allowFail,
String canonicalId,
+ SkylarkDict<String, SkylarkDict<Object, Object>> auth,
Location location)
throws RepositoryFunctionException, InterruptedException, EvalException {
+ Map<URI, Map<String, String>> authHeaders = getAuthHeaders(auth);
+
List<URL> urls = getUrls(url, /* ensureNonEmpty= */ !allowFail);
RepositoryFunctionException sha256Validation = validateSha256(sha256, location);
if (sha256Validation != null) {
@@ -590,6 +600,7 @@
downloadedPath =
httpDownloader.download(
urls,
+ authHeaders,
sha256,
canonicalId,
Optional.of(type),
@@ -792,4 +803,49 @@
}
}
}
+
+ /**
+ * From an authentication dict extract a map of headers.
+ *
+ * <p>Given a dict as provided as "auth" argument, compute a map specifying for each URI provided
+ * which additional headers (as usual, represented as a map from Strings to Strings) should
+ * additionally be added to the request. For some form of authentication, in particular basic
+ * authentication, adding those headers is enough; for other forms of authentication other
+ * measures might be necessary.
+ */
+ private static Map<URI, Map<String, String>> getAuthHeaders(
+ SkylarkDict<String, SkylarkDict<Object, Object>> auth)
+ throws RepositoryFunctionException, EvalException {
+ ImmutableMap.Builder<URI, Map<String, String>> headers = new ImmutableMap.Builder<>();
+ for (Map.Entry<String, SkylarkDict<Object, Object>> entry : auth.entrySet()) {
+ try {
+ URL url = new URL(entry.getKey());
+ SkylarkDict<Object, Object> authMap = entry.getValue();
+ if (authMap.containsKey("type")) {
+ if ("basic".equals(authMap.get("type"))) {
+ if (!authMap.containsKey("login") || !authMap.containsKey("password")) {
+ throw new EvalException(
+ null,
+ "Found request to do basic auth for "
+ + entry.getKey()
+ + " without 'login' and 'password' being provided.");
+ }
+ String credentials = authMap.get("login") + ":" + authMap.get("password");
+ headers.put(
+ url.toURI(),
+ ImmutableMap.<String, String>of(
+ "Authorization",
+ "Basic "
+ + Base64.getEncoder()
+ .encodeToString(credentials.getBytes(StandardCharsets.UTF_8))));
+ }
+ }
+ } catch (MalformedURLException e) {
+ throw new RepositoryFunctionException(e, Transience.PERSISTENT);
+ } catch (URISyntaxException e) {
+ throw new EvalException(null, e.getMessage());
+ }
+ }
+ return headers.build();
+ }
}
diff --git a/src/main/java/com/google/devtools/build/lib/skylarkbuildapi/repository/SkylarkRepositoryContextApi.java b/src/main/java/com/google/devtools/build/lib/skylarkbuildapi/repository/SkylarkRepositoryContextApi.java
index bca57b0..eebaf44 100644
--- a/src/main/java/com/google/devtools/build/lib/skylarkbuildapi/repository/SkylarkRepositoryContextApi.java
+++ b/src/main/java/com/google/devtools/build/lib/skylarkbuildapi/repository/SkylarkRepositoryContextApi.java
@@ -368,6 +368,12 @@
doc =
"If set, restrict cache hits to those cases where the file was added to the cache"
+ " with the same canonical id"),
+ @Param(
+ name = "auth",
+ type = SkylarkDict.class,
+ defaultValue = "{}",
+ named = true,
+ doc = "An optional dict specifying authentication information for some of the URLs."),
})
public StructApi download(
Object url,
@@ -376,6 +382,7 @@
Boolean executable,
Boolean allowFail,
String canonicalId,
+ SkylarkDict<String, SkylarkDict<Object, Object>> auth,
Location location)
throws RepositoryFunctionExceptionT, EvalException, InterruptedException;
@@ -503,6 +510,12 @@
doc =
"If set, restrict cache hits to those cases where the file was added to the cache"
+ " with the same canonical id"),
+ @Param(
+ name = "auth",
+ type = SkylarkDict.class,
+ defaultValue = "{}",
+ named = true,
+ doc = "An optional dict specifying authentication information for some of the URLs."),
})
public StructApi downloadAndExtract(
Object url,
@@ -512,6 +525,7 @@
String stripPrefix,
Boolean allowFail,
String canonicalId,
+ SkylarkDict<String, SkylarkDict<Object, Object>> auth,
Location location)
throws RepositoryFunctionExceptionT, InterruptedException, EvalException;
}
diff --git a/src/test/shell/bazel/remote_helpers.sh b/src/test/shell/bazel/remote_helpers.sh
index a6e8492..528f52e 100755
--- a/src/test/shell/bazel/remote_helpers.sh
+++ b/src/test/shell/bazel/remote_helpers.sh
@@ -38,6 +38,29 @@
cd -
}
+# Serves $1 as a file on localhost:$nc_port insisting on authentication (but
+# accepting any credentials.
+# * nc_port - the port nc is listening on.
+# * nc_log - the path to nc's log.
+# * nc_pid - the PID of nc.
+function serve_file_auth() {
+ file_name=served_file.$$
+ cat $1 > "${TEST_TMPDIR}/$file_name"
+ nc_log="${TEST_TMPDIR}/nc.log"
+ rm -f $nc_log
+ touch $nc_log
+ cd "${TEST_TMPDIR}"
+ port_file=server-port.$$
+ rm -f $port_file
+ python $python_server auth $file_name > $port_file &
+ nc_pid=$!
+ while ! grep started $port_file; do sleep 1; done
+ nc_port=$(head -n 1 $port_file)
+ fileserver_port=$nc_port
+ wait_for_server_startup
+ cd -
+}
+
# Creates a jar carnivore.Mongoose and serves it using serve_file.
function serve_jar() {
make_test_jar
diff --git a/src/test/shell/bazel/skylark_repository_test.sh b/src/test/shell/bazel/skylark_repository_test.sh
index 72fc395..aa83a7d 100755
--- a/src/test/shell/bazel/skylark_repository_test.sh
+++ b/src/test/shell/bazel/skylark_repository_test.sh
@@ -1436,6 +1436,47 @@
expect_log "//:b.bzl"
}
+function test_auth_provided() {
+ mkdir x
+ echo 'exports_files(["file.txt"])' > x/BUILD
+ echo 'Hello World' > x/file.txt
+ tar cvf x.tar x
+ serve_file_auth x.tar
+ cat > WORKSPACE <<EOF
+load("//:auth.bzl", "with_auth")
+with_auth(
+ name="ext",
+ url = "http://127.0.0.1:$nc_port/x.tar",
+)
+EOF
+ cat > auth.bzl <<'EOF'
+def _impl(ctx):
+ ctx.download_and_extract(
+ url = ctx.attr.url,
+ # Use the username/password pair hard-coded
+ # in the testing server.
+ auth = {ctx.attr.url : { "type": "basic",
+ "login" : "foo",
+ "password" : "bar"}}
+ )
+
+with_auth = repository_rule(
+ implementation = _impl,
+ attrs = { "url" : attr.string() }
+)
+EOF
+ cat > BUILD <<'EOF'
+genrule(
+ name = "it",
+ srcs = ["@ext//x:file.txt"],
+ outs = ["it.txt"],
+ cmd = "cp $< $@",
+)
+EOF
+ bazel build //:it \
+ || fail "Expected success despite needing a file behind basic auth"
+}
+
function tear_down() {
shutdown_server
diff --git a/src/test/shell/bazel/testing_server.py b/src/test/shell/bazel/testing_server.py
index 314e409..1285ec9 100644
--- a/src/test/shell/bazel/testing_server.py
+++ b/src/test/shell/bazel/testing_server.py
@@ -104,8 +104,10 @@
Handler.not_found = True
elif argv and argv[0] == 'timeout':
Handler.simulate_timeout = True
- elif argv:
+ elif argv and argv[0] == 'auth':
Handler.auth = True
+ if len(argv) > 1:
+ Handler.filename = argv[1]
httpd = None
port = None