Add batch methods to WalkableGraph and convert SkyQueryEnvironment to use them.

--
MOS_MIGRATED_REVID=96214911
diff --git a/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java b/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java
index 9c7a0bf..00ddd2b 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java
@@ -13,13 +13,18 @@
 // limitations under the License.
 package com.google.devtools.build.lib.query2;
 
+import com.google.common.base.Function;
 import com.google.common.base.Preconditions;
 import com.google.common.base.Predicate;
 import com.google.common.base.Predicates;
+import com.google.common.base.Supplier;
 import com.google.common.collect.Collections2;
+import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Maps;
+import com.google.common.collect.Multimap;
+import com.google.common.collect.Multimaps;
 import com.google.devtools.build.lib.cmdline.ResolvedTargets;
 import com.google.devtools.build.lib.cmdline.TargetParsingException;
 import com.google.devtools.build.lib.cmdline.TargetPattern;
@@ -27,6 +32,7 @@
 import com.google.devtools.build.lib.events.EventHandler;
 import com.google.devtools.build.lib.events.StoredEventHandler;
 import com.google.devtools.build.lib.graph.Digraph;
+import com.google.devtools.build.lib.packages.NoSuchTargetException;
 import com.google.devtools.build.lib.packages.NoSuchThingException;
 import com.google.devtools.build.lib.packages.Package;
 import com.google.devtools.build.lib.packages.Rule;
@@ -45,10 +51,12 @@
 import com.google.devtools.build.lib.syntax.Label;
 import com.google.devtools.build.skyframe.SkyFunctionName;
 import com.google.devtools.build.skyframe.SkyKey;
+import com.google.devtools.build.skyframe.SkyValue;
 import com.google.devtools.build.skyframe.WalkableGraph;
 import com.google.devtools.build.skyframe.WalkableGraph.WalkableGraphFactory;
 
 import java.util.ArrayDeque;
+import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Deque;
 import java.util.HashMap;
@@ -59,8 +67,6 @@
 import java.util.Map;
 import java.util.Set;
 
