/*
 * Copyright 2017 The Bazel Authors. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.google.idea.blaze.python.run.smrunner;

import com.google.common.collect.ImmutableList;
import com.google.idea.blaze.base.run.smrunner.SmRunnerUtils;
import com.google.idea.blaze.python.run.PyTestUtils;
import com.intellij.execution.Location;
import com.intellij.execution.PsiLocation;
import com.intellij.execution.testframework.sm.runner.SMTestLocator;
import com.intellij.openapi.progress.ProgressManager;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.util.text.StringUtil;
import com.intellij.psi.search.GlobalSearchScope;
import com.jetbrains.python.psi.PyClass;
import com.jetbrains.python.psi.PyFunction;
import com.jetbrains.python.psi.stubs.PyClassNameIndex;
import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nullable;

/** Locate python test classes / methods for test UI navigation. */
public final class BlazePythonTestLocator implements SMTestLocator {

  public static final BlazePythonTestLocator INSTANCE = new BlazePythonTestLocator();

  static final String PY_TESTCASE_PREFIX = "__main__.";

  private BlazePythonTestLocator() {}

  @Override
  public List<Location> getLocation(
      String protocol, String path, Project project, GlobalSearchScope scope) {
    if (protocol.equals(SmRunnerUtils.GENERIC_SUITE_PROTOCOL)) {
      return findTestClass(project, scope, path);
    }
    if (protocol.equals(SmRunnerUtils.GENERIC_TEST_PROTOCOL)) {
      path = StringUtil.trimStart(path, PY_TESTCASE_PREFIX);
      String[] components = path.split("\\.");
      if (components.length < 2) {
        return ImmutableList.of();
      }
      return findTestMethod(
          project, scope, components[components.length - 2], components[components.length - 1]);
    }
    return ImmutableList.of();
  }

  private static List<Location> findTestMethod(
      Project project, GlobalSearchScope scope, String className, @Nullable String methodName) {
    List<Location> results = new ArrayList<>();
    if (methodName == null) {
      return findTestClass(project, scope, className);
    }
    for (PyClass pyClass : PyClassNameIndex.find(className, project, scope)) {
      ProgressManager.checkCanceled();
      if (PyTestUtils.isTestClass(pyClass)) {
        PyFunction method = pyClass.findMethodByName(methodName, true, null);
        if (method != null && PyTestUtils.isTestFunction(method)) {
          results.add(new PsiLocation<>(project, method));
        }
        results.add(new PsiLocation<>(project, pyClass));
      }
    }
    return results;
  }

  private static List<Location> findTestClass(
      Project project, GlobalSearchScope scope, String className) {
    List<Location> results = new ArrayList<>();
    for (PyClass pyClass : PyClassNameIndex.find(className, project, scope)) {
      ProgressManager.checkCanceled();
      if (PyTestUtils.isTestClass(pyClass)) {
        results.add(new PsiLocation<>(project, pyClass));
      }
    }
    return results;
  }
}
