Call getValueOrThrow for unary StateMachine lookups.

This is more efficient than getValuesAndExceptions when it applies.

* Adds more test coverage for all the different kinds of lookups.

PiperOrigin-RevId: 514698671
Change-Id: I3160901426930ee9d61a3185903d5cac2dcc13ea
diff --git a/src/main/java/com/google/devtools/build/skyframe/state/Driver.java b/src/main/java/com/google/devtools/build/skyframe/state/Driver.java
index cfce35899..b0e6a48 100644
--- a/src/main/java/com/google/devtools/build/skyframe/state/Driver.java
+++ b/src/main/java/com/google/devtools/build/skyframe/state/Driver.java
@@ -94,11 +94,18 @@
       }
 
       // Performs lookups for any newly added keys.
-      SkyframeLookupResult result =
-          env.getValuesAndExceptions(Lists.transform(newlyAdded, Lookup::key));
-      for (var lookup : newlyAdded) {
-        if (!result.queryDep(lookup.key(), lookup)) {
-          pending.add(lookup); // Unhandled exceptions also end up here.
+      if (newlyAdded.size() == 1) { // Uses a lower overhead lookup for the unary case.
+        var onlyLookup = newlyAdded.get(0);
+        if (!onlyLookup.doLookup(env)) {
+          pending.add(onlyLookup);
+        }
+      } else {
+        SkyframeLookupResult result =
+            env.getValuesAndExceptions(Lists.transform(newlyAdded, Lookup::key));
+        for (var lookup : newlyAdded) {
+          if (!result.queryDep(lookup.key(), lookup)) {
+            pending.add(lookup); // Unhandled exceptions also end up here.
+          }
         }
       }
       newlyAdded.clear(); // Every entry is either done or has moved to pending.
diff --git a/src/main/java/com/google/devtools/build/skyframe/state/Lookup.java b/src/main/java/com/google/devtools/build/skyframe/state/Lookup.java
index bcf0c69..9a5238e 100644
--- a/src/main/java/com/google/devtools/build/skyframe/state/Lookup.java
+++ b/src/main/java/com/google/devtools/build/skyframe/state/Lookup.java
@@ -13,6 +13,7 @@
 // limitations under the License.
 package com.google.devtools.build.skyframe.state;
 
+import com.google.devtools.build.skyframe.SkyFunction.Environment;
 import com.google.devtools.build.skyframe.SkyKey;
 import com.google.devtools.build.skyframe.SkyValue;
 import com.google.devtools.build.skyframe.SkyframeLookupResult;
@@ -21,7 +22,7 @@
 /** Captures information about a lookup requested by a state machine. */
 abstract class Lookup implements SkyframeLookupResult.QueryDepCallback {
   private final TaskTreeNode parent;
-  private final SkyKey key;
+  final SkyKey key;
 
   private Lookup(TaskTreeNode parent, SkyKey key) {
     this.parent = parent;
@@ -32,6 +33,17 @@
     return key;
   }
 
+  /**
+   * Performs a lookup directly against the environment.
+   *
+   * <p>This is more efficient than {@link Environment#getValuesAndExceptions} when there is only
+   * one key at a time.
+   *
+   * @return true if a value was available or an exception was handled. Note: this is false for
+   *     unhandled exceptions.
+   */
+  abstract boolean doLookup(Environment env) throws InterruptedException;
+
   @Override
   public final void acceptValue(SkyKey unusedKey, SkyValue value) {
     acceptValue(value);
@@ -60,6 +72,16 @@
     }
 
     @Override
+    boolean doLookup(Environment env) throws InterruptedException {
+      var value = env.getValue(key);
+      if (value == null) {
+        return false;
+      }
+      acceptValue(key, value);
+      return true;
+    }
+
+    @Override
     void acceptValue(SkyValue value) {
       sink.accept(value);
     }
@@ -85,6 +107,25 @@
     }
 
     @Override
+    boolean doLookup(Environment env) throws InterruptedException {
+      SkyValue value;
+      try {
+        if ((value = env.getValueOrThrow(key(), exceptionClass)) == null) {
+          return false;
+        }
+        acceptValue(key, value);
+      } catch (Exception e) {
+        if (e instanceof InterruptedException) {
+          throw (InterruptedException) e;
+        }
+        if (!tryHandleException(e)) {
+          throw new IllegalArgumentException("Unexpected exception for " + key(), e);
+        }
+      }
+      return true;
+    }
+
+    @Override
     void acceptValue(SkyValue value) {
       sink.accept(value, /* exception= */ null);
     }
@@ -118,6 +159,25 @@
     }
 
     @Override
+    boolean doLookup(Environment env) throws InterruptedException {
+      SkyValue value;
+      try {
+        if ((value = env.getValueOrThrow(key(), exceptionClass1, exceptionClass2)) == null) {
+          return false;
+        }
+        acceptValue(key, value);
+      } catch (Exception e) {
+        if (e instanceof InterruptedException) {
+          throw (InterruptedException) e;
+        }
+        if (!tryHandleException(e)) {
+          throw new IllegalArgumentException("Unexpected exception for " + key(), e);
+        }
+      }
+      return true;
+    }
+
+    @Override
     void acceptValue(SkyValue value) {
       sink.accept(value, /* e1= */ null, /* e2= */ null);
     }
@@ -159,6 +219,26 @@
     }
 
     @Override
