Add `ConcurrentLookupEnvironment` subclass to support concurrently driving multiple `StateMachine`s under the same root `StateMachine`. The new `ConcurrentLookupEnvironment` will be used when evaluating all globs under the same package using a single `GLOBS` key.

PiperOrigin-RevId: 604408379
Change-Id: I4fe707c08a35af679b1295ac72df3e8a4a1aa3fc
diff --git a/src/main/java/com/google/devtools/build/skyframe/ConcurrentLookupEnvironment.java b/src/main/java/com/google/devtools/build/skyframe/ConcurrentLookupEnvironment.java
new file mode 100644
index 0000000..a7f6724
--- /dev/null
+++ b/src/main/java/com/google/devtools/build/skyframe/ConcurrentLookupEnvironment.java
@@ -0,0 +1,91 @@
+// Copyright 2024 The Bazel Authors. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package com.google.devtools.build.skyframe;
+
+import com.google.devtools.build.skyframe.SkyFunction.LookupEnvironment;
+import javax.annotation.Nullable;
+
+/**
+ * Used when {@link SkyFunction#compute} needs to concurrently access {@link LookupEnvironment}.
+ *
+ * <p>A common use case can be {@link SkyFunction#compute} drives several in-parallel {@link
+ * com.google.devtools.build.skyframe.state.StateMachine}s, which could concurrently query deps in
+ * {@link LookupEnvironment}.
+ */
+final class ConcurrentLookupEnvironment implements LookupEnvironment {
+  private final LookupEnvironment delegate;
+
+  ConcurrentLookupEnvironment(LookupEnvironment delegate) {
+    this.delegate = delegate;
+  }
+
+  @Nullable
+  @Override
+  public synchronized SkyValue getValue(SkyKey valueName) throws InterruptedException {
+    return delegate.getValue(valueName);
+  }
+
+  @Nullable
+  @Override
+  public synchronized <E extends Exception> SkyValue getValueOrThrow(
+      SkyKey depKey, Class<E> exceptionClass) throws E, InterruptedException {
+    return delegate.getValueOrThrow(depKey, exceptionClass);
+  }
+
+  @Nullable
+  @Override
+  public synchronized <E1 extends Exception, E2 extends Exception> SkyValue getValueOrThrow(
+      SkyKey depKey, Class<E1> exceptionClass1, Class<E2> exceptionClass2)
+      throws E1, E2, InterruptedException {
+    return delegate.getValueOrThrow(depKey, exceptionClass1, exceptionClass2);
+  }
+
+  @Nullable
+  @Override
+  public synchronized <E1 extends Exception, E2 extends Exception, E3 extends Exception>
+      SkyValue getValueOrThrow(
+          SkyKey depKey,
+          Class<E1> exceptionClass1,
+          Class<E2> exceptionClass2,
+          Class<E3> exceptionClass3)
+          throws E1, E2, E3, InterruptedException {
+    return delegate.getValueOrThrow(depKey, exceptionClass1, exceptionClass2, exceptionClass3);
+  }
+
+  @Nullable
+  @Override
+  public synchronized <
+          E1 extends Exception, E2 extends Exception, E3 extends Exception, E4 extends Exception>
+      SkyValue getValueOrThrow(
+          SkyKey depKey,
+          Class<E1> exceptionClass1,
+          Class<E2> exceptionClass2,
+          Class<E3> exceptionClass3,
+          Class<E4> exceptionClass4)
+          throws E1, E2, E3, E4, InterruptedException {
+    return delegate.getValueOrThrow(
+        depKey, exceptionClass1, exceptionClass2, exceptionClass3, exceptionClass4);
+  }
+
+  @Override
+  public synchronized SkyframeLookupResult getValuesAndExceptions(
+      Iterable<? extends SkyKey> depKeys) throws InterruptedException {
+    return delegate.getValuesAndExceptions(depKeys);
+  }
+
+  @Override
+  public synchronized SkyframeLookupResult getLookupHandleForPreviouslyRequestedDeps() {
+    return delegate.getLookupHandleForPreviouslyRequestedDeps();
+  }
+}
diff --git a/src/test/java/com/google/devtools/build/skyframe/StateMachineTest.java b/src/test/java/com/google/devtools/build/skyframe/StateMachineTest.java
index 4d9b8ef..141e838 100644
--- a/src/test/java/com/google/devtools/build/skyframe/StateMachineTest.java
+++ b/src/test/java/com/google/devtools/build/skyframe/StateMachineTest.java
@@ -17,6 +17,7 @@
 import static com.google.common.base.Preconditions.checkState;
 import static com.google.common.truth.Truth.assertThat;
 import static com.google.devtools.build.skyframe.EvaluationResultSubjectFactory.assertThatEvaluationResult;
