未验证 提交 b1d44bfc 编写于 作者: R RedContritio 提交者: GitHub

Fix Python IndexError of case7: paddle.static.nn.spectral_norm (#49988)

* add dim check for spectral_norm

* add unittest out of range for spectral_norm

* use ValueError when dim out of range for spectral_norm

* update dim limit and add unittest for spectral_norm
上级 6f8ec229
......@@ -149,6 +149,24 @@ class TestSpectralNormOpError(unittest.TestCase):
# the data type of type must be float32 or float64
self.assertRaises(TypeError, test_weight_dtype)
def test_dim_out_of_range_1():
weight_3 = np.random.random((2, 4)).astype("float32")
tensor_3 = paddle.to_tensor(weight_3)
paddle.static.nn.spectral_norm(
tensor_3, dim=1382376303, power_iters=2
)
# the dim must be 0 or 1
self.assertRaises(ValueError, test_dim_out_of_range_1)
def test_dim_out_of_range_2():
weight_4 = np.random.random((2, 4)).astype("float32")
tensor_4 = paddle.to_tensor(weight_4)
paddle.static.nn.spectral_norm(tensor_4, dim=-1, power_iters=2)
# the dim must be 0 or 1
self.assertRaises(ValueError, test_dim_out_of_range_2)
class TestDygraphSpectralNormOpError(unittest.TestCase):
def test_errors(self):
......
......@@ -3418,11 +3418,12 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None):
# create intput and parameters
input_shape = weight.shape
assert weight.numel() > 0, "Any dimension of input cannot be equal to 0."
assert dim < len(input_shape), (
"The input `dim` should be less than the "
"rank of `weight`, but received dim="
"{}".format(dim)
)
if dim not in [0, 1]:
raise ValueError(
f"The input `dim` must be 0 (if weight in fc) or 1 (if weight in conv), but received dim={dim}"
)
h = input_shape[dim]
w = np.prod(input_shape) // h
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册