+    boolean doLookup(Environment env) throws InterruptedException {
+      SkyValue value;
+      try {
+        if ((value = env.getValueOrThrow(key(), exceptionClass1, exceptionClass2, exceptionClass3))
+            == null) {
+          return false;
+        }
+        acceptValue(key, value);
+      } catch (Exception e) {
+        if (e instanceof InterruptedException) {
+          throw (InterruptedException) e;
+        }
+        if (!tryHandleException(e)) {
+          throw new IllegalArgumentException("Unexpected exception for " + key(), e);
+        }
+      }
+      return true;
+    }
+
+    @Override
     void acceptValue(SkyValue value) {
       sink.accept(value, /* e1= */ null, /* e2= */ null, /* e3= */ null);
     }
diff --git a/src/test/java/com/google/devtools/build/skyframe/StateMachineTest.java b/src/test/java/com/google/devtools/build/skyframe/StateMachineTest.java
index e1cda76..ec3b429 100644
--- a/src/test/java/com/google/devtools/build/skyframe/StateMachineTest.java
+++ b/src/test/java/com/google/devtools/build/skyframe/StateMachineTest.java
@@ -13,6 +13,8 @@
 // limitations under the License.
 package com.google.devtools.build.skyframe;
 
+import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.base.Preconditions.checkState;
 import static com.google.common.truth.Truth.assertThat;
 import static com.google.devtools.build.skyframe.EvaluationResultSubjectFactory.assertThatEvaluationResult;
 import static org.junit.Assert.assertThrows;
@@ -607,6 +609,72 @@
     }
   }
 
