未验证 提交 292fd200 编写于 作者: J JYChen 提交者: GitHub

fix trainable (#56104)

上级 4036c937
......@@ -7170,6 +7170,8 @@ class Parameter(Variable, metaclass=ParameterMetaClass):
)
self.trainable = kwargs.get('trainable', True)
self.stop_gradient = not self.trainable
self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0})
self.regularizer = kwargs.get('regularizer', None)
......
......@@ -20,6 +20,8 @@ from simple_nets import init_data
import paddle
from paddle import fluid
paddle.enable_static()
def test_trainable():
x = paddle.static.data(name='image', shape=[-1, 784], dtype='float32')
......@@ -68,12 +70,12 @@ class TestTrainable(unittest.TestCase):
self.check_trainable(
test_trainable,
feed_dict,
op_count={'adam': 1, 'scale': 0, 'mul_grad': 1},
op_count={'adam': 1, 'scale': 0, 'mul_grad': 0},
)
self.check_trainable(
test_trainable,
feed_dict,
op_count={'adamax': 1, 'scale': 1, 'mul_grad': 1},
op_count={'adamax': 1, 'scale': 1, 'mul_grad': 0},
optimizer=paddle.optimizer.Adamax(learning_rate=0.2),
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册