From 1aa2bde0e27fe7a1f1699e8aff0c6de5621844f2 Mon Sep 17 00:00:00 2001 From: shangliang Xu Date: Fri, 20 Aug 2021 11:21:25 +0800 Subject: [PATCH] [bug fix] fix spectral_norm bug (#35005) --- python/paddle/fluid/dygraph/nn.py | 6 ++++++ python/paddle/fluid/layers/nn.py | 4 ++++ 2 files changed, 10 insertions(+) diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 608e85acec3..d9a431990c1 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -3062,6 +3062,12 @@ class SpectralNorm(layers.Layer): self._dtype = dtype self._weight_shape = list(weight_shape) + assert np.prod(self._weight_shape) > 0,\ + "Any dimension of `weight_shape` cannot be equal to 0." + assert dim < len(self._weight_shape), \ + ("The input `dim` should be less than the " + "length of `weight_shape`, but received dim=" + "{}".format(dim)) h = self._weight_shape[self._dim] w = np.prod(self._weight_shape) // h diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index d0d15e92bfb..bd7ecfeee65 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3720,6 +3720,10 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None): # create intput and parameters inputs = {'Weight': weight} input_shape = weight.shape + assert weight.numel() > 0, "Any dimension of input cannot be equal to 0." + assert dim < len(input_shape), ("The input `dim` should be less than the " + "rank of `weight`, but received dim=" + "{}".format(dim)) h = input_shape[dim] w = np.prod(input_shape) // h -- GitLab