diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 0a5de2bd3f2622d1863db7b5b327a5c98c247fcc..0fca87df34f5a500a630c7023c11ed50d0f7ec73 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1300,17 +1300,38 @@ bool OperatorWithKernel::SupportsKernelType( const OpKernelType& kernel_type) const { auto& all_op_kernels = AllOpKernels(); auto kernels_iter = all_op_kernels.find(type_); - bool support = - kernels_iter != all_op_kernels.end() && - kernels_iter->second.find(kernel_type) != kernels_iter->second.end(); -#if defined(PADDLE_WITH_XPU) + if (kernels_iter == all_op_kernels.end()) return false; + OpKernelMap& kernels = kernels_iter->second; + auto kernel_iter = kernels.find(kernel_type); + +#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) if (paddle::platform::is_xpu_place(kernel_type.place_)) { - support = support && - paddle::platform::is_xpu_support_op(type_, kernel_type) && - !paddle::platform::is_in_xpu_black_list(type_); + return kernel_iter != kernels.end() && + paddle::platform::is_xpu_support_op(type_, kernel_type) && + !paddle::platform::is_in_xpu_black_list(type_); } #endif - return support; + +#ifdef PADDLE_WITH_XPU_KP + if (paddle::platform::is_xpu_place(kernel_type.place_)) { + bool use_xpu_kp_kernel_rt = + FLAGS_run_kp_kernel && + paddle::platform::is_xpu_kp_support_op(type_, kernel_type); + bool use_xpu_kp_kernel_debug = + paddle::platform::is_in_xpu_kpwhite_list(type_); + bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug); + if (is_xpu_kp_support) { + auto tmp_kernel_type = kernel_type; + tmp_kernel_type.library_type_ = LibraryType::kKP; + return kernels.find(tmp_kernel_type) != kernels.end(); + } + return kernel_iter != kernels.end() && + paddle::platform::is_xpu_support_op(type_, kernel_type) && + !paddle::platform::is_in_xpu_black_list(type_); + } +#endif + + return kernel_iter != kernels.end(); } bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,