Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d37cd740
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d37cd740
编写于
4月 09, 2020
作者:
A
Aurelius84
提交者:
GitHub
4月 09, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Polish set_optimizer Interface (#23588)
上级
f301eb7f
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
66 addition
and
37 deletion
+66
-37
python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py
...dle/fluid/dygraph/dygraph_to_static/program_translator.py
+48
-32
python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py
...d/tests/unittests/dygraph_to_static/test_cache_program.py
+15
-2
python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py
...luid/tests/unittests/dygraph_to_static/test_fetch_feed.py
+1
-1
python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py
.../unittests/dygraph_to_static/test_save_inference_model.py
+2
-2
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py
浏览文件 @
d37cd740
...
...
@@ -20,6 +20,7 @@ import six
import
textwrap
import
threading
import
warnings
from
collections
import
defaultdict
from
paddle.fluid
import
framework
from
paddle.fluid
import
core
,
executor
...
...
@@ -28,6 +29,7 @@ from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStat
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ast_to_source_code
from
paddle.fluid.dygraph.dygraph_to_static.variable_trans_func
import
data_layer_not_check
from
paddle.fluid.framework
import
in_dygraph_mode
from
paddle.fluid.data_feeder
import
check_type
__all__
=
[
'ProgramTranslator'
,
'convert_function_with_cache'
]
...
...
@@ -261,19 +263,20 @@ class ProgramTranslator(object):
else
:
self
.
_exe
=
exe
self
.
_program_cache
=
ProgramCache
()
self
.
_optimizer_info
=
None
self
.
_optimizer
=
None
self
.
_
already_minimized
=
Fals
e
self
.
_
loss_name
=
Non
e
# Once main_program is changed, should run startup_program.
self
.
_need_startup
=
True
def
get_output
(
self
,
dygraph_func
,
*
args
,
**
kwargs
):
"""
Return
s
the output tensors for dygraph function and its arguments
Return the output tensors for dygraph function and its arguments
"""
if
in_dygraph_mode
():
warnings
.
warn
(
"The ProgramTranslator.get_output doesn't work in dygraph "
"mode. We will just return dygraph output. Use
the
it in "
"mode. We will just return dygraph output. Use it in "
"static mode if you would like to translate to static graph."
)
return
dygraph_func
(
*
args
,
**
kwargs
)
...
...
@@ -286,12 +289,12 @@ class ProgramTranslator(object):
def
get_func
(
self
,
dygraph_func
):
"""
Return
s
the translated static function from dygraph function
Return the translated static function from dygraph function
"""
if
in_dygraph_mode
():
warnings
.
warn
(
"The ProgramTranslator.get_func doesn't work in dygraph "
"mode. We will just return dygraph function. Use
the
it in "
"mode. We will just return dygraph function. Use it in "
"static mode if you would like to translate to static graph."
)
return
dygraph_func
static_func
=
convert_function_with_cache
(
dygraph_func
)
...
...
@@ -299,7 +302,7 @@ class ProgramTranslator(object):
def
get_program
(
self
,
dygraph_func
,
*
args
,
**
kwargs
):
"""
Return
s
the translated static program and input/output variables from
Return the translated static program and input/output variables from
dygraph function.
"""
if
in_dygraph_mode
():
...
...
@@ -315,7 +318,7 @@ class ProgramTranslator(object):
def
get_code
(
self
,
dygraph_func
):
"""
Return
s
the translated static function code from dygraph code
Return the translated static function code from dygraph code
"""
# Get AST from dygraph function
raw_code
=
inspect
.
getsource
(
dygraph_func
)
...
...
@@ -332,7 +335,7 @@ class ProgramTranslator(object):
def
run
(
self
,
*
args
,
**
kwargs
):
"""
Execute
s
main_program and returns output Tensors.
Execute main_program and returns output Tensors.
"""
feed_dict
,
fetch_list
=
self
.
_prepare
(
args
)
...
...
@@ -343,18 +346,18 @@ class ProgramTranslator(object):
return
outputs
def
set_optimizer
(
self
,
optimizer
,
loss_name
):
def
set_optimizer
(
self
,
optimizer
,
index_of_loss
=
0
):
"""
Support
s
to set or update the optimizer used to minimize loss.
Support to set or update the optimizer used to minimize loss.
"""
check_type
(
index_of_loss
,
"index_of_loss"
,
int
,
"ProgramTranslator.set_optimizer"
)
self
.
_check_cache_valid
()
self
.
_optimizer
=
optimizer
if
not
isinstance
(
loss_name
,
six
.
string_types
):
if
self
.
_optimizer
and
self
.
_loss_name
:
raise
ValueError
(
"
Type of input loss_name should type(str), but received {}.
"
.
format
(
type
(
loss_name
)
))
self
.
_
loss_name
=
loss_name
"
{} for {} has already been set before. Please confirm not to call `set_optimizer` in for loop.
"
.
format
(
self
.
_optimizer
,
self
.
_loss_name
))
self
.
_
optimizer_info
=
(
optimizer
,
index_of_loss
)
def
save_inference_model
(
self
,
dirname
,
feed
=
None
,
fetch
=
None
):
"""
...
...
@@ -377,16 +380,16 @@ class ProgramTranslator(object):
def
_prepare
(
self
,
args
):
"""
Prepare
s
with feed_dict, fetch_list, optimizer and initialize vars
Prepare with feed_dict, fetch_list, optimizer and initialize vars
by running startup_program.
"""
# Update
s
batch_data for feed_dict
# Update batch_data for feed_dict
feed_dict
=
self
.
_update_batch_data
(
args
)
fetch_list
=
self
.
_program_cache
.
outputs
# Add
s
optimizer if needed.
if
self
.
_optimizer
and
not
self
.
_already_minimized
:
# Add optimizer if needed.
if
self
.
_optimizer
_info
and
self
.
_optimizer
is
None
:
self
.
_add_optimizer
()
if
self
.
_need_startup
:
...
...
@@ -397,7 +400,7 @@ class ProgramTranslator(object):
def
_check_cache_valid
(
self
):
"""
Check
s
whether the current program is consistent with `default_main_program`.
Check whether the current program is consistent with `default_main_program`.
In some models and unittest, program will be switched frequently by `program_guard`.
If does, the cached program and other properties are not available and should be reset.
"""
...
...
@@ -408,7 +411,7 @@ class ProgramTranslator(object):
def
_update_batch_data
(
self
,
args
):
"""
Update
s
cached batch data while training program.
Update cached batch data while training program.
"""
feed_name_to_idx
=
self
.
_program_cache
.
feed_name_to_idx
feed_vars
=
self
.
_program_cache
.
inputs
...
...
@@ -421,27 +424,40 @@ class ProgramTranslator(object):
def
_add_optimizer
(
self
):
"""
Support
s
to set or update the optimizer used to minimize loss.
Support to set or update the optimizer used to minimize loss.
"""
optimizer
,
index_of_loss
=
self
.
_optimizer_info
outputs
=
self
.
_program_cache
.
outputs
outputs
=
[
outputs
]
if
not
isinstance
(
outputs
,
(
list
,
tuple
))
else
outputs
assert
abs
(
index_of_loss
)
<
len
(
outputs
),
\
"index_of_loss: {} shall not exceed the length of outputs: {}."
.
format
(
index_of_loss
,
len
(
outputs
))
loss_var
=
outputs
[
index_of_loss
]
check_type
(
loss_var
,
"loss_var"
,
framework
.
Variable
,
"ProgramTranslator._add_optimizer"
)
main_program
=
self
.
_program_cache
.
main_program
startup_program
=
self
.
_program_cache
.
startup_program
all_vars
=
main_program
.
block
(
0
).
vars
loss_var
=
all_vars
.
get
(
self
.
_loss_name
,
None
)
if
loss_var
is
None
:
if
all_vars
.
get
(
loss_var
.
name
,
None
)
is
None
:
raise
ValueError
(
"Can't find {} in main_program, please confirm whether the
loss input is correct
"
.
format
(
self
.
_loss_
name
))
# Add
s
optimizer to minimize loss
"Can't find {} in main_program, please confirm whether the
input loss is correct.
"
.
format
(
loss_var
.
name
))
# Add optimizer to minimize loss
with
framework
.
program_guard
(
main_program
,
startup_program
):
self
.
_
optimizer
.
minimize
(
loss_var
)
optimizer
.
minimize
(
loss_var
)
# Avoids to set optimizer repeatedly.
self
.
_
already_minimized
=
Tru
e
self
.
_optimizer
=
optimizer
self
.
_
loss_name
=
loss_var
.
nam
e
def
get_program_cache
(
self
):
"""
Return
s
the ProgramCache instance.
Return the ProgramCache instance.
"""
self
.
_check_cache_valid
()
return
self
.
_program_cache
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py
浏览文件 @
d37cd740
...
...
@@ -76,9 +76,8 @@ class TestCacheProgramWithOptimizer(unittest.TestCase):
static_net
=
self
.
dygraph_class
()
adam
=
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.001
)
# set optimizer
# TODO: Need a better interfaces to set optimizer.
program_translator
=
ProgramTranslator
()
program_translator
.
set_optimizer
(
adam
,
'avg_loss'
)
program_translator
.
set_optimizer
(
adam
,
index_of_loss
=
1
)
for
batch_id
in
range
(
self
.
batch_num
):
pred
,
avg_loss
=
static_net
(
self
.
data
)
...
...
@@ -110,6 +109,20 @@ class TestCacheProgramWithOptimizer(unittest.TestCase):
msg
=
'dygraph is {}
\n
static_res is
\n
{}'
.
format
(
dygraph_loss
,
static_loss
))
def
test_exception
(
self
):
main_program
=
fluid
.
Program
()
loss_data
=
[]
with
fluid
.
program_guard
(
main_program
):
static_net
=
self
.
dygraph_class
()
adam
=
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.001
)
# set optimizer
program_translator
=
ProgramTranslator
()
with
self
.
assertRaisesRegexp
(
ValueError
,
"has already been set"
):
for
batch_id
in
range
(
self
.
batch_num
):
program_translator
.
set_optimizer
(
adam
,
index_of_loss
=
1
)
static_net
(
self
.
data
)
def
simple_func
(
x
):
inputs
=
fluid
.
dygraph
.
to_variable
(
x
)
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py
浏览文件 @
d37cd740
...
...
@@ -58,7 +58,7 @@ class Linear(fluid.dygraph.Layer):
def
forward
(
self
,
x
):
inputs
=
fluid
.
dygraph
.
to_variable
(
x
)
pre
=
self
.
fc
(
inputs
)
loss
=
fluid
.
layers
.
mean
(
pre
,
name
=
'avg_loss'
)
loss
=
fluid
.
layers
.
mean
(
pre
)
return
pre
,
loss
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py
浏览文件 @
d37cd740
...
...
@@ -39,7 +39,7 @@ class SimpleFcLayer(fluid.dygraph.Layer):
x
=
fluid
.
dygraph
.
to_variable
(
x
)
y
=
self
.
_linear
(
x
)
z
=
self
.
_linear
(
y
)
out
=
fluid
.
layers
.
mean
(
z
,
name
=
'mean'
)
out
=
fluid
.
layers
.
mean
(
z
)
return
out
...
...
@@ -53,7 +53,7 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase):
program_translator
=
ProgramTranslator
.
get_instance
()
program_cache
=
ProgramTranslator
().
get_program_cache
adam
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
program_translator
.
set_optimizer
(
adam
,
'mean'
)
program_translator
.
set_optimizer
(
adam
,
index_of_loss
=
0
)
for
i
in
range
(
5
):
out
=
layer
(
x
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录