Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
99c593bc
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
99c593bc
编写于
6月 28, 2023
作者:
Z
zqw_1997
提交者:
GitHub
6月 28, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add set_lr_scheduler api (#54752)
* demo1 * add test cases * modify the usage of StepDecay * refine
上级
63f242b6
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
82 addition
and
0 deletion
+82
-0
python/paddle/optimizer/optimizer.py
python/paddle/optimizer/optimizer.py
+46
-0
test/legacy_test/test_imperative_optimizer_v2.py
test/legacy_test/test_imperative_optimizer_v2.py
+36
-0
未找到文件。
python/paddle/optimizer/optimizer.py
浏览文件 @
99c593bc
...
...
@@ -553,6 +553,52 @@ class Optimizer:
stop_gradient
=
True
,
)
@
framework
.
dygraph_only
def
set_lr_scheduler
(
self
,
scheduler
):
"""
:api_attr: imperative
Set the LRScheduler of the learning rate manually in the optimizer. If the optimizer already used LRScheduler previously,
this API will set it be the new one.
Args:
scheduler (LRScheduler): the LRScheduler of learning rate
Returns:
None
Examples:
.. code-block:: python
import paddle
linear = paddle.nn.Linear(10, 10)
adam = paddle.optimizer.Adam(0.1, parameters=linear.parameters())
# set learning rate manually by class LRScheduler
scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2,4,6], gamma=0.8)
adam.set_lr_scheduler(scheduler)
lr = adam.get_lr()
print("current lr is {}".format(lr))
# current lr is 0.5
# set learning rate manually by another LRScheduler
scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.1, step_size=5, gamma=0.6)
adam.set_lr_scheduler(scheduler)
lr = adam.get_lr()
print("current lr is {}".format(lr))
# current lr is 0.1
"""
from
paddle.optimizer.lr
import
LRScheduler
if
not
isinstance
(
scheduler
,
LRScheduler
):
raise
TypeError
(
"The type of 'scheduler' in optimizer.set_lr_schduler must be LRScheduler, but received %s."
%
(
type
(
scheduler
))
)
self
.
_learning_rate
=
scheduler
def
get_lr
(
self
):
"""
Get current learning rate of optimizer.
...
...
test/legacy_test/test_imperative_optimizer_v2.py
浏览文件 @
99c593bc
...
...
@@ -656,6 +656,42 @@ class TestOptimizerLearningRate(unittest.TestCase):
)
adam
.
set_lr
(
0.01
)
def
test_set_lr_scheduler
(
self
):
with
fluid
.
dygraph
.
guard
():
a
=
np
.
random
.
uniform
(
-
0.1
,
0.1
,
[
10
,
10
]).
astype
(
"float32"
)
linear
=
paddle
.
nn
.
Linear
(
10
,
10
)
a
=
fluid
.
dygraph
.
to_variable
(
a
)
b
=
linear
(
a
)
loss
=
paddle
.
mean
(
b
)
adam
=
paddle
.
optimizer
.
Adam
(
0.1
,
parameters
=
linear
.
parameters
())
# float to LRScheduler
scheduler
=
paddle
.
optimizer
.
lr
.
StepDecay
(
learning_rate
=
0.2
,
step_size
=
5
,
gamma
=
0.6
)
adam
.
set_lr_scheduler
(
scheduler
)
adam
.
minimize
(
loss
)
lr
=
adam
.
get_lr
()
np
.
testing
.
assert_allclose
(
lr
,
0.2
,
rtol
=
1e-06
,
atol
=
0.0
)
# LRScheduler to another LRScheduler
scheduler
=
paddle
.
optimizer
.
lr
.
MultiStepDecay
(
learning_rate
=
0.5
,
milestones
=
[
2
,
4
,
6
],
gamma
=
0.8
)
adam
.
set_lr_scheduler
(
scheduler
)
adam
.
minimize
(
loss
)
lr
=
adam
.
get_lr
()
np
.
testing
.
assert_allclose
(
lr
,
0.5
,
rtol
=
1e-06
,
atol
=
0.0
)
with
self
.
assertRaises
(
TypeError
):
scheduler_var
=
paddle
.
fluid
.
dygraph
.
StepDecay
(
0.5
,
step_size
=
3
)
adam
.
set_lr_scheduler
(
scheduler_var
)
class
TestImperativeMomentumOptimizer
(
TestImperativeOptimizerBase
):
def
get_optimizer_dygraph
(
self
,
parameter_list
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录