diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 701fc7de6940ac7fde48da492206623fd2f2d811..692ebf6f332f15be552a223cab89eabbf5c4a69b 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1211,7 +1211,17 @@ void OperatorWithKernel::RunImpl(const Scope& scope, << "` not found."; } } - if (pt_kernel_->IsValid()) { +#ifdef PADDLE_WITH_XPU + bool is_xpu_unsupport = + paddle::platform::is_xpu_place(kernel_type_->place_) && + !paddle::platform::is_xpu_support_op(type_, *kernel_type_.get()) || + paddle::platform::is_in_xpu_black_list(type_); +#endif + if (pt_kernel_->IsValid() +#ifdef PADDLE_WITH_XPU + && !is_xpu_unsupport +#endif + ) { run_pten_kernel_ = true; } else { auto& all_op_kernels = AllOpKernels(); @@ -1220,13 +1230,9 @@ void OperatorWithKernel::RunImpl(const Scope& scope, kernels_iter->second.find(*kernel_type_.get()) == kernels_iter->second.end() #ifdef PADDLE_WITH_XPU - || - paddle::platform::is_xpu_place(kernel_type_->place_) && // NOLINT - !paddle::platform::is_xpu_support_op( - type_, *kernel_type_.get()) // NOLINT - || paddle::platform::is_in_xpu_black_list(type_) + || is_xpu_unsupport #endif - ) { + ) { auto pt_cpu_kernel_key = FallBackToCpu(*kernel_type_.get(), pt_kernel_key, *this); pt_kernel_.reset( diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 05218ba961fdd115bd0d28755ce14e03a1c01003..6d18b0a86f0911f38e1c51d61467bf9a01a6de21 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -161,6 +161,13 @@ PreparedOp PrepareImpl(const NameVarMap& ins, framework::KernelSignature pt_kernel_signature; phi::KernelKey pt_kernel_key; std::string pt_kernel_name; +#ifdef PADDLE_WITH_XPU + bool is_xpu_unsupport = + paddle::platform::is_xpu_place(expected_kernel_key.place_) && + !paddle::platform::is_xpu_support_op(op.Type(), + expected_kernel_key) || + paddle::platform::is_in_xpu_black_list(op.Type()); +#endif if (phi::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) { pt_kernel_signature = op.GetExpectedPtenKernelArgs(dygraph_exe_ctx); VLOG(6) << pt_kernel_signature; @@ -170,7 +177,11 @@ PreparedOp PrepareImpl(const NameVarMap& ins, auto pt_kernel = phi::KernelFactory::Instance().SelectKernel(pt_kernel_name, pt_kernel_key); - if (pt_kernel.IsValid()) { + if (pt_kernel.IsValid() +#ifdef PADDLE_WITH_XPU + && !is_xpu_unsupport +#endif + ) { VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name << " | kernel key: " << pt_kernel_key << " | kernel: " << pt_kernel; @@ -197,13 +208,9 @@ PreparedOp PrepareImpl(const NameVarMap& ins, kernels_iter->second.find(expected_kernel_key) == kernels_iter->second.end()) #ifdef PADDLE_WITH_XPU - || - paddle::platform::is_xpu_place(expected_kernel_key.place_) && - !paddle::platform::is_xpu_support_op(op.Type(), - expected_kernel_key) || - paddle::platform::is_in_xpu_black_list(op.Type()) + || is_xpu_unsupport #endif - ) { + ) { if (phi::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) { auto pt_cpu_kernel_key = FallBackToCpu(expected_kernel_key, pt_kernel_key, op); @@ -230,9 +237,7 @@ PreparedOp PrepareImpl(const NameVarMap& ins, #ifdef PADDLE_WITH_XPU if (paddle::platform::is_xpu_place(expected_kernel_key.place_) && - (kernel_iter == kernels.end() || - !paddle::platform::is_xpu_support_op(op.Type(), expected_kernel_key) || - paddle::platform::is_in_xpu_black_list(op.Type()))) { + (kernel_iter == kernels.end() || is_xpu_unsupport)) { VLOG(3) << "missing XPU kernel: " << op.Type() << ", expected_kernel_key:" << expected_kernel_key << ", fallbacking to CPU one!"; diff --git a/paddle/phi/core/compat/convert_utils.cc b/paddle/phi/core/compat/convert_utils.cc index b4e7e127995ec2d0eeda788e9d6e6f9ccf12f8b1..a5b7b869b948dfb17b9f58a455bb336a4f021c4f 100644 --- a/paddle/phi/core/compat/convert_utils.cc +++ b/paddle/phi/core/compat/convert_utils.cc @@ -30,6 +30,8 @@ Backend TransToPtenBackend(const phi::Place& place) { return Backend::CPU; } else if (place.GetType() == phi::AllocationType::GPU) { return Backend::GPU; + } else if (place.GetType() == phi::AllocationType::XPU) { + return Backend::XPU; } else if (place.GetType() == phi::AllocationType::CUSTOM) { return static_cast( static_cast(Backend::NUM_BACKENDS) +