diff --git a/paddle/phi/infermeta/sparse/unary.cc b/paddle/phi/infermeta/sparse/unary.cc index 45cb4f75e38c970beb661960583c44faebc3dd95..f80f18bbba857ab9f0ed8e59cff34f026869ddd8 100644 --- a/paddle/phi/infermeta/sparse/unary.cc +++ b/paddle/phi/infermeta/sparse/unary.cc @@ -20,7 +20,11 @@ namespace phi { namespace sparse { void IndicesInferMeta(const MetaTensor& x, MetaTensor* out) { - out->set_dims({-1}); + // TODO(zhangkaihuo) Currently, we cannot get sparse_dim from tensor. + // correct shape is: shape[0] = x.sparse_dim() + // In the 3D point cloud model: + // the input x is 5-D tensor, non_zero_elements is 1-D tensor + out->set_dims({x.dims().size() - 1, -1}); out->set_dtype(DataType::INT32); out->set_layout(DataLayout::NCHW); }