未验证 提交 ac5548a2 编写于 作者: L Liu-xiandong 提交者: GitHub

[KP] fix bug in phi kp (#41069)

* [KP] fix bug in phi kp

* delete useless comment

* update

* update

* choose the xpu kp kernel in phi
上级 7c5dca9f
...@@ -161,24 +161,48 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -161,24 +161,48 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
framework::KernelSignature pt_kernel_signature; framework::KernelSignature pt_kernel_signature;
phi::KernelKey pt_kernel_key; phi::KernelKey pt_kernel_key;
std::string pt_kernel_name; std::string pt_kernel_name;
#ifdef PADDLE_WITH_XPU #if defined(PADDLE_WITH_XPU)
bool is_xpu_unsupport = bool is_xpu_unsupport =
paddle::platform::is_xpu_place(expected_kernel_key.place_) && paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
!paddle::platform::is_xpu_support_op(op.Type(), !paddle::platform::is_xpu_support_op(op.Type(),
expected_kernel_key) || expected_kernel_key) ||
paddle::platform::is_in_xpu_black_list(op.Type()); paddle::platform::is_in_xpu_black_list(op.Type());
#endif #endif
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) { if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) {
pt_kernel_signature = op.GetExpectedPhiKernelArgs(dygraph_exe_ctx); pt_kernel_signature = op.GetExpectedPhiKernelArgs(dygraph_exe_ctx);
VLOG(6) << pt_kernel_signature; VLOG(6) << pt_kernel_signature;
pt_kernel_name = pt_kernel_signature.name; pt_kernel_name = pt_kernel_signature.name;
// modify the expected_kernel_key for KP in phi
#ifdef PADDLE_WITH_XPU_KP
if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) {
bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel && paddle::platform::is_xpu_kp_support_op(
op.Type(), expected_kernel_key);
bool use_xpu_kp_kernel_debug =
paddle::platform::is_in_xpu_kpwhite_list(op.Type());
if (use_xpu_kp_kernel_rt) {
VLOG(3) << "phi xpu_kp using rt mode ";
}
if (use_xpu_kp_kernel_debug) {
VLOG(3) << "phi xpu_kp using debug mode ";
}
bool is_xpu_kp_support =
(use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
if (is_xpu_kp_support) {
expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP;
VLOG(3) << "modify XPU KP kernel: " << op.Type()
<< ", using_kernel_key:" << expected_kernel_key;
}
}
#endif
pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key); pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
auto pt_kernel = phi::KernelFactory::Instance().SelectKernel(pt_kernel_name, auto pt_kernel = phi::KernelFactory::Instance().SelectKernel(pt_kernel_name,
pt_kernel_key); pt_kernel_key);
if (pt_kernel.IsValid() if (pt_kernel.IsValid()
#ifdef PADDLE_WITH_XPU #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
&& !is_xpu_unsupport && !is_xpu_unsupport
#endif #endif
) { ) {
...@@ -206,7 +230,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -206,7 +230,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
if ((kernels_iter == all_op_kernels.end() || if ((kernels_iter == all_op_kernels.end() ||
kernels_iter->second.find(expected_kernel_key) == kernels_iter->second.find(expected_kernel_key) ==
kernels_iter->second.end()) kernels_iter->second.end())
#ifdef PADDLE_WITH_XPU #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
|| is_xpu_unsupport || is_xpu_unsupport
#endif #endif
) { ) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册