blob: 1582b120c48eea3270693d70fc28c6f54de5018b [file] [log] [blame]
/*
* Copyright 2017 The Bazel Authors. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.idea.blaze.python.run.smrunner;
import com.google.common.collect.ImmutableList;
import com.google.idea.blaze.base.run.smrunner.SmRunnerUtils;
import com.google.idea.blaze.python.run.PyTestUtils;
import com.intellij.execution.Location;
import com.intellij.execution.PsiLocation;
import com.intellij.execution.testframework.sm.runner.SMTestLocator;
import com.intellij.openapi.progress.ProgressManager;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.util.text.StringUtil;
import com.intellij.psi.search.GlobalSearchScope;
import com.jetbrains.python.psi.PyClass;
import com.jetbrains.python.psi.PyFunction;
import com.jetbrains.python.psi.stubs.PyClassNameIndex;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import javax.annotation.Nullable;
/** Locate python test classes / methods for test UI navigation. */
public final class BlazePythonTestLocator implements SMTestLocator {
public static final BlazePythonTestLocator INSTANCE = new BlazePythonTestLocator();
static final String PY_TESTCASE_PREFIX = "__main__.";
private BlazePythonTestLocator() {}
@Override
public List<Location> getLocation(
String protocol, String path, Project project, GlobalSearchScope scope) {
if (protocol.equals(SmRunnerUtils.GENERIC_SUITE_PROTOCOL)) {
return findTestClass(project, scope, path);
}
if (protocol.equals(SmRunnerUtils.GENERIC_TEST_PROTOCOL)) {
path = StringUtil.trimStart(path, PY_TESTCASE_PREFIX);
String[] components = path.split("\\.|::");
if (components.length < 2) {
return ImmutableList.of();
}
return findTestMethod(
project, scope, components[components.length - 2], components[components.length - 1]);
}
return ImmutableList.of();
}
private static List<Location> findTestMethod(
Project project, GlobalSearchScope scope, String className, @Nullable String methodName) {
List<Location> results = new ArrayList<>();
if (methodName == null) {
return findTestClass(project, scope, className);
}
for (PyClass pyClass : PyClassNameIndex.find(className, project, scope)) {
ProgressManager.checkCanceled();
if (PyTestUtils.isTestClass(pyClass)) {
PyFunction method = findMethod(pyClass, methodName);
if (method != null && PyTestUtils.isTestFunction(method)) {
results.add(new PsiLocation<>(project, method));
}
results.add(new PsiLocation<>(project, pyClass));
}
}
return results;
}
private static List<Location> findTestClass(
Project project, GlobalSearchScope scope, String className) {
List<Location> results = new ArrayList<>();
for (PyClass pyClass : PyClassNameIndex.find(className, project, scope)) {
ProgressManager.checkCanceled();
if (PyTestUtils.isTestClass(pyClass)) {
results.add(new PsiLocation<>(project, pyClass));
}
}
return results;
}
@Nullable
private static PyFunction findMethod(PyClass pyClass, String methodName) {
PyFunction method = pyClass.findMethodByName(methodName, true, null);
if (method != null) {
return method;
}
return Arrays.stream(PyParameterizedNameConverter.EP_NAME.getExtensions())
.map(converter -> converter.toFunctionName(methodName))
.map(name -> pyClass.findMethodByName(name, true, null))
.filter(Objects::nonNull)
.findFirst()
.orElse(null);
}
}