diff --git a/paddle/fluid/framework/phi_utils.cc b/paddle/fluid/framework/phi_utils.cc index 53c35fc41c07885346a8f5c0f6fdaec7224895d8..8da7b23a6a0449ea1e8d7f015552ed5cbc2d3b44 100644 --- a/paddle/fluid/framework/phi_utils.cc +++ b/paddle/fluid/framework/phi_utils.cc @@ -141,9 +141,11 @@ phi::KernelKey FallBackToCpu(const phi::KernelKey& kernel_key, #endif #ifdef PADDLE_WITH_CUSTOM_DEVICE auto place = phi::TransToPhiPlace(kernel_key.backend()); - if (platform::is_custom_place(place)) { - VLOG(3) << "phi missing " << place.GetDeviceType() - << " kernel: " << op.Type() + bool is_custom_place = platform::is_custom_place(place); + if (is_custom_place || + phi::backends::custom_device::is_in_custom_black_list(op.Type())) { + std::string info = is_custom_place ? "phi missing " : "phi in black list "; + VLOG(3) << info << place.GetDeviceType() << " kernel: " << op.Type() << ", expected_kernel_key:" << kernel_key << ", fallback to CPU one!"; return phi::KernelKey( diff --git a/paddle/fluid/framework/phi_utils.h b/paddle/fluid/framework/phi_utils.h index 602528f5bb061603b90d350246baa8ce1992453f..0c214176f27e9a9b4d3da59123db3a14ce83d82d 100644 --- a/paddle/fluid/framework/phi_utils.h +++ b/paddle/fluid/framework/phi_utils.h @@ -35,6 +35,9 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU #include "paddle/fluid/platform/device/xpu/xpu_op_list.h" #endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE +#include "paddle/phi/backends/custom/custom_device_op_list.h" +#endif namespace paddle { namespace framework { diff --git a/paddle/phi/backends/CMakeLists.txt b/paddle/phi/backends/CMakeLists.txt index e90cdc9e0663abe516c932d49b4572debe420c4b..8b32aa00f7a3849f95258aa9bf1c3f2d6d057b92 100644 --- a/paddle/phi/backends/CMakeLists.txt +++ b/paddle/phi/backends/CMakeLists.txt @@ -47,7 +47,8 @@ list( device_manager.cc) if(WITH_CUSTOM_DEVICE) - list(APPEND BACKENDS_SRCS custom/custom_context.cc custom/custom_device.cc) + list(APPEND BACKENDS_SRCS custom/custom_context.cc custom/custom_device.cc + custom/custom_device_op_list.cc) endif() add_library(phi_backends "${BACKENDS_SRCS}") diff --git a/paddle/phi/backends/custom/custom_device_op_list.cc b/paddle/phi/backends/custom/custom_device_op_list.cc new file mode 100644 index 0000000000000000000000000000000000000000..db00eec81900f294c96158e235d01690bc1e4b93 --- /dev/null +++ b/paddle/phi/backends/custom/custom_device_op_list.cc @@ -0,0 +1,61 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#ifdef PADDLE_WITH_CUSTOM_DEVICE +#include "paddle/phi/backends/custom/custom_device_op_list.h" +#include +#include +#include +#include +namespace phi { +namespace backends { +namespace custom_device { +// ops_string contains op_list(e.g., 'mul,mul_grad'), parse the op string and +// insert op to op set +static void tokenize(const std::string& ops, + char delim, + std::unordered_set* op_set) { + std::string::size_type beg = 0; + for (uint64_t end = 0; (end = ops.find(delim, end)) != std::string::npos; + ++end) { + op_set->insert(ops.substr(beg, end - beg)); + beg = end + 1; + } + op_set->insert(ops.substr(beg)); +} + +bool is_in_custom_black_list(const std::string& fluid_op_name) { + static bool inited = false; + static std::unordered_set cs_black_list; + static std::mutex s_mtx; + if (!inited) { + std::lock_guard guard(s_mtx); + if (!inited) { + if (std::getenv("CUSTOM_DEVICE_BLACK_LIST") != nullptr) { + std::string ops(std::getenv("CUSTOM_DEVICE_BLACK_LIST")); + tokenize(ops, ',', &cs_black_list); + } + inited = true; + VLOG(3) << "Custom Device Black List: "; + for (auto iter = cs_black_list.begin(); iter != cs_black_list.end(); + ++iter) { + VLOG(3) << *iter << " "; + } + } + } + if (cs_black_list.find(fluid_op_name) != cs_black_list.end()) { + return true; + } + return false; +} +} // namespace custom_device +} // namespace backends +} // namespace phi +#endif diff --git a/paddle/phi/backends/custom/custom_device_op_list.h b/paddle/phi/backends/custom/custom_device_op_list.h new file mode 100644 index 0000000000000000000000000000000000000000..695bfdb1a09c96e0522479ca4a10c9c81d5d3163 --- /dev/null +++ b/paddle/phi/backends/custom/custom_device_op_list.h @@ -0,0 +1,24 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once +#ifdef PADDLE_WITH_CUSTOM_DEVICE +#include +#include +#include +#include +namespace phi { +namespace backends { +namespace custom_device { +bool is_in_custom_black_list(const std::string& fluid_op_name); +} // namespace custom_device +} // namespace backends +} // namespace phi +#endif diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index f44bfe6a2e0dd5ed4d11dd8bf8d6895791599893..096207e06fcde1e030618f0c18e75b742481c0d3 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -20,6 +20,9 @@ #include "paddle/phi/backends/xpu/xpu_op_list.h" #include "paddle/phi/core/compat/convert_utils.h" #endif +#if defined(PADDLE_WITH_CUSTOM_DEVICE) +#include "paddle/phi/backends/custom/custom_device_op_list.h" +#endif #include "paddle/phi/core/compat/op_utils.h" #include "paddle/utils/string/string_helper.h" @@ -191,6 +194,11 @@ KernelResult KernelFactory::SelectKernelOrThrowError( if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) || !phi::backends::xpu::is_xpu_support_op(TransToFluidOpName(kernel_name), kernel_key.dtype()) +#elif defined(PADDLE_WITH_CUSTOM_DEVICE) + if (FLAGS_enable_api_kernel_fallback && + (kernel_iter == iter->second.end() || + phi::backends::custom_device::is_in_custom_black_list( + TransToFluidOpName(kernel_name))) #else if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) #endif