Add `ConcurrentLookupEnvironment` subclass to support concurrently driving multiple `StateMachine`s under the same root `StateMachine`. The new `ConcurrentLookupEnvironment` will be used when evaluating all globs under the same package using a single `GLOBS` key. PiperOrigin-RevId: 604408379 Change-Id: I4fe707c08a35af679b1295ac72df3e8a4a1aa3fc
diff --git a/src/main/java/com/google/devtools/build/skyframe/ConcurrentLookupEnvironment.java b/src/main/java/com/google/devtools/build/skyframe/ConcurrentLookupEnvironment.java new file mode 100644 index 0000000..a7f6724 --- /dev/null +++ b/src/main/java/com/google/devtools/build/skyframe/ConcurrentLookupEnvironment.java
@@ -0,0 +1,91 @@ +// Copyright 2024 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.google.devtools.build.skyframe; + +import com.google.devtools.build.skyframe.SkyFunction.LookupEnvironment; +import javax.annotation.Nullable; + +/** + * Used when {@link SkyFunction#compute} needs to concurrently access {@link LookupEnvironment}. + * + * <p>A common use case can be {@link SkyFunction#compute} drives several in-parallel {@link + * com.google.devtools.build.skyframe.state.StateMachine}s, which could concurrently query deps in + * {@link LookupEnvironment}. + */ +final class ConcurrentLookupEnvironment implements LookupEnvironment { + private final LookupEnvironment delegate; + + ConcurrentLookupEnvironment(LookupEnvironment delegate) { + this.delegate = delegate; + } + + @Nullable + @Override + public synchronized SkyValue getValue(SkyKey valueName) throws InterruptedException { + return delegate.getValue(valueName); + } + + @Nullable + @Override + public synchronized <E extends Exception> SkyValue getValueOrThrow( + SkyKey depKey, Class<E> exceptionClass) throws E, InterruptedException { + return delegate.getValueOrThrow(depKey, exceptionClass); + } + + @Nullable + @Override + public synchronized <E1 extends Exception, E2 extends Exception> SkyValue getValueOrThrow( + SkyKey depKey, Class<E1> exceptionClass1, Class<E2> exceptionClass2) + throws E1, E2, InterruptedException { + return delegate.getValueOrThrow(depKey, exceptionClass1, exceptionClass2); + } + + @Nullable + @Override + public synchronized <E1 extends Exception, E2 extends Exception, E3 extends Exception> + SkyValue getValueOrThrow( + SkyKey depKey, + Class<E1> exceptionClass1, + Class<E2> exceptionClass2, + Class<E3> exceptionClass3) + throws E1, E2, E3, InterruptedException { + return delegate.getValueOrThrow(depKey, exceptionClass1, exceptionClass2, exceptionClass3); + } + + @Nullable + @Override + public synchronized < + E1 extends Exception, E2 extends Exception, E3 extends Exception, E4 extends Exception> + SkyValue getValueOrThrow( + SkyKey depKey, + Class<E1> exceptionClass1, + Class<E2> exceptionClass2, + Class<E3> exceptionClass3, + Class<E4> exceptionClass4) + throws E1, E2, E3, E4, InterruptedException { + return delegate.getValueOrThrow( + depKey, exceptionClass1, exceptionClass2, exceptionClass3, exceptionClass4); + } + + @Override + public synchronized SkyframeLookupResult getValuesAndExceptions( + Iterable<? extends SkyKey> depKeys) throws InterruptedException { + return delegate.getValuesAndExceptions(depKeys); + } + + @Override + public synchronized SkyframeLookupResult getLookupHandleForPreviouslyRequestedDeps() { + return delegate.getLookupHandleForPreviouslyRequestedDeps(); + } +}
diff --git a/src/test/java/com/google/devtools/build/skyframe/StateMachineTest.java b/src/test/java/com/google/devtools/build/skyframe/StateMachineTest.java index 4d9b8ef..141e838 100644 --- a/src/test/java/com/google/devtools/build/skyframe/StateMachineTest.java +++ b/src/test/java/com/google/devtools/build/skyframe/StateMachineTest.java
@@ -17,6 +17,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.truth.Truth.assertThat; import static com.google.devtools.build.skyframe.EvaluationResultSubjectFactory.assertThatEvaluationResult; +import static java.util.concurrent.TimeUnit.NANOSECONDS; import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; @@ -27,6 +28,7 @@ import com.google.devtools.build.skyframe.GraphTester.StringValue; import com.google.devtools.build.skyframe.SkyFunction.Environment; import com.google.devtools.build.skyframe.SkyFunction.Environment.SkyKeyComputeState; +import com.google.devtools.build.skyframe.SkyFunction.LookupEnvironment; import com.google.devtools.build.skyframe.state.Driver; import com.google.devtools.build.skyframe.state.StateMachine; import com.google.devtools.build.skyframe.state.StateMachineEvaluatorForTesting; @@ -35,6 +37,11 @@ import com.google.devtools.build.skyframe.state.ValueOrExceptionProducer; import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -1401,4 +1408,87 @@ abstract int exceptionOrdinal(); } + + private static class StateMachineWithMultipleConcurrentDriverWrapper + implements SkyKeyComputeState { + private final List<Driver> drivers = new ArrayList<>(); + + private StateMachineWithMultipleConcurrentDriverWrapper(List<StateMachine> stateMachines) { + for (StateMachine stateMachine : stateMachines) { + drivers.add(new Driver(stateMachine)); + } + } + + private boolean drive(LookupEnvironment env) throws InterruptedException { + ExecutorService executor = Executors.newFixedThreadPool(4); + AtomicBoolean allCompletes = new AtomicBoolean(true); + ConcurrentLookupEnvironment concurrentEnvironment = new ConcurrentLookupEnvironment(env); + for (Driver driver : drivers) { + var unused = + executor.submit( + () -> { + try { + if (!driver.drive(concurrentEnvironment)) { + allCompletes.set(false); + } + } catch (InterruptedException e) { + throw new AssertionError("No exception is expected to be thrown", e); + } + }); + } + + executor.shutdown(); + executor.awaitTermination(Long.MAX_VALUE, NANOSECONDS); + return allCompletes.get(); + } + } + + private AtomicInteger defineRootMachineWithMultipleDriver( + Supplier<List<StateMachine>> rootMachineSupplier) { + AtomicInteger restartCount = new AtomicInteger(); + tester + .getOrCreate(ROOT_KEY) + .setBuilder( + (k, env) -> { + if (!env.getState( + () -> + new StateMachineWithMultipleConcurrentDriverWrapper( + rootMachineSupplier.get())) + .drive(env)) { + restartCount.getAndIncrement(); + return null; + } + return DONE_VALUE; + }); + return restartCount; + } + + private int evalMachineWithMultipleDrivers(Supplier<List<StateMachine>> rootMachineSupplier) + throws InterruptedException { + AtomicInteger restartCount = defineRootMachineWithMultipleDriver(rootMachineSupplier); + assertThat(eval(ROOT_KEY, /* keepGoing= */ false).get(ROOT_KEY)).isEqualTo(DONE_VALUE); + return restartCount.get(); + } + + @Test + public void test_multipleStateMachinesInParallelDriver() throws InterruptedException { + for (int i = 0; i < 100; ++i) { + graph.remove(ROOT_KEY); + graph.remove(KEY_A1); + graph.remove(KEY_A2); + var v1Sink = new SkyValueSink(); + var v2Sink = new SkyValueSink(); + var v3Sink = new SkyValueSink(); + var v4Sink = new SkyValueSink(); + var v5Sink = new SkyValueSink(); + var v6Sink = new SkyValueSink(); + Supplier<List<StateMachine>> factory = + () -> + Arrays.asList( + new TwoStepMachine(v1Sink, v2Sink), + new TwoStepMachine(v3Sink, v4Sink), + new TwoStepMachine(v5Sink, v6Sink)); + assertThat(evalMachineWithMultipleDrivers(factory)).isEqualTo(2); + } + } }