From 3380778f8b644abdef47b8f0b0a577ed745f2c59 Mon Sep 17 00:00:00 2001 From: wangzhen38 <41941775+wangzhen38@users.noreply.github.com> Date: Mon, 9 Aug 2021 15:11:15 +0800 Subject: [PATCH] limit chunk.axis (#34630) --- python/paddle/fluid/layers/nn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 37b67ea993f..dc1e56f13f3 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4890,6 +4890,7 @@ def split(input, num_or_sections, dim=-1, name=None): if isinstance(dim, Variable): dim = dim.numpy() dim = dim.item(0) + assert len(input.shape) + dim >= 0, "(rank(x) + axis) must >= 0" dim = (len(input.shape) + dim) if dim < 0 else dim attrs += ('axis', dim) @@ -4951,6 +4952,7 @@ def split(input, num_or_sections, dim=-1, name=None): dim.stop_gradient = True inputs['AxisTensor'] = dim else: + assert len(input.shape) + dim >= 0, "(rank(x) + axis) must >= 0" dim = (len(input_shape) + dim) if dim < 0 else dim attrs['axis'] = dim -- GitLab