From 43ad0b17951dfc858dc4136afffab7c4fb6cda9e Mon Sep 17 00:00:00 2001 From: Leo Guo <58431564+ZibinGuo@users.noreply.github.com> Date: Fri, 21 Oct 2022 13:24:04 +0800 Subject: [PATCH] Fix the bug where the device memory address appears in abs_grad kernel fallback to CPU. test=kunlun (#47186) --- paddle/phi/api/lib/api_custom_impl.cc | 12 +++++++++++- paddle/phi/core/kernel_factory.cc | 8 +++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index 2fa40320e55..19a9b808dd6 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -99,8 +99,15 @@ Tensor add_n_impl(const std::vector& x) { (*kernel_fn)(*dev_ctx, input_x, kernel_out); } else { std::vector input_x(x.size()); + std::vector> temp_dense_tensots; + temp_dense_tensots.reserve(x.size()); for (size_t i = 0; i < input_x.size(); ++i) { - input_x[i] = x[i].impl().get(); + if (phi::DenseTensor::classof(x[i].impl().get())) { + temp_dense_tensots.push_back(PrepareData(x[i], kernel.InputAt(0), {})); + input_x[i] = temp_dense_tensots.back().get(); + } else { + input_x[i] = x[i].impl().get(); + } } auto x_meta_vec = MakeMetaTensor(input_x); std::vector x_metas(x_meta_vec.size()); @@ -118,6 +125,9 @@ Tensor add_n_impl(const std::vector& x) { auto* kernel_fn = kernel.GetVariadicKernelFn(); (*kernel_fn)(*dev_ctx, input_x, kernel_out); + if (kernel_result.has_fallback_cpu) { + TransDataBackend(kernel_out, kernel_backend, kernel_out); + } } return api_output; diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 480882550db..bbfe10591f0 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -145,10 +145,12 @@ KernelResult KernelFactory::SelectKernelOrThrowError( kernel_key, kernel_name)); - if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) - || paddle::platform::is_in_xpu_black_list(TransToFluidOpName(kernel_name)) - + VLOG(6) << "fluid_op_name: " << TransToFluidOpName(kernel_name); + if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) || + paddle::platform::is_in_xpu_black_list(TransToFluidOpName(kernel_name)) +#else + if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) #endif ) { // Fallback CPU backend -- GitLab