未验证 提交 292fd200 编写于 作者: J JYChen 提交者: GitHub

fix trainable (#56104)

上级 4036c937
...@@ -7170,6 +7170,8 @@ class Parameter(Variable, metaclass=ParameterMetaClass): ...@@ -7170,6 +7170,8 @@ class Parameter(Variable, metaclass=ParameterMetaClass):
) )
self.trainable = kwargs.get('trainable', True) self.trainable = kwargs.get('trainable', True)
self.stop_gradient = not self.trainable
self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0}) self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0})
self.regularizer = kwargs.get('regularizer', None) self.regularizer = kwargs.get('regularizer', None)
......
...@@ -20,6 +20,8 @@ from simple_nets import init_data ...@@ -20,6 +20,8 @@ from simple_nets import init_data
import paddle import paddle
from paddle import fluid from paddle import fluid
paddle.enable_static()
def test_trainable(): def test_trainable():
x = paddle.static.data(name='image', shape=[-1, 784], dtype='float32') x = paddle.static.data(name='image', shape=[-1, 784], dtype='float32')
...@@ -68,12 +70,12 @@ class TestTrainable(unittest.TestCase): ...@@ -68,12 +70,12 @@ class TestTrainable(unittest.TestCase):
self.check_trainable( self.check_trainable(
test_trainable, test_trainable,
feed_dict, feed_dict,
op_count={'adam': 1, 'scale': 0, 'mul_grad': 1}, op_count={'adam': 1, 'scale': 0, 'mul_grad': 0},
) )
self.check_trainable( self.check_trainable(
test_trainable, test_trainable,
feed_dict, feed_dict,
op_count={'adamax': 1, 'scale': 1, 'mul_grad': 1}, op_count={'adamax': 1, 'scale': 1, 'mul_grad': 0},
optimizer=paddle.optimizer.Adamax(learning_rate=0.2), optimizer=paddle.optimizer.Adamax(learning_rate=0.2),
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册