| // Copyright 2015 The Bazel Authors. All rights reserved. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| package com.google.devtools.build.skyframe; |
| |
| import static com.google.common.truth.Truth.assertThat; |
| import static org.junit.Assert.assertEquals; |
| import static org.junit.Assert.assertFalse; |
| import static org.junit.Assert.assertNotNull; |
| import static org.junit.Assert.assertTrue; |
| import static org.junit.Assert.fail; |
| |
| import com.google.common.base.Preconditions; |
| import com.google.common.collect.Iterables; |
| import com.google.common.collect.Sets; |
| import com.google.devtools.build.lib.concurrent.ExecutorUtil; |
| import com.google.devtools.build.lib.concurrent.KeyedLocker; |
| import com.google.devtools.build.lib.concurrent.RefCountedMultisetKeyedLocker; |
| import com.google.devtools.build.lib.testutil.TestRunnableWrapper; |
| import com.google.devtools.build.lib.testutil.TestUtils; |
| import com.google.devtools.build.lib.util.GroupedList.GroupedListHelper; |
| import com.google.devtools.build.skyframe.GraphTester.StringValue; |
| import com.google.devtools.build.skyframe.NodeEntry.DependencyState; |
| |
| import org.junit.Before; |
| import org.junit.Test; |
| |
| import java.util.ArrayList; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Random; |
| import java.util.Set; |
| import java.util.concurrent.CountDownLatch; |
| import java.util.concurrent.ExecutorService; |
| import java.util.concurrent.Executors; |
| import java.util.concurrent.TimeUnit; |
| |
| /** Base class for concurrency sanity tests on {@link EvaluableGraph} implementations. */ |
| public abstract class GraphConcurrencyTest { |
| |
| private static final SkyFunctionName SKY_FUNCTION_NAME = |
| SkyFunctionName.create("GraphConcurrencyTestKey"); |
| private ProcessableGraph graph; |
| private TestRunnableWrapper wrapper; |
| |
| // This code should really be in a @Before method, but @Before methods are executed from the |
| // top down, and this class's @Before method calls #getGraph, so makeGraph must have already |
| // been called. |
| protected abstract void makeGraph() throws Exception; |
| |
| protected abstract ProcessableGraph getGraph(Version version) throws Exception; |
| |
| private static final IntVersion startingVersion = new IntVersion(42); |
| |
| @Before |
| public void init() throws Exception { |
| makeGraph(); |
| this.graph = getGraph(startingVersion); |
| this.wrapper = new TestRunnableWrapper("GraphConcurrencyTest"); |
| } |
| |
| private SkyKey key(String name) { |
| return new SkyKey(SKY_FUNCTION_NAME, name); |
| } |
| |
| @Test |
| public void createIfAbsentSanity() { |
| graph.createIfAbsent(key("cat")); |
| } |
| |
| // Tests adding and removing Rdeps of a {@link NodeEntry} while a node transitions from |
| // not done to done. |
| @Test |
| public void testAddRemoveRdeps() throws Exception { |
| SkyKey key = key("foo"); |
| final NodeEntry entry = graph.createIfAbsent(key); |
| // These numbers are arbitrary. |
| int numThreads = 50; |
| int numKeys = numThreads; |
| // One chunk will be used to add and remove rdeps before setting the node value. The second |
| // chunk of work will have the node value set and the last chunk will be to add and remove |
| // rdeps after the value has been set. |
| final int chunkSize = 40; |
| final int numIterations = chunkSize * 2; |
| // This latch is used to signal that the runnables have been submitted to the executor. |
| final CountDownLatch waitForStart = new CountDownLatch(1); |
| // This latch is used to signal to the main thread that we have begun the second chunk |
| // for sufficiently many keys. The minimum of numThreads and numKeys is used to prevent |
| // thread starvation from causing a delay here. |
| final CountDownLatch waitForAddedRdep = new CountDownLatch(numThreads); |
| // This latch is used to guarantee that we set the node's value before we enter the third |
| // chunk for any key. |
| final CountDownLatch waitForSetValue = new CountDownLatch(1); |
| ExecutorService pool = Executors.newFixedThreadPool(numThreads); |
| // Add single rdep before transition to done. |
| assertEquals(DependencyState.NEEDS_SCHEDULING, entry.addReverseDepAndCheckIfDone(key("rdep"))); |
| for (int i = 0; i < numKeys; i++) { |
| final int j = i; |
| Runnable r = |
| new Runnable() { |
| @Override |
| public void run() { |
| try { |
| // Add and remove the rdep a bunch of times to test interleaving. |
| waitForStart.await(TestUtils.WAIT_TIMEOUT_SECONDS, TimeUnit.SECONDS); |
| for (int k = 1; k < chunkSize; k++) { |
| assertThat(entry.addReverseDepAndCheckIfDone(key("rdep" + j))) |
| .isNotEqualTo(DependencyState.DONE); |
| entry.removeInProgressReverseDep(key("rdep" + j)); |
| assertThat(entry.getInProgressReverseDeps()).doesNotContain(key("rdep" + j)); |
| } |
| assertThat(entry.addReverseDepAndCheckIfDone(key("rdep" + j))) |
| .isNotEqualTo(DependencyState.DONE); |
| waitForAddedRdep.countDown(); |
| waitForSetValue.await(TestUtils.WAIT_TIMEOUT_SECONDS, TimeUnit.SECONDS); |
| } catch (InterruptedException e) { |
| fail("Test failed: " + e.toString()); |
| } |
| for (int k = chunkSize; k <= numIterations; k++) { |
| entry.removeReverseDep(key("rdep" + j)); |
| entry.addReverseDepAndCheckIfDone(key("rdep" + j)); |
| entry.getReverseDeps(); |
| } |
| } |
| }; |
| pool.execute(wrapper.wrap(r)); |
| } |
| waitForStart.countDown(); |
| waitForAddedRdep.await(TestUtils.WAIT_TIMEOUT_SECONDS, TimeUnit.SECONDS); |
| entry.setValue(new StringValue("foo1"), startingVersion); |
| waitForSetValue.countDown(); |
| entry.removeReverseDep(key("rdep")); |
| wrapper.waitForTasksAndMaybeThrow(); |
| assertFalse(ExecutorUtil.interruptibleShutdown(pool)); |
| assertEquals(new StringValue("foo1"), graph.get(key).getValue()); |
| assertEquals(numKeys, Iterables.size(graph.get(key).getReverseDeps())); |
| |
| graph = getGraph(startingVersion.next()); |
| NodeEntry sameEntry = Preconditions.checkNotNull(graph.get(key)); |
| // Mark the node as dirty again and check that the reverse deps have been preserved. |
| sameEntry.markDirty(true); |
| startEvaluation(sameEntry); |
| sameEntry.markRebuildingAndGetAllRemainingDirtyDirectDeps(); |
| sameEntry.setValue(new StringValue("foo2"), startingVersion.next()); |
| assertEquals(new StringValue("foo2"), graph.get(key).getValue()); |
| assertEquals(numKeys, Iterables.size(graph.get(key).getReverseDeps())); |
| } |
| |
| // Tests adding inflight nodes with a given key while an existing node with the same key |
| // undergoes a transition from not done to done. |
| @Test |
| public void testAddingInflightNodes() throws Exception { |
| int numThreads = 50; |
| final KeyedLocker<SkyKey> locker = new RefCountedMultisetKeyedLocker<>(); |
| ExecutorService pool = Executors.newFixedThreadPool(numThreads); |
| final int numKeys = 500; |
| // Add each key 10 times. |
| final Set<SkyKey> nodeCreated = Sets.newConcurrentHashSet(); |
| final Set<SkyKey> valuesSet = Sets.newConcurrentHashSet(); |
| for (int i = 0; i < 10; i++) { |
| for (int j = 0; j < numKeys; j++) { |
| final int keyNum = j; |
| final SkyKey key = key("foo" + keyNum); |
| Runnable r = |
| new Runnable() { |
| public void run() { |
| NodeEntry entry; |
| try (KeyedLocker.AutoUnlocker unlocker = locker.lock(key)) { |
| entry = graph.get(key); |
| if (entry == null) { |
| assertTrue(nodeCreated.add(key)); |
| } |
| entry = graph.createIfAbsent(key); |
| } |
| // {@code entry.addReverseDepAndCheckIfDone(null)} should return NEEDS_SCHEDULING at |
| // most once. |
| if (startEvaluation(entry).equals(DependencyState.NEEDS_SCHEDULING)) { |
| assertTrue(valuesSet.add(key)); |
| // Set to done. |
| entry.setValue(new StringValue("bar" + keyNum), startingVersion); |
| assertThat(entry.isDone()).isTrue(); |
| } |
| // This shouldn't cause any problems from the other threads. |
| graph.createIfAbsent(key); |
| } |
| }; |
| pool.execute(wrapper.wrap(r)); |
| } |
| } |
| wrapper.waitForTasksAndMaybeThrow(); |
| assertFalse(ExecutorUtil.interruptibleShutdown(pool)); |
| // Check that all the values are as expected. |
| for (int i = 0; i < numKeys; i++) { |
| SkyKey key = key("foo" + i); |
| assertTrue(nodeCreated.contains(key)); |
| assertTrue(valuesSet.contains(key)); |
| assertThat(graph.get(key).getValue()).isEqualTo(new StringValue("bar" + i)); |
| assertThat(graph.get(key).getVersion()).isEqualTo(startingVersion); |
| } |
| } |
| |
| /** |
| * Initially calling {@link NodeEntry#setValue} and then making sure concurrent calls to |
| * {@link QueryableGraph#get} and {@link QueryableGraph#getBatch} do not interfere with the node. |
| */ |
| @Test |
| public void testDoneToDirty() throws Exception { |
| final int numKeys = 1000; |
| int numThreads = 50; |
| final int numBatchRequests = 100; |
| // Create a bunch of done nodes. |
| for (int i = 0; i < numKeys; i++) { |
| NodeEntry entry = graph.createIfAbsent(key("foo" + i)); |
| startEvaluation(entry); |
| entry.setValue(new StringValue("bar"), startingVersion); |
| } |
| |
| assertNotNull(graph.get(key("foo" + 0))); |
| graph = getGraph(startingVersion.next()); |
| assertNotNull(graph.get(key("foo" + 0))); |
| ExecutorService pool1 = Executors.newFixedThreadPool(numThreads); |
| ExecutorService pool2 = Executors.newFixedThreadPool(numThreads); |
| ExecutorService pool3 = Executors.newFixedThreadPool(numThreads); |
| |
| // Only start all the threads once the batch requests are ready. |
| final CountDownLatch makeBatchCountDownLatch = new CountDownLatch(numBatchRequests); |
| // Do at least 5 single requests and batch requests before transitioning node. |
| final CountDownLatch getBatchCountDownLatch = new CountDownLatch(5); |
| final CountDownLatch getCountDownLatch = new CountDownLatch(5); |
| |
| for (int i = 0; i < numKeys; i++) { |
| final int keyNum = i; |
| // Transition the nodes from done to dirty and then back to done. |
| Runnable r1 = |
| new Runnable() { |
| @Override |
| public void run() { |
| try { |
| makeBatchCountDownLatch.await(); |
| getBatchCountDownLatch.await(); |
| getCountDownLatch.await(); |
| } catch (InterruptedException e) { |
| throw new AssertionError(e); |
| } |
| NodeEntry entry = graph.get(key("foo" + keyNum)); |
| entry.markDirty(true); |
| // Make some changes, like adding a dep and rdep. |
| entry.addReverseDepAndCheckIfDone(key("rdep")); |
| entry.markRebuildingAndGetAllRemainingDirtyDirectDeps(); |
| addTemporaryDirectDep(entry, key("dep")); |
| entry.signalDep(); |
| // Move node from dirty back to done. |
| entry.setValue(new StringValue("bar" + keyNum), startingVersion.next()); |
| } |
| }; |
| |
| // Start a bunch of get() calls while the node transitions from dirty to done and back. |
| Runnable r2 = |
| new Runnable() { |
| @Override |
| public void run() { |
| try { |
| makeBatchCountDownLatch.await(); |
| } catch (InterruptedException e) { |
| throw new AssertionError(e); |
| } |
| NodeEntry entry = graph.get(key("foo" + keyNum)); |
| assertNotNull(entry); |
| // Requests for the value are made at the same time that the version increments from |
| // the base. Check that there is no problem in requesting the version and that the |
| // number is sane. |
| assertThat(entry.getVersion()).isAnyOf(startingVersion, startingVersion.next()); |
| getCountDownLatch.countDown(); |
| } |
| }; |
| pool1.execute(wrapper.wrap(r1)); |
| pool2.execute(wrapper.wrap(r2)); |
| } |
| Random r = new Random(TestUtils.getRandomSeed()); |
| // Start a bunch of getBatch() calls while the node transitions from dirty to done and back. |
| for (int i = 0; i < numBatchRequests; i++) { |
| final List<SkyKey> batch = new ArrayList<>(numKeys); |
| // Pseudorandomly uniformly sample the powerset of the keys. |
| for (int j = 0; j < numKeys; j++) { |
| if (r.nextBoolean()) { |
| batch.add(key("foo" + j)); |
| } |
| } |
| makeBatchCountDownLatch.countDown(); |
| Runnable r3 = |
| new Runnable() { |
| @Override |
| public void run() { |
| try { |
| makeBatchCountDownLatch.await(); |
| } catch (InterruptedException e) { |
| throw new AssertionError(e); |
| } |
| Map<SkyKey, NodeEntry> batchMap = graph.getBatch(batch); |
| getBatchCountDownLatch.countDown(); |
| assertThat(batchMap).hasSize(batch.size()); |
| for (NodeEntry entry : batchMap.values()) { |
| // Batch requests are made at the same time that the version increments from the |
| // base. Check that there is no problem in requesting the version and that the |
| // number is sane. |
| assertThat(entry.getVersion()).isAnyOf(startingVersion, startingVersion.next()); |
| } |
| } |
| }; |
| pool3.execute(wrapper.wrap(r3)); |
| } |
| wrapper.waitForTasksAndMaybeThrow(); |
| assertFalse(ExecutorUtil.interruptibleShutdown(pool1)); |
| assertFalse(ExecutorUtil.interruptibleShutdown(pool2)); |
| assertFalse(ExecutorUtil.interruptibleShutdown(pool3)); |
| for (int i = 0; i < numKeys; i++) { |
| NodeEntry entry = graph.get(key("foo" + i)); |
| assertThat(entry.getValue()).isEqualTo(new StringValue("bar" + i)); |
| assertThat(entry.getVersion()).isEqualTo(startingVersion.next()); |
| for (SkyKey key : entry.getReverseDeps()) { |
| assertEquals(key("rdep"), key); |
| } |
| for (SkyKey key : entry.getDirectDeps()) { |
| assertEquals(key("dep"), key); |
| } |
| } |
| } |
| |
| private DependencyState startEvaluation(NodeEntry entry) { |
| return entry.addReverseDepAndCheckIfDone(null); |
| } |
| |
| private static void addTemporaryDirectDep(NodeEntry entry, SkyKey key) { |
| GroupedListHelper<SkyKey> helper = new GroupedListHelper<>(); |
| helper.add(key); |
| entry.addTemporaryDirectDeps(helper); |
| } |
| } |