diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index 2fa40320e558308dd2ea60dd5aa7c9e40ad6e210..19a9b808dd6f6d986189928741c49200c1d2b86e 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -99,8 +99,15 @@ Tensor add_n_impl(const std::vector& x) { (*kernel_fn)(*dev_ctx, input_x, kernel_out); } else { std::vector input_x(x.size()); + std::vector> temp_dense_tensots; + temp_dense_tensots.reserve(x.size()); for (size_t i = 0; i < input_x.size(); ++i) { - input_x[i] = x[i].impl().get(); + if (phi::DenseTensor::classof(x[i].impl().get())) { + temp_dense_tensots.push_back(PrepareData(x[i], kernel.InputAt(0), {})); + input_x[i] = temp_dense_tensots.back().get(); + } else { + input_x[i] = x[i].impl().get(); + } } auto x_meta_vec = MakeMetaTensor(input_x); std::vector x_metas(x_meta_vec.size()); @@ -118,6 +125,9 @@ Tensor add_n_impl(const std::vector& x) { auto* kernel_fn = kernel.GetVariadicKernelFn(); (*kernel_fn)(*dev_ctx, input_x, kernel_out); + if (kernel_result.has_fallback_cpu) { + TransDataBackend(kernel_out, kernel_backend, kernel_out); + } } return api_output; diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 480882550dbcaddd42d50e7ca8a30df95242330a..bbfe10591f0f92beb7f5fd9736dcc808574ebbab 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -145,10 +145,12 @@ KernelResult KernelFactory::SelectKernelOrThrowError( kernel_key, kernel_name)); - if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) - || paddle::platform::is_in_xpu_black_list(TransToFluidOpName(kernel_name)) - + VLOG(6) << "fluid_op_name: " << TransToFluidOpName(kernel_name); + if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) || + paddle::platform::is_in_xpu_black_list(TransToFluidOpName(kernel_name)) +#else + if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) #endif ) { // Fallback CPU backend