Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
602cb6a5
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
602cb6a5
编写于
7月 04, 2019
作者:
Q
qingqing01
提交者:
GitHub
7月 04, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Enhance linear_lr_warmup (#18463)
* make it support float/int learning as input.
上级
74538573
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
69 addition
and
8 deletion
+69
-8
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-1
python/paddle/fluid/layers/learning_rate_scheduler.py
python/paddle/fluid/layers/learning_rate_scheduler.py
+13
-6
python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py
...dle/fluid/tests/unittests/test_learning_rate_scheduler.py
+55
-1
未找到文件。
paddle/fluid/API.spec
浏览文件 @
602cb6a5
...
...
@@ -406,7 +406,7 @@ paddle.fluid.layers.polynomial_decay (ArgSpec(args=['learning_rate', 'decay_step
paddle.fluid.layers.piecewise_decay (ArgSpec(args=['boundaries', 'values'], varargs=None, keywords=None, defaults=None), ('document', 'd9f654117542c6b702963dda107a247f'))
paddle.fluid.layers.noam_decay (ArgSpec(args=['d_model', 'warmup_steps'], varargs=None, keywords=None, defaults=None), ('document', 'fd57228fb76195e66bbcc8d8e42c494d'))
paddle.fluid.layers.cosine_decay (ArgSpec(args=['learning_rate', 'step_each_epoch', 'epochs'], varargs=None, keywords=None, defaults=None), ('document', 'f0d65d8c89d0fe78051ca689daa15e35'))
paddle.fluid.layers.linear_lr_warmup (ArgSpec(args=['learning_rate', 'warmup_steps', 'start_lr', 'end_lr'], varargs=None, keywords=None, defaults=None), ('document', '
0b529386b62cc73d27b711a5f618f3e4
'))
paddle.fluid.layers.linear_lr_warmup (ArgSpec(args=['learning_rate', 'warmup_steps', 'start_lr', 'end_lr'], varargs=None, keywords=None, defaults=None), ('document', '
dc7292c456847ba41cfd318e9f7f4363
'))
paddle.fluid.contrib.InitState ('paddle.fluid.contrib.decoder.beam_search_decoder.InitState', ('document', '3afd1f84232718e628e9e566941c5f05'))
paddle.fluid.contrib.InitState.__init__ (ArgSpec(args=['self', 'init', 'shape', 'value', 'init_boot', 'need_reorder', 'dtype'], varargs=None, keywords=None, defaults=(None, None, 0.0, None, False, 'float32')), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.StateCell ('paddle.fluid.contrib.decoder.beam_search_decoder.StateCell', ('document', 'ecd0066c02867d445d7b461e28220c50'))
...
...
python/paddle/fluid/layers/learning_rate_scheduler.py
浏览文件 @
602cb6a5
...
...
@@ -23,6 +23,7 @@ strategy according to this module.
from
__future__
import
print_function
import
math
import
numbers
from
.
import
control_flow
from
.
import
nn
...
...
@@ -30,6 +31,7 @@ from . import ops
from
.
import
tensor
from
..initializer
import
init_on_cpu
from
..framework
import
default_main_program
,
Parameter
,
unique_name
,
name_scope
from
..framework
import
Variable
from
..dygraph
import
base
as
imperative_base
from
..dygraph
import
learning_rate_scheduler
as
imperate_lr
...
...
@@ -450,8 +452,8 @@ def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
Args:
learning_rate (float | Variable): A float value or Variable.
warmup_steps (int): The warmup steps.
start_lr (float): The start learning of warmup.
end_lr (float): The end learning of warmup.
start_lr (float): The start learning
rate
of warmup.
end_lr (float): The end learning
rate
of warmup.
Returns:
The decayed learning rate in warmup period.
...
...
@@ -470,14 +472,16 @@ def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
warmup_steps, start_lr, end_lr)
"""
assert
(
isinstance
(
end_lr
,
float
))
assert
(
isinstance
(
start_lr
,
float
))
linear_step
=
end_lr
-
start_lr
dtype
=
'float32'
if
isinstance
(
learning_rate
,
Variable
):
dtype
=
learning_rate
.
dtype
linear_step
=
float
(
end_lr
)
-
float
(
start_lr
)
with
default_main_program
().
_lr_schedule_guard
():
lr
=
tensor
.
create_global_var
(
shape
=
[
1
],
value
=
0.0
,
dtype
=
'float32'
,
dtype
=
dtype
,
persistable
=
True
,
name
=
"learning_rate_warmup"
)
...
...
@@ -489,5 +493,8 @@ def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
float
(
warmup_steps
))
tensor
.
assign
(
decayed_lr
,
lr
)
with
switch
.
default
():
if
not
isinstance
(
learning_rate
,
Variable
):
learning_rate
=
tensor
.
fill_constant
(
shape
=
[
1
],
dtype
=
dtype
,
value
=
float
(
learning_rate
))
tensor
.
assign
(
learning_rate
,
lr
)
return
lr
python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py
浏览文件 @
602cb6a5
...
...
@@ -185,7 +185,7 @@ class TestLinearWamrupLearningRateDecay(TestLearningRateDecay):
startup_prog
=
fluid
.
Program
()
warmup_steps
=
10
start_lr
=
1.
/
3.
start_lr
=
0.1
/
3.
end_lr
=
0.1
with
fluid
.
program_guard
(
main_prog
,
startup_prog
):
...
...
@@ -212,5 +212,59 @@ class TestLinearWamrupLearningRateDecay(TestLearningRateDecay):
str
(
step
),
str
(
python_decayed_lr
),
str
(
lr_val
[
0
])))
class
TestLinearWamrupLearningRateDecayWithScalarInput
(
unittest
.
TestCase
):
def
run_scalar_lr
(
self
,
place
,
lr
,
start_lr
,
end_lr
):
main_prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
warmup_steps
=
10
with
fluid
.
program_guard
(
main_prog
,
startup_prog
):
decayed_lr
=
layers
.
linear_lr_warmup
(
lr
,
warmup_steps
,
start_lr
,
end_lr
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_prog
)
for
step
in
range
(
20
):
lr_val
,
=
exe
.
run
(
main_prog
,
feed
=
{},
fetch_list
=
[
decayed_lr
])
if
step
<
warmup_steps
:
expected_lr
=
linear_lr_warmup
(
float
(
step
),
warmup_steps
,
start_lr
,
end_lr
)
else
:
expected_lr
=
lr
self
.
assertAlmostEqual
(
expected_lr
,
lr_val
[
0
],
msg
=
'Test failed, step {0}, expected {1}, but got {2}'
.
format
(
step
,
expected_lr
,
lr_val
[
0
]))
def
test_scalar_lr
(
self
):
def
run_places
(
lr
,
start_lr
,
end_lr
):
places
=
[
fluid
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
fluid
.
CUDAPlace
(
0
))
for
p
in
places
:
self
.
run_scalar_lr
(
p
,
lr
,
start_lr
,
end_lr
)
# float
lr
=
0.2
start_lr
=
0.1
/
3.
end_lr
=
0.2
run_places
(
lr
,
start_lr
,
end_lr
)
# int end_lr
lr
=
2.
start_lr
=
0.1
/
3.
end_lr
=
1
run_places
(
lr
,
start_lr
,
end_lr
)
# int
lr
=
1
start_lr
=
0
end_lr
=
1
run_places
(
lr
,
start_lr
,
end_lr
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录