Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3f4c088a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3f4c088a
编写于
8月 09, 2019
作者:
C
chengduo
提交者:
GitHub
8月 09, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
prune the feed op in compiler (#18997)
test=develop
上级
d2360332
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
95 addition
and
0 deletion
+95
-0
python/paddle/fluid/compiler.py
python/paddle/fluid/compiler.py
+10
-0
python/paddle/fluid/tests/unittests/test_parallel_executor_run_load_infer_program.py
...nittests/test_parallel_executor_run_load_infer_program.py
+85
-0
未找到文件。
python/paddle/fluid/compiler.py
浏览文件 @
3f4c088a
...
...
@@ -45,6 +45,15 @@ def _is_pserver_mode(main_program):
return
False
def
_prune_feed_ops
(
program
):
# prune the feed ops in the program.
pop_idx
=
[]
for
i
,
op
in
enumerate
(
program
.
global_block
().
ops
):
if
op
.
type
==
"feed"
:
pop_idx
.
append
(
i
)
for
index
in
pop_idx
[::
-
1
]:
program
.
global_block
().
_remove_op
(
index
)
class
CompiledProgram
(
object
):
"""
Compiles to Graph for execution.
...
...
@@ -100,6 +109,7 @@ class CompiledProgram(object):
# don't not create a new program here.
self
.
_program
=
None
elif
isinstance
(
program_or_graph
,
framework
.
Program
):
_prune_feed_ops
(
program_or_graph
)
self
.
_graph
=
core
.
Graph
(
program_or_graph
.
desc
)
self
.
_program
=
program_or_graph
else
:
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor_run_load_infer_program.py
0 → 100644
浏览文件 @
3f4c088a
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
paddle.fluid
as
fluid
from
simple_nets
import
simple_fc_net
,
init_data
class
TestMNIST
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
save_dirname
=
"./"
cls
.
model_filename
=
"test_parallel_executor_run_load_infer_program_model"
cls
.
params_filename
=
"test_parallel_executor_run_load_infer_program_parameter"
cls
.
place
=
fluid
.
CPUPlace
()
cls
.
exe
=
fluid
.
Executor
(
cls
.
place
)
img
,
label
=
init_data
()
cls
.
batch_data
=
[]
for
img
,
label
in
zip
(
img
,
label
):
cls
.
batch_data
.
append
([
img
,
label
])
def
test_simple_fc
(
self
):
exe_loss
=
self
.
run_with_executor
()
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
fluid
.
io
.
load_inference_model
(
self
.
save_dirname
,
self
.
exe
,
self
.
model_filename
,
self
.
params_filename
)
train_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
False
,
main_program
=
inference_program
)
feed_vars
=
[
inference_program
.
global_block
().
var
(
var_name
)
for
var_name
in
[
"image"
,
"label"
]
]
feeder
=
fluid
.
DataFeeder
(
place
=
self
.
place
,
feed_list
=
feed_vars
)
pe_loss
=
train_exe
.
run
(
feed
=
feeder
.
feed
(
self
.
batch_data
),
fetch_list
=
[
fetch_targets
[
0
].
name
])
assert
exe_loss
==
pe_loss
def
run_with_executor
(
self
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
,
startup
):
loss
=
simple_fc_net
()
feed_vars
=
[
main
.
global_block
().
var
(
var_name
)
for
var_name
in
[
"image"
,
"label"
]
]
feeder
=
fluid
.
DataFeeder
(
place
=
self
.
place
,
feed_list
=
feed_vars
)
self
.
exe
.
run
(
startup
)
loss_data
=
self
.
exe
.
run
(
main
,
feed
=
feeder
.
feed
(
self
.
batch_data
),
fetch_list
=
[
loss
.
name
])
fluid
.
io
.
save_inference_model
(
self
.
save_dirname
,
[
"image"
,
"label"
],
[
loss
],
self
.
exe
,
model_filename
=
self
.
model_filename
,
params_filename
=
self
.
params_filename
,
main_program
=
main
)
return
loss_data
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录