未验证 提交 8d797fd2 编写于 作者: Z zyfncg 提交者: GitHub

[Phi] Refactor logic of judging whether having a phi kernrel (#46920)

* refind logic of choose phi kernrel

* fix complie budg
上级 561fd8c8
...@@ -40,7 +40,7 @@ const std::unordered_set<std::string> standard_kernel_suffixs({ ...@@ -40,7 +40,7 @@ const std::unordered_set<std::string> standard_kernel_suffixs({
* after 2.0, and can no longer be occupied by the previously abandoned ops. * after 2.0, and can no longer be occupied by the previously abandoned ops.
* They are marked here uniformly. * They are marked here uniformly.
*/ */
const std::unordered_set<std::string> deprecated_op_names( static const std::unordered_set<std::string> deprecated_op_names(
{"diag", {"diag",
"flatten", "flatten",
"flatten_grad", "flatten_grad",
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/fluid/platform/device/xpu/xpu_op_list.h" #include "paddle/fluid/platform/device/xpu/xpu_op_list.h"
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
#endif #endif
#include "paddle/phi/core/compat/op_utils.h"
DECLARE_bool(enable_api_kernel_fallback); DECLARE_bool(enable_api_kernel_fallback);
...@@ -45,6 +46,17 @@ KernelFactory& KernelFactory::Instance() { ...@@ -45,6 +46,17 @@ KernelFactory& KernelFactory::Instance() {
return g_op_kernel_factory; return g_op_kernel_factory;
} }
bool KernelFactory::HasCompatiblePhiKernel(const std::string& op_type) const {
if (deprecated_op_names.find(op_type) == deprecated_op_names.end()) {
if (phi::OpUtilsMap::Instance().Contains(op_type)) {
return true;
} else if (kernels_.find(op_type) != kernels_.end()) {
return true;
}
}
return false;
}
const Kernel& KernelFactory::SelectKernel(const std::string& kernel_name, const Kernel& KernelFactory::SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const { const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name); auto iter = kernels_.find(kernel_name);
......
...@@ -272,9 +272,7 @@ class KernelFactory { ...@@ -272,9 +272,7 @@ class KernelFactory {
KernelNameMap& kernels() { return kernels_; } KernelNameMap& kernels() { return kernels_; }
bool HasCompatiblePhiKernel(const std::string& op_type) const { bool HasCompatiblePhiKernel(const std::string& op_type) const;
return kernels_.find(TransToPhiKernelName(op_type)) != kernels_.end();
}
KernelResult SelectKernelOrThrowError(const std::string& kernel_name, KernelResult SelectKernelOrThrowError(const std::string& kernel_name,
const KernelKey& kernel_key, const KernelKey& kernel_key,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册