提交 ec6499bc 编写于 作者: M miraiwk

reback nn.py

上级 982a2b5a
......@@ -1375,7 +1375,7 @@ class BatchNorm(layers.Layer):
outputs = {
"Y": [batch_norm_out],
"MeanOut": [],
"MeanOut": [mean_out],
"VarianceOut": [variance_out],
"SavedMean": [saved_mean],
"SavedVariance": [saved_variance]
......@@ -3031,11 +3031,9 @@ class SpectralNorm(layers.Layer):
dim(int, optional): The index of dimension which should be permuted to the first before reshaping Input(Weight) to matrix, it should be set as 0 if Input(Weight) is the weight of fc layer, and should be set as 1 if Input(Weight) is the weight of conv layer. Default: 0.
power_iters(int, optional): The number of power iterations to calculate spectral norm. Default: 1.
eps(float, optional): The epsilon for numerical stability in calculating norms. Default: 1e-12.
fix_state(bool, optional): whether to update the two vectors `u` and `v`. Default: True.
name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` .
dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32".
Returns:
None
......@@ -3057,12 +3055,10 @@ class SpectralNorm(layers.Layer):
dim=0,
power_iters=1,
eps=1e-12,
fix_state=True,
dtype='float32'):
super(SpectralNorm, self).__init__()
self._power_iters = power_iters
self._eps = eps
self._fix_state = fix_state
self._dim = dim
self._dtype = dtype
......@@ -3084,31 +3080,10 @@ class SpectralNorm(layers.Layer):
default_initializer=Normal(0., 1.))
self.weight_v.stop_gradient = True
if fix_state:
self.out_weight_u = self.create_parameter(
attr=ParamAttr(),
shape=[h],
dtype=self._dtype,
default_initializer=Normal(0., 1.))
self.out_weight_u.stop_gradient = True
self.out_weight_v = self.create_parameter(
attr=ParamAttr(),
shape=[w],
dtype=self._dtype,
default_initializer=Normal(0., 1.))
self.out_weight_v.stop_gradient = True
else:
self.out_weight_u = self.weight_u
self.out_weight_v = self.weight_v
def forward(self, weight):
check_variable_and_dtype(weight, "weight", ['float32', 'float64'],
'SpectralNorm')
inputs = {
'Weight': weight, 'U': self.weight_u, 'V': self.weight_v,
'UOut': self.out_weight_u, 'VOut': self.out_weight_v,
}
inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v}
out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="spectral_norm",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册