未验证 提交 30416052 编写于 作者: W wangguanzhong 提交者: GitHub

fix lr in param group (#34468)

* fix lr in param group

* add unittest for adamw
上级 423ea978
......@@ -147,5 +147,33 @@ class TestAdamWOpGroup(TestAdamWOp):
adam.clear_gradients()
class TestAdamWOpGroupWithLR(TestAdamWOp):
def test_adamw_op_dygraph(self):
paddle.disable_static()
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.to_tensor(value)
linear_1 = paddle.nn.Linear(13, 5)
linear_2 = paddle.nn.Linear(5, 3)
adam = paddle.optimizer.AdamW(
learning_rate=paddle.optimizer.lr.PiecewiseDecay(
boundaries=[3, 6], values=[0.1, 0.2, 0.3]),
parameters=[{
'params': linear_1.parameters(),
'learning_rate': 0.1,
}, {
'params': linear_2.parameters(),
'weight_decay': 0.001,
}],
apply_decay_param_fun=lambda name: True,
weight_decay=0.01)
for _ in range(2):
out = linear_1(a)
out = linear_2(out)
out.backward()
adam.step()
adam.clear_gradients()
if __name__ == "__main__":
unittest.main()
......@@ -185,10 +185,9 @@ class AdamW(Adam):
Raises:
Exception: The type of coeff and parameter is not consistent.
"""
if not isinstance(param_and_grad, dict):
param, grad = param_and_grad
else:
param, grad = self._update_param_group(param_and_grad)
if isinstance(param_and_grad, dict):
param_and_grad = self._update_param_group(param_and_grad)
param, grad = param_and_grad
if self._apply_decay_param_fun is not None \
and not self._apply_decay_param_fun(param.name):
......
......@@ -206,7 +206,6 @@ class Optimizer(object):
self._param_device_map = dict()
self.clear_gradients = self.clear_grad
self._default_dict = {
'learning_rate': self._learning_rate,
'weight_decay': self.regularization,
'grad_clip': self._grad_clip
}
......@@ -1190,7 +1189,8 @@ class Optimizer(object):
else:
regularization = weight_decay
param.regularizer = regularization
param.optimize_attr['learning_rate'] = param_group['learning_rate']
param.optimize_attr['learning_rate'] = param_group.get(
'learning_rate', 1.)
self._param_groups.append(param_group)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册