未验证 提交 46f9d9b7 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] fix accumulation node (#56205)

上级 6e40fc1d
......@@ -46,7 +46,25 @@ static void CopyOrAddTensor(paddle::Tensor* tensor,
if (LIKELY(t.is_dense_tensor())) {
if (LIKELY(tensor->is_dense_tensor())) {
if (t.is_custom_device()) {
add__ad_func(*tensor, t);
auto* dev_ctx =
phi::DeviceContextPool::Instance().Get(tensor->place());
auto kernel_result =
phi::KernelFactory::Instance().SelectKernelOrThrowError(
"add",
phi::KernelKey(phi::TransToPhiBackend(tensor->place()),
phi::DataLayout::ALL_LAYOUT,
tensor->dtype()));
const auto& kernel = kernel_result.kernel;
using kernel_signature = void (*)(const phi::DeviceContext&,
const phi::DenseTensor&,
const phi::DenseTensor&,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(
*dev_ctx,
*reinterpret_cast<phi::DenseTensor*>(tensor->impl().get()),
*reinterpret_cast<phi::DenseTensor*>(t.impl().get()),
reinterpret_cast<phi::DenseTensor*>(tensor->impl().get()));
} else {
paddle::imperative::TensorAdd<paddle::Tensor>(t, tensor);
}
......@@ -71,7 +89,25 @@ static void CopyOrAddTensor(paddle::Tensor* tensor,
paddle::Tensor tensor_values(std::make_shared<phi::DenseTensor>(
tensor_sparse->non_zero_elements()));
if (t.is_custom_device()) {
add__ad_func(tensor_values, t_values);
auto* dev_ctx =
phi::DeviceContextPool::Instance().Get(tensor->place());
auto kernel_result =
phi::KernelFactory::Instance().SelectKernelOrThrowError(
"add_coo_coo",
phi::KernelKey(phi::TransToPhiBackend(tensor->place()),
phi::DataLayout::ALL_LAYOUT,
tensor->dtype()));
const auto& kernel = kernel_result.kernel;
using kernel_signature = void (*)(const phi::DeviceContext&,
const phi::SparseCooTensor&,
const phi::SparseCooTensor&,
phi::SparseCooTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(
*dev_ctx,
*reinterpret_cast<phi::SparseCooTensor*>(tensor->impl().get()),
*reinterpret_cast<phi::SparseCooTensor*>(t.impl().get()),
reinterpret_cast<phi::SparseCooTensor*>(tensor->impl().get()));
} else {
paddle::imperative::TensorAdd<paddle::Tensor>(t_values,
&tensor_values);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册