blob: 91559622062c0e1bf953a5aa9f6aa23d6e535c4c [file] [log] [blame]
// Copyright 2024 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.devtools.build.lib.util;
import static com.google.common.truth.Truth.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.refEq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.devtools.build.lib.collect.ConcurrentIdentitySet;
import com.google.devtools.build.lib.util.ObjectGraphTraverser.DomainSpecificTraverser;
import com.google.devtools.build.lib.util.ObjectGraphTraverser.EdgeType;
import com.google.devtools.build.lib.util.ObjectGraphTraverser.FieldCache;
import com.google.devtools.build.lib.util.ObjectGraphTraverser.Traversal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.function.Supplier;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@RunWith(JUnit4.class)
public class ObjectGraphTraverserTest {
private static final class Edge {
private final Object from;
private final Object to;
private final EdgeType type;
private Edge(Object from, Object to, EdgeType type) {
this.from = from;
this.to = to;
this.type = type;
}
private static Edge of(Object from, Object to, EdgeType type) {
return new Edge(from, to, type);
}
@Override
public boolean equals(Object o) {
if (!(o instanceof Edge that)) {
return false;
}
return that.from == from && that.to == to && that.type == type;
}
@Override
public int hashCode() {
return Objects.hash(System.identityHashCode(from), System.identityHashCode(to), type);
}
}
private static final class LoggingObjectReceiver implements ObjectGraphTraverser.ObjectReceiver {
private List<Object> objects = new ArrayList<>();
private Map<Object, String> objectContexts = new HashMap<>();
private List<Edge> edges = new ArrayList<>();
private Map<Edge, String> edgeContexts = new HashMap<>();
@Override
public void objectFound(Object o, String context) {
objects.add(o);
if (context != null) {
objectContexts.put(o, context);
}
}
@Override
public void edgeFound(Object from, Object to, String toContext, EdgeType edgeType) {
Edge edge = Edge.of(from, to, edgeType);
edges.add(edge);
if (toContext != null) {
edgeContexts.put(edge, toContext);
}
}
}
private ObjectGraphTraverser createObjectGraphTraverser(
DomainSpecificTraverser domainSpecific,
ConcurrentIdentitySet seen,
LoggingObjectReceiver receiver,
boolean collectContext) {
ImmutableList<DomainSpecificTraverser> traversers =
domainSpecific == null ? ImmutableList.of() : ImmutableList.of(domainSpecific);
return new ObjectGraphTraverser(
new FieldCache(traversers), true, seen, collectContext, receiver, null);
}
@Test
public void smoke() {
Object o1 = new Object();
Object o2 = new Object();
Object array = new Object[] {o2};
Object pair = Pair.of(o1, array);
ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
LoggingObjectReceiver receiver = new LoggingObjectReceiver();
ObjectGraphTraverser cut = createObjectGraphTraverser(null, seen, receiver, false);
cut.traverse(pair);
assertThat(receiver.objects).containsExactly(o1, o2, array, pair);
assertThat(receiver.edges).hasSize(3);
}
@Test
public void testAdmit() {
Object o1 = new Object();
Object o2 = new Object();
Object pair1 = Pair.of(o1, o1);
Object pair2 = Pair.of(o2, o2);
Object pair3 = Pair.of(pair1, pair2);
DomainSpecificTraverser domainSpecific = mock(DomainSpecificTraverser.class);
when(domainSpecific.admit(any())).thenAnswer(i -> i.getArgument(0) != pair2);
ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
LoggingObjectReceiver receiver = new LoggingObjectReceiver();
ObjectGraphTraverser cut = createObjectGraphTraverser(domainSpecific, seen, receiver, false);
cut.traverse(pair3);
assertThat(receiver.objects).containsExactly(o1, pair1, pair3);
assertThat(receiver.edges).hasSize(3);
}
@Test
public void testCustomTraversal() {
Object o1 = new Object();
Object o2 = new Object();
DomainSpecificTraverser domainSpecific = mock(DomainSpecificTraverser.class);
when(domainSpecific.admit(any())).thenReturn(true);
when(domainSpecific.maybeTraverse(any(), any()))
.thenAnswer(
i -> {
Object arg = i.getArgument(0);
Traversal traversal = i.getArgument(1);
if (arg != o1) {
return false;
}
traversal.objectFound(o1, null);
traversal.edgeFound(o2, null);
return true;
});
ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
LoggingObjectReceiver receiver = new LoggingObjectReceiver();
ObjectGraphTraverser cut = createObjectGraphTraverser(domainSpecific, seen, receiver, false);
cut.traverse(o1);
assertThat(receiver.objects).containsExactly(o1, o2);
assertThat(receiver.edges).containsExactly(Edge.of(o1, o2, EdgeType.CURRENT_TRAVERSAL));
}
@Test
public void testIgnoredFields() {
Object o1 = new Object();
Object o2 = new Object();
Object pair = Pair.of(o1, o2);
DomainSpecificTraverser domainSpecific = mock(DomainSpecificTraverser.class);
when(domainSpecific.ignoredFields(Pair.class)).thenReturn(ImmutableSet.of("second"));
when(domainSpecific.admit(any())).thenReturn(true);
ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
LoggingObjectReceiver receiver = new LoggingObjectReceiver();
ObjectGraphTraverser cut = createObjectGraphTraverser(domainSpecific, seen, receiver, false);
cut.traverse(pair);
assertThat(receiver.objects).containsExactly(o1, pair);
assertThat(receiver.edges).containsExactly(Edge.of(pair, o1, EdgeType.CURRENT_TRAVERSAL));
}
@Test
public void testSeenObjects() {
Object o1 = new Object();
Object o2 = new Object();
Object pair = Pair.of(o1, o2);
ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
var unused = seen.add(o2);
LoggingObjectReceiver receiver = new LoggingObjectReceiver();
ObjectGraphTraverser cut = createObjectGraphTraverser(null, seen, receiver, false);
cut.traverse(pair);
assertThat(receiver.objects).containsExactly(o1, pair);
assertThat(receiver.edges)
.containsExactly(
Edge.of(pair, o1, EdgeType.CURRENT_TRAVERSAL),
Edge.of(pair, o2, EdgeType.ALREADY_SEEN));
}
private static final class Outer {
private Inner createInner() {
return new Inner();
}
private class Inner {
// Java is clever and will optimize out the reference to Outer without this
@SuppressWarnings("unused")
private Outer getOuter() {
return Outer.this;
}
}
}
@Test
public void testNonStaticClassTraversesEnclosingClass() {
Outer outer = new Outer();
Outer.Inner inner = outer.createInner();
ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
LoggingObjectReceiver receiver = new LoggingObjectReceiver();
ObjectGraphTraverser cut = createObjectGraphTraverser(null, seen, receiver, false);
cut.traverse(inner);
assertThat(receiver.objects).containsExactly(outer, inner);
}
@Test
public void testLambdaClosingOverNothingReported() {
Object o1 = new Object();
Supplier<Object> lambda = () -> 3;
Object pair = Pair.of(o1, lambda);
ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
LoggingObjectReceiver receiver = new LoggingObjectReceiver();
ObjectGraphTraverser cut = createObjectGraphTraverser(null, seen, receiver, false);
cut.traverse(pair);
assertThat(receiver.objects).containsExactly(pair, o1, lambda);
}
@Test
public void testLambdaClosingOverNothingReportedWhenReferencedTwice() {
Supplier<Object> lambda = () -> 3;
Object pair = Pair.of(lambda, lambda);
ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
LoggingObjectReceiver receiver = new LoggingObjectReceiver();
ObjectGraphTraverser cut = createObjectGraphTraverser(null, seen, receiver, false);
cut.traverse(pair);
assertThat(receiver.objects).containsExactly(pair, lambda);
}
@Test
public void testValuesClosedOverReported() {
Object o1 = new Object();
Supplier<Object> lambda = () -> o1;
ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
LoggingObjectReceiver receiver = new LoggingObjectReceiver();
ObjectGraphTraverser cut = createObjectGraphTraverser(null, seen, receiver, false);
cut.traverse(lambda);
assertThat(receiver.objects).containsExactly(lambda, o1);
}
@Test
public void testMultipleClosuresWithSameCodeReported() {
Object o1 = new Object();
Object o2 = new Object();
Function<Object, Supplier<Object>> generator = o -> () -> o;
Object l1 = generator.apply(o1);
Object l2 = generator.apply(o2);
Object pair = Pair.of(l1, l2);
ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
LoggingObjectReceiver receiver = new LoggingObjectReceiver();
ObjectGraphTraverser cut = createObjectGraphTraverser(null, seen, receiver, false);
cut.traverse(pair);
assertThat(receiver.objects).containsExactly(pair, l1, l2, o1, o2);
}
@Test
public void testEdgeContexts() {
Object o1 = new Object();
Object o2 = new Object();
Object array = new Object[] {o2};
Object pair = Pair.of(o1, array);
ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
LoggingObjectReceiver receiver = new LoggingObjectReceiver();
DomainSpecificTraverser domainSpecific = mock(DomainSpecificTraverser.class);
when(domainSpecific.admit(any())).thenReturn(true);
when(domainSpecific.contextForField(refEq(pair), any(), any(), refEq(o1)))
.thenReturn("o1context");
when(domainSpecific.contextForArrayItem(refEq(array), any(), refEq(o2)))
.thenReturn("o2context");
ObjectGraphTraverser cut = createObjectGraphTraverser(domainSpecific, seen, receiver, true);
cut.traverse(pair);
assertThat(receiver.edgeContexts)
.containsEntry(Edge.of(pair, o1, EdgeType.CURRENT_TRAVERSAL), "o1context");
assertThat(receiver.edgeContexts)
.containsEntry(Edge.of(array, o2, EdgeType.CURRENT_TRAVERSAL), "o2context");
assertThat(receiver.objectContexts).containsEntry(o1, "o1context");
assertThat(receiver.objectContexts).containsEntry(o2, "o2context");
}
@Test
public void testObjectContexts() {
Object o1 = new Object();
Object o2 = new Object();
Object pair = Pair.of(o1, o2);
ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
LoggingObjectReceiver receiver = new LoggingObjectReceiver();
DomainSpecificTraverser domainSpecific = mock(DomainSpecificTraverser.class);
when(domainSpecific.admit(any())).thenReturn(true);
when(domainSpecific.contextForField(refEq(pair), any(), any(), refEq(o1))).thenReturn("bad");
when(domainSpecific.maybeTraverse(any(), any()))
.thenAnswer(
i -> {
Object o = i.getArgument(0);
Traversal traversal = i.getArgument(1);
if (o == o1) {
traversal.objectFound(o, "o1context");
return true;
} else if (o == o2) {
traversal.objectFound(o, "o2context");
return true;
} else {
return false;
}
});
ObjectGraphTraverser cut = createObjectGraphTraverser(domainSpecific, seen, receiver, true);
cut.traverse(pair);
assertThat(receiver.objectContexts).containsEntry(o1, "o1context"); // overrides edge context
assertThat(receiver.objectContexts).containsEntry(o2, "o2context");
}
}