Rename TaskFifo to ConcurrentFifo and generalize it beyond Runnable elements.

Uses a generic parameter to represent the element type.

PiperOrigin-RevId: 663286905
Change-Id: I29cc86df6e5fb6c0cf41594e894e016736f9510f
diff --git a/src/main/java/com/google/devtools/build/lib/concurrent/TaskFifo.java b/src/main/java/com/google/devtools/build/lib/concurrent/ConcurrentFifo.java
similarity index 79%
rename from src/main/java/com/google/devtools/build/lib/concurrent/TaskFifo.java
rename to src/main/java/com/google/devtools/build/lib/concurrent/ConcurrentFifo.java
index b19ae3d..fc0d323 100644
--- a/src/main/java/com/google/devtools/build/lib/concurrent/TaskFifo.java
+++ b/src/main/java/com/google/devtools/build/lib/concurrent/ConcurrentFifo.java
@@ -22,7 +22,7 @@
 import sun.misc.Unsafe;
 
 /**
- * A fixed-capacity concurrent FIFO for tasks.
+ * A fixed-capacity concurrent FIFO.
  *
  * <p>This class is a higher performance, nearly garbage-free, but less flexible substitute for
  * {@link ConcurrentLinkedQueue},
@@ -30,7 +30,7 @@
  * <ul>
  *   <li>The queue capacity is fixed.
  *   <li>The client must guarantee not to take more than it has added.
- *   <li>The client must have an fallback if {@link TaskFifo#tryAppend} fails.
+ *   <li>The client must have an fallback if {@link tryAppend} fails.
  * </ul>
  *
  * <p>This class is inspired by Morrison, Adam, and Yehuda Afek. "Fast concurrent queues for x86
@@ -38,25 +38,26 @@
  * programming. 2013.
  */
 @SuppressWarnings("SunApi") // TODO: b/359688989 - clean this up
