未验证 提交 7ffde4bc 编写于 作者: H HongyuJia 提交者: GitHub

simplify FallbackToCpu (#49124)

上级 ff79c144
...@@ -651,8 +651,8 @@ void BuildOpFuncList(const platform::Place& place, ...@@ -651,8 +651,8 @@ void BuildOpFuncList(const platform::Place& place,
} else { } else {
if (!op_with_kernel->SupportsKernelType(expected_kernel_key, if (!op_with_kernel->SupportsKernelType(expected_kernel_key,
exec_ctx)) { exec_ctx)) {
auto phi_cpu_kernel_key = FallBackToCpu( auto phi_cpu_kernel_key =
expected_kernel_key, phi_kernel_key, *op_with_kernel); FallBackToCpu(phi_kernel_key, *op_with_kernel);
op_with_kernel->ResetPhiKernel( op_with_kernel->ResetPhiKernel(
new phi::Kernel(phi::KernelFactory::Instance().SelectKernel( new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
phi_kernel_name, phi_cpu_kernel_key))); phi_kernel_name, phi_cpu_kernel_key)));
......
...@@ -1808,8 +1808,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1808,8 +1808,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
#endif #endif
) { ) {
fallback_to_cpu = true; fallback_to_cpu = true;
auto phi_cpu_kernel_key = auto phi_cpu_kernel_key = FallBackToCpu(phi_kernel_key, *this);
FallBackToCpu(*kernel_type_.get(), phi_kernel_key, *this);
phi_kernel_.reset( phi_kernel_.reset(
new phi::Kernel(phi::KernelFactory::Instance().SelectKernel( new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
phi_kernel_name, phi_cpu_kernel_key))); phi_kernel_name, phi_cpu_kernel_key)));
......
...@@ -100,58 +100,59 @@ phi::KernelKey TransOpKernelTypeToPhiKernelKey( ...@@ -100,58 +100,59 @@ phi::KernelKey TransOpKernelTypeToPhiKernelKey(
framework::TransToPhiDataType(kernel_type.data_type_)); framework::TransToPhiDataType(kernel_type.data_type_));
} }
phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, phi::KernelKey FallBackToCpu(const phi::KernelKey& kernel_key,
const phi::KernelKey& kernel_key,
const framework::OperatorBase& op) { const framework::OperatorBase& op) {
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(expected_kernel_key.place_) || if (kernel_key.backend() == phi::Backend::XPU ||
paddle::platform::is_in_xpu_black_list(op.Type())) { paddle::platform::is_in_xpu_black_list(op.Type())) {
VLOG(3) << "phi missing XPU kernel: " << op.Type() VLOG(3) << "phi missing XPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key << ", expected_kernel_key:" << kernel_key
<< ", fallbacking to CPU one!"; << ", fallback to CPU one!";
return phi::KernelKey( return phi::KernelKey(
phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype());
} }
#endif #endif
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(expected_kernel_key.place_)) { if (kernel_key.backend() == phi::Backend::NPU) {
VLOG(3) << "phi missing NPU kernel: " << op.Type() VLOG(3) << "phi missing NPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key << ", expected_kernel_key:" << kernel_key
<< ", fallbacking to CPU one!"; << ", fallback to CPU one!";
return phi::KernelKey( return phi::KernelKey(
phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype());
} }
#endif #endif
#ifdef PADDLE_WITH_MLU #ifdef PADDLE_WITH_MLU
if (platform::is_mlu_place(expected_kernel_key.place_)) { if (kernel_key.backend() == phi::Backend::MLU) {
VLOG(3) << "phi missing MLU kernel: " << op.Type() VLOG(3) << "phi missing MLU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key << ", expected_kernel_key:" << kernel_key
<< ", fallbacking to CPU one!"; << ", fallback to CPU one!";
return phi::KernelKey( return phi::KernelKey(
phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype());
} }
#endif #endif
#ifdef PADDLE_WITH_IPU #ifdef PADDLE_WITH_IPU
if (platform::is_ipu_place(expected_kernel_key.place_)) { if (kernel_key.backend() == phi::Backend::IPU) {
VLOG(3) << "phi missing IPU kernel: " << op.Type() VLOG(3) << "phi missing IPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key << ", expected_kernel_key:" << kernel_key
<< ", fallbacking to CPU one!"; << ", fallback to CPU one!";
return phi::KernelKey( return phi::KernelKey(
phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype());
} }
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
if (platform::is_custom_place(expected_kernel_key.place_)) { auto place = phi::TransToPhiPlace(kernel_key.backend());
VLOG(3) << "phi missing " << expected_kernel_key.place_.GetDeviceType() if (platform::is_custom_place(place)) {
VLOG(3) << "phi missing " << place.GetDeviceType()
<< " kernel: " << op.Type() << " kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key << ", expected_kernel_key:" << kernel_key
<< ", fallbacking to CPU one!"; << ", fallback to CPU one!";
return phi::KernelKey( return phi::KernelKey(
phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype()); phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype());
} }
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(expected_kernel_key.place_)) { if (kernel_key.backend() == phi::Backend::GPU ||
kernel_key.backend() == phi::Backend::GPUDNN) {
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"For GPU kernel, they must not fallback into CPU kernel.")); "For GPU kernel, they must not fallback into CPU kernel."));
} }
......
...@@ -43,8 +43,7 @@ namespace framework { ...@@ -43,8 +43,7 @@ namespace framework {
OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key); OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key);
phi::KernelKey TransOpKernelTypeToPhiKernelKey(const OpKernelType& kernel_type); phi::KernelKey TransOpKernelTypeToPhiKernelKey(const OpKernelType& kernel_type);
phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key, phi::KernelKey FallBackToCpu(const phi::KernelKey& kernel_key,
const phi::KernelKey& kernel_key,
const framework::OperatorBase& op); const framework::OperatorBase& op);
/* Kernel Args parse */ /* Kernel Args parse */
......
...@@ -393,8 +393,7 @@ PreparedOp PrepareImpl( ...@@ -393,8 +393,7 @@ PreparedOp PrepareImpl(
#endif #endif
) { ) {
if (has_phi_kernel) { if (has_phi_kernel) {
auto phi_cpu_kernel_key = auto phi_cpu_kernel_key = FallBackToCpu(phi_kernel_key, op);
FallBackToCpu(expected_kernel_key, phi_kernel_key, op);
auto& phi_cpu_kernel = auto& phi_cpu_kernel =
phi_kernel_factory.SelectKernel(phi_kernel_name, phi_cpu_kernel_key); phi_kernel_factory.SelectKernel(phi_kernel_name, phi_cpu_kernel_key);
if (phi_cpu_kernel.IsValid()) { if (phi_cpu_kernel.IsValid()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册