未验证 提交 603f8425 编写于 作者: L Liu-xiandong 提交者: GitHub

[KP]fix bug that cannot fallback to CPU normally in XPU KP (#40576)

* [kp]fix bug that cannot fallback to CPU normally in XPU KP

* fix bug in static graph
上级 c040bbd7
......@@ -1456,7 +1456,8 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const {
kernel_iter = kernels.find(expected_kernel_key);
}
#endif
#ifdef PADDLE_WITH_XPU
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
if (platform::is_xpu_place(expected_kernel_key.place_) &&
(kernel_iter == kernels.end() ||
!paddle::platform::is_xpu_support_op(type_, expected_kernel_key) ||
......@@ -1470,17 +1471,36 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const {
#endif
#ifdef PADDLE_WITH_XPU_KP
bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_kp_support_op(type_, expected_kernel_key);
bool use_xpu_kp_kernel_debug =
paddle::platform::is_in_xpu_kpwhite_list(type_);
if (platform::is_xpu_place(expected_kernel_key.place_) &&
(use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug)) {
expected_kernel_key.library_type_ = LibraryType::kKP;
kernel_iter = kernels.find(expected_kernel_key);
VLOG(3) << "using XPU KP kernel: " << type_
<< ", using_kernel_key:" << expected_kernel_key;
if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) {
bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_kp_support_op(type_, expected_kernel_key);
bool use_xpu_kp_kernel_debug =
paddle::platform::is_in_xpu_kpwhite_list(type_);
if (use_xpu_kp_kernel_rt) {
VLOG(3) << "xpu_kp using rt mode ";
}
if (use_xpu_kp_kernel_debug) {
VLOG(3) << "xpu_kp using debug mode ";
}
bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
if (is_xpu_kp_support) {
expected_kernel_key.library_type_ = LibraryType::kKP;
kernel_iter = kernels.find(expected_kernel_key);
VLOG(3) << "using XPU KP kernel: " << type_
<< ", using_kernel_key:" << expected_kernel_key;
}
bool is_xpu_unsupport =
(!paddle::platform::is_xpu_support_op(type_, expected_kernel_key) ||
paddle::platform::is_in_xpu_black_list(type_));
if (!is_xpu_kp_support &&
(kernel_iter == kernels.end() || is_xpu_unsupport)) {
VLOG(3) << "missing XPU kernel: " << type_
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
expected_kernel_key.place_ = platform::CPUPlace();
kernel_iter = kernels.find(expected_kernel_key);
}
}
#endif
......
......@@ -234,7 +234,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
auto& kernels = kernels_iter->second;
auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_XPU
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
if (paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
(kernel_iter == kernels.end() || is_xpu_unsupport)) {
VLOG(3) << "missing XPU kernel: " << op.Type()
......@@ -243,29 +243,36 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
expected_kernel_key.place_ = platform::CPUPlace();
kernel_iter = kernels.find(expected_kernel_key);
}
#endif
#ifdef PADDLE_WITH_XPU_KP
expected_kernel_key.place_ = platform::XPUPlace();
bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key);
bool use_xpu_kp_kernel_debug =
paddle::platform::is_in_xpu_kpwhite_list(op.Type());
if (use_xpu_kp_kernel_rt) {
VLOG(3) << "xpu_kp using rt mode ";
}
if (use_xpu_kp_kernel_debug) {
VLOG(3) << "xpu_kp using debug mode ";
}
if (paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
(use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug)) {
expected_kernel_key.place_ = platform::XPUPlace();
expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP;
kernel_iter = kernels.find(expected_kernel_key);
VLOG(3) << "using XPU KP kernel: " << op.Type()
<< ", using_kernel_key:" << expected_kernel_key;
if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) {
bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key);
bool use_xpu_kp_kernel_debug =
paddle::platform::is_in_xpu_kpwhite_list(op.Type());
if (use_xpu_kp_kernel_rt) {
VLOG(3) << "xpu_kp using rt mode ";
}
if (use_xpu_kp_kernel_debug) {
VLOG(3) << "xpu_kp using debug mode ";
}
bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
if (is_xpu_kp_support) {
expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP;
kernel_iter = kernels.find(expected_kernel_key);
VLOG(3) << "using XPU KP kernel: " << op.Type()
<< ", using_kernel_key:" << expected_kernel_key;
}
if (!is_xpu_kp_support &&
(kernel_iter == kernels.end() || is_xpu_unsupport)) {
VLOG(3) << "missing XPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
expected_kernel_key.place_ = platform::CPUPlace();
kernel_iter = kernels.find(expected_kernel_key);
}
}
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册