未验证 提交 31b92305 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] fix spectral_norm interface under eager mode (#45655)

* [Eager] fix spectral_norm interface under eager mode

* fix spectral_norm yaml config
上级 b0b8f1d7
......@@ -2496,7 +2496,7 @@
infer_meta :
func : SpectralNormInferMeta
kernel :
func : spectralnorm
func : spectral_norm
data_type : weight
backward : spectral_norm_grad
......
......@@ -3163,6 +3163,10 @@ class SpectralNorm(layers.Layer):
self.weight_v.stop_gradient = True
def forward(self, weight):
if in_dygraph_mode():
return _C_ops.spectral_norm(weight, self.weight_u, self.weight_v,
self._dim, self._power_iters, self._eps)
check_variable_and_dtype(weight, "weight", ['float32', 'float64'],
'SpectralNorm')
inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v}
......
......@@ -3933,7 +3933,6 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None):
dtype = weight.dtype
# create intput and parameters
inputs = {'Weight': weight}
input_shape = weight.shape
assert weight.numel() > 0, "Any dimension of input cannot be equal to 0."
assert dim < len(input_shape), ("The input `dim` should be less than the "
......@@ -3947,14 +3946,19 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None):
dtype=dtype,
default_initializer=Normal(0., 1.))
u.stop_gradient = True
inputs['U'] = u
v = helper.create_parameter(attr=ParamAttr(),
shape=[w],
dtype=dtype,
default_initializer=Normal(0., 1.))
inputs['V'] = v
v.stop_gradient = True
if in_dygraph_mode():
return _C_ops.spectral_norm(weight, u, v, dim, power_iters, eps)
inputs = {'Weight': weight}
inputs['U'] = u
inputs['V'] = v
# create output
out = helper.create_variable(dtype=dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册