未验证 提交 8d797fd2 编写于 作者: Z zyfncg 提交者: GitHub

[Phi] Refactor logic of judging whether having a phi kernrel (#46920)

* refind logic of choose phi kernrel

* fix complie budg
上级 561fd8c8
......@@ -40,7 +40,7 @@ const std::unordered_set<std::string> standard_kernel_suffixs({
* after 2.0, and can no longer be occupied by the previously abandoned ops.
* They are marked here uniformly.
*/
const std::unordered_set<std::string> deprecated_op_names(
static const std::unordered_set<std::string> deprecated_op_names(
{"diag",
"flatten",
"flatten_grad",
......
......@@ -20,6 +20,7 @@
#include "paddle/fluid/platform/device/xpu/xpu_op_list.h"
#include "paddle/phi/core/compat/convert_utils.h"
#endif
#include "paddle/phi/core/compat/op_utils.h"
DECLARE_bool(enable_api_kernel_fallback);
......@@ -45,6 +46,17 @@ KernelFactory& KernelFactory::Instance() {
return g_op_kernel_factory;
}
bool KernelFactory::HasCompatiblePhiKernel(const std::string& op_type) const {
if (deprecated_op_names.find(op_type) == deprecated_op_names.end()) {
if (phi::OpUtilsMap::Instance().Contains(op_type)) {
return true;
} else if (kernels_.find(op_type) != kernels_.end()) {
return true;
}
}
return false;
}
const Kernel& KernelFactory::SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name);
......
......@@ -272,9 +272,7 @@ class KernelFactory {
KernelNameMap& kernels() { return kernels_; }
bool HasCompatiblePhiKernel(const std::string& op_type) const {
return kernels_.find(TransToPhiKernelName(op_type)) != kernels_.end();
}
bool HasCompatiblePhiKernel(const std::string& op_type) const;
KernelResult SelectKernelOrThrowError(const std::string& kernel_name,
const KernelKey& kernel_key,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册