Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d37cd740
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d37cd740
编写于
4月 09, 2020
作者:
A
Aurelius84
提交者:
GitHub
4月 09, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Polish set_optimizer Interface (#23588)
上级
f301eb7f
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
66 addition
and
37 deletion
+66
-37
python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py
...dle/fluid/dygraph/dygraph_to_static/program_translator.py
+48
-32
python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py
...d/tests/unittests/dygraph_to_static/test_cache_program.py
+15
-2
python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py
...luid/tests/unittests/dygraph_to_static/test_fetch_feed.py
+1
-1
python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py
.../unittests/dygraph_to_static/test_save_inference_model.py
+2
-2
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py
浏览文件 @
d37cd740
...
...
@@ -20,6 +20,7 @@ import six
import
textwrap
import
threading
import
warnings
from
collections
import
defaultdict
from
paddle.fluid
import
framework
from
paddle.fluid
import
core
,
executor
...
...
@@ -28,6 +29,7 @@ from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStat
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ast_to_source_code
from
paddle.fluid.dygraph.dygraph_to_static.variable_trans_func
import
data_layer_not_check
from
paddle.fluid.framework
import
in_dygraph_mode
from
paddle.fluid.data_feeder
import
check_type
__all__
=
[
'ProgramTranslator'
,
'convert_function_with_cache'
]
...
...
@@ -261,19 +263,20 @@ class ProgramTranslator(object):
else
:
self
.
_exe
=
exe
self
.
_program_cache
=
ProgramCache
()
self
.
_optimizer_info
=
None
self
.
_optimizer
=
None
self
.
_
already_minimized
=
Fals
e
self
.
_
loss_name
=
Non
e
# Once main_program is changed, should run startup_program.
self
.
_need_startup
=
True
def
get_output
(
self
,
dygraph_func
,
*
args
,
**
kwargs
):
"""
Return
s
the output tensors for dygraph function and its arguments
Return the output tensors for dygraph function and its arguments
"""
if
in_dygraph_mode
():
warnings
.
warn
(
"The ProgramTranslator.get_output doesn't work in dygraph "
"mode. We will just return dygraph output. Use
the
it in "
"mode. We will just return dygraph output. Use it in "
"static mode if you would like to translate to static graph."
)
return
dygraph_func
(
*
args
,
**
kwargs
)
...
...
@@ -286,12 +289,12 @@ class ProgramTranslator(object):
def
get_func
(
self
,
dygraph_func
):
"""
Return
s
the translated static function from dygraph function
Return the translated static function from dygraph function
"""
if
in_dygraph_mode
():
warnings
.
warn
(
"The ProgramTranslator.get_func doesn't work in dygraph "
"mode. We will just return dygraph function. Use
the
it in "
"mode. We will just return dygraph function. Use it in "
"static mode if you would like to translate to static graph."
)
return
dygraph_func
static_func
=
convert_function_with_cache
(
dygraph_func
)
...
...
@@ -299,7 +302,7 @@ class ProgramTranslator(object):
def
get_program
(
self
,
dygraph_func
,
*
args
,
**
kwargs
):
"""
Return
s
the translated static program and input/output variables from
Return the translated static program and input/output variables from
dygraph function.
"""
if
in_dygraph_mode
():
...
...
@@ -315,7 +318,7 @@ class ProgramTranslator(object):
def
get_code
(
self
,
dygraph_func
):
"""
Return
s
the translated static function code from dygraph code
Return the translated static function code from dygraph code
"""
# Get AST from dygraph function
raw_code
=
inspect
.
getsource
(
dygraph_func
)
...
...
@@ -332,7 +335,7 @@ class ProgramTranslator(object):
def
run
(
self
,
*
args
,
**
kwargs
):
"""
Execute
s
main_program and returns output Tensors.
Execute main_program and returns output Tensors.
"""
feed_dict
,
fetch_list
=
self
.
_prepare
(
args
)
...
...
@@ -343,18 +346,18 @@ class ProgramTranslator(object):
return
outputs
def
set_optimizer
(
self
,
optimizer
,
loss_name
):
def
set_optimizer
(
self
,
optimizer
,
index_of_loss
=
0
):
"""
Support
s
to set or update the optimizer used to minimize loss.
Support to set or update the optimizer used to minimize loss.
"""
check_type
(
index_of_loss
,
"index_of_loss"
,
int
,
"ProgramTranslator.set_optimizer"
)
self
.
_check_cache_valid
()
self
.
_optimizer
=
optimizer
if
not
isinstance
(
loss_name
,
six
.
string_types
):
if
self
.
_optimizer
and
self
.
_loss_name
:
raise
ValueError
(
"
Type of input loss_name should type(str), but received {}.
"
.
format
(
type
(
loss_name
)
))
self
.
_
loss_name
=
loss_name
"
{} for {} has already been set before. Please confirm not to call `set_optimizer` in for loop.
"
.
format
(
self
.
_optimizer
,
self
.
_loss_name
))
self
.
_
optimizer_info
=
(
optimizer
,
index_of_loss
)
def
save_inference_model
(
self
,
dirname
,
feed
=
None
,
fetch
=
None
):
"""
...
...
@@ -377,16 +380,16 @@ class ProgramTranslator(object):
def
_prepare
(
self
,
args
):
"""
Prepare
s
with feed_dict, fetch_list, optimizer and initialize vars
Prepare with feed_dict, fetch_list, optimizer and initialize vars
by running startup_program.
"""
# Update
s
batch_data for feed_dict
# Update batch_data for feed_dict
feed_dict
=
self
.
_update_batch_data
(
args
)
fetch_list
=
self
.
_program_cache
.
outputs
# Add
s
optimizer if needed.
if
self
.
_optimizer
and
not
self
.
_already_minimized
:
# Add optimizer if needed.
if
self
.
_optimizer
_info
and
self
.
_optimizer
is
None
:
self
.
_add_optimizer
()
if
self
.
_need_startup
:
...
...
@@ -397,7 +400,7 @@ class ProgramTranslator(object):
def
_check_cache_valid
(
self
):
"""
Check
s
whether the current program is consistent with `default_main_program`.
Check whether the current program is consistent with `default_main_program`.
In some models and unittest, program will be switched frequently by `program_guard`.
If does, the cached program and other properties are not available and should be reset.
"""
...
...
@@ -408,7 +411,7 @@ class ProgramTranslator(object):
def
_update_batch_data
(
self
,
args
):
"""
Update
s
cached batch data while training program.
Update cached batch data while training program.
"""
feed_name_to_idx
=
self
.
_program_cache
.
feed_name_to_idx
feed_vars
=
self
.
_program_cache
.
inputs
...
...
@@ -421,27 +424,40 @@ class ProgramTranslator(object):
def
_add_optimizer
(
self
):
"""
Support
s
to set or update the optimizer used to minimize loss.
Support to set or update the optimizer used to minimize loss.
"""
optimizer
,
index_of_loss
=
self
.
_optimizer_info
outputs
=
self
.
_program_cache
.
outputs
outputs
=
[
outputs
]
if
not
isinstance
(
outputs
,
(
list
,
tuple
))
else
outputs
assert
abs
(
index_of_loss
)
<
len
(
outputs
),
\
"index_of_loss: {} shall not exceed the length of outputs: {}."
.
format
(
index_of_loss
,
len
(
outputs
))
loss_var
=
outputs
[
index_of_loss
]
check_type
(
loss_var
,
"loss_var"
,
framework
.
Variable
,
"ProgramTranslator._add_optimizer"
)
main_program
=
self
.
_program_cache
.
main_program
startup_program
=
self
.
_program_cache
.
startup_program
all_vars
=
main_program
.
block
(
0
).
vars
loss_var
=
all_vars
.
get
(
self
.
_loss_name
,
None
)
if
loss_var
is
None
:
if
all_vars
.
get
(
loss_var
.
name
,
None
)
is
None
:
raise
ValueError
(
"Can't find {} in main_program, please confirm whether the
loss input is correct
"
.
format
(
self
.
_loss_
name
))
# Add
s
optimizer to minimize loss
"Can't find {} in main_program, please confirm whether the
input loss is correct.
"
.
format
(
loss_var
.
name
))
# Add optimizer to minimize loss
with
framework
.
program_guard
(
main_program
,
startup_program
):
self
.
_
optimizer
.
minimize
(
loss_var
)
optimizer
.
minimize
(
loss_var
)
# Avoids to set optimizer repeatedly.
self
.
_
already_minimized
=
Tru
e
self
.
_optimizer
=
optimizer
self
.
_
loss_name
=
loss_var
.
nam
e
def
get_program_cache
(
self
):
"""
Return
s
the ProgramCache instance.
Return the ProgramCache instance.
"""
self
.
_check_cache_valid
()
return
self
.
_program_cache
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py
浏览文件 @
d37cd740
...
...
@@ -76,9 +76,8 @@ class TestCacheProgramWithOptimizer(unittest.TestCase):
static_net
=
self
.
dygraph_class
()
adam
=
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.001
)
# set optimizer
# TODO: Need a better interfaces to set optimizer.
program_translator
=
ProgramTranslator
()
program_translator
.
set_optimizer
(
adam
,
'avg_loss'
)
program_translator
.
set_optimizer
(
adam
,
index_of_loss
=
1
)
for
batch_id
in
range
(
self
.
batch_num
):
pred
,
avg_loss
=
static_net
(
self
.
data
)
...
...
@@ -110,6 +109,20 @@ class TestCacheProgramWithOptimizer(unittest.TestCase):
msg
=
'dygraph is {}
\n
static_res is
\n
{}'
.
format
(
dygraph_loss
,
static_loss
))
def
test_exception
(
self
):
main_program
=
fluid
.
Program
()
loss_data
=
[]
with
fluid
.
program_guard
(
main_program
):
static_net
=
self
.
dygraph_class
()
adam
=
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.001
)
# set optimizer
program_translator
=
ProgramTranslator
()
with
self
.
assertRaisesRegexp
(
ValueError
,
"has already been set"
):
for
batch_id
in
range
(
self
.
batch_num
):
program_translator
.
set_optimizer
(
adam
,
index_of_loss
=
1
)
static_net
(
self
.
data
)
def
simple_func
(
x
):
inputs
=
fluid
.
dygraph
.
to_variable
(
x
)
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py
浏览文件 @
d37cd740
...
...
@@ -58,7 +58,7 @@ class Linear(fluid.dygraph.Layer):
def
forward
(
self
,
x
):
inputs
=
fluid
.
dygraph
.
to_variable
(
x
)
pre
=
self
.
fc
(
inputs
)
loss
=
fluid
.
layers
.
mean
(
pre
,
name
=
'avg_loss'
)
loss
=
fluid
.
layers
.
mean
(
pre
)
return
pre
,
loss
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_save_inference_model.py
浏览文件 @
d37cd740
...
...
@@ -39,7 +39,7 @@ class SimpleFcLayer(fluid.dygraph.Layer):
x
=
fluid
.
dygraph
.
to_variable
(
x
)
y
=
self
.
_linear
(
x
)
z
=
self
.
_linear
(
y
)
out
=
fluid
.
layers
.
mean
(
z
,
name
=
'mean'
)
out
=
fluid
.
layers
.
mean
(
z
)
return
out
...
...
@@ -53,7 +53,7 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase):
program_translator
=
ProgramTranslator
.
get_instance
()
program_cache
=
ProgramTranslator
().
get_program_cache
adam
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
program_translator
.
set_optimizer
(
adam
,
'mean'
)
program_translator
.
set_optimizer
(
adam
,
index_of_loss
=
0
)
for
i
in
range
(
5
):
out
=
layer
(
x
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录