diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index f6b8fcb6bac7c505e914a3014ac0f3886eb11d7b..407da71c8c66337ea7613588a2e9a23f1243f1db 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -651,8 +651,8 @@ void BuildOpFuncList(const platform::Place& place, } else { if (!op_with_kernel->SupportsKernelType(expected_kernel_key, exec_ctx)) { - auto phi_cpu_kernel_key = FallBackToCpu( - expected_kernel_key, phi_kernel_key, *op_with_kernel); + auto phi_cpu_kernel_key = + FallBackToCpu(phi_kernel_key, *op_with_kernel); op_with_kernel->ResetPhiKernel( new phi::Kernel(phi::KernelFactory::Instance().SelectKernel( phi_kernel_name, phi_cpu_kernel_key))); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 5cb1cca0529027cbcff1c113e7736e4d1b00e15e..9773d90c5cd12cfc32cd834c68aa657df74de286 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1808,8 +1808,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, #endif ) { fallback_to_cpu = true; - auto phi_cpu_kernel_key = - FallBackToCpu(*kernel_type_.get(), phi_kernel_key, *this); + auto phi_cpu_kernel_key = FallBackToCpu(phi_kernel_key, *this); phi_kernel_.reset( new phi::Kernel(phi::KernelFactory::Instance().SelectKernel( phi_kernel_name, phi_cpu_kernel_key))); diff --git a/paddle/fluid/framework/phi_utils.cc b/paddle/fluid/framework/phi_utils.cc index 165a84307591215d669e3145e821550b149d4006..53c35fc41c07885346a8f5c0f6fdaec7224895d8 100644 --- a/paddle/fluid/framework/phi_utils.cc +++ b/paddle/fluid/framework/phi_utils.cc @@ -100,58 +100,59 @@ phi::KernelKey TransOpKernelTypeToPhiKernelKey( framework::TransToPhiDataType(kernel_type.data_type_)); } -phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, - const phi::KernelKey& kernel_key, +phi::KernelKey FallBackToCpu(const phi::KernelKey& kernel_key, const framework::OperatorBase& op) { #ifdef PADDLE_WITH_XPU - if (platform::is_xpu_place(expected_kernel_key.place_) || + if (kernel_key.backend() == phi::Backend::XPU || paddle::platform::is_in_xpu_black_list(op.Type())) { VLOG(3) << "phi missing XPU kernel: " << op.Type() - << ", expected_kernel_key:" << expected_kernel_key - << ", fallbacking to CPU one!"; + << ", expected_kernel_key:" << kernel_key + << ", fallback to CPU one!"; return phi::KernelKey( phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); } #endif #ifdef PADDLE_WITH_ASCEND_CL - if (platform::is_npu_place(expected_kernel_key.place_)) { + if (kernel_key.backend() == phi::Backend::NPU) { VLOG(3) << "phi missing NPU kernel: " << op.Type() - << ", expected_kernel_key:" << expected_kernel_key - << ", fallbacking to CPU one!"; + << ", expected_kernel_key:" << kernel_key + << ", fallback to CPU one!"; return phi::KernelKey( phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); } #endif #ifdef PADDLE_WITH_MLU - if (platform::is_mlu_place(expected_kernel_key.place_)) { + if (kernel_key.backend() == phi::Backend::MLU) { VLOG(3) << "phi missing MLU kernel: " << op.Type() - << ", expected_kernel_key:" << expected_kernel_key - << ", fallbacking to CPU one!"; + << ", expected_kernel_key:" << kernel_key + << ", fallback to CPU one!"; return phi::KernelKey( phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); } #endif #ifdef PADDLE_WITH_IPU - if (platform::is_ipu_place(expected_kernel_key.place_)) { + if (kernel_key.backend() == phi::Backend::IPU) { VLOG(3) << "phi missing IPU kernel: " << op.Type() - << ", expected_kernel_key:" << expected_kernel_key - << ", fallbacking to CPU one!"; + << ", expected_kernel_key:" << kernel_key + << ", fallback to CPU one!"; return phi::KernelKey( phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); } #endif #ifdef PADDLE_WITH_CUSTOM_DEVICE - if (platform::is_custom_place(expected_kernel_key.place_)) { - VLOG(3) << "phi missing " << expected_kernel_key.place_.GetDeviceType() + auto place = phi::TransToPhiPlace(kernel_key.backend()); + if (platform::is_custom_place(place)) { + VLOG(3) << "phi missing " << place.GetDeviceType() << " kernel: " << op.Type() - << ", expected_kernel_key:" << expected_kernel_key - << ", fallbacking to CPU one!"; + << ", expected_kernel_key:" << kernel_key + << ", fallback to CPU one!"; return phi::KernelKey( phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); } #endif #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::is_gpu_place(expected_kernel_key.place_)) { + if (kernel_key.backend() == phi::Backend::GPU || + kernel_key.backend() == phi::Backend::GPUDNN) { PADDLE_THROW(platform::errors::Unavailable( "For GPU kernel, they must not fallback into CPU kernel.")); } diff --git a/paddle/fluid/framework/phi_utils.h b/paddle/fluid/framework/phi_utils.h index 050a51a0f1077636abb0267528973643d7a24ff9..602528f5bb061603b90d350246baa8ce1992453f 100644 --- a/paddle/fluid/framework/phi_utils.h +++ b/paddle/fluid/framework/phi_utils.h @@ -43,8 +43,7 @@ namespace framework { OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key); phi::KernelKey TransOpKernelTypeToPhiKernelKey(const OpKernelType& kernel_type); -phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, - const phi::KernelKey& kernel_key, +phi::KernelKey FallBackToCpu(const phi::KernelKey& kernel_key, const framework::OperatorBase& op); /* Kernel Args parse */ diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index b0cd6b07a40469c475bb8a443722bcf5ab39c2b4..5eb045a0c522396e42450ca14ece251f626292e3 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -393,8 +393,7 @@ PreparedOp PrepareImpl( #endif ) { if (has_phi_kernel) { - auto phi_cpu_kernel_key = - FallBackToCpu(expected_kernel_key, phi_kernel_key, op); + auto phi_cpu_kernel_key = FallBackToCpu(phi_kernel_key, op); auto& phi_cpu_kernel = phi_kernel_factory.SelectKernel(phi_kernel_name, phi_cpu_kernel_key); if (phi_cpu_kernel.IsValid()) {