Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d37cd740
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d37cd740
编写于
4月 09, 2020
作者:
A
Aurelius84
提交者:
GitHub
4月 09, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Polish set_optimizer Interface (#23588)
上级
f301eb7f
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
66 addition
and
37 deletion
+66
-37
python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py
...dle/fluid/dygraph/dygraph_to_static/program_translator.py
+48
-32
python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py
...d/tests/unittests/dygraph_to_static/test_cache_program.py
+15
-2
python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py
...luid/tests/unittests/dygraph_to_static/test_fetch_feed.py
+1
-1
python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py
.../unittests/dygraph_to_static/test_save_inference_model.py
+2
-2
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py
浏览文件 @
d37cd740
...
@@ -20,6 +20,7 @@ import six
...
@@ -20,6 +20,7 @@ import six
import
textwrap
import
textwrap
import
threading
import
threading
import
warnings
import
warnings
from
collections
import
defaultdict
from
paddle.fluid
import
framework
from
paddle.fluid
import
framework
from
paddle.fluid
import
core
,
executor
from
paddle.fluid
import
core
,
executor
...
@@ -28,6 +29,7 @@ from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStat
...
@@ -28,6 +29,7 @@ from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStat
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ast_to_source_code
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ast_to_source_code
from
paddle.fluid.dygraph.dygraph_to_static.variable_trans_func
import
data_layer_not_check
from
paddle.fluid.dygraph.dygraph_to_static.variable_trans_func
import
data_layer_not_check
from
paddle.fluid.framework
import
in_dygraph_mode
from
paddle.fluid.framework
import
in_dygraph_mode
from
paddle.fluid.data_feeder
import
check_type
__all__
=
[
'ProgramTranslator'
,
'convert_function_with_cache'
]
__all__
=
[
'ProgramTranslator'
,
'convert_function_with_cache'
]
...
@@ -261,19 +263,20 @@ class ProgramTranslator(object):
...
@@ -261,19 +263,20 @@ class ProgramTranslator(object):
else
:
else
:
self
.
_exe
=
exe
self
.
_exe
=
exe
self
.
_program_cache
=
ProgramCache
()
self
.
_program_cache
=
ProgramCache
()
self
.
_optimizer_info
=
None
self
.
_optimizer
=
None
self
.
_optimizer
=
None
self
.
_
already_minimized
=
Fals
e
self
.
_
loss_name
=
Non
e
# Once main_program is changed, should run startup_program.
# Once main_program is changed, should run startup_program.
self
.
_need_startup
=
True
self
.
_need_startup
=
True
def
get_output
(
self
,
dygraph_func
,
*
args
,
**
kwargs
):
def
get_output
(
self
,
dygraph_func
,
*
args
,
**
kwargs
):
"""
"""
Return
s
the output tensors for dygraph function and its arguments
Return the output tensors for dygraph function and its arguments
"""
"""
if
in_dygraph_mode
():
if
in_dygraph_mode
():
warnings
.
warn
(
warnings
.
warn
(
"The ProgramTranslator.get_output doesn't work in dygraph "
"The ProgramTranslator.get_output doesn't work in dygraph "
"mode. We will just return dygraph output. Use
the
it in "
"mode. We will just return dygraph output. Use it in "
"static mode if you would like to translate to static graph."
)
"static mode if you would like to translate to static graph."
)
return
dygraph_func
(
*
args
,
**
kwargs
)
return
dygraph_func
(
*
args
,
**
kwargs
)
...
@@ -286,12 +289,12 @@ class ProgramTranslator(object):
...
@@ -286,12 +289,12 @@ class ProgramTranslator(object):
def
get_func
(
self
,
dygraph_func
):
def
get_func
(
self
,
dygraph_func
):
"""
"""
Return
s
the translated static function from dygraph function
Return the translated static function from dygraph function
"""
"""
if
in_dygraph_mode
():
if
in_dygraph_mode
():
warnings
.
warn
(
warnings
.
warn
(
"The ProgramTranslator.get_func doesn't work in dygraph "
"The ProgramTranslator.get_func doesn't work in dygraph "
"mode. We will just return dygraph function. Use
the
it in "
"mode. We will just return dygraph function. Use it in "
"static mode if you would like to translate to static graph."
)
"static mode if you would like to translate to static graph."
)
return
dygraph_func
return
dygraph_func
static_func
=
convert_function_with_cache
(
dygraph_func
)
static_func
=
convert_function_with_cache
(
dygraph_func
)
...
@@ -299,7 +302,7 @@ class ProgramTranslator(object):
...
@@ -299,7 +302,7 @@ class ProgramTranslator(object):
def
get_program
(
self
,
dygraph_func
,
*
args
,
**
kwargs
):
def
get_program
(
self
,
dygraph_func
,
*
args
,
**
kwargs
):
"""
"""
Return
s
the translated static program and input/output variables from
Return the translated static program and input/output variables from
dygraph function.
dygraph function.
"""
"""
if
in_dygraph_mode
():
if
in_dygraph_mode
():
...
@@ -315,7 +318,7 @@ class ProgramTranslator(object):
...
@@ -315,7 +318,7 @@ class ProgramTranslator(object):
def
get_code
(
self
,
dygraph_func
):
def
get_code
(
self
,
dygraph_func
):
"""
"""
Return
s
the translated static function code from dygraph code
Return the translated static function code from dygraph code
"""
"""
# Get AST from dygraph function
# Get AST from dygraph function
raw_code
=
inspect
.
getsource
(
dygraph_func
)
raw_code
=
inspect
.
getsource
(
dygraph_func
)
...
@@ -332,7 +335,7 @@ class ProgramTranslator(object):
...
@@ -332,7 +335,7 @@ class ProgramTranslator(object):
def
run
(
self
,
*
args
,
**
kwargs
):
def
run
(
self
,
*
args
,
**
kwargs
):
"""
"""
Execute
s
main_program and returns output Tensors.
Execute main_program and returns output Tensors.
"""
"""
feed_dict
,
fetch_list
=
self
.
_prepare
(
args
)
feed_dict
,
fetch_list
=
self
.
_prepare
(
args
)
...
@@ -343,18 +346,18 @@ class ProgramTranslator(object):
...
@@ -343,18 +346,18 @@ class ProgramTranslator(object):
return
outputs
return
outputs
def
set_optimizer
(
self
,
optimizer
,
loss_name
):
def
set_optimizer
(
self
,
optimizer
,
index_of_loss
=
0
):
"""
"""
Support
s
to set or update the optimizer used to minimize loss.
Support to set or update the optimizer used to minimize loss.
"""
"""
check_type
(
index_of_loss
,
"index_of_loss"
,
int
,
"ProgramTranslator.set_optimizer"
)
self
.
_check_cache_valid
()
self
.
_check_cache_valid
()
self
.
_optimizer
=
optimizer
if
self
.
_optimizer
and
self
.
_loss_name
:
if
not
isinstance
(
loss_name
,
six
.
string_types
):
raise
ValueError
(
raise
ValueError
(
"
Type of input loss_name should type(str), but received {}.
"
.
"
{} for {} has already been set before. Please confirm not to call `set_optimizer` in for loop.
"
.
format
(
type
(
loss_name
)
))
format
(
self
.
_optimizer
,
self
.
_loss_name
))
self
.
_
loss_name
=
loss_name
self
.
_
optimizer_info
=
(
optimizer
,
index_of_loss
)
def
save_inference_model
(
self
,
dirname
,
feed
=
None
,
fetch
=
None
):
def
save_inference_model
(
self
,
dirname
,
feed
=
None
,
fetch
=
None
):
"""
"""
...
@@ -377,16 +380,16 @@ class ProgramTranslator(object):
...
@@ -377,16 +380,16 @@ class ProgramTranslator(object):
def
_prepare
(
self
,
args
):
def
_prepare
(
self
,
args
):
"""
"""
Prepare
s
with feed_dict, fetch_list, optimizer and initialize vars
Prepare with feed_dict, fetch_list, optimizer and initialize vars
by running startup_program.
by running startup_program.
"""
"""
# Update
s
batch_data for feed_dict
# Update batch_data for feed_dict
feed_dict
=
self
.
_update_batch_data
(
args
)
feed_dict
=
self
.
_update_batch_data
(
args
)
fetch_list
=
self
.
_program_cache
.
outputs
fetch_list
=
self
.
_program_cache
.
outputs
# Add
s
optimizer if needed.
# Add optimizer if needed.
if
self
.
_optimizer
and
not
self
.
_already_minimized
:
if
self
.
_optimizer
_info
and
self
.
_optimizer
is
None
:
self
.
_add_optimizer
()
self
.
_add_optimizer
()
if
self
.
_need_startup
:
if
self
.
_need_startup
:
...
@@ -397,7 +400,7 @@ class ProgramTranslator(object):
...
@@ -397,7 +400,7 @@ class ProgramTranslator(object):
def
_check_cache_valid
(
self
):
def
_check_cache_valid
(
self
):
"""
"""
Check
s
whether the current program is consistent with `default_main_program`.
Check whether the current program is consistent with `default_main_program`.
In some models and unittest, program will be switched frequently by `program_guard`.
In some models and unittest, program will be switched frequently by `program_guard`.
If does, the cached program and other properties are not available and should be reset.
If does, the cached program and other properties are not available and should be reset.
"""
"""
...
@@ -408,7 +411,7 @@ class ProgramTranslator(object):
...
@@ -408,7 +411,7 @@ class ProgramTranslator(object):
def
_update_batch_data
(
self
,
args
):
def
_update_batch_data
(
self
,
args
):
"""
"""
Update
s
cached batch data while training program.
Update cached batch data while training program.
"""
"""
feed_name_to_idx
=
self
.
_program_cache
.
feed_name_to_idx
feed_name_to_idx
=
self
.
_program_cache
.
feed_name_to_idx
feed_vars
=
self
.
_program_cache
.
inputs
feed_vars
=
self
.
_program_cache
.
inputs
...
@@ -421,27 +424,40 @@ class ProgramTranslator(object):
...
@@ -421,27 +424,40 @@ class ProgramTranslator(object):
def
_add_optimizer
(
self
):
def
_add_optimizer
(
self
):
"""
"""
Support
s
to set or update the optimizer used to minimize loss.
Support to set or update the optimizer used to minimize loss.
"""
"""
optimizer
,
index_of_loss
=
self
.
_optimizer_info
outputs
=
self
.
_program_cache
.
outputs
outputs
=
[
outputs
]
if
not
isinstance
(
outputs
,
(
list
,
tuple
))
else
outputs
assert
abs
(
index_of_loss
)
<
len
(
outputs
),
\
"index_of_loss: {} shall not exceed the length of outputs: {}."
.
format
(
index_of_loss
,
len
(
outputs
))
loss_var
=
outputs
[
index_of_loss
]
check_type
(
loss_var
,
"loss_var"
,
framework
.
Variable
,
"ProgramTranslator._add_optimizer"
)
main_program
=
self
.
_program_cache
.
main_program
main_program
=
self
.
_program_cache
.
main_program
startup_program
=
self
.
_program_cache
.
startup_program
startup_program
=
self
.
_program_cache
.
startup_program
all_vars
=
main_program
.
block
(
0
).
vars
all_vars
=
main_program
.
block
(
0
).
vars
loss_var
=
all_vars
.
get
(
self
.
_loss_name
,
None
)
if
loss_var
is
None
:
if
all_vars
.
get
(
loss_var
.
name
,
None
)
is
None
:
raise
ValueError
(
raise
ValueError
(
"Can't find {} in main_program, please confirm whether the
loss input is correct
"
"Can't find {} in main_program, please confirm whether the
input loss is correct.
"
.
format
(
self
.
_loss_
name
))
.
format
(
loss_var
.
name
))
# Add
s
optimizer to minimize loss
# Add optimizer to minimize loss
with
framework
.
program_guard
(
main_program
,
startup_program
):
with
framework
.
program_guard
(
main_program
,
startup_program
):
self
.
_
optimizer
.
minimize
(
loss_var
)
optimizer
.
minimize
(
loss_var
)
# Avoids to set optimizer repeatedly.
self
.
_optimizer
=
optimizer
self
.
_
already_minimized
=
Tru
e
self
.
_
loss_name
=
loss_var
.
nam
e
def
get_program_cache
(
self
):
def
get_program_cache
(
self
):
"""
"""
Return
s
the ProgramCache instance.
Return the ProgramCache instance.
"""
"""
self
.
_check_cache_valid
()
self
.
_check_cache_valid
()
return
self
.
_program_cache
return
self
.
_program_cache
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py
浏览文件 @
d37cd740
...
@@ -76,9 +76,8 @@ class TestCacheProgramWithOptimizer(unittest.TestCase):
...
@@ -76,9 +76,8 @@ class TestCacheProgramWithOptimizer(unittest.TestCase):
static_net
=
self
.
dygraph_class
()
static_net
=
self
.
dygraph_class
()
adam
=
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.001
)
adam
=
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.001
)
# set optimizer
# set optimizer
# TODO: Need a better interfaces to set optimizer.
program_translator
=
ProgramTranslator
()
program_translator
=
ProgramTranslator
()
program_translator
.
set_optimizer
(
adam
,
'avg_loss'
)
program_translator
.
set_optimizer
(
adam
,
index_of_loss
=
1
)
for
batch_id
in
range
(
self
.
batch_num
):
for
batch_id
in
range
(
self
.
batch_num
):
pred
,
avg_loss
=
static_net
(
self
.
data
)
pred
,
avg_loss
=
static_net
(
self
.
data
)
...
@@ -110,6 +109,20 @@ class TestCacheProgramWithOptimizer(unittest.TestCase):
...
@@ -110,6 +109,20 @@ class TestCacheProgramWithOptimizer(unittest.TestCase):
msg
=
'dygraph is {}
\n
static_res is
\n
{}'
.
format
(
dygraph_loss
,
msg
=
'dygraph is {}
\n
static_res is
\n
{}'
.
format
(
dygraph_loss
,
static_loss
))
static_loss
))
def
test_exception
(
self
):
main_program
=
fluid
.
Program
()
loss_data
=
[]
with
fluid
.
program_guard
(
main_program
):
static_net
=
self
.
dygraph_class
()
adam
=
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.001
)
# set optimizer
program_translator
=
ProgramTranslator
()
with
self
.
assertRaisesRegexp
(
ValueError
,
"has already been set"
):
for
batch_id
in
range
(
self
.
batch_num
):
program_translator
.
set_optimizer
(
adam
,
index_of_loss
=
1
)
static_net
(
self
.
data
)
def
simple_func
(
x
):
def
simple_func
(
x
):
inputs
=
fluid
.
dygraph
.
to_variable
(
x
)
inputs
=
fluid
.
dygraph
.
to_variable
(
x
)
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py
浏览文件 @
d37cd740
...
@@ -58,7 +58,7 @@ class Linear(fluid.dygraph.Layer):
...
@@ -58,7 +58,7 @@ class Linear(fluid.dygraph.Layer):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
inputs
=
fluid
.
dygraph
.
to_variable
(
x
)
inputs
=
fluid
.
dygraph
.
to_variable
(
x
)
pre
=
self
.
fc
(
inputs
)
pre
=
self
.
fc
(
inputs
)
loss
=
fluid
.
layers
.
mean
(
pre
,
name
=
'avg_loss'
)
loss
=
fluid
.
layers
.
mean
(
pre
)
return
pre
,
loss
return
pre
,
loss
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py
浏览文件 @
d37cd740
...
@@ -39,7 +39,7 @@ class SimpleFcLayer(fluid.dygraph.Layer):
...
@@ -39,7 +39,7 @@ class SimpleFcLayer(fluid.dygraph.Layer):
x
=
fluid
.
dygraph
.
to_variable
(
x
)
x
=
fluid
.
dygraph
.
to_variable
(
x
)
y
=
self
.
_linear
(
x
)
y
=
self
.
_linear
(
x
)
z
=
self
.
_linear
(
y
)
z
=
self
.
_linear
(
y
)
out
=
fluid
.
layers
.
mean
(
z
,
name
=
'mean'
)
out
=
fluid
.
layers
.
mean
(
z
)
return
out
return
out
...
@@ -53,7 +53,7 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase):
...
@@ -53,7 +53,7 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase):
program_translator
=
ProgramTranslator
.
get_instance
()
program_translator
=
ProgramTranslator
.
get_instance
()
program_cache
=
ProgramTranslator
().
get_program_cache
program_cache
=
ProgramTranslator
().
get_program_cache
adam
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
adam
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
program_translator
.
set_optimizer
(
adam
,
'mean'
)
program_translator
.
set_optimizer
(
adam
,
index_of_loss
=
0
)
for
i
in
range
(
5
):
for
i
in
range
(
5
):
out
=
layer
(
x
)
out
=
layer
(
x
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录