Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
72efef63
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
696
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
72efef63
编写于
2月 27, 2019
作者:
R
ruri
提交者:
GitHub
2月 27, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #15887 from shippingwang/cosine_decay_op
add cosine decay op, test=develop
上级
e40d56c3
733da7b2
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
51 addition
and
1 deletion
+51
-1
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-0
python/paddle/fluid/layers/learning_rate_scheduler.py
python/paddle/fluid/layers/learning_rate_scheduler.py
+38
-1
python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py
...dle/fluid/tests/unittests/test_learning_rate_scheduler.py
+12
-0
未找到文件。
paddle/fluid/API.spec
浏览文件 @
72efef63
...
...
@@ -337,6 +337,7 @@ paddle.fluid.layers.polynomial_decay ArgSpec(args=['learning_rate', 'decay_steps
paddle.fluid.layers.piecewise_decay ArgSpec(args=['boundaries', 'values'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.noam_decay ArgSpec(args=['d_model', 'warmup_steps'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.append_LARS ArgSpec(args=['params_grads', 'learning_rate', 'weight_decay'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.cosine_decay ArgSpec(args=['learning_rate', 'step_each_epoch', 'epochs'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.InitState.__init__ ArgSpec(args=['self', 'init', 'shape', 'value', 'init_boot', 'need_reorder', 'dtype'], varargs=None, keywords=None, defaults=(None, None, 0.0, None, False, 'float32'))
paddle.fluid.contrib.StateCell.__init__ ArgSpec(args=['self', 'inputs', 'states', 'out_state', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.contrib.StateCell.compute_state ArgSpec(args=['self', 'inputs'], varargs=None, keywords=None, defaults=None)
...
...
python/paddle/fluid/layers/learning_rate_scheduler.py
浏览文件 @
72efef63
...
...
@@ -28,10 +28,12 @@ from . import ops
from
.
import
tensor
from
..initializer
import
init_on_cpu
from
..framework
import
default_main_program
,
Parameter
,
unique_name
,
name_scope
import
math
__all__
=
[
'exponential_decay'
,
'natural_exp_decay'
,
'inverse_time_decay'
,
'polynomial_decay'
,
'piecewise_decay'
,
'noam_decay'
,
'append_LARS'
'polynomial_decay'
,
'piecewise_decay'
,
'noam_decay'
,
'append_LARS'
,
'cosine_decay'
]
...
...
@@ -307,6 +309,41 @@ def piecewise_decay(boundaries, values):
return
lr
def
cosine_decay
(
learning_rate
,
step_each_epoch
,
epochs
):
"""
Applies cosine decay to the learning rate.
when training a model, it is often recommended to lower the learning rate as the
training progresses. By using this function, the learning rate will be decayed by
following cosine decay strategy.
decayed_lr = learning_rate * 0.5 * (math.cos(epoch * math.pi / epochs) + 1)
Args:
learning_rate(Variable|float): The initial learning rate.
step_each_epoch(int): the number of steps in an epoch.
epochs(int): the number of epochs.
Returns:
Variable: The decayed learning rate.
Examples:
..code-block:: python
base_lr = 0.1
lr = fluid.layers.cosine_decay(
learning_rate = base_lr, step_each_epoch=10000, epochs=120)
"""
with
default_main_program
().
_lr_schedule_guard
():
global_step
=
_decay_step_counter
()
cur_epoch
=
ops
.
floor
(
global_step
/
step_each_epoch
)
decayed_lr
=
learning_rate
*
0.5
*
(
ops
.
cos
(
cur_epoch
*
math
.
pi
/
epochs
)
+
1
)
return
decayed_lr
def
append_LARS
(
params_grads
,
learning_rate
,
weight_decay
):
"""
Applies LARS (LAYER-WISE ADAPTIVE RATE SCALING) to learning rate for
...
...
python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py
浏览文件 @
72efef63
...
...
@@ -82,6 +82,13 @@ def piecewise_decay(global_step, boundaries, values):
return
values
[
len
(
values
)
-
1
]
def
cosine_decay
(
global_step
,
learning_rate
,
step_each_epoch
,
epochs
):
cur_epoch
=
math
.
floor
(
global_step
/
step_each_epoch
)
decayed_lr
=
learning_rate
*
0.5
*
(
math
.
cos
(
cur_epoch
*
math
.
pi
/
epochs
)
+
1
)
return
decayed_lr
class
TestLearningRateDecay
(
unittest
.
TestCase
):
def
check_decay
(
self
,
python_decay_fn
,
fluid_decay_fn
,
kwargs
):
places
=
[
fluid
.
CPUPlace
()]
...
...
@@ -149,6 +156,11 @@ class TestLearningRateDecay(unittest.TestCase):
"boundaries"
:
[
3
,
6
,
9
],
"values"
:
[
0.1
,
0.2
,
0.3
,
0.4
]
}),
(
cosine_decay
,
layers
.
cosine_decay
,
{
"learning_rate"
:
0.1
,
"step_each_epoch"
:
100
,
"epochs"
:
120
}),
]
for
py_decay_fn
,
fluid_decay_fn
,
kwargs
in
decay_fns
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录