From 015532b4eb19627b8324efe7e6a77aaeb4b541f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Tue, 12 Jul 2022 21:32:39 +0800 Subject: [PATCH] add xpu_kp support for standalone executor. test=develop (#44231) --- paddle/fluid/framework/operator.cc | 37 +++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 0a5de2bd3f2..0fca87df34f 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1300,17 +1300,38 @@ bool OperatorWithKernel::SupportsKernelType( const OpKernelType& kernel_type) const { auto& all_op_kernels = AllOpKernels(); auto kernels_iter = all_op_kernels.find(type_); - bool support = - kernels_iter != all_op_kernels.end() && - kernels_iter->second.find(kernel_type) != kernels_iter->second.end(); -#if defined(PADDLE_WITH_XPU) + if (kernels_iter == all_op_kernels.end()) return false; + OpKernelMap& kernels = kernels_iter->second; + auto kernel_iter = kernels.find(kernel_type); + +#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) if (paddle::platform::is_xpu_place(kernel_type.place_)) { - support = support && - paddle::platform::is_xpu_support_op(type_, kernel_type) && - !paddle::platform::is_in_xpu_black_list(type_); + return kernel_iter != kernels.end() && + paddle::platform::is_xpu_support_op(type_, kernel_type) && + !paddle::platform::is_in_xpu_black_list(type_); } #endif - return support; + +#ifdef PADDLE_WITH_XPU_KP + if (paddle::platform::is_xpu_place(kernel_type.place_)) { + bool use_xpu_kp_kernel_rt = + FLAGS_run_kp_kernel && + paddle::platform::is_xpu_kp_support_op(type_, kernel_type); + bool use_xpu_kp_kernel_debug = + paddle::platform::is_in_xpu_kpwhite_list(type_); + bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug); + if (is_xpu_kp_support) { + auto tmp_kernel_type = kernel_type; + tmp_kernel_type.library_type_ = LibraryType::kKP; + return kernels.find(tmp_kernel_type) != kernels.end(); + } + return kernel_iter != kernels.end() && + paddle::platform::is_xpu_support_op(type_, kernel_type) && + !paddle::platform::is_in_xpu_black_list(type_); + } +#endif + + return kernel_iter != kernels.end(); } bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, -- GitLab