Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ec9c0874
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ec9c0874
编写于
3月 27, 2019
作者:
M
minqiyang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement Expotential NatureExp Inversetime and Polynomal Decay
上级
4278be8c
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
248 addition
and
53 deletion
+248
-53
python/paddle/fluid/imperative/learning_rate_scheduler.py
python/paddle/fluid/imperative/learning_rate_scheduler.py
+117
-1
python/paddle/fluid/layers/learning_rate_scheduler.py
python/paddle/fluid/layers/learning_rate_scheduler.py
+58
-37
python/paddle/fluid/tests/unittests/test_imperative_optimizer.py
...paddle/fluid/tests/unittests/test_imperative_optimizer.py
+73
-15
未找到文件。
python/paddle/fluid/imperative/learning_rate_scheduler.py
浏览文件 @
ec9c0874
...
...
@@ -16,7 +16,9 @@ from __future__ import print_function
from
..
import
unique_name
__all__
=
[
'PiecewiseDecay'
]
__all__
=
[
'PiecewiseDecay'
,
'NaturalExpDecay'
,
'ExponentialDecay'
,
'InverseTimeDecay'
]
class
LearningRateDecay
(
object
):
...
...
@@ -65,3 +67,117 @@ class PiecewiseDecay(LearningRateDecay):
if
self
.
step_num
<
self
.
boundaries
[
i
]:
return
self
.
vars
[
i
]
return
self
.
vars
[
len
(
self
.
values
)
-
1
]
class
NaturalExpDecay
(
LearningRateDecay
):
def
__init__
(
self
,
learning_rate
,
decay_steps
,
decay_rate
,
staircase
=
False
,
begin
=
0
,
step
=
1
,
dtype
=
'float32'
):
super
(
NaturalExpDecay
,
self
).
__init__
(
begin
,
step
,
dtype
)
self
.
learning_rate
=
learning_rate
self
.
decay_steps
=
decay_steps
self
.
decay_rate
=
decay_rate
self
.
staircase
=
staircase
def
step
(
self
):
from
..
import
layers
div_res
=
self
.
create_lr_var
(
self
.
step_num
/
self
.
decay_steps
)
if
self
.
staircase
:
div_res
=
layers
.
floor
(
div_res
)
decayed_lr
=
self
.
learning_rate
*
layers
.
exp
(
-
1
*
self
.
decay_rate
*
div_res
)
return
decayed_lr
class
ExponentialDecay
(
LearningRateDecay
):
def
__init__
(
self
,
learning_rate
,
decay_steps
,
decay_rate
,
staircase
=
False
,
begin
=
0
,
step
=
1
,
dtype
=
'float32'
):
super
(
ExponentialDecay
,
self
).
__init__
(
begin
,
step
,
dtype
)
self
.
learning_rate
=
learning_rate
self
.
decay_steps
=
decay_steps
self
.
decay_rate
=
decay_rate
self
.
staircase
=
staircase
def
step
(
self
):
from
..
import
layers
div_res
=
self
.
create_lr_var
(
self
.
step_num
/
self
.
decay_steps
)
if
self
.
staircase
:
div_res
=
layers
.
floor
(
div_res
)
decayed_lr
=
self
.
learning_rate
*
(
self
.
decay_rate
**
div_res
)
return
decayed_lr
class
InverseTimeDecay
(
LearningRateDecay
):
def
__init__
(
self
,
learning_rate
,
decay_steps
,
decay_rate
,
staircase
=
False
,
begin
=
0
,
step
=
1
,
dtype
=
'float32'
):
super
(
InverseTimeDecay
,
self
).
__init__
(
begin
,
step
,
dtype
)
self
.
learning_rate
=
learning_rate
self
.
decay_steps
=
decay_steps
self
.
decay_rate
=
decay_rate
self
.
staircase
=
staircase
def
step
(
self
):
from
..
import
layers
div_res
=
self
.
create_lr_var
(
self
.
step_num
/
self
.
decay_steps
)
if
self
.
staircase
:
div_res
=
layers
.
floor
(
div_res
)
decayed_lr
=
self
.
learning_rate
/
(
1
+
self
.
decay_rate
*
div_res
)
return
decayed_lr
class
PolynomialDecay
(
LearningRateDecay
):
def
__init__
(
self
,
learning_rate
,
decay_steps
,
end_learning_rate
=
0.0001
,
power
=
1.0
,
cycle
=
False
,
begin
=
0
,
step
=
1
,
dtype
=
'float32'
):
super
(
PolynomialDecay
,
self
).
__init__
(
begin
,
step
,
dtype
)
self
.
learning_rate
=
learning_rate
self
.
decay_steps
=
decay_steps
self
.
end_learning_rate
=
end_learning_rate
self
.
power
=
power
self
.
cycle
=
cycle
def
step
(
self
):
from
..
import
layers
if
self
.
cycle
:
div_res
=
layers
.
ceil
(
self
.
create_lr_var
(
self
.
step_num
/
self
.
decay_steps
))
zero_var
=
0.0
one_var
=
1.0
if
float
(
self
.
step_num
)
==
zero_var
:
div_res
=
one_var
decay_steps
=
self
.
decay_steps
*
div_res
else
:
global_step
=
global_step
if
global_step
<
self
.
decay_steps
else
self
.
decay_steps
decayed_lr
=
(
self
.
learning_rate
-
self
.
end_learning_rate
)
*
\
((
1
-
global_step
/
self
.
decay_steps
)
**
self
.
power
)
+
self
.
end_learning_rate
return
self
.
create_lr_var
(
decayed_lr
)
python/paddle/fluid/layers/learning_rate_scheduler.py
浏览文件 @
ec9c0874
...
...
@@ -115,14 +115,19 @@ def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
"""
with
default_main_program
().
_lr_schedule_guard
():
global_step
=
_decay_step_counter
()
if
imperative_base
.
enabled
():
decay
=
imperate_lr
.
ExponentialDecay
(
learning_rate
,
decay_steps
,
decay_rate
,
staircase
)
return
decay
else
:
global_step
=
_decay_step_counter
()
div_res
=
global_step
/
decay_steps
if
staircase
:
div_res
=
ops
.
floor
(
div_res
)
decayed_lr
=
learning_rate
*
(
decay_rate
**
div_res
)
div_res
=
global_step
/
decay_steps
if
staircase
:
div_res
=
ops
.
floor
(
div_res
)
decayed_lr
=
learning_rate
*
(
decay_rate
**
div_res
)
return
decayed_lr
return
decayed_lr
def
natural_exp_decay
(
learning_rate
,
decay_steps
,
decay_rate
,
staircase
=
False
):
...
...
@@ -144,14 +149,19 @@ def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
The decayed learning rate
"""
with
default_main_program
().
_lr_schedule_guard
():
global_step
=
_decay_step_counter
()
if
imperative_base
.
enabled
():
decay
=
imperate_lr
.
NaturalExpDecay
(
learning_rate
,
decay_steps
,
decay_rate
,
staircase
)
return
decay
else
:
global_step
=
_decay_step_counter
()
div_res
=
global_step
/
decay_steps
if
staircase
:
div_res
=
ops
.
floor
(
div_res
)
decayed_lr
=
learning_rate
*
ops
.
exp
(
-
1
*
decay_rate
*
div_res
)
div_res
=
global_step
/
decay_steps
if
staircase
:
div_res
=
ops
.
floor
(
div_res
)
decayed_lr
=
learning_rate
*
ops
.
exp
(
-
1
*
decay_rate
*
div_res
)
return
decayed_lr
return
decayed_lr
def
inverse_time_decay
(
learning_rate
,
decay_steps
,
decay_rate
,
staircase
=
False
):
...
...
@@ -190,15 +200,20 @@ def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
sgd_optimizer.minimize(avg_cost)
"""
with
default_main_program
().
_lr_schedule_guard
():
global_step
=
_decay_step_counter
()
if
imperative_base
.
enabled
():
decay
=
imperate_lr
.
InverseTimeDecay
(
learning_rate
,
decay_steps
,
decay_rate
,
staircase
)
return
decay
else
:
global_step
=
_decay_step_counter
()
div_res
=
global_step
/
decay_steps
if
staircase
:
div_res
=
ops
.
floor
(
div_res
)
div_res
=
global_step
/
decay_steps
if
staircase
:
div_res
=
ops
.
floor
(
div_res
)
decayed_lr
=
learning_rate
/
(
1
+
decay_rate
*
div_res
)
decayed_lr
=
learning_rate
/
(
1
+
decay_rate
*
div_res
)
return
decayed_lr
return
decayed_lr
def
polynomial_decay
(
learning_rate
,
...
...
@@ -230,27 +245,33 @@ def polynomial_decay(learning_rate,
Variable: The decayed learning rate
"""
with
default_main_program
().
_lr_schedule_guard
():
global_step
=
_decay_step_counter
()
if
cycle
:
div_res
=
ops
.
ceil
(
global_step
/
decay_steps
)
zero_var
=
tensor
.
fill_constant
(
shape
=
[
1
],
dtype
=
'float32'
,
value
=
0.0
)
one_var
=
tensor
.
fill_constant
(
shape
=
[
1
],
dtype
=
'float32'
,
value
=
1.0
)
with
control_flow
.
Switch
()
as
switch
:
with
switch
.
case
(
global_step
==
zero_var
):
tensor
.
assign
(
input
=
one_var
,
output
=
div_res
)
decay_steps
=
decay_steps
*
div_res
if
imperative_base
.
enabled
():
decay
=
imperate_lr
.
PolynomialDecay
(
learning_rate
,
decay_steps
,
end_learning_rate
,
power
,
cycle
)
return
decay
else
:
decay_steps_var
=
tensor
.
fill_constant
(
shape
=
[
1
],
dtype
=
'float32'
,
value
=
float
(
decay_steps
))
global_step
=
nn
.
elementwise_min
(
x
=
global_step
,
y
=
decay_steps_var
)
global_step
=
_decay_step_counter
()
decayed_lr
=
(
learning_rate
-
end_learning_rate
)
*
\
((
1
-
global_step
/
decay_steps
)
**
power
)
+
end_learning_rate
return
decayed_lr
if
cycle
:
div_res
=
ops
.
ceil
(
global_step
/
decay_steps
)
zero_var
=
tensor
.
fill_constant
(
shape
=
[
1
],
dtype
=
'float32'
,
value
=
0.0
)
one_var
=
tensor
.
fill_constant
(
shape
=
[
1
],
dtype
=
'float32'
,
value
=
1.0
)
with
control_flow
.
Switch
()
as
switch
:
with
switch
.
case
(
global_step
==
zero_var
):
tensor
.
assign
(
input
=
one_var
,
output
=
div_res
)
decay_steps
=
decay_steps
*
div_res
else
:
decay_steps_var
=
tensor
.
fill_constant
(
shape
=
[
1
],
dtype
=
'float32'
,
value
=
float
(
decay_steps
))
global_step
=
nn
.
elementwise_min
(
x
=
global_step
,
y
=
decay_steps_var
)
decayed_lr
=
(
learning_rate
-
end_learning_rate
)
*
\
((
1
-
global_step
/
decay_steps
)
**
power
)
+
end_learning_rate
return
decayed_lr
def
piecewise_decay
(
boundaries
,
values
):
...
...
python/paddle/fluid/tests/unittests/test_imperative_optimizer.py
浏览文件 @
ec9c0874
...
...
@@ -22,7 +22,7 @@ import six
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid
import
core
from
paddle.fluid.optimizer
import
SGDOptimizer
from
paddle.fluid.optimizer
import
SGDOptimizer
,
Adam
from
paddle.fluid.imperative.nn
import
FC
from
paddle.fluid.imperative.base
import
to_variable
from
test_imperative_base
import
new_program_scope
...
...
@@ -46,14 +46,9 @@ class TestImperativeOptimizerBase(unittest.TestCase):
self
.
batch_num
=
10
def
get_optimizer
(
self
):
bd
=
[
3
,
6
,
9
]
self
.
optimizer
=
SGDOptimizer
(
learning_rate
=
fluid
.
layers
.
piecewise_decay
(
boundaries
=
bd
,
values
=
[
0.1
*
(
0.1
**
i
)
for
i
in
range
(
len
(
bd
)
+
1
)]))
return
self
.
optimizer
raise
NotImplementedError
()
def
test_optimizer_float32
(
self
):
def
_check_mlp
(
self
):
seed
=
90
with
fluid
.
imperative
.
guard
():
fluid
.
default_startup_program
().
random_seed
=
seed
...
...
@@ -83,16 +78,14 @@ class TestImperativeOptimizerBase(unittest.TestCase):
dy_out
=
avg_loss
.
_numpy
()
if
batch_id
==
0
:
for
param
in
fluid
.
default_main_program
().
global_block
(
).
all_parameters
():
for
param
in
mlp
.
parameters
():
dy_param_init_value
[
param
.
name
]
=
param
.
_numpy
()
avg_loss
.
_backward
()
optimizer
.
minimize
(
avg_loss
)
mlp
.
clear_gradients
()
dy_param_value
=
{}
for
param
in
fluid
.
default_main_program
().
global_block
(
).
all_parameters
():
for
param
in
mlp
.
parameters
():
dy_param_value
[
param
.
name
]
=
param
.
_numpy
()
with
new_program_scope
():
...
...
@@ -102,7 +95,7 @@ class TestImperativeOptimizerBase(unittest.TestCase):
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
(
)
if
not
core
.
is_compiled_with_cuda
()
else
fluid
.
CUDAPlace
(
0
))
m
nist
=
MLP
(
'mlp'
)
m
lp
=
MLP
(
'mlp'
)
optimizer
=
self
.
get_optimizer
()
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
128
,
drop_last
=
True
)
...
...
@@ -110,14 +103,14 @@ class TestImperativeOptimizerBase(unittest.TestCase):
img
=
fluid
.
layers
.
data
(
name
=
'pixel'
,
shape
=
[
1
,
28
,
28
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
cost
=
m
nist
(
img
)
cost
=
m
lp
(
img
)
avg_loss
=
fluid
.
layers
.
reduce_mean
(
cost
)
optimizer
.
minimize
(
avg_loss
)
# initialize params and fetch them
static_param_init_value
=
{}
static_param_name_list
=
[]
for
param
in
m
nist
.
parameters
():
for
param
in
m
lp
.
parameters
():
static_param_name_list
.
append
(
param
.
name
)
out
=
exe
.
run
(
fluid
.
default_startup_program
(),
...
...
@@ -156,5 +149,70 @@ class TestImperativeOptimizerBase(unittest.TestCase):
self
.
assertTrue
(
np
.
allclose
(
value
,
dy_param_value
[
key
],
atol
=
1e-5
))
class
TestImperativeOptimizerPiecewiseDecay
(
TestImperativeOptimizerBase
):
def
get_optimizer
(
self
):
bd
=
[
3
,
6
,
9
]
optimizer
=
SGDOptimizer
(
learning_rate
=
fluid
.
layers
.
piecewise_decay
(
boundaries
=
bd
,
values
=
[
0.1
*
(
0.1
**
i
)
for
i
in
range
(
len
(
bd
)
+
1
)]))
return
optimizer
def
test_sgd
(
self
):
self
.
_check_mlp
()
class
TestImperativeOptimizerNaturalExpDecay
(
TestImperativeOptimizerBase
):
def
get_optimizer
(
self
):
optimizer
=
SGDOptimizer
(
learning_rate
=
fluid
.
layers
.
natural_exp_decay
(
learning_rate
=
0.1
,
decay_steps
=
10000
,
decay_rate
=
0.5
,
staircase
=
True
))
return
optimizer
def
test_sgd
(
self
):
self
.
_check_mlp
()
class
TestImperativeOptimizerExponentialDecay
(
TestImperativeOptimizerBase
):
def
get_optimizer
(
self
):
optimizer
=
SGDOptimizer
(
learning_rate
=
fluid
.
layers
.
exponential_decay
(
learning_rate
=
0.1
,
decay_steps
=
10000
,
decay_rate
=
0.5
,
staircase
=
True
))
return
optimizer
def
test_sgd
(
self
):
self
.
_check_mlp
()
class
TestImperativeOptimizerInverseTimeDecay
(
TestImperativeOptimizerBase
):
def
get_optimizer
(
self
):
optimizer
=
Adam
(
learning_rate
=
fluid
.
layers
.
inverse_time_decay
(
learning_rate
=
0.1
,
decay_steps
=
10000
,
decay_rate
=
0.5
,
staircase
=
True
))
return
optimizer
def
test_adam
(
self
):
self
.
_check_mlp
()
class
TestImperativeOptimizerPolynomialDecay
(
TestImperativeOptimizerBase
):
def
get_optimizer
(
self
):
optimizer
=
SGDOptimizer
(
learning_rate
=
fluid
.
layers
.
polynomial_decay
(
learning_rate
=
0.1
,
decay_steps
=
5
,
cycle
=
self
.
cycle
))
return
optimizer
def
test_sgd_cycle
(
self
):
self
.
cycle
=
True
self
.
_check_mlp
()
def
test_sgd
(
self
):
self
.
cycle
=
False
self
.
_check_mlp
()
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录