Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
29543da5
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
29543da5
编写于
2月 23, 2021
作者:
S
Shibo Tao
提交者:
GitHub
2月 23, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
export paddle.static.normalize_program method. test=develop (#31080)
上级
1d2bd35e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
103 addition
and
5 deletion
+103
-5
python/paddle/fluid/tests/unittests/test_inference_model_io.py
...n/paddle/fluid/tests/unittests/test_inference_model_io.py
+42
-0
python/paddle/static/__init__.py
python/paddle/static/__init__.py
+1
-0
python/paddle/static/io.py
python/paddle/static/io.py
+60
-5
未找到文件。
python/paddle/fluid/tests/unittests/test_inference_model_io.py
浏览文件 @
29543da5
...
@@ -356,6 +356,48 @@ class TestSaveInferenceModelNew(unittest.TestCase):
...
@@ -356,6 +356,48 @@ class TestSaveInferenceModelNew(unittest.TestCase):
self
.
assertRaises
(
TypeError
,
paddle
.
static
.
io
.
deserialize_persistables
,
self
.
assertRaises
(
TypeError
,
paddle
.
static
.
io
.
deserialize_persistables
,
None
,
None
,
None
)
None
,
None
,
None
)
def
test_normalize_program
(
self
):
init_program
=
fluid
.
default_startup_program
()
program
=
fluid
.
default_main_program
()
# fake program without feed/fetch
with
program_guard
(
program
,
init_program
):
x
=
layers
.
data
(
name
=
'x'
,
shape
=
[
2
],
dtype
=
'float32'
)
y
=
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
y_predict
=
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
cost
=
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_cost
=
layers
.
mean
(
cost
)
sgd_optimizer
=
optimizer
.
SGDOptimizer
(
learning_rate
=
0.001
)
sgd_optimizer
.
minimize
(
avg_cost
,
init_program
)
place
=
core
.
CPUPlace
()
exe
=
executor
.
Executor
(
place
)
exe
.
run
(
init_program
,
feed
=
{},
fetch_list
=
[])
tensor_x
=
np
.
array
([[
1
,
1
],
[
1
,
2
],
[
5
,
2
]]).
astype
(
"float32"
)
tensor_y
=
np
.
array
([[
-
2
],
[
-
3
],
[
-
7
]]).
astype
(
"float32"
)
for
i
in
six
.
moves
.
xrange
(
3
):
exe
.
run
(
program
,
feed
=
{
'x'
:
tensor_x
,
'y'
:
tensor_y
},
fetch_list
=
[
avg_cost
])
# test if return type of serialize_program is bytes
res
=
paddle
.
static
.
normalize_program
(
program
,
[
x
,
y
],
[
avg_cost
])
self
.
assertTrue
(
isinstance
(
res
,
Program
))
# test program type
self
.
assertRaises
(
TypeError
,
paddle
.
static
.
normalize_program
,
None
,
[
x
,
y
],
[
avg_cost
])
# test feed_vars type
self
.
assertRaises
(
TypeError
,
paddle
.
static
.
normalize_program
,
program
,
[
'x'
,
'y'
],
[
avg_cost
])
# test fetch_vars type
self
.
assertRaises
(
TypeError
,
paddle
.
static
.
normalize_program
,
program
,
[
x
,
y
],
[
'avg_cost'
])
class
TestLoadInferenceModelError
(
unittest
.
TestCase
):
class
TestLoadInferenceModelError
(
unittest
.
TestCase
):
def
test_load_model_not_exist
(
self
):
def
test_load_model_not_exist
(
self
):
...
...
python/paddle/static/__init__.py
浏览文件 @
29543da5
...
@@ -59,6 +59,7 @@ from .io import deserialize_program #DEFINE_ALIAS
...
@@ -59,6 +59,7 @@ from .io import deserialize_program #DEFINE_ALIAS
from
.io
import
serialize_program
#DEFINE_ALIAS
from
.io
import
serialize_program
#DEFINE_ALIAS
from
.io
import
load_from_file
#DEFINE_ALIAS
from
.io
import
load_from_file
#DEFINE_ALIAS
from
.io
import
save_to_file
#DEFINE_ALIAS
from
.io
import
save_to_file
#DEFINE_ALIAS
from
.io
import
normalize_program
#DEFINE_ALIAS
from
..fluid
import
Scope
#DEFINE_ALIAS
from
..fluid
import
Scope
#DEFINE_ALIAS
from
.input
import
data
#DEFINE_ALIAS
from
.input
import
data
#DEFINE_ALIAS
from
.input
import
InputSpec
#DEFINE_ALIAS
from
.input
import
InputSpec
#DEFINE_ALIAS
...
...
python/paddle/static/io.py
浏览文件 @
29543da5
...
@@ -46,6 +46,7 @@ __all__ = [
...
@@ -46,6 +46,7 @@ __all__ = [
'deserialize_program'
,
'deserialize_program'
,
'deserialize_persistables'
,
'deserialize_persistables'
,
'load_from_file'
,
'load_from_file'
,
'normalize_program'
,
]
]
_logger
=
get_logger
(
_logger
=
get_logger
(
...
@@ -127,10 +128,64 @@ def _clone_var_in_block(block, var):
...
@@ -127,10 +128,64 @@ def _clone_var_in_block(block, var):
persistable
=
True
)
persistable
=
True
)
def
_
normalize_program
(
program
,
feed_vars
,
fetch_vars
):
def
normalize_program
(
program
,
feed_vars
,
fetch_vars
):
"""
"""
optimize program according feed_vars and fetch_vars.
:api_attr: Static Graph
Normalize/Optimize a program according to feed_vars and fetch_vars.
Args:
program(Program): Specify a program you want to optimize.
feed_vars(Variable | list[Variable]): Variables needed by inference.
fetch_vars(Variable | list[Variable]): Variables returned by inference.
Returns:
Program: Normalized/Optimized program.
Raises:
TypeError: If `program` is not a Program, an exception is thrown.
TypeError: If `feed_vars` is not a Variable or a list of Variable, an exception is thrown.
TypeError: If `fetch_vars` is not a Variable or a list of Variable, an exception is thrown.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
path_prefix = "./infer_model"
# User defined network, here a softmax regession example
image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
predict = paddle.static.nn.fc(image, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, label)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
# normalize main program.
program = default_main_program()
normalized_program = paddle.static.normalize_program(program, [image], [predict])
"""
"""
if
not
isinstance
(
program
,
Program
):
raise
TypeError
(
"program type must be `fluid.Program`, but received `%s`"
%
type
(
program
))
if
not
isinstance
(
feed_vars
,
list
):
feed_vars
=
[
feed_vars
]
if
not
all
(
isinstance
(
v
,
Variable
)
for
v
in
feed_vars
):
raise
TypeError
(
"feed_vars type must be a Variable or a list of Variable."
)
if
not
isinstance
(
fetch_vars
,
list
):
fetch_vars
=
[
fetch_vars
]
if
not
all
(
isinstance
(
v
,
Variable
)
for
v
in
fetch_vars
):
raise
TypeError
(
"fetch_vars type must be a Variable or a list of Variable."
)
# remind users to set auc_states to 0 if auc op were found.
# remind users to set auc_states to 0 if auc op were found.
for
op
in
program
.
global_block
().
ops
:
for
op
in
program
.
global_block
().
ops
:
# clear device of Op
# clear device of Op
...
@@ -255,7 +310,7 @@ def serialize_program(feed_vars, fetch_vars, **kwargs):
...
@@ -255,7 +310,7 @@ def serialize_program(feed_vars, fetch_vars, **kwargs):
_check_vars
(
'fetch_vars'
,
fetch_vars
)
_check_vars
(
'fetch_vars'
,
fetch_vars
)
program
=
_get_valid_program
(
kwargs
.
get
(
'program'
,
None
))
program
=
_get_valid_program
(
kwargs
.
get
(
'program'
,
None
))
program
=
_
normalize_program
(
program
,
feed_vars
,
fetch_vars
)
program
=
normalize_program
(
program
,
feed_vars
,
fetch_vars
)
return
_serialize_program
(
program
)
return
_serialize_program
(
program
)
...
@@ -319,7 +374,7 @@ def serialize_persistables(feed_vars, fetch_vars, executor, **kwargs):
...
@@ -319,7 +374,7 @@ def serialize_persistables(feed_vars, fetch_vars, executor, **kwargs):
_check_vars
(
'fetch_vars'
,
fetch_vars
)
_check_vars
(
'fetch_vars'
,
fetch_vars
)
program
=
_get_valid_program
(
kwargs
.
get
(
'program'
,
None
))
program
=
_get_valid_program
(
kwargs
.
get
(
'program'
,
None
))
program
=
_
normalize_program
(
program
,
feed_vars
,
fetch_vars
)
program
=
normalize_program
(
program
,
feed_vars
,
fetch_vars
)
return
_serialize_persistables
(
program
,
executor
)
return
_serialize_persistables
(
program
,
executor
)
...
@@ -463,7 +518,7 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor,
...
@@ -463,7 +518,7 @@ def save_inference_model(path_prefix, feed_vars, fetch_vars, executor,
_check_vars
(
'fetch_vars'
,
fetch_vars
)
_check_vars
(
'fetch_vars'
,
fetch_vars
)
program
=
_get_valid_program
(
kwargs
.
get
(
'program'
,
None
))
program
=
_get_valid_program
(
kwargs
.
get
(
'program'
,
None
))
program
=
_
normalize_program
(
program
,
feed_vars
,
fetch_vars
)
program
=
normalize_program
(
program
,
feed_vars
,
fetch_vars
)
# serialize and save program
# serialize and save program
program_bytes
=
_serialize_program
(
program
)
program_bytes
=
_serialize_program
(
program
)
save_to_file
(
model_path
,
program_bytes
)
save_to_file
(
model_path
,
program_bytes
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录