diff --git a/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py b/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py index f286296fe5560699108e0388c7862307e6cde876..033ee7908866d7468e773e68befe519f560cd68b 100644 --- a/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_spectral_norm_op.py @@ -149,6 +149,24 @@ class TestSpectralNormOpError(unittest.TestCase): # the data type of type must be float32 or float64 self.assertRaises(TypeError, test_weight_dtype) + def test_dim_out_of_range_1(): + weight_3 = np.random.random((2, 4)).astype("float32") + tensor_3 = paddle.to_tensor(weight_3) + paddle.static.nn.spectral_norm( + tensor_3, dim=1382376303, power_iters=2 + ) + + # the dim must be 0 or 1 + self.assertRaises(ValueError, test_dim_out_of_range_1) + + def test_dim_out_of_range_2(): + weight_4 = np.random.random((2, 4)).astype("float32") + tensor_4 = paddle.to_tensor(weight_4) + paddle.static.nn.spectral_norm(tensor_4, dim=-1, power_iters=2) + + # the dim must be 0 or 1 + self.assertRaises(ValueError, test_dim_out_of_range_2) + class TestDygraphSpectralNormOpError(unittest.TestCase): def test_errors(self): diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index 0b278eefa1551fe9e7bf8dd8be48738687843603..5da81feb3369d9e4894dbef35fa0f36d0f6a5feb 100644 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -3418,11 +3418,12 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None): # create intput and parameters input_shape = weight.shape assert weight.numel() > 0, "Any dimension of input cannot be equal to 0." - assert dim < len(input_shape), ( - "The input `dim` should be less than the " - "rank of `weight`, but received dim=" - "{}".format(dim) - ) + + if dim not in [0, 1]: + raise ValueError( + f"The input `dim` must be 0 (if weight in fc) or 1 (if weight in conv), but received dim={dim}" + ) + h = input_shape[dim] w = np.prod(input_shape) // h