-import javax.annotation.Nullable;
-
 /**
  * {@link AbstractBlazeQueryEnvironment} that introspects the Skyframe graph to find forward and
  * reverse edges. Results obtained by calling {@link #evaluateQuery} are not guaranteed to be in
@@ -116,30 +122,33 @@
     return super.evaluateQuery(expr);
   }
 
-  private static SkyKey transformToKey(Target value) {
+  private static SkyKey makeKey(Target value) {
     return TransitiveTargetValue.key(value.getLabel());
   }
 
-  @Nullable
-  private Target transformToValue(SkyKey key) {
-    SkyFunctionName functionName = key.functionName();
-    if (!functionName.equals(SkyFunctions.TRANSITIVE_TARGET)) {
-      return null;
-    }
-    try {
-      return getTarget(((Label) key.argument()));
-    } catch (QueryException | TargetNotFoundException e) {
-      // Any problems with targets were already reported during #buildTransitiveClosure.
-      return null;
-    }
+  private Collection<Target> getRawFwdDeps(Target target) {
+    return makeTargets(graph.getDirectDeps(makeKey(target)));
   }
 
-  private Collection<Target> getRawFwdDeps(Target target) {
-    return makeTargets(graph.getDirectDeps(transformToKey(target)));
+  private Map<Target, Collection<Target>> makeTargetsMap(Map<SkyKey, Iterable<SkyKey>> input) {
+    ImmutableMap.Builder<Target, Collection<Target>> result = ImmutableMap.builder();
+
+    for (Map.Entry<SkyKey, Target> entry : makeTargetsWithAssociations(input.keySet()).entrySet()) {
+      result.put(entry.getValue(), makeTargets(input.get(entry.getKey())));
+    }
+    return result.build();
+  }
+
+  private Map<Target, Collection<Target>> getRawFwdDeps(Iterable<Target> targets) {
+    return makeTargetsMap(graph.getDirectDeps(makeKeys(targets)));
   }
 
   private Collection<Target> getRawReverseDeps(Target target) {
-    return makeTargets(graph.getReverseDeps(transformToKey(target)));
+    return makeTargets(graph.getReverseDeps(makeKey(target)));
+  }
+
+  private Map<Target, Collection<Target>> getRawReverseDeps(Iterable<Target> targets) {
+    return makeTargetsMap(graph.getReverseDeps(makeKeys(targets)));
   }
 
   private Set<Label> getAllowedDeps(Rule rule) {
@@ -150,14 +159,12 @@
     return allowedLabels;
   }
 
-  @Override
-  public Collection<Target> getFwdDeps(Target target) {
-    Collection<Target> unfilteredDeps = getRawFwdDeps(target);
+  private Collection<Target> filterFwdDeps(Target target, Collection<Target> rawFwdDeps) {
     if (!(target instanceof Rule)) {
-      return getRawFwdDeps(target);
+      return rawFwdDeps;
     }
     final Set<Label> allowedLabels = getAllowedDeps((Rule) target);
-    return Collections2.filter(unfilteredDeps,
+    return Collections2.filter(rawFwdDeps,
         new Predicate<Target>() {
           @Override
           public boolean apply(Target target) {
@@ -167,30 +174,41 @@
   }
 
   @Override
+  public Collection<Target> getFwdDeps(Target target) {
+    return filterFwdDeps(target, getRawFwdDeps(target));
+  }
+
+  @Override
   public Collection<Target> getFwdDeps(Iterable<Target> targets) {
     Set<Target> result = new HashSet<>();
-    for (Target target : targets) {
-      result.addAll(getFwdDeps(target));
+    for (Map.Entry<Target, Collection<Target>> entry : getRawFwdDeps(targets).entrySet()) {
+      result.addAll(filterFwdDeps(entry.getKey(), entry.getValue()));
     }
     return result;
   }
 
-  @Override
-  public Collection<Target> getReverseDeps(final Target target) {
-    return Collections2.filter(getRawReverseDeps(target), new Predicate<Target>() {
+  private Collection<Target> filterReverseDeps(final Target target,
+      Collection<Target> rawReverseDeps) {
+    return Collections2.filter(rawReverseDeps, new Predicate<Target>() {
       @Override
       public boolean apply(Target parent) {
         return !(parent instanceof Rule)
             || getAllowedDeps((Rule) parent).contains(target.getLabel());
       }
     });
+
+  }
+
+  @Override
+  public Collection<Target> getReverseDeps(Target target) {
+    return filterReverseDeps(target, getRawReverseDeps(target));
   }
 
   @Override
   public Collection<Target> getReverseDeps(Iterable<Target> targets) {
     Set<Target> result = new HashSet<>();
-    for (Target target : targets) {
-      result.addAll(getReverseDeps(target));
+    for (Map.Entry<Target, Collection<Target>> entry : getRawReverseDeps(targets).entrySet()) {
+      result.addAll(filterReverseDeps(entry.getKey(), entry.getValue()));
     }
     return result;
   }
@@ -297,11 +315,15 @@
     return accessor;
   }
 
-  @Override
-  public Target getTarget(Label label) throws TargetNotFoundException, QueryException {
+  private SkyKey getPackageKeyAndValidateLabel(Label label) throws QueryException {
     // Can't use strictScope here because we are expecting a target back.
     validateScope(label, true);
-    SkyKey packageKey = PackageValue.key(label.getPackageIdentifier());
+    return PackageValue.key(label.getPackageIdentifier());
+  }
+
+  @Override
+  public Target getTarget(Label label) throws TargetNotFoundException, QueryException {
+    SkyKey packageKey = getPackageKeyAndValidateLabel(label);
     checkExistence(packageKey);
     try {
       PackageValue packageValue =
@@ -406,16 +428,53 @@
     return result;
   }
 
-  private Set<Target> makeTargets(Iterable<SkyKey> keys) {
-    ImmutableSet.Builder<Target> builder = ImmutableSet.builder();
+  private Collection<Target> makeTargets(Iterable<SkyKey> keys) {
+    return makeTargetsWithAssociations(keys).values();
+  }
+
+  private Map<SkyKey, Target> makeTargetsWithAssociations(Iterable<SkyKey> keys) {
+    Multimap<SkyKey, SkyKey> packageKeyToTargetKeyMap = Multimaps.newListMultimap(
+        new HashMap<SkyKey, Collection<SkyKey>>(),
+        new Supplier<List<SkyKey>>() {
+          @Override
+          public List<SkyKey> get() {
+            return new ArrayList<>();
+          }
+        });
     for (SkyKey key : keys) {
-      Target value = transformToValue(key);
-      if (value != null) {
-        // Some values may be filtered out because they are not Targets.
-        builder.add(value);
+      SkyFunctionName functionName = key.functionName();
+      if (!functionName.equals(SkyFunctions.TRANSITIVE_TARGET)) {
+        // Skip non-targets.
+        continue;
+      }
+      try {
+        packageKeyToTargetKeyMap.put(getPackageKeyAndValidateLabel((Label) key.argument()), key);
+      } catch (QueryException e) {
+        // Skip disallowed labels.
       }
     }
-    return builder.build();
+    ImmutableMap.Builder<SkyKey, Target> result = ImmutableMap.builder();
+    Map<SkyKey, SkyValue> packageMap = graph.getValuesMaybe(packageKeyToTargetKeyMap.keySet());
+    for (Map.Entry<SkyKey, SkyValue> entry : packageMap.entrySet()) {
+      for (SkyKey targetKey : packageKeyToTargetKeyMap.get(entry.getKey())) {
+        try {
+          result.put(targetKey, ((PackageValue) entry.getValue()).getPackage()
+              .getTarget(((Label) targetKey.argument()).getName()));
+        } catch (NoSuchTargetException e) {
+          // Skip missing target.
+        }
+      }
+    }
+    return result.build();
+  }
+
+  private Iterable<SkyKey> makeKeys(Iterable<Target> targets) {
+    return Iterables.transform(targets, new Function<Target, SkyKey>() {
+      @Override
+      public SkyKey apply(Target target) {
+        return TransitiveTargetValue.key(target.getLabel());
+      }
+    });
   }
 
   private void checkExistence(SkyKey key) throws QueryException {