未验证 提交 1aa2bde0 编写于 作者: S shangliang Xu 提交者: GitHub

[bug fix] fix spectral_norm bug (#35005)

上级 096b0f2e
...@@ -3062,6 +3062,12 @@ class SpectralNorm(layers.Layer): ...@@ -3062,6 +3062,12 @@ class SpectralNorm(layers.Layer):
self._dtype = dtype self._dtype = dtype
self._weight_shape = list(weight_shape) self._weight_shape = list(weight_shape)
assert np.prod(self._weight_shape) > 0,\
"Any dimension of `weight_shape` cannot be equal to 0."
assert dim < len(self._weight_shape), \
("The input `dim` should be less than the "
"length of `weight_shape`, but received dim="
"{}".format(dim))
h = self._weight_shape[self._dim] h = self._weight_shape[self._dim]
w = np.prod(self._weight_shape) // h w = np.prod(self._weight_shape) // h
......
...@@ -3720,6 +3720,10 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None): ...@@ -3720,6 +3720,10 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None):
# create intput and parameters # create intput and parameters
inputs = {'Weight': weight} inputs = {'Weight': weight}
input_shape = weight.shape input_shape = weight.shape
assert weight.numel() > 0, "Any dimension of input cannot be equal to 0."
assert dim < len(input_shape), ("The input `dim` should be less than the "
"rank of `weight`, but received dim="
"{}".format(dim))
h = input_shape[dim] h = input_shape[dim]
w = np.prod(input_shape) // h w = np.prod(input_shape) // h
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册