未验证 提交 31b92305 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] fix spectral_norm interface under eager mode (#45655)

* [Eager] fix spectral_norm interface under eager mode

* fix spectral_norm yaml config
上级 b0b8f1d7
...@@ -2496,7 +2496,7 @@ ...@@ -2496,7 +2496,7 @@
infer_meta : infer_meta :
func : SpectralNormInferMeta func : SpectralNormInferMeta
kernel : kernel :
func : spectralnorm func : spectral_norm
data_type : weight data_type : weight
backward : spectral_norm_grad backward : spectral_norm_grad
......
...@@ -3163,6 +3163,10 @@ class SpectralNorm(layers.Layer): ...@@ -3163,6 +3163,10 @@ class SpectralNorm(layers.Layer):
self.weight_v.stop_gradient = True self.weight_v.stop_gradient = True
def forward(self, weight): def forward(self, weight):
if in_dygraph_mode():
return _C_ops.spectral_norm(weight, self.weight_u, self.weight_v,
self._dim, self._power_iters, self._eps)
check_variable_and_dtype(weight, "weight", ['float32', 'float64'], check_variable_and_dtype(weight, "weight", ['float32', 'float64'],
'SpectralNorm') 'SpectralNorm')
inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v} inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v}
......
...@@ -3933,7 +3933,6 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None): ...@@ -3933,7 +3933,6 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None):
dtype = weight.dtype dtype = weight.dtype
# create intput and parameters # create intput and parameters
inputs = {'Weight': weight}
input_shape = weight.shape input_shape = weight.shape
assert weight.numel() > 0, "Any dimension of input cannot be equal to 0." assert weight.numel() > 0, "Any dimension of input cannot be equal to 0."
assert dim < len(input_shape), ("The input `dim` should be less than the " assert dim < len(input_shape), ("The input `dim` should be less than the "
...@@ -3947,14 +3946,19 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None): ...@@ -3947,14 +3946,19 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None):
dtype=dtype, dtype=dtype,
default_initializer=Normal(0., 1.)) default_initializer=Normal(0., 1.))
u.stop_gradient = True u.stop_gradient = True
inputs['U'] = u
v = helper.create_parameter(attr=ParamAttr(), v = helper.create_parameter(attr=ParamAttr(),
shape=[w], shape=[w],
dtype=dtype, dtype=dtype,
default_initializer=Normal(0., 1.)) default_initializer=Normal(0., 1.))
inputs['V'] = v
v.stop_gradient = True v.stop_gradient = True
if in_dygraph_mode():
return _C_ops.spectral_norm(weight, u, v, dim, power_iters, eps)
inputs = {'Weight': weight}
inputs['U'] = u
inputs['V'] = v
# create output # create output
out = helper.create_variable(dtype=dtype) out = helper.create_variable(dtype=dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册