From 4681f13b6612ee203ecb12e8174dac984fd0387e Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 12 Oct 2022 15:25:02 +0800 Subject: [PATCH] deliver indices_dict (#46919) --- paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu | 1 + paddle/phi/kernels/sparse/impl/unary_kernel_impl.h | 3 +++ 2 files changed, 4 insertions(+) diff --git a/paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu b/paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu index e10d762c886..d499cdf54ab 100644 --- a/paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu @@ -169,6 +169,7 @@ void CoalesceCooGPUKernel(const GPUContext& dev_ctx, indexs_ptr, const_dims, out_nnz, sparse_dim, out_indices.data()); out->SetMember(out_indices, out_values, x.dims(), true); + out->SetIndicesDict(x.GetIndicesDict()); } template diff --git a/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h b/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h index 9b8b33d4d3a..a4b89fd8132 100644 --- a/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h +++ b/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h @@ -37,6 +37,7 @@ namespace sparse { EmptyLikeCooKernel(dev_ctx, x, out); \ phi::prefix##Kernel( \ dev_ctx, x.non_zero_elements(), out->mutable_non_zero_elements()); \ + out->SetIndicesDict(x.GetIndicesDict()); \ } \ \ template \ @@ -105,6 +106,7 @@ void ScaleCooKernel(const Context& dev_ctx, bias, bias_after_scale, out->mutable_non_zero_elements()); + out->SetIndicesDict(x.GetIndicesDict()); } template @@ -155,6 +157,7 @@ void CastCooKernel(const Context& dev_ctx, meta.set_dtype(value_dtype); phi::CastKernel(dev_ctx, x_values, value_dtype, out_values); } + out->SetIndicesDict(x.GetIndicesDict()); } template -- GitLab