未验证 提交 7a0a576e 编写于 作者: W WangXi 提交者: GitHub

fix adamw lr_to_coeff is fixed when dygraph (#30526)

上级 59ad6ff3
......@@ -98,16 +98,27 @@ class TestAdamWOp(unittest.TestCase):
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.to_tensor(value)
linear = paddle.nn.Linear(13, 5)
lr = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=10)
wd = 0.1
adam = paddle.optimizer.AdamW(
learning_rate=paddle.optimizer.lr.NoamDecay(
d_model=512, warmup_steps=4000),
learning_rate=lr,
parameters=linear.parameters(),
apply_decay_param_fun=lambda name: True,
weight_decay=0.01)
out = linear(a)
out.backward()
adam.step()
adam.clear_gradients()
weight_decay=wd)
for _ in range(2):
out = linear(a)
out.backward()
lr_to_coeff = adam._lr_to_coeff
adam.step()
for i, value in enumerate(lr_to_coeff.values()):
self.assertAlmostEqual(value.numpy()[0], 1.0 - lr() * wd)
self.assertEqual(len(adam._lr_to_coeff), 0)
lr.step()
adam.clear_gradients()
if __name__ == "__main__":
......
......@@ -173,7 +173,10 @@ class AdamW(Adam):
[param, grad]), framework.name_scope('weight decay'):
self._params_name.add(param.name)
# If it has been calculated, the result will be reused
# If it has been calculated, the result will be reused.
# NOTE(wangxi): In dygraph mode, apply_gradient will be executed
# every step, so need clear _lr_to_coeff every step,
# we do this in _create_optimization_pass
decay_coeff = self._lr_to_coeff.get(learning_rate, None)
if decay_coeff is None:
decay_coeff = 1.0 - learning_rate * self._coeff
......@@ -186,5 +189,12 @@ class AdamW(Adam):
self._append_decoupled_weight_decay(block, param_and_grad)
return super(AdamW, self)._append_optimize_op(block, param_and_grad)
def _create_optimization_pass(self, parameters_and_grads):
optimize_ops = super(
AdamW, self)._create_optimization_pass(parameters_and_grads)
# In dygraph mode, clear _lr_to_coeff after applied gradient
self._lr_to_coeff = dict()
return optimize_ops
def __str__(self):
return " ".join(["Weight Decay, params:", ",".join(self._params_name)])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册