未验证 提交 84f7835d 编写于 作者: W wanghuancoder 提交者: GitHub

[PHI] phi support xpu black list (#46527)

* phi support xpu black list
上级 7e2e2ee7
...@@ -11,10 +11,17 @@ cc_library( ...@@ -11,10 +11,17 @@ cc_library(
SRCS enforce.cc SRCS enforce.cc
DEPS ${phi_enforce_deps}) DEPS ${phi_enforce_deps})
cc_library( if(WITH_XPU)
kernel_factory cc_library(
SRCS kernel_factory.cc kernel_factory
DEPS phi_enforce fluid_convert_utils) SRCS kernel_factory.cc
DEPS phi_enforce fluid_convert_utils convert_utils xpu_op_list)
else()
cc_library(
kernel_factory
SRCS kernel_factory.cc
DEPS phi_enforce fluid_convert_utils)
endif()
cc_library( cc_library(
kernel_context kernel_context
SRCS kernel_context.cc SRCS kernel_context.cc
......
...@@ -16,6 +16,10 @@ ...@@ -16,6 +16,10 @@
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
#include "paddle/fluid/platform/device/xpu/xpu_op_list.h"
#include "paddle/phi/core/compat/convert_utils.h"
#endif
DECLARE_bool(enable_api_kernel_fallback); DECLARE_bool(enable_api_kernel_fallback);
...@@ -112,7 +116,6 @@ KernelResult KernelFactory::SelectKernelOrThrowError( ...@@ -112,7 +116,6 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
<< "] is not registered."; << "] is not registered.";
} }
#endif #endif
auto kernel_iter = iter->second.find(kernel_key); auto kernel_iter = iter->second.find(kernel_key);
// TODO(chenweihang): polish refind impl here // TODO(chenweihang): polish refind impl here
if (kernel_iter == iter->second.end() && if (kernel_iter == iter->second.end() &&
...@@ -130,7 +133,12 @@ KernelResult KernelFactory::SelectKernelOrThrowError( ...@@ -130,7 +133,12 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
kernel_key, kernel_key,
kernel_name)); kernel_name));
if (FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) { if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end())
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
|| paddle::platform::is_in_xpu_black_list(TransToFluidOpName(kernel_name))
#endif
) {
// Fallback CPU backend // Fallback CPU backend
phi::KernelKey cpu_kernel_key( phi::KernelKey cpu_kernel_key(
phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册