未验证 提交 603f8425 编写于 作者: L Liu-xiandong 提交者: GitHub

[KP]fix bug that cannot fallback to CPU normally in XPU KP (#40576)

* [kp]fix bug that cannot fallback to CPU normally in XPU KP

* fix bug in static graph
上级 c040bbd7
...@@ -1456,7 +1456,8 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const { ...@@ -1456,7 +1456,8 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const {
kernel_iter = kernels.find(expected_kernel_key); kernel_iter = kernels.find(expected_kernel_key);
} }
#endif #endif
#ifdef PADDLE_WITH_XPU
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
if (platform::is_xpu_place(expected_kernel_key.place_) && if (platform::is_xpu_place(expected_kernel_key.place_) &&
(kernel_iter == kernels.end() || (kernel_iter == kernels.end() ||
!paddle::platform::is_xpu_support_op(type_, expected_kernel_key) || !paddle::platform::is_xpu_support_op(type_, expected_kernel_key) ||
...@@ -1470,17 +1471,36 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const { ...@@ -1470,17 +1471,36 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const {
#endif #endif
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
bool use_xpu_kp_kernel_rt = if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) {
FLAGS_run_kp_kernel && bool use_xpu_kp_kernel_rt =
paddle::platform::is_xpu_kp_support_op(type_, expected_kernel_key); FLAGS_run_kp_kernel &&
bool use_xpu_kp_kernel_debug = paddle::platform::is_xpu_kp_support_op(type_, expected_kernel_key);
paddle::platform::is_in_xpu_kpwhite_list(type_); bool use_xpu_kp_kernel_debug =
if (platform::is_xpu_place(expected_kernel_key.place_) && paddle::platform::is_in_xpu_kpwhite_list(type_);
(use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug)) { if (use_xpu_kp_kernel_rt) {
expected_kernel_key.library_type_ = LibraryType::kKP; VLOG(3) << "xpu_kp using rt mode ";
kernel_iter = kernels.find(expected_kernel_key); }
VLOG(3) << "using XPU KP kernel: " << type_ if (use_xpu_kp_kernel_debug) {
<< ", using_kernel_key:" << expected_kernel_key; VLOG(3) << "xpu_kp using debug mode ";
}
bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
if (is_xpu_kp_support) {
expected_kernel_key.library_type_ = LibraryType::kKP;
kernel_iter = kernels.find(expected_kernel_key);
VLOG(3) << "using XPU KP kernel: " << type_
<< ", using_kernel_key:" << expected_kernel_key;
}
bool is_xpu_unsupport =
(!paddle::platform::is_xpu_support_op(type_, expected_kernel_key) ||
paddle::platform::is_in_xpu_black_list(type_));
if (!is_xpu_kp_support &&
(kernel_iter == kernels.end() || is_xpu_unsupport)) {
VLOG(3) << "missing XPU kernel: " << type_
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
expected_kernel_key.place_ = platform::CPUPlace();
kernel_iter = kernels.find(expected_kernel_key);
}
} }
#endif #endif
......
...@@ -234,7 +234,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -234,7 +234,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
auto& kernels = kernels_iter->second; auto& kernels = kernels_iter->second;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_XPU #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
if (paddle::platform::is_xpu_place(expected_kernel_key.place_) && if (paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
(kernel_iter == kernels.end() || is_xpu_unsupport)) { (kernel_iter == kernels.end() || is_xpu_unsupport)) {
VLOG(3) << "missing XPU kernel: " << op.Type() VLOG(3) << "missing XPU kernel: " << op.Type()
...@@ -243,29 +243,36 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -243,29 +243,36 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
expected_kernel_key.place_ = platform::CPUPlace(); expected_kernel_key.place_ = platform::CPUPlace();
kernel_iter = kernels.find(expected_kernel_key); kernel_iter = kernels.find(expected_kernel_key);
} }
#endif #endif
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
expected_kernel_key.place_ = platform::XPUPlace(); if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) {
bool use_xpu_kp_kernel_rt = bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel && FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key); paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key);
bool use_xpu_kp_kernel_debug = bool use_xpu_kp_kernel_debug =
paddle::platform::is_in_xpu_kpwhite_list(op.Type()); paddle::platform::is_in_xpu_kpwhite_list(op.Type());
if (use_xpu_kp_kernel_rt) { if (use_xpu_kp_kernel_rt) {
VLOG(3) << "xpu_kp using rt mode "; VLOG(3) << "xpu_kp using rt mode ";
} }
if (use_xpu_kp_kernel_debug) { if (use_xpu_kp_kernel_debug) {
VLOG(3) << "xpu_kp using debug mode "; VLOG(3) << "xpu_kp using debug mode ";
} }
if (paddle::platform::is_xpu_place(expected_kernel_key.place_) && bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
(use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug)) { if (is_xpu_kp_support) {
expected_kernel_key.place_ = platform::XPUPlace(); expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP;
expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP; kernel_iter = kernels.find(expected_kernel_key);
kernel_iter = kernels.find(expected_kernel_key); VLOG(3) << "using XPU KP kernel: " << op.Type()
VLOG(3) << "using XPU KP kernel: " << op.Type() << ", using_kernel_key:" << expected_kernel_key;
<< ", using_kernel_key:" << expected_kernel_key; }
if (!is_xpu_kp_support &&
(kernel_iter == kernels.end() || is_xpu_unsupport)) {
VLOG(3) << "missing XPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
expected_kernel_key.place_ = platform::CPUPlace();
kernel_iter = kernels.find(expected_kernel_key);
}
} }
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册