未验证 提交 0111a354 编写于 作者: R Ryan 提交者: GitHub

[Fix IndexError] add unstack axis check (#49943)

* add unstack axis check

* IndexErr -> ValueError

* add static select
上级 f6e874bc
......@@ -84,5 +84,29 @@ class TestStackOp6(TestUnStackOpBase):
self.axis = 2
class TestUnstackZeroInputOp(unittest.TestCase):
def unstack_zero_input_static(self):
paddle.enable_static()
array = np.array([], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [0]), dtype='float32')
paddle.unstack(x, axis=1)
def unstack_zero_input_dynamic(self):
array = np.array([], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [0]), dtype='float32')
paddle.unstack(x, axis=1)
def test_type_error(self):
paddle.disable_static()
self.assertRaises(ValueError, self.unstack_zero_input_dynamic)
self.assertRaises(ValueError, self.unstack_zero_input_static)
paddle.disable_static()
if __name__ == '__main__':
unittest.main()
......@@ -543,6 +543,10 @@ def unstack(x, axis=0, num=None):
y = paddle.unstack(x, axis=1) # unstack with second axis, which results 3 tensors with shape=[2, 5]
"""
if not (-x.ndim <= axis < x.ndim):
raise ValueError(
'`axis` must be in the range [-{0}, {0})'.format(x.ndim)
)
if in_dygraph_mode():
if num is None:
num = x.shape[axis]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册