Implement toolchain resolution via constraint checks.

Part of #2219.

Change-Id: I5777e9b6cafbb7586cbbfb5b300344fd4417513d
PiperOrigin-RevId: 162359389
diff --git a/src/main/java/com/google/devtools/build/lib/analysis/platform/PlatformInfo.java b/src/main/java/com/google/devtools/build/lib/analysis/platform/PlatformInfo.java
index 1d1092a..9e5c917 100644
--- a/src/main/java/com/google/devtools/build/lib/analysis/platform/PlatformInfo.java
+++ b/src/main/java/com/google/devtools/build/lib/analysis/platform/PlatformInfo.java
@@ -35,6 +35,7 @@
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import javax.annotation.Nullable;
 
 /** Provider for a platform, which is a group of constraints and values. */
 @SkylarkModule(
@@ -90,7 +91,7 @@
       };
 
   private final Label label;
-  private final ImmutableList<ConstraintValueInfo> constraints;
+  private final ImmutableMap<ConstraintSettingInfo, ConstraintValueInfo> constraints;
   private final ImmutableMap<String, String> remoteExecutionProperties;
 
   private PlatformInfo(
@@ -106,8 +107,14 @@
         location);
 
     this.label = label;
-    this.constraints = constraints;
     this.remoteExecutionProperties = remoteExecutionProperties;
+
+    ImmutableMap.Builder<ConstraintSettingInfo, ConstraintValueInfo> constraintsBuilder =
+        new ImmutableMap.Builder<>();
+    for (ConstraintValueInfo constraint : constraints) {
+      constraintsBuilder.put(constraint.constraint(), constraint);
+    }
+    this.constraints = constraintsBuilder.build();
   }
 
   @SkylarkCallable(
@@ -126,8 +133,17 @@
             + "this platform.",
     structField = true
   )
