From 0ed80e09fcbac3d62e35dc07fa451ce1a32d4eb3 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Fri, 20 Nov 2020 19:48:38 +0800 Subject: [PATCH] Fix param base trainable set failed (#28756) * fix param base trainable set failed * add unittest * fix typo * polish comment --- python/paddle/fluid/framework.py | 7 ++++++- .../fluid/tests/unittests/test_layers.py | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 317cae815f4..2c9e9a12b05 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -2858,6 +2858,12 @@ class Block(object): param = ParamBase(*args, **kwargs) else: param = Parameter(global_block, *args, **kwargs) + # NOTE: Why only set stop_gradient=False in static mode + # Because in dygraph mode, the `stop_gradient` and `trainable` + # are related, and `trainable` default vallue is `True` or + # it is specified by users, there is no need to set + # `stop_gradient` for ParamBase here. + param.stop_gradient = False if 'initializer' in kwargs: def _is_inited_by(block, var): @@ -2884,7 +2890,6 @@ class Block(object): pass else: initializer(param, self) - param.stop_gradient = False return param def append_op(self, *args, **kwargs): diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 3908d65229a..8ae5264381e 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -3683,5 +3683,24 @@ class TestMetricsDetectionMap(unittest.TestCase): print(str(program)) +class ExampleNet(paddle.nn.Layer): + def __init__(self): + super(ExampleNet, self).__init__() + self.weight = self.create_parameter( + shape=[1, 1], attr=paddle.ParamAttr(trainable=False)) + + def forward(self): + # only for test parameter trainable attr + pass + + +class TestLayerParameterTrainableSet(unittest.TestCase): + def test_layer_parameter_set(self): + with fluid.dygraph.guard(): + net = ExampleNet() + self.assertFalse(net.weight.trainable) + + if __name__ == '__main__': + paddle.enable_static() unittest.main() -- GitLab