From 31b923056bba41fbe85e263372c3a95bceff7b90 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Fri, 2 Sep 2022 14:20:24 +0800 Subject: [PATCH] [Eager] fix spectral_norm interface under eager mode (#45655) * [Eager] fix spectral_norm interface under eager mode * fix spectral_norm yaml config --- paddle/phi/api/yaml/legacy_api.yaml | 2 +- python/paddle/fluid/dygraph/nn.py | 4 ++++ python/paddle/fluid/layers/nn.py | 10 +++++++--- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 0afe20af5c8..ff218b1756f 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -2496,7 +2496,7 @@ infer_meta : func : SpectralNormInferMeta kernel : - func : spectralnorm + func : spectral_norm data_type : weight backward : spectral_norm_grad diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index a21edd8cb57..09030e9441c 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -3163,6 +3163,10 @@ class SpectralNorm(layers.Layer): self.weight_v.stop_gradient = True def forward(self, weight): + if in_dygraph_mode(): + return _C_ops.spectral_norm(weight, self.weight_u, self.weight_v, + self._dim, self._power_iters, self._eps) + check_variable_and_dtype(weight, "weight", ['float32', 'float64'], 'SpectralNorm') inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v} diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 4a17ad0b4ca..cf138cc5a4f 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3933,7 +3933,6 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None): dtype = weight.dtype # create intput and parameters - inputs = {'Weight': weight} input_shape = weight.shape assert weight.numel() > 0, "Any dimension of input cannot be equal to 0." assert dim < len(input_shape), ("The input `dim` should be less than the " @@ -3947,14 +3946,19 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None): dtype=dtype, default_initializer=Normal(0., 1.)) u.stop_gradient = True - inputs['U'] = u v = helper.create_parameter(attr=ParamAttr(), shape=[w], dtype=dtype, default_initializer=Normal(0., 1.)) - inputs['V'] = v v.stop_gradient = True + if in_dygraph_mode(): + return _C_ops.spectral_norm(weight, u, v, dim, power_iters, eps) + + inputs = {'Weight': weight} + inputs['U'] = u + inputs['V'] = v + # create output out = helper.create_variable(dtype=dtype) -- GitLab