-  public ImmutableList<ConstraintValueInfo> constraints() {
-    return constraints;
+  public Iterable<ConstraintValueInfo> constraints() {
+    return constraints.values();
+  }
+
+  /**
+   * Returns the {@link ConstraintValueInfo} for the given {@link ConstraintSettingInfo}, or {@code
+   * null} if none exists.
+   */
+  @Nullable
+  public ConstraintValueInfo getConstraint(ConstraintSettingInfo constraint) {
+    return constraints.get(constraint);
   }
 
   @SkylarkCallable(
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunction.java b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunction.java
index 7a61c67..dbe0e9f 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunction.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunction.java
@@ -14,7 +14,9 @@
 
 package com.google.devtools.build.lib.skyframe;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.ImmutableList;
+import com.google.devtools.build.lib.analysis.platform.ConstraintValueInfo;
 import com.google.devtools.build.lib.analysis.platform.DeclaredToolchainInfo;
 import com.google.devtools.build.lib.analysis.platform.PlatformInfo;
 import com.google.devtools.build.lib.cmdline.Label;
@@ -63,25 +65,55 @@
     DeclaredToolchainInfo toolchain =
         resolveConstraints(
             key.toolchainType(),
-            key.targetPlatform(),
             key.execPlatform(),
+            key.targetPlatform(),
             toolchains.registeredToolchains());
+
+    if (toolchain == null) {
+      throw new ToolchainResolutionFunctionException(
+          new NoToolchainFoundException(key.toolchainType()));
+    }
     return ToolchainResolutionValue.create(toolchain.toolchainLabel());
   }
 
-  // TODO(katre): Implement real resolution.
-  private DeclaredToolchainInfo resolveConstraints(
+  @VisibleForTesting
+  static DeclaredToolchainInfo resolveConstraints(
       Label toolchainType,
-      PlatformInfo targetPlatform,
       PlatformInfo execPlatform,
-      ImmutableList<DeclaredToolchainInfo> toolchains)
-      throws ToolchainResolutionFunctionException {
+      PlatformInfo targetPlatform,
+      ImmutableList<DeclaredToolchainInfo> toolchains) {
     for (DeclaredToolchainInfo toolchain : toolchains) {
-      if (toolchain.toolchainType().equals(toolchainType)) {
-        return toolchain;
+      // Make sure the type matches.
+      if (!toolchain.toolchainType().equals(toolchainType)) {
+        continue;
+      }
+      if (!checkConstraints(toolchain.execConstraints(), execPlatform)) {
+        continue;
+      }
+      if (!checkConstraints(toolchain.targetConstraints(), targetPlatform)) {
+        continue;
+      }
+
+      return toolchain;
+    }
+
+    return null;
+  }
+
+  /**
+   * Returns {@code true} iff all constraints set by the toolchain are present in the {@link
+   * PlatformInfo}.
+   */
+  private static boolean checkConstraints(
+      Iterable<ConstraintValueInfo> toolchainConstraints, PlatformInfo platform) {
+
+    for (ConstraintValueInfo constraint : toolchainConstraints) {
+      ConstraintValueInfo found = platform.getConstraint(constraint.constraint());
+      if (!constraint.equals(found)) {
+        return false;
       }
     }
-    throw new ToolchainResolutionFunctionException(new NoToolchainFoundException(toolchainType));
+    return true;
   }
 
   @Nullable
diff --git a/src/test/java/com/google/devtools/build/lib/rules/platform/ToolchainTestCase.java b/src/test/java/com/google/devtools/build/lib/rules/platform/ToolchainTestCase.java
index 5d170ba..3b9548d 100644
--- a/src/test/java/com/google/devtools/build/lib/rules/platform/ToolchainTestCase.java
+++ b/src/test/java/com/google/devtools/build/lib/rules/platform/ToolchainTestCase.java
@@ -24,8 +24,8 @@
 /** Utility methods for setting up platform and toolchain related tests. */
 public abstract class ToolchainTestCase extends SkylarkTestCase {
 
-  public PlatformInfo targetPlatform;
-  public PlatformInfo hostPlatform;
+  public PlatformInfo linuxPlatform;
+  public PlatformInfo macPlatform;
 
   public ConstraintSettingInfo setting;
   public ConstraintValueInfo linuxConstraint;
@@ -46,18 +46,22 @@
     setting = ConstraintSettingInfo.create(makeLabel("//constraint:os"));
     linuxConstraint = ConstraintValueInfo.create(setting, makeLabel("//constraint:linux"));
     macConstraint = ConstraintValueInfo.create(setting, makeLabel("//constraint:mac"));
-  }
 
-  @Before
-  public void createPlatforms() throws Exception {
-    targetPlatform =
-        PlatformInfo.builder().setLabel(makeLabel("//platforms:target_platform")).build();
-    hostPlatform = PlatformInfo.builder().setLabel(makeLabel("//platforms:host_platform")).build();
+    linuxPlatform =
+        PlatformInfo.builder()
+            .setLabel(makeLabel("//platforms:target_platform"))
+            .addConstraint(linuxConstraint)
+            .build();
+    macPlatform =
+        PlatformInfo.builder()
+            .setLabel(makeLabel("//platforms:host_platform"))
+            .addConstraint(macConstraint)
+            .build();
   }
 
   @Before
   public void createToolchains() throws Exception {
-    rewriteWorkspace("register_toolchains('//toolchain:toolchain_1',  '//toolchain:toolchain_2')");
+    rewriteWorkspace("register_toolchains('//toolchain:toolchain_1', '//toolchain:toolchain_2')");
 
     scratch.file(
         "toolchain/BUILD",
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/BUILD b/src/test/java/com/google/devtools/build/lib/skyframe/BUILD
index d68e4e8..fd6714e 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/BUILD
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/BUILD
@@ -73,7 +73,6 @@
         "//src/main/java/com/google/devtools/build/lib/analysis/platform",
         "//src/main/java/com/google/devtools/build/lib/cmdline",
         "//src/main/java/com/google/devtools/build/lib/rules/cpp",
-        "//src/main/java/com/google/devtools/build/lib/rules/platform",
         "//src/main/java/com/google/devtools/build/skyframe",
         "//src/main/java/com/google/devtools/build/skyframe:skyframe-objects",
         "//src/test/java/com/google/devtools/build/lib:actions_testutil",
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunctionTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunctionTest.java
index 88bc501..214cc61 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunctionTest.java
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunctionTest.java
@@ -17,7 +17,15 @@
 import static com.google.common.truth.Truth.assertThat;
 import static com.google.devtools.build.skyframe.EvaluationResultSubjectFactory.assertThatEvaluationResult;
 
+import com.google.common.collect.ImmutableList;
 import com.google.common.testing.EqualsTester;
+import com.google.common.truth.DefaultSubject;
+import com.google.common.truth.Subject;
+import com.google.devtools.build.lib.analysis.platform.ConstraintSettingInfo;
+import com.google.devtools.build.lib.analysis.platform.ConstraintValueInfo;
+import com.google.devtools.build.lib.analysis.platform.DeclaredToolchainInfo;
+import com.google.devtools.build.lib.analysis.platform.PlatformInfo;
+import com.google.devtools.build.lib.cmdline.Label;
 import com.google.devtools.build.lib.rules.platform.ToolchainTestCase;
 import com.google.devtools.build.lib.skyframe.util.SkyframeExecutorTestUtils;
 import com.google.devtools.build.skyframe.EvaluationResult;
@@ -32,6 +40,7 @@
 
   private EvaluationResult<ToolchainResolutionValue> invokeToolchainResolution(SkyKey key)
       throws InterruptedException {
+
     try {
       getSkyframeExecutor().getSkyframeBuildView().enableAnalysis(true);
       return SkyframeExecutorTestUtils.evaluate(
@@ -41,19 +50,17 @@
     }
   }
 
-  // TODO(katre): Current toolchain resolution does not actually check the constraints, it just
-  // returns the first toolchain available.
   @Test
   public void testResolution() throws Exception {
     SkyKey key =
-        ToolchainResolutionValue.key(targetConfig, testToolchainType, targetPlatform, hostPlatform);
+        ToolchainResolutionValue.key(targetConfig, testToolchainType, linuxPlatform, macPlatform);
     EvaluationResult<ToolchainResolutionValue> result = invokeToolchainResolution(key);
 
     assertThatEvaluationResult(result).hasNoError();
 
     ToolchainResolutionValue toolchainResolutionValue = result.get(key);
     assertThat(toolchainResolutionValue.toolchainLabel())
-        .isEqualTo(makeLabel("//toolchain:test_toolchain_1"));
+        .isEqualTo(makeLabel("//toolchain:test_toolchain_2"));
   }
 
   @Test
@@ -62,7 +69,7 @@
     rewriteWorkspace();
 
     SkyKey key =
-        ToolchainResolutionValue.key(targetConfig, testToolchainType, targetPlatform, hostPlatform);
+        ToolchainResolutionValue.key(targetConfig, testToolchainType, linuxPlatform, macPlatform);
     EvaluationResult<ToolchainResolutionValue> result = invokeToolchainResolution(key);
 
     assertThatEvaluationResult(result)
@@ -73,6 +80,110 @@
   }
 
   @Test
+  public void testResolveConstraints() throws Exception {
+    ConstraintSettingInfo setting1 =
+        ConstraintSettingInfo.create(makeLabel("//constraint:setting1"));
+    ConstraintSettingInfo setting2 =
+        ConstraintSettingInfo.create(makeLabel("//constraint:setting2"));
+    ConstraintValueInfo constraint1a =
+        ConstraintValueInfo.create(setting1, makeLabel("//constraint:value1a"));
+    ConstraintValueInfo constraint1b =
+        ConstraintValueInfo.create(setting1, makeLabel("//constraint:value1b"));
+    ConstraintValueInfo constraint2a =
+        ConstraintValueInfo.create(setting2, makeLabel("//constraint:value2a"));
+    ConstraintValueInfo constraint2b =
+        ConstraintValueInfo.create(setting2, makeLabel("//constraint:value2b"));
+
+    Label toolchainType1 = makeLabel("//toolchain:type1");
+    Label toolchainType2 = makeLabel("//toolchain:type2");
+
+    DeclaredToolchainInfo toolchain1a =
+        DeclaredToolchainInfo.create(
+            toolchainType1,
+            ImmutableList.of(constraint1a, constraint2a),
+            ImmutableList.of(constraint1a, constraint2a),
+            makeLabel("//toolchain:toolchain1a"));
+    DeclaredToolchainInfo toolchain1b =
+        DeclaredToolchainInfo.create(
+            toolchainType1,
+            ImmutableList.of(constraint1a, constraint2b),
+            ImmutableList.of(constraint1a, constraint2b),
+            makeLabel("//toolchain:toolchain1b"));
+    DeclaredToolchainInfo toolchain2a =
+        DeclaredToolchainInfo.create(
+            toolchainType2,
+            ImmutableList.of(constraint1b, constraint2a),
+            ImmutableList.of(constraint1b, constraint2a),
+            makeLabel("//toolchain:toolchain2a"));
+    DeclaredToolchainInfo toolchain2b =
+        DeclaredToolchainInfo.create(
+            toolchainType2,
+            ImmutableList.of(constraint1b, constraint2b),
+            ImmutableList.of(constraint1b, constraint2b),
+            makeLabel("//toolchain:toolchain2b"));
+
+    ImmutableList<DeclaredToolchainInfo> allToolchains =
+        ImmutableList.of(toolchain1a, toolchain1b, toolchain2a, toolchain2b);
+
+    assertToolchainResolution(
+            toolchainType1,
+            ImmutableList.of(constraint1a, constraint2a),
+            ImmutableList.of(constraint1a, constraint2a),
+            allToolchains)
+        .isEqualTo(toolchain1a);
+    assertToolchainResolution(
+            toolchainType1,
+            ImmutableList.of(constraint1a, constraint2b),
+            ImmutableList.of(constraint1a, constraint2b),
+            allToolchains)
+        .isEqualTo(toolchain1b);
+    assertToolchainResolution(
+            toolchainType2,
+            ImmutableList.of(constraint1b, constraint2a),
+            ImmutableList.of(constraint1b, constraint2a),
+            allToolchains)
+        .isEqualTo(toolchain2a);
+    assertToolchainResolution(
+            toolchainType2,
+            ImmutableList.of(constraint1b, constraint2b),
+            ImmutableList.of(constraint1b, constraint2b),
+            allToolchains)
+        .isEqualTo(toolchain2b);
+
+    // No toolchains of type.
+    assertToolchainResolution(
+            makeLabel("//toolchain:type3"),
+            ImmutableList.of(constraint1a, constraint2a),
+            ImmutableList.of(constraint1a, constraint2a),
+            allToolchains)
+        .isNull();
+  }
+
+  private Subject<DefaultSubject, Object> assertToolchainResolution(
+      Label toolchainType,
+      Iterable<ConstraintValueInfo> targetConstraints,
+      Iterable<ConstraintValueInfo> execConstraints,
+      ImmutableList<DeclaredToolchainInfo> toolchains)
+      throws Exception {
+
+    PlatformInfo execPlatform =
+        PlatformInfo.builder()
+            .setLabel(makeLabel("//platform:exec"))
+            .addConstraints(execConstraints)
+            .build();
+    PlatformInfo targetPlatform =
+        PlatformInfo.builder()
+            .setLabel(makeLabel("//platform:target"))
+            .addConstraints(targetConstraints)
+            .build();
+
+    DeclaredToolchainInfo resolvedToolchain =
+        ToolchainResolutionFunction.resolveConstraints(
+            toolchainType, execPlatform, targetPlatform, toolchains);
+    return assertThat(resolvedToolchain);
+  }
+
+  @Test
   public void testToolchainResolutionValue_equalsAndHashCode() {
     new EqualsTester()
         .addEqualityGroup(