diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index f7d2ef1bf5d423af756aa4e89712a7bb0d5dd7ee..d248715f00c2ba7dddb24a79450f76cd45cfbf5f 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -191,12 +191,23 @@ PreparedOp PrepareImpl(const NameVarMap& ins, bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug); if (is_xpu_kp_support) { + auto expected_kernel_key_library_type = + expected_kernel_key.library_type_; expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP; - VLOG(3) << "modify XPU KP kernel: " << op.Type() + VLOG(3) << "modifing XPU KP kernel: " << op.Type() << ", using_kernel_key:" << expected_kernel_key; + phi::KernelKey try_pt_kernel_key = + TransOpKernelTypeToPhiKernelKey(expected_kernel_key); + if (!phi::KernelFactory::Instance().IsSelectKernelValid( + pt_kernel_name, try_pt_kernel_key)) { + expected_kernel_key.library_type_ = expected_kernel_key_library_type; + VLOG(3) << "modify XPU KP kernel: " << op.Type() << " is failed " + << expected_kernel_key; + } } } #endif + pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key); auto pt_kernel = phi::KernelFactory::Instance().SelectKernel(pt_kernel_name, pt_kernel_key); @@ -227,6 +238,20 @@ PreparedOp PrepareImpl(const NameVarMap& ins, auto& all_op_kernels = op.AllOpKernels(); auto kernels_iter = all_op_kernels.find(op.Type()); +#ifdef PADDLE_WITH_XPU_KP + bool use_xpu_kp_kernel_rt = + paddle::platform::is_xpu_place(expected_kernel_key.place_) && + FLAGS_run_kp_kernel && + paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key); + bool use_xpu_kp_kernel_debug = + paddle::platform::is_xpu_place(expected_kernel_key.place_) && + paddle::platform::is_in_xpu_kpwhite_list(op.Type()); + bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug); + if (is_xpu_kp_support) { + expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP; + } +#endif + if ((kernels_iter == all_op_kernels.end() || kernels_iter->second.find(expected_kernel_key) == kernels_iter->second.end()) @@ -255,6 +280,7 @@ PreparedOp PrepareImpl(const NameVarMap& ins, platform::errors::NotFound( "There are no kernels which are registered in the %s operator.", op.Type())); + auto& kernels = kernels_iter->second; auto kernel_iter = kernels.find(expected_kernel_key); @@ -271,18 +297,12 @@ PreparedOp PrepareImpl(const NameVarMap& ins, #ifdef PADDLE_WITH_XPU_KP if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) { - bool use_xpu_kp_kernel_rt = - FLAGS_run_kp_kernel && - paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key); - bool use_xpu_kp_kernel_debug = - paddle::platform::is_in_xpu_kpwhite_list(op.Type()); if (use_xpu_kp_kernel_rt) { VLOG(3) << "xpu_kp using rt mode "; } if (use_xpu_kp_kernel_debug) { VLOG(3) << "xpu_kp using debug mode "; } - bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug); if (is_xpu_kp_support) { expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP; kernel_iter = kernels.find(expected_kernel_key); diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index ba41e082ab9122c1d895dddc3f75c6956fb7e62b..81c43764fee9edf57c2b3682e8650e5c906f7795 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -59,6 +59,21 @@ KernelKeyMap KernelFactory::SelectKernelMap( return iter->second; } +bool KernelFactory::IsSelectKernelValid(const std::string& kernel_name, + const KernelKey& kernel_key) const { + auto iter = kernels_.find(kernel_name); + PADDLE_ENFORCE_NE( + iter, + kernels_.end(), + phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name)); + + auto kernel_iter = iter->second.find(kernel_key); + if (kernel_iter == iter->second.end()) { + return false; + } + return true; +} + const Kernel& KernelFactory::SelectKernelOrThrowError( const std::string& kernel_name, const KernelKey& kernel_key) const { auto iter = kernels_.find(kernel_name); diff --git a/paddle/phi/core/kernel_factory.h b/paddle/phi/core/kernel_factory.h index e502b9cb3e02536e8d764a4cbc5e1d5509960303..6c098c75a0eda058d098fe1e5c83cf4f0e68af4a 100644 --- a/paddle/phi/core/kernel_factory.h +++ b/paddle/phi/core/kernel_factory.h @@ -245,6 +245,9 @@ class KernelFactory { DataLayout layout, DataType dtype) const; + bool IsSelectKernelValid(const std::string& kernel_name, + const KernelKey& kernel_key) const; + Kernel SelectKernel(const std::string& kernel_name, const KernelKey& kernel_key) const;