diff --git a/paddle/fluid/eager/accumulation/accumulation_node.cc b/paddle/fluid/eager/accumulation/accumulation_node.cc index c95b08382b8ca655d94f73d4a5a323f962d731a0..92b74144e1e6f895fc0f62ef882fd2d55f8da283 100644 --- a/paddle/fluid/eager/accumulation/accumulation_node.cc +++ b/paddle/fluid/eager/accumulation/accumulation_node.cc @@ -46,7 +46,25 @@ static void CopyOrAddTensor(paddle::Tensor* tensor, if (LIKELY(t.is_dense_tensor())) { if (LIKELY(tensor->is_dense_tensor())) { if (t.is_custom_device()) { - add__ad_func(*tensor, t); + auto* dev_ctx = + phi::DeviceContextPool::Instance().Get(tensor->place()); + auto kernel_result = + phi::KernelFactory::Instance().SelectKernelOrThrowError( + "add", + phi::KernelKey(phi::TransToPhiBackend(tensor->place()), + phi::DataLayout::ALL_LAYOUT, + tensor->dtype())); + const auto& kernel = kernel_result.kernel; + using kernel_signature = void (*)(const phi::DeviceContext&, + const phi::DenseTensor&, + const phi::DenseTensor&, + phi::DenseTensor*); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + (*kernel_fn)( + *dev_ctx, + *reinterpret_cast(tensor->impl().get()), + *reinterpret_cast(t.impl().get()), + reinterpret_cast(tensor->impl().get())); } else { paddle::imperative::TensorAdd(t, tensor); } @@ -71,7 +89,25 @@ static void CopyOrAddTensor(paddle::Tensor* tensor, paddle::Tensor tensor_values(std::make_shared( tensor_sparse->non_zero_elements())); if (t.is_custom_device()) { - add__ad_func(tensor_values, t_values); + auto* dev_ctx = + phi::DeviceContextPool::Instance().Get(tensor->place()); + auto kernel_result = + phi::KernelFactory::Instance().SelectKernelOrThrowError( + "add_coo_coo", + phi::KernelKey(phi::TransToPhiBackend(tensor->place()), + phi::DataLayout::ALL_LAYOUT, + tensor->dtype())); + const auto& kernel = kernel_result.kernel; + using kernel_signature = void (*)(const phi::DeviceContext&, + const phi::SparseCooTensor&, + const phi::SparseCooTensor&, + phi::SparseCooTensor*); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + (*kernel_fn)( + *dev_ctx, + *reinterpret_cast(tensor->impl().get()), + *reinterpret_cast(t.impl().get()), + reinterpret_cast(tensor->impl().get())); } else { paddle::imperative::TensorAdd(t_values, &tensor_values);