未验证 提交 942ab42f 编写于 作者: Z zhangkaihuo 提交者: GitHub

[Sparse] Fix indices (#47190) (#47226)

当前无法从Tensor中获取到SparseTensor的sparse_dim,无法准确推断出indices的shape,所以目前先以3D点云模型为主,输入的SparseTensor的维度是5D的,其中非零元素是一维向量,所以indices是[4, -1]。
上级 99d8ba47
...@@ -20,7 +20,11 @@ namespace phi { ...@@ -20,7 +20,11 @@ namespace phi {
namespace sparse { namespace sparse {
void IndicesInferMeta(const MetaTensor& x, MetaTensor* out) { void IndicesInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_dims({-1}); // TODO(zhangkaihuo) Currently, we cannot get sparse_dim from tensor.
// correct shape is: shape[0] = x.sparse_dim()
// In the 3D point cloud model:
// the input x is 5-D tensor, non_zero_elements is 1-D tensor
out->set_dims({x.dims().size() - 1, -1});
out->set_dtype(DataType::INT32); out->set_dtype(DataType::INT32);
out->set_layout(DataLayout::NCHW); out->set_layout(DataLayout::NCHW);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册