diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 37b67ea993f779bb1aace08384e9f7804b05d091..dc1e56f13f3b1d4557affb7ed44cdd68e4c88d1f 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4890,6 +4890,7 @@ def split(input, num_or_sections, dim=-1, name=None): if isinstance(dim, Variable): dim = dim.numpy() dim = dim.item(0) + assert len(input.shape) + dim >= 0, "(rank(x) + axis) must >= 0" dim = (len(input.shape) + dim) if dim < 0 else dim attrs += ('axis', dim) @@ -4951,6 +4952,7 @@ def split(input, num_or_sections, dim=-1, name=None): dim.stop_gradient = True inputs['AxisTensor'] = dim else: + assert len(input.shape) + dim >= 0, "(rank(x) + axis) must >= 0" dim = (len(input_shape) + dim) if dim < 0 else dim attrs['axis'] = dim