Extend the repository_context.download{,_and_extract} with authentication

Extend, as described in the authentication API design[1], the functions
download{,_and_extract} of the repository context by an additional
parameter 'auth'. If given, the information provided there will be used
to authenticate upon download.

[1] https://github.com/bazelbuild/proposals/blob/master/designs/2019-05-27-auth.md

Change-Id: Ifa119987accbcc56219e9a0c965ed3e94dc9d38e
PiperOrigin-RevId: 253760773
diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnector.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnector.java
index e1050bb..8259399 100644
--- a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnector.java
+++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnector.java
@@ -90,6 +90,7 @@
   URLConnection connect(
       URL originalUrl, ImmutableMap<String, String> requestHeaders)
           throws IOException {
+
     if (Thread.interrupted()) {
       throw new InterruptedIOException();
     }
@@ -116,7 +117,7 @@
             // appears to be compressed.
             continue;
           }
-          connection.setRequestProperty(entry.getKey(), entry.getValue());
+          connection.addRequestProperty(entry.getKey(), entry.getValue());
         }
         connection.setConnectTimeout(connectTimeout);
         // The read timeout is always large because it stays in effect after this method.
diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexer.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexer.java
index 8798a33..36328ae 100644
--- a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexer.java
+++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpConnectorMultiplexer.java
@@ -30,6 +30,8 @@
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.InterruptedIOException;
+import java.net.URI;
+import java.net.URISyntaxException;
 import java.net.URL;
 import java.net.URLConnection;
 import java.util.ArrayList;
@@ -37,6 +39,7 @@
 import java.util.Deque;
 import java.util.LinkedList;
 import java.util.List;
+import java.util.Map;
 import java.util.logging.Level;
 import java.util.logging.Logger;
 import javax.annotation.Nullable;
@@ -92,6 +95,10 @@
     this.sleeper = sleeper;
   }
 
+  public HttpStream connect(List<URL> urls, String sha256) throws IOException {
+    return connect(urls, sha256, ImmutableMap.<URI, Map<String, String>>of());
+  }
+
   /**
    * Establishes reliable HTTP connection to a good mirror URL.
    *
@@ -115,7 +122,8 @@
    * @throws InterruptedIOException if current thread is being cast into oblivion
    * @throws IllegalArgumentException if {@code urls} is empty or has an unsupported protocol
    */
