From b1d44bfc7b2aeb3c55167b11f04eeca0c9a25a21 Mon Sep 17 00:00:00 2001 From: RedContritio Date: Mon, 30 Jan 2023 18:18:04 +0800 Subject: [PATCH] Fix Python IndexError of case7: paddle.static.nn.spectral_norm (#49988) * add dim check for spectral_norm * add unittest out of range for spectral_norm * use ValueError when dim out of range for spectral_norm * update dim limit and add unittest for spectral_norm --- .../tests/unittests/test_spectral_norm_op.py | 18 ++++++++++++++++++ python/paddle/static/nn/common.py | 11 ++++++----- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py b/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py index f286296fe55..033ee790886 100644 --- a/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py @@ -149,6 +149,24 @@ class TestSpectralNormOpError(unittest.TestCase): # the data type of type must be float32 or float64 self.assertRaises(TypeError, test_weight_dtype) + def test_dim_out_of_range_1(): + weight_3 = np.random.random((2, 4)).astype("float32") + tensor_3 = paddle.to_tensor(weight_3) + paddle.static.nn.spectral_norm( + tensor_3, dim=1382376303, power_iters=2 + ) + + # the dim must be 0 or 1 + self.assertRaises(ValueError, test_dim_out_of_range_1) + + def test_dim_out_of_range_2(): + weight_4 = np.random.random((2, 4)).astype("float32") + tensor_4 = paddle.to_tensor(weight_4) + paddle.static.nn.spectral_norm(tensor_4, dim=-1, power_iters=2) + + # the dim must be 0 or 1 + self.assertRaises(ValueError, test_dim_out_of_range_2) + class TestDygraphSpectralNormOpError(unittest.TestCase): def test_errors(self): diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index 0b278eefa15..5da81feb336 100644 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -3418,11 +3418,12 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None): # create intput and parameters input_shape = weight.shape assert weight.numel() > 0, "Any dimension of input cannot be equal to 0." - assert dim < len(input_shape), ( - "The input `dim` should be less than the " - "rank of `weight`, but received dim=" - "{}".format(dim) - ) + + if dim not in [0, 1]: + raise ValueError( + f"The input `dim` must be 0 (if weight in fc) or 1 (if weight in conv), but received dim={dim}" + ) + h = input_shape[dim] w = np.prod(input_shape) // h -- GitLab