From 84f7835d67c2654dfcfdd49a4efb48f4b6f89227 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Wed, 28 Sep 2022 09:42:55 +0800 Subject: [PATCH] [PHI] phi support xpu black list (#46527) * phi support xpu black list --- paddle/phi/core/CMakeLists.txt | 15 +++++++++++---- paddle/phi/core/kernel_factory.cc | 12 ++++++++++-- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/paddle/phi/core/CMakeLists.txt b/paddle/phi/core/CMakeLists.txt index 669ca6c63c..d34f5f658b 100644 --- a/paddle/phi/core/CMakeLists.txt +++ b/paddle/phi/core/CMakeLists.txt @@ -11,10 +11,17 @@ cc_library( SRCS enforce.cc DEPS ${phi_enforce_deps}) -cc_library( - kernel_factory - SRCS kernel_factory.cc - DEPS phi_enforce fluid_convert_utils) +if(WITH_XPU) + cc_library( + kernel_factory + SRCS kernel_factory.cc + DEPS phi_enforce fluid_convert_utils convert_utils xpu_op_list) +else() + cc_library( + kernel_factory + SRCS kernel_factory.cc + DEPS phi_enforce fluid_convert_utils) +endif() cc_library( kernel_context SRCS kernel_context.cc diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 6e16029ee4..71256bdaba 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -16,6 +16,10 @@ #include "glog/logging.h" #include "paddle/phi/core/enforce.h" +#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) +#include "paddle/fluid/platform/device/xpu/xpu_op_list.h" +#include "paddle/phi/core/compat/convert_utils.h" +#endif DECLARE_bool(enable_api_kernel_fallback); @@ -112,7 +116,6 @@ KernelResult KernelFactory::SelectKernelOrThrowError( << "] is not registered."; } #endif - auto kernel_iter = iter->second.find(kernel_key); // TODO(chenweihang): polish refind impl here if (kernel_iter == iter->second.end() && @@ -130,7 +133,12 @@ KernelResult KernelFactory::SelectKernelOrThrowError( kernel_key, kernel_name)); - if (FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) { + if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) +#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) + || paddle::platform::is_in_xpu_black_list(TransToFluidOpName(kernel_name)) + +#endif + ) { // Fallback CPU backend phi::KernelKey cpu_kernel_key( phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); -- GitLab