未验证 提交 99c593bc 编写于 作者: Z zqw_1997 提交者: GitHub

Add set_lr_scheduler api (#54752)

* demo1

* add test cases

* modify the usage of StepDecay

* refine
上级 63f242b6
......@@ -553,6 +553,52 @@ class Optimizer:
stop_gradient=True,
)
@framework.dygraph_only
def set_lr_scheduler(self, scheduler):
"""
:api_attr: imperative
Set the LRScheduler of the learning rate manually in the optimizer. If the optimizer already used LRScheduler previously,
this API will set it be the new one.
Args:
scheduler (LRScheduler): the LRScheduler of learning rate
Returns:
None
Examples:
.. code-block:: python
import paddle
linear = paddle.nn.Linear(10, 10)
adam = paddle.optimizer.Adam(0.1, parameters=linear.parameters())
# set learning rate manually by class LRScheduler
scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2,4,6], gamma=0.8)
adam.set_lr_scheduler(scheduler)
lr = adam.get_lr()
print("current lr is {}".format(lr))
# current lr is 0.5
# set learning rate manually by another LRScheduler
scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.1, step_size=5, gamma=0.6)
adam.set_lr_scheduler(scheduler)
lr = adam.get_lr()
print("current lr is {}".format(lr))
# current lr is 0.1
"""
from paddle.optimizer.lr import LRScheduler
if not isinstance(scheduler, LRScheduler):
raise TypeError(
"The type of 'scheduler' in optimizer.set_lr_schduler must be LRScheduler, but received %s."
% (type(scheduler))
)
self._learning_rate = scheduler
def get_lr(self):
"""
Get current learning rate of optimizer.
......
......@@ -656,6 +656,42 @@ class TestOptimizerLearningRate(unittest.TestCase):
)
adam.set_lr(0.01)
def test_set_lr_scheduler(self):
with fluid.dygraph.guard():
a = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
a = fluid.dygraph.to_variable(a)
b = linear(a)
loss = paddle.mean(b)
adam = paddle.optimizer.Adam(0.1, parameters=linear.parameters())
# float to LRScheduler
scheduler = paddle.optimizer.lr.StepDecay(
learning_rate=0.2, step_size=5, gamma=0.6
)
adam.set_lr_scheduler(scheduler)
adam.minimize(loss)
lr = adam.get_lr()
np.testing.assert_allclose(lr, 0.2, rtol=1e-06, atol=0.0)
# LRScheduler to another LRScheduler
scheduler = paddle.optimizer.lr.MultiStepDecay(
learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8
)
adam.set_lr_scheduler(scheduler)
adam.minimize(loss)
lr = adam.get_lr()
np.testing.assert_allclose(lr, 0.5, rtol=1e-06, atol=0.0)
with self.assertRaises(TypeError):
scheduler_var = paddle.fluid.dygraph.StepDecay(0.5, step_size=3)
adam.set_lr_scheduler(scheduler_var)
class TestImperativeMomentumOptimizer(TestImperativeOptimizerBase):
def get_optimizer_dygraph(self, parameter_list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册