未验证 提交 c1394c6a 编写于 作者: L Liu-xiandong 提交者: GitHub

[KP]fix bug when TruncatedNormal cannot fall back in cpu (#41565)

* [KP]fix bug when TruncatedNormal cannot fall back in cpu

* delete useless comment

* delete useless comment
上级 91d6f47a
......@@ -1333,7 +1333,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// NOTE(Liu-xiandong): Determine whether the selected kernel is valid
// If not, use the kernel registered in fluid. And if the fluid do not
// contains the related heterogeneous kernel, use phi CPU kernel.
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
#if defined(PADDLE_WITH_XPU)
bool is_xpu_unsupport =
paddle::platform::is_xpu_place(kernel_type_->place_) &&
!paddle::platform::is_xpu_support_op(type_, *kernel_type_.get()) ||
......@@ -1373,7 +1373,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
|| is_xpu_unsupport
#endif
) {
#if defined(PADDLE_WITH_XPU_KP)
|| (is_xpu_unsupport && !is_xpu_kp_support)
#endif
) {
auto pt_cpu_kernel_key =
FallBackToCpu(*kernel_type_.get(), pt_kernel_key, *this);
pt_kernel_.reset(
......
......@@ -263,7 +263,10 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
|| is_xpu_unsupport
#endif
) {
#if defined(PADDLE_WITH_XPU_KP)
|| (is_xpu_unsupport && !is_xpu_kp_support)
#endif
) {
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) {
auto pt_cpu_kernel_key =
FallBackToCpu(expected_kernel_key, pt_kernel_key, op);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册