未验证 提交 0ed80e09 编写于 作者: C Chen Weihang 提交者: GitHub

Fix param base trainable set failed (#28756)

* fix param base trainable set failed

* add unittest

* fix typo

* polish comment
上级 b969c32a
......@@ -2858,6 +2858,12 @@ class Block(object):
param = ParamBase(*args, **kwargs)
else:
param = Parameter(global_block, *args, **kwargs)
# NOTE: Why only set stop_gradient=False in static mode
# Because in dygraph mode, the `stop_gradient` and `trainable`
# are related, and `trainable` default vallue is `True` or
# it is specified by users, there is no need to set
# `stop_gradient` for ParamBase here.
param.stop_gradient = False
if 'initializer' in kwargs:
def _is_inited_by(block, var):
......@@ -2884,7 +2890,6 @@ class Block(object):
pass
else:
initializer(param, self)
param.stop_gradient = False
return param
def append_op(self, *args, **kwargs):
......
......@@ -3683,5 +3683,24 @@ class TestMetricsDetectionMap(unittest.TestCase):
print(str(program))
class ExampleNet(paddle.nn.Layer):
def __init__(self):
super(ExampleNet, self).__init__()
self.weight = self.create_parameter(
shape=[1, 1], attr=paddle.ParamAttr(trainable=False))
def forward(self):
# only for test parameter trainable attr
pass
class TestLayerParameterTrainableSet(unittest.TestCase):
def test_layer_parameter_set(self):
with fluid.dygraph.guard():
net = ExampleNet()
self.assertFalse(net.weight.trainable)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册