Validate `Exception` Type in `getOrThrow`.
PiperOrigin-RevId: 438123169
diff --git a/src/main/java/com/google/devtools/build/skyframe/AbstractSkyFunctionEnvironment.java b/src/main/java/com/google/devtools/build/skyframe/AbstractSkyFunctionEnvironment.java
index 016aeb2..26ab7a4 100644
--- a/src/main/java/com/google/devtools/build/skyframe/AbstractSkyFunctionEnvironment.java
+++ b/src/main/java/com/google/devtools/build/skyframe/AbstractSkyFunctionEnvironment.java
@@ -78,7 +78,8 @@
@Nullable
public <E extends Exception> SkyValue getValueOrThrow(SkyKey depKey, Class<E> exceptionClass)
throws E, InterruptedException {
- return getValueOrThrow(depKey, exceptionClass, null, null, null);
+ SkyFunctionException.validateExceptionType(exceptionClass);
+ return getValueOrThrowHelper(depKey, exceptionClass, null, null, null);
}
@Override
@@ -86,7 +87,9 @@
public <E1 extends Exception, E2 extends Exception> SkyValue getValueOrThrow(
SkyKey depKey, Class<E1> exceptionClass1, Class<E2> exceptionClass2)
throws E1, E2, InterruptedException {
- return getValueOrThrow(depKey, exceptionClass1, exceptionClass2, null, null);
+ SkyFunctionException.validateExceptionType(exceptionClass1);
+ SkyFunctionException.validateExceptionType(exceptionClass2);
+ return getValueOrThrowHelper(depKey, exceptionClass1, exceptionClass2, null, null);
}
@Override
@@ -98,19 +101,39 @@
Class<E2> exceptionClass2,
Class<E3> exceptionClass3)
throws E1, E2, E3, InterruptedException {
- return getValueOrThrow(depKey, exceptionClass1, exceptionClass2, exceptionClass3, null);
+ SkyFunctionException.validateExceptionType(exceptionClass1);
+ SkyFunctionException.validateExceptionType(exceptionClass2);
+ SkyFunctionException.validateExceptionType(exceptionClass3);
+ return getValueOrThrowHelper(depKey, exceptionClass1, exceptionClass2, exceptionClass3, null);
}
@Override
+ @Nullable
public <E1 extends Exception, E2 extends Exception, E3 extends Exception, E4 extends Exception>
SkyValue getValueOrThrow(
SkyKey depKey,
Class<E1> exceptionClass1,
+ Class<E2> exceptionClass2,
+ Class<E3> exceptionClass3,
+ Class<E4> exceptionClass4)
+ throws E1, E2, E3, E4, InterruptedException {
+ SkyFunctionException.validateExceptionType(exceptionClass1);
+ SkyFunctionException.validateExceptionType(exceptionClass2);
+ SkyFunctionException.validateExceptionType(exceptionClass3);
+ SkyFunctionException.validateExceptionType(exceptionClass4);
+ return getValueOrThrowHelper(
+ depKey, exceptionClass1, exceptionClass2, exceptionClass3, exceptionClass4);
+ }
+
+ @Nullable
+ private <E1 extends Exception, E2 extends Exception, E3 extends Exception, E4 extends Exception>
+ SkyValue getValueOrThrowHelper(
+ SkyKey depKey,
+ @Nullable Class<E1> exceptionClass1,
@Nullable Class<E2> exceptionClass2,
@Nullable Class<E3> exceptionClass3,
@Nullable Class<E4> exceptionClass4)
throws E1, E2, E3, E4, InterruptedException {
- SkyFunctionException.validateExceptionType(exceptionClass1);
SkyframeIterableResult result = getOrderedValuesAndExceptions(ImmutableSet.of(depKey));
return result.nextOrThrow(exceptionClass1, exceptionClass2, exceptionClass3, exceptionClass4);
}
diff --git a/src/test/java/com/google/devtools/build/skyframe/ParallelEvaluatorTest.java b/src/test/java/com/google/devtools/build/skyframe/ParallelEvaluatorTest.java
index 1b012ab..3903864 100644
--- a/src/test/java/com/google/devtools/build/skyframe/ParallelEvaluatorTest.java
+++ b/src/test/java/com/google/devtools/build/skyframe/ParallelEvaluatorTest.java
@@ -64,6 +64,7 @@
import com.google.devtools.build.skyframe.SkyFunctionException.Transience;
import com.google.testing.junit.testparameterinjector.TestParameter;
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
+import java.io.IOException;
import java.lang.ref.WeakReference;
import java.util.ArrayList;
import java.util.List;
@@ -2353,6 +2354,86 @@
}
@Test
+ public void validateExceptionTypeInDifferentPosition(
+ @TestParameter({"0", "1", "2", "3"}) int exceptionIndex) throws Exception {
+ ImmutableList<Class<? extends Exception>> exceptions =
+ ImmutableList.of(
+ Exception.class,
+ SomeOtherErrorException.class,
+ IOException.class,
+ SomeErrorException.class);
+ graph = new InMemoryGraphImpl();
+ SkyKey otherKey = GraphTester.toSkyKey("other");
+ tester.set(otherKey, new StringValue("other"));
+ SkyKey parentKey = GraphTester.toSkyKey("parent");
+ SomeErrorException parentExn = new SomeErrorException("parent error");
+ tester
+ .getOrCreate(parentKey)
+ .setBuilder(
+ (skyKey, env) -> {
+ IllegalStateException illegalStateException =
+ assertThrows(
+ IllegalStateException.class,
+ () ->
+ env.getValueOrThrow(
+ otherKey,
+ exceptions.get(exceptionIndex % 4),
+ exceptions.get((exceptionIndex + 1) % 4),
+ exceptions.get((exceptionIndex + 2) % 4),
+ exceptions.get((exceptionIndex + 3) % 4)));
+ assertThat(illegalStateException)
+ .hasMessageThat()
+ .contains("is a supertype of RuntimeException");
+ assertThat(env.valuesMissing()).isFalse();
+ throw new GenericFunctionException(parentExn, Transience.PERSISTENT);
+ });
+ EvaluationResult<StringValue> result = eval(/*keepGoing=*/ true, ImmutableList.of(parentKey));
+ assertThat(result.hasError()).isTrue();
+ assertThat(result.getError().getException()).isEqualTo(parentExn);
+ }
+
+ @Test
+ public void validateExceptionTypeWithDifferentException(
+ @TestParameter ExceptionOption exceptionOption) throws Exception {
+ graph = new InMemoryGraphImpl();
+ SkyKey otherKey = GraphTester.toSkyKey("other");
+ tester.set(otherKey, new StringValue("other"));
+ SkyKey parentKey = GraphTester.toSkyKey("parent");
+ SomeErrorException parentExn = new SomeErrorException("parent error");
+ tester
+ .getOrCreate(parentKey)
+ .setBuilder(
+ (skyKey, env) -> {
+ IllegalStateException illegalStateException =
+ assertThrows(
+ IllegalStateException.class,
+ () -> env.getValueOrThrow(otherKey, exceptionOption.exceptionClass));
+ assertThat(illegalStateException)
+ .hasMessageThat()
+ .contains(exceptionOption.errorMessage);
+ assertThat(env.valuesMissing()).isFalse();
+ throw new GenericFunctionException(parentExn, Transience.PERSISTENT);
+ });
+ EvaluationResult<StringValue> result = eval(/*keepGoing=*/ true, ImmutableList.of(parentKey));
+ assertThat(result.hasError()).isTrue();
+ assertThat(result.getError().getException()).isEqualTo(parentExn);
+ }
+
+ private enum ExceptionOption {
+ EXCEPTION(Exception.class, "is a supertype of RuntimeException"),
+ NULL_POINTER_EXCEPTION(NullPointerException.class, "is a subtype of RuntimeException"),
+ INTERRUPTED_EXCEPTION(InterruptedException.class, "is a subtype of InterruptedException");
+
+ final Class<? extends Exception> exceptionClass;
+ final String errorMessage;
+
+ ExceptionOption(Class<? extends Exception> exceptionClass, String errorMessage) {
+ this.exceptionClass = exceptionClass;
+ this.errorMessage = errorMessage;
+ }
+ }
+
+ @Test
public void duplicateCycles() throws Exception {
graph = new InMemoryGraphImpl();
SkyKey grandparentKey = GraphTester.toSkyKey("grandparent");