未验证 提交 3380778f 编写于 作者: W wangzhen38 提交者: GitHub

limit chunk.axis (#34630)

上级 aab4d6e4
......@@ -4890,6 +4890,7 @@ def split(input, num_or_sections, dim=-1, name=None):
if isinstance(dim, Variable):
dim = dim.numpy()
dim = dim.item(0)
assert len(input.shape) + dim >= 0, "(rank(x) + axis) must >= 0"
dim = (len(input.shape) + dim) if dim < 0 else dim
attrs += ('axis', dim)
......@@ -4951,6 +4952,7 @@ def split(input, num_or_sections, dim=-1, name=None):
dim.stop_gradient = True
inputs['AxisTensor'] = dim
else:
assert len(input.shape) + dim >= 0, "(rank(x) + axis) must >= 0"
dim = (len(input_shape) + dim) if dim < 0 else dim
attrs['axis'] = dim
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册