未验证 提交 43ad0b17 编写于 作者: L Leo Guo 提交者: GitHub

Fix the bug where the device memory address appears in abs_grad kernel...

Fix the bug where the device memory address appears in abs_grad kernel fallback to CPU. test=kunlun (#47186)
上级 340009d6
...@@ -99,8 +99,15 @@ Tensor add_n_impl(const std::vector<Tensor>& x) { ...@@ -99,8 +99,15 @@ Tensor add_n_impl(const std::vector<Tensor>& x) {
(*kernel_fn)(*dev_ctx, input_x, kernel_out); (*kernel_fn)(*dev_ctx, input_x, kernel_out);
} else { } else {
std::vector<const phi::TensorBase*> input_x(x.size()); std::vector<const phi::TensorBase*> input_x(x.size());
std::vector<std::shared_ptr<phi::DenseTensor>> temp_dense_tensots;
temp_dense_tensots.reserve(x.size());
for (size_t i = 0; i < input_x.size(); ++i) { for (size_t i = 0; i < input_x.size(); ++i) {
input_x[i] = x[i].impl().get(); if (phi::DenseTensor::classof(x[i].impl().get())) {
temp_dense_tensots.push_back(PrepareData(x[i], kernel.InputAt(0), {}));
input_x[i] = temp_dense_tensots.back().get();
} else {
input_x[i] = x[i].impl().get();
}
} }
auto x_meta_vec = MakeMetaTensor(input_x); auto x_meta_vec = MakeMetaTensor(input_x);
std::vector<const phi::MetaTensor*> x_metas(x_meta_vec.size()); std::vector<const phi::MetaTensor*> x_metas(x_meta_vec.size());
...@@ -118,6 +125,9 @@ Tensor add_n_impl(const std::vector<Tensor>& x) { ...@@ -118,6 +125,9 @@ Tensor add_n_impl(const std::vector<Tensor>& x) {
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>(); auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, input_x, kernel_out); (*kernel_fn)(*dev_ctx, input_x, kernel_out);
if (kernel_result.has_fallback_cpu) {
TransDataBackend(kernel_out, kernel_backend, kernel_out);
}
} }
return api_output; return api_output;
......
...@@ -145,10 +145,12 @@ KernelResult KernelFactory::SelectKernelOrThrowError( ...@@ -145,10 +145,12 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
kernel_key, kernel_key,
kernel_name)); kernel_name));
if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end())
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
|| paddle::platform::is_in_xpu_black_list(TransToFluidOpName(kernel_name)) VLOG(6) << "fluid_op_name: " << TransToFluidOpName(kernel_name);
if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) ||
paddle::platform::is_in_xpu_black_list(TransToFluidOpName(kernel_name))
#else
if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end())
#endif #endif
) { ) {
// Fallback CPU backend // Fallback CPU backend
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册