未验证 提交 436144e9 编写于 作者: W WangXi 提交者: GitHub

fix adamw lr_to_coeff is fixed when dygraph (#30526) (#30559)

上级 832032c2
...@@ -98,15 +98,26 @@ class TestAdamWOp(unittest.TestCase): ...@@ -98,15 +98,26 @@ class TestAdamWOp(unittest.TestCase):
value = np.arange(26).reshape(2, 13).astype("float32") value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.to_tensor(value) a = paddle.to_tensor(value)
linear = paddle.nn.Linear(13, 5) linear = paddle.nn.Linear(13, 5)
lr = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=10)
wd = 0.1
adam = paddle.optimizer.AdamW( adam = paddle.optimizer.AdamW(
learning_rate=paddle.optimizer.lr.NoamDecay( learning_rate=lr,
d_model=512, warmup_steps=4000),
parameters=linear.parameters(), parameters=linear.parameters(),
apply_decay_param_fun=lambda name: True, apply_decay_param_fun=lambda name: True,
weight_decay=0.01) weight_decay=wd)
for _ in range(2):
out = linear(a) out = linear(a)
out.backward() out.backward()
lr_to_coeff = adam._lr_to_coeff
adam.step() adam.step()
for i, value in enumerate(lr_to_coeff.values()):
self.assertAlmostEqual(value.numpy()[0], 1.0 - lr() * wd)
self.assertEqual(len(adam._lr_to_coeff), 0)
lr.step()
adam.clear_gradients() adam.clear_gradients()
......
...@@ -173,7 +173,10 @@ class AdamW(Adam): ...@@ -173,7 +173,10 @@ class AdamW(Adam):
[param, grad]), framework.name_scope('weight decay'): [param, grad]), framework.name_scope('weight decay'):
self._params_name.add(param.name) self._params_name.add(param.name)
# If it has been calculated, the result will be reused # If it has been calculated, the result will be reused.
# NOTE(wangxi): In dygraph mode, apply_gradient will be executed
# every step, so need clear _lr_to_coeff every step,
# we do this in _create_optimization_pass
decay_coeff = self._lr_to_coeff.get(learning_rate, None) decay_coeff = self._lr_to_coeff.get(learning_rate, None)
if decay_coeff is None: if decay_coeff is None:
decay_coeff = 1.0 - learning_rate * self._coeff decay_coeff = 1.0 - learning_rate * self._coeff
...@@ -186,5 +189,12 @@ class AdamW(Adam): ...@@ -186,5 +189,12 @@ class AdamW(Adam):
self._append_decoupled_weight_decay(block, param_and_grad) self._append_decoupled_weight_decay(block, param_and_grad)
return super(AdamW, self)._append_optimize_op(block, param_and_grad) return super(AdamW, self)._append_optimize_op(block, param_and_grad)
def _create_optimization_pass(self, parameters_and_grads):
optimize_ops = super(
AdamW, self)._create_optimization_pass(parameters_and_grads)
# In dygraph mode, clear _lr_to_coeff after applied gradient
self._lr_to_coeff = dict()
return optimize_ops
def __str__(self): def __str__(self):
return " ".join(["Weight Decay, params:", ",".join(self._params_name)]) return " ".join(["Weight Decay, params:", ",".join(self._params_name)])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册