Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
d8d73ff3
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d8d73ff3
编写于
3月 31, 2019
作者:
Q
Qiyang Min
提交者:
GitHub
3月 31, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #15584 from velconia/imperative_lr_scheduler
Support imperative learning rate scheduler
上级
1ebd7434
64b09294
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
773 addition
and
230 deletion
+773
-230
python/paddle/fluid/dygraph/__init__.py
python/paddle/fluid/dygraph/__init__.py
+4
-0
python/paddle/fluid/dygraph/learning_rate_scheduler.py
python/paddle/fluid/dygraph/learning_rate_scheduler.py
+224
-0
python/paddle/fluid/layers/learning_rate_scheduler.py
python/paddle/fluid/layers/learning_rate_scheduler.py
+111
-71
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+57
-20
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+2
-2
python/paddle/fluid/tests/unittests/test_imperative_mnist.py
python/paddle/fluid/tests/unittests/test_imperative_mnist.py
+217
-0
python/paddle/fluid/tests/unittests/test_imperative_optimizer.py
...paddle/fluid/tests/unittests/test_imperative_optimizer.py
+158
-137
未找到文件。
python/paddle/fluid/dygraph/__init__.py
浏览文件 @
d8d73ff3
...
...
@@ -32,6 +32,9 @@ from .profiler import *
from
.
import
checkpoint
from
.checkpoint
import
*
from
.
import
learning_rate_scheduler
from
.learning_rate_scheduler
import
*
__all__
=
[]
__all__
+=
layers
.
__all__
__all__
+=
base
.
__all__
...
...
@@ -39,3 +42,4 @@ __all__ += nn.__all__
__all__
+=
tracer
.
__all__
__all__
+=
profiler
.
__all__
__all__
+=
checkpoint
.
__all__
__all__
+=
learning_rate_scheduler
.
__all__
python/paddle/fluid/dygraph/learning_rate_scheduler.py
0 → 100644
浏览文件 @
d8d73ff3
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
math
from
..
import
unique_name
__all__
=
[
'NoamDecay'
,
'PiecewiseDecay'
,
'NaturalExpDecay'
,
'ExponentialDecay'
,
'InverseTimeDecay'
,
'PolynomialDecay'
,
'CosineDecay'
]
class
LearningRateDecay
(
object
):
"""
Base class of learning rate decay
"""
def
__init__
(
self
,
begin
=
0
,
step
=
1
,
dtype
=
'float32'
):
self
.
step_num
=
begin
self
.
step_size
=
step
self
.
dtype
=
dtype
def
__call__
(
self
):
lr
=
self
.
step
()
if
isinstance
(
lr
,
float
):
lr
=
self
.
create_lr_var
(
lr
)
self
.
step_num
+=
self
.
step_size
return
lr
def
create_lr_var
(
self
,
lr
):
from
..
import
layers
lr
=
layers
.
create_global_var
(
name
=
unique_name
.
generate
(
"learning_rate"
),
shape
=
[
1
],
value
=
float
(
lr
),
dtype
=
self
.
dtype
,
persistable
=
True
)
return
lr
def
step
(
self
):
raise
NotImplementedError
()
class
PiecewiseDecay
(
LearningRateDecay
):
def
__init__
(
self
,
boundaries
,
values
,
begin
,
step
=
1
,
dtype
=
'float32'
):
super
(
PiecewiseDecay
,
self
).
__init__
(
begin
,
step
,
dtype
)
self
.
boundaries
=
boundaries
self
.
values
=
values
self
.
vars
=
[]
for
value
in
values
:
self
.
vars
.
append
(
self
.
create_lr_var
(
value
))
def
step
(
self
):
for
i
in
range
(
len
(
self
.
boundaries
)):
if
self
.
step_num
<
self
.
boundaries
[
i
]:
return
self
.
vars
[
i
]
return
self
.
vars
[
len
(
self
.
values
)
-
1
]
class
NaturalExpDecay
(
LearningRateDecay
):
def
__init__
(
self
,
learning_rate
,
decay_steps
,
decay_rate
,
staircase
=
False
,
begin
=
0
,
step
=
1
,
dtype
=
'float32'
):
super
(
NaturalExpDecay
,
self
).
__init__
(
begin
,
step
,
dtype
)
self
.
learning_rate
=
learning_rate
self
.
decay_steps
=
decay_steps
self
.
decay_rate
=
decay_rate
self
.
staircase
=
staircase
def
step
(
self
):
from
..
import
layers
div_res
=
self
.
create_lr_var
(
self
.
step_num
/
self
.
decay_steps
)
if
self
.
staircase
:
div_res
=
layers
.
floor
(
div_res
)
decayed_lr
=
self
.
learning_rate
*
layers
.
exp
(
-
1
*
self
.
decay_rate
*
div_res
)
return
decayed_lr
class
ExponentialDecay
(
LearningRateDecay
):
def
__init__
(
self
,
learning_rate
,
decay_steps
,
decay_rate
,
staircase
=
False
,
begin
=
0
,
step
=
1
,
dtype
=
'float32'
):
super
(
ExponentialDecay
,
self
).
__init__
(
begin
,
step
,
dtype
)
self
.
learning_rate
=
learning_rate
self
.
decay_steps
=
decay_steps
self
.
decay_rate
=
decay_rate
self
.
staircase
=
staircase
def
step
(
self
):
from
..
import
layers
div_res
=
self
.
create_lr_var
(
self
.
step_num
/
self
.
decay_steps
)
if
self
.
staircase
:
div_res
=
layers
.
floor
(
div_res
)
decayed_lr
=
self
.
learning_rate
*
(
self
.
decay_rate
**
div_res
)
return
decayed_lr
class
InverseTimeDecay
(
LearningRateDecay
):
def
__init__
(
self
,
learning_rate
,
decay_steps
,
decay_rate
,
staircase
=
False
,
begin
=
0
,
step
=
1
,
dtype
=
'float32'
):
super
(
InverseTimeDecay
,
self
).
__init__
(
begin
,
step
,
dtype
)
self
.
learning_rate
=
learning_rate
self
.
decay_steps
=
decay_steps
self
.
decay_rate
=
decay_rate
self
.
staircase
=
staircase
def
step
(
self
):
from
..
import
layers
div_res
=
self
.
create_lr_var
(
self
.
step_num
/
self
.
decay_steps
)
if
self
.
staircase
:
div_res
=
layers
.
floor
(
div_res
)
decayed_lr
=
self
.
learning_rate
/
(
1
+
self
.
decay_rate
*
div_res
)
return
decayed_lr
class
PolynomialDecay
(
LearningRateDecay
):
def
__init__
(
self
,
learning_rate
,
decay_steps
,
end_learning_rate
=
0.0001
,
power
=
1.0
,
cycle
=
False
,
begin
=
0
,
step
=
1
,
dtype
=
'float32'
):
super
(
PolynomialDecay
,
self
).
__init__
(
begin
,
step
,
dtype
)
self
.
learning_rate
=
learning_rate
self
.
decay_steps
=
decay_steps
self
.
end_learning_rate
=
end_learning_rate
self
.
power
=
power
self
.
cycle
=
cycle
def
step
(
self
):
from
..
import
layers
tmp_step_num
=
self
.
step_num
tmp_decay_steps
=
self
.
decay_steps
if
self
.
cycle
:
div_res
=
layers
.
ceil
(
self
.
create_lr_var
(
tmp_step_num
/
float
(
self
.
decay_steps
)))
if
tmp_step_num
==
0
:
div_res
=
self
.
create_lr_var
(
1.0
)
tmp_decay_steps
=
self
.
decay_steps
*
div_res
else
:
tmp_step_num
=
self
.
create_lr_var
(
tmp_step_num
if
tmp_step_num
<
self
.
decay_steps
else
self
.
decay_steps
)
decayed_lr
=
(
self
.
learning_rate
-
self
.
end_learning_rate
)
*
\
((
1
-
tmp_step_num
/
tmp_decay_steps
)
**
self
.
power
)
+
self
.
end_learning_rate
return
decayed_lr
class
CosineDecay
(
LearningRateDecay
):
def
__init__
(
self
,
learning_rate
,
step_each_epoch
,
epochs
,
begin
=
0
,
step
=
1
,
dtype
=
'float32'
):
super
(
CosineDecay
,
self
).
__init__
(
begin
,
step
,
dtype
)
self
.
learning_rate
=
learning_rate
self
.
step_each_epoch
=
step_each_epoch
self
.
epochs
=
epochs
def
step
(
self
):
from
..
import
layers
cur_epoch
=
layers
.
floor
(
self
.
create_lr_var
(
self
.
step_num
/
self
.
step_each_epoch
))
decayed_lr
=
self
.
learning_rate
*
0.5
*
(
layers
.
cos
(
cur_epoch
*
math
.
pi
/
self
.
epochs
)
+
1
)
return
decayed_lr
class
NoamDecay
(
LearningRateDecay
):
def
__init__
(
self
,
d_model
,
warmup_steps
,
begin
=
1
,
step
=
1
,
dtype
=
'float32'
):
super
(
NoamDecay
,
self
).
__init__
(
begin
,
step
,
dtype
)
self
.
d_model
=
d_model
self
.
warmup_steps
=
warmup_steps
def
step
(
self
):
from
..
import
layers
a
=
self
.
create_lr_var
(
self
.
step_num
**-
0.5
)
b
=
self
.
create_lr_var
((
self
.
warmup_steps
**-
1.5
)
*
self
.
step_num
)
lr_value
=
(
self
.
d_model
**-
0.5
)
*
layers
.
elementwise_min
(
a
,
b
)
return
lr_value
python/paddle/fluid/layers/learning_rate_scheduler.py
浏览文件 @
d8d73ff3
...
...
@@ -22,13 +22,16 @@ strategy according to this module.
from
__future__
import
print_function
import
math
from
.
import
control_flow
from
.
import
nn
from
.
import
ops
from
.
import
tensor
from
..initializer
import
init_on_cpu
from
..framework
import
default_main_program
,
Parameter
,
unique_name
,
name_scope
import
math
from
..dygraph
import
base
as
imperative_base
from
..dygraph
import
learning_rate_scheduler
as
imperate_lr
__all__
=
[
'exponential_decay'
,
'natural_exp_decay'
,
'inverse_time_decay'
,
...
...
@@ -66,6 +69,10 @@ def noam_decay(d_model, warmup_steps):
The decayed learning rate.
"""
with
default_main_program
().
_lr_schedule_guard
():
if
imperative_base
.
enabled
():
decay
=
imperate_lr
.
NoamDecay
(
d_model
,
warmup_steps
)
return
decay
else
:
global_step
=
_decay_step_counter
(
1
)
a
=
global_step
**-
0.5
...
...
@@ -112,6 +119,11 @@ def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
"""
with
default_main_program
().
_lr_schedule_guard
():
if
imperative_base
.
enabled
():
decay
=
imperate_lr
.
ExponentialDecay
(
learning_rate
,
decay_steps
,
decay_rate
,
staircase
)
return
decay
else
:
global_step
=
_decay_step_counter
()
div_res
=
global_step
/
decay_steps
...
...
@@ -141,6 +153,11 @@ def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
The decayed learning rate
"""
with
default_main_program
().
_lr_schedule_guard
():
if
imperative_base
.
enabled
():
decay
=
imperate_lr
.
NaturalExpDecay
(
learning_rate
,
decay_steps
,
decay_rate
,
staircase
)
return
decay
else
:
global_step
=
_decay_step_counter
()
div_res
=
global_step
/
decay_steps
...
...
@@ -187,6 +204,11 @@ def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
sgd_optimizer.minimize(avg_cost)
"""
with
default_main_program
().
_lr_schedule_guard
():
if
imperative_base
.
enabled
():
decay
=
imperate_lr
.
InverseTimeDecay
(
learning_rate
,
decay_steps
,
decay_rate
,
staircase
)
return
decay
else
:
global_step
=
_decay_step_counter
()
div_res
=
global_step
/
decay_steps
...
...
@@ -227,6 +249,11 @@ def polynomial_decay(learning_rate,
Variable: The decayed learning rate
"""
with
default_main_program
().
_lr_schedule_guard
():
if
imperative_base
.
enabled
():
decay
=
imperate_lr
.
PolynomialDecay
(
learning_rate
,
decay_steps
,
end_learning_rate
,
power
,
cycle
)
return
decay
else
:
global_step
=
_decay_step_counter
()
if
cycle
:
...
...
@@ -243,7 +270,8 @@ def polynomial_decay(learning_rate,
else
:
decay_steps_var
=
tensor
.
fill_constant
(
shape
=
[
1
],
dtype
=
'float32'
,
value
=
float
(
decay_steps
))
global_step
=
nn
.
elementwise_min
(
x
=
global_step
,
y
=
decay_steps_var
)
global_step
=
nn
.
elementwise_min
(
x
=
global_step
,
y
=
decay_steps_var
)
decayed_lr
=
(
learning_rate
-
end_learning_rate
)
*
\
((
1
-
global_step
/
decay_steps
)
**
power
)
+
end_learning_rate
...
...
@@ -279,6 +307,10 @@ def piecewise_decay(boundaries, values):
if
len
(
values
)
-
len
(
boundaries
)
!=
1
:
raise
ValueError
(
"len(values) - len(boundaries) should be 1"
)
if
imperative_base
.
enabled
():
decay
=
imperate_lr
.
PiecewiseDecay
(
boundaries
,
values
,
0
)
return
decay
else
:
global_step
=
_decay_step_counter
()
lr
=
tensor
.
create_global_var
(
...
...
@@ -336,6 +368,11 @@ def cosine_decay(learning_rate, step_each_epoch, epochs):
learning_rate = base_lr, step_each_epoch=10000, epochs=120)
"""
with
default_main_program
().
_lr_schedule_guard
():
if
imperative_base
.
enabled
():
decay
=
imperate_lr
.
CosineDecay
(
learning_rate
,
step_each_epoch
,
epochs
)
return
decay
else
:
global_step
=
_decay_step_counter
()
cur_epoch
=
ops
.
floor
(
global_step
/
step_each_epoch
)
...
...
@@ -363,6 +400,9 @@ def append_LARS(params_grads, learning_rate, weight_decay):
/ (sqrt(sumsq(gradient))+ weight_decay * sqrt(sumsq(param)))
"""
assert
not
imperative_base
.
enabled
(
),
"append_LARS is NOT supported in dygraph mode now"
def
_balanced_weight
(
param_norm
,
grad_norm
):
if
weight_decay
==
1.0
:
return
grad_norm
+
param_norm
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
d8d73ff3
...
...
@@ -30,6 +30,8 @@ from .initializer import Constant
from
.layer_helper
import
LayerHelper
from
.layers
import
ops
from
.regularizer
import
append_regularization_ops
from
.dygraph
import
base
as
imperative_base
from
.dygraph.learning_rate_scheduler
import
LearningRateDecay
from
paddle.fluid
import
core
from
paddle.fluid.layers
import
tensor
from
functools
import
reduce
...
...
@@ -53,9 +55,19 @@ class Optimizer(object):
"""
def
__init__
(
self
,
learning_rate
,
regularization
=
None
,
name
=
None
):
if
framework
.
_in_dygraph_mode
():
if
not
isinstance
(
learning_rate
,
float
)
and
\
not
isinstance
(
learning_rate
,
LearningRateDecay
):
raise
TypeError
(
"learning rate should be float or LearningRateDecay, got %s here"
%
type
(
learning_rate
))
else
:
if
not
isinstance
(
learning_rate
,
float
)
and
\
not
isinstance
(
learning_rate
,
framework
.
Variable
):
raise
TypeError
(
"learning rate should be float or Variable"
)
raise
TypeError
(
"learning rate should be float or Variable, got %s here"
%
type
(
learning_rate
))
self
.
_name
=
name
self
.
regularization
=
regularization
self
.
_learning_rate
=
learning_rate
...
...
@@ -79,6 +91,30 @@ class Optimizer(object):
return
self
.
_opti_name_list
def
_create_global_learning_rate
(
self
):
if
imperative_base
.
enabled
():
# create learning rate Variable
if
isinstance
(
self
.
_learning_rate
,
float
):
lr
=
self
.
_global_learning_rate
()
if
isinstance
(
lr
,
framework
.
Variable
):
return
else
:
self
.
_learning_rate_map
[
framework
.
default_main_program
(
)]
=
layers
.
create_global_var
(
name
=
unique_name
.
generate
(
"learning_rate"
),
shape
=
[
1
],
value
=
float
(
self
.
_learning_rate
),
dtype
=
'float32'
if
self
.
_dtype
is
None
else
self
.
_dtype
,
persistable
=
True
)
# get learning rate Variable from LearningRateDecay
elif
isinstance
(
self
.
_learning_rate
,
LearningRateDecay
):
self
.
_learning_rate_map
[
framework
.
default_main_program
(
)]
=
self
.
_learning_rate
()
else
:
raise
TypeError
(
"optimizer's learning rate must be float or LearningRateDecay"
)
else
:
lr
=
self
.
_global_learning_rate
()
if
isinstance
(
lr
,
framework
.
Variable
):
...
...
@@ -87,7 +123,8 @@ class Optimizer(object):
if
not
isinstance
(
self
.
_learning_rate
,
float
):
raise
TypeError
(
"learning rate variable is create outside optimizer,"
"can not create new learning rate variable for new program"
)
"can not create new learning rate variable for new program"
)
# create learning rate in the current main program
self
.
_learning_rate_map
[
framework
.
default_main_program
(
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
d8d73ff3
...
...
@@ -78,7 +78,7 @@ list(REMOVE_ITEM TEST_OPS test_image_classification_resnet)
list
(
REMOVE_ITEM TEST_OPS test_bilinear_interp_op
)
list
(
REMOVE_ITEM TEST_OPS test_nearest_interp_op
)
list
(
REMOVE_ITEM TEST_OPS test_imperative_resnet
)
list
(
REMOVE_ITEM TEST_OPS test_imperative_
optimizer
)
list
(
REMOVE_ITEM TEST_OPS test_imperative_
mnist
)
list
(
REMOVE_ITEM TEST_OPS test_ir_memory_optimize_transformer
)
foreach
(
TEST_OP
${
TEST_OPS
}
)
py_test_modules
(
${
TEST_OP
}
MODULES
${
TEST_OP
}
)
...
...
@@ -89,7 +89,7 @@ py_test_modules(test_bilinear_interp_op MODULES test_bilinear_interp_op SERIAL)
py_test_modules
(
test_nearest_interp_op MODULES test_nearest_interp_op SERIAL
)
py_test_modules
(
test_imperative_resnet MODULES test_imperative_resnet ENVS
FLAGS_cudnn_deterministic=1
)
py_test_modules
(
test_imperative_
optimizer MODULES test_imperative_optimizer
ENVS
py_test_modules
(
test_imperative_
mnist MODULES test_imperative_mnist
ENVS
FLAGS_cudnn_deterministic=1
)
if
(
WITH_DISTRIBUTE
)
py_test_modules
(
test_dist_train MODULES test_dist_train SERIAL
)
...
...
python/paddle/fluid/tests/unittests/test_imperative_mnist.py
0 → 100644
浏览文件 @
d8d73ff3
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
contextlib
import
unittest
import
numpy
as
np
import
six
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid
import
core
from
paddle.fluid.optimizer
import
SGDOptimizer
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
FC
from
paddle.fluid.dygraph.base
import
to_variable
from
test_imperative_base
import
new_program_scope
class
SimpleImgConvPool
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
name_scope
,
num_channels
,
num_filters
,
filter_size
,
pool_size
,
pool_stride
,
pool_padding
=
0
,
pool_type
=
'max'
,
global_pooling
=
False
,
conv_stride
=
1
,
conv_padding
=
0
,
conv_dilation
=
1
,
conv_groups
=
1
,
act
=
None
,
use_cudnn
=
False
,
param_attr
=
None
,
bias_attr
=
None
):
super
(
SimpleImgConvPool
,
self
).
__init__
(
name_scope
)
self
.
_conv2d
=
Conv2D
(
self
.
full_name
(),
num_channels
=
num_channels
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
conv_stride
,
padding
=
conv_padding
,
dilation
=
conv_dilation
,
groups
=
conv_groups
,
param_attr
=
None
,
bias_attr
=
None
,
use_cudnn
=
use_cudnn
)
self
.
_pool2d
=
Pool2D
(
self
.
full_name
(),
pool_size
=
pool_size
,
pool_type
=
pool_type
,
pool_stride
=
pool_stride
,
pool_padding
=
pool_padding
,
global_pooling
=
global_pooling
,
use_cudnn
=
use_cudnn
)
def
forward
(
self
,
inputs
):
x
=
self
.
_conv2d
(
inputs
)
x
=
self
.
_pool2d
(
x
)
return
x
class
MNIST
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
name_scope
):
super
(
MNIST
,
self
).
__init__
(
name_scope
)
self
.
_simple_img_conv_pool_1
=
SimpleImgConvPool
(
self
.
full_name
(),
1
,
20
,
5
,
2
,
2
,
act
=
"relu"
)
self
.
_simple_img_conv_pool_2
=
SimpleImgConvPool
(
self
.
full_name
(),
20
,
50
,
5
,
2
,
2
,
act
=
"relu"
)
pool_2_shape
=
50
*
4
*
4
SIZE
=
10
scale
=
(
2.0
/
(
pool_2_shape
**
2
*
SIZE
))
**
0.5
self
.
_fc
=
FC
(
self
.
full_name
(),
10
,
param_attr
=
fluid
.
param_attr
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
NormalInitializer
(
loc
=
0.0
,
scale
=
scale
)),
act
=
"softmax"
)
def
forward
(
self
,
inputs
):
x
=
self
.
_simple_img_conv_pool_1
(
inputs
)
x
=
self
.
_simple_img_conv_pool_2
(
x
)
x
=
self
.
_fc
(
x
)
return
x
class
TestImperativeMnist
(
unittest
.
TestCase
):
def
test_mnist_float32
(
self
):
seed
=
90
epoch_num
=
1
with
fluid
.
dygraph
.
guard
():
fluid
.
default_startup_program
().
random_seed
=
seed
fluid
.
default_main_program
().
random_seed
=
seed
mnist
=
MNIST
(
"mnist"
)
sgd
=
SGDOptimizer
(
learning_rate
=
1e-3
)
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
128
,
drop_last
=
True
)
dy_param_init_value
=
{}
for
epoch
in
range
(
epoch_num
):
for
batch_id
,
data
in
enumerate
(
train_reader
()):
dy_x_data
=
np
.
array
(
[
x
[
0
].
reshape
(
1
,
28
,
28
)
for
x
in
data
]).
astype
(
'float32'
)
y_data
=
np
.
array
(
[
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
).
reshape
(
128
,
1
)
img
=
to_variable
(
dy_x_data
)
label
=
to_variable
(
y_data
)
label
.
_stop_gradient
=
True
cost
=
mnist
(
img
)
loss
=
fluid
.
layers
.
cross_entropy
(
cost
,
label
)
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
dy_out
=
avg_loss
.
_numpy
()
if
epoch
==
0
and
batch_id
==
0
:
for
param
in
mnist
.
parameters
():
dy_param_init_value
[
param
.
name
]
=
param
.
_numpy
()
avg_loss
.
_backward
()
sgd
.
minimize
(
avg_loss
)
mnist
.
clear_gradients
()
dy_param_value
=
{}
for
param
in
mnist
.
parameters
():
dy_param_value
[
param
.
name
]
=
param
.
_numpy
()
with
new_program_scope
():
fluid
.
default_startup_program
().
random_seed
=
seed
fluid
.
default_main_program
().
random_seed
=
seed
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
(
)
if
not
core
.
is_compiled_with_cuda
()
else
fluid
.
CUDAPlace
(
0
))
mnist
=
MNIST
(
"mnist"
)
sgd
=
SGDOptimizer
(
learning_rate
=
1e-3
)
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
128
,
drop_last
=
True
)
img
=
fluid
.
layers
.
data
(
name
=
'pixel'
,
shape
=
[
1
,
28
,
28
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
cost
=
mnist
(
img
)
loss
=
fluid
.
layers
.
cross_entropy
(
cost
,
label
)
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
sgd
.
minimize
(
avg_loss
)
# initialize params and fetch them
static_param_init_value
=
{}
static_param_name_list
=
[]
for
param
in
mnist
.
parameters
():
static_param_name_list
.
append
(
param
.
name
)
out
=
exe
.
run
(
fluid
.
default_startup_program
(),
fetch_list
=
static_param_name_list
)
for
i
in
range
(
len
(
static_param_name_list
)):
static_param_init_value
[
static_param_name_list
[
i
]]
=
out
[
i
]
for
epoch
in
range
(
epoch_num
):
for
batch_id
,
data
in
enumerate
(
train_reader
()):
static_x_data
=
np
.
array
(
[
x
[
0
].
reshape
(
1
,
28
,
28
)
for
x
in
data
]).
astype
(
'float32'
)
y_data
=
np
.
array
(
[
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
).
reshape
([
128
,
1
])
fetch_list
=
[
avg_loss
.
name
]
fetch_list
.
extend
(
static_param_name_list
)
out
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
{
"pixel"
:
static_x_data
,
"label"
:
y_data
},
fetch_list
=
fetch_list
)
static_param_value
=
{}
static_out
=
out
[
0
]
for
i
in
range
(
1
,
len
(
out
)):
static_param_value
[
static_param_name_list
[
i
-
1
]]
=
out
[
i
]
self
.
assertTrue
(
np
.
allclose
(
dy_x_data
.
all
(),
static_x_data
.
all
()))
for
key
,
value
in
six
.
iteritems
(
static_param_init_value
):
self
.
assertTrue
(
np
.
allclose
(
value
,
dy_param_init_value
[
key
]))
self
.
assertTrue
(
np
.
allclose
(
static_out
,
dy_out
))
for
key
,
value
in
six
.
iteritems
(
static_param_value
):
self
.
assertTrue
(
np
.
allclose
(
value
,
dy_param_value
[
key
],
atol
=
1e-5
))
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_imperative_optimizer.py
浏览文件 @
d8d73ff3
...
...
@@ -22,130 +22,70 @@ import six
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid
import
core
from
paddle.fluid.optimizer
import
SGDOptimizer
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
FC
from
paddle.fluid.optimizer
import
SGDOptimizer
,
Adam
from
paddle.fluid.dygraph.nn
import
FC
from
paddle.fluid.dygraph.base
import
to_variable
from
test_imperative_base
import
new_program_scope
class
SimpleImgConvPool
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
name_scope
,
num_channels
,
num_filters
,
filter_size
,
pool_size
,
pool_stride
,
pool_padding
=
0
,
pool_type
=
'max'
,
global_pooling
=
False
,
conv_stride
=
1
,
conv_padding
=
0
,
conv_dilation
=
1
,
conv_groups
=
1
,
act
=
None
,
use_cudnn
=
False
,
param_attr
=
None
,
bias_attr
=
None
):
super
(
SimpleImgConvPool
,
self
).
__init__
(
name_scope
)
self
.
_conv2d
=
Conv2D
(
self
.
full_name
(),
num_channels
=
num_channels
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
conv_stride
,
padding
=
conv_padding
,
dilation
=
conv_dilation
,
groups
=
conv_groups
,
param_attr
=
None
,
bias_attr
=
None
,
use_cudnn
=
use_cudnn
)
self
.
_pool2d
=
Pool2D
(
self
.
full_name
(),
pool_size
=
pool_size
,
pool_type
=
pool_type
,
pool_stride
=
pool_stride
,
pool_padding
=
pool_padding
,
global_pooling
=
global_pooling
,
use_cudnn
=
use_cudnn
)
class
MLP
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
name_scope
,
param_attr
=
None
,
bias_attr
=
None
):
super
(
MLP
,
self
).
__init__
(
name_scope
)
def
forward
(
self
,
inputs
):
x
=
self
.
_conv2d
(
inputs
)
x
=
self
.
_pool2d
(
x
)
return
x
class
MNIST
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
name_scope
):
super
(
MNIST
,
self
).
__init__
(
name_scope
)
self
.
_fc1
=
FC
(
self
.
full_name
(),
10
)
self
.
_fc2
=
FC
(
self
.
full_name
(),
10
)
self
.
_simple_img_conv_pool_1
=
SimpleImgConvPool
(
self
.
full_name
(),
1
,
20
,
5
,
2
,
2
,
act
=
"relu"
)
def
forward
(
self
,
inputs
):
y
=
self
.
_fc1
(
inputs
)
y
=
self
.
_fc2
(
y
)
return
y
self
.
_simple_img_conv_pool_2
=
SimpleImgConvPool
(
self
.
full_name
(),
20
,
50
,
5
,
2
,
2
,
act
=
"relu"
)
pool_2_shape
=
50
*
4
*
4
SIZE
=
10
scale
=
(
2.0
/
(
pool_2_shape
**
2
*
SIZE
))
**
0.5
self
.
_fc
=
FC
(
self
.
full_name
(),
10
,
param_attr
=
fluid
.
param_attr
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
NormalInitializer
(
loc
=
0.0
,
scale
=
scale
)),
act
=
"softmax"
)
class
TestImperativeOptimizerBase
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
batch_num
=
20
def
forward
(
self
,
inputs
):
x
=
self
.
_simple_img_conv_pool_1
(
inputs
)
x
=
self
.
_simple_img_conv_pool_2
(
x
)
x
=
self
.
_fc
(
x
)
return
x
def
get_optimizer
(
self
):
raise
NotImplementedError
()
class
TestDygraphMnist
(
unittest
.
TestCase
):
def
test_mnist_float32
(
self
):
def
_check_mlp
(
self
):
seed
=
90
epoch_num
=
1
with
fluid
.
dygraph
.
guard
():
fluid
.
default_startup_program
().
random_seed
=
seed
fluid
.
default_main_program
().
random_seed
=
seed
m
nist
=
MNIST
(
"mnist"
)
sgd
=
SGDOptimizer
(
learning_rate
=
1e-3
)
m
lp
=
MLP
(
'mlp'
)
optimizer
=
self
.
get_optimizer
(
)
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
128
,
drop_last
=
True
)
dy_param_init_value
=
{}
for
epoch
in
range
(
epoch_num
):
for
batch_id
,
data
in
enumerate
(
train_reader
()):
if
batch_id
>=
self
.
batch_num
:
break
dy_x_data
=
np
.
array
(
[
x
[
0
].
reshape
(
1
,
28
,
28
)
for
x
in
data
]).
astype
(
'float32'
)
y_data
=
np
.
array
(
[
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
).
reshape
(
128
,
1
)
[
x
[
0
].
reshape
(
1
,
28
,
28
)
for
x
in
data
]).
astype
(
'float32'
)
y_data
=
np
.
array
([
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
).
reshape
(
128
,
1
)
img
=
to_variable
(
dy_x_data
)
label
=
to_variable
(
y_data
)
label
.
_stop_gradient
=
True
cost
=
mnist
(
img
)
loss
=
fluid
.
layers
.
cross_entropy
(
cost
,
label
)
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
cost
=
mlp
(
img
)
avg_loss
=
fluid
.
layers
.
reduce_mean
(
cost
)
dy_out
=
avg_loss
.
_numpy
()
if
epoch
==
0
and
batch_id
==
0
:
for
param
in
mnist
.
parameters
():
if
batch_id
==
0
:
for
param
in
mlp
.
parameters
():
dy_param_init_value
[
param
.
name
]
=
param
.
_numpy
()
avg_loss
.
_backward
()
sgd
.
minimize
(
avg_loss
)
mnist
.
clear_gradients
()
optimizer
.
minimize
(
avg_loss
)
mlp
.
clear_gradients
()
dy_param_value
=
{}
for
param
in
mnist
.
parameters
():
for
param
in
mlp
.
parameters
():
dy_param_value
[
param
.
name
]
=
param
.
_numpy
()
with
new_program_scope
():
...
...
@@ -155,23 +95,22 @@ class TestDygraphMnist(unittest.TestCase):
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
(
)
if
not
core
.
is_compiled_with_cuda
()
else
fluid
.
CUDAPlace
(
0
))
m
nist
=
MNIST
(
"mnist"
)
sgd
=
SGDOptimizer
(
learning_rate
=
1e-3
)
m
lp
=
MLP
(
'mlp'
)
optimizer
=
self
.
get_optimizer
(
)
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
128
,
drop_last
=
True
)
img
=
fluid
.
layers
.
data
(
name
=
'pixel'
,
shape
=
[
1
,
28
,
28
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
cost
=
mnist
(
img
)
loss
=
fluid
.
layers
.
cross_entropy
(
cost
,
label
)
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
sgd
.
minimize
(
avg_loss
)
cost
=
mlp
(
img
)
avg_loss
=
fluid
.
layers
.
reduce_mean
(
cost
)
optimizer
.
minimize
(
avg_loss
)
# initialize params and fetch them
static_param_init_value
=
{}
static_param_name_list
=
[]
for
param
in
m
nist
.
parameters
():
for
param
in
m
lp
.
parameters
():
static_param_name_list
.
append
(
param
.
name
)
out
=
exe
.
run
(
fluid
.
default_startup_program
(),
...
...
@@ -180,18 +119,18 @@ class TestDygraphMnist(unittest.TestCase):
for
i
in
range
(
len
(
static_param_name_list
)):
static_param_init_value
[
static_param_name_list
[
i
]]
=
out
[
i
]
for
epoch
in
range
(
epoch_num
):
for
batch_id
,
data
in
enumerate
(
train_reader
()):
if
batch_id
>=
self
.
batch_num
:
break
static_x_data
=
np
.
array
(
[
x
[
0
].
reshape
(
1
,
28
,
28
)
for
x
in
data
]).
astype
(
'float32'
)
y_data
=
np
.
array
(
[
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
).
reshape
([
128
,
1
])
[
x
[
0
].
reshape
(
1
,
28
,
28
)
for
x
in
data
]).
astype
(
'float32'
)
y_data
=
np
.
array
([
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
).
reshape
(
[
128
,
1
])
fetch_list
=
[
avg_loss
.
name
]
fetch_list
.
extend
(
static_param_name_list
)
out
=
exe
.
run
(
fluid
.
default_main_program
(),
out
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
{
"pixel"
:
static_x_data
,
"label"
:
y_data
},
fetch_list
=
fetch_list
)
...
...
@@ -199,10 +138,7 @@ class TestDygraphMnist(unittest.TestCase):
static_param_value
=
{}
static_out
=
out
[
0
]
for
i
in
range
(
1
,
len
(
out
)):
static_param_value
[
static_param_name_list
[
i
-
1
]]
=
out
[
i
]
self
.
assertTrue
(
np
.
allclose
(
dy_x_data
.
all
(),
static_x_data
.
all
()))
static_param_value
[
static_param_name_list
[
i
-
1
]]
=
out
[
i
]
for
key
,
value
in
six
.
iteritems
(
static_param_init_value
):
self
.
assertTrue
(
np
.
allclose
(
value
,
dy_param_init_value
[
key
]))
...
...
@@ -210,7 +146,92 @@ class TestDygraphMnist(unittest.TestCase):
self
.
assertTrue
(
np
.
allclose
(
static_out
,
dy_out
))
for
key
,
value
in
six
.
iteritems
(
static_param_value
):
self
.
assertTrue
(
np
.
allclose
(
value
,
dy_param_value
[
key
],
atol
=
1e-5
))
self
.
assertTrue
(
np
.
allclose
(
value
,
dy_param_value
[
key
]))
class
TestImperativeOptimizerPiecewiseDecay
(
TestImperativeOptimizerBase
):
def
get_optimizer
(
self
):
bd
=
[
3
,
6
,
9
]
optimizer
=
SGDOptimizer
(
learning_rate
=
fluid
.
layers
.
piecewise_decay
(
boundaries
=
bd
,
values
=
[
0.1
*
(
0.1
**
i
)
for
i
in
range
(
len
(
bd
)
+
1
)]))
return
optimizer
def
test_sgd
(
self
):
self
.
_check_mlp
()
class
TestImperativeOptimizerNaturalExpDecay
(
TestImperativeOptimizerBase
):
def
get_optimizer
(
self
):
optimizer
=
SGDOptimizer
(
learning_rate
=
fluid
.
layers
.
natural_exp_decay
(
learning_rate
=
0.1
,
decay_steps
=
10000
,
decay_rate
=
0.5
,
staircase
=
True
))
return
optimizer
def
test_sgd
(
self
):
self
.
_check_mlp
()
class
TestImperativeOptimizerExponentialDecay
(
TestImperativeOptimizerBase
):
def
get_optimizer
(
self
):
optimizer
=
SGDOptimizer
(
learning_rate
=
fluid
.
layers
.
exponential_decay
(
learning_rate
=
0.1
,
decay_steps
=
10000
,
decay_rate
=
0.5
,
staircase
=
True
))
return
optimizer
def
test_sgd
(
self
):
self
.
_check_mlp
()
class
TestImperativeOptimizerInverseTimeDecay
(
TestImperativeOptimizerBase
):
def
get_optimizer
(
self
):
optimizer
=
Adam
(
learning_rate
=
fluid
.
layers
.
inverse_time_decay
(
learning_rate
=
0.1
,
decay_steps
=
10000
,
decay_rate
=
0.5
,
staircase
=
True
))
return
optimizer
def
test_adam
(
self
):
self
.
_check_mlp
()
class
TestImperativeOptimizerPolynomialDecay
(
TestImperativeOptimizerBase
):
def
get_optimizer
(
self
):
optimizer
=
SGDOptimizer
(
learning_rate
=
fluid
.
layers
.
polynomial_decay
(
learning_rate
=
0.1
,
decay_steps
=
5
,
cycle
=
self
.
cycle
))
return
optimizer
def
test_sgd_cycle
(
self
):
self
.
cycle
=
True
self
.
_check_mlp
()
def
test_sgd
(
self
):
self
.
cycle
=
False
self
.
_check_mlp
()
class
TestImperativeOptimizerCosineDecay
(
TestImperativeOptimizerBase
):
def
get_optimizer
(
self
):
optimizer
=
SGDOptimizer
(
learning_rate
=
fluid
.
layers
.
cosine_decay
(
learning_rate
=
0.1
,
step_each_epoch
=
10000
,
epochs
=
120
))
return
optimizer
def
test_sgd
(
self
):
self
.
_check_mlp
()
class
TestImperativeOptimizerNoamDecay
(
TestImperativeOptimizerBase
):
def
get_optimizer
(
self
):
optimizer
=
SGDOptimizer
(
learning_rate
=
fluid
.
layers
.
noam_decay
(
d_model
=
512
,
warmup_steps
=
8000
))
return
optimizer
def
test_sgd
(
self
):
self
.
_check_mlp
()
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录