未验证 提交 b96869bc 编写于 作者: G Guo Sheng 提交者: GitHub

Fix lr setting of AdamW when lr is an instance of LRScheduler (#28300)

* Fix lr setting of AdamW when lr is an instance of LRScheduler.
test=develop

* Fix static graph test mode in test_adamw_op.py.
test=develop
上级 57e4411a
......@@ -47,6 +47,7 @@ class TestAdamWOp(unittest.TestCase):
assert (adam.__str__() is not None)
def test_adamw_op(self):
paddle.enable_static()
place = fluid.CPUPlace()
shape = [2, 3, 8, 8]
exe = fluid.Executor(place)
......@@ -75,6 +76,7 @@ class TestAdamWOp(unittest.TestCase):
data_np = np.random.random(shape).astype('float32')
rets = exe.run(train_prog, feed={"data": data_np}, fetch_list=[loss])
assert rets[0] is not None
paddle.disable_static()
def test_adamw_op_invalid_input(self):
paddle.disable_static()
......@@ -89,6 +91,22 @@ class TestAdamWOp(unittest.TestCase):
adam = paddle.optimizer.AdamW(
0.1, epsilon=-1, parameters=linear.parameters())
def test_adamw_lr_decay(self):
paddle.disable_static()
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.to_tensor(value)
linear = paddle.nn.Linear(13, 5)
adam = paddle.optimizer.AdamW(
learning_rate=paddle.optimizer.lr.NoamDecay(
d_model=512, warmup_steps=4000),
parameters=linear.parameters(),
apply_decay_param_fun=lambda name: True,
weight_decay=0.01)
out = linear(a)
out.backward()
adam.step()
adam.clear_gradients()
if __name__ == "__main__":
unittest.main()
......@@ -57,7 +57,7 @@ class AdamW(Adam):
The default value is 1e-08.
weight_decay (float|Tensor, optional): The weight decay coefficient, it can be float or Tensor. The default value is 0.01.
apply_decay_param_fun (function|None, optional): If it is not None,
only tensors that makes apply_decay_param_fun(Tensor)==True
only tensors that makes apply_decay_param_fun(Tensor.name)==True
will be updated. It only works when we want to specify tensors.
Default: None.
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
......@@ -168,7 +168,7 @@ class AdamW(Adam):
if isinstance(self._learning_rate, float):
learning_rate = self._learning_rate
else:
self._learning_rate()
learning_rate = self._learning_rate()
with param.block.program._optimized_guard(
[param, grad]), framework.name_scope('weight decay'):
if param.name not in self._params_name:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册