diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 0afe20af5c85489e769d30a2fb6e45add4fb916e..ff218b1756ffc5811205fb82716802eab622324a 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -2496,7 +2496,7 @@ infer_meta : func : SpectralNormInferMeta kernel : - func : spectralnorm + func : spectral_norm data_type : weight backward : spectral_norm_grad diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index a21edd8cb5744374ceb9bef8fc67a2352505442f..09030e9441c89dcc95e2add3de72f4e8383bef0e 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -3163,6 +3163,10 @@ class SpectralNorm(layers.Layer): self.weight_v.stop_gradient = True def forward(self, weight): + if in_dygraph_mode(): + return _C_ops.spectral_norm(weight, self.weight_u, self.weight_v, + self._dim, self._power_iters, self._eps) + check_variable_and_dtype(weight, "weight", ['float32', 'float64'], 'SpectralNorm') inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v} diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 4a17ad0b4ca1e7a968dc998868af10a581bb2fca..cf138cc5a4f05e1904c838e634c62c817baaa6b3 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3933,7 +3933,6 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None): dtype = weight.dtype # create intput and parameters - inputs = {'Weight': weight} input_shape = weight.shape assert weight.numel() > 0, "Any dimension of input cannot be equal to 0." assert dim < len(input_shape), ("The input `dim` should be less than the " @@ -3947,14 +3946,19 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None): dtype=dtype, default_initializer=Normal(0., 1.)) u.stop_gradient = True - inputs['U'] = u v = helper.create_parameter(attr=ParamAttr(), shape=[w], dtype=dtype, default_initializer=Normal(0., 1.)) - inputs['V'] = v v.stop_gradient = True + if in_dygraph_mode(): + return _C_ops.spectral_norm(weight, u, v, dim, power_iters, eps) + + inputs = {'Weight': weight} + inputs['U'] = u + inputs['V'] = v + # create output out = helper.create_variable(dtype=dtype)