+import static java.util.concurrent.TimeUnit.NANOSECONDS;
 import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.fail;
 
@@ -27,6 +28,7 @@
 import com.google.devtools.build.skyframe.GraphTester.StringValue;
 import com.google.devtools.build.skyframe.SkyFunction.Environment;
 import com.google.devtools.build.skyframe.SkyFunction.Environment.SkyKeyComputeState;
+import com.google.devtools.build.skyframe.SkyFunction.LookupEnvironment;
 import com.google.devtools.build.skyframe.state.Driver;
 import com.google.devtools.build.skyframe.state.StateMachine;
 import com.google.devtools.build.skyframe.state.StateMachineEvaluatorForTesting;
@@ -35,6 +37,11 @@
 import com.google.devtools.build.skyframe.state.ValueOrExceptionProducer;
 import com.google.testing.junit.testparameterinjector.TestParameter;
 import com.google.testing.junit.testparameterinjector.TestParameterInjector;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
@@ -1401,4 +1408,87 @@
 
     abstract int exceptionOrdinal();
   }
+
+  private static class StateMachineWithMultipleConcurrentDriverWrapper
+      implements SkyKeyComputeState {
+    private final List<Driver> drivers = new ArrayList<>();
+
+    private StateMachineWithMultipleConcurrentDriverWrapper(List<StateMachine> stateMachines) {
+      for (StateMachine stateMachine : stateMachines) {
+        drivers.add(new Driver(stateMachine));
+      }
+    }
+
+    private boolean drive(LookupEnvironment env) throws InterruptedException {
+      ExecutorService executor = Executors.newFixedThreadPool(4);
+      AtomicBoolean allCompletes = new AtomicBoolean(true);
+      ConcurrentLookupEnvironment concurrentEnvironment = new ConcurrentLookupEnvironment(env);
+      for (Driver driver : drivers) {
+        var unused =
+            executor.submit(
+                () -> {
+                  try {
+                    if (!driver.drive(concurrentEnvironment)) {
+                      allCompletes.set(false);
+                    }
+                  } catch (InterruptedException e) {
+                    throw new AssertionError("No exception is expected to be thrown", e);
+                  }
+                });
+      }
+
+      executor.shutdown();
+      executor.awaitTermination(Long.MAX_VALUE, NANOSECONDS);
+      return allCompletes.get();
+    }
+  }
+
+  private AtomicInteger defineRootMachineWithMultipleDriver(
+      Supplier<List<StateMachine>> rootMachineSupplier) {
+    AtomicInteger restartCount = new AtomicInteger();
+    tester
+        .getOrCreate(ROOT_KEY)
+        .setBuilder(
+            (k, env) -> {
+              if (!env.getState(
+                      () ->
+                          new StateMachineWithMultipleConcurrentDriverWrapper(
+                              rootMachineSupplier.get()))
+                  .drive(env)) {
+                restartCount.getAndIncrement();
+                return null;
+              }
+              return DONE_VALUE;
+            });
+    return restartCount;
+  }
+
+  private int evalMachineWithMultipleDrivers(Supplier<List<StateMachine>> rootMachineSupplier)
+      throws InterruptedException {
+    AtomicInteger restartCount = defineRootMachineWithMultipleDriver(rootMachineSupplier);
+    assertThat(eval(ROOT_KEY, /* keepGoing= */ false).get(ROOT_KEY)).isEqualTo(DONE_VALUE);
+    return restartCount.get();
+  }
+
+  @Test
+  public void test_multipleStateMachinesInParallelDriver() throws InterruptedException {
+    for (int i = 0; i < 100; ++i) {
+      graph.remove(ROOT_KEY);
+      graph.remove(KEY_A1);
+      graph.remove(KEY_A2);
+      var v1Sink = new SkyValueSink();
+      var v2Sink = new SkyValueSink();
+      var v3Sink = new SkyValueSink();
+      var v4Sink = new SkyValueSink();
+      var v5Sink = new SkyValueSink();
+      var v6Sink = new SkyValueSink();
+      Supplier<List<StateMachine>> factory =
+          () ->
+              Arrays.asList(
+                  new TwoStepMachine(v1Sink, v2Sink),
+                  new TwoStepMachine(v3Sink, v4Sink),
+                  new TwoStepMachine(v5Sink, v6Sink));
+      assertThat(evalMachineWithMultipleDrivers(factory)).isEqualTo(2);
+    }
+  }
 }