From 7ffde4bc050f8196c0d993acf5e51f8fe9c82829 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Mon, 19 Dec 2022 20:31:20 +0800 Subject: [PATCH] simplify FallbackToCpu (#49124) --- .../interpreter/interpreter_util.cc | 4 +- paddle/fluid/framework/operator.cc | 3 +- paddle/fluid/framework/phi_utils.cc | 39 ++++++++++--------- paddle/fluid/framework/phi_utils.h | 3 +- paddle/fluid/imperative/prepared_operator.cc | 3 +- 5 files changed, 25 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index f6b8fcb6bac..407da71c8c6 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -651,8 +651,8 @@ void BuildOpFuncList(const platform::Place& place, } else { if (!op_with_kernel->SupportsKernelType(expected_kernel_key, exec_ctx)) { - auto phi_cpu_kernel_key = FallBackToCpu( - expected_kernel_key, phi_kernel_key, *op_with_kernel); + auto phi_cpu_kernel_key = + FallBackToCpu(phi_kernel_key, *op_with_kernel); op_with_kernel->ResetPhiKernel( new phi::Kernel(phi::KernelFactory::Instance().SelectKernel( phi_kernel_name, phi_cpu_kernel_key))); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 5cb1cca0529..9773d90c5cd 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1808,8 +1808,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, #endif ) { fallback_to_cpu = true; - auto phi_cpu_kernel_key = - FallBackToCpu(*kernel_type_.get(), phi_kernel_key, *this); + auto phi_cpu_kernel_key = FallBackToCpu(phi_kernel_key, *this); phi_kernel_.reset( new phi::Kernel(phi::KernelFactory::Instance().SelectKernel( phi_kernel_name, phi_cpu_kernel_key))); diff --git a/paddle/fluid/framework/phi_utils.cc b/paddle/fluid/framework/phi_utils.cc index 165a8430759..53c35fc41c0 100644 --- a/paddle/fluid/framework/phi_utils.cc +++ b/paddle/fluid/framework/phi_utils.cc @@ -100,58 +100,59 @@ phi::KernelKey TransOpKernelTypeToPhiKernelKey( framework::TransToPhiDataType(kernel_type.data_type_)); } -phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, - const phi::KernelKey& kernel_key, +phi::KernelKey FallBackToCpu(const phi::KernelKey& kernel_key, const framework::OperatorBase& op) { #ifdef PADDLE_WITH_XPU - if (platform::is_xpu_place(expected_kernel_key.place_) || + if (kernel_key.backend() == phi::Backend::XPU || paddle::platform::is_in_xpu_black_list(op.Type())) { VLOG(3) << "phi missing XPU kernel: " << op.Type() - << ", expected_kernel_key:" << expected_kernel_key - << ", fallbacking to CPU one!"; + << ", expected_kernel_key:" << kernel_key + << ", fallback to CPU one!"; return phi::KernelKey( phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); } #endif #ifdef PADDLE_WITH_ASCEND_CL - if (platform::is_npu_place(expected_kernel_key.place_)) { + if (kernel_key.backend() == phi::Backend::NPU) { VLOG(3) << "phi missing NPU kernel: " << op.Type() - << ", expected_kernel_key:" << expected_kernel_key - << ", fallbacking to CPU one!"; + << ", expected_kernel_key:" << kernel_key + << ", fallback to CPU one!"; return phi::KernelKey( phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); } #endif #ifdef PADDLE_WITH_MLU - if (platform::is_mlu_place(expected_kernel_key.place_)) { + if (kernel_key.backend() == phi::Backend::MLU) { VLOG(3) << "phi missing MLU kernel: " << op.Type() - << ", expected_kernel_key:" << expected_kernel_key - << ", fallbacking to CPU one!"; + << ", expected_kernel_key:" << kernel_key + << ", fallback to CPU one!"; return phi::KernelKey( phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); } #endif #ifdef PADDLE_WITH_IPU - if (platform::is_ipu_place(expected_kernel_key.place_)) { + if (kernel_key.backend() == phi::Backend::IPU) { VLOG(3) << "phi missing IPU kernel: " << op.Type() - << ", expected_kernel_key:" << expected_kernel_key - << ", fallbacking to CPU one!"; + << ", expected_kernel_key:" << kernel_key + << ", fallback to CPU one!"; return phi::KernelKey( phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); } #endif #ifdef PADDLE_WITH_CUSTOM_DEVICE - if (platform::is_custom_place(expected_kernel_key.place_)) { - VLOG(3) << "phi missing " << expected_kernel_key.place_.GetDeviceType() + auto place = phi::TransToPhiPlace(kernel_key.backend()); + if (platform::is_custom_place(place)) { + VLOG(3) << "phi missing " << place.GetDeviceType() << " kernel: " << op.Type() - << ", expected_kernel_key:" << expected_kernel_key - << ", fallbacking to CPU one!"; + << ", expected_kernel_key:" << kernel_key + << ", fallback to CPU one!"; return phi::KernelKey( phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); } #endif #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::is_gpu_place(expected_kernel_key.place_)) { + if (kernel_key.backend() == phi::Backend::GPU || + kernel_key.backend() == phi::Backend::GPUDNN) { PADDLE_THROW(platform::errors::Unavailable( "For GPU kernel, they must not fallback into CPU kernel.")); } diff --git a/paddle/fluid/framework/phi_utils.h b/paddle/fluid/framework/phi_utils.h index 050a51a0f10..602528f5bb0 100644 --- a/paddle/fluid/framework/phi_utils.h +++ b/paddle/fluid/framework/phi_utils.h @@ -43,8 +43,7 @@ namespace framework { OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key); phi::KernelKey TransOpKernelTypeToPhiKernelKey(const OpKernelType& kernel_type); -phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, - const phi::KernelKey& kernel_key, +phi::KernelKey FallBackToCpu(const phi::KernelKey& kernel_key, const framework::OperatorBase& op); /* Kernel Args parse */ diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index b0cd6b07a40..5eb045a0c52 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -393,8 +393,7 @@ PreparedOp PrepareImpl( #endif ) { if (has_phi_kernel) { - auto phi_cpu_kernel_key = - FallBackToCpu(expected_kernel_key, phi_kernel_key, op); + auto phi_cpu_kernel_key = FallBackToCpu(phi_kernel_key, op); auto& phi_cpu_kernel = phi_kernel_factory.SelectKernel(phi_kernel_name, phi_cpu_kernel_key); if (phi_cpu_kernel.IsValid()) { -- GitLab