未验证 提交 749bc240 编写于 作者: J Jiawei Wang 提交者: GitHub

cherry-pick #36021 fix unique/unstack zero tensor (#36163)

* fix unique unstack dim 0

* fix unique_op format
上级 1db28fd9
......@@ -403,7 +403,10 @@ class UniqueKernel : public framework::OpKernel<T> {
bool return_index = context.Attr<bool>("return_index");
bool return_inverse = context.Attr<bool>("return_inverse");
bool return_counts = context.Attr<bool>("return_counts");
if (x->numel() == 0) {
out->mutable_data<T>(context.GetPlace());
return;
}
if (axis_vec.empty()) {
framework::VisitDataTypeTiny(
data_type,
......
......@@ -149,7 +149,7 @@ class UnStackKernel : public framework::OpKernel<T> {
dx_datas[i] = dx[i]->mutable_data<T>(ctx.GetPlace());
}
auto dy_data = dy->data<T>();
if (dy->numel() == 0) return;
int pre = 1;
for (int i = 0; i < axis; ++i) pre *= dy->dims()[i];
int total_num = dy->numel();
......
......@@ -10315,6 +10315,8 @@ def unstack(x, axis=0, num=None):
if in_dygraph_mode():
if num == None:
num = x.shape[axis]
if num == 0:
return []
return _C_ops.unstack(x, num, 'axis', int(axis), 'num', num)
helper = LayerHelper('unstack', **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册