From c1394c6ae1a0293ca7800d6614f955f09c52785e Mon Sep 17 00:00:00 2001 From: Liu-xiandong <85323580+Liu-xiandong@users.noreply.github.com> Date: Sun, 10 Apr 2022 21:59:09 +0800 Subject: [PATCH] [KP]fix bug when TruncatedNormal cannot fall back in cpu (#41565) * [KP]fix bug when TruncatedNormal cannot fall back in cpu * delete useless comment * delete useless comment --- paddle/fluid/framework/operator.cc | 7 +++++-- paddle/fluid/imperative/prepared_operator.cc | 5 ++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 6af07caaf88..e6577f662ae 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1333,7 +1333,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, // NOTE(Liu-xiandong): Determine whether the selected kernel is valid // If not, use the kernel registered in fluid. And if the fluid do not // contains the related heterogeneous kernel, use phi CPU kernel. -#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) +#if defined(PADDLE_WITH_XPU) bool is_xpu_unsupport = paddle::platform::is_xpu_place(kernel_type_->place_) && !paddle::platform::is_xpu_support_op(type_, *kernel_type_.get()) || @@ -1373,7 +1373,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope, #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) || is_xpu_unsupport #endif - ) { +#if defined(PADDLE_WITH_XPU_KP) + || (is_xpu_unsupport && !is_xpu_kp_support) +#endif + ) { auto pt_cpu_kernel_key = FallBackToCpu(*kernel_type_.get(), pt_kernel_key, *this); pt_kernel_.reset( diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index b56d113937d..0ad5e808b1d 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -263,7 +263,10 @@ PreparedOp PrepareImpl(const NameVarMap& ins, #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) || is_xpu_unsupport #endif - ) { +#if defined(PADDLE_WITH_XPU_KP) + || (is_xpu_unsupport && !is_xpu_kp_support) +#endif + ) { if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) { auto pt_cpu_kernel_key = FallBackToCpu(expected_kernel_key, pt_kernel_key, op); -- GitLab