Add batch methods to WalkableGraph and convert SkyQueryEnvironment to use them.

--
MOS_MIGRATED_REVID=96214911
diff --git a/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java b/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java
index 9c7a0bf..00ddd2b 100644
--- a/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java
+++ b/src/main/java/com/google/devtools/build/lib/query2/SkyQueryEnvironment.java
@@ -13,13 +13,18 @@
 // limitations under the License.
 package com.google.devtools.build.lib.query2;
 
+import com.google.common.base.Function;
 import com.google.common.base.Preconditions;
 import com.google.common.base.Predicate;
 import com.google.common.base.Predicates;
+import com.google.common.base.Supplier;
 import com.google.common.collect.Collections2;
+import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Maps;
+import com.google.common.collect.Multimap;
+import com.google.common.collect.Multimaps;
 import com.google.devtools.build.lib.cmdline.ResolvedTargets;
 import com.google.devtools.build.lib.cmdline.TargetParsingException;
 import com.google.devtools.build.lib.cmdline.TargetPattern;
@@ -27,6 +32,7 @@
 import com.google.devtools.build.lib.events.EventHandler;
 import com.google.devtools.build.lib.events.StoredEventHandler;
 import com.google.devtools.build.lib.graph.Digraph;
+import com.google.devtools.build.lib.packages.NoSuchTargetException;
 import com.google.devtools.build.lib.packages.NoSuchThingException;
 import com.google.devtools.build.lib.packages.Package;
 import com.google.devtools.build.lib.packages.Rule;
@@ -45,10 +51,12 @@
 import com.google.devtools.build.lib.syntax.Label;
 import com.google.devtools.build.skyframe.SkyFunctionName;
 import com.google.devtools.build.skyframe.SkyKey;
+import com.google.devtools.build.skyframe.SkyValue;
 import com.google.devtools.build.skyframe.WalkableGraph;
 import com.google.devtools.build.skyframe.WalkableGraph.WalkableGraphFactory;
 
 import java.util.ArrayDeque;
+import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Deque;
 import java.util.HashMap;
@@ -59,8 +67,6 @@
 import java.util.Map;
 import java.util.Set;
 
