diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 6bf437b4dfbd06847122d64dada14ad3accc8181..cd3b7354ed9eb2a0a0542ff0a18f1c9922e4cbe2 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4883,7 +4883,7 @@ def split(input, num_or_sections, dim=-1, name=None): assert num_or_sections > 1, 'num_or_sections must be more than 1.' num = num_or_sections else: - assert len(num_or_sections) < input_shape[ + assert len(num_or_sections) <= input_shape[ dim], 'len(num_or_sections) must not be more than input.shape[dim].' num = len(num_or_sections) outs = [