提交 56414c7d 编写于 作者: S songyouwei 提交者: hong

move private weight fields to public ones (#21982)

* move private weight properties to public ones
test=develop

* revert changes to FC
test=develop

* fix unittest
test=develop

* fix unittest
test=develop

* fix coverage
test=develop

* fix merged dev
test=develop

* bug fix
test=develop
上级 b852ef73
...@@ -206,38 +206,22 @@ class Conv2D(layers.Layer): ...@@ -206,38 +206,22 @@ class Conv2D(layers.Layer):
std = (2.0 / filter_elem_num)**0.5 std = (2.0 / filter_elem_num)**0.5
return Normal(0.0, std, 0) return Normal(0.0, std, 0)
self._filter_param = self.create_parameter( self.weight = self.create_parameter(
attr=self._param_attr, attr=self._param_attr,
shape=filter_shape, shape=filter_shape,
dtype=self._dtype, dtype=self._dtype,
default_initializer=_get_default_param_initializer()) default_initializer=_get_default_param_initializer())
self._bias_param = self.create_parameter( self.bias = self.create_parameter(
attr=self._bias_attr, attr=self._bias_attr,
shape=[self._num_filters], shape=[self._num_filters],
dtype=self._dtype, dtype=self._dtype,
is_bias=True) is_bias=True)
@property
def weight(self):
return self._filter_param
@weight.setter
def weight(self, value):
self._filter_param = value
@property
def bias(self):
return self._bias_param
@bias.setter
def bias(self, value):
self._bias_param = value
def forward(self, input): def forward(self, input):
inputs = { inputs = {
'Input': [input], 'Input': [input],
'Filter': [self._filter_param], 'Filter': [self.weight],
} }
attrs = { attrs = {
'strides': self._stride, 'strides': self._stride,
...@@ -252,8 +236,8 @@ class Conv2D(layers.Layer): ...@@ -252,8 +236,8 @@ class Conv2D(layers.Layer):
outs = core.ops.conv2d(inputs, attrs) outs = core.ops.conv2d(inputs, attrs)
pre_bias = outs['Output'][0] pre_bias = outs['Output'][0]
pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, self.bias,
self._bias_param, 1) 1)
return dygraph_utils._append_activation_in_dygraph(pre_act, return dygraph_utils._append_activation_in_dygraph(pre_act,
self._act) self._act)
...@@ -265,18 +249,18 @@ class Conv2D(layers.Layer): ...@@ -265,18 +249,18 @@ class Conv2D(layers.Layer):
type=self._l_type, type=self._l_type,
inputs={ inputs={
'Input': input, 'Input': input,
'Filter': self._filter_param, 'Filter': self.weight,
}, },
outputs={"Output": pre_bias}, outputs={"Output": pre_bias},
attrs=attrs) attrs=attrs)
if self._bias_param is not None: if self.bias is not None:
pre_act = self._helper.create_variable_for_type_inference( pre_act = self._helper.create_variable_for_type_inference(
dtype=self._dtype) dtype=self._dtype)
self._helper.append_op( self._helper.append_op(
type='elementwise_add', type='elementwise_add',
inputs={'X': [pre_bias], inputs={'X': [pre_bias],
'Y': [self._bias_param]}, 'Y': [self.bias]},
outputs={'Out': [pre_act]}, outputs={'Out': [pre_act]},
attrs={'axis': 1}) attrs={'axis': 1})
else: else:
...@@ -441,34 +425,18 @@ class Conv3D(layers.Layer): ...@@ -441,34 +425,18 @@ class Conv3D(layers.Layer):
std = (2.0 / filter_elem_num)**0.5 std = (2.0 / filter_elem_num)**0.5
return Normal(0.0, std, 0) return Normal(0.0, std, 0)
self._filter_param = self.create_parameter( self.weight = self.create_parameter(
attr=self._param_attr, attr=self._param_attr,
shape=filter_shape, shape=filter_shape,
dtype=self._dtype, dtype=self._dtype,
default_initializer=_get_default_param_initializer()) default_initializer=_get_default_param_initializer())
self._bias_param = self.create_parameter( self.bias = self.create_parameter(
attr=self._bias_attr, attr=self._bias_attr,
shape=[self._num_filters], shape=[self._num_filters],
dtype=self._dtype, dtype=self._dtype,
is_bias=True) is_bias=True)
@property
def weight(self):
return self._filter_param
@weight.setter
def weight(self, value):
self._filter_param = value
@property
def bias(self):
return self._bias_param
@bias.setter
def bias(self, value):
self._bias_param = value
def forward(self, input): def forward(self, input):
pre_bias = self._helper.create_variable_for_type_inference( pre_bias = self._helper.create_variable_for_type_inference(
dtype=self._dtype) dtype=self._dtype)
...@@ -477,7 +445,7 @@ class Conv3D(layers.Layer): ...@@ -477,7 +445,7 @@ class Conv3D(layers.Layer):
type='conv3d', type='conv3d',
inputs={ inputs={
'Input': input, 'Input': input,
'Filter': self._filter_param, 'Filter': self.weight,
}, },
outputs={"Output": pre_bias}, outputs={"Output": pre_bias},
attrs={ attrs={
...@@ -489,13 +457,13 @@ class Conv3D(layers.Layer): ...@@ -489,13 +457,13 @@ class Conv3D(layers.Layer):
'use_mkldnn': False 'use_mkldnn': False
}) })
if self._bias_param is not None: if self.bias is not None:
pre_act = self._helper.create_variable_for_type_inference( pre_act = self._helper.create_variable_for_type_inference(
dtype=self._dtype) dtype=self._dtype)
self._helper.append_op( self._helper.append_op(
type='elementwise_add', type='elementwise_add',
inputs={'X': [pre_bias], inputs={'X': [pre_bias],
'Y': [self._bias_param]}, 'Y': [self.bias]},
outputs={'Out': [pre_act]}, outputs={'Out': [pre_act]},
attrs={'axis': 1}) attrs={'axis': 1})
else: else:
...@@ -681,38 +649,22 @@ class Conv3DTranspose(layers.Layer): ...@@ -681,38 +649,22 @@ class Conv3DTranspose(layers.Layer):
filter_shape = [self._num_channels, self._num_filters // self._groups filter_shape = [self._num_channels, self._num_filters // self._groups
] + self._filter_size ] + self._filter_size
self._img_filter = self.create_parameter( self.weight = self.create_parameter(
dtype=self._dtype, shape=filter_shape, attr=self._param_attr) dtype=self._dtype, shape=filter_shape, attr=self._param_attr)
if self._bias_attr: if self._bias_attr:
self._bias_param = self.create_parameter( self.bias = self.create_parameter(
attr=self._bias_attr, attr=self._bias_attr,
shape=[self._num_filters], shape=[self._num_filters],
dtype=self._dtype, dtype=self._dtype,
is_bias=True) is_bias=True)
@property
def weight(self):
return self._img_filter
@weight.setter
def weight(self, value):
self._img_filter = value
@property
def bias(self):
return self._bias_param
@bias.setter
def bias(self, value):
self._bias_param = value
def forward(self, input): def forward(self, input):
pre_bias = self._helper.create_variable_for_type_inference( pre_bias = self._helper.create_variable_for_type_inference(
dtype=self._dtype) dtype=self._dtype)
self._helper.append_op( self._helper.append_op(
type="conv3d_transpose", type="conv3d_transpose",
inputs={'Input': [input], inputs={'Input': [input],
'Filter': [self._img_filter]}, 'Filter': [self.weight]},
outputs={'Output': pre_bias}, outputs={'Output': pre_bias},
attrs={ attrs={
'strides': self._stride, 'strides': self._stride,
...@@ -728,7 +680,7 @@ class Conv3DTranspose(layers.Layer): ...@@ -728,7 +680,7 @@ class Conv3DTranspose(layers.Layer):
self._helper.append_op( self._helper.append_op(
type='elementwise_add', type='elementwise_add',
inputs={'X': [pre_bias], inputs={'X': [pre_bias],
'Y': [self._bias_param]}, 'Y': [self.bias]},
outputs={'Out': [pre_act]}, outputs={'Out': [pre_act]},
attrs={'axis': 1}) attrs={'axis': 1})
else: else:
...@@ -1345,21 +1297,19 @@ class BatchNorm(layers.Layer): ...@@ -1345,21 +1297,19 @@ class BatchNorm(layers.Layer):
param_shape = [num_channels] param_shape = [num_channels]
# create parameter # create parameter
self._scale = self.create_parameter( self.weight = self.create_parameter(
attr=self._param_attr, attr=self._param_attr,
shape=param_shape, shape=param_shape,
dtype=self._dtype, dtype=self._dtype,
default_initializer=Constant(1.0)) default_initializer=Constant(1.0))
if use_global_stats and self._param_attr.learning_rate == 0.: self.weight.stop_gradient = use_global_stats and self._param_attr.learning_rate == 0.
self._scale.stop_gradient = True
self._bias = self.create_parameter( self.bias = self.create_parameter(
attr=self._bias_attr, attr=self._bias_attr,
shape=param_shape, shape=param_shape,
dtype=self._dtype, dtype=self._dtype,
is_bias=True) is_bias=True)
if use_global_stats and self._param_attr.learning_rate == 0.: self.bias.stop_gradient = use_global_stats and self._param_attr.learning_rate == 0.
self._bias.stop_gradient = True
self._mean = self.create_parameter( self._mean = self.create_parameter(
attr=ParamAttr( attr=ParamAttr(
...@@ -1408,8 +1358,8 @@ class BatchNorm(layers.Layer): ...@@ -1408,8 +1358,8 @@ class BatchNorm(layers.Layer):
type="batch_norm", type="batch_norm",
inputs={ inputs={
"X": input, "X": input,
"Scale": self._scale, "Scale": self.weight,
"Bias": self._bias, "Bias": self.bias,
"Mean": self._mean, "Mean": self._mean,
"Variance": self._variance "Variance": self._variance
}, },
...@@ -1559,20 +1509,12 @@ class Embedding(layers.Layer): ...@@ -1559,20 +1509,12 @@ class Embedding(layers.Layer):
if self._remote_prefetch: if self._remote_prefetch:
assert self._is_sparse is True and self._is_distributed is False assert self._is_sparse is True and self._is_distributed is False
self._w = self.create_parameter( self.weight = self.create_parameter(
attr=self._param_attr, attr=self._param_attr,
shape=self._size, shape=self._size,
dtype=self._dtype, dtype=self._dtype,
is_bias=False) is_bias=False)
@property
def weight(self):
return self._w
@weight.setter
def weight(self, value):
self._w = value
def forward(self, input): def forward(self, input):
attrs = { attrs = {
'is_sparse': self._is_sparse, 'is_sparse': self._is_sparse,
...@@ -1582,7 +1524,7 @@ class Embedding(layers.Layer): ...@@ -1582,7 +1524,7 @@ class Embedding(layers.Layer):
} }
if in_dygraph_mode(): if in_dygraph_mode():
inputs = {'Ids': [input], 'W': [self._w]} inputs = {'Ids': [input], 'W': [self.weight]}
outs = core.ops.lookup_table_v2(inputs, attrs) outs = core.ops.lookup_table_v2(inputs, attrs)
return outs['Out'][0] return outs['Out'][0]
...@@ -1590,7 +1532,7 @@ class Embedding(layers.Layer): ...@@ -1590,7 +1532,7 @@ class Embedding(layers.Layer):
self._helper.append_op( self._helper.append_op(
type='lookup_table_v2', type='lookup_table_v2',
inputs={'Ids': input, inputs={'Ids': input,
'W': self._w}, 'W': self.weight},
outputs={'Out': out}, outputs={'Out': out},
attrs=attrs) attrs=attrs)
...@@ -1686,7 +1628,7 @@ class LayerNorm(layers.Layer): ...@@ -1686,7 +1628,7 @@ class LayerNorm(layers.Layer):
self._dtype = dtype self._dtype = dtype
param_shape = [np.prod(self._normalized_shape)] param_shape = [np.prod(self._normalized_shape)]
if self._scale: if self._scale:
self._scale_w = self.create_parameter( self.weight = self.create_parameter(
attr=self._param_attr, attr=self._param_attr,
shape=param_shape, shape=param_shape,
dtype=self._dtype, dtype=self._dtype,
...@@ -1697,7 +1639,7 @@ class LayerNorm(layers.Layer): ...@@ -1697,7 +1639,7 @@ class LayerNorm(layers.Layer):
if self._shift: if self._shift:
assert self._bias_attr is not False assert self._bias_attr is not False
self._bias_w = self.create_parameter( self.bias = self.create_parameter(
attr=self._bias_attr, attr=self._bias_attr,
shape=param_shape, shape=param_shape,
dtype=self._dtype, dtype=self._dtype,
...@@ -1721,9 +1663,9 @@ class LayerNorm(layers.Layer): ...@@ -1721,9 +1663,9 @@ class LayerNorm(layers.Layer):
inputs = dict() inputs = dict()
inputs['X'] = input inputs['X'] = input
if self._scale: if self._scale:
inputs['Scale'] = self._scale_w inputs['Scale'] = self.weight
if self._shift: if self._shift:
inputs['Bias'] = self._bias_w inputs['Bias'] = self.bias
# create output # create output
mean_out = self._helper.create_variable_for_type_inference( mean_out = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True) dtype=self._dtype, stop_gradient=True)
...@@ -1878,35 +1820,19 @@ class GRUUnit(layers.Layer): ...@@ -1878,35 +1820,19 @@ class GRUUnit(layers.Layer):
self._dtype = dtype self._dtype = dtype
size = size // 3 size = size // 3
# create weight # create weight
self._weight = self.create_parameter( self.weight = self.create_parameter(
attr=param_attr, shape=[size, 3 * size], dtype=dtype) attr=param_attr, shape=[size, 3 * size], dtype=dtype)
# create bias # create bias
bias_size = [1, 3 * size] bias_size = [1, 3 * size]
self._bias_size = bias_size self._bias_size = bias_size
self._bias = self.create_parameter( self.bias = self.create_parameter(
attr=bias_attr, shape=bias_size, dtype=dtype, is_bias=True) attr=bias_attr, shape=bias_size, dtype=dtype, is_bias=True)
@property
def weight(self):
return self._weight
@weight.setter
def weight(self, value):
self._weight = value
@property
def bias(self):
return self._bias
@bias.setter
def bias(self, value):
self._bias = value
def forward(self, input, hidden): def forward(self, input, hidden):
inputs = {'Input': input, 'HiddenPrev': hidden, 'Weight': self._weight} inputs = {'Input': input, 'HiddenPrev': hidden, 'Weight': self.weight}
if self._bias: if self.bias:
inputs['Bias'] = self._bias inputs['Bias'] = self.bias
gate = self._helper.create_variable_for_type_inference(self._dtype) gate = self._helper.create_variable_for_type_inference(self._dtype)
reset_hidden_pre = self._helper.create_variable_for_type_inference( reset_hidden_pre = self._helper.create_variable_for_type_inference(
...@@ -2122,35 +2048,19 @@ class NCE(layers.Layer): ...@@ -2122,35 +2048,19 @@ class NCE(layers.Layer):
'remote_prefetch': remote_prefetch 'remote_prefetch': remote_prefetch
} }
self._w = self.create_parameter( self.weight = self.create_parameter(
attr=self._param_attr, attr=self._param_attr,
shape=[self._num_total_classes, dim], shape=[self._num_total_classes, dim],
is_bias=False, is_bias=False,
dtype=self._dtype) dtype=self._dtype)
if self._bias_attr: if self._bias_attr:
self._b = self.create_parameter( self.bias = self.create_parameter(
attr=self._bias_attr, attr=self._bias_attr,
shape=[self._num_total_classes, 1], shape=[self._num_total_classes, 1],
is_bias=True, is_bias=True,
dtype=self._dtype) dtype=self._dtype)
self._inputs['Bias'] = self._b self._inputs['Bias'] = self.bias
self._inputs['Weight'] = self._w self._inputs['Weight'] = self.weight
@property
def weight(self):
return self._w
@weight.setter
def weight(self, value):
self._w = value
@property
def bias(self):
return self._b
@bias.setter
def bias(self, value):
self._b = value
def forward(self, input, label, sample_weight=None): def forward(self, input, label, sample_weight=None):
assert isinstance(input, Variable) assert isinstance(input, Variable)
...@@ -2243,28 +2153,19 @@ class PRelu(layers.Layer): ...@@ -2243,28 +2153,19 @@ class PRelu(layers.Layer):
self._alpha_shape = [1, input_shape[1], 1, 1] self._alpha_shape = [1, input_shape[1], 1, 1]
elif self._mode == 'element': elif self._mode == 'element':
self._alpha_shape = input_shape self._alpha_shape = input_shape
self._alpha = self.create_parameter( self.weight = self.create_parameter(
attr=self._param_attr, attr=self._param_attr,
shape=self._alpha_shape, shape=self._alpha_shape,
dtype='float32', dtype='float32',
is_bias=False, is_bias=False,
default_initializer=Constant(1.0)) default_initializer=Constant(1.0))
@property
def weight(self):
return self._alpha
@weight.setter
def weight(self, value):
self._alpha = value
def forward(self, input): def forward(self, input):
out = self._helper.create_variable_for_type_inference(self._dtype) out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op( self._helper.append_op(
type="prelu", type="prelu",
inputs={"X": input, inputs={"X": input,
'Alpha': self._alpha}, 'Alpha': self.weight},
attrs={"mode": self._mode}, attrs={"mode": self._mode},
outputs={"Out": out}) outputs={"Out": out})
return out return out
...@@ -2345,38 +2246,22 @@ class BilinearTensorProduct(layers.Layer): ...@@ -2345,38 +2246,22 @@ class BilinearTensorProduct(layers.Layer):
self._dtype = dtype self._dtype = dtype
param_shape = [self._output_dim, self._input1_dim, self._input2_dim] param_shape = [self._output_dim, self._input1_dim, self._input2_dim]
self._w = self.create_parameter( self.weight = self.create_parameter(
attr=self._param_attr, attr=self._param_attr,
shape=param_shape, shape=param_shape,
dtype=self._dtype, dtype=self._dtype,
is_bias=False) is_bias=False)
bias_size = [1, self._output_dim] bias_size = [1, self._output_dim]
self._bias_param = self.create_parameter( self.bias = self.create_parameter(
attr=self._bias_attr, attr=self._bias_attr,
shape=bias_size, shape=bias_size,
dtype=self._dtype, dtype=self._dtype,
is_bias=True) is_bias=True)
@property
def weight(self):
return self._w
@weight.setter
def weight(self, value):
self._w = value
@property
def bias(self):
return self._bias_param
@bias.setter
def bias(self, value):
self._bias_param = value
def forward(self, x, y): def forward(self, x, y):
self._inputs = {"X": x, "Y": y, "Weight": self._w} self._inputs = {"X": x, "Y": y, "Weight": self.weight}
if self._bias_param: if self.bias:
self._inputs["Bias"] = self._bias_param self._inputs["Bias"] = self.bias
if self._name is not None: if self._name is not None:
out = self._helper.create_variable( out = self._helper.create_variable(
name=".".join([self.full_name(), self._name]), name=".".join([self.full_name(), self._name]),
...@@ -2569,38 +2454,22 @@ class Conv2DTranspose(layers.Layer): ...@@ -2569,38 +2454,22 @@ class Conv2DTranspose(layers.Layer):
filter_shape = [self._num_channels, self._num_filters // self._groups filter_shape = [self._num_channels, self._num_filters // self._groups
] + self._filter_size ] + self._filter_size
self._img_filter = self.create_parameter( self.weight = self.create_parameter(
dtype=self._dtype, shape=filter_shape, attr=self._param_attr) dtype=self._dtype, shape=filter_shape, attr=self._param_attr)
self._bias_param = self.create_parameter( self.bias = self.create_parameter(
attr=self._bias_attr, attr=self._bias_attr,
shape=[self._num_filters], shape=[self._num_filters],
dtype=self._dtype, dtype=self._dtype,
is_bias=True) is_bias=True)
@property
def weight(self):
return self._img_filter
@weight.setter
def weight(self, value):
self._img_filter = value
@property
def bias(self):
return self._bias_param
@bias.setter
def bias(self, value):
self._bias_param = value
def forward(self, input): def forward(self, input):
pre_bias = self._helper.create_variable_for_type_inference( pre_bias = self._helper.create_variable_for_type_inference(
dtype=input.dtype) dtype=input.dtype)
self._helper.append_op( self._helper.append_op(
type=self._op_type, type=self._op_type,
inputs={'Input': [input], inputs={'Input': [input],
'Filter': [self._img_filter]}, 'Filter': [self.weight]},
outputs={'Output': pre_bias}, outputs={'Output': pre_bias},
attrs={ attrs={
'output_size': self._output_size, 'output_size': self._output_size,
...@@ -2611,13 +2480,13 @@ class Conv2DTranspose(layers.Layer): ...@@ -2611,13 +2480,13 @@ class Conv2DTranspose(layers.Layer):
'use_cudnn': self._use_cudnn 'use_cudnn': self._use_cudnn
}) })
if self._bias_param is not None: if self.bias is not None:
pre_act = self._helper.create_variable_for_type_inference( pre_act = self._helper.create_variable_for_type_inference(
dtype=self._dtype) dtype=self._dtype)
self._helper.append_op( self._helper.append_op(
type='elementwise_add', type='elementwise_add',
inputs={'X': [pre_bias], inputs={'X': [pre_bias],
'Y': [self._bias_param]}, 'Y': [self.bias]},
outputs={'Out': [pre_act]}, outputs={'Out': [pre_act]},
attrs={'axis': 1}) attrs={'axis': 1})
else: else:
...@@ -2682,10 +2551,10 @@ class SequenceConv(layers.Layer): ...@@ -2682,10 +2551,10 @@ class SequenceConv(layers.Layer):
def _build_once(self, input): def _build_once(self, input):
self._dtype = self._helper.input_dtype(input) self._dtype = self._helper.input_dtype(input)
filter_shape = [self._filter_size * input.shape[1], self._num_filters] filter_shape = [self._filter_size * input.shape[1], self._num_filters]
self._filter_param = self.create_parameter( self.weight = self.create_parameter(
attr=self._param_attr, shape=filter_shape, dtype=self._dtype) attr=self._param_attr, shape=filter_shape, dtype=self._dtype)
self._bias_param = self.create_parameter( self.bias = self.create_parameter(
attr=self._bias_attr, attr=self._bias_attr,
shape=[self._num_filters], shape=[self._num_filters],
dtype=self._dtype, dtype=self._dtype,
...@@ -2697,7 +2566,7 @@ class SequenceConv(layers.Layer): ...@@ -2697,7 +2566,7 @@ class SequenceConv(layers.Layer):
type='sequence_conv', type='sequence_conv',
inputs={ inputs={
'X': [input], 'X': [input],
'Filter': [self._filter_param], 'Filter': [self.weight],
}, },
outputs={"Out": pre_bias}, outputs={"Out": pre_bias},
attrs={ attrs={
...@@ -2706,13 +2575,13 @@ class SequenceConv(layers.Layer): ...@@ -2706,13 +2575,13 @@ class SequenceConv(layers.Layer):
'contextLength': self._filter_size 'contextLength': self._filter_size
}) })
if self._bias_param is not None: if self.bias is not None:
pre_act = self._helper.create_variable_for_type_inference( pre_act = self._helper.create_variable_for_type_inference(
dtype=self._dtype) dtype=self._dtype)
self._helper.append_op( self._helper.append_op(
type='elementwise_add', type='elementwise_add',
inputs={'X': [pre_bias], inputs={'X': [pre_bias],
'Y': [self._bias_param]}, 'Y': [self.bias]},
outputs={'Out': [pre_act]}, outputs={'Out': [pre_act]},
attrs={'axis': 1}) attrs={'axis': 1})
else: else:
...@@ -2784,7 +2653,7 @@ class RowConv(layers.Layer): ...@@ -2784,7 +2653,7 @@ class RowConv(layers.Layer):
def _build_once(self, input): def _build_once(self, input):
self._dtype = self._helper.input_dtype(input) self._dtype = self._helper.input_dtype(input)
filter_shape = [self._future_context_size + 1, input.shape[1]] filter_shape = [self._future_context_size + 1, input.shape[1]]
self._filter_param = self.create_parameter( self.weight = self.create_parameter(
attr=self._param_attr, attr=self._param_attr,
shape=filter_shape, shape=filter_shape,
dtype=self._dtype, dtype=self._dtype,
...@@ -2795,7 +2664,7 @@ class RowConv(layers.Layer): ...@@ -2795,7 +2664,7 @@ class RowConv(layers.Layer):
self._helper.append_op( self._helper.append_op(
type='row_conv', type='row_conv',
inputs={'X': [input], inputs={'X': [input],
'Filter': [self._filter_param]}, 'Filter': [self.weight]},
outputs={'Out': [out]}) outputs={'Out': [out]})
return self._helper.append_activation(out, act=self._act) return self._helper.append_activation(out, act=self._act)
...@@ -2858,26 +2727,25 @@ class GroupNorm(layers.Layer): ...@@ -2858,26 +2727,25 @@ class GroupNorm(layers.Layer):
raise ValueError("unsupported data layout:" + data_layout) raise ValueError("unsupported data layout:" + data_layout)
param_shape = [self._channels] param_shape = [self._channels]
if self._bias_attr:
self._bias = self.create_parameter(
attr=self._bias_attr,
shape=param_shape,
dtype=self._dtype,
is_bias=True)
if self._param_attr: self.weight = self.create_parameter(
self._scale = self.create_parameter( attr=self._param_attr or False,
attr=self._param_attr, shape=param_shape,
shape=param_shape, dtype=self._dtype,
dtype=self._dtype, default_initializer=Constant(1.0))
default_initializer=Constant(1.0))
self.bias = self.create_parameter(
attr=self._bias_attr or False,
shape=param_shape,
dtype=self._dtype,
is_bias=True)
def forward(self, input): def forward(self, input):
inputs = {'X': input} inputs = {'X': input}
if self._bias_attr: if self.bias:
inputs['Bias'] = self._bias inputs['Bias'] = self.bias
if self._param_attr: if self.weight:
inputs['Scale'] = self._scale inputs['Scale'] = self.weight
# create output # create output
mean_out = self._helper.create_variable_for_type_inference( mean_out = self._helper.create_variable_for_type_inference(
...@@ -2976,22 +2844,22 @@ class SpectralNorm(layers.Layer): ...@@ -2976,22 +2844,22 @@ class SpectralNorm(layers.Layer):
h = self._weight_shape[self._dim] h = self._weight_shape[self._dim]
w = np.prod(self._weight_shape) // h w = np.prod(self._weight_shape) // h
self.u = self.create_parameter( self.weight_u = self.create_parameter(
attr=ParamAttr(), attr=ParamAttr(),
shape=[h], shape=[h],
dtype=self._dtype, dtype=self._dtype,
default_initializer=Normal(0., 1.)) default_initializer=Normal(0., 1.))
self.u.stop_gradient = True self.weight_u.stop_gradient = True
self.v = self.create_parameter( self.weight_v = self.create_parameter(
attr=ParamAttr(), attr=ParamAttr(),
shape=[w], shape=[w],
dtype=self._dtype, dtype=self._dtype,
default_initializer=Normal(0., 1.)) default_initializer=Normal(0., 1.))
self.v.stop_gradient = True self.weight_v.stop_gradient = True
def forward(self, weight): def forward(self, weight):
inputs = {'Weight': weight, 'U': self.u, 'V': self.v} inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v}
out = self._helper.create_variable_for_type_inference(self._dtype) out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op( self._helper.append_op(
type="spectral_norm", type="spectral_norm",
...@@ -3073,49 +2941,30 @@ class TreeConv(layers.Layer): ...@@ -3073,49 +2941,30 @@ class TreeConv(layers.Layer):
self._dtype = dtype self._dtype = dtype
w_shape = [self._feature_size, 3, self._output_size, self._num_filters] w_shape = [self._feature_size, 3, self._output_size, self._num_filters]
if self._bias_attr: if self._bias_attr:
self._bias_param = self.create_parameter( self.bias = self.create_parameter(
attr=self._bias_attr, attr=self._bias_attr,
shape=[self._num_filters], shape=[self._num_filters],
dtype=self._dtype, dtype=self._dtype,
is_bias=True) is_bias=True)
self.W = self.create_parameter( self.weight = self.create_parameter(
attr=self._param_attr, attr=self._param_attr,
shape=w_shape, shape=w_shape,
dtype=self._dtype, dtype=self._dtype,
is_bias=False) is_bias=False)
@property
def weight(self):
return self.W
@weight.setter
def weight(self, value):
self.W = value
@property
def bias(self):
return self._bias_param
@bias.setter
def bias(self, value):
self._bias_param = value
def forward(self, nodes_vector, edge_set): def forward(self, nodes_vector, edge_set):
if self._name: if self._name:
out = self.create_variable( out = self.create_variable(
name=self._name, dtype=self._dtype, persistable=False) name=self._name, dtype=self._dtype, persistable=False)
else: else:
out = self._helper.create_variable_for_type_inference( out = self._helper.create_variable_for_type_inference(
dtype=self._dtype) dtype=self._dtype)
self._helper.append_op( self._helper.append_op(
type='tree_conv', type='tree_conv',
inputs={ inputs={
'NodesVector': nodes_vector, 'NodesVector': nodes_vector,
'EdgeSet': edge_set, 'EdgeSet': edge_set,
'Filter': self.W 'Filter': self.weight
}, },
outputs={'Out': out, }, outputs={'Out': out, },
attrs={'max_depth': self._max_depth}) attrs={'max_depth': self._max_depth})
...@@ -3125,7 +2974,7 @@ class TreeConv(layers.Layer): ...@@ -3125,7 +2974,7 @@ class TreeConv(layers.Layer):
self._helper.append_op( self._helper.append_op(
type='elementwise_add', type='elementwise_add',
inputs={'X': [out], inputs={'X': [out],
'Y': [self._bias_param]}, 'Y': [self.bias]},
outputs={'Out': [pre_activation]}, outputs={'Out': [pre_activation]},
attrs={'axis': 1}) attrs={'axis': 1})
else: else:
......
...@@ -153,8 +153,8 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -153,8 +153,8 @@ class TestImperativeAutoPrune(unittest.TestCase):
v2 = fluid.dygraph.to_variable(value2) v2 = fluid.dygraph.to_variable(value2)
loss = case1(v1, v2) loss = case1(v1, v2)
loss.backward() loss.backward()
self.assertTrue(case1.fc2._w._grad_ivar() is not None) self.assertTrue(case1.fc2.weight._grad_ivar() is not None)
self.assertTrue(case1.fc1._w._grad_ivar() is not None) self.assertTrue(case1.fc1.weight._grad_ivar() is not None)
def test_auto_prune2(self): def test_auto_prune2(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -166,8 +166,8 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -166,8 +166,8 @@ class TestImperativeAutoPrune(unittest.TestCase):
loss = case2(v1, v2) loss = case2(v1, v2)
loss.backward() loss.backward()
self.assertTrue(case2.fc2._w._grad_ivar() is None) self.assertTrue(case2.fc2.weight._grad_ivar() is None)
self.assertTrue(case2.fc1._w._grad_ivar() is not None) self.assertTrue(case2.fc1.weight._grad_ivar() is not None)
def test_auto_prune3(self): def test_auto_prune3(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -178,7 +178,7 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -178,7 +178,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
v2 = fluid.dygraph.to_variable(value2) v2 = fluid.dygraph.to_variable(value2)
loss, part2 = case3(v1, v2, 1) loss, part2 = case3(v1, v2, 1)
loss.backward() loss.backward()
self.assertTrue(case3.fc._w._grad_ivar() is not None) self.assertTrue(case3.fc.weight._grad_ivar() is not None)
self.assertTrue((part2.gradient() == 0).all()) self.assertTrue((part2.gradient() == 0).all())
def test_auto_prune4(self): def test_auto_prune4(self):
...@@ -190,7 +190,7 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -190,7 +190,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
v2 = fluid.dygraph.to_variable(value2) v2 = fluid.dygraph.to_variable(value2)
loss, part2 = case4(v1, v2, 1) loss, part2 = case4(v1, v2, 1)
part2.backward() part2.backward()
self.assertTrue(case4.fc._w._grad_ivar() is not None) self.assertTrue(case4.fc.weight._grad_ivar() is not None)
self.assertTrue((part2.gradient() == 1).all()) self.assertTrue((part2.gradient() == 1).all())
def test_auto_prune5(self): def test_auto_prune5(self):
...@@ -202,7 +202,7 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -202,7 +202,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
v2 = fluid.dygraph.to_variable(value2) v2 = fluid.dygraph.to_variable(value2)
loss, part1, part2 = case4(v1, v2, 2) loss, part1, part2 = case4(v1, v2, 2)
part1.backward() part1.backward()
self.assertTrue(case4.fc._w._grad_ivar() is not None) self.assertTrue(case4.fc.weight._grad_ivar() is not None)
self.assertTrue((part2.gradient() == 0).all()) self.assertTrue((part2.gradient() == 0).all())
def test_auto_prune6(self): def test_auto_prune6(self):
...@@ -220,7 +220,7 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -220,7 +220,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
out1.stop_gradient = True out1.stop_gradient = True
out = fluid.layers.concat(input=[out1, out2, c], axis=1) out = fluid.layers.concat(input=[out1, out2, c], axis=1)
out.backward() out.backward()
self.assertTrue((fc._w.gradient() == 0).all()) self.assertTrue((fc.weight.gradient() == 0).all())
self.assertTrue((out1.gradient() == 0).all()) self.assertTrue((out1.gradient() == 0).all())
def test_auto_prune7(self): def test_auto_prune7(self):
...@@ -239,7 +239,7 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -239,7 +239,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
out = fluid.layers.concat(input=[out1, out2, c], axis=1) out = fluid.layers.concat(input=[out1, out2, c], axis=1)
backward_strategy = fluid.dygraph.BackwardStrategy() backward_strategy = fluid.dygraph.BackwardStrategy()
out.backward(backward_strategy) out.backward(backward_strategy)
self.assertTrue((fc._w.gradient() == 0).all()) self.assertTrue((fc.weight.gradient() == 0).all())
self.assertTrue((out1.gradient() == 0).all()) self.assertTrue((out1.gradient() == 0).all())
def test_auto_prune8(self): def test_auto_prune8(self):
...@@ -253,17 +253,17 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -253,17 +253,17 @@ class TestImperativeAutoPrune(unittest.TestCase):
b = fluid.dygraph.to_variable(value1) b = fluid.dygraph.to_variable(value1)
c = fluid.dygraph.to_variable(value2) c = fluid.dygraph.to_variable(value2)
out1 = fc(a) out1 = fc(a)
fc_origin = fc._w.numpy() fc_origin = fc.weight.numpy()
out2 = fc2(out1) out2 = fc2(out1)
fc2_origin = fc2._w.numpy() fc2_origin = fc2.weight.numpy()
fc2._w.stop_gradient = True fc2.weight.stop_gradient = True
out2.backward() out2.backward()
optimizer = fluid.optimizer.SGD( optimizer = fluid.optimizer.SGD(
learning_rate=0.003, learning_rate=0.003,
parameter_list=(fc.parameters() + fc2.parameters())) parameter_list=(fc.parameters() + fc2.parameters()))
optimizer.minimize(out2) optimizer.minimize(out2)
self.assertTrue(np.array_equal(fc2_origin, fc2._w.numpy())) self.assertTrue(np.array_equal(fc2_origin, fc2.weight.numpy()))
self.assertFalse(np.array_equal(fc_origin, fc._w.numpy())) self.assertFalse(np.array_equal(fc_origin, fc.weight.numpy()))
def test_auto_prune9(self): def test_auto_prune9(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -276,19 +276,19 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -276,19 +276,19 @@ class TestImperativeAutoPrune(unittest.TestCase):
b = fluid.dygraph.to_variable(value1) b = fluid.dygraph.to_variable(value1)
c = fluid.dygraph.to_variable(value2) c = fluid.dygraph.to_variable(value2)
out1 = fc(a) out1 = fc(a)
fc_origin = fc._w.numpy() fc_origin = fc.weight.numpy()
out2 = fc2(out1) out2 = fc2(out1)
fc2_origin = fc2._w.numpy() fc2_origin = fc2.weight.numpy()
out2.stop_gradient = True out2.stop_gradient = True
out2.backward() out2.backward()
optimizer = fluid.optimizer.SGD( optimizer = fluid.optimizer.SGD(
learning_rate=0.003, learning_rate=0.003,
parameter_list=(fc.parameters() + fc2.parameters())) parameter_list=(fc.parameters() + fc2.parameters()))
optimizer.minimize(out2) optimizer.minimize(out2)
self.assertTrue(np.array_equal(fc2_origin, fc2._w.numpy())) self.assertTrue(np.array_equal(fc2_origin, fc2.weight.numpy()))
self.assertTrue(np.array_equal(fc_origin, fc._w.numpy())) self.assertTrue(np.array_equal(fc_origin, fc.weight.numpy()))
try: try:
fc2._w.gradient() fc2.weight.gradient()
except ValueError as e: except ValueError as e:
assert type(e) == ValueError assert type(e) == ValueError
...@@ -309,7 +309,7 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -309,7 +309,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
backward_strategy = fluid.dygraph.BackwardStrategy() backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True backward_strategy.sort_sum_gradient = True
out.backward(backward_strategy) out.backward(backward_strategy)
self.assertTrue((fc._w.gradient() == 0).all()) self.assertTrue((fc.weight.gradient() == 0).all())
self.assertTrue((out1.gradient() == 0).all()) self.assertTrue((out1.gradient() == 0).all())
def test_auto_prune_with_optimizer(self): def test_auto_prune_with_optimizer(self):
...@@ -336,10 +336,10 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -336,10 +336,10 @@ class TestImperativeAutoPrune(unittest.TestCase):
loss.backward() loss.backward()
_, params_grads = optimizer.minimize(loss, grad_clip=grad_clip) _, params_grads = optimizer.minimize(loss, grad_clip=grad_clip)
for items in params_grads: for items in params_grads:
assert items[0].name is not model.embed1._w.name assert items[0].name is not model.embed1.weight.name
assert items[0].name is not model.fc1._w.name assert items[0].name is not model.fc1.weight.name
assert model.embed1._w._grad_ivar() is None assert model.embed1.weight._grad_ivar() is None
assert model.fc1._w._grad_ivar() is None assert model.fc1.weight._grad_ivar() is None
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
model = MyLayer2("mylayer", vocab_size, size) model = MyLayer2("mylayer", vocab_size, size)
...@@ -355,10 +355,10 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -355,10 +355,10 @@ class TestImperativeAutoPrune(unittest.TestCase):
loss.backward() loss.backward()
optimizer.minimize(loss, grad_clip=grad_clip) optimizer.minimize(loss, grad_clip=grad_clip)
for items in params_grads: for items in params_grads:
assert items[0].name is not model.embed1._w.name assert items[0].name is not model.embed1.weight.name
assert items[0].name is not model.fc1._w.name assert items[0].name is not model.fc1.weight.name
assert model.embed1._w._grad_ivar() is None assert model.embed1.weight._grad_ivar() is None
assert model.fc1._w._grad_ivar() is None assert model.fc1.weight._grad_ivar() is None
def test_case2_prune_no_grad_branch(self): def test_case2_prune_no_grad_branch(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -369,8 +369,8 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -369,8 +369,8 @@ class TestImperativeAutoPrune(unittest.TestCase):
case3 = AutoPruneLayer2("l2") case3 = AutoPruneLayer2("l2")
loss = case3(v1, v2) loss = case3(v1, v2)
loss.backward() loss.backward()
self.assertTrue(case3.fc2._w._grad_ivar() is None) self.assertTrue(case3.fc2.weight._grad_ivar() is None)
self.assertTrue(case3.fc._w._grad_ivar() is not None) self.assertTrue(case3.fc.weight._grad_ivar() is not None)
def test_case2_prune_no_grad_branch(self): def test_case2_prune_no_grad_branch(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -381,8 +381,8 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -381,8 +381,8 @@ class TestImperativeAutoPrune(unittest.TestCase):
case3 = AutoPruneLayer2("l2") case3 = AutoPruneLayer2("l2")
loss = case3(v1, v2) loss = case3(v1, v2)
loss.backward() loss.backward()
self.assertTrue(case3.fc2._w._grad_ivar() is None) self.assertTrue(case3.fc2.weight._grad_ivar() is None)
self.assertTrue(case3.fc._w._grad_ivar() is not None) self.assertTrue(case3.fc.weight._grad_ivar() is not None)
def test_case3_prune_no_grad_branch2(self): def test_case3_prune_no_grad_branch2(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -395,7 +395,7 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -395,7 +395,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
out = fluid.layers.one_hot(input=label, depth=100) out = fluid.layers.one_hot(input=label, depth=100)
loss = fluid.layers.mean(out) loss = fluid.layers.mean(out)
loss.backward() loss.backward()
self.assertTrue(fc._w._grad_ivar() is None) self.assertTrue(fc.weight._grad_ivar() is None)
def test_case4_with_no_grad_op_maker(self): def test_case4_with_no_grad_op_maker(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
......
...@@ -342,7 +342,7 @@ class TestImperative(unittest.TestCase): ...@@ -342,7 +342,7 @@ class TestImperative(unittest.TestCase):
out = mlp(var_inp) out = mlp(var_inp)
dy_out = out.numpy() dy_out = out.numpy()
out.backward() out.backward()
dy_grad = mlp._fc1._w.gradient() dy_grad = mlp._fc1.weight.gradient()
with fluid.dygraph.guard(): with fluid.dygraph.guard():
var_inp2 = fluid.dygraph.base.to_variable(np_inp) var_inp2 = fluid.dygraph.base.to_variable(np_inp)
...@@ -352,7 +352,7 @@ class TestImperative(unittest.TestCase): ...@@ -352,7 +352,7 @@ class TestImperative(unittest.TestCase):
backward_strategy = fluid.dygraph.BackwardStrategy() backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True backward_strategy.sort_sum_gradient = True
out2.backward(backward_strategy) out2.backward(backward_strategy)
dy_grad2 = mlp2._fc1._w.gradient() dy_grad2 = mlp2._fc1.weight.gradient()
with new_program_scope(): with new_program_scope():
inp = fluid.layers.data( inp = fluid.layers.data(
...@@ -360,7 +360,7 @@ class TestImperative(unittest.TestCase): ...@@ -360,7 +360,7 @@ class TestImperative(unittest.TestCase):
mlp = MLP("mlp") mlp = MLP("mlp")
out = mlp(inp) out = mlp(inp)
param_grads = fluid.backward.append_backward( param_grads = fluid.backward.append_backward(
out, parameter_list=[mlp._fc1._w.name])[0] out, parameter_list=[mlp._fc1.weight.name])[0]
exe = fluid.Executor(fluid.CPUPlace( exe = fluid.Executor(fluid.CPUPlace(
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
......
...@@ -59,7 +59,7 @@ class SimpleNet(fluid.Layer): ...@@ -59,7 +59,7 @@ class SimpleNet(fluid.Layer):
x_emb = self.embedding(input) x_emb = self.embedding(input)
projection = fluid.layers.matmul( projection = fluid.layers.matmul(
x_emb, fluid.layers.transpose( x_emb, fluid.layers.transpose(
self.embedding._w, perm=[1, 0])) self.embedding.weight, perm=[1, 0]))
projection = fluid.layers.elementwise_add(projection, self.softmax_bias) projection = fluid.layers.elementwise_add(projection, self.softmax_bias)
projection = fluid.layers.reshape( projection = fluid.layers.reshape(
projection, shape=[-1, self.vocab_size]) projection, shape=[-1, self.vocab_size])
......
...@@ -61,7 +61,7 @@ class TestSimpleNet(unittest.TestCase): ...@@ -61,7 +61,7 @@ class TestSimpleNet(unittest.TestCase):
input_emb, emb = simplenet(input) input_emb, emb = simplenet(input)
try: try:
emb._w.gradient() emb.weight.gradient()
except ValueError as e: except ValueError as e:
assert "has no grad, Please set Variable.stop_gradient=False, or check if this is the first and only variable need grad, if so, please set its pre-Variable's stop_gradient=False, to make sure it has gradient" in str( assert "has no grad, Please set Variable.stop_gradient=False, or check if this is the first and only variable need grad, if so, please set its pre-Variable's stop_gradient=False, to make sure it has gradient" in str(
e) e)
...@@ -73,11 +73,11 @@ class TestSimpleNet(unittest.TestCase): ...@@ -73,11 +73,11 @@ class TestSimpleNet(unittest.TestCase):
input_emb.backward(backward_strategy) input_emb.backward(backward_strategy)
adam.minimize(input_emb) # grad_clip=grad_clip adam.minimize(input_emb) # grad_clip=grad_clip
emb._w.gradient() emb.weight.gradient()
emb.clear_gradients() emb.clear_gradients()
try: try:
emb._w.gradient() emb.weight.gradient()
except ValueError as e: except ValueError as e:
assert "has no grad, Please set Variable.stop_gradient=False, or check if this is the first and only variable need grad, if so, please set its pre-Variable's stop_gradient=False, to make sure it has gradient" in str( assert "has no grad, Please set Variable.stop_gradient=False, or check if this is the first and only variable need grad, if so, please set its pre-Variable's stop_gradient=False, to make sure it has gradient" in str(
e) e)
...@@ -108,7 +108,7 @@ class TestSimpleNet(unittest.TestCase): ...@@ -108,7 +108,7 @@ class TestSimpleNet(unittest.TestCase):
input_emb, emb = simplenet(input) input_emb, emb = simplenet(input)
try: try:
emb._w.gradient() emb.weight.gradient()
except ValueError as e: except ValueError as e:
assert "has no grad, Please set Variable.stop_gradient=False, or check if this is the first and only variable need grad, if so, please set its pre-Variable's stop_gradient=False, to make sure it has gradient" in str( assert "has no grad, Please set Variable.stop_gradient=False, or check if this is the first and only variable need grad, if so, please set its pre-Variable's stop_gradient=False, to make sure it has gradient" in str(
e) e)
...@@ -120,11 +120,11 @@ class TestSimpleNet(unittest.TestCase): ...@@ -120,11 +120,11 @@ class TestSimpleNet(unittest.TestCase):
input_emb.backward(backward_strategy) input_emb.backward(backward_strategy)
adam.minimize(input_emb, grad_clip=grad_clip) adam.minimize(input_emb, grad_clip=grad_clip)
emb._w.gradient() emb.weight.gradient()
emb.clear_gradients() emb.clear_gradients()
try: try:
emb._w.gradient() emb.weight.gradient()
except ValueError as e: except ValueError as e:
assert "has no grad, Please set Variable.stop_gradient=False, or check if this is the first and only variable need grad, if so, please set its pre-Variable's stop_gradient=False, to make sure it has gradient" in str( assert "has no grad, Please set Variable.stop_gradient=False, or check if this is the first and only variable need grad, if so, please set its pre-Variable's stop_gradient=False, to make sure it has gradient" in str(
e) e)
......
...@@ -66,7 +66,7 @@ class SimpleNet(fluid.Layer): ...@@ -66,7 +66,7 @@ class SimpleNet(fluid.Layer):
fc = fluid.layers.elementwise_add(fc, self.softmax_bias) fc = fluid.layers.elementwise_add(fc, self.softmax_bias)
projection = fluid.layers.matmul( projection = fluid.layers.matmul(
fc, fluid.layers.transpose( fc, fluid.layers.transpose(
self.embedding._w, perm=[1, 0])) self.embedding.weight, perm=[1, 0]))
projection = fluid.layers.reshape( projection = fluid.layers.reshape(
projection, shape=[-1, self.vocab_size]) projection, shape=[-1, self.vocab_size])
loss = fluid.layers.softmax_with_cross_entropy( loss = fluid.layers.softmax_with_cross_entropy(
......
...@@ -843,7 +843,7 @@ class WrapDecoderLayer(Layer): ...@@ -843,7 +843,7 @@ class WrapDecoderLayer(Layer):
if self._weight_sharing: if self._weight_sharing:
predict = fluid.layers.matmul( predict = fluid.layers.matmul(
x=dec_output_reshape, x=dec_output_reshape,
y=self._prepare_decoder_layer._input_emb._w, y=self._prepare_decoder_layer._input_emb.weight,
transpose_y=True) transpose_y=True)
else: else:
predict = self._fc(dec_output_reshape) predict = self._fc(dec_output_reshape)
...@@ -917,7 +917,7 @@ class TransFormer(Layer): ...@@ -917,7 +917,7 @@ class TransFormer(Layer):
is_sparse=is_sparse) is_sparse=is_sparse)
if weight_sharing: if weight_sharing:
self._wrap_decoder_layer._prepare_decoder_layer._input_emb._w = self._wrap_encoder_layer._prepare_encoder_layer._input_emb._w self._wrap_decoder_layer._prepare_decoder_layer._input_emb.weight = self._wrap_encoder_layer._prepare_encoder_layer._input_emb.weight
def forward(self, enc_inputs, dec_inputs, label, weights): def forward(self, enc_inputs, dec_inputs, label, weights):
enc_output = self._wrap_encoder_layer(enc_inputs) enc_output = self._wrap_encoder_layer(enc_inputs)
......
...@@ -368,7 +368,7 @@ class TestLayer(LayerTest): ...@@ -368,7 +368,7 @@ class TestLayer(LayerTest):
filter_size=[2, 2], filter_size=[2, 2],
bias_attr=False) bias_attr=False)
dy_ret = conv2d(base.to_variable(images)) dy_ret = conv2d(base.to_variable(images))
self.assertTrue(conv2d._bias_param is None) self.assertTrue(conv2d.bias is None)
self.assertTrue(np.allclose(static_ret, dy_ret_value)) self.assertTrue(np.allclose(static_ret, dy_ret_value))
self.assertTrue(np.allclose(static_ret, static_ret2)) self.assertTrue(np.allclose(static_ret, static_ret2))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册