Validate `Exception` Type in `getOrThrow`.

PiperOrigin-RevId: 438123169
diff --git a/src/main/java/com/google/devtools/build/skyframe/AbstractSkyFunctionEnvironment.java b/src/main/java/com/google/devtools/build/skyframe/AbstractSkyFunctionEnvironment.java
index 016aeb2..26ab7a4 100644
--- a/src/main/java/com/google/devtools/build/skyframe/AbstractSkyFunctionEnvironment.java
+++ b/src/main/java/com/google/devtools/build/skyframe/AbstractSkyFunctionEnvironment.java
@@ -78,7 +78,8 @@
   @Nullable
   public <E extends Exception> SkyValue getValueOrThrow(SkyKey depKey, Class<E> exceptionClass)
       throws E, InterruptedException {
-    return getValueOrThrow(depKey, exceptionClass, null, null, null);
+    SkyFunctionException.validateExceptionType(exceptionClass);
+    return getValueOrThrowHelper(depKey, exceptionClass, null, null, null);
   }
 
   @Override
@@ -86,7 +87,9 @@
   public <E1 extends Exception, E2 extends Exception> SkyValue getValueOrThrow(
       SkyKey depKey, Class<E1> exceptionClass1, Class<E2> exceptionClass2)
       throws E1, E2, InterruptedException {
-    return getValueOrThrow(depKey, exceptionClass1, exceptionClass2, null, null);
+    SkyFunctionException.validateExceptionType(exceptionClass1);
+    SkyFunctionException.validateExceptionType(exceptionClass2);
+    return getValueOrThrowHelper(depKey, exceptionClass1, exceptionClass2, null, null);
   }
 
   @Override
@@ -98,19 +101,39 @@
           Class<E2> exceptionClass2,
           Class<E3> exceptionClass3)
           throws E1, E2, E3, InterruptedException {
-    return getValueOrThrow(depKey, exceptionClass1, exceptionClass2, exceptionClass3, null);
+    SkyFunctionException.validateExceptionType(exceptionClass1);
+    SkyFunctionException.validateExceptionType(exceptionClass2);
+    SkyFunctionException.validateExceptionType(exceptionClass3);
+    return getValueOrThrowHelper(depKey, exceptionClass1, exceptionClass2, exceptionClass3, null);
   }
 
   @Override
+  @Nullable
   public <E1 extends Exception, E2 extends Exception, E3 extends Exception, E4 extends Exception>
       SkyValue getValueOrThrow(
           SkyKey depKey,
           Class<E1> exceptionClass1,
+          Class<E2> exceptionClass2,
+          Class<E3> exceptionClass3,
+          Class<E4> exceptionClass4)
+          throws E1, E2, E3, E4, InterruptedException {
+    SkyFunctionException.validateExceptionType(exceptionClass1);
+    SkyFunctionException.validateExceptionType(exceptionClass2);
+    SkyFunctionException.validateExceptionType(exceptionClass3);
+    SkyFunctionException.validateExceptionType(exceptionClass4);
+    return getValueOrThrowHelper(
+        depKey, exceptionClass1, exceptionClass2, exceptionClass3, exceptionClass4);
+  }
+
+  @Nullable
+  private <E1 extends Exception, E2 extends Exception, E3 extends Exception, E4 extends Exception>
+      SkyValue getValueOrThrowHelper(
+          SkyKey depKey,
+          @Nullable Class<E1> exceptionClass1,
           @Nullable Class<E2> exceptionClass2,
           @Nullable Class<E3> exceptionClass3,
           @Nullable Class<E4> exceptionClass4)
           throws E1, E2, E3, E4, InterruptedException {
-    SkyFunctionException.validateExceptionType(exceptionClass1);
     SkyframeIterableResult result = getOrderedValuesAndExceptions(ImmutableSet.of(depKey));
     return result.nextOrThrow(exceptionClass1, exceptionClass2, exceptionClass3, exceptionClass4);
   }
diff --git a/src/test/java/com/google/devtools/build/skyframe/ParallelEvaluatorTest.java b/src/test/java/com/google/devtools/build/skyframe/ParallelEvaluatorTest.java
index 1b012ab..3903864 100644
--- a/src/test/java/com/google/devtools/build/skyframe/ParallelEvaluatorTest.java
+++ b/src/test/java/com/google/devtools/build/skyframe/ParallelEvaluatorTest.java
@@ -64,6 +64,7 @@
 import com.google.devtools.build.skyframe.SkyFunctionException.Transience;
 import com.google.testing.junit.testparameterinjector.TestParameter;
 import com.google.testing.junit.testparameterinjector.TestParameterInjector;
