From 942ab42f05018e828ad9821d6a63be85e8221a04 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Tue, 25 Oct 2022 14:42:17 +0800 Subject: [PATCH] [Sparse] Fix indices (#47190) (#47226) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 当前无法从Tensor中获取到SparseTensor的sparse_dim,无法准确推断出indices的shape,所以目前先以3D点云模型为主,输入的SparseTensor的维度是5D的,其中非零元素是一维向量,所以indices是[4, -1]。 --- paddle/phi/infermeta/sparse/unary.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/paddle/phi/infermeta/sparse/unary.cc b/paddle/phi/infermeta/sparse/unary.cc index 45cb4f75e38..f80f18bbba8 100644 --- a/paddle/phi/infermeta/sparse/unary.cc +++ b/paddle/phi/infermeta/sparse/unary.cc @@ -20,7 +20,11 @@ namespace phi { namespace sparse { void IndicesInferMeta(const MetaTensor& x, MetaTensor* out) { - out->set_dims({-1}); + // TODO(zhangkaihuo) Currently, we cannot get sparse_dim from tensor. + // correct shape is: shape[0] = x.sparse_dim() + // In the 3D point cloud model: + // the input x is 5-D tensor, non_zero_elements is 1-D tensor + out->set_dims({x.dims().size() - 1, -1}); out->set_dtype(DataType::INT32); out->set_layout(DataLayout::NCHW); } -- GitLab