diff --git a/python/paddle/fluid/tests/unittests/test_adamw_op.py b/python/paddle/fluid/tests/unittests/test_adamw_op.py index e7033d845116afaa60c07ad6c4aabb866fca98b7..9b77dae1afed2d58601724fed033119cffe6a8e6 100644 --- a/python/paddle/fluid/tests/unittests/test_adamw_op.py +++ b/python/paddle/fluid/tests/unittests/test_adamw_op.py @@ -98,16 +98,27 @@ class TestAdamWOp(unittest.TestCase): value = np.arange(26).reshape(2, 13).astype("float32") a = paddle.to_tensor(value) linear = paddle.nn.Linear(13, 5) + + lr = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=10) + wd = 0.1 adam = paddle.optimizer.AdamW( - learning_rate=paddle.optimizer.lr.NoamDecay( - d_model=512, warmup_steps=4000), + learning_rate=lr, parameters=linear.parameters(), apply_decay_param_fun=lambda name: True, - weight_decay=0.01) - out = linear(a) - out.backward() - adam.step() - adam.clear_gradients() + weight_decay=wd) + + for _ in range(2): + out = linear(a) + out.backward() + lr_to_coeff = adam._lr_to_coeff + adam.step() + + for i, value in enumerate(lr_to_coeff.values()): + self.assertAlmostEqual(value.numpy()[0], 1.0 - lr() * wd) + self.assertEqual(len(adam._lr_to_coeff), 0) + + lr.step() + adam.clear_gradients() if __name__ == "__main__": diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index ff560e81343765376d6997755e7e08bec9052129..cd3955d5f06d7846151527663bb96a90d20f0ddd 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -173,7 +173,10 @@ class AdamW(Adam): [param, grad]), framework.name_scope('weight decay'): self._params_name.add(param.name) - # If it has been calculated, the result will be reused + # If it has been calculated, the result will be reused. + # NOTE(wangxi): In dygraph mode, apply_gradient will be executed + # every step, so need clear _lr_to_coeff every step, + # we do this in _create_optimization_pass decay_coeff = self._lr_to_coeff.get(learning_rate, None) if decay_coeff is None: decay_coeff = 1.0 - learning_rate * self._coeff @@ -186,5 +189,12 @@ class AdamW(Adam): self._append_decoupled_weight_decay(block, param_and_grad) return super(AdamW, self)._append_optimize_op(block, param_and_grad) + def _create_optimization_pass(self, parameters_and_grads): + optimize_ops = super( + AdamW, self)._create_optimization_pass(parameters_and_grads) + # In dygraph mode, clear _lr_to_coeff after applied gradient + self._lr_to_coeff = dict() + return optimize_ops + def __str__(self): return " ".join(["Weight Decay, params:", ",".join(self._params_name)])