-  public HttpStream connect(List<URL> urls, String sha256) throws IOException {
+  public HttpStream connect(
+      List<URL> urls, String sha256, Map<URI, Map<String, String>> authHeaders) throws IOException {
     Preconditions.checkNotNull(sha256);
     HttpUtils.checkUrlsArgument(urls);
     if (Thread.interrupted()) {
@@ -123,7 +131,7 @@
     }
     // If there's only one URL then there's no need for us to run all our fancy thread stuff.
     if (urls.size() == 1) {
-      return establishConnection(urls.get(0), sha256);
+      return establishConnection(urls.get(0), sha256, authHeaders);
     }
     MutexConditionSharedMemory context = new MutexConditionSharedMemory();
     // The parent thread always holds the lock except when released by wait().
@@ -132,7 +140,7 @@
       long now = clock.currentTimeMillis();
       long startAtTime = now;
       for (URL url : urls) {
-        context.jobs.add(new WorkItem(url, sha256, startAtTime));
+        context.jobs.add(new WorkItem(url, sha256, startAtTime, authHeaders));
         startAtTime += FAILOVER_DELAY_MS;
       }
       // Create the worker thread pool.
@@ -204,11 +212,13 @@
     final URL url;
     final String sha256;
     final long startAtTime;
+    final Map<URI, Map<String, String>> authHeaders;
 
-    WorkItem(URL url, String sha256, long startAtTime) {
+    WorkItem(URL url, String sha256, long startAtTime, Map<URI, Map<String, String>> authHeaders) {
       this.url = url;
       this.sha256 = sha256;
       this.startAtTime = startAtTime;
+      this.authHeaders = authHeaders;
     }
   }
 
@@ -253,7 +263,7 @@
         // Now we're actually going to attempt to connect to the remote server.
         HttpStream result;
         try {
-          result = establishConnection(work.url, work.sha256);
+          result = establishConnection(work.url, work.sha256, work.authHeaders);
         } catch (InterruptedIOException e) {
           // The parent thread got its result from another thread and killed this one.
           synchronized (context) {
@@ -296,21 +306,38 @@
     }
   }
 
-  private HttpStream establishConnection(final URL url, String sha256) throws IOException {
-    final URLConnection connection = connector.connect(url, REQUEST_HEADERS);
+  private HttpStream establishConnection(
+      final URL url, String sha256, Map<URI, Map<String, String>> additionalHeaders)
+      throws IOException {
+    ImmutableMap<String, String> headers = REQUEST_HEADERS;
+    try {
+      if (additionalHeaders.containsKey(url.toURI())) {
+        headers =
+            ImmutableMap.<String, String>builder()
+                .putAll(headers)
+                .putAll(additionalHeaders.get(url.toURI()))
+                .build();
+      }
+    } catch (URISyntaxException e) {
+      // If we can't convert the URL to a URI (because it is syntactically malformed), still try to
+      // do the connection, not adding authentication information as we cannot look it up.
+    }
+    final URLConnection connection = connector.connect(url, headers);
+    final Map<String, String> allHeaders = headers;
     return httpStreamFactory.create(
-        connection, url, sha256,
+        connection,
+        url,
+        sha256,
         new Reconnector() {
           @Override
-          public URLConnection connect(
-              Throwable cause, ImmutableMap<String, String> extraHeaders)
-                  throws IOException {
+          public URLConnection connect(Throwable cause, ImmutableMap<String, String> extraHeaders)
+              throws IOException {
             eventHandler.handle(
                 Event.progress(String.format("Lost connection for %s due to %s", url, cause)));
             return connector.connect(
                 connection.getURL(),
                 new ImmutableMap.Builder<String, String>()
-                    .putAll(REQUEST_HEADERS)
+                    .putAll(allHeaders)
                     .putAll(extraHeaders)
                     .build());
           }
diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloader.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloader.java
index 4f6a881..d9fad055 100644
--- a/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloader.java
+++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/downloader/HttpDownloader.java
@@ -35,6 +35,7 @@
 import java.io.IOException;
 import java.io.InterruptedIOException;
 import java.io.OutputStream;
+import java.net.URI;
 import java.net.URL;
 import java.util.List;
 import java.util.Locale;
@@ -89,6 +90,7 @@
    */
   public Path download(
       List<URL> urls,
+      Map<URI, Map<String, String>> authHeaders,
       String sha256,
       String canonicalId,
       Optional<String> type,
@@ -199,7 +201,7 @@
     // Connect to the best mirror and download the file, while reporting progress to the CLI.
     semaphore.acquire();
     boolean success = false;
-    try (HttpStream payload = multiplexer.connect(urls, sha256);
+    try (HttpStream payload = multiplexer.connect(urls, sha256, authHeaders);
         OutputStream out = destination.getOutputStream()) {
       ByteStreams.copy(payload, out);
       success = true;
diff --git a/src/main/java/com/google/devtools/build/lib/bazel/repository/skylark/SkylarkRepositoryContext.java b/src/main/java/com/google/devtools/build/lib/bazel/repository/skylark/SkylarkRepositoryContext.java
index 729cdbd..25b521a 100644
--- a/src/main/java/com/google/devtools/build/lib/bazel/repository/skylark/SkylarkRepositoryContext.java
+++ b/src/main/java/com/google/devtools/build/lib/bazel/repository/skylark/SkylarkRepositoryContext.java
@@ -61,9 +61,12 @@
 import java.io.IOException;
 import java.io.OutputStream;
 import java.net.MalformedURLException;
+import java.net.URI;
+import java.net.URISyntaxException;
 import java.net.URL;
 import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
+import java.util.Base64;
 import java.util.List;
 import java.util.Map;
 
@@ -460,8 +463,11 @@
       Boolean executable,
       Boolean allowFail,
       String canonicalId,
+      SkylarkDict<String, SkylarkDict<Object, Object>> auth,
       Location location)
       throws RepositoryFunctionException, EvalException, InterruptedException {
+    Map<URI, Map<String, String>> authHeaders = getAuthHeaders(auth);
+
     List<URL> urls = getUrls(url, /* ensureNonEmpty= */ !allowFail);
     RepositoryFunctionException sha256Validation = validateSha256(sha256, location);
     if (sha256Validation != null) {
@@ -480,6 +486,7 @@
       downloadedPath =
           httpDownloader.download(
               urls,
+              authHeaders,
               sha256,
               canonicalId,
               Optional.<String>absent(),
@@ -561,8 +568,11 @@
       String stripPrefix,
       Boolean allowFail,
       String canonicalId,
+      SkylarkDict<String, SkylarkDict<Object, Object>> auth,
       Location location)
       throws RepositoryFunctionException, InterruptedException, EvalException {
+    Map<URI, Map<String, String>> authHeaders = getAuthHeaders(auth);
+
     List<URL> urls = getUrls(url, /* ensureNonEmpty= */ !allowFail);
     RepositoryFunctionException sha256Validation = validateSha256(sha256, location);
     if (sha256Validation != null) {
@@ -590,6 +600,7 @@
       downloadedPath =
           httpDownloader.download(
               urls,
+              authHeaders,
               sha256,
               canonicalId,
               Optional.of(type),
@@ -792,4 +803,49 @@
       }
     }
   }
+
+  /**
+   * From an authentication dict extract a map of headers.
+   *
+   * <p>Given a dict as provided as "auth" argument, compute a map specifying for each URI provided
+   * which additional headers (as usual, represented as a map from Strings to Strings) should
+   * additionally be added to the request. For some form of authentication, in particular basic
+   * authentication, adding those headers is enough; for other forms of authentication other
+   * measures might be necessary.
+   */
+  private static Map<URI, Map<String, String>> getAuthHeaders(
+      SkylarkDict<String, SkylarkDict<Object, Object>> auth)
+      throws RepositoryFunctionException, EvalException {
+    ImmutableMap.Builder<URI, Map<String, String>> headers = new ImmutableMap.Builder<>();
+    for (Map.Entry<String, SkylarkDict<Object, Object>> entry : auth.entrySet()) {
+      try {
+        URL url = new URL(entry.getKey());
+        SkylarkDict<Object, Object> authMap = entry.getValue();
+        if (authMap.containsKey("type")) {
+          if ("basic".equals(authMap.get("type"))) {
+            if (!authMap.containsKey("login") || !authMap.containsKey("password")) {
+              throw new EvalException(
+                  null,
+                  "Found request to do basic auth for "
+                      + entry.getKey()
+                      + " without 'login' and 'password' being provided.");
+            }
+            String credentials = authMap.get("login") + ":" + authMap.get("password");
+            headers.put(
+                url.toURI(),
+                ImmutableMap.<String, String>of(
+                    "Authorization",
+                    "Basic "
+                        + Base64.getEncoder()
+                            .encodeToString(credentials.getBytes(StandardCharsets.UTF_8))));
+          }
+        }
+      } catch (MalformedURLException e) {
+        throw new RepositoryFunctionException(e, Transience.PERSISTENT);
+      } catch (URISyntaxException e) {
+        throw new EvalException(null, e.getMessage());
+      }
+    }
+    return headers.build();
+  }
 }
diff --git a/src/main/java/com/google/devtools/build/lib/skylarkbuildapi/repository/SkylarkRepositoryContextApi.java b/src/main/java/com/google/devtools/build/lib/skylarkbuildapi/repository/SkylarkRepositoryContextApi.java
index bca57b0..eebaf44 100644
--- a/src/main/java/com/google/devtools/build/lib/skylarkbuildapi/repository/SkylarkRepositoryContextApi.java
+++ b/src/main/java/com/google/devtools/build/lib/skylarkbuildapi/repository/SkylarkRepositoryContextApi.java
@@ -368,6 +368,12 @@
             doc =
                 "If set, restrict cache hits to those cases where the file was added to the cache"
                     + " with the same canonical id"),
+        @Param(
+            name = "auth",
+            type = SkylarkDict.class,
+            defaultValue = "{}",
+            named = true,
+            doc = "An optional dict specifying authentication information for some of the URLs."),
       })
   public StructApi download(
       Object url,
@@ -376,6 +382,7 @@
       Boolean executable,
       Boolean allowFail,
       String canonicalId,
+      SkylarkDict<String, SkylarkDict<Object, Object>> auth,
       Location location)
       throws RepositoryFunctionExceptionT, EvalException, InterruptedException;
 
@@ -503,6 +510,12 @@
             doc =
                 "If set, restrict cache hits to those cases where the file was added to the cache"
                     + " with the same canonical id"),
+        @Param(
+            name = "auth",
+            type = SkylarkDict.class,
+            defaultValue = "{}",
+            named = true,
+            doc = "An optional dict specifying authentication information for some of the URLs."),
       })
   public StructApi downloadAndExtract(
       Object url,
@@ -512,6 +525,7 @@
       String stripPrefix,
       Boolean allowFail,
       String canonicalId,
+      SkylarkDict<String, SkylarkDict<Object, Object>> auth,
       Location location)
       throws RepositoryFunctionExceptionT, InterruptedException, EvalException;
 }
diff --git a/src/test/shell/bazel/remote_helpers.sh b/src/test/shell/bazel/remote_helpers.sh
index a6e8492..528f52e 100755
--- a/src/test/shell/bazel/remote_helpers.sh
+++ b/src/test/shell/bazel/remote_helpers.sh
@@ -38,6 +38,29 @@
   cd -
 }
 
+# Serves $1 as a file on localhost:$nc_port insisting on authentication (but
+# accepting any credentials.
+#   * nc_port - the port nc is listening on.
+#   * nc_log - the path to nc's log.
+#   * nc_pid - the PID of nc.
+function serve_file_auth() {
+  file_name=served_file.$$
+  cat $1 > "${TEST_TMPDIR}/$file_name"
+  nc_log="${TEST_TMPDIR}/nc.log"
+  rm -f $nc_log
+  touch $nc_log
+  cd "${TEST_TMPDIR}"
+  port_file=server-port.$$
+  rm -f $port_file
+  python $python_server auth $file_name > $port_file &
+  nc_pid=$!
+  while ! grep started $port_file; do sleep 1; done
+  nc_port=$(head -n 1 $port_file)
+  fileserver_port=$nc_port
+  wait_for_server_startup
+  cd -
+}
+
 # Creates a jar carnivore.Mongoose and serves it using serve_file.
 function serve_jar() {
   make_test_jar
diff --git a/src/test/shell/bazel/skylark_repository_test.sh b/src/test/shell/bazel/skylark_repository_test.sh
index 72fc395..aa83a7d 100755
--- a/src/test/shell/bazel/skylark_repository_test.sh
+++ b/src/test/shell/bazel/skylark_repository_test.sh
@@ -1436,6 +1436,47 @@
   expect_log "//:b.bzl"
 }
 
+function test_auth_provided() {
+  mkdir x
+  echo 'exports_files(["file.txt"])' > x/BUILD
+  echo 'Hello World' > x/file.txt
+  tar cvf x.tar x
+  serve_file_auth x.tar
+  cat > WORKSPACE <<EOF
+load("//:auth.bzl", "with_auth")
+with_auth(
+  name="ext",
+  url = "http://127.0.0.1:$nc_port/x.tar",
+)
+EOF
+  cat > auth.bzl <<'EOF'
+def _impl(ctx):
+  ctx.download_and_extract(
+    url = ctx.attr.url,
+    # Use the username/password pair hard-coded
+    # in the testing server.
+    auth = {ctx.attr.url : { "type": "basic",
+                            "login" : "foo",
+                            "password" : "bar"}}
+  )
+
+with_auth = repository_rule(
+  implementation = _impl,
+  attrs = { "url" : attr.string() }
+)
+EOF
+  cat > BUILD <<'EOF'
+genrule(
+  name = "it",
+  srcs = ["@ext//x:file.txt"],
+  outs = ["it.txt"],
+  cmd = "cp $< $@",
+)
+EOF
+  bazel build //:it \
+      || fail "Expected success despite needing a file behind basic auth"
+}
+
 
 function tear_down() {
   shutdown_server
diff --git a/src/test/shell/bazel/testing_server.py b/src/test/shell/bazel/testing_server.py
index 314e409..1285ec9 100644
--- a/src/test/shell/bazel/testing_server.py
+++ b/src/test/shell/bazel/testing_server.py
@@ -104,8 +104,10 @@
     Handler.not_found = True
   elif argv and argv[0] == 'timeout':
     Handler.simulate_timeout = True
-  elif argv:
+  elif argv and argv[0] == 'auth':
     Handler.auth = True
+    if len(argv) > 1:
+      Handler.filename = argv[1]
 
   httpd = None
   port = None