未验证 提交 52de7fd8 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #17043 from tink2123/fix_split

fix split for dimension judgment
...@@ -4883,7 +4883,7 @@ def split(input, num_or_sections, dim=-1, name=None): ...@@ -4883,7 +4883,7 @@ def split(input, num_or_sections, dim=-1, name=None):
assert num_or_sections > 1, 'num_or_sections must be more than 1.' assert num_or_sections > 1, 'num_or_sections must be more than 1.'
num = num_or_sections num = num_or_sections
else: else:
assert len(num_or_sections) < input_shape[ assert len(num_or_sections) <= input_shape[
dim], 'len(num_or_sections) must not be more than input.shape[dim].' dim], 'len(num_or_sections) must not be more than input.shape[dim].'
num = len(num_or_sections) num = len(num_or_sections)
outs = [ outs = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册