From 183c0dd0a768416fb7b553c7b3051c18fcaae301 Mon Sep 17 00:00:00 2001 From: Surya <116063290+SuryanarayanaY@users.noreply.github.com> Date: Fri, 1 Sep 2023 18:02:54 +0530 Subject: [PATCH] validated axis argument in swapaxes validated axis argument in swapaxes for raising exception for invalid axis argument --- tensorflow/python/ops/numpy_ops/np_array_ops.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/ops/numpy_ops/np_array_ops.py b/tensorflow/python/ops/numpy_ops/np_array_ops.py index 316f7f82141..62af34908d7 100644 --- a/tensorflow/python/ops/numpy_ops/np_array_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_array_ops.py @@ -819,6 +819,11 @@ def swapaxes(a, axis1, axis2): # pylint: disable=missing-docstring x = array_ops.where_v2(x < 0, np_utils.add(x, a_rank), x) return x return nest.map_structure(f, axes) + + if axis2 < -array_ops.rank(a) or axis2 >= array_ops.rank(a): + raise ValueError( + f"Argument `axis` = {axis2} not in range " + f"[{-array_ops.rank(a)}, {array_ops.rank(a)})") if (a.shape.rank is not None and isinstance(axis1, int) and isinstance(axis2, int)): @@ -1066,10 +1071,6 @@ def stack(arrays, axis=0): # pylint: disable=missing-function-docstring arrays = asarray(arrays) if axis == 0: return arrays - elif axis < -array_ops.rank(arrays) or axis >= array_ops.rank(arrays): - raise ValueError( - f"Argument `axis` = {axis} not in range of " - f"[{-array_ops.rank(arrays)}, {array_ops.rank(arrays)})") else: return swapaxes(arrays, 0, axis) arrays = _promote_dtype(*arrays) # pylint: disable=protected-access -- GitLab