Thread package-overhead estimate through to validation

PiperOrigin-RevId: 363639092
diff --git a/src/test/java/com/google/devtools/build/lib/skyframe/PackageFunctionTest.java b/src/test/java/com/google/devtools/build/lib/skyframe/PackageFunctionTest.java
index 13d3d72..c1c8ab1 100644
--- a/src/test/java/com/google/devtools/build/lib/skyframe/PackageFunctionTest.java
+++ b/src/test/java/com/google/devtools/build/lib/skyframe/PackageFunctionTest.java
@@ -19,7 +19,10 @@
 import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 import com.google.common.base.Predicates;
 import com.google.common.collect.ImmutableList;
@@ -40,6 +43,7 @@
 import com.google.devtools.build.lib.packages.NoSuchPackageException;
 import com.google.devtools.build.lib.packages.NoSuchTargetException;
 import com.google.devtools.build.lib.packages.Package;
+import com.google.devtools.build.lib.packages.PackageOverheadEstimator;
 import com.google.devtools.build.lib.packages.PackageValidator;
 import com.google.devtools.build.lib.packages.PackageValidator.InvalidPackageException;
 import com.google.devtools.build.lib.packages.semantics.BuildLanguageOptions;
@@ -79,6 +83,7 @@
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
+import java.util.OptionalLong;
 import java.util.Set;
 import java.util.UUID;
 import java.util.concurrent.atomic.AtomicInteger;
@@ -104,6 +109,8 @@
 
   @Mock private PackageValidator mockPackageValidator;
 
+  @Mock private PackageOverheadEstimator mockPackageOverheadEstimator;
+
   private CustomInMemoryFs fs = new CustomInMemoryFs(new ManualClock());
 
   private void preparePackageLoading(Path... roots) {
@@ -141,6 +148,11 @@
     return mockPackageValidator;
   }
 
+  @Override
+  protected PackageOverheadEstimator getPackageOverheadEstimator() {
+    return mockPackageOverheadEstimator;
+  }
+
   private Package validPackageWithoutErrors(SkyKey skyKey) throws InterruptedException {
     return validPackageInternal(skyKey, /*checkPackageError=*/ true);
   }
@@ -200,13 +212,13 @@
             inv -> {
               Package pkg = inv.getArgument(0, Package.class);
               if (pkg.getName().equals("pkg")) {
-                inv.getArgument(1, ExtendedEventHandler.class).handle(Event.warn("warning event"));
+                inv.getArgument(2, ExtendedEventHandler.class).handle(Event.warn("warning event"));
                 throw new InvalidPackageException(pkg.getPackageIdentifier(), "no good");
               }
               return null;
             })
         .when(mockPackageValidator)
-        .validate(any(Package.class), any(ExtendedEventHandler.class));
+        .validate(any(Package.class), any(OptionalLong.class), any(ExtendedEventHandler.class));
 
     invalidatePackages();
 
@@ -217,6 +229,25 @@
   }
 
   @Test
+  public void testPackageOverheadPassedToValidationLogic() throws Exception {
+    scratch.file("pkg/BUILD", "# Contents doesn't matter, it's all fake");
+
+    when(mockPackageOverheadEstimator.estimatePackageOverhead(any(Package.class)))
+        .thenReturn(OptionalLong.of(42));
+
+    invalidatePackages();
+
+    SkyframeExecutorTestUtils.evaluate(
+        getSkyframeExecutor(),
+        PackageValue.key(PackageIdentifier.parse("@//pkg")),
+        /*keepGoing=*/ false,
+        reporter);
+
+    verify(mockPackageValidator)
+        .validate(any(Package.class), eq(OptionalLong.of(42)), any(ExtendedEventHandler.class));
+  }
+
+  @Test
   public void testSkyframeExecutorClearedPackagesResultsInReload() throws Exception {
     scratch.file("pkg/BUILD", "sh_library(name='foo', srcs=['foo.sh'])");
     scratch.file("pkg/foo.sh");
@@ -233,7 +264,7 @@
               return null;
             })
         .when(mockPackageValidator)
-        .validate(any(Package.class), any(ExtendedEventHandler.class));
+        .validate(any(Package.class), any(OptionalLong.class), any(ExtendedEventHandler.class));
 
     SkyKey skyKey = PackageValue.key(PackageIdentifier.parse("@//pkg"));
     EvaluationResult<PackageValue> result1 =