From 705776ca7bafb6968c918a653895e2363f48d503 Mon Sep 17 00:00:00 2001 From: Liu-xiandong <85323580+Liu-xiandong@users.noreply.github.com> Date: Fri, 1 Apr 2022 14:31:10 +0800 Subject: [PATCH] [KP] fix bug in activation xpu kp kernel (#41219) * fix bug in activation xpu kp kernel * delete useless comment --- paddle/fluid/imperative/prepared_operator.cc | 34 ++++++++++++++++---- paddle/phi/core/kernel_factory.cc | 15 +++++++++ paddle/phi/core/kernel_factory.h | 3 ++ 3 files changed, 45 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index f7d2ef1bf5d..d248715f00c 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -191,12 +191,23 @@ PreparedOp PrepareImpl(const NameVarMap& ins, bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug); if (is_xpu_kp_support) { + auto expected_kernel_key_library_type = + expected_kernel_key.library_type_; expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP; - VLOG(3) << "modify XPU KP kernel: " << op.Type() + VLOG(3) << "modifing XPU KP kernel: " << op.Type() << ", using_kernel_key:" << expected_kernel_key; + phi::KernelKey try_pt_kernel_key = + TransOpKernelTypeToPhiKernelKey(expected_kernel_key); + if (!phi::KernelFactory::Instance().IsSelectKernelValid( + pt_kernel_name, try_pt_kernel_key)) { + expected_kernel_key.library_type_ = expected_kernel_key_library_type; + VLOG(3) << "modify XPU KP kernel: " << op.Type() << " is failed " + << expected_kernel_key; + } } } #endif + pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key); auto pt_kernel = phi::KernelFactory::Instance().SelectKernel(pt_kernel_name, pt_kernel_key); @@ -227,6 +238,20 @@ PreparedOp PrepareImpl(const NameVarMap& ins, auto& all_op_kernels = op.AllOpKernels(); auto kernels_iter = all_op_kernels.find(op.Type()); +#ifdef PADDLE_WITH_XPU_KP + bool use_xpu_kp_kernel_rt = + paddle::platform::is_xpu_place(expected_kernel_key.place_) && + FLAGS_run_kp_kernel && + paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key); + bool use_xpu_kp_kernel_debug = + paddle::platform::is_xpu_place(expected_kernel_key.place_) && + paddle::platform::is_in_xpu_kpwhite_list(op.Type()); + bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug); + if (is_xpu_kp_support) { + expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP; + } +#endif + if ((kernels_iter == all_op_kernels.end() || kernels_iter->second.find(expected_kernel_key) == kernels_iter->second.end()) @@ -255,6 +280,7 @@ PreparedOp PrepareImpl(const NameVarMap& ins, platform::errors::NotFound( "There are no kernels which are registered in the %s operator.", op.Type())); + auto& kernels = kernels_iter->second; auto kernel_iter = kernels.find(expected_kernel_key); @@ -271,18 +297,12 @@ PreparedOp PrepareImpl(const NameVarMap& ins, #ifdef PADDLE_WITH_XPU_KP if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) { - bool use_xpu_kp_kernel_rt = - FLAGS_run_kp_kernel && - paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key); - bool use_xpu_kp_kernel_debug = - paddle::platform::is_in_xpu_kpwhite_list(op.Type()); if (use_xpu_kp_kernel_rt) { VLOG(3) << "xpu_kp using rt mode "; } if (use_xpu_kp_kernel_debug) { VLOG(3) << "xpu_kp using debug mode "; } - bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug); if (is_xpu_kp_support) { expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP; kernel_iter = kernels.find(expected_kernel_key); diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index ba41e082ab9..81c43764fee 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -59,6 +59,21 @@ KernelKeyMap KernelFactory::SelectKernelMap( return iter->second; } +bool KernelFactory::IsSelectKernelValid(const std::string& kernel_name, + const KernelKey& kernel_key) const { + auto iter = kernels_.find(kernel_name); + PADDLE_ENFORCE_NE( + iter, + kernels_.end(), + phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name)); + + auto kernel_iter = iter->second.find(kernel_key); + if (kernel_iter == iter->second.end()) { + return false; + } + return true; +} + const Kernel& KernelFactory::SelectKernelOrThrowError( const std::string& kernel_name, const KernelKey& kernel_key) const { auto iter = kernels_.find(kernel_name); diff --git a/paddle/phi/core/kernel_factory.h b/paddle/phi/core/kernel_factory.h index e502b9cb3e0..6c098c75a0e 100644 --- a/paddle/phi/core/kernel_factory.h +++ b/paddle/phi/core/kernel_factory.h @@ -245,6 +245,9 @@ class KernelFactory { DataLayout layout, DataType dtype) const; + bool IsSelectKernelValid(const std::string& kernel_name, + const KernelKey& kernel_key) const; + Kernel SelectKernel(const std::string& kernel_name, const KernelKey& kernel_key) const; -- GitLab