From 8d797fd222b8efc02b8291f232aada2776885583 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Thu, 13 Oct 2022 21:29:24 +0800 Subject: [PATCH] [Phi] Refactor logic of judging whether having a phi kernrel (#46920) * refind logic of choose phi kernrel * fix complie budg --- paddle/phi/core/compat/op_utils.h | 2 +- paddle/phi/core/kernel_factory.cc | 12 ++++++++++++ paddle/phi/core/kernel_factory.h | 4 +--- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/paddle/phi/core/compat/op_utils.h b/paddle/phi/core/compat/op_utils.h index b578afa7c2b..10b859fdac2 100644 --- a/paddle/phi/core/compat/op_utils.h +++ b/paddle/phi/core/compat/op_utils.h @@ -40,7 +40,7 @@ const std::unordered_set standard_kernel_suffixs({ * after 2.0, and can no longer be occupied by the previously abandoned ops. * They are marked here uniformly. */ -const std::unordered_set deprecated_op_names( +static const std::unordered_set deprecated_op_names( {"diag", "flatten", "flatten_grad", diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 71256bdabaa..480882550db 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/platform/device/xpu/xpu_op_list.h" #include "paddle/phi/core/compat/convert_utils.h" #endif +#include "paddle/phi/core/compat/op_utils.h" DECLARE_bool(enable_api_kernel_fallback); @@ -45,6 +46,17 @@ KernelFactory& KernelFactory::Instance() { return g_op_kernel_factory; } +bool KernelFactory::HasCompatiblePhiKernel(const std::string& op_type) const { + if (deprecated_op_names.find(op_type) == deprecated_op_names.end()) { + if (phi::OpUtilsMap::Instance().Contains(op_type)) { + return true; + } else if (kernels_.find(op_type) != kernels_.end()) { + return true; + } + } + return false; +} + const Kernel& KernelFactory::SelectKernel(const std::string& kernel_name, const KernelKey& kernel_key) const { auto iter = kernels_.find(kernel_name); diff --git a/paddle/phi/core/kernel_factory.h b/paddle/phi/core/kernel_factory.h index 8e98c276646..ed9280fa475 100644 --- a/paddle/phi/core/kernel_factory.h +++ b/paddle/phi/core/kernel_factory.h @@ -272,9 +272,7 @@ class KernelFactory { KernelNameMap& kernels() { return kernels_; } - bool HasCompatiblePhiKernel(const std::string& op_type) const { - return kernels_.find(TransToPhiKernelName(op_type)) != kernels_.end(); - } + bool HasCompatiblePhiKernel(const std::string& op_type) const; KernelResult SelectKernelOrThrowError(const std::string& kernel_name, const KernelKey& kernel_key, -- GitLab