blob: 20904173abe244da61b67224af482304769e8ba6 [file] [log] [blame]
Googlerddc7dd72024-03-05 08:34:53 -08001// Copyright 2024 The Bazel Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package com.google.devtools.build.lib.util;
16
17import static com.google.common.truth.Truth.assertThat;
18import static org.mockito.ArgumentMatchers.any;
19import static org.mockito.ArgumentMatchers.refEq;
20import static org.mockito.Mockito.mock;
21import static org.mockito.Mockito.when;
22
23import com.google.common.collect.ImmutableList;
24import com.google.common.collect.ImmutableSet;
25import com.google.devtools.build.lib.collect.ConcurrentIdentitySet;
26import com.google.devtools.build.lib.util.ObjectGraphTraverser.DomainSpecificTraverser;
27import com.google.devtools.build.lib.util.ObjectGraphTraverser.EdgeType;
28import com.google.devtools.build.lib.util.ObjectGraphTraverser.FieldCache;
29import com.google.devtools.build.lib.util.ObjectGraphTraverser.Traversal;
30import java.util.ArrayList;
31import java.util.HashMap;
32import java.util.List;
33import java.util.Map;
34import java.util.Objects;
Googler85da0c02024-03-06 08:31:48 -080035import java.util.function.Function;
Googlerddc7dd72024-03-05 08:34:53 -080036import java.util.function.Supplier;
37import org.junit.Test;
38import org.junit.runner.RunWith;
39import org.junit.runners.JUnit4;
40
41@RunWith(JUnit4.class)
42public class ObjectGraphTraverserTest {
43 private static final class Edge {
44 private final Object from;
45 private final Object to;
46 private final EdgeType type;
47
48 private Edge(Object from, Object to, EdgeType type) {
49 this.from = from;
50 this.to = to;
51 this.type = type;
52 }
53
54 private static Edge of(Object from, Object to, EdgeType type) {
55 return new Edge(from, to, type);
56 }
57
58 @Override
59 public boolean equals(Object o) {
Googler6f48f1c2024-04-16 14:29:09 -070060 if (!(o instanceof Edge that)) {
Googlerddc7dd72024-03-05 08:34:53 -080061 return false;
62 }
63
Googlerddc7dd72024-03-05 08:34:53 -080064 return that.from == from && that.to == to && that.type == type;
65 }
66
67 @Override
68 public int hashCode() {
69 return Objects.hash(System.identityHashCode(from), System.identityHashCode(to), type);
70 }
71 }
72
73 private static final class LoggingObjectReceiver implements ObjectGraphTraverser.ObjectReceiver {
74 private List<Object> objects = new ArrayList<>();
75 private Map<Object, String> objectContexts = new HashMap<>();
76 private List<Edge> edges = new ArrayList<>();
77 private Map<Edge, String> edgeContexts = new HashMap<>();
78
79 @Override
80 public void objectFound(Object o, String context) {
81 objects.add(o);
82 if (context != null) {
83 objectContexts.put(o, context);
84 }
85 }
86
87 @Override
88 public void edgeFound(Object from, Object to, String toContext, EdgeType edgeType) {
89 Edge edge = Edge.of(from, to, edgeType);
90
91 edges.add(edge);
92 if (toContext != null) {
93 edgeContexts.put(edge, toContext);
94 }
95 }
96 }
97
98 private ObjectGraphTraverser createObjectGraphTraverser(
99 DomainSpecificTraverser domainSpecific,
100 ConcurrentIdentitySet seen,
101 LoggingObjectReceiver receiver,
102 boolean collectContext) {
103 ImmutableList<DomainSpecificTraverser> traversers =
104 domainSpecific == null ? ImmutableList.of() : ImmutableList.of(domainSpecific);
105 return new ObjectGraphTraverser(
Googlerfd6d9df2024-05-08 03:49:40 -0700106 new FieldCache(traversers), false, true, seen, collectContext, receiver, null);
Googlerddc7dd72024-03-05 08:34:53 -0800107 }
108
109 @Test
110 public void smoke() {
111 Object o1 = new Object();
112 Object o2 = new Object();
113 Object array = new Object[] {o2};
114 Object pair = Pair.of(o1, array);
115
116 ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
117 LoggingObjectReceiver receiver = new LoggingObjectReceiver();
118 ObjectGraphTraverser cut = createObjectGraphTraverser(null, seen, receiver, false);
119 cut.traverse(pair);
120
121 assertThat(receiver.objects).containsExactly(o1, o2, array, pair);
122 assertThat(receiver.edges).hasSize(3);
123 }
124
125 @Test
126 public void testAdmit() {
127 Object o1 = new Object();
128 Object o2 = new Object();
129 Object pair1 = Pair.of(o1, o1);
130 Object pair2 = Pair.of(o2, o2);
131 Object pair3 = Pair.of(pair1, pair2);
132
133 DomainSpecificTraverser domainSpecific = mock(DomainSpecificTraverser.class);
134 when(domainSpecific.admit(any())).thenAnswer(i -> i.getArgument(0) != pair2);
135
136 ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
137 LoggingObjectReceiver receiver = new LoggingObjectReceiver();
138 ObjectGraphTraverser cut = createObjectGraphTraverser(domainSpecific, seen, receiver, false);
139 cut.traverse(pair3);
140
141 assertThat(receiver.objects).containsExactly(o1, pair1, pair3);
142 assertThat(receiver.edges).hasSize(3);
143 }
144
145 @Test
146 public void testCustomTraversal() {
147 Object o1 = new Object();
148 Object o2 = new Object();
149
150 DomainSpecificTraverser domainSpecific = mock(DomainSpecificTraverser.class);
151 when(domainSpecific.admit(any())).thenReturn(true);
152 when(domainSpecific.maybeTraverse(any(), any()))
153 .thenAnswer(
154 i -> {
155 Object arg = i.getArgument(0);
156 Traversal traversal = i.getArgument(1);
157
158 if (arg != o1) {
159 return false;
160 }
161
162 traversal.objectFound(o1, null);
163 traversal.edgeFound(o2, null);
164 return true;
165 });
166
167 ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
168 LoggingObjectReceiver receiver = new LoggingObjectReceiver();
169 ObjectGraphTraverser cut = createObjectGraphTraverser(domainSpecific, seen, receiver, false);
170 cut.traverse(o1);
171
172 assertThat(receiver.objects).containsExactly(o1, o2);
173 assertThat(receiver.edges).containsExactly(Edge.of(o1, o2, EdgeType.CURRENT_TRAVERSAL));
174 }
175
176 @Test
177 public void testIgnoredFields() {
178 Object o1 = new Object();
179 Object o2 = new Object();
180 Object pair = Pair.of(o1, o2);
181
182 DomainSpecificTraverser domainSpecific = mock(DomainSpecificTraverser.class);
183 when(domainSpecific.ignoredFields(Pair.class)).thenReturn(ImmutableSet.of("second"));
184 when(domainSpecific.admit(any())).thenReturn(true);
185
186 ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
187 LoggingObjectReceiver receiver = new LoggingObjectReceiver();
188 ObjectGraphTraverser cut = createObjectGraphTraverser(domainSpecific, seen, receiver, false);
189 cut.traverse(pair);
190
191 assertThat(receiver.objects).containsExactly(o1, pair);
192 assertThat(receiver.edges).containsExactly(Edge.of(pair, o1, EdgeType.CURRENT_TRAVERSAL));
193 }
194
195 @Test
196 public void testSeenObjects() {
197 Object o1 = new Object();
198 Object o2 = new Object();
199 Object pair = Pair.of(o1, o2);
200
201 ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
202 var unused = seen.add(o2);
203 LoggingObjectReceiver receiver = new LoggingObjectReceiver();
204 ObjectGraphTraverser cut = createObjectGraphTraverser(null, seen, receiver, false);
205 cut.traverse(pair);
206
207 assertThat(receiver.objects).containsExactly(o1, pair);
208 assertThat(receiver.edges)
209 .containsExactly(
210 Edge.of(pair, o1, EdgeType.CURRENT_TRAVERSAL),
211 Edge.of(pair, o2, EdgeType.ALREADY_SEEN));
212 }
213
214 private static final class Outer {
215 private Inner createInner() {
216 return new Inner();
217 }
218
219 private class Inner {
220 // Java is clever and will optimize out the reference to Outer without this
221 @SuppressWarnings("unused")
222 private Outer getOuter() {
223 return Outer.this;
224 }
225 }
226 }
227
228 @Test
229 public void testNonStaticClassTraversesEnclosingClass() {
230 Outer outer = new Outer();
231 Outer.Inner inner = outer.createInner();
232
233 ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
234 LoggingObjectReceiver receiver = new LoggingObjectReceiver();
235 ObjectGraphTraverser cut = createObjectGraphTraverser(null, seen, receiver, false);
236
237 cut.traverse(inner);
238 assertThat(receiver.objects).containsExactly(outer, inner);
239 }
240
Googlerddc7dd72024-03-05 08:34:53 -0800241 @Test
Googler85da0c02024-03-06 08:31:48 -0800242 public void testLambdaClosingOverNothingReported() {
Googlerddc7dd72024-03-05 08:34:53 -0800243 Object o1 = new Object();
Googler85da0c02024-03-06 08:31:48 -0800244 Supplier<Object> lambda = () -> 3;
245 Object pair = Pair.of(o1, lambda);
Googlerddc7dd72024-03-05 08:34:53 -0800246
247 ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
248 LoggingObjectReceiver receiver = new LoggingObjectReceiver();
249 ObjectGraphTraverser cut = createObjectGraphTraverser(null, seen, receiver, false);
250
251 cut.traverse(pair);
Googler85da0c02024-03-06 08:31:48 -0800252 assertThat(receiver.objects).containsExactly(pair, o1, lambda);
253 }
254
255 @Test
256 public void testLambdaClosingOverNothingReportedWhenReferencedTwice() {
257 Supplier<Object> lambda = () -> 3;
258 Object pair = Pair.of(lambda, lambda);
259
260 ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
261 LoggingObjectReceiver receiver = new LoggingObjectReceiver();
262 ObjectGraphTraverser cut = createObjectGraphTraverser(null, seen, receiver, false);
263
264 cut.traverse(pair);
265 assertThat(receiver.objects).containsExactly(pair, lambda);
266 }
267
268 @Test
269 public void testValuesClosedOverReported() {
270 Object o1 = new Object();
271 Supplier<Object> lambda = () -> o1;
272
273 ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
274 LoggingObjectReceiver receiver = new LoggingObjectReceiver();
275 ObjectGraphTraverser cut = createObjectGraphTraverser(null, seen, receiver, false);
276
277 cut.traverse(lambda);
278 assertThat(receiver.objects).containsExactly(lambda, o1);
279 }
280
281 @Test
282 public void testMultipleClosuresWithSameCodeReported() {
283 Object o1 = new Object();
284 Object o2 = new Object();
285 Function<Object, Supplier<Object>> generator = o -> () -> o;
286 Object l1 = generator.apply(o1);
287 Object l2 = generator.apply(o2);
288 Object pair = Pair.of(l1, l2);
289
290 ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
291 LoggingObjectReceiver receiver = new LoggingObjectReceiver();
292 ObjectGraphTraverser cut = createObjectGraphTraverser(null, seen, receiver, false);
293
294 cut.traverse(pair);
295 assertThat(receiver.objects).containsExactly(pair, l1, l2, o1, o2);
Googlerddc7dd72024-03-05 08:34:53 -0800296 }
297
298 @Test
299 public void testEdgeContexts() {
300 Object o1 = new Object();
301 Object o2 = new Object();
302 Object array = new Object[] {o2};
303 Object pair = Pair.of(o1, array);
304
305 ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
306 LoggingObjectReceiver receiver = new LoggingObjectReceiver();
307 DomainSpecificTraverser domainSpecific = mock(DomainSpecificTraverser.class);
308 when(domainSpecific.admit(any())).thenReturn(true);
309 when(domainSpecific.contextForField(refEq(pair), any(), any(), refEq(o1)))
310 .thenReturn("o1context");
311 when(domainSpecific.contextForArrayItem(refEq(array), any(), refEq(o2)))
312 .thenReturn("o2context");
313 ObjectGraphTraverser cut = createObjectGraphTraverser(domainSpecific, seen, receiver, true);
314
315 cut.traverse(pair);
316 assertThat(receiver.edgeContexts)
317 .containsEntry(Edge.of(pair, o1, EdgeType.CURRENT_TRAVERSAL), "o1context");
318 assertThat(receiver.edgeContexts)
319 .containsEntry(Edge.of(array, o2, EdgeType.CURRENT_TRAVERSAL), "o2context");
320 assertThat(receiver.objectContexts).containsEntry(o1, "o1context");
321 assertThat(receiver.objectContexts).containsEntry(o2, "o2context");
322 }
323
324 @Test
325 public void testObjectContexts() {
326 Object o1 = new Object();
327 Object o2 = new Object();
328 Object pair = Pair.of(o1, o2);
329
330 ConcurrentIdentitySet seen = new ConcurrentIdentitySet(1);
331 LoggingObjectReceiver receiver = new LoggingObjectReceiver();
332 DomainSpecificTraverser domainSpecific = mock(DomainSpecificTraverser.class);
333 when(domainSpecific.admit(any())).thenReturn(true);
334 when(domainSpecific.contextForField(refEq(pair), any(), any(), refEq(o1))).thenReturn("bad");
335 when(domainSpecific.maybeTraverse(any(), any()))
336 .thenAnswer(
337 i -> {
338 Object o = i.getArgument(0);
339 Traversal traversal = i.getArgument(1);
340 if (o == o1) {
341 traversal.objectFound(o, "o1context");
342 return true;
343 } else if (o == o2) {
344 traversal.objectFound(o, "o2context");
345 return true;
346 } else {
347 return false;
348 }
349 });
350 ObjectGraphTraverser cut = createObjectGraphTraverser(domainSpecific, seen, receiver, true);
351
352 cut.traverse(pair);
353 assertThat(receiver.objectContexts).containsEntry(o1, "o1context"); // overrides edge context
354 assertThat(receiver.objectContexts).containsEntry(o2, "o2context");
355 }
356}