diff --git a/paddle/phi/core/compat/op_utils.h b/paddle/phi/core/compat/op_utils.h index b578afa7c2b854c058e238b2923b2fe5830243d2..10b859fdac260396618116e0da5efbe3f3de192f 100644 --- a/paddle/phi/core/compat/op_utils.h +++ b/paddle/phi/core/compat/op_utils.h @@ -40,7 +40,7 @@ const std::unordered_set standard_kernel_suffixs({ * after 2.0, and can no longer be occupied by the previously abandoned ops. * They are marked here uniformly. */ -const std::unordered_set deprecated_op_names( +static const std::unordered_set deprecated_op_names( {"diag", "flatten", "flatten_grad", diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 71256bdabaa67582625a8de04fdc4c86c0ca2c50..480882550dbcaddd42d50e7ca8a30df95242330a 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/platform/device/xpu/xpu_op_list.h" #include "paddle/phi/core/compat/convert_utils.h" #endif +#include "paddle/phi/core/compat/op_utils.h" DECLARE_bool(enable_api_kernel_fallback); @@ -45,6 +46,17 @@ KernelFactory& KernelFactory::Instance() { return g_op_kernel_factory; } +bool KernelFactory::HasCompatiblePhiKernel(const std::string& op_type) const { + if (deprecated_op_names.find(op_type) == deprecated_op_names.end()) { + if (phi::OpUtilsMap::Instance().Contains(op_type)) { + return true; + } else if (kernels_.find(op_type) != kernels_.end()) { + return true; + } + } + return false; +} + const Kernel& KernelFactory::SelectKernel(const std::string& kernel_name, const KernelKey& kernel_key) const { auto iter = kernels_.find(kernel_name); diff --git a/paddle/phi/core/kernel_factory.h b/paddle/phi/core/kernel_factory.h index 8e98c276646d9cf387758417b53eaedba10e0564..ed9280fa475bf5b46c689b0a5e49e470484b77a6 100644 --- a/paddle/phi/core/kernel_factory.h +++ b/paddle/phi/core/kernel_factory.h @@ -272,9 +272,7 @@ class KernelFactory { KernelNameMap& kernels() { return kernels_; } - bool HasCompatiblePhiKernel(const std::string& op_type) const { - return kernels_.find(TransToPhiKernelName(op_type)) != kernels_.end(); - } + bool HasCompatiblePhiKernel(const std::string& op_type) const; KernelResult SelectKernelOrThrowError(const std::string& kernel_name, const KernelKey& kernel_key,