+  @Test
+  public void lookupValue_matrix(
+      @TestParameter LookupType lookupType,
+      @TestParameter boolean useBatch,
+      @TestParameter boolean keepGoing)
+      throws InterruptedException {
+    var sink = new OmniSink();
+    var unused =
+        defineRootMachine(
+            () -> {
+              var lookup = lookupType.newLookup(KEY_A1, sink);
+              if (!useBatch) {
+                return lookup;
+              }
+              return new BatchPair(lookup);
+            });
+
+    assertThat(eval(ROOT_KEY, keepGoing).get(ROOT_KEY)).isEqualTo(DONE_VALUE);
+    assertThat(sink.value).isEqualTo(VALUE_A1);
+    assertThat(sink.exception).isNull();
+  }
+
+  @Test
+  public void lookupErrors_matrix(
+      @TestParameter LookupType lookupType,
+      @TestParameter ExceptionCase exceptionCase,
+      @TestParameter boolean useBatch,
+      @TestParameter boolean keepGoing)
+      throws InterruptedException {
+    var exception = exceptionCase.getException();
+    tester
+        .getOrCreate(KEY_A1)
+        .unsetConstantValue()
+        .setBuilder(
+            (k, env) -> {
+              throw new ExceptionWrapper(exception);
+            });
+    var sink = new OmniSink();
+    var unused =
+        defineRootMachine(
+            () -> {
+              var lookup = lookupType.newLookup(KEY_A1, sink);
+              if (!useBatch) {
+                return lookup;
+              }
+              return new BatchPair(lookup);
+            });
+    var result = eval(ROOT_KEY, keepGoing);
+    assertThat(sink.value).isNull();
+    if (exceptionCase.exceptionOrdinal() > lookupType.exceptionCount()) {
+      // The exception was not handled.
+      assertThat(sink.exception).isNull();
+      assertThat(result.get(ROOT_KEY)).isNull();
+      assertThatEvaluationResult(result).hasSingletonErrorThat(KEY_A1);
+      return;
+    }
+    assertThat(sink.exception).isEqualTo(exception);
+    if (keepGoing) {
+      // The error is completely handled.
+      assertThat(result.get(ROOT_KEY)).isEqualTo(DONE_VALUE);
+      return;
+    }
+    assertThatEvaluationResult(result).hasSingletonErrorThat(KEY_A1);
+    assertThat(result.get(ROOT_KEY)).isNull();
+  }
+
   /**
    * Sink for {@link SkyValue}s.
    *
@@ -626,4 +694,289 @@
       return value;
     }
   }
+
+  // -------------------- Helpers for lookupErrors_matrix --------------------
+  private static class Exception1 extends Exception {}
+
+  private static class Exception2 extends Exception {}
+
+  private static class Exception3 extends Exception {}
+
+  private static class Exception4 extends Exception {}
+
+  private static class ExceptionWrapper extends SkyFunctionException {
+    private ExceptionWrapper(Exception e) {
+      super(e, Transience.PERSISTENT);
+    }
+  }
+
+  /**
+   * Adds a secondary lookup in parallel with a given {@link StateMachine}.
+   *
+   * <p>This causes the {@link Environment#getValuesAndExceptions} codepath in {@link Driver#drive}
+   * to be used instead of the {@link Lookup#doLookup} when there is a single lookup.
+   */
+  private static class BatchPair implements StateMachine {
+    private final StateMachine other;
+
+    private BatchPair(StateMachine other) {
+      this.other = other;
+    }
+
+    @Override
+    @Nullable
+    public StateMachine step(Tasks tasks, ExtendedEventHandler listener) {
+      tasks.enqueue(other);
+      tasks.lookUp(KEY_B1, v -> assertThat(v).isEqualTo(VALUE_B1));
+      return null;
+    }
+  }
+
+  private static class Lookup0 implements StateMachine {
+    private final SkyKey key;
+    private final Consumer<SkyValue> sink;
+
+    private Lookup0(SkyKey key, Consumer<SkyValue> sink) {
+      this.key = key;
+      this.sink = sink;
+    }
+
+    @Override
+    @Nullable
+    public StateMachine step(Tasks tasks, ExtendedEventHandler listener) {
+      tasks.lookUp(key, sink);
+      return null;
+    }
+  }
+
+  private static class Lookup1 implements StateMachine {
+    private final SkyKey key;
+    private final ValueOrExceptionSink<Exception1> sink;
+
+    private Lookup1(SkyKey key, ValueOrExceptionSink<Exception1> sink) {
+      this.key = key;
+      this.sink = sink;
+    }
+
+    @Override
+    @Nullable
+    public StateMachine step(Tasks tasks, ExtendedEventHandler listener) {
+      tasks.lookUp(key, Exception1.class, sink);
+      return null;
+    }
+  }
+
+  private static class Lookup2 implements StateMachine {
+    private final SkyKey key;
+    private final ValueOrException2Sink<Exception1, Exception2> sink;
+
+    private Lookup2(SkyKey key, ValueOrException2Sink<Exception1, Exception2> sink) {
+      this.key = key;
+      this.sink = sink;
+    }
+
+    @Override
+    @Nullable
+    public StateMachine step(Tasks tasks, ExtendedEventHandler listener) {
+      tasks.lookUp(key, Exception1.class, Exception2.class, sink);
+      return null;
+    }
+  }
+
+  private static class Lookup3 implements StateMachine {
+    private final SkyKey key;
+    private final ValueOrException3Sink<Exception1, Exception2, Exception3> sink;
+
+    private Lookup3(SkyKey key, ValueOrException3Sink<Exception1, Exception2, Exception3> sink) {
+      this.key = key;
+      this.sink = sink;
+    }
+
+    @Override
+    @Nullable
+    public StateMachine step(Tasks tasks, ExtendedEventHandler listener) {
+      tasks.lookUp(key, Exception1.class, Exception2.class, Exception3.class, sink);
+      return null;
+    }
+  }
+
+  private static class OmniSink
+      implements Consumer<SkyValue>,
+          StateMachine.ValueOrExceptionSink<Exception1>,
+          StateMachine.ValueOrException2Sink<Exception1, Exception2>,
+          StateMachine.ValueOrException3Sink<Exception1, Exception2, Exception3> {
+    private SkyValue value;
+    private Exception exception;
+
+    @Override
+    public void accept(SkyValue value) {
+      checkState(this.value == null && exception == null);
+      this.value = checkNotNull(value);
+    }
+
+    @Override
+    public void accept(@Nullable SkyValue value, @Nullable Exception1 exception1) {
+      checkState(this.value == null && exception == null);
+      if (value != null) {
+        this.value = value;
+        return;
+      }
+      if (exception1 != null) {
+        checkState(value == null);
+        this.exception = exception1;
+      }
+    }
+
+    @Override
+    public void accept(
+        @Nullable SkyValue value,
+        @Nullable Exception1 exception1,
+        @Nullable Exception2 exception2) {
+      checkState(this.value == null && exception == null);
+      if (value != null) {
+        checkState(exception1 == null && exception2 == null);
+        this.value = value;
+        return;
+      }
+      if (exception1 != null) {
+        checkState(value == null && exception2 == null);
+        this.exception = exception1;
+        return;
+      }
+      if (exception2 != null) {
+        checkState(value == null && exception1 == null);
+        this.exception = exception2;
+      }
+    }
+
+    @Override
+    public void accept(
+        @Nullable SkyValue value,
+        @Nullable Exception1 exception1,
+        @Nullable Exception2 exception2,
+        @Nullable Exception3 exception3) {
+      checkState(this.value == null && exception == null);
+      if (value != null) {
+        checkState(exception1 == null && exception2 == null && exception3 == null);
+        this.value = value;
+        return;
+      }
+      if (exception1 != null) {
+        checkState(value == null && exception2 == null && exception3 == null);
+        this.exception = exception1;
+        return;
+      }
+      if (exception2 != null) {
+        checkState(value == null && exception1 == null && exception3 == null);
+        this.exception = exception2;
+        return;
+      }
+      if (exception3 != null) {
+        checkState(value == null && exception1 == null && exception2 == null);
+        this.exception = exception3;
+      }
+    }
+  }
+
+  private enum LookupType {
+    LOOKUP0 {
+      @Override
+      StateMachine newLookup(SkyKey key, OmniSink sink) {
+        return new Lookup0(key, sink);
+      }
+
+      @Override
+      int exceptionCount() {
+        return 0;
+      }
+    },
+    LOOKUP1 {
+      @Override
+      StateMachine newLookup(SkyKey key, OmniSink sink) {
+        return new Lookup1(key, sink);
+      }
+
+      @Override
+      int exceptionCount() {
+        return 1;
+      }
+    },
+    LOOKUP2 {
+      @Override
+      StateMachine newLookup(SkyKey key, OmniSink sink) {
+        return new Lookup2(key, sink);
+      }
+
+      @Override
+      int exceptionCount() {
+        return 2;
+      }
+    },
+    LOOKUP3 {
+      @Override
+      StateMachine newLookup(SkyKey key, OmniSink sink) {
+        return new Lookup3(key, sink);
+      }
+
+      @Override
+      int exceptionCount() {
+        return 3;
+      }
+    };
+
+    abstract StateMachine newLookup(SkyKey key, OmniSink sink);
+
+    abstract int exceptionCount();
+  }
+
+  private enum ExceptionCase {
+    EXCEPTION1 {
+      @Override
+      Exception getException() {
+        return new Exception1();
+      }
+
+      @Override
+      int exceptionOrdinal() {
+        return 1;
+      }
+    },
+    EXCEPTION2 {
+      @Override
+      Exception getException() {
+        return new Exception2();
+      }
+
+      @Override
+      int exceptionOrdinal() {
+        return 2;
+      }
+    },
+    EXCEPTION3 {
+      @Override
+      Exception getException() {
+        return new Exception3();
+      }
+
+      @Override
+      int exceptionOrdinal() {
+        return 3;
+      }
+    },
+    EXCEPTION4 {
+      @Override
+      Exception getException() {
+        return new Exception4();
+      }
+
+      @Override
+      int exceptionOrdinal() {
+        return 4;
+      }
+    };
+
+    abstract Exception getException();
+
+    abstract int exceptionOrdinal();
+  }
 }