+import java.io.IOException;
 import java.lang.ref.WeakReference;
 import java.util.ArrayList;
 import java.util.List;
@@ -2353,6 +2354,86 @@
   }
 
   @Test
+  public void validateExceptionTypeInDifferentPosition(
+      @TestParameter({"0", "1", "2", "3"}) int exceptionIndex) throws Exception {
+    ImmutableList<Class<? extends Exception>> exceptions =
+        ImmutableList.of(
+            Exception.class,
+            SomeOtherErrorException.class,
+            IOException.class,
+            SomeErrorException.class);
+    graph = new InMemoryGraphImpl();
+    SkyKey otherKey = GraphTester.toSkyKey("other");
+    tester.set(otherKey, new StringValue("other"));
+    SkyKey parentKey = GraphTester.toSkyKey("parent");
+    SomeErrorException parentExn = new SomeErrorException("parent error");
+    tester
+        .getOrCreate(parentKey)
+        .setBuilder(
+            (skyKey, env) -> {
+              IllegalStateException illegalStateException =
+                  assertThrows(
+                      IllegalStateException.class,
+                      () ->
+                          env.getValueOrThrow(
+                              otherKey,
+                              exceptions.get(exceptionIndex % 4),
+                              exceptions.get((exceptionIndex + 1) % 4),
+                              exceptions.get((exceptionIndex + 2) % 4),
+                              exceptions.get((exceptionIndex + 3) % 4)));
+              assertThat(illegalStateException)
+                  .hasMessageThat()
+                  .contains("is a supertype of RuntimeException");
+              assertThat(env.valuesMissing()).isFalse();
+              throw new GenericFunctionException(parentExn, Transience.PERSISTENT);
+            });
+    EvaluationResult<StringValue> result = eval(/*keepGoing=*/ true, ImmutableList.of(parentKey));
+    assertThat(result.hasError()).isTrue();
+    assertThat(result.getError().getException()).isEqualTo(parentExn);
+  }
+
+  @Test
+  public void validateExceptionTypeWithDifferentException(
+      @TestParameter ExceptionOption exceptionOption) throws Exception {
+    graph = new InMemoryGraphImpl();
+    SkyKey otherKey = GraphTester.toSkyKey("other");
+    tester.set(otherKey, new StringValue("other"));
+    SkyKey parentKey = GraphTester.toSkyKey("parent");
+    SomeErrorException parentExn = new SomeErrorException("parent error");
+    tester
+        .getOrCreate(parentKey)
+        .setBuilder(
+            (skyKey, env) -> {
+              IllegalStateException illegalStateException =
+                  assertThrows(
+                      IllegalStateException.class,
+                      () -> env.getValueOrThrow(otherKey, exceptionOption.exceptionClass));
+              assertThat(illegalStateException)
+                  .hasMessageThat()
+                  .contains(exceptionOption.errorMessage);
+              assertThat(env.valuesMissing()).isFalse();
+              throw new GenericFunctionException(parentExn, Transience.PERSISTENT);
+            });
+    EvaluationResult<StringValue> result = eval(/*keepGoing=*/ true, ImmutableList.of(parentKey));
+    assertThat(result.hasError()).isTrue();
+    assertThat(result.getError().getException()).isEqualTo(parentExn);
+  }
+
+  private enum ExceptionOption {
+    EXCEPTION(Exception.class, "is a supertype of RuntimeException"),
+    NULL_POINTER_EXCEPTION(NullPointerException.class, "is a subtype of RuntimeException"),
+    INTERRUPTED_EXCEPTION(InterruptedException.class, "is a subtype of InterruptedException");
+
+    final Class<? extends Exception> exceptionClass;
+    final String errorMessage;
+
+    ExceptionOption(Class<? extends Exception> exceptionClass, String errorMessage) {
+      this.exceptionClass = exceptionClass;
+      this.errorMessage = errorMessage;
+    }
+  }
+
+  @Test
   public void duplicateCycles() throws Exception {
     graph = new InMemoryGraphImpl();
     SkyKey grandparentKey = GraphTester.toSkyKey("grandparent");