未验证 提交 66d3c56e 编写于 作者: Y YuhangLi 提交者: GitHub

[CUSTOM]custom device add black_list (#50409)

* [CUSTOM]custom device add black_list

* change log level

* fix some issues
上级 86fa306a
......@@ -141,9 +141,11 @@ phi::KernelKey FallBackToCpu(const phi::KernelKey& kernel_key,
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto place = phi::TransToPhiPlace(kernel_key.backend());
if (platform::is_custom_place(place)) {
VLOG(3) << "phi missing " << place.GetDeviceType()
<< " kernel: " << op.Type()
bool is_custom_place = platform::is_custom_place(place);
if (is_custom_place ||
phi::backends::custom_device::is_in_custom_black_list(op.Type())) {
std::string info = is_custom_place ? "phi missing " : "phi in black list ";
VLOG(3) << info << place.GetDeviceType() << " kernel: " << op.Type()
<< ", expected_kernel_key:" << kernel_key
<< ", fallback to CPU one!";
return phi::KernelKey(
......
......@@ -35,6 +35,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/device/xpu/xpu_op_list.h"
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/custom/custom_device_op_list.h"
#endif
namespace paddle {
namespace framework {
......
......@@ -47,7 +47,8 @@ list(
device_manager.cc)
if(WITH_CUSTOM_DEVICE)
list(APPEND BACKENDS_SRCS custom/custom_context.cc custom/custom_device.cc)
list(APPEND BACKENDS_SRCS custom/custom_context.cc custom/custom_device.cc
custom/custom_device_op_list.cc)
endif()
add_library(phi_backends "${BACKENDS_SRCS}")
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/custom/custom_device_op_list.h"
#include <glog/logging.h>
#include <mutex>
#include <string>
#include <unordered_set>
namespace phi {
namespace backends {
namespace custom_device {
// ops_string contains op_list(e.g., 'mul,mul_grad'), parse the op string and
// insert op to op set
static void tokenize(const std::string& ops,
char delim,
std::unordered_set<std::string>* op_set) {
std::string::size_type beg = 0;
for (uint64_t end = 0; (end = ops.find(delim, end)) != std::string::npos;
++end) {
op_set->insert(ops.substr(beg, end - beg));
beg = end + 1;
}
op_set->insert(ops.substr(beg));
}
bool is_in_custom_black_list(const std::string& fluid_op_name) {
static bool inited = false;
static std::unordered_set<std::string> cs_black_list;
static std::mutex s_mtx;
if (!inited) {
std::lock_guard<std::mutex> guard(s_mtx);
if (!inited) {
if (std::getenv("CUSTOM_DEVICE_BLACK_LIST") != nullptr) {
std::string ops(std::getenv("CUSTOM_DEVICE_BLACK_LIST"));
tokenize(ops, ',', &cs_black_list);
}
inited = true;
VLOG(3) << "Custom Device Black List: ";
for (auto iter = cs_black_list.begin(); iter != cs_black_list.end();
++iter) {
VLOG(3) << *iter << " ";
}
}
}
if (cs_black_list.find(fluid_op_name) != cs_black_list.end()) {
return true;
}
return false;
}
} // namespace custom_device
} // namespace backends
} // namespace phi
#endif
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include <paddle/phi/common/data_type.h>
#include <string>
#include <unordered_map>
#include <unordered_set>
namespace phi {
namespace backends {
namespace custom_device {
bool is_in_custom_black_list(const std::string& fluid_op_name);
} // namespace custom_device
} // namespace backends
} // namespace phi
#endif
......@@ -20,6 +20,9 @@
#include "paddle/phi/backends/xpu/xpu_op_list.h"
#include "paddle/phi/core/compat/convert_utils.h"
#endif
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
#include "paddle/phi/backends/custom/custom_device_op_list.h"
#endif
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/utils/string/string_helper.h"
......@@ -191,6 +194,11 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) ||
!phi::backends::xpu::is_xpu_support_op(TransToFluidOpName(kernel_name),
kernel_key.dtype())
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
if (FLAGS_enable_api_kernel_fallback &&
(kernel_iter == iter->second.end() ||
phi::backends::custom_device::is_in_custom_black_list(
TransToFluidOpName(kernel_name)))
#else
if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end())
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册