-final class TaskFifo {
-  private static final int TASKS_MAX_VALUE = (1 << 20) - 1;
-
+final class ConcurrentFifo<T> {
   private static final Integer SKIP_SLOW_APPENDER = 1;
 
-  /**
-   * The power of 2 backing array capacity.
-   *
-   * <p>This is one more than the number of elements the queue can contain at one time. This helps
-   * the efficiency of bits used to represent the number of elements enqueued. For example, the
-   * number of bits needed to represent the element count for a queue of size 256 is 9, but only 8
-   * bits are needed for a queue of size 255.
-   */
-  @VisibleForTesting static final int CAPACITY = TASKS_MAX_VALUE + 1;
+  /** The power of 2 backing array capacity. */
+  @VisibleForTesting static final int CAPACITY = 1 << 20;
 
   /** AND with this mask performs modulo {@link #CAPACITY}. */
   private static final int CAPACITY_MASK = CAPACITY - 1;
 
   /**
+   * Maximum number of elements the FIFO can contain, one less than {@link #CAPACITY}.
+   *
+   * <p>While the backing array's size is a power of 2, this is one less than that, to improve the
+   * efficiency of bits used to represent the number of elements enqueued. For example, the number
+   * of bits needed to represent the element count for a queue of size 256 is 9, but only 8 bits are
+   * needed for a queue of size 255.
+   */
+  private static final int MAX_ELEMENTS = CAPACITY - 1;
+
+  /**
    * Circular buffer containing tasks and skip metadata.
    *
    * <p>The algorithm assigns to each caller of {@link #tryAppend} or {@link #take} a monotonically
@@ -81,8 +82,8 @@
    * should expect to find an empty queue position due to capacity constraints.
    *
    * <p>Appenders that observe a value when expecting an empty position wrap the value with {@link
-   * TaskWithSkippedAppends} then skip to the next available index. Slow takers decrement the counts
-   * on the wrappers then skip to the next available index.
+   * ElementWithSkippedAppends} then skip to the next available index. Slow takers decrement the
+   * counts on the wrappers then skip to the next available index.
    *
    * <p>The skip marker has a count because the number of threads that could potentially be
    * descheduled at a particular index is only limited by the queue capacity, though more than one
@@ -92,10 +93,10 @@
    *
    * <ul>
    *   <li>{@code null} is an empty position.
-   *   <li>{@link Runnable} is a position containing a task.
+   *   <li>{@link T} is a position containing a task.
    *   <li>{@link Integer} is a count of takers that skipped the position because they observed a
    *       null value. The count corresponds to slow appenders at the position.
-   *   <li>{@link TaskWithSkippedAppends} is a task with a count of appenders that skipped the
+   *   <li>{@link ElementWithSkippedAppends} is a task with a count of appenders that skipped the
    *       position due to it being still occupied with a task. The count corresponds to slow takers
    *       assigned to the position.
    * </ul>
@@ -107,15 +108,17 @@
    */
   private final Object[] queue = new Object[CAPACITY];
 
+  private final Class<? super T> elementType;
+
   /**
-   * Address of index for appending; incremented by appending.
+   * Address of int index for appending; incremented by appending.
    *
    * <p>The actual array offset is the value modulo {@link #CAPACITY}.
    */
   private final long appendIndexAddress;
 
   /**
-   * Address of index for taking; incremented by taking.
+   * Address of int index for taking; incremented by taking.
    *
    * <p>The actual array offset is the value modulo {@link #CAPACITY}.
    */
@@ -137,12 +140,17 @@
    * @param takeIndexAddress padded location of the {@code int} take index.
    * @param appendIndexAddress padded location of the {@code int} append index.
    */
-  TaskFifo(long sizeAddress, long appendIndexAddress, long takeIndexAddress) {
+  ConcurrentFifo(
+      Class<? super T> elementType,
+      long sizeAddress,
+      long appendIndexAddress,
+      long takeIndexAddress) {
+    this.elementType = elementType;
     this.sizeAddress = sizeAddress;
     this.appendIndexAddress = appendIndexAddress;
     this.takeIndexAddress = takeIndexAddress;
 
-    // Explicitly initializes the provided addresses.
+    // Explicitly initializes the memory at the provided addresses.
     UNSAFE.putInt(null, sizeAddress, 0);
     UNSAFE.putInt(null, appendIndexAddress, 0);
     UNSAFE.putInt(null, takeIndexAddress, 0);
@@ -153,9 +161,9 @@
    *
    * @return true if successful, false if it would have exceeded the capacity.
    */
-  boolean tryAppend(Runnable task) {
+  boolean tryAppend(T task) {
     // Optimistically increases size, and rolls back if it exceeds capacity.
-    if (UNSAFE.getAndAddInt(null, sizeAddress, 1) >= TASKS_MAX_VALUE) {
+    if (UNSAFE.getAndAddInt(null, sizeAddress, 1) >= MAX_ELEMENTS) {
       UNSAFE.getAndAddInt(null, sizeAddress, -1);
       return false;
     }
@@ -184,12 +192,14 @@
           // skipped this offset.
           int newCount = ((Integer) snapshot) - 1;
           target = newCount == 0 ? null : newCount;
-        } else if (snapshot instanceof Runnable) {
+        } else if (elementType.isInstance(snapshot)) {
           // A taker was slow.
-          target = new TaskWithSkippedAppends((Runnable) snapshot, /* skippedAppendCount= */ 1);
+          @SuppressWarnings("unchecked")
+          T castSnapshot = (T) snapshot;
+          target = new ElementWithSkippedAppends<T>(castSnapshot, /* skippedAppendCount= */ 1);
         } else {
           // Multiple takers are slow. This should be very rare. Increments the skip count.
-          target = ((TaskWithSkippedAppends) snapshot).incrementSkips();
+          target = ((ElementWithSkippedAppends) snapshot).incrementSkips();
         }
         if (UNSAFE.compareAndSwapObject(queue, offset, snapshot, target)) {
           break; // Success, skips to next.
@@ -203,7 +213,7 @@
    *
    * <p>This must not be called more times than {@link #tryAppend} has succeeded.
    */
-  Runnable take() {
+  T take() {
     do {
       int offset = getQueueOffset(UNSAFE.getAndAddInt(null, takeIndexAddress, 1));
       do {
@@ -214,11 +224,13 @@
         // 2. On subsequent reads, this immediately follows a failed CAS of the same memory
         //    location, which refreshes the memory.
         Object snapshot = UNSAFE.getObject(queue, offset);
-        if (snapshot instanceof Runnable) {
+        if (elementType.isInstance(snapshot)) {
           // Attempts to take ownership of the task.
           if (UNSAFE.compareAndSwapObject(queue, offset, snapshot, null)) {
             UNSAFE.getAndAddInt(null, sizeAddress, -1);
-            return (Runnable) snapshot;
+            @SuppressWarnings("unchecked")
+            T castSnapshot = (T) snapshot;
+            return castSnapshot;
           }
         } else {
           Object target;
@@ -229,7 +241,7 @@
             target = ((Integer) snapshot).intValue() + 1;
           } else {
             // There have been appends without corresponding takes. Acknowledges one skip.
-            target = ((TaskWithSkippedAppends) snapshot).decrementSkips();
+            target = ((ElementWithSkippedAppends) snapshot).decrementSkips();
           }
           if (UNSAFE.compareAndSwapObject(queue, offset, snapshot, target)) {
             break; // Success, skips to next.
@@ -270,12 +282,12 @@
       var elt = queue[i];
       if (elt == null) {
         buf.append('0');
-      } else if (elt instanceof Runnable) {
+      } else if (elementType.isInstance(elt)) {
         buf.append('1');
       } else if (elt instanceof Integer) {
         buf.append('S').append(elt);
       } else {
-        buf.append('T').append(((TaskWithSkippedAppends) elt).skippedAppendCount);
+        buf.append('T').append(((ElementWithSkippedAppends) elt).skippedAppendCount());
       }
     }
     helper.add("queue", buf.append(']').toString());
@@ -288,43 +300,26 @@
   }
 
   private static int getQueueOffset(int index) {
-    return TASKS_BASE + TASKS_SCALE * (index & CAPACITY_MASK);
+    return ELEMENTS_BASE + ELEMENTS_SCALE * (index & CAPACITY_MASK);
   }
 
   @VisibleForTesting
-  static class TaskWithSkippedAppends {
-    private final Runnable task;
-    private final int skippedAppendCount;
-
-    private TaskWithSkippedAppends(Runnable task, int skippedAppendCount) {
-      this.task = task;
-      this.skippedAppendCount = skippedAppendCount;
-    }
+  record ElementWithSkippedAppends<T>(T element, int skippedAppendCount) {
 
     private Object decrementSkips() {
       if (skippedAppendCount <= 1) {
-        return task;
+        return element;
       }
-      return new TaskWithSkippedAppends(task, skippedAppendCount - 1);
+      return new ElementWithSkippedAppends<>(element, skippedAppendCount - 1);
     }
 
-    private TaskWithSkippedAppends incrementSkips() {
-      return new TaskWithSkippedAppends(task, skippedAppendCount + 1);
-    }
-
-    @VisibleForTesting
-    Runnable taskForTesting() {
-      return task;
-    }
-
-    @VisibleForTesting
-    int skippedAppendCountForTesting() {
-      return skippedAppendCount;
+    private ElementWithSkippedAppends<T> incrementSkips() {
+      return new ElementWithSkippedAppends<T>(element, skippedAppendCount + 1);
     }
   }
 
   private static final Unsafe UNSAFE = UnsafeProvider.unsafe();
 
-  private static final int TASKS_BASE = Unsafe.ARRAY_OBJECT_BASE_OFFSET;
-  private static final int TASKS_SCALE = Unsafe.ARRAY_OBJECT_INDEX_SCALE;
+  private static final int ELEMENTS_BASE = Unsafe.ARRAY_OBJECT_BASE_OFFSET;
+  private static final int ELEMENTS_SCALE = Unsafe.ARRAY_OBJECT_INDEX_SCALE;
 }
diff --git a/src/test/java/com/google/devtools/build/lib/concurrent/TaskFifoTest.java b/src/test/java/com/google/devtools/build/lib/concurrent/ConcurrentFifoTest.java
similarity index 92%
rename from src/test/java/com/google/devtools/build/lib/concurrent/TaskFifoTest.java
rename to src/test/java/com/google/devtools/build/lib/concurrent/ConcurrentFifoTest.java
index 53fd3f4..c6f029b 100644
--- a/src/test/java/com/google/devtools/build/lib/concurrent/TaskFifoTest.java
+++ b/src/test/java/com/google/devtools/build/lib/concurrent/ConcurrentFifoTest.java
@@ -14,16 +14,16 @@
 package com.google.devtools.build.lib.concurrent;
 
 import static com.google.common.truth.Truth.assertThat;
+import static com.google.devtools.build.lib.concurrent.ConcurrentFifo.CAPACITY;
 import static com.google.devtools.build.lib.concurrent.PaddedAddresses.createPaddedBaseAddress;
 import static com.google.devtools.build.lib.concurrent.PaddedAddresses.getAlignedAddress;
-import static com.google.devtools.build.lib.concurrent.TaskFifo.CAPACITY;
 import static com.google.devtools.build.lib.testutil.TestUtils.WAIT_TIMEOUT_SECONDS;
 import static java.util.concurrent.TimeUnit.SECONDS;
 import static org.junit.Assert.fail;
 
 import com.google.common.collect.Sets;
 import com.google.common.flogger.GoogleLogger;
-import com.google.devtools.build.lib.concurrent.TaskFifo.TaskWithSkippedAppends;
+import com.google.devtools.build.lib.concurrent.ConcurrentFifo.ElementWithSkippedAppends;
 import com.google.devtools.build.lib.unsafe.UnsafeProvider;
 import com.google.testing.junit.testparameterinjector.TestParameter;
 import com.google.testing.junit.testparameterinjector.TestParameterInjector;
@@ -38,7 +38,7 @@
 
 @RunWith(TestParameterInjector.class)
 @SuppressWarnings("SunApi") // TODO: b/359688989 - clean this up
-public final class TaskFifoTest {
+public final class ConcurrentFifoTest {
   private static final GoogleLogger logger = GoogleLogger.forEnclosingClass();
 
   private static final int PARALLELISM = 10;
@@ -50,7 +50,7 @@
   private long appendIndexAddress;
   private long takeIndexAddress;
 
-  private TaskFifo queue;
+  private ConcurrentFifo<Runnable> queue;
 
   @Before
   public void setUp() {
@@ -58,7 +58,7 @@
     sizeAddress = getAlignedAddress(baseAddress, /* offset= */ 0);
     appendIndexAddress = getAlignedAddress(baseAddress, /* offset= */ 1);
     takeIndexAddress = getAlignedAddress(baseAddress, /* offset= */ 2);
-    queue = new TaskFifo(sizeAddress, appendIndexAddress, takeIndexAddress);
+    queue = new ConcurrentFifo<>(Runnable.class, sizeAddress, appendIndexAddress, takeIndexAddress);
   }
 
   @After
@@ -265,10 +265,10 @@
     assertThat(queue.tryAppend(task1)).isTrue();
 
     // Verifies that append adds a wrapper to the task.
-    var wrappedTask = (TaskWithSkippedAppends) queue.getQueueForTesting()[0];
-    assertThat(wrappedTask.taskForTesting()).isEqualTo(task0);
+    var wrappedTask = (ElementWithSkippedAppends) queue.getQueueForTesting()[0];
+    assertThat(wrappedTask.element()).isEqualTo(task0);
     // Verifies that the skip count is 1.
-    assertThat(wrappedTask.skippedAppendCountForTesting()).isEqualTo(1);
+    assertThat(wrappedTask.skippedAppendCount()).isEqualTo(1);
 
     // Verifies that append in fact skips to the next index and appends there.
     assertThat(queue.getQueueForTesting()[1]).isEqualTo(task1);
@@ -284,9 +284,9 @@
     assertThat(queue.tryAppend(task2)).isTrue();
 
     // Verifies that the skip count has been incremented to 2.
-    wrappedTask = (TaskWithSkippedAppends) queue.getQueueForTesting()[0];
-    assertThat(wrappedTask.taskForTesting()).isEqualTo(task0);
-    assertThat(wrappedTask.skippedAppendCountForTesting()).isEqualTo(2);
+    wrappedTask = (ElementWithSkippedAppends) queue.getQueueForTesting()[0];
+    assertThat(wrappedTask.element()).isEqualTo(task0);
+    assertThat(wrappedTask.skippedAppendCount()).isEqualTo(2);
     // Verifies that the append actually skipped to the next index.
     assertThat(queue.getQueueForTesting()[1]).isEqualTo(task2);
 
@@ -294,10 +294,10 @@
 
     // Take skips to the task in the next position when it observes the wrapper.
     assertThat(queue.take()).isEqualTo(task2);
-    wrappedTask = (TaskWithSkippedAppends) queue.getQueueForTesting()[0];
-    assertThat(wrappedTask.taskForTesting()).isEqualTo(task0);
+    wrappedTask = (ElementWithSkippedAppends) queue.getQueueForTesting()[0];
+    assertThat(wrappedTask.element()).isEqualTo(task0);
     // Take decrements the skip counter.
-    assertThat(wrappedTask.skippedAppendCountForTesting()).isEqualTo(1);
+    assertThat(wrappedTask.skippedAppendCount()).isEqualTo(1);
     // Verifies that it took the task in the next position out of the queue.
     assertThat(queue.getQueueForTesting()[1]).isNull();