Get PlatformInfo out of ToolchainResolutionKey, replace with the ConfiguredTargetKeys that own the PlatformInfo.
PiperOrigin-RevId: 185770105
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunction.java b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunction.java
index 6a8dcca..0211c46 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunction.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunction.java
@@ -27,9 +27,9 @@
import com.google.devtools.build.lib.events.Event;
import com.google.devtools.build.lib.events.EventHandler;
import com.google.devtools.build.lib.packages.NoSuchThingException;
-import com.google.devtools.build.lib.skyframe.ConfiguredTargetFunction.ConfiguredValueCreationException;
import com.google.devtools.build.lib.skyframe.RegisteredToolchainsFunction.InvalidToolchainLabelException;
import com.google.devtools.build.lib.skyframe.ToolchainResolutionValue.ToolchainResolutionKey;
+import com.google.devtools.build.lib.skyframe.ToolchainUtil.ToolchainContextException;
import com.google.devtools.build.lib.syntax.EvalException;
import com.google.devtools.build.skyframe.SkyFunction;
import com.google.devtools.build.skyframe.SkyFunctionException;
@@ -37,7 +37,9 @@
import com.google.devtools.build.skyframe.SkyValue;
import java.util.HashSet;
import java.util.List;
+import java.util.Map;
import java.util.Set;
+import java.util.stream.Collectors;
import javax.annotation.Nullable;
/** {@link SkyFunction} which performs toolchain resolution for a class of rules. */
@@ -46,7 +48,7 @@
@Nullable
@Override
public SkyValue compute(SkyKey skyKey, Environment env)
- throws SkyFunctionException, InterruptedException {
+ throws ToolchainResolutionFunctionException, InterruptedException {
ToolchainResolutionKey key = (ToolchainResolutionKey) skyKey.argument();
// This call could be combined with the call below, but this SkyFunction is evaluated so rarely
@@ -75,13 +77,24 @@
throw new ToolchainResolutionFunctionException(e);
}
+ Map<ConfiguredTargetKey, PlatformInfo> platforms;
+ try {
+ platforms =
+ ToolchainUtil.getPlatformInfo(
+ key.targetPlatformKey(), key.availableExecutionPlatformKeys(), env);
+ } catch (ToolchainContextException e) {
+ throw new ToolchainResolutionFunctionException(e);
+ }
// Find the right one.
boolean debug = configuration.getOptions().get(PlatformOptions.class).toolchainResolutionDebug;
ImmutableMap<PlatformInfo, Label> resolvedToolchainLabels =
resolveConstraints(
key.toolchainType(),
- key.availableExecutionPlatforms(),
- key.targetPlatform(),
+ key.availableExecutionPlatformKeys()
+ .stream()
+ .map(platforms::get)
+ .collect(Collectors.toList()),
+ platforms.get(key.targetPlatformKey()),
toolchains.registeredToolchains(),
debug ? env.getListener() : null);
@@ -227,7 +240,7 @@
super(e, Transience.PERSISTENT);
}
- public ToolchainResolutionFunctionException(ConfiguredValueCreationException e) {
+ public ToolchainResolutionFunctionException(ToolchainContextException e) {
super(e, Transience.PERSISTENT);
}
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionValue.java b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionValue.java
index f71c3bd..69ab599 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionValue.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionValue.java
@@ -37,10 +37,10 @@
public static SkyKey key(
BuildConfigurationValue.Key configurationKey,
Label toolchainType,
- PlatformInfo targetPlatform,
- List<PlatformInfo> availableExecutionPlatforms) {
+ ConfiguredTargetKey targetPlatformKey,
+ List<ConfiguredTargetKey> availableExecutionPlatformKeys) {
return ToolchainResolutionKey.create(
- configurationKey, toolchainType, targetPlatform, availableExecutionPlatforms);
+ configurationKey, toolchainType, targetPlatformKey, availableExecutionPlatformKeys);
}
/** {@link SkyKey} implementation used for {@link ToolchainResolutionFunction}. */
@@ -55,20 +55,20 @@
public abstract Label toolchainType();
- public abstract PlatformInfo targetPlatform();
+ abstract ConfiguredTargetKey targetPlatformKey();
- abstract ImmutableList<PlatformInfo> availableExecutionPlatforms();
+ abstract ImmutableList<ConfiguredTargetKey> availableExecutionPlatformKeys();
static ToolchainResolutionKey create(
BuildConfigurationValue.Key configuration,
Label toolchainType,
- PlatformInfo targetPlatform,
- List<PlatformInfo> availableExecutionPlatforms) {
+ ConfiguredTargetKey targetPlatformKey,
+ List<ConfiguredTargetKey> availableExecutionPlatformKeys) {
return new AutoValue_ToolchainResolutionValue_ToolchainResolutionKey(
configuration,
toolchainType,
- targetPlatform,
- ImmutableList.copyOf(availableExecutionPlatforms));
+ targetPlatformKey,
+ ImmutableList.copyOf(availableExecutionPlatformKeys));
}
}
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainUtil.java b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainUtil.java
index 28d5d9a..cf340d4 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainUtil.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/ToolchainUtil.java
@@ -36,6 +36,7 @@
import com.google.devtools.build.skyframe.ValueOrException;
import com.google.devtools.build.skyframe.ValueOrException4;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -51,7 +52,7 @@
* of the {@link ToolchainResolutionFunction}.
*/
@Nullable
- public static ToolchainContext createToolchainContext(
+ static ToolchainContext createToolchainContext(
Environment env,
String targetDescription,
Set<Label> requiredToolchains,
@@ -79,7 +80,11 @@
ImmutableBiMap<Label, Label> resolvedLabels =
resolveToolchainLabels(
- env, requiredToolchains, executionPlatform, targetPlatform, configurationKey);
+ env,
+ requiredToolchains,
+ configurationKey,
+ platforms.hostPlatformKey(),
+ platforms.targetPlatformKey());
if (resolvedLabels == null) {
return null;
}
@@ -103,9 +108,17 @@
abstract PlatformInfo targetPlatform();
+ abstract ConfiguredTargetKey hostPlatformKey();
+
+ abstract ConfiguredTargetKey targetPlatformKey();
+
protected static PlatformDescriptors create(
- PlatformInfo hostPlatform, PlatformInfo targetPlatform) {
- return new AutoValue_ToolchainUtil_PlatformDescriptors(hostPlatform, targetPlatform);
+ PlatformInfo hostPlatform,
+ PlatformInfo targetPlatform,
+ ConfiguredTargetKey hostPlatformKey,
+ ConfiguredTargetKey targetPlatformKey) {
+ return new AutoValue_ToolchainUtil_PlatformDescriptors(
+ hostPlatform, targetPlatform, hostPlatformKey, targetPlatformKey);
}
}
@@ -136,6 +149,37 @@
}
@Nullable
+ static Map<ConfiguredTargetKey, PlatformInfo> getPlatformInfo(
+ ConfiguredTargetKey targetPlatformKey,
+ Iterable<ConfiguredTargetKey> hostPlatformKeys,
+ Environment env)
+ throws InterruptedException, ToolchainContextException {
+ Iterable<ConfiguredTargetKey> allKeys =
+ Iterables.concat(ImmutableList.of(targetPlatformKey), hostPlatformKeys);
+ Map<SkyKey, ValueOrException<ConfiguredValueCreationException>> values =
+ env.getValuesOrThrow(allKeys, ConfiguredValueCreationException.class);
+ boolean valuesMissing = env.valuesMissing();
+ Map<ConfiguredTargetKey, PlatformInfo> platforms = valuesMissing ? null : new HashMap<>();
+ try {
+ for (ConfiguredTargetKey key : allKeys) {
+ PlatformInfo platformInfo =
+ findPlatformInfo(
+ values.get(key),
+ key.equals(targetPlatformKey) ? "target platform" : "host platform");
+ if (!valuesMissing && platformInfo != null) {
+ platforms.put(key, platformInfo);
+ }
+ }
+ } catch (ConfiguredValueCreationException e) {
+ throw new ToolchainContextException(e);
+ }
+ if (valuesMissing) {
+ return null;
+ }
+ return platforms;
+ }
+
+ @Nullable
private static PlatformDescriptors loadPlatformDescriptors(
Environment env, BuildConfiguration configuration)
throws InterruptedException, ToolchainContextException {
@@ -147,36 +191,29 @@
Label hostPlatformLabel = platformConfiguration.getHostPlatform();
Label targetPlatformLabel = platformConfiguration.getTargetPlatforms().get(0);
- SkyKey hostPlatformKey = ConfiguredTargetKey.of(hostPlatformLabel, configuration);
- SkyKey targetPlatformKey = ConfiguredTargetKey.of(targetPlatformLabel, configuration);
-
- Map<SkyKey, ValueOrException<ConfiguredValueCreationException>> values =
- env.getValuesOrThrow(
- ImmutableList.of(hostPlatformKey, targetPlatformKey),
- ConfiguredValueCreationException.class);
- boolean valuesMissing = env.valuesMissing();
- try {
- PlatformInfo hostPlatform = findPlatformInfo(values.get(hostPlatformKey), "host platform");
- PlatformInfo targetPlatform =
- findPlatformInfo(values.get(targetPlatformKey), "target platform");
-
- if (valuesMissing) {
- return null;
- }
-
- return PlatformDescriptors.create(hostPlatform, targetPlatform);
- } catch (ConfiguredValueCreationException e) {
- throw new ToolchainContextException(e);
+ ConfiguredTargetKey hostPlatformKey = ConfiguredTargetKey.of(hostPlatformLabel, configuration);
+ ConfiguredTargetKey targetPlatformKey =
+ ConfiguredTargetKey.of(targetPlatformLabel, configuration);
+ Map<ConfiguredTargetKey, PlatformInfo> platformResult =
+ getPlatformInfo(targetPlatformKey, ImmutableList.of(hostPlatformKey), env);
+ if (env.valuesMissing()) {
+ return null;
}
+
+ return PlatformDescriptors.create(
+ platformResult.get(hostPlatformKey),
+ platformResult.get(targetPlatformKey),
+ hostPlatformKey,
+ targetPlatformKey);
}
@Nullable
private static ImmutableBiMap<Label, Label> resolveToolchainLabels(
Environment env,
Set<Label> requiredToolchains,
- PlatformInfo executionPlatform,
- PlatformInfo targetPlatform,
- BuildConfigurationValue.Key configurationKey)
+ BuildConfigurationValue.Key configurationKey,
+ ConfiguredTargetKey executionPlatformKey,
+ ConfiguredTargetKey targetPlatformKey)
throws InterruptedException, ToolchainContextException {
// If there are no required toolchains, bail out early.
@@ -191,8 +228,8 @@
ToolchainResolutionValue.key(
configurationKey,
toolchainType,
- targetPlatform,
- ImmutableList.of(executionPlatform)));
+ targetPlatformKey,
+ ImmutableList.of(executionPlatformKey)));
}
Map<
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunctionTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunctionTest.java
index 35d0705..4f8b3bc 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunctionTest.java
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/ToolchainResolutionFunctionTest.java
@@ -16,12 +16,18 @@
import static com.google.common.truth.Truth.assertThat;
import static com.google.devtools.build.skyframe.EvaluationResultSubjectFactory.assertThatEvaluationResult;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.testing.EqualsTester;
+import com.google.devtools.build.lib.actions.Actions;
+import com.google.devtools.build.lib.analysis.ConfiguredTarget;
import com.google.devtools.build.lib.analysis.platform.PlatformInfo;
import com.google.devtools.build.lib.cmdline.Label;
+import com.google.devtools.build.lib.collect.nestedset.NestedSetBuilder;
+import com.google.devtools.build.lib.collect.nestedset.Order;
import com.google.devtools.build.lib.rules.platform.ToolchainTestCase;
import com.google.devtools.build.lib.skyframe.util.SkyframeExecutorTestUtils;
import com.google.devtools.build.skyframe.EvaluationResult;
@@ -33,9 +39,34 @@
/** Tests for {@link ToolchainResolutionValue} and {@link ToolchainResolutionFunction}. */
@RunWith(JUnit4.class)
public class ToolchainResolutionFunctionTest extends ToolchainTestCase {
+ private static final ConfiguredTargetKey LINUX_CTKEY =
+ ConfiguredTargetKey.of(Label.parseAbsoluteUnchecked("//linux:key"), null, false);
+ private static final ConfiguredTargetKey MAC_CTKEY =
+ ConfiguredTargetKey.of(Label.parseAbsoluteUnchecked("//mac:key"), null, false);
+
+ private static ConfiguredTargetValue createConfiguredTargetValue(
+ ConfiguredTarget configuredTarget) {
+ return new ConfiguredTargetValue(
+ configuredTarget,
+ new Actions.GeneratingActions(ImmutableList.of(), ImmutableMap.of()),
+ NestedSetBuilder.emptySet(Order.STABLE_ORDER),
+ /*removeActionsAfterEvaluation=*/ false);
+ }
private EvaluationResult<ToolchainResolutionValue> invokeToolchainResolution(SkyKey key)
throws InterruptedException {
+ ConfiguredTarget mockLinuxTarget = mock(ConfiguredTarget.class);
+ when(mockLinuxTarget.get(PlatformInfo.SKYLARK_CONSTRUCTOR)).thenReturn(linuxPlatform);
+ ConfiguredTarget mockMacTarget = mock(ConfiguredTarget.class);
+ when(mockMacTarget.get(PlatformInfo.SKYLARK_CONSTRUCTOR)).thenReturn(macPlatform);
+ getSkyframeExecutor()
+ .getDifferencerForTesting()
+ .inject(
+ ImmutableMap.of(
+ LINUX_CTKEY,
+ createConfiguredTargetValue(mockLinuxTarget),
+ MAC_CTKEY,
+ createConfiguredTargetValue(mockMacTarget)));
try {
getSkyframeExecutor().getSkyframeBuildView().enableAnalysis(true);
@@ -50,7 +81,7 @@
public void testResolution_singleExecutionPlatform() throws Exception {
SkyKey key =
ToolchainResolutionValue.key(
- targetConfigKey, testToolchainType, linuxPlatform, ImmutableList.of(macPlatform));
+ targetConfigKey, testToolchainType, LINUX_CTKEY, ImmutableList.of(MAC_CTKEY));
EvaluationResult<ToolchainResolutionValue> result = invokeToolchainResolution(key);
assertThatEvaluationResult(result).hasNoError();
@@ -78,8 +109,8 @@
ToolchainResolutionValue.key(
targetConfigKey,
testToolchainType,
- linuxPlatform,
- ImmutableList.of(linuxPlatform, macPlatform));
+ LINUX_CTKEY,
+ ImmutableList.of(LINUX_CTKEY, MAC_CTKEY));
EvaluationResult<ToolchainResolutionValue> result = invokeToolchainResolution(key);
assertThatEvaluationResult(result).hasNoError();
@@ -100,7 +131,7 @@
SkyKey key =
ToolchainResolutionValue.key(
- targetConfigKey, testToolchainType, linuxPlatform, ImmutableList.of(macPlatform));
+ targetConfigKey, testToolchainType, LINUX_CTKEY, ImmutableList.of(MAC_CTKEY));
EvaluationResult<ToolchainResolutionValue> result = invokeToolchainResolution(key);
assertThatEvaluationResult(result)