Simplify logical `SkyKeyComputeState` API.
Replace `SkyKeyComputeState SkyFunction.createNewSkyKeyComputeState()` & `SkyFunction.compute(SkyKey, SkyKeyComputeState, Environment)` with `T Environment#getState(Supplier<T>)`
This accomplishes three things:
(1) Improves readability of Blaze-on-Skyframe code by making everything more concise and making the `SkyFunction#compute` method less cluttered (now, due to this CL here and https://github.com/bazelbuild/bazel/commit/927b625e15d2beac5bfecae1d3c05783151a8ff3, the `SkyFunction` interface has only one method that needs to be implemented). Notably, see the final bullet point of "Implementation notes" in the description of https://github.com/bazelbuild/bazel/commit/ed279ab4fa2d4356be00b54266f56edcc5aeae78. Yes, this CL here is the solution to the issue there :)
(2) The new `SkyFunction.Environment#getState` is actually strictly more powerful than the old `SkyFunction#createNewSkyKeyComputeState` because `SkyFunction.Environment` is associated with a specific `SkyKey` while `SkyFunction` is, of course, not. This additional power may be useful for the memory performance TODO in `ActionExecutionFunction` (due to "shared actions").
(3) Improves readability of the Skyframe engine implementation by making the `SkyKeyComputeStateManager` abstraction unnecessary. I originally thought that abstraction would be useful for the high water mark memory concern (e.g. by maintaining per-SkyFunctionName caches), but I fixed that concern already via https://github.com/bazelbuild/bazel/commit/343ba438a93f8c56a7b524ac7a54666c57a969d9. See the description of that CL for details.
PiperOrigin-RevId: 418639155
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/PackageFunction.java b/src/main/java/com/google/devtools/build/lib/skyframe/PackageFunction.java
index 9d24aa5..34d4455 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/PackageFunction.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/PackageFunction.java
@@ -69,6 +69,7 @@
import com.google.devtools.build.lib.vfs.RootedPath;
import com.google.devtools.build.lib.vfs.UnixGlob;
import com.google.devtools.build.skyframe.SkyFunction;
+import com.google.devtools.build.skyframe.SkyFunction.Environment.SkyKeyComputeState;
import com.google.devtools.build.skyframe.SkyFunctionException;
import com.google.devtools.build.skyframe.SkyFunctionException.Transience;
import com.google.devtools.build.skyframe.SkyKey;
@@ -356,16 +357,6 @@
return new PackageValue(pkg);
}
- @Override
- public boolean supportsSkyKeyComputeState() {
- return true;
- }
-
- @Override
- public SkyKeyComputeState createNewSkyKeyComputeState() {
- return new State();
- }
-
private static class LoadedPackage {
private final Package.Builder builder;
private final Set<SkyKey> globDepKeys;
@@ -386,12 +377,6 @@
@Override
public SkyValue compute(SkyKey key, Environment env)
throws PackageFunctionException, InterruptedException {
- return compute(key, createNewSkyKeyComputeState(), env);
- }
-
- @Override
- public SkyValue compute(SkyKey key, SkyKeyComputeState skyKeyComputeState, Environment env)
- throws PackageFunctionException, InterruptedException {
PackageIdentifier packageId = (PackageIdentifier) key.argument();
if (packageId.equals(LabelConstants.EXTERNAL_PACKAGE_IDENTIFIER)) {
return getExternalPackage(env);
@@ -526,7 +511,7 @@
// like we do for .bzl files, so that we don't need to recompile BUILD files each time their
// .bzl dependencies change.
- State state = (State) skyKeyComputeState;
+ State state = env.getState(State::new);
LoadedPackage loadedPackage = state.loadedPackage;
if (loadedPackage == null) {
loadedPackage =
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/PackageLookupFunction.java b/src/main/java/com/google/devtools/build/lib/skyframe/PackageLookupFunction.java
index c1d4e4b..bb8c362 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/PackageLookupFunction.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/PackageLookupFunction.java
@@ -38,6 +38,7 @@
import com.google.devtools.build.lib.vfs.Root;
import com.google.devtools.build.lib.vfs.RootedPath;
import com.google.devtools.build.skyframe.SkyFunction;
+import com.google.devtools.build.skyframe.SkyFunction.Environment.SkyKeyComputeState;
import com.google.devtools.build.skyframe.SkyFunctionException;
import com.google.devtools.build.skyframe.SkyFunctionException.Transience;
import com.google.devtools.build.skyframe.SkyKey;
@@ -74,30 +75,14 @@
this.externalPackageHelper = externalPackageHelper;
}
- @Override
- public boolean supportsSkyKeyComputeState() {
- return true;
- }
-
private static class State implements SkyKeyComputeState {
private int packagePathEntryPos = 0;
private int buildFileNamePos = 0;
}
@Override
- public State createNewSkyKeyComputeState() {
- return new State();
- }
-
- @Override
public SkyValue compute(SkyKey skyKey, Environment env)
throws PackageLookupFunctionException, InterruptedException {
- return compute(skyKey, createNewSkyKeyComputeState(), env);
- }
-
- @Override
- public SkyValue compute(SkyKey skyKey, SkyKeyComputeState skyKeyComputeState, Environment env)
- throws PackageLookupFunctionException, InterruptedException {
PathPackageLocator pkgLocator = PrecomputedValue.PATH_PACKAGE_LOCATOR.get(env);
StarlarkSemantics semantics = PrecomputedValue.STARLARK_SEMANTICS.get(env);
@@ -135,7 +120,7 @@
return PackageLookupValue.DELETED_PACKAGE_VALUE;
}
- return findPackageByBuildFile((State) skyKeyComputeState, env, pkgLocator, packageKey);
+ return findPackageByBuildFile(env, pkgLocator, packageKey);
}
/**
@@ -169,8 +154,9 @@
@Nullable
private PackageLookupValue findPackageByBuildFile(
- State state, Environment env, PathPackageLocator pkgLocator, PackageIdentifier packageKey)
+ Environment env, PathPackageLocator pkgLocator, PackageIdentifier packageKey)
throws PackageLookupFunctionException, InterruptedException {
+ State state = env.getState(State::new);
while (state.packagePathEntryPos < pkgLocator.getPathEntries().size()) {
while (state.buildFileNamePos < buildFilesByPriority.size()) {
Root packagePathEntry = pkgLocator.getPathEntries().get(state.packagePathEntryPos);
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/ProgressEventSuppressingEnvironment.java b/src/main/java/com/google/devtools/build/lib/skyframe/ProgressEventSuppressingEnvironment.java
index da3be2a..f31f690 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/ProgressEventSuppressingEnvironment.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/ProgressEventSuppressingEnvironment.java
@@ -27,6 +27,7 @@
import com.google.devtools.build.skyframe.Version;
import java.util.List;
import java.util.Map;
+import java.util.function.Supplier;
import javax.annotation.Nullable;
/**
@@ -295,4 +296,9 @@
public boolean restartPermitted() {
return delegate.restartPermitted();
}
+
+ @Override
+ public <T extends SkyKeyComputeState> T getState(Supplier<T> stateSupplier) {
+ return delegate.getState(stateSupplier);
+ }
}
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/SkyFunctionEnvironmentForTesting.java b/src/main/java/com/google/devtools/build/lib/skyframe/SkyFunctionEnvironmentForTesting.java
index ad536cd..23579ad 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/SkyFunctionEnvironmentForTesting.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/SkyFunctionEnvironmentForTesting.java
@@ -27,6 +27,7 @@
import com.google.devtools.build.skyframe.ValueOrUntypedException;
import java.util.List;
import java.util.Map;
+import java.util.function.Supplier;
/**
* A {@link SkyFunction.Environment} backed by a {@link SkyframeExecutor} that can be used to
@@ -87,4 +88,9 @@
public boolean restartPermitted() {
return false;
}
+
+ @Override
+ public <T extends SkyKeyComputeState> T getState(Supplier<T> stateSupplier) {
+ return stateSupplier.get();
+ }
}
diff --git a/src/main/java/com/google/devtools/build/lib/skyframe/StateInformingSkyFunctionEnvironment.java b/src/main/java/com/google/devtools/build/lib/skyframe/StateInformingSkyFunctionEnvironment.java
index 84315fc..3cd151f 100644
--- a/src/main/java/com/google/devtools/build/lib/skyframe/StateInformingSkyFunctionEnvironment.java
+++ b/src/main/java/com/google/devtools/build/lib/skyframe/StateInformingSkyFunctionEnvironment.java
@@ -27,6 +27,7 @@
import com.google.devtools.build.skyframe.Version;
import java.util.List;
import java.util.Map;
+import java.util.function.Supplier;
import javax.annotation.Nullable;
/** An environment that wraps each call to its delegate by informing injected {@link Informee}s. */
@@ -382,6 +383,11 @@
return delegate.restartPermitted();
}
+ @Override
+ public <T extends SkyKeyComputeState> T getState(Supplier<T> stateSupplier) {
+ return delegate.getState(stateSupplier);
+ }
+
interface Informee {
void inform() throws InterruptedException;
}
diff --git a/src/main/java/com/google/devtools/build/skyframe/AbstractExceptionalParallelEvaluator.java b/src/main/java/com/google/devtools/build/skyframe/AbstractExceptionalParallelEvaluator.java
index 896ca98..c97d1f7 100644
--- a/src/main/java/com/google/devtools/build/skyframe/AbstractExceptionalParallelEvaluator.java
+++ b/src/main/java/com/google/devtools/build/skyframe/AbstractExceptionalParallelEvaluator.java
@@ -32,7 +32,6 @@
import com.google.devtools.build.skyframe.MemoizingEvaluator.EmittedEventState;
import com.google.devtools.build.skyframe.NodeEntry.DependencyState;
import com.google.devtools.build.skyframe.QueryableGraph.Reason;
-import com.google.devtools.build.skyframe.SkyFunction.SkyKeyComputeState;
import com.google.devtools.build.skyframe.SkyFunctionException.ReifiedSkyFunctionException;
import java.io.IOException;
import java.util.ArrayList;
@@ -372,7 +371,7 @@
throws InterruptedException {
// Remove all the compute states so as to give the SkyFunctions a chance to do fresh
// computations during error bubbling.
- skyKeyComputeStateManager.removeAll();
+ stateCache.invalidateAll();
Set<SkyKey> rootValues = ImmutableSet.copyOf(roots);
ErrorInfo error = leafFailure;
@@ -499,17 +498,12 @@
bubbleErrorInfo,
ImmutableSet.of(),
evaluatorContext);
- SkyKeyComputeState skyKeyComputeStateToUse = skyKeyComputeStateManager.maybeGet(parent);
externalInterrupt = externalInterrupt || Thread.currentThread().isInterrupted();
boolean completedRun = false;
try {
// This build is only to check if the parent node can give us a better error. We don't
// care about a return value.
- if (skyKeyComputeStateToUse == null) {
- factory.compute(parent, env);
- } else {
- factory.compute(parent, skyKeyComputeStateToUse, env);
- }
+ factory.compute(parent, env);
completedRun = true;
} catch (InterruptedException interruptedException) {
logger.atInfo().withCause(interruptedException).log("Interrupted during %s eval", parent);
diff --git a/src/main/java/com/google/devtools/build/skyframe/AbstractParallelEvaluator.java b/src/main/java/com/google/devtools/build/skyframe/AbstractParallelEvaluator.java
index c608466..b7b5b11e 100644
--- a/src/main/java/com/google/devtools/build/skyframe/AbstractParallelEvaluator.java
+++ b/src/main/java/com/google/devtools/build/skyframe/AbstractParallelEvaluator.java
@@ -13,6 +13,8 @@
// limitations under the License.
package com.google.devtools.build.skyframe;
+import com.github.benmanes.caffeine.cache.Cache;
+import com.github.benmanes.caffeine.cache.Caffeine;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
@@ -41,8 +43,8 @@
import com.google.devtools.build.skyframe.NodeEntry.DependencyState;
import com.google.devtools.build.skyframe.NodeEntry.DirtyState;
import com.google.devtools.build.skyframe.QueryableGraph.Reason;
+import com.google.devtools.build.skyframe.SkyFunction.Environment.SkyKeyComputeState;
import com.google.devtools.build.skyframe.SkyFunction.Restart;
-import com.google.devtools.build.skyframe.SkyFunction.SkyKeyComputeState;
import com.google.devtools.build.skyframe.SkyFunctionEnvironment.UndonePreviouslyRequestedDeps;
import com.google.devtools.build.skyframe.SkyFunctionException.ReifiedSkyFunctionException;
import com.google.devtools.build.skyframe.ThinNodeEntry.DirtyType;
@@ -91,7 +93,7 @@
*/
private final AtomicInteger globalEnqueuedIndex = new AtomicInteger(Integer.MIN_VALUE);
- protected final SkyKeyComputeStateManager skyKeyComputeStateManager;
+ protected final Cache<SkyKey, SkyKeyComputeState> stateCache = Caffeine.newBuilder().build();
AbstractParallelEvaluator(
ProcessableGraph graph,
@@ -113,7 +115,7 @@
Supplier<QuiescingExecutor> quiescingExecutorSupplier =
getQuiescingExecutorSupplier(
executorService, cpuHeavySkyKeysThreadPoolSize, executionJobsThreadPoolSize);
- evaluatorContext =
+ this.evaluatorContext =
new ParallelEvaluatorContext(
graph,
graphVersion,
@@ -128,8 +130,8 @@
() ->
new NodeEntryVisitor(
quiescingExecutorSupplier.get(), progressReceiver, Evaluate::new),
- /*mergingSkyframeAnalysisExecutionPhases=*/ executionJobsThreadPoolSize > 0);
- this.skyKeyComputeStateManager = new SkyKeyComputeStateManager(skyFunctions);
+ /*mergingSkyframeAnalysisExecutionPhases=*/ executionJobsThreadPoolSize > 0,
+ stateCache);
}
private Supplier<QuiescingExecutor> getQuiescingExecutorSupplier(
@@ -551,15 +553,11 @@
state);
SkyValue value = null;
- SkyKeyComputeState skyKeyComputeStateToUse = skyKeyComputeStateManager.maybeGet(skyKey);
long startTimeNanos = BlazeClock.instance().nanoTime();
try {
try {
evaluatorContext.getProgressReceiver().stateStarting(skyKey, NodeState.COMPUTE);
- value =
- skyKeyComputeStateToUse == null
- ? factory.compute(skyKey, env)
- : factory.compute(skyKey, skyKeyComputeStateToUse, env);
+ value = factory.compute(skyKey, env);
} finally {
evaluatorContext.getProgressReceiver().stateEnding(skyKey, NodeState.COMPUTE);
long elapsedTimeNanos = BlazeClock.instance().nanoTime() - startTimeNanos;
@@ -573,9 +571,7 @@
}
}
} catch (final SkyFunctionException builderException) {
- if (skyKeyComputeStateToUse != null) {
- skyKeyComputeStateManager.remove(skyKey);
- }
+ stateCache.invalidate(skyKey);
ReifiedSkyFunctionException reifiedBuilderException =
new ReifiedSkyFunctionException(builderException);
@@ -642,7 +638,7 @@
}
if (maybeHandleRestart(skyKey, state, value)) {
- skyKeyComputeStateManager.removeAll();
+ stateCache.invalidateAll();
cancelExternalDeps(env);
evaluatorContext.getVisitor().enqueueEvaluation(skyKey, determineRestartPriority());
return;
@@ -653,9 +649,7 @@
GroupedListHelper<SkyKey> newDirectDeps = env.getNewlyRequestedDeps();
if (value != null) {
- if (skyKeyComputeStateToUse != null) {
- skyKeyComputeStateManager.remove(skyKey);
- }
+ stateCache.invalidate(skyKey);
Preconditions.checkState(
!env.valuesMissing(),
diff --git a/src/main/java/com/google/devtools/build/skyframe/ParallelEvaluator.java b/src/main/java/com/google/devtools/build/skyframe/ParallelEvaluator.java
index a15dfcc..943897b 100644
--- a/src/main/java/com/google/devtools/build/skyframe/ParallelEvaluator.java
+++ b/src/main/java/com/google/devtools/build/skyframe/ParallelEvaluator.java
@@ -81,7 +81,7 @@
public <T extends SkyValue> EvaluationResult<T> eval(Iterable<? extends SkyKey> skyKeys)
throws InterruptedException {
unnecessaryTemporaryStateDropperReceiver.onEvaluationStarted(
- skyKeyComputeStateManager::removeAll);
+ () -> evaluatorContext.stateCache().invalidateAll());
try {
return this.evalExceptionally(skyKeys);
} finally {
diff --git a/src/main/java/com/google/devtools/build/skyframe/ParallelEvaluatorContext.java b/src/main/java/com/google/devtools/build/skyframe/ParallelEvaluatorContext.java
index 62da390..fb7311d 100644
--- a/src/main/java/com/google/devtools/build/skyframe/ParallelEvaluatorContext.java
+++ b/src/main/java/com/google/devtools/build/skyframe/ParallelEvaluatorContext.java
@@ -13,6 +13,7 @@
// limitations under the License.
package com.google.devtools.build.skyframe;
+import com.github.benmanes.caffeine.cache.Cache;
import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
@@ -23,6 +24,7 @@
import com.google.devtools.build.lib.events.ExtendedEventHandler.Postable;
import com.google.devtools.build.skyframe.MemoizingEvaluator.EmittedEventState;
import com.google.devtools.build.skyframe.QueryableGraph.Reason;
+import com.google.devtools.build.skyframe.SkyFunction.Environment.SkyKeyComputeState;
import java.util.Map;
import javax.annotation.Nullable;
@@ -47,6 +49,7 @@
private final ErrorInfoManager errorInfoManager;
private final GraphInconsistencyReceiver graphInconsistencyReceiver;
private final boolean mergingSkyframeAnalysisExecutionPhases;
+ private final Cache<SkyKey, SkyKeyComputeState> stateCache;
/**
* The visitor managing the thread pool. Used to enqueue parents when an entry is finished, and,
@@ -80,7 +83,8 @@
ErrorInfoManager errorInfoManager,
GraphInconsistencyReceiver graphInconsistencyReceiver,
Supplier<NodeEntryVisitor> visitorSupplier,
- boolean mergingSkyframeAnalysisExecutionPhases) {
+ boolean mergingSkyframeAnalysisExecutionPhases,
+ Cache<SkyKey, SkyKeyComputeState> stateCache) {
this.graph = graph;
this.graphVersion = graphVersion;
this.skyFunctions = skyFunctions;
@@ -97,6 +101,7 @@
this.errorInfoManager = errorInfoManager;
this.visitorSupplier = Suppliers.memoize(visitorSupplier);
this.mergingSkyframeAnalysisExecutionPhases = mergingSkyframeAnalysisExecutionPhases;
+ this.stateCache = stateCache;
}
Map<SkyKey, ? extends NodeEntry> getBatchValues(
@@ -194,6 +199,10 @@
return mergingSkyframeAnalysisExecutionPhases;
}
+ Cache<SkyKey, SkyKeyComputeState> stateCache() {
+ return stateCache;
+ }
+
/** Receives the events from the NestedSet and delegates to the reporter. */
private static final class NestedSetEventReceiver
implements NestedSetVisitor.Receiver<TaggedEvents> {
diff --git a/src/main/java/com/google/devtools/build/skyframe/RecordingSkyFunctionEnvironment.java b/src/main/java/com/google/devtools/build/skyframe/RecordingSkyFunctionEnvironment.java
index 661e7ef..f433adf 100644
--- a/src/main/java/com/google/devtools/build/skyframe/RecordingSkyFunctionEnvironment.java
+++ b/src/main/java/com/google/devtools/build/skyframe/RecordingSkyFunctionEnvironment.java
@@ -20,6 +20,7 @@
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
+import java.util.function.Supplier;
import javax.annotation.Nullable;
/** An environment that can observe the deps requested through getValue(s) calls. */
@@ -409,4 +410,9 @@
public boolean restartPermitted() {
return delegate.restartPermitted();
}
+
+ @Override
+ public <T extends SkyKeyComputeState> T getState(Supplier<T> stateSupplier) {
+ return delegate.getState(stateSupplier);
+ }
}
diff --git a/src/main/java/com/google/devtools/build/skyframe/SkyFunction.java b/src/main/java/com/google/devtools/build/skyframe/SkyFunction.java
index 95c95bc..fa58c1d 100644
--- a/src/main/java/com/google/devtools/build/skyframe/SkyFunction.java
+++ b/src/main/java/com/google/devtools/build/skyframe/SkyFunction.java
@@ -22,6 +22,7 @@
import com.google.devtools.build.lib.util.GroupedList;
import java.util.List;
import java.util.Map;
+import java.util.function.Supplier;
import javax.annotation.Nullable;
/**
@@ -65,109 +66,6 @@
throws SkyFunctionException, InterruptedException;
/**
- * Same as {@link #compute(SkyKey, Environment)}, except with a {@link SkyKeyComputeState}.
- *
- * <p>The {@link SkyKeyComputeState} instance will either be freshly created via {@link
- * #createNewSkyKeyComputeState()}, or will be the same exact instance used on the previous call
- * to this method for the same {@link SkyKey}. This allows {@link SkyFunction} implementations to
- * avoid redoing the same intermediate work over-and-over again on each {@link #compute} call for
- * the same {@link SkyKey}, due to missing Skyframe dependencies. For example,
- *
- * <pre>
- * class MyFunction implements SkyFunction {
- * public SkyValue compute(SkyKey skyKey, Environment env) throws InterruptedException {
- * int x = (Integer) skyKey.argument();
- * SkyKey myDependencyKey = getSkyKeyForValue(someExpensiveComputation(x));
- * SkyValue myDependencyValue = env.getValue(myDependencyKey);
- * if (env.valuesMissing()) {
- * return null;
- * }
- * return createMyValue(myDependencyValue);
- * }
- * }
- * </pre>
- *
- * <p>If the dependency was missing, then we'll end up evaluating {@code
- * someExpensiveComputation(x)} twice, once on the initial call to {@link #compute} and then again
- * on the subsequent call after the dependency was computed.
- *
- * <p>To fix this, we can use a mutable {@link SkyKeyComputeState} implementation and store the
- * result of {@code someExpensiveComputation(x)} in there:
- *
- * <pre>
- * class MyFunction implements SkyFunction {
- * public boolean supportsSkyKeyComputeState() {
- * return true;
- * }
- *
- * private static class State implements SkyKeyComputeState {
- * private Integer result;
- * }
- *
- * public State createNewSkyKeyComputeState() {
- * return new State();
- * }
- *
- * public SkyValue compute(
- * SkyKey skyKey, SkyKeyComputeState skyKeyComputeState, Environment env)
- * throws InterruptedException {
- * int x = (Integer) skyKey.argument();
- * State state = (State) skyKeyComputeState;
- * if (state.result == null) {
- * state.result = someExpensiveComputation(x);
- * }
- * SkyKey myDependencyKey = getSkyKeyForValue(state.result);
- * SkyValue myDependencyValue = env.getValue(myDependencyKey);
- * if (env.valuesMissing()) {
- * return null;
- * }
- * return createMyValue(myDependencyValue);
- * }
- * }
- * </pre>
- *
- * <p>Now {@code someExpensiveComputation(x)} gets called exactly once for each {@code x}!
- *
- * <p>Important: There's no guarantee the {@link SkyKeyComputeState} instance will be the same
- * exact instance used on the previous call to this method for the same {@link SkyKey}. Therefore,
- * {@link SkyFunction} implementations should make use of this feature only as a performance
- * optimization.
- *
- * <p>TODO(b/209701268): Reimplement Blaze-on-Skyframe SkyFunctions that would benefit from this
- * sort of optimization.
- */
- @ThreadSafe
- @Nullable
- default SkyValue compute(SkyKey skyKey, SkyKeyComputeState state, Environment env)
- throws SkyFunctionException, InterruptedException {
- throw new UnsupportedOperationException();
- }
-
- /**
- * If this returns {@code true}, then the Skyframe engine will invoke {@link #compute(SkyKey,
- * SkyKeyComputeState, Environment)} instead of {@link #compute(SkyKey, Environment)}.
- */
- default boolean supportsSkyKeyComputeState() {
- return false;
- }
-
- /**
- * Container for data stored in between calls to {@link #compute} for the same {@link SkyKey}.
- *
- * <p>See the javadoc of {@link #compute(SkyKey, SkyKeyComputeState, Environment)} for motivation
- * and an example.
- */
- interface SkyKeyComputeState {}
-
- /**
- * Returns a new {@link SkyKeyComputeState} instance to use for {@link #compute(SkyKey,
- * SkyKeyComputeState, Environment)}.
- */
- default SkyKeyComputeState createNewSkyKeyComputeState() {
- throw new UnsupportedOperationException();
- }
-
- /**
* Extracts a tag (target label) from a SkyKey if it has one. Otherwise return {@code null}.
*
* <p>The tag is used for filtering out non-error event messages that do not match --output_filter
@@ -555,5 +453,77 @@
* true}.
*/
boolean restartPermitted();
+
+ /**
+ * Container for data stored in between calls to {@link #compute} for the same {@link SkyKey}.
+ *
+ * <p>See the javadoc of {@link #getState} for motivation and an example.
+ */
+ interface SkyKeyComputeState {}
+
+ /**
+ * Returns (or creates and returns) a "state" object to assist with temporary computations for
+ * the {@link SkyKey} associated with this {@link Environment}.
+ *
+ * <p>The {@link SkyKeyComputeState} will either be freshly created via the given {@link
+ * Supplier}, or will be the same exact instance used on the previous call to this method for
+ * the same {@link SkyKey}. This allows {@link SkyFunction} implementations to avoid redoing the
+ * same intermediate work over-and-over again on each {@link #compute} call for the same {@link
+ * SkyKey}, due to missing Skyframe dependencies. For example,
+ *
+ * <pre>
+ * class MyFunction implements SkyFunction {
+ * public SkyValue compute(SkyKey skyKey, Environment env) throws InterruptedException {
+ * int x = (Integer) skyKey.argument();
+ * SkyKey myDependencyKey = getSkyKeyForValue(someExpensiveComputation(x));
+ * SkyValue myDependencyValue = env.getValue(myDependencyKey);
+ * if (env.valuesMissing()) {
+ * return null;
+ * }
+ * return createMyValue(myDependencyValue);
+ * }
+ * }
+ * </pre>
+ *
+ * <p>If the dependency was missing, then we'll end up evaluating {@code
+ * someExpensiveComputation(x)} twice, once on the initial call to {@link #compute} and then
+ * again on the subsequent call after the dependency was computed.
+ *
+ * <p>To fix this, we can use a mutable {@link SkyKeyComputeState} implementation and store the
+ * result of {@code someExpensiveComputation(x)} in there:
+ *
+ * <pre>
+ * class MyFunction implements SkyFunction {
+ * private static class State implements SkyKeyComputeState {
+ * private Integer result;
+ * }
+ *
+ * public SkyValue compute(SkyKey skyKey, Environment env) throws InterruptedException {
+ * int x = (Integer) skyKey.argument();
+ * State state = env.getState(State::new);
+ * if (state.result == null) {
+ * state.result = someExpensiveComputation(x);
+ * }
+ * SkyKey myDependencyKey = getSkyKeyForValue(state.result);
+ * SkyValue myDependencyValue = env.getValue(myDependencyKey);
+ * if (env.valuesMissing()) {
+ * return null;
+ * }
+ * return createMyValue(myDependencyValue);
+ * }
+ * }
+ * </pre>
+ *
+ * <p>Now {@code someExpensiveComputation(x)} gets called exactly once for each {@code x}!
+ *
+ * <p>Important: There's no guarantee the{@link SkyKeyComputeState} instance will be the same
+ * exact instance used on the previous call to this method for the same {@link SkyKey}. The
+ * above example was just illustrating the best-case outcome. Therefore, {@link SkyFunction}
+ * implementations should make use of this feature only as a performance optimization.
+ *
+ * <p>TODO(b/209701268): Reimplement Blaze-on-Skyframe SkyFunctions that would benefit from this
+ * sort of optimization.
+ */
+ <T extends SkyKeyComputeState> T getState(Supplier<T> stateSupplier);
}
}
diff --git a/src/main/java/com/google/devtools/build/skyframe/SkyFunctionEnvironment.java b/src/main/java/com/google/devtools/build/skyframe/SkyFunctionEnvironment.java
index 2af7b38..ab67a96 100644
--- a/src/main/java/com/google/devtools/build/skyframe/SkyFunctionEnvironment.java
+++ b/src/main/java/com/google/devtools/build/skyframe/SkyFunctionEnvironment.java
@@ -37,6 +37,7 @@
import com.google.devtools.build.skyframe.EvaluationProgressReceiver.EvaluationState;
import com.google.devtools.build.skyframe.NodeEntry.DependencyState;
import com.google.devtools.build.skyframe.QueryableGraph.Reason;
+import com.google.devtools.build.skyframe.SkyFunction.Environment.SkyKeyComputeState;
import com.google.devtools.build.skyframe.proto.GraphInconsistency.Inconsistency;
import java.io.IOException;
import java.util.ArrayList;
@@ -49,6 +50,7 @@
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
+import java.util.function.Supplier;
import javax.annotation.Nullable;
/** A {@link SkyFunction.Environment} implementation for {@link ParallelEvaluator}. */
@@ -963,6 +965,12 @@
return evaluatorContext.restartPermitted();
}
+ @SuppressWarnings("unchecked")
+ @Override
+ public <T extends SkyKeyComputeState> T getState(Supplier<T> stateSupplier) {
+ return (T) evaluatorContext.stateCache().get(skyKey, k -> stateSupplier.get());
+ }
+
/** Thrown during environment construction if previously requested deps are no longer done. */
static class UndonePreviouslyRequestedDeps extends Exception {
private final ImmutableList<SkyKey> depKeys;
diff --git a/src/main/java/com/google/devtools/build/skyframe/SkyKeyComputeStateManager.java b/src/main/java/com/google/devtools/build/skyframe/SkyKeyComputeStateManager.java
deleted file mode 100644
index 75343a1..0000000
--- a/src/main/java/com/google/devtools/build/skyframe/SkyKeyComputeStateManager.java
+++ /dev/null
@@ -1,48 +0,0 @@
-// Copyright 2021 The Bazel Authors. All rights reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-package com.google.devtools.build.skyframe;
-
-import com.github.benmanes.caffeine.cache.Caffeine;
-import com.github.benmanes.caffeine.cache.LoadingCache;
-import com.google.common.collect.ImmutableMap;
-import com.google.devtools.build.skyframe.SkyFunction.SkyKeyComputeState;
-import javax.annotation.Nullable;
-
-/** Helper class used to support {@link SkyKeyComputeState}. */
-class SkyKeyComputeStateManager {
- private final ImmutableMap<SkyFunctionName, SkyFunction> skyFunctions;
- private final LoadingCache<SkyKey, SkyKeyComputeState> cache;
-
- SkyKeyComputeStateManager(ImmutableMap<SkyFunctionName, SkyFunction> skyFunctions) {
- this.skyFunctions = skyFunctions;
- this.cache =
- Caffeine.newBuilder()
- .build(k -> skyFunctions.get(k.functionName()).createNewSkyKeyComputeState());
- }
-
- @Nullable
- SkyKeyComputeState maybeGet(SkyKey skyKey) {
- return skyFunctions.get(skyKey.functionName()).supportsSkyKeyComputeState()
- ? cache.get(skyKey)
- : null;
- }
-
- void remove(SkyKey skyKey) {
- cache.invalidate(skyKey);
- }
-
- void removeAll() {
- cache.invalidateAll();
- }
-}
diff --git a/src/test/java/com/google/devtools/build/lib/actions/util/ActionsTestUtil.java b/src/test/java/com/google/devtools/build/lib/actions/util/ActionsTestUtil.java
index dbf8fcd..db9c074 100644
--- a/src/test/java/com/google/devtools/build/lib/actions/util/ActionsTestUtil.java
+++ b/src/test/java/com/google/devtools/build/lib/actions/util/ActionsTestUtil.java
@@ -119,6 +119,7 @@
import java.util.Queue;
import java.util.Set;
import java.util.function.Function;
+import java.util.function.Supplier;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
@@ -388,6 +389,11 @@
public boolean restartPermitted() {
return false;
}
+
+ @Override
+ public <T extends SkyKeyComputeState> T getState(Supplier<T> stateSupplier) {
+ return stateSupplier.get();
+ }
}
@SerializationConstant
diff --git a/src/test/java/com/google/devtools/build/skyframe/ParallelEvaluatorTest.java b/src/test/java/com/google/devtools/build/skyframe/ParallelEvaluatorTest.java
index 60f0626..6b4cb5a 100644
--- a/src/test/java/com/google/devtools/build/skyframe/ParallelEvaluatorTest.java
+++ b/src/test/java/com/google/devtools/build/skyframe/ParallelEvaluatorTest.java
@@ -57,7 +57,7 @@
import com.google.devtools.build.skyframe.GraphTester.StringValue;
import com.google.devtools.build.skyframe.NotifyingHelper.EventType;
import com.google.devtools.build.skyframe.NotifyingHelper.Order;
-import com.google.devtools.build.skyframe.SkyFunction.SkyKeyComputeState;
+import com.google.devtools.build.skyframe.SkyFunction.Environment.SkyKeyComputeState;
import com.google.devtools.build.skyframe.SkyFunctionException.Transience;
import com.google.testing.junit.testparameterinjector.TestParameter;
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
@@ -2650,84 +2650,61 @@
AtomicReference<WeakReference<State>> stateForKey2Ref = new AtomicReference<>();
AtomicReference<WeakReference<State>> stateForKey3Ref = new AtomicReference<>();
SkyFunction skyFunctionForTest =
- new SkyFunction() {
- // That supports compute state
- @Override
- public boolean supportsSkyKeyComputeState() {
- return true;
- }
+ // Whose #compute is such that
+ (skyKey, env) -> {
+ State state = env.getState(State::new);
+ state.usageCount++;
+ int numCallsForKey = (int) numCalls.incrementAndGet(skyKey);
+ // The number of calls to #compute is expected to be equal to the number of usages of
+ // the state for that key,
+ assertThat(state.usageCount).isEqualTo(numCallsForKey);
+ if (skyKey.equals(key1)) {
+ // And the semantics for key1 are:
- @Override
- public State createNewSkyKeyComputeState() {
- return new State();
- }
-
- // And crashes if the normal #compute method is called
- @Override
- public SkyValue compute(SkyKey skyKey, Environment env) {
- fail();
- throw new IllegalStateException();
- }
-
- // And whose #compute is such that
- @Override
- public SkyValue compute(
- SkyKey skyKey, SkyKeyComputeState skyKeyComputeState, Environment env)
- throws InterruptedException {
- State state = (State) skyKeyComputeState;
- state.usageCount++;
- int numCallsForKey = (int) numCalls.incrementAndGet(skyKey);
- // The number of calls to #compute is expected to be equal to the number of usages of
- // the state for that key,
- assertThat(state.usageCount).isEqualTo(numCallsForKey);
- if (skyKey.equals(key1)) {
- // And the semantics for key1 are:
-
- // The state for key1 is expected to be the first one created (since key1 is expected
- // to be the first node we attempt to compute).
- assertThat(state.instanceCount).isEqualTo(1);
- // And key1 declares a dep on key2,
- if (env.getValue(key2) == null) {
- // (And that dep is expected to be missing on the initial #compute call for key1)
- assertThat(numCallsForKey).isEqualTo(1);
- return null;
- }
- // And if that dep is not missing, then we expect:
- // - We're on the second #compute call for key1
- assertThat(numCallsForKey).isEqualTo(2);
- // - The state for key2 should have been eligible for GC. This is because the node
- // for key2 must have been fully computed, meaning its compute state is no longer
- // needed, and so ParallelEvaluator ought to have made it eligible for GC.
- GcFinalization.awaitClear(stateForKey2Ref.get());
- return new StringValue("value1");
- } else if (skyKey.equals(key2)) {
- // And the semantics for key2 are:
-
- // The state for key2 is expected to be the second one created.
- assertThat(state.instanceCount).isEqualTo(2);
- stateForKey2Ref.set(new WeakReference<>(state));
- // And key2 declares a dep on key3,
- if (env.getValue(key3) == null) {
- // (And that dep is expected to be missing on the initial #compute call for key2)
- assertThat(numCallsForKey).isEqualTo(1);
- return null;
- }
- // And if that dep is not missing, then we expect the same sort of things we expected
- // for key1 in this situation.
- assertThat(numCallsForKey).isEqualTo(2);
- GcFinalization.awaitClear(stateForKey3Ref.get());
- return new StringValue("value2");
- } else if (skyKey.equals(key3)) {
- // And the semantics for key3 are:
-
- // The state for key3 is expected to be the third one created.
- assertThat(state.instanceCount).isEqualTo(3);
- stateForKey3Ref.set(new WeakReference<>(state));
- // And key3 declares no deps.
- return new StringValue("value3");
+ // The state for key1 is expected to be the first one created (since key1 is expected
+ // to be the first node we attempt to compute).
+ assertThat(state.instanceCount).isEqualTo(1);
+ // And key1 declares a dep on key2,
+ if (env.getValue(key2) == null) {
+ // (And that dep is expected to be missing on the initial #compute call for key1)
+ assertThat(numCallsForKey).isEqualTo(1);
+ return null;
}
- throw new IllegalStateException();
+ // And if that dep is not missing, then we expect:
+ // - We're on the second #compute call for key1
+ assertThat(numCallsForKey).isEqualTo(2);
+ // - The state for key2 should have been eligible for GC. This is because the node
+ // for key2 must have been fully computed, meaning its compute state is no longer
+ // needed, and so ParallelEvaluator ought to have made it eligible for GC.
+ GcFinalization.awaitClear(stateForKey2Ref.get());
+ return new StringValue("value1");
+ } else if (skyKey.equals(key2)) {
+ // And the semantics for key2 are:
+
+ // The state for key2 is expected to be the second one created.
+ assertThat(state.instanceCount).isEqualTo(2);
+ stateForKey2Ref.set(new WeakReference<>(state));
+ // And key2 declares a dep on key3,
+ if (env.getValue(key3) == null) {
+ // (And that dep is expected to be missing on the initial #compute call for key2)
+ assertThat(numCallsForKey).isEqualTo(1);
+ return null;
+ }
+ // And if that dep is not missing, then we expect the same sort of things we expected
+ // for key1 in this situation.
+ assertThat(numCallsForKey).isEqualTo(2);
+ GcFinalization.awaitClear(stateForKey3Ref.get());
+ return new StringValue("value2");
+ } else if (skyKey.equals(key3)) {
+ // And the semantics for key3 are:
+
+ // The state for key3 is expected to be the third one created.
+ assertThat(state.instanceCount).isEqualTo(3);
+ stateForKey3Ref.set(new WeakReference<>(state));
+ // And key3 declares no deps.
+ return new StringValue("value3");
}
+ throw new IllegalStateException();
};
tester.putSkyFunction(SkyKeyForSkyKeyComputeStateTests.FUNCTION_NAME, skyFunctionForTest);
@@ -2764,71 +2741,48 @@
CountDownLatch key3SleepingLatch = new CountDownLatch(1);
AtomicBoolean onNormalEvaluation = new AtomicBoolean(true);
SkyFunction skyFunctionForTest =
- new SkyFunction() {
- // That supports compute state
- @Override
- public boolean supportsSkyKeyComputeState() {
- return true;
- }
+ // Whose #compute is such that
+ (skyKey, env) -> {
+ if (onNormalEvaluation.get()) {
+ // When we're on the normal evaluation:
- @Override
- public State createNewSkyKeyComputeState() {
- return new State();
- }
+ State state = env.getState(State::new);
+ if (skyKey.equals(key1)) {
+ // For key1:
- // And crashes if the normal #compute method is called
- @Override
- public SkyValue compute(SkyKey skyKey, Environment env) {
- fail();
- throw new IllegalStateException();
- }
+ stateForKey1Ref.set(new WeakReference<>(state));
+ // We declare a dep on key.
+ return env.getValue(key2);
+ } else if (skyKey.equals(key2)) {
+ // For key2:
- // And whose #compute is such that
- @Override
- public SkyValue compute(
- SkyKey skyKey, SkyKeyComputeState skyKeyComputeState, Environment env)
- throws InterruptedException, SkyFunctionException {
- if (onNormalEvaluation.get()) {
- // When we're on the normal evaluation:
+ // We wait for the thread computing key3 to be sleeping
+ key3SleepingLatch.await();
+ // And then we throw an error, which will fail the normal evaluation and trigger
+ // error bubbling.
+ onNormalEvaluation.set(false);
+ throw new SkyFunctionExceptionForTest("normal evaluation");
+ } else if (skyKey.equals(key3)) {
+ // For key3:
- State state = (State) skyKeyComputeState;
- if (skyKey.equals(key1)) {
- // For key1:
-
- stateForKey1Ref.set(new WeakReference<>(state));
- // We declare a dep on key.
- return env.getValue(key2);
- } else if (skyKey.equals(key2)) {
- // For key2:
-
- // We wait for the thread computing key3 to be sleeping
- key3SleepingLatch.await();
- // And then we throw an error, which will fail the normal evaluation and trigger
- // error bubbling.
- onNormalEvaluation.set(false);
- throw new SkyFunctionExceptionForTest("normal evaluation");
- } else if (skyKey.equals(key3)) {
- // For key3:
-
- stateForKey3Ref.set(new WeakReference<>(state));
- key3SleepingLatch.countDown();
- // We sleep forever. (To be interrupted by ParallelEvaluator when the normal
- // evaluation fails).
- Thread.sleep(Long.MAX_VALUE);
- }
- throw new IllegalStateException();
- } else {
- // When we're in error bubbling:
-
- // The states for the nodes from normal evaluation should have been eligible for GC.
- // This is because ParallelEvaluator ought to have them eligible for GC before
- // starting error bubbling.
- GcFinalization.awaitClear(stateForKey1Ref.get());
- GcFinalization.awaitClear(stateForKey3Ref.get());
-
- // We bubble up a unique error message.
- throw new SkyFunctionExceptionForTest("error bubbling for " + skyKey.argument());
+ stateForKey3Ref.set(new WeakReference<>(state));
+ key3SleepingLatch.countDown();
+ // We sleep forever. (To be interrupted by ParallelEvaluator when the normal
+ // evaluation fails).
+ Thread.sleep(Long.MAX_VALUE);
}
+ throw new IllegalStateException();
+ } else {
+ // When we're in error bubbling:
+
+ // The states for the nodes from normal evaluation should have been eligible for GC.
+ // This is because ParallelEvaluator ought to have them eligible for GC before
+ // starting error bubbling.
+ GcFinalization.awaitClear(stateForKey1Ref.get());
+ GcFinalization.awaitClear(stateForKey3Ref.get());
+
+ // We bubble up a unique error message.
+ throw new SkyFunctionExceptionForTest("error bubbling for " + skyKey.argument());
}
};
@@ -2882,64 +2836,41 @@
// And a SkyFunction for these nodes,
SkyFunction skyFunctionForTest =
- new SkyFunction() {
- // That supports compute staten
- @Override
- public boolean supportsSkyKeyComputeState() {
- return true;
- }
+ // Whose #compute is such that
+ (skyKey, env) -> {
+ State state = env.getState(State::new);
+ if (skyKey.equals(key1)) {
+ // The semantics for key1 are:
- @Override
- public State createNewSkyKeyComputeState() {
- return new State();
- }
+ // We declare a dep on key2.
+ if (env.getValue(key2) == null) {
+ // If key2 is missing, that means we're on the initial #compute call for key1,
+ // And so we expect the compute state to be the first instance ever.
+ assertThat(state.instanceCount).isEqualTo(1);
+ stateForKey1Ref.set(new WeakReference<>(state));
- // And crashes if the normal #compute method is called,
- @Override
- public SkyValue compute(SkyKey skyKey, Environment env) {
- fail();
- throw new IllegalStateException();
- }
+ return null;
+ } else {
+ // But if key2 is not missing, that means we're on the subsequent #compute call for
+ // key1. That means we expect the compute state to be the third instance ever,
+ // because...
+ assertThat(state.instanceCount).isEqualTo(3);
- // And whose #compute is such that
- @Override
- public SkyValue compute(
- SkyKey skyKey, SkyKeyComputeState skyKeyComputeState, Environment env)
- throws InterruptedException {
- State state = (State) skyKeyComputeState;
- if (skyKey.equals(key1)) {
- // The semantics for key1 are:
-
- // We declare a dep on key2.
- if (env.getValue(key2) == null) {
- // If key2 is missing, that means we're on the initial #compute call for key1,
- // And so we expect the compute state to be the first instance ever.
- assertThat(state.instanceCount).isEqualTo(1);
- stateForKey1Ref.set(new WeakReference<>(state));
-
- return null;
- } else {
- // But if key2 is not missing, that means we're on the subsequent #compute call for
- // key1. That means we expect the compute state to be the third instance ever,
- // because...
- assertThat(state.instanceCount).isEqualTo(3);
-
- return new StringValue("value1");
- }
- } else if (skyKey.equals(key2)) {
- // ... The semantics for key2 are:
-
- // Drop all compute states.
- dropperRef.get().drop();
- // Confirm the old compute state for key1 was GC'd.
- GcFinalization.awaitClear(stateForKey1Ref.get());
- // Also confirm key2's compute state is the second instance ever.
- assertThat(state.instanceCount).isEqualTo(2);
-
- return new StringValue("value2");
+ return new StringValue("value1");
}
- throw new IllegalStateException();
+ } else if (skyKey.equals(key2)) {
+ // ... The semantics for key2 are:
+
+ // Drop all compute states.
+ dropperRef.get().drop();
+ // Confirm the old compute state for key1 was GC'd.
+ GcFinalization.awaitClear(stateForKey1Ref.get());
+ // Also confirm key2's compute state is the second instance ever.
+ assertThat(state.instanceCount).isEqualTo(2);
+
+ return new StringValue("value2");
}
+ throw new IllegalStateException();
};
tester.putSkyFunction(SkyKeyForSkyKeyComputeStateTests.FUNCTION_NAME, skyFunctionForTest);