Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
6e3e3f13
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6e3e3f13
编写于
9月 23, 2020
作者:
C
Chen Weihang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add get inference program api
上级
827ac36f
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
281 addition
and
154 deletion
+281
-154
python/paddle/fluid/dygraph/jit.py
python/paddle/fluid/dygraph/jit.py
+117
-78
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+127
-76
python/paddle/fluid/tests/unittests/test_inference_model_io.py
...n/paddle/fluid/tests/unittests/test_inference_model_io.py
+2
-0
python/paddle/fluid/tests/unittests/test_jit_save_load.py
python/paddle/fluid/tests/unittests/test_jit_save_load.py
+29
-0
python/paddle/jit/__init__.py
python/paddle/jit/__init__.py
+3
-0
python/paddle/static/__init__.py
python/paddle/static/__init__.py
+3
-0
未找到文件。
python/paddle/fluid/dygraph/jit.py
浏览文件 @
6e3e3f13
...
@@ -19,9 +19,11 @@ import pickle
...
@@ -19,9 +19,11 @@ import pickle
import
warnings
import
warnings
import
functools
import
functools
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
six
import
six
import
paddle
import
paddle
# deprecated module import
from
paddle.fluid
import
core
from
paddle.fluid
import
core
from
paddle.fluid.compiler
import
BuildStrategy
,
CompiledProgram
,
ExecutionStrategy
from
paddle.fluid.compiler
import
BuildStrategy
,
CompiledProgram
,
ExecutionStrategy
from
paddle.fluid.data_feeder
import
check_type
from
paddle.fluid.data_feeder
import
check_type
...
@@ -644,6 +646,18 @@ class SaveLoadConfig(object):
...
@@ -644,6 +646,18 @@ class SaveLoadConfig(object):
self
.
_keep_name_table
=
value
self
.
_keep_name_table
=
value
# NOTE(chenweihang): change jit.save/load argument `configs` to `config`
def
deprecate_save_load_configs
(
func
):
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
if
'configs'
in
kwargs
:
kwargs
[
'config'
]
=
kwargs
[
'configs'
]
kwargs
.
pop
(
'configs'
)
return
func
(
*
args
,
**
kwargs
)
return
wrapper
def
_get_input_var_names
(
inputs
,
input_spec
):
def
_get_input_var_names
(
inputs
,
input_spec
):
name_none_error
=
"The %s's name is None. "
\
name_none_error
=
"The %s's name is None. "
\
"When using jit.save, please set InputSepc's name in "
\
"When using jit.save, please set InputSepc's name in "
\
...
@@ -696,9 +710,9 @@ def _get_output_vars(outputs, output_spec):
...
@@ -696,9 +710,9 @@ def _get_output_vars(outputs, output_spec):
if
isinstance
(
var
,
Variable
):
if
isinstance
(
var
,
Variable
):
output_vars_dict
[
var
.
name
]
=
var
output_vars_dict
[
var
.
name
]
=
var
if
output_spec
is
None
:
if
output_spec
is
None
:
result_list
=
output_vars_dict
.
values
(
)
result_list
=
list
(
output_vars_dict
.
values
()
)
elif
output_spec
is
not
None
and
len
(
output_spec
)
==
len
(
output_vars_dict
):
elif
output_spec
is
not
None
and
len
(
output_spec
)
==
len
(
output_vars_dict
):
result_list
=
output_vars_dict
.
values
(
)
result_list
=
list
(
output_vars_dict
.
values
()
)
for
var
in
output_spec
:
for
var
in
output_spec
:
if
var
.
name
not
in
output_vars_dict
:
if
var
.
name
not
in
output_vars_dict
:
warnings
.
warn
(
name_no_exists_error
%
var
.
name
)
warnings
.
warn
(
name_no_exists_error
%
var
.
name
)
...
@@ -711,16 +725,95 @@ def _get_output_vars(outputs, output_spec):
...
@@ -711,16 +725,95 @@ def _get_output_vars(outputs, output_spec):
return
result_list
return
result_list
# NOTE(chenweihang): change jit.save/load argument `configs` to `config`
def
_infer_input_check
(
layer
,
input_spec
):
def
deprecate_save_load_configs
(
func
):
prog_translator
=
ProgramTranslator
()
@
functools
.
wraps
(
func
)
if
not
prog_translator
.
enable_to_static
:
def
wrapper
(
*
args
,
**
kwargs
):
raise
RuntimeError
(
if
'configs'
in
kwargs
:
"The paddle.jit.save doesn't work when setting ProgramTranslator.enable to False."
kwargs
[
'config'
]
=
kwargs
[
'configs'
]
)
kwargs
.
pop
(
'configs'
)
if
not
isinstance
(
layer
,
Layer
):
return
func
(
*
args
,
**
kwargs
)
raise
TypeError
(
"The input layer of paddle.jit.save should be 'Layer', but received layer type is %s."
%
type
(
layer
))
return
wrapper
# avoid change user given input_spec
inner_input_spec
=
None
if
input_spec
is
not
None
:
if
not
isinstance
(
input_spec
,
list
):
raise
TypeError
(
"The input input_spec should be 'list', but received input_spec's type is %s."
%
type
(
input_spec
))
inner_input_spec
=
[]
for
var
in
input_spec
:
if
isinstance
(
var
,
paddle
.
static
.
InputSpec
):
inner_input_spec
.
append
(
var
)
elif
isinstance
(
var
,
(
core
.
VarBase
,
Variable
)):
inner_input_spec
.
append
(
paddle
.
static
.
InputSpec
.
from_tensor
(
var
))
else
:
raise
TypeError
(
"The element in input_spec list should be 'Variable' or `paddle.static.InputSpec`, but received element's type is %s."
%
type
(
var
))
return
inner_input_spec
def
_get_concrete_program_from_layer
(
layer
,
inner_input_spec
):
# TODO(chenweihang): add support for other method, not only forward
if
isinstance
(
layer
.
forward
,
StaticLayer
):
concrete_program
=
layer
.
forward
.
concrete_program
else
:
# transform in jit.save, if input_spec is incomplete, declarative will throw error
static_forward
=
declarative
(
layer
.
forward
,
input_spec
=
inner_input_spec
)
concrete_program
=
static_forward
.
concrete_program
# the input_spec has been used in declarative, which is equal to
# @declarative with input_spec and jit.save without input_spec,
# avoid needless warning
inner_input_spec
=
None
return
concrete_program
def
_build_input_and_output
(
concrete_program
,
inner_input_spec
,
config
):
# NOTE(chenweihang): [ Get input variables name ]
# There are two cases, whether to prune the inputs or not
# - not prune inputs (recommend):
# - the len(input_spec) == len((concrete_program.inputs) - 1
# - here can use concrete_program.inputs directly
# - prune inputs:
# - the input_spec length < len((concrete_program.inputs) - 1
# - the input_spec's name should be in concrete_program.inputs
input_var_names
=
_get_input_var_names
(
concrete_program
.
inputs
,
inner_input_spec
)
# NOTE(chenweihang): [ Get output variables ]
# the rule is like [ Get input variables name ]. For output var,
# we only support VarBase spec, and actually, we only need the
# var name of output, and we don't recommended to use output_spec
output_vars
=
_get_output_vars
(
concrete_program
.
outputs
,
config
.
output_spec
)
return
input_var_names
,
output_vars
# NOTE: This function is not exposed to users, only used for paddle2onnx now
@
switch_to_static_graph
def
get_inference_program
(
layer
,
input_spec
=
None
,
config
=
None
):
# 1. input check
inner_input_spec
=
_infer_input_check
(
layer
,
input_spec
)
if
config
is
None
:
config
=
SaveLoadConfig
()
# 2. get program from Layer
concrete_program
=
_get_concrete_program_from_layer
(
layer
,
inner_input_spec
)
# 3. build input & output of save_infernece_model
input_var_names
,
output_vars
=
_build_input_and_output
(
concrete_program
,
inner_input_spec
,
config
)
# 4. only get inference program
inference_program
=
paddle
.
fluid
.
io
.
get_inference_program
(
input_var_names
,
output_vars
,
concrete_program
.
main_program
.
clone
())
return
inference_program
@
deprecate_save_load_configs
@
deprecate_save_load_configs
...
@@ -830,72 +923,18 @@ def save(layer, model_path, input_spec=None, config=None):
...
@@ -830,72 +923,18 @@ def save(layer, model_path, input_spec=None, config=None):
model_path = "linear.example.model"
model_path = "linear.example.model"
paddle.jit.save(layer, model_path)
paddle.jit.save(layer, model_path)
"""
"""
# 1. input check
# 1. input check
prog_translator
=
ProgramTranslator
()
inner_input_spec
=
_infer_input_check
(
layer
,
input_spec
)
if
not
prog_translator
.
enable_to_static
:
raise
RuntimeError
(
"The paddle.jit.save doesn't work when setting ProgramTranslator.enable to False."
)
if
not
isinstance
(
layer
,
Layer
):
raise
TypeError
(
"The input layer of paddle.jit.save should be 'Layer', but received layer type is %s."
%
type
(
layer
))
configs
=
config
if
configs
is
None
:
configs
=
SaveLoadConfig
()
# avoid change user given input_spec
if
config
is
None
:
inner_input_spec
=
None
config
=
SaveLoadConfig
()
if
input_spec
is
not
None
:
if
not
isinstance
(
input_spec
,
list
):
raise
TypeError
(
"The input input_spec should be 'list', but received input_spec's type is %s."
%
type
(
input_spec
))
inner_input_spec
=
[]
for
var
in
input_spec
:
if
isinstance
(
var
,
paddle
.
static
.
InputSpec
):
inner_input_spec
.
append
(
var
)
elif
isinstance
(
var
,
(
core
.
VarBase
,
Variable
)):
inner_input_spec
.
append
(
paddle
.
static
.
InputSpec
.
from_tensor
(
var
))
else
:
raise
TypeError
(
"The element in input_spec list should be 'Variable' or `paddle.static.InputSpec`, but received element's type is %s."
%
type
(
var
))
# 2. get program from Layer
# 2. get program from Layer
# TODO(chenweihang): add support for other method, not only forward
concrete_program
=
_get_concrete_program_from_layer
(
layer
,
inner_input_spec
)
if
isinstance
(
layer
.
forward
,
StaticLayer
):
concrete_program
=
layer
.
forward
.
concrete_program
else
:
# transform in jit.save, if input_spec is incomplete, declarative will throw error
static_forward
=
declarative
(
layer
.
forward
,
input_spec
=
inner_input_spec
)
concrete_program
=
static_forward
.
concrete_program
# the input_spec has been used in declarative, which is equal to
# @declarative with input_spec and jit.save without input_spec,
# avoid needless warning
inner_input_spec
=
None
# 3. build input & output of save_infernece_model
# 3. build input & output of save_infernece_model
# NOTE(chenweihang): [ Get input variables name ]
input_var_names
,
output_vars
=
_build_input_and_output
(
# There are two cases, whether to prune the inputs or not
concrete_program
,
inner_input_spec
,
config
)
# - not prune inputs (recommend):
# - the len(input_spec) == len((concrete_program.inputs) - 1
# - here can use concrete_program.inputs directly
# - prune inputs:
# - the input_spec length < len((concrete_program.inputs) - 1
# - the input_spec's name should be in concrete_program.inputs
input_var_names
=
_get_input_var_names
(
concrete_program
.
inputs
,
inner_input_spec
)
# NOTE(chenweihang): [ Get output variables ]
# the rule is like [ Get input variables name ]. For output var,
# we only support VarBase spec, and actually, we only need the
# var name of output, and we don't recommended to use output_spec
output_vars
=
_get_output_vars
(
concrete_program
.
outputs
,
configs
.
output_spec
)
# NOTE(chenweihang): we maintain the mapping of variable name to
# NOTE(chenweihang): we maintain the mapping of variable name to
# structured name, the buffer variable (non-persistable)
# structured name, the buffer variable (non-persistable)
...
@@ -927,8 +966,8 @@ def save(layer, model_path, input_spec=None, config=None):
...
@@ -927,8 +966,8 @@ def save(layer, model_path, input_spec=None, config=None):
from
paddle.fluid.io
import
save_inference_model
from
paddle.fluid.io
import
save_inference_model
# VARIABLE_FILENAME keep nameing style consistent with '__model__'
# VARIABLE_FILENAME keep nameing style consistent with '__model__'
if
config
s
.
params_filename
is
None
:
if
config
.
params_filename
is
None
:
config
s
.
params_filename
=
VARIABLE_FILENAME
config
.
params_filename
=
VARIABLE_FILENAME
with
scope_guard
(
scope
):
with
scope_guard
(
scope
):
save_inference_model
(
save_inference_model
(
...
@@ -937,11 +976,11 @@ def save(layer, model_path, input_spec=None, config=None):
...
@@ -937,11 +976,11 @@ def save(layer, model_path, input_spec=None, config=None):
target_vars
=
output_vars
,
target_vars
=
output_vars
,
executor
=
Executor
(
_current_expected_place
()),
executor
=
Executor
(
_current_expected_place
()),
main_program
=
concrete_program
.
main_program
.
clone
(),
main_program
=
concrete_program
.
main_program
.
clone
(),
model_filename
=
config
s
.
model_filename
,
model_filename
=
config
.
model_filename
,
params_filename
=
None
params_filename
=
None
if
config
s
.
separate_params
else
configs
.
params_filename
,
if
config
.
separate_params
else
config
.
params_filename
,
export_for_deployment
=
config
s
.
_export_for_deployment
,
export_for_deployment
=
config
.
_export_for_deployment
,
program_only
=
config
s
.
_program_only
)
program_only
=
config
.
_program_only
)
# NOTE(chenweihang): [ Save extra variable info ]
# NOTE(chenweihang): [ Save extra variable info ]
# save_inference_model will lose some important variable information, including:
# save_inference_model will lose some important variable information, including:
...
...
python/paddle/fluid/io.py
浏览文件 @
6e3e3f13
...
@@ -22,10 +22,11 @@ import logging
...
@@ -22,10 +22,11 @@ import logging
import
pickle
import
pickle
import
contextlib
import
contextlib
from
functools
import
reduce
from
functools
import
reduce
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
# ddeprecated module import
from
paddle.fluid
import
layers
from
paddle.fluid
import
layers
from
paddle.fluid.executor
import
Executor
,
global_scope
from
paddle.fluid.executor
import
Executor
,
global_scope
from
paddle.fluid.evaluator
import
Evaluator
from
paddle.fluid.evaluator
import
Evaluator
...
@@ -220,6 +221,113 @@ def _get_valid_program(main_program):
...
@@ -220,6 +221,113 @@ def _get_valid_program(main_program):
return
main_program
return
main_program
def
_feed_fetch_check
(
feeded_var_names
,
target_vars
,
export_for_deployment
=
True
):
if
isinstance
(
feeded_var_names
,
six
.
string_types
):
feeded_var_names
=
[
feeded_var_names
]
elif
export_for_deployment
:
if
len
(
feeded_var_names
)
>
0
:
# TODO(paddle-dev): polish these code blocks
if
not
(
bool
(
feeded_var_names
)
and
all
(
isinstance
(
name
,
six
.
string_types
)
for
name
in
feeded_var_names
)):
raise
ValueError
(
"'feed_var_names' should be a list of str."
)
if
isinstance
(
target_vars
,
Variable
):
target_vars
=
[
target_vars
]
elif
export_for_deployment
:
if
not
(
bool
(
target_vars
)
and
all
(
isinstance
(
var
,
Variable
)
for
var
in
target_vars
)):
raise
ValueError
(
"'target_vars' should be a list of Variable."
)
def
_auc_states_check_and_remind
(
main_program
):
all_ops
=
main_program
.
global_block
().
ops
for
op
in
all_ops
:
# clear device of Op
device_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpDeviceAttrName
()
op
.
_set_attr
(
device_attr_name
,
""
)
if
op
.
type
==
'auc'
:
warnings
.
warn
(
"please ensure that you have set the auc states to zeros before saving inference model"
)
break
def
_update_target_vars
(
target_vars
,
main_program
):
# fix the bug that the activation op's output as target will be pruned.
# will affect the inference performance.
# TODO(Superjomn) add an IR pass to remove 1-scale op.
with
program_guard
(
main_program
):
uniq_target_vars
=
[]
for
i
,
var
in
enumerate
(
target_vars
):
if
isinstance
(
var
,
Variable
):
var
=
layers
.
scale
(
var
,
1.
,
name
=
"save_infer_model/scale_{}"
.
format
(
i
))
uniq_target_vars
.
append
(
var
)
target_vars
=
uniq_target_vars
return
target_vars
def
_get_train_program
(
feeded_var_names
,
target_vars
,
main_program
):
# 1. feed & fetch check
_feed_fetch_check
(
feeded_var_names
,
target_vars
,
False
)
# 2. remind user to set auc_states to zeros if the program contains auc op
_auc_states_check_and_remind
(
main_program
)
# 3. update input target_vars to fix bug
target_vars
=
_update_target_vars
(
target_vars
,
main_program
)
return
main_program
def
_serialization
(
main_program
,
model_basename
):
with
open
(
model_basename
,
"wb"
)
as
f
:
f
.
write
(
main_program
.
desc
.
serialize_to_string
())
# NOTE: This function is not exposed to users, only used for paddle2onnx now
@
dygraph_not_support
def
get_inference_program
(
feeded_var_names
,
target_vars
,
main_program
):
# 1. feed & fetch check
_feed_fetch_check
(
feeded_var_names
,
target_vars
)
# 2. remind user to set auc_states to zeros if the program contains auc op
_auc_states_check_and_remind
(
main_program
)
# 3. update input target_vars to fix bug
target_vars
=
_update_target_vars
(
target_vars
,
main_program
)
# 4. build inference program
main_program
=
main_program
.
clone
()
global_block
=
main_program
.
global_block
()
need_to_remove_op_index
=
[]
for
i
,
op
in
enumerate
(
global_block
.
ops
):
op
.
desc
.
set_is_target
(
False
)
if
op
.
type
==
"feed"
or
op
.
type
==
"fetch"
:
need_to_remove_op_index
.
append
(
i
)
for
index
in
need_to_remove_op_index
[::
-
1
]:
global_block
.
_remove_op
(
index
)
main_program
.
desc
.
flush
()
main_program
=
main_program
.
_prune_with_input
(
feeded_var_names
=
feeded_var_names
,
targets
=
target_vars
)
main_program
=
main_program
.
_inference_optimize
(
prune_read_op
=
True
)
fetch_var_names
=
[
v
.
name
for
v
in
target_vars
]
prepend_feed_ops
(
main_program
,
feeded_var_names
)
append_fetch_ops
(
main_program
,
fetch_var_names
)
main_program
.
desc
.
_set_version
()
paddle
.
fluid
.
core
.
save_op_compatible_info
(
main_program
.
desc
)
return
main_program
@
dygraph_not_support
@
dygraph_not_support
def
save_vars
(
executor
,
def
save_vars
(
executor
,
dirname
,
dirname
,
...
@@ -1257,50 +1365,16 @@ def save_inference_model(dirname,
...
@@ -1257,50 +1365,16 @@ def save_inference_model(dirname,
# "./infer_model".
# "./infer_model".
"""
"""
if
isinstance
(
feeded_var_names
,
six
.
string_types
):
# 1. get main program
feeded_var_names
=
[
feeded_var_names
]
elif
export_for_deployment
:
if
len
(
feeded_var_names
)
>
0
:
# TODO(paddle-dev): polish these code blocks
if
not
(
bool
(
feeded_var_names
)
and
all
(
isinstance
(
name
,
six
.
string_types
)
for
name
in
feeded_var_names
)):
raise
ValueError
(
"'feed_var_names' should be a list of str."
)
if
isinstance
(
target_vars
,
Variable
):
target_vars
=
[
target_vars
]
elif
export_for_deployment
:
if
not
(
bool
(
target_vars
)
and
all
(
isinstance
(
var
,
Variable
)
for
var
in
target_vars
)):
raise
ValueError
(
"'target_vars' should be a list of Variable."
)
main_program
=
_get_valid_program
(
main_program
)
main_program
=
_get_valid_program
(
main_program
)
# remind user to set auc_states to zeros if the program contains auc op
# When export_for_deployment is true, we modify the program online so that
all_ops
=
main_program
.
global_block
().
ops
# it can only be loaded for inference directly. If it's false, the whole
for
op
in
all_ops
:
# original program and related meta are saved so that future usage can be
# clear device of Op
# more flexible.
device_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpDeviceAttrName
()
origin_program
=
main_program
.
clone
()
op
.
_set_attr
(
device_attr_name
,
""
)
if
op
.
type
==
'auc'
:
warnings
.
warn
(
"please ensure that you have set the auc states to zeros before saving inference model"
)
break
# fix the bug that the activation op's output as target will be pruned.
# will affect the inference performance.
# TODO(Superjomn) add an IR pass to remove 1-scale op.
with
program_guard
(
main_program
):
uniq_target_vars
=
[]
for
i
,
var
in
enumerate
(
target_vars
):
if
isinstance
(
var
,
Variable
):
var
=
layers
.
scale
(
var
,
1.
,
name
=
"save_infer_model/scale_{}"
.
format
(
i
))
uniq_target_vars
.
append
(
var
)
target_vars
=
uniq_target_vars
target_var_name_list
=
[
var
.
name
for
var
in
target_vars
]
# 2. dirname check & create
# when a pserver and a trainer running on the same machine, mkdir may conflict
# when a pserver and a trainer running on the same machine, mkdir may conflict
save_dirname
=
dirname
save_dirname
=
dirname
try
:
try
:
...
@@ -1310,57 +1384,34 @@ def save_inference_model(dirname,
...
@@ -1310,57 +1384,34 @@ def save_inference_model(dirname,
if
e
.
errno
!=
errno
.
EEXIST
:
if
e
.
errno
!=
errno
.
EEXIST
:
raise
raise
# 3. model_filename check & create
if
model_filename
is
not
None
:
if
model_filename
is
not
None
:
model_basename
=
os
.
path
.
basename
(
model_filename
)
model_basename
=
os
.
path
.
basename
(
model_filename
)
else
:
else
:
model_basename
=
"__model__"
model_basename
=
"__model__"
model_basename
=
os
.
path
.
join
(
save_dirname
,
model_basename
)
model_basename
=
os
.
path
.
join
(
save_dirname
,
model_basename
)
# When export_for_deployment is true, we modify the program online so that
# 4. get & serialize program
# it can only be loaded for inference directly. If it's false, the whole
# original program and related meta are saved so that future usage can be
# more flexible.
origin_program
=
main_program
.
clone
()
if
export_for_deployment
:
if
export_for_deployment
:
main_program
=
main_program
.
clone
()
main_program
=
get_inference_program
(
feeded_var_names
,
target_vars
,
global_block
=
main_program
.
global_block
()
main_program
)
need_to_remove_op_index
=
[]
_serialization
(
main_program
,
model_basename
)
for
i
,
op
in
enumerate
(
global_block
.
ops
):
op
.
desc
.
set_is_target
(
False
)
if
op
.
type
==
"feed"
or
op
.
type
==
"fetch"
:
need_to_remove_op_index
.
append
(
i
)
for
index
in
need_to_remove_op_index
[::
-
1
]:
global_block
.
_remove_op
(
index
)
main_program
.
desc
.
flush
()
main_program
=
main_program
.
_prune_with_input
(
feeded_var_names
=
feeded_var_names
,
targets
=
target_vars
)
main_program
=
main_program
.
_inference_optimize
(
prune_read_op
=
True
)
fetch_var_names
=
[
v
.
name
for
v
in
target_vars
]
prepend_feed_ops
(
main_program
,
feeded_var_names
)
append_fetch_ops
(
main_program
,
fetch_var_names
)
main_program
.
desc
.
_set_version
()
paddle
.
fluid
.
core
.
save_op_compatible_info
(
main_program
.
desc
)
with
open
(
model_basename
,
"wb"
)
as
f
:
f
.
write
(
main_program
.
desc
.
serialize_to_string
())
else
:
else
:
# TODO(panyx0718): Save more information so that it can also be used
# TODO(panyx0718): Save more information so that it can also be used
# for training and more flexible post-processing.
# for training and more flexible post-processing.
with
open
(
model_basename
+
".main_program"
,
"wb"
)
as
f
:
main_program
=
_get_train_program
(
feeded_var_names
,
target_vars
,
f
.
write
(
main_program
.
desc
.
serialize_to_string
())
main_program
)
_serialization
(
main_program
,
model_basename
+
".main_program"
)
# 5. get target var_name list & judge whether serialize program only
target_var_name_list
=
[
var
.
name
for
var
in
target_vars
]
if
program_only
:
if
program_only
:
warnings
.
warn
(
warnings
.
warn
(
"save_inference_model specified the param `program_only` to True, It will not save params of Program."
"save_inference_model specified the param `program_only` to True, It will not save params of Program."
)
)
return
target_var_name_list
return
target_var_name_list
# 6. save persistables
main_program
.
_copy_dist_param_info_from
(
origin_program
)
main_program
.
_copy_dist_param_info_from
(
origin_program
)
if
params_filename
is
not
None
:
if
params_filename
is
not
None
:
...
...
python/paddle/fluid/tests/unittests/test_inference_model_io.py
浏览文件 @
6e3e3f13
...
@@ -23,6 +23,7 @@ import paddle.fluid.core as core
...
@@ -23,6 +23,7 @@ import paddle.fluid.core as core
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
warnings
import
warnings
import
paddle
import
paddle.fluid.executor
as
executor
import
paddle.fluid.executor
as
executor
import
paddle.fluid.layers
as
layers
import
paddle.fluid.layers
as
layers
import
paddle.fluid.optimizer
as
optimizer
import
paddle.fluid.optimizer
as
optimizer
...
@@ -201,4 +202,5 @@ class TestLoadInferenceModelError(unittest.TestCase):
...
@@ -201,4 +202,5 @@ class TestLoadInferenceModelError(unittest.TestCase):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_jit_save_load.py
浏览文件 @
6e3e3f13
...
@@ -755,5 +755,34 @@ class TestJitSaveLoadNoParamLayer(unittest.TestCase):
...
@@ -755,5 +755,34 @@ class TestJitSaveLoadNoParamLayer(unittest.TestCase):
self
.
assertTrue
(
np
.
array_equal
(
out
,
load_out
))
self
.
assertTrue
(
np
.
array_equal
(
out
,
load_out
))
class
TestJitGetInferenceProgram
(
unittest
.
TestCase
):
def
setUp
(
self
):
# enable dygraph mode
paddle
.
disable_static
()
def
test_get_inference_program
(
self
):
layer
=
LinearNet
(
784
,
1
)
train
(
layer
)
model_path
=
"model.jit_get_inference_program"
paddle
.
jit
.
save
(
layer
,
model_path
)
infer_program
=
paddle
.
jit
.
get_inference_program
(
layer
)
# the program of jit.load is different with original inference program
model_file_path
=
os
.
path
.
join
(
model_path
,
"__model__"
)
load_program_desc
=
fluid
.
dygraph
.
io
.
_load_program_desc
(
model_file_path
)
load_program
=
fluid
.
dygraph
.
io
.
_build_program_by_desc
(
load_program_desc
)
self
.
assertEqual
(
infer_program
.
num_blocks
,
load_program
.
num_blocks
)
self
.
assertEqual
(
len
(
infer_program
.
global_block
().
ops
),
len
(
load_program
.
global_block
().
ops
))
self
.
assertEqual
(
len
(
infer_program
.
global_block
().
vars
),
len
(
load_program
.
global_block
().
vars
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
python/paddle/jit/__init__.py
浏览文件 @
6e3e3f13
...
@@ -21,6 +21,9 @@ from ..fluid.dygraph.jit import declarative as to_static #DEFINE_ALIAS
...
@@ -21,6 +21,9 @@ from ..fluid.dygraph.jit import declarative as to_static #DEFINE_ALIAS
from
..fluid.dygraph
import
ProgramTranslator
#DEFINE_ALIAS
from
..fluid.dygraph
import
ProgramTranslator
#DEFINE_ALIAS
from
..fluid.dygraph.io
import
TranslatedLayer
#DEFINE_ALIAS
from
..fluid.dygraph.io
import
TranslatedLayer
#DEFINE_ALIAS
# NOTE: This function is not exposed to users, only used for paddle2onnx now
from
..fluid.dygraph.jit
import
get_inference_program
#DEFINE_ALIAS
__all__
=
[
__all__
=
[
'save'
,
'load'
,
'TracedLayer'
,
'to_static'
,
'ProgramTranslator'
,
'save'
,
'load'
,
'TracedLayer'
,
'to_static'
,
'ProgramTranslator'
,
'TranslatedLayer'
,
'set_code_level'
,
'set_verbosity'
'TranslatedLayer'
,
'set_code_level'
,
'set_verbosity'
...
...
python/paddle/static/__init__.py
浏览文件 @
6e3e3f13
...
@@ -43,3 +43,6 @@ from ..fluid.parallel_executor import ParallelExecutor #DEFINE_ALIAS
...
@@ -43,3 +43,6 @@ from ..fluid.parallel_executor import ParallelExecutor #DEFINE_ALIAS
from
..fluid.param_attr
import
WeightNormParamAttr
#DEFINE_ALIAS
from
..fluid.param_attr
import
WeightNormParamAttr
#DEFINE_ALIAS
from
..tensor.io
import
save
#DEFINE_ALIAS
from
..tensor.io
import
save
#DEFINE_ALIAS
from
..tensor.io
import
load
#DEFINE_ALIAS
from
..tensor.io
import
load
#DEFINE_ALIAS
# NOTE: This function is not exposed to users, only used for paddle2onnx now
from
..fluid.io
import
get_inference_program
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录