From 56414c7daf40ea1d6f5d34afef712cacf8e125e9 Mon Sep 17 00:00:00 2001 From: songyouwei Date: Thu, 2 Jan 2020 16:03:58 +0800 Subject: [PATCH] move private weight fields to public ones (#21982) * move private weight properties to public ones test=develop * revert changes to FC test=develop * fix unittest test=develop * fix unittest test=develop * fix coverage test=develop * fix merged dev test=develop * bug fix test=develop --- python/paddle/fluid/dygraph/nn.py | 315 +++++------------- .../unittests/test_imperative_auto_prune.py | 66 ++-- .../tests/unittests/test_imperative_basic.py | 6 +- ..._imperative_lod_tensor_to_selected_rows.py | 2 +- .../test_imperative_selected_rows.py | 12 +- ..._imperative_selected_rows_to_lod_tensor.py | 2 +- ..._imperative_transformer_sorted_gradient.py | 4 +- .../fluid/tests/unittests/test_layers.py | 2 +- 8 files changed, 129 insertions(+), 280 deletions(-) diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 3292a7ee4eb..f660396780f 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -206,38 +206,22 @@ class Conv2D(layers.Layer): std = (2.0 / filter_elem_num)**0.5 return Normal(0.0, std, 0) - self._filter_param = self.create_parameter( + self.weight = self.create_parameter( attr=self._param_attr, shape=filter_shape, dtype=self._dtype, default_initializer=_get_default_param_initializer()) - self._bias_param = self.create_parameter( + self.bias = self.create_parameter( attr=self._bias_attr, shape=[self._num_filters], dtype=self._dtype, is_bias=True) - @property - def weight(self): - return self._filter_param - - @weight.setter - def weight(self, value): - self._filter_param = value - - @property - def bias(self): - return self._bias_param - - @bias.setter - def bias(self, value): - self._bias_param = value - def forward(self, input): inputs = { 'Input': [input], - 'Filter': [self._filter_param], + 'Filter': [self.weight], } attrs = { 'strides': self._stride, @@ -252,8 +236,8 @@ class Conv2D(layers.Layer): outs = core.ops.conv2d(inputs, attrs) pre_bias = outs['Output'][0] - pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, - self._bias_param, 1) + pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, self.bias, + 1) return dygraph_utils._append_activation_in_dygraph(pre_act, self._act) @@ -265,18 +249,18 @@ class Conv2D(layers.Layer): type=self._l_type, inputs={ 'Input': input, - 'Filter': self._filter_param, + 'Filter': self.weight, }, outputs={"Output": pre_bias}, attrs=attrs) - if self._bias_param is not None: + if self.bias is not None: pre_act = self._helper.create_variable_for_type_inference( dtype=self._dtype) self._helper.append_op( type='elementwise_add', inputs={'X': [pre_bias], - 'Y': [self._bias_param]}, + 'Y': [self.bias]}, outputs={'Out': [pre_act]}, attrs={'axis': 1}) else: @@ -441,34 +425,18 @@ class Conv3D(layers.Layer): std = (2.0 / filter_elem_num)**0.5 return Normal(0.0, std, 0) - self._filter_param = self.create_parameter( + self.weight = self.create_parameter( attr=self._param_attr, shape=filter_shape, dtype=self._dtype, default_initializer=_get_default_param_initializer()) - self._bias_param = self.create_parameter( + self.bias = self.create_parameter( attr=self._bias_attr, shape=[self._num_filters], dtype=self._dtype, is_bias=True) - @property - def weight(self): - return self._filter_param - - @weight.setter - def weight(self, value): - self._filter_param = value - - @property - def bias(self): - return self._bias_param - - @bias.setter - def bias(self, value): - self._bias_param = value - def forward(self, input): pre_bias = self._helper.create_variable_for_type_inference( dtype=self._dtype) @@ -477,7 +445,7 @@ class Conv3D(layers.Layer): type='conv3d', inputs={ 'Input': input, - 'Filter': self._filter_param, + 'Filter': self.weight, }, outputs={"Output": pre_bias}, attrs={ @@ -489,13 +457,13 @@ class Conv3D(layers.Layer): 'use_mkldnn': False }) - if self._bias_param is not None: + if self.bias is not None: pre_act = self._helper.create_variable_for_type_inference( dtype=self._dtype) self._helper.append_op( type='elementwise_add', inputs={'X': [pre_bias], - 'Y': [self._bias_param]}, + 'Y': [self.bias]}, outputs={'Out': [pre_act]}, attrs={'axis': 1}) else: @@ -681,38 +649,22 @@ class Conv3DTranspose(layers.Layer): filter_shape = [self._num_channels, self._num_filters // self._groups ] + self._filter_size - self._img_filter = self.create_parameter( + self.weight = self.create_parameter( dtype=self._dtype, shape=filter_shape, attr=self._param_attr) if self._bias_attr: - self._bias_param = self.create_parameter( + self.bias = self.create_parameter( attr=self._bias_attr, shape=[self._num_filters], dtype=self._dtype, is_bias=True) - @property - def weight(self): - return self._img_filter - - @weight.setter - def weight(self, value): - self._img_filter = value - - @property - def bias(self): - return self._bias_param - - @bias.setter - def bias(self, value): - self._bias_param = value - def forward(self, input): pre_bias = self._helper.create_variable_for_type_inference( dtype=self._dtype) self._helper.append_op( type="conv3d_transpose", inputs={'Input': [input], - 'Filter': [self._img_filter]}, + 'Filter': [self.weight]}, outputs={'Output': pre_bias}, attrs={ 'strides': self._stride, @@ -728,7 +680,7 @@ class Conv3DTranspose(layers.Layer): self._helper.append_op( type='elementwise_add', inputs={'X': [pre_bias], - 'Y': [self._bias_param]}, + 'Y': [self.bias]}, outputs={'Out': [pre_act]}, attrs={'axis': 1}) else: @@ -1345,21 +1297,19 @@ class BatchNorm(layers.Layer): param_shape = [num_channels] # create parameter - self._scale = self.create_parameter( + self.weight = self.create_parameter( attr=self._param_attr, shape=param_shape, dtype=self._dtype, default_initializer=Constant(1.0)) - if use_global_stats and self._param_attr.learning_rate == 0.: - self._scale.stop_gradient = True + self.weight.stop_gradient = use_global_stats and self._param_attr.learning_rate == 0. - self._bias = self.create_parameter( + self.bias = self.create_parameter( attr=self._bias_attr, shape=param_shape, dtype=self._dtype, is_bias=True) - if use_global_stats and self._param_attr.learning_rate == 0.: - self._bias.stop_gradient = True + self.bias.stop_gradient = use_global_stats and self._param_attr.learning_rate == 0. self._mean = self.create_parameter( attr=ParamAttr( @@ -1408,8 +1358,8 @@ class BatchNorm(layers.Layer): type="batch_norm", inputs={ "X": input, - "Scale": self._scale, - "Bias": self._bias, + "Scale": self.weight, + "Bias": self.bias, "Mean": self._mean, "Variance": self._variance }, @@ -1559,20 +1509,12 @@ class Embedding(layers.Layer): if self._remote_prefetch: assert self._is_sparse is True and self._is_distributed is False - self._w = self.create_parameter( + self.weight = self.create_parameter( attr=self._param_attr, shape=self._size, dtype=self._dtype, is_bias=False) - @property - def weight(self): - return self._w - - @weight.setter - def weight(self, value): - self._w = value - def forward(self, input): attrs = { 'is_sparse': self._is_sparse, @@ -1582,7 +1524,7 @@ class Embedding(layers.Layer): } if in_dygraph_mode(): - inputs = {'Ids': [input], 'W': [self._w]} + inputs = {'Ids': [input], 'W': [self.weight]} outs = core.ops.lookup_table_v2(inputs, attrs) return outs['Out'][0] @@ -1590,7 +1532,7 @@ class Embedding(layers.Layer): self._helper.append_op( type='lookup_table_v2', inputs={'Ids': input, - 'W': self._w}, + 'W': self.weight}, outputs={'Out': out}, attrs=attrs) @@ -1686,7 +1628,7 @@ class LayerNorm(layers.Layer): self._dtype = dtype param_shape = [np.prod(self._normalized_shape)] if self._scale: - self._scale_w = self.create_parameter( + self.weight = self.create_parameter( attr=self._param_attr, shape=param_shape, dtype=self._dtype, @@ -1697,7 +1639,7 @@ class LayerNorm(layers.Layer): if self._shift: assert self._bias_attr is not False - self._bias_w = self.create_parameter( + self.bias = self.create_parameter( attr=self._bias_attr, shape=param_shape, dtype=self._dtype, @@ -1721,9 +1663,9 @@ class LayerNorm(layers.Layer): inputs = dict() inputs['X'] = input if self._scale: - inputs['Scale'] = self._scale_w + inputs['Scale'] = self.weight if self._shift: - inputs['Bias'] = self._bias_w + inputs['Bias'] = self.bias # create output mean_out = self._helper.create_variable_for_type_inference( dtype=self._dtype, stop_gradient=True) @@ -1878,35 +1820,19 @@ class GRUUnit(layers.Layer): self._dtype = dtype size = size // 3 # create weight - self._weight = self.create_parameter( + self.weight = self.create_parameter( attr=param_attr, shape=[size, 3 * size], dtype=dtype) # create bias bias_size = [1, 3 * size] self._bias_size = bias_size - self._bias = self.create_parameter( + self.bias = self.create_parameter( attr=bias_attr, shape=bias_size, dtype=dtype, is_bias=True) - @property - def weight(self): - return self._weight - - @weight.setter - def weight(self, value): - self._weight = value - - @property - def bias(self): - return self._bias - - @bias.setter - def bias(self, value): - self._bias = value - def forward(self, input, hidden): - inputs = {'Input': input, 'HiddenPrev': hidden, 'Weight': self._weight} - if self._bias: - inputs['Bias'] = self._bias + inputs = {'Input': input, 'HiddenPrev': hidden, 'Weight': self.weight} + if self.bias: + inputs['Bias'] = self.bias gate = self._helper.create_variable_for_type_inference(self._dtype) reset_hidden_pre = self._helper.create_variable_for_type_inference( @@ -2122,35 +2048,19 @@ class NCE(layers.Layer): 'remote_prefetch': remote_prefetch } - self._w = self.create_parameter( + self.weight = self.create_parameter( attr=self._param_attr, shape=[self._num_total_classes, dim], is_bias=False, dtype=self._dtype) if self._bias_attr: - self._b = self.create_parameter( + self.bias = self.create_parameter( attr=self._bias_attr, shape=[self._num_total_classes, 1], is_bias=True, dtype=self._dtype) - self._inputs['Bias'] = self._b - self._inputs['Weight'] = self._w - - @property - def weight(self): - return self._w - - @weight.setter - def weight(self, value): - self._w = value - - @property - def bias(self): - return self._b - - @bias.setter - def bias(self, value): - self._b = value + self._inputs['Bias'] = self.bias + self._inputs['Weight'] = self.weight def forward(self, input, label, sample_weight=None): assert isinstance(input, Variable) @@ -2243,28 +2153,19 @@ class PRelu(layers.Layer): self._alpha_shape = [1, input_shape[1], 1, 1] elif self._mode == 'element': self._alpha_shape = input_shape - self._alpha = self.create_parameter( + self.weight = self.create_parameter( attr=self._param_attr, shape=self._alpha_shape, dtype='float32', is_bias=False, default_initializer=Constant(1.0)) - @property - def weight(self): - return self._alpha - - @weight.setter - def weight(self, value): - self._alpha = value - def forward(self, input): - out = self._helper.create_variable_for_type_inference(self._dtype) self._helper.append_op( type="prelu", inputs={"X": input, - 'Alpha': self._alpha}, + 'Alpha': self.weight}, attrs={"mode": self._mode}, outputs={"Out": out}) return out @@ -2345,38 +2246,22 @@ class BilinearTensorProduct(layers.Layer): self._dtype = dtype param_shape = [self._output_dim, self._input1_dim, self._input2_dim] - self._w = self.create_parameter( + self.weight = self.create_parameter( attr=self._param_attr, shape=param_shape, dtype=self._dtype, is_bias=False) bias_size = [1, self._output_dim] - self._bias_param = self.create_parameter( + self.bias = self.create_parameter( attr=self._bias_attr, shape=bias_size, dtype=self._dtype, is_bias=True) - @property - def weight(self): - return self._w - - @weight.setter - def weight(self, value): - self._w = value - - @property - def bias(self): - return self._bias_param - - @bias.setter - def bias(self, value): - self._bias_param = value - def forward(self, x, y): - self._inputs = {"X": x, "Y": y, "Weight": self._w} - if self._bias_param: - self._inputs["Bias"] = self._bias_param + self._inputs = {"X": x, "Y": y, "Weight": self.weight} + if self.bias: + self._inputs["Bias"] = self.bias if self._name is not None: out = self._helper.create_variable( name=".".join([self.full_name(), self._name]), @@ -2569,38 +2454,22 @@ class Conv2DTranspose(layers.Layer): filter_shape = [self._num_channels, self._num_filters // self._groups ] + self._filter_size - self._img_filter = self.create_parameter( + self.weight = self.create_parameter( dtype=self._dtype, shape=filter_shape, attr=self._param_attr) - self._bias_param = self.create_parameter( + self.bias = self.create_parameter( attr=self._bias_attr, shape=[self._num_filters], dtype=self._dtype, is_bias=True) - @property - def weight(self): - return self._img_filter - - @weight.setter - def weight(self, value): - self._img_filter = value - - @property - def bias(self): - return self._bias_param - - @bias.setter - def bias(self, value): - self._bias_param = value - def forward(self, input): pre_bias = self._helper.create_variable_for_type_inference( dtype=input.dtype) self._helper.append_op( type=self._op_type, inputs={'Input': [input], - 'Filter': [self._img_filter]}, + 'Filter': [self.weight]}, outputs={'Output': pre_bias}, attrs={ 'output_size': self._output_size, @@ -2611,13 +2480,13 @@ class Conv2DTranspose(layers.Layer): 'use_cudnn': self._use_cudnn }) - if self._bias_param is not None: + if self.bias is not None: pre_act = self._helper.create_variable_for_type_inference( dtype=self._dtype) self._helper.append_op( type='elementwise_add', inputs={'X': [pre_bias], - 'Y': [self._bias_param]}, + 'Y': [self.bias]}, outputs={'Out': [pre_act]}, attrs={'axis': 1}) else: @@ -2682,10 +2551,10 @@ class SequenceConv(layers.Layer): def _build_once(self, input): self._dtype = self._helper.input_dtype(input) filter_shape = [self._filter_size * input.shape[1], self._num_filters] - self._filter_param = self.create_parameter( + self.weight = self.create_parameter( attr=self._param_attr, shape=filter_shape, dtype=self._dtype) - self._bias_param = self.create_parameter( + self.bias = self.create_parameter( attr=self._bias_attr, shape=[self._num_filters], dtype=self._dtype, @@ -2697,7 +2566,7 @@ class SequenceConv(layers.Layer): type='sequence_conv', inputs={ 'X': [input], - 'Filter': [self._filter_param], + 'Filter': [self.weight], }, outputs={"Out": pre_bias}, attrs={ @@ -2706,13 +2575,13 @@ class SequenceConv(layers.Layer): 'contextLength': self._filter_size }) - if self._bias_param is not None: + if self.bias is not None: pre_act = self._helper.create_variable_for_type_inference( dtype=self._dtype) self._helper.append_op( type='elementwise_add', inputs={'X': [pre_bias], - 'Y': [self._bias_param]}, + 'Y': [self.bias]}, outputs={'Out': [pre_act]}, attrs={'axis': 1}) else: @@ -2784,7 +2653,7 @@ class RowConv(layers.Layer): def _build_once(self, input): self._dtype = self._helper.input_dtype(input) filter_shape = [self._future_context_size + 1, input.shape[1]] - self._filter_param = self.create_parameter( + self.weight = self.create_parameter( attr=self._param_attr, shape=filter_shape, dtype=self._dtype, @@ -2795,7 +2664,7 @@ class RowConv(layers.Layer): self._helper.append_op( type='row_conv', inputs={'X': [input], - 'Filter': [self._filter_param]}, + 'Filter': [self.weight]}, outputs={'Out': [out]}) return self._helper.append_activation(out, act=self._act) @@ -2858,26 +2727,25 @@ class GroupNorm(layers.Layer): raise ValueError("unsupported data layout:" + data_layout) param_shape = [self._channels] - if self._bias_attr: - self._bias = self.create_parameter( - attr=self._bias_attr, - shape=param_shape, - dtype=self._dtype, - is_bias=True) - if self._param_attr: - self._scale = self.create_parameter( - attr=self._param_attr, - shape=param_shape, - dtype=self._dtype, - default_initializer=Constant(1.0)) + self.weight = self.create_parameter( + attr=self._param_attr or False, + shape=param_shape, + dtype=self._dtype, + default_initializer=Constant(1.0)) + + self.bias = self.create_parameter( + attr=self._bias_attr or False, + shape=param_shape, + dtype=self._dtype, + is_bias=True) def forward(self, input): inputs = {'X': input} - if self._bias_attr: - inputs['Bias'] = self._bias - if self._param_attr: - inputs['Scale'] = self._scale + if self.bias: + inputs['Bias'] = self.bias + if self.weight: + inputs['Scale'] = self.weight # create output mean_out = self._helper.create_variable_for_type_inference( @@ -2976,22 +2844,22 @@ class SpectralNorm(layers.Layer): h = self._weight_shape[self._dim] w = np.prod(self._weight_shape) // h - self.u = self.create_parameter( + self.weight_u = self.create_parameter( attr=ParamAttr(), shape=[h], dtype=self._dtype, default_initializer=Normal(0., 1.)) - self.u.stop_gradient = True + self.weight_u.stop_gradient = True - self.v = self.create_parameter( + self.weight_v = self.create_parameter( attr=ParamAttr(), shape=[w], dtype=self._dtype, default_initializer=Normal(0., 1.)) - self.v.stop_gradient = True + self.weight_v.stop_gradient = True def forward(self, weight): - inputs = {'Weight': weight, 'U': self.u, 'V': self.v} + inputs = {'Weight': weight, 'U': self.weight_u, 'V': self.weight_v} out = self._helper.create_variable_for_type_inference(self._dtype) self._helper.append_op( type="spectral_norm", @@ -3073,49 +2941,30 @@ class TreeConv(layers.Layer): self._dtype = dtype w_shape = [self._feature_size, 3, self._output_size, self._num_filters] if self._bias_attr: - self._bias_param = self.create_parameter( + self.bias = self.create_parameter( attr=self._bias_attr, shape=[self._num_filters], dtype=self._dtype, is_bias=True) - self.W = self.create_parameter( + self.weight = self.create_parameter( attr=self._param_attr, shape=w_shape, dtype=self._dtype, is_bias=False) - @property - def weight(self): - return self.W - - @weight.setter - def weight(self, value): - self.W = value - - @property - def bias(self): - return self._bias_param - - @bias.setter - def bias(self, value): - self._bias_param = value - def forward(self, nodes_vector, edge_set): - if self._name: out = self.create_variable( name=self._name, dtype=self._dtype, persistable=False) else: - out = self._helper.create_variable_for_type_inference( dtype=self._dtype) - self._helper.append_op( type='tree_conv', inputs={ 'NodesVector': nodes_vector, 'EdgeSet': edge_set, - 'Filter': self.W + 'Filter': self.weight }, outputs={'Out': out, }, attrs={'max_depth': self._max_depth}) @@ -3125,7 +2974,7 @@ class TreeConv(layers.Layer): self._helper.append_op( type='elementwise_add', inputs={'X': [out], - 'Y': [self._bias_param]}, + 'Y': [self.bias]}, outputs={'Out': [pre_activation]}, attrs={'axis': 1}) else: diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py index a3a5ce883a3..6ab4a72e836 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py @@ -153,8 +153,8 @@ class TestImperativeAutoPrune(unittest.TestCase): v2 = fluid.dygraph.to_variable(value2) loss = case1(v1, v2) loss.backward() - self.assertTrue(case1.fc2._w._grad_ivar() is not None) - self.assertTrue(case1.fc1._w._grad_ivar() is not None) + self.assertTrue(case1.fc2.weight._grad_ivar() is not None) + self.assertTrue(case1.fc1.weight._grad_ivar() is not None) def test_auto_prune2(self): with fluid.dygraph.guard(): @@ -166,8 +166,8 @@ class TestImperativeAutoPrune(unittest.TestCase): loss = case2(v1, v2) loss.backward() - self.assertTrue(case2.fc2._w._grad_ivar() is None) - self.assertTrue(case2.fc1._w._grad_ivar() is not None) + self.assertTrue(case2.fc2.weight._grad_ivar() is None) + self.assertTrue(case2.fc1.weight._grad_ivar() is not None) def test_auto_prune3(self): with fluid.dygraph.guard(): @@ -178,7 +178,7 @@ class TestImperativeAutoPrune(unittest.TestCase): v2 = fluid.dygraph.to_variable(value2) loss, part2 = case3(v1, v2, 1) loss.backward() - self.assertTrue(case3.fc._w._grad_ivar() is not None) + self.assertTrue(case3.fc.weight._grad_ivar() is not None) self.assertTrue((part2.gradient() == 0).all()) def test_auto_prune4(self): @@ -190,7 +190,7 @@ class TestImperativeAutoPrune(unittest.TestCase): v2 = fluid.dygraph.to_variable(value2) loss, part2 = case4(v1, v2, 1) part2.backward() - self.assertTrue(case4.fc._w._grad_ivar() is not None) + self.assertTrue(case4.fc.weight._grad_ivar() is not None) self.assertTrue((part2.gradient() == 1).all()) def test_auto_prune5(self): @@ -202,7 +202,7 @@ class TestImperativeAutoPrune(unittest.TestCase): v2 = fluid.dygraph.to_variable(value2) loss, part1, part2 = case4(v1, v2, 2) part1.backward() - self.assertTrue(case4.fc._w._grad_ivar() is not None) + self.assertTrue(case4.fc.weight._grad_ivar() is not None) self.assertTrue((part2.gradient() == 0).all()) def test_auto_prune6(self): @@ -220,7 +220,7 @@ class TestImperativeAutoPrune(unittest.TestCase): out1.stop_gradient = True out = fluid.layers.concat(input=[out1, out2, c], axis=1) out.backward() - self.assertTrue((fc._w.gradient() == 0).all()) + self.assertTrue((fc.weight.gradient() == 0).all()) self.assertTrue((out1.gradient() == 0).all()) def test_auto_prune7(self): @@ -239,7 +239,7 @@ class TestImperativeAutoPrune(unittest.TestCase): out = fluid.layers.concat(input=[out1, out2, c], axis=1) backward_strategy = fluid.dygraph.BackwardStrategy() out.backward(backward_strategy) - self.assertTrue((fc._w.gradient() == 0).all()) + self.assertTrue((fc.weight.gradient() == 0).all()) self.assertTrue((out1.gradient() == 0).all()) def test_auto_prune8(self): @@ -253,17 +253,17 @@ class TestImperativeAutoPrune(unittest.TestCase): b = fluid.dygraph.to_variable(value1) c = fluid.dygraph.to_variable(value2) out1 = fc(a) - fc_origin = fc._w.numpy() + fc_origin = fc.weight.numpy() out2 = fc2(out1) - fc2_origin = fc2._w.numpy() - fc2._w.stop_gradient = True + fc2_origin = fc2.weight.numpy() + fc2.weight.stop_gradient = True out2.backward() optimizer = fluid.optimizer.SGD( learning_rate=0.003, parameter_list=(fc.parameters() + fc2.parameters())) optimizer.minimize(out2) - self.assertTrue(np.array_equal(fc2_origin, fc2._w.numpy())) - self.assertFalse(np.array_equal(fc_origin, fc._w.numpy())) + self.assertTrue(np.array_equal(fc2_origin, fc2.weight.numpy())) + self.assertFalse(np.array_equal(fc_origin, fc.weight.numpy())) def test_auto_prune9(self): with fluid.dygraph.guard(): @@ -276,19 +276,19 @@ class TestImperativeAutoPrune(unittest.TestCase): b = fluid.dygraph.to_variable(value1) c = fluid.dygraph.to_variable(value2) out1 = fc(a) - fc_origin = fc._w.numpy() + fc_origin = fc.weight.numpy() out2 = fc2(out1) - fc2_origin = fc2._w.numpy() + fc2_origin = fc2.weight.numpy() out2.stop_gradient = True out2.backward() optimizer = fluid.optimizer.SGD( learning_rate=0.003, parameter_list=(fc.parameters() + fc2.parameters())) optimizer.minimize(out2) - self.assertTrue(np.array_equal(fc2_origin, fc2._w.numpy())) - self.assertTrue(np.array_equal(fc_origin, fc._w.numpy())) + self.assertTrue(np.array_equal(fc2_origin, fc2.weight.numpy())) + self.assertTrue(np.array_equal(fc_origin, fc.weight.numpy())) try: - fc2._w.gradient() + fc2.weight.gradient() except ValueError as e: assert type(e) == ValueError @@ -309,7 +309,7 @@ class TestImperativeAutoPrune(unittest.TestCase): backward_strategy = fluid.dygraph.BackwardStrategy() backward_strategy.sort_sum_gradient = True out.backward(backward_strategy) - self.assertTrue((fc._w.gradient() == 0).all()) + self.assertTrue((fc.weight.gradient() == 0).all()) self.assertTrue((out1.gradient() == 0).all()) def test_auto_prune_with_optimizer(self): @@ -336,10 +336,10 @@ class TestImperativeAutoPrune(unittest.TestCase): loss.backward() _, params_grads = optimizer.minimize(loss, grad_clip=grad_clip) for items in params_grads: - assert items[0].name is not model.embed1._w.name - assert items[0].name is not model.fc1._w.name - assert model.embed1._w._grad_ivar() is None - assert model.fc1._w._grad_ivar() is None + assert items[0].name is not model.embed1.weight.name + assert items[0].name is not model.fc1.weight.name + assert model.embed1.weight._grad_ivar() is None + assert model.fc1.weight._grad_ivar() is None with fluid.dygraph.guard(place): model = MyLayer2("mylayer", vocab_size, size) @@ -355,10 +355,10 @@ class TestImperativeAutoPrune(unittest.TestCase): loss.backward() optimizer.minimize(loss, grad_clip=grad_clip) for items in params_grads: - assert items[0].name is not model.embed1._w.name - assert items[0].name is not model.fc1._w.name - assert model.embed1._w._grad_ivar() is None - assert model.fc1._w._grad_ivar() is None + assert items[0].name is not model.embed1.weight.name + assert items[0].name is not model.fc1.weight.name + assert model.embed1.weight._grad_ivar() is None + assert model.fc1.weight._grad_ivar() is None def test_case2_prune_no_grad_branch(self): with fluid.dygraph.guard(): @@ -369,8 +369,8 @@ class TestImperativeAutoPrune(unittest.TestCase): case3 = AutoPruneLayer2("l2") loss = case3(v1, v2) loss.backward() - self.assertTrue(case3.fc2._w._grad_ivar() is None) - self.assertTrue(case3.fc._w._grad_ivar() is not None) + self.assertTrue(case3.fc2.weight._grad_ivar() is None) + self.assertTrue(case3.fc.weight._grad_ivar() is not None) def test_case2_prune_no_grad_branch(self): with fluid.dygraph.guard(): @@ -381,8 +381,8 @@ class TestImperativeAutoPrune(unittest.TestCase): case3 = AutoPruneLayer2("l2") loss = case3(v1, v2) loss.backward() - self.assertTrue(case3.fc2._w._grad_ivar() is None) - self.assertTrue(case3.fc._w._grad_ivar() is not None) + self.assertTrue(case3.fc2.weight._grad_ivar() is None) + self.assertTrue(case3.fc.weight._grad_ivar() is not None) def test_case3_prune_no_grad_branch2(self): with fluid.dygraph.guard(): @@ -395,7 +395,7 @@ class TestImperativeAutoPrune(unittest.TestCase): out = fluid.layers.one_hot(input=label, depth=100) loss = fluid.layers.mean(out) loss.backward() - self.assertTrue(fc._w._grad_ivar() is None) + self.assertTrue(fc.weight._grad_ivar() is None) def test_case4_with_no_grad_op_maker(self): with fluid.dygraph.guard(): diff --git a/python/paddle/fluid/tests/unittests/test_imperative_basic.py b/python/paddle/fluid/tests/unittests/test_imperative_basic.py index 7c3b32ebe69..8f1e2fdd2a3 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_basic.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_basic.py @@ -342,7 +342,7 @@ class TestImperative(unittest.TestCase): out = mlp(var_inp) dy_out = out.numpy() out.backward() - dy_grad = mlp._fc1._w.gradient() + dy_grad = mlp._fc1.weight.gradient() with fluid.dygraph.guard(): var_inp2 = fluid.dygraph.base.to_variable(np_inp) @@ -352,7 +352,7 @@ class TestImperative(unittest.TestCase): backward_strategy = fluid.dygraph.BackwardStrategy() backward_strategy.sort_sum_gradient = True out2.backward(backward_strategy) - dy_grad2 = mlp2._fc1._w.gradient() + dy_grad2 = mlp2._fc1.weight.gradient() with new_program_scope(): inp = fluid.layers.data( @@ -360,7 +360,7 @@ class TestImperative(unittest.TestCase): mlp = MLP("mlp") out = mlp(inp) param_grads = fluid.backward.append_backward( - out, parameter_list=[mlp._fc1._w.name])[0] + out, parameter_list=[mlp._fc1.weight.name])[0] exe = fluid.Executor(fluid.CPUPlace( ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) exe.run(fluid.default_startup_program()) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_lod_tensor_to_selected_rows.py b/python/paddle/fluid/tests/unittests/test_imperative_lod_tensor_to_selected_rows.py index 9bd6f039d91..c7a9e202e1f 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_lod_tensor_to_selected_rows.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_lod_tensor_to_selected_rows.py @@ -59,7 +59,7 @@ class SimpleNet(fluid.Layer): x_emb = self.embedding(input) projection = fluid.layers.matmul( x_emb, fluid.layers.transpose( - self.embedding._w, perm=[1, 0])) + self.embedding.weight, perm=[1, 0])) projection = fluid.layers.elementwise_add(projection, self.softmax_bias) projection = fluid.layers.reshape( projection, shape=[-1, self.vocab_size]) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_selected_rows.py b/python/paddle/fluid/tests/unittests/test_imperative_selected_rows.py index 6a9c20a53d2..5f5656deac3 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_selected_rows.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_selected_rows.py @@ -61,7 +61,7 @@ class TestSimpleNet(unittest.TestCase): input_emb, emb = simplenet(input) try: - emb._w.gradient() + emb.weight.gradient() except ValueError as e: assert "has no grad, Please set Variable.stop_gradient=False, or check if this is the first and only variable need grad, if so, please set its pre-Variable's stop_gradient=False, to make sure it has gradient" in str( e) @@ -73,11 +73,11 @@ class TestSimpleNet(unittest.TestCase): input_emb.backward(backward_strategy) adam.minimize(input_emb) # grad_clip=grad_clip - emb._w.gradient() + emb.weight.gradient() emb.clear_gradients() try: - emb._w.gradient() + emb.weight.gradient() except ValueError as e: assert "has no grad, Please set Variable.stop_gradient=False, or check if this is the first and only variable need grad, if so, please set its pre-Variable's stop_gradient=False, to make sure it has gradient" in str( e) @@ -108,7 +108,7 @@ class TestSimpleNet(unittest.TestCase): input_emb, emb = simplenet(input) try: - emb._w.gradient() + emb.weight.gradient() except ValueError as e: assert "has no grad, Please set Variable.stop_gradient=False, or check if this is the first and only variable need grad, if so, please set its pre-Variable's stop_gradient=False, to make sure it has gradient" in str( e) @@ -120,11 +120,11 @@ class TestSimpleNet(unittest.TestCase): input_emb.backward(backward_strategy) adam.minimize(input_emb, grad_clip=grad_clip) - emb._w.gradient() + emb.weight.gradient() emb.clear_gradients() try: - emb._w.gradient() + emb.weight.gradient() except ValueError as e: assert "has no grad, Please set Variable.stop_gradient=False, or check if this is the first and only variable need grad, if so, please set its pre-Variable's stop_gradient=False, to make sure it has gradient" in str( e) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_selected_rows_to_lod_tensor.py b/python/paddle/fluid/tests/unittests/test_imperative_selected_rows_to_lod_tensor.py index 4d06d6da209..3db655b788e 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_selected_rows_to_lod_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_selected_rows_to_lod_tensor.py @@ -66,7 +66,7 @@ class SimpleNet(fluid.Layer): fc = fluid.layers.elementwise_add(fc, self.softmax_bias) projection = fluid.layers.matmul( fc, fluid.layers.transpose( - self.embedding._w, perm=[1, 0])) + self.embedding.weight, perm=[1, 0])) projection = fluid.layers.reshape( projection, shape=[-1, self.vocab_size]) loss = fluid.layers.softmax_with_cross_entropy( diff --git a/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py b/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py index d5dc6e84782..c57d10a24aa 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py @@ -843,7 +843,7 @@ class WrapDecoderLayer(Layer): if self._weight_sharing: predict = fluid.layers.matmul( x=dec_output_reshape, - y=self._prepare_decoder_layer._input_emb._w, + y=self._prepare_decoder_layer._input_emb.weight, transpose_y=True) else: predict = self._fc(dec_output_reshape) @@ -917,7 +917,7 @@ class TransFormer(Layer): is_sparse=is_sparse) if weight_sharing: - self._wrap_decoder_layer._prepare_decoder_layer._input_emb._w = self._wrap_encoder_layer._prepare_encoder_layer._input_emb._w + self._wrap_decoder_layer._prepare_decoder_layer._input_emb.weight = self._wrap_encoder_layer._prepare_encoder_layer._input_emb.weight def forward(self, enc_inputs, dec_inputs, label, weights): enc_output = self._wrap_encoder_layer(enc_inputs) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index c58ba9e3873..c5eb2a9be9e 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -368,7 +368,7 @@ class TestLayer(LayerTest): filter_size=[2, 2], bias_attr=False) dy_ret = conv2d(base.to_variable(images)) - self.assertTrue(conv2d._bias_param is None) + self.assertTrue(conv2d.bias is None) self.assertTrue(np.allclose(static_ret, dy_ret_value)) self.assertTrue(np.allclose(static_ret, static_ret2)) -- GitLab