-import javax.annotation.Nullable;
-
 /**
  * {@link AbstractBlazeQueryEnvironment} that introspects the Skyframe graph to find forward and
  * reverse edges. Results obtained by calling {@link #evaluateQuery} are not guaranteed to be in
@@ -116,30 +122,33 @@
     return super.evaluateQuery(expr);
   }
 
-  private static SkyKey transformToKey(Target value) {
+  private static SkyKey makeKey(Target value) {
     return TransitiveTargetValue.key(value.getLabel());
   }
 
-  @Nullable
-  private Target transformToValue(SkyKey key) {
-    SkyFunctionName functionName = key.functionName();
-    if (!functionName.equals(SkyFunctions.TRANSITIVE_TARGET)) {
-      return null;
-    }
-    try {
-      return getTarget(((Label) key.argument()));
-    } catch (QueryException | TargetNotFoundException e) {
-      // Any problems with targets were already reported during #buildTransitiveClosure.
-      return null;
-    }
+  private Collection<Target> getRawFwdDeps(Target target) {
+    return makeTargets(graph.getDirectDeps(makeKey(target)));
   }
 
-  private Collection<Target> getRawFwdDeps(Target target) {
-    return makeTargets(graph.getDirectDeps(transformToKey(target)));
+  private Map<Target, Collection<Target>> makeTargetsMap(Map<SkyKey, Iterable<SkyKey>> input) {
+    ImmutableMap.Builder<Target, Collection<Target>> result = ImmutableMap.builder();
+
+    for (Map.Entry<SkyKey, Target> entry : makeTargetsWithAssociations(input.keySet()).entrySet()) {
+      result.put(entry.getValue(), makeTargets(input.get(entry.getKey())));
+    }
+    return result.build();
+  }
+
+  private Map<Target, Collection<Target>> getRawFwdDeps(Iterable<Target> targets) {
+    return makeTargetsMap(graph.getDirectDeps(makeKeys(targets)));
   }
 
   private Collection<Target> getRawReverseDeps(Target target) {
-    return makeTargets(graph.getReverseDeps(transformToKey(target)));
+    return makeTargets(graph.getReverseDeps(makeKey(target)));
+  }
+
+  private Map<Target, Collection<Target>> getRawReverseDeps(Iterable<Target> targets) {
+    return makeTargetsMap(graph.getReverseDeps(makeKeys(targets)));
   }
 
   private Set<Label> getAllowedDeps(Rule rule) {
@@ -150,14 +159,12 @@
     return allowedLabels;
   }
 
-  @Override
-  public Collection<Target> getFwdDeps(Target target) {
-    Collection<Target> unfilteredDeps = getRawFwdDeps(target);
+  private Collection<Target> filterFwdDeps(Target target, Collection<Target> rawFwdDeps) {
     if (!(target instanceof Rule)) {
-      return getRawFwdDeps(target);
+      return rawFwdDeps;
     }
     final Set<Label> allowedLabels = getAllowedDeps((Rule) target);
-    return Collections2.filter(unfilteredDeps,
+    return Collections2.filter(rawFwdDeps,
         new Predicate<Target>() {
           @Override
           public boolean apply(Target target) {
@@ -167,30 +174,41 @@
   }
 
   @Override
+  public Collection<Target> getFwdDeps(Target target) {
+    return filterFwdDeps(target, getRawFwdDeps(target));
+  }
+
+  @Override
   public Collection<Target> getFwdDeps(Iterable<Target> targets) {
     Set<Target> result = new HashSet<>();
-    for (Target target : targets) {
-      result.addAll(getFwdDeps(target));
+    for (Map.Entry<Target, Collection<Target>> entry : getRawFwdDeps(targets).entrySet()) {
+      result.addAll(filterFwdDeps(entry.getKey(), entry.getValue()));
     }
     return result;
   }
 
-  @Override
-  public Collection<Target> getReverseDeps(final Target target) {
-    return Collections2.filter(getRawReverseDeps(target), new Predicate<Target>() {
+  private Collection<Target> filterReverseDeps(final Target target,
+      Collection<Target> rawReverseDeps) {
+    return Collections2.filter(rawReverseDeps, new Predicate<Target>() {
       @Override
       public boolean apply(Target parent) {
         return !(parent instanceof Rule)
             || getAllowedDeps((Rule) parent).contains(target.getLabel());
       }
     });
+
+  }
+
+  @Override
+  public Collection<Target> getReverseDeps(Target target) {
+    return filterReverseDeps(target, getRawReverseDeps(target));
   }
 
   @Override
   public Collection<Target> getReverseDeps(Iterable<Target> targets) {
     Set<Target> result = new HashSet<>();
-    for (Target target : targets) {
-      result.addAll(getReverseDeps(target));
+    for (Map.Entry<Target, Collection<Target>> entry : getRawReverseDeps(targets).entrySet()) {
+      result.addAll(filterReverseDeps(entry.getKey(), entry.getValue()));
     }
     return result;
   }
@@ -297,11 +315,15 @@
     return accessor;
   }
 
-  @Override
-  public Target getTarget(Label label) throws TargetNotFoundException, QueryException {
+  private SkyKey getPackageKeyAndValidateLabel(Label label) throws QueryException {
     // Can't use strictScope here because we are expecting a target back.
     validateScope(label, true);
-    SkyKey packageKey = PackageValue.key(label.getPackageIdentifier());
+    return PackageValue.key(label.getPackageIdentifier());
+  }
+
+  @Override
+  public Target getTarget(Label label) throws TargetNotFoundException, QueryException {
+    SkyKey packageKey = getPackageKeyAndValidateLabel(label);
     checkExistence(packageKey);
     try {
       PackageValue packageValue =
@@ -406,16 +428,53 @@
     return result;
   }
 
-  private Set<Target> makeTargets(Iterable<SkyKey> keys) {
-    ImmutableSet.Builder<Target> builder = ImmutableSet.builder();
+  private Collection<Target> makeTargets(Iterable<SkyKey> keys) {
+    return makeTargetsWithAssociations(keys).values();
+  }
+
+  private Map<SkyKey, Target> makeTargetsWithAssociations(Iterable<SkyKey> keys) {
+    Multimap<SkyKey, SkyKey> packageKeyToTargetKeyMap = Multimaps.newListMultimap(
+        new HashMap<SkyKey, Collection<SkyKey>>(),
+        new Supplier<List<SkyKey>>() {
+          @Override
+          public List<SkyKey> get() {
+            return new ArrayList<>();
+          }
+        });
     for (SkyKey key : keys) {
-      Target value = transformToValue(key);
-      if (value != null) {
-        // Some values may be filtered out because they are not Targets.
-        builder.add(value);
+      SkyFunctionName functionName = key.functionName();
+      if (!functionName.equals(SkyFunctions.TRANSITIVE_TARGET)) {
+        // Skip non-targets.
+        continue;
+      }
+      try {
+        packageKeyToTargetKeyMap.put(getPackageKeyAndValidateLabel((Label) key.argument()), key);
+      } catch (QueryException e) {
+        // Skip disallowed labels.
       }
     }
-    return builder.build();
+    ImmutableMap.Builder<SkyKey, Target> result = ImmutableMap.builder();
+    Map<SkyKey, SkyValue> packageMap = graph.getValuesMaybe(packageKeyToTargetKeyMap.keySet());
+    for (Map.Entry<SkyKey, SkyValue> entry : packageMap.entrySet()) {
+      for (SkyKey targetKey : packageKeyToTargetKeyMap.get(entry.getKey())) {
+        try {
+          result.put(targetKey, ((PackageValue) entry.getValue()).getPackage()
+              .getTarget(((Label) targetKey.argument()).getName()));
+        } catch (NoSuchTargetException e) {
+          // Skip missing target.
+        }
+      }
+    }
+    return result.build();
+  }
+
+  private Iterable<SkyKey> makeKeys(Iterable<Target> targets) {
+    return Iterables.transform(targets, new Function<Target, SkyKey>() {
+      @Override
+      public SkyKey apply(Target target) {
+        return TransitiveTargetValue.key(target.getLabel());
+      }
+    });
   }
 
   private void checkExistence(SkyKey key) throws QueryException {
diff --git a/src/main/java/com/google/devtools/build/skyframe/DelegatingWalkableGraph.java b/src/main/java/com/google/devtools/build/skyframe/DelegatingWalkableGraph.java
index 5328a48..fee4f9c 100644
--- a/src/main/java/com/google/devtools/build/skyframe/DelegatingWalkableGraph.java
+++ b/src/main/java/com/google/devtools/build/skyframe/DelegatingWalkableGraph.java
@@ -13,7 +13,13 @@
 // limitations under the License.
 package com.google.devtools.build.skyframe;
 
+import com.google.common.base.Function;
 import com.google.common.base.Preconditions;
+import com.google.common.base.Predicates;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Maps;
+
+import java.util.Map;
 
 import javax.annotation.Nullable;
 
@@ -33,6 +39,15 @@
     return entry;
   }
 
+  private Map<SkyKey, NodeEntry> getEntries(Iterable<SkyKey> keys) {
+    Map<SkyKey, NodeEntry> result = graph.getBatch(keys);
+    Preconditions.checkState(result.size() == Iterables.size(keys), "%s %s", keys, result);
+    for (Map.Entry<SkyKey, NodeEntry> entry : result.entrySet()) {
+      Preconditions.checkState(entry.getValue().isDone(), entry);
+    }
+    return result;
+  }
+
   @Override
   public boolean exists(SkyKey key) {
     NodeEntry entry = graph.get(key);
@@ -45,6 +60,21 @@
     return getEntry(key).getValue();
   }
 
+  private static final Function<NodeEntry, SkyValue> GET_SKY_VALUE_FUNCTION =
+      new Function<NodeEntry, SkyValue>() {
+        @Nullable
+        @Override
+        public SkyValue apply(NodeEntry entry) {
+          return entry.isDone() ? entry.getValue() : null;
+        }
+      };
+
+  @Override
+  public Map<SkyKey, SkyValue> getValuesMaybe(Iterable<SkyKey> keys) {
+    return Maps.filterValues(Maps.transformValues(graph.getBatch(keys), GET_SKY_VALUE_FUNCTION),
+        Predicates.notNull());
+  }
+
   @Nullable
   @Override
   public Exception getException(SkyKey key) {
@@ -57,8 +87,34 @@
     return getEntry(key).getDirectDeps();
   }
 
+  private static final Function<NodeEntry, Iterable<SkyKey>> GET_DIRECT_DEPS_FUNCTION =
+      new Function<NodeEntry, Iterable<SkyKey>>() {
+        @Override
+        public Iterable<SkyKey> apply(NodeEntry entry) {
+          return entry.getDirectDeps();
+        }
+      };
+
+  @Override
+  public Map<SkyKey, Iterable<SkyKey>> getDirectDeps(Iterable<SkyKey> keys) {
+    return Maps.transformValues(getEntries(keys), GET_DIRECT_DEPS_FUNCTION);
+  }
+
   @Override
   public Iterable<SkyKey> getReverseDeps(SkyKey key) {
     return getEntry(key).getReverseDeps();
   }
+
+  private static final Function<NodeEntry, Iterable<SkyKey>> GET_REVERSE_DEPS_FUNCTION =
+      new Function<NodeEntry, Iterable<SkyKey>>() {
+        @Override
+        public Iterable<SkyKey> apply(NodeEntry entry) {
+          return entry.getReverseDeps();
+        }
+      };
+
+  @Override
+  public Map<SkyKey, Iterable<SkyKey>> getReverseDeps(Iterable<SkyKey> keys) {
+    return Maps.transformValues(getEntries(keys), GET_REVERSE_DEPS_FUNCTION);
+  }
 }
diff --git a/src/main/java/com/google/devtools/build/skyframe/WalkableGraph.java b/src/main/java/com/google/devtools/build/skyframe/WalkableGraph.java
index 68c39c9..e020a26 100644
--- a/src/main/java/com/google/devtools/build/skyframe/WalkableGraph.java
+++ b/src/main/java/com/google/devtools/build/skyframe/WalkableGraph.java
@@ -16,6 +16,7 @@
 import com.google.devtools.build.lib.events.EventHandler;
 
 import java.util.Collection;
+import java.util.Map;
 
 import javax.annotation.Nullable;
 
@@ -40,6 +41,12 @@
   SkyValue getValue(SkyKey key);
 
   /**
+   * Returns a map giving the values of the given keys for done keys. Keys not present in the graph
+   * or whose nodes are not done will not be present in the returned map.
+   */
+  Map<SkyKey, SkyValue> getValuesMaybe(Iterable<SkyKey> keys);
+
+  /**
    * Returns the exception thrown when computing the node with the given key, if any. If the node
    * was computed successfully, returns null. A node with this key must exist in the graph.
    */
@@ -52,11 +59,23 @@
   Iterable<SkyKey> getDirectDeps(SkyKey key);
 
   /**
-   * Returns the reverse dependencies of the node with the given key. A node with this key must
-   * exist in the graph.
+   * Returns a map giving the direct dependencies of the nodes with the given keys. Same semantics
+   * as {@link #getDirectDeps(SkyKey)}.
    */
+  Map<SkyKey, Iterable<SkyKey>> getDirectDeps(Iterable<SkyKey> keys);
+
+    /**
+     * Returns the reverse dependencies of the node with the given key. A node with this key must
+     * exist in the graph.
+     */
   Iterable<SkyKey> getReverseDeps(SkyKey key);
 
+  /**
+   * Returns a map giving the reverse dependencies of the nodes with the given keys. Same semantics
+   * as {@link #getReverseDeps(SkyKey)}.
+   */
+  Map<SkyKey, Iterable<SkyKey>> getReverseDeps(Iterable<SkyKey> keys);
+
   /** Provides a WalkableGraph on demand after preparing it. */
   interface WalkableGraphFactory {
     WalkableGraph prepareAndGet(Collection<String> roots, int numThreads,