未验证 提交 c1394c6a 编写于 作者: L Liu-xiandong 提交者: GitHub

[KP]fix bug when TruncatedNormal cannot fall back in cpu (#41565)

* [KP]fix bug when TruncatedNormal cannot fall back in cpu

* delete useless comment

* delete useless comment
上级 91d6f47a
...@@ -1333,7 +1333,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1333,7 +1333,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// NOTE(Liu-xiandong): Determine whether the selected kernel is valid // NOTE(Liu-xiandong): Determine whether the selected kernel is valid
// If not, use the kernel registered in fluid. And if the fluid do not // If not, use the kernel registered in fluid. And if the fluid do not
// contains the related heterogeneous kernel, use phi CPU kernel. // contains the related heterogeneous kernel, use phi CPU kernel.
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) #if defined(PADDLE_WITH_XPU)
bool is_xpu_unsupport = bool is_xpu_unsupport =
paddle::platform::is_xpu_place(kernel_type_->place_) && paddle::platform::is_xpu_place(kernel_type_->place_) &&
!paddle::platform::is_xpu_support_op(type_, *kernel_type_.get()) || !paddle::platform::is_xpu_support_op(type_, *kernel_type_.get()) ||
...@@ -1373,7 +1373,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1373,7 +1373,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
|| is_xpu_unsupport || is_xpu_unsupport
#endif #endif
) { #if defined(PADDLE_WITH_XPU_KP)
|| (is_xpu_unsupport && !is_xpu_kp_support)
#endif
) {
auto pt_cpu_kernel_key = auto pt_cpu_kernel_key =
FallBackToCpu(*kernel_type_.get(), pt_kernel_key, *this); FallBackToCpu(*kernel_type_.get(), pt_kernel_key, *this);
pt_kernel_.reset( pt_kernel_.reset(
......
...@@ -263,7 +263,10 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -263,7 +263,10 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
|| is_xpu_unsupport || is_xpu_unsupport
#endif #endif
) { #if defined(PADDLE_WITH_XPU_KP)
|| (is_xpu_unsupport && !is_xpu_kp_support)
#endif
) {
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) { if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) {
auto pt_cpu_kernel_key = auto pt_cpu_kernel_key =
FallBackToCpu(expected_kernel_key, pt_kernel_key, op); FallBackToCpu(expected_kernel_key, pt_kernel_key, op);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册