未验证 提交 183c0dd0 编写于 作者: S Surya 提交者: GitHub

validated axis argument in swapaxes

validated axis argument in swapaxes for raising exception for invalid axis argument
上级 ad5a72f2
...@@ -819,6 +819,11 @@ def swapaxes(a, axis1, axis2): # pylint: disable=missing-docstring ...@@ -819,6 +819,11 @@ def swapaxes(a, axis1, axis2): # pylint: disable=missing-docstring
x = array_ops.where_v2(x < 0, np_utils.add(x, a_rank), x) x = array_ops.where_v2(x < 0, np_utils.add(x, a_rank), x)
return x return x
return nest.map_structure(f, axes) return nest.map_structure(f, axes)
if axis2 < -array_ops.rank(a) or axis2 >= array_ops.rank(a):
raise ValueError(
f"Argument `axis` = {axis2} not in range "
f"[{-array_ops.rank(a)}, {array_ops.rank(a)})")
if (a.shape.rank is not None and if (a.shape.rank is not None and
isinstance(axis1, int) and isinstance(axis2, int)): isinstance(axis1, int) and isinstance(axis2, int)):
...@@ -1066,10 +1071,6 @@ def stack(arrays, axis=0): # pylint: disable=missing-function-docstring ...@@ -1066,10 +1071,6 @@ def stack(arrays, axis=0): # pylint: disable=missing-function-docstring
arrays = asarray(arrays) arrays = asarray(arrays)
if axis == 0: if axis == 0:
return arrays return arrays
elif axis < -array_ops.rank(arrays) or axis >= array_ops.rank(arrays):
raise ValueError(
f"Argument `axis` = {axis} not in range of "
f"[{-array_ops.rank(arrays)}, {array_ops.rank(arrays)})")
else: else:
return swapaxes(arrays, 0, axis) return swapaxes(arrays, 0, axis)
arrays = _promote_dtype(*arrays) # pylint: disable=protected-access arrays = _promote_dtype(*arrays) # pylint: disable=protected-access
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册