Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
5e99f31b
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5e99f31b
编写于
9月 29, 2019
作者:
M
mapingshuo
提交者:
Dong Daxiang
9月 29, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add a new interface _prune_with_input (#20022)
* add a default value for _prune interface * modify document
上级
6f184775
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
158 addition
and
2 deletion
+158
-2
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+56
-1
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+2
-1
python/paddle/fluid/tests/unittests/test_prune.py
python/paddle/fluid/tests/unittests/test_prune.py
+100
-0
未找到文件。
python/paddle/fluid/framework.py
浏览文件 @
5e99f31b
...
...
@@ -3597,7 +3597,7 @@ class Program(object):
p
.
_copy_dist_param_info_from
(
self
)
return
p
def
_prune
(
self
,
feeded_var_names
,
targets
):
def
_prune
(
self
,
targets
):
"""
Prune operators and variables which are not needed to generate
:code:`targets`.
...
...
@@ -3611,8 +3611,63 @@ class Program(object):
Returns:
Program: A new, pruned program.
"""
if
not
isinstance
(
targets
,
list
):
targets
=
[
targets
]
targets_idx
=
[]
for
t
in
targets
:
if
not
isinstance
(
t
,
Operator
):
if
isinstance
(
t
,
Variable
):
# After transpiler processing, the op that output this
# variable maybe has been changed, so t.op is not reliable
# and we need to find the current op that generate this
# variable here.
t
.
op
=
None
global_block
=
self
.
global_block
()
for
idx
,
op
in
enumerate
(
global_block
.
ops
):
if
t
.
name
in
op
.
output_arg_names
:
t
.
op
=
op
break
t
=
t
.
op
if
t
is
None
:
raise
ValueError
(
"The target variable must have an "
"associated operator that generates it."
)
else
:
raise
ValueError
(
"All targets of prune() can only be "
"Variable or Operator."
)
targets_idx
.
append
([
t
.
block
.
idx
,
t
.
idx
])
res
=
Program
()
res
.
desc
=
core
.
prune
(
self
.
desc
,
set
(),
targets_idx
)
res
.
blocks
=
[
Block
(
res
,
i
)
for
i
in
six
.
moves
.
range
(
res
.
desc
.
num_blocks
())
]
res
.
_sync_with_cpp
()
return
res
def
_prune_with_input
(
self
,
feeded_var_names
,
targets
):
"""
Prune operators and variables which are not needed to generate
:code:`targets`. Prune operators and variables which are needed
to generate feeded_var
Notes: This is a very low level API. Users should not use this API
directly. This API is in flux and not stable.
Args:
feeded_var_names(list|str): A list of variable names from where
pruning start. If it is set as [], this API works just like _prune()
targets(list|Variable|Operator): A list of variables or operators
need to be pruned
Returns:
Program: A new, pruned program.
"""
if
not
isinstance
(
feeded_var_names
,
list
):
feeded_var_names
=
[
feeded_var_names
]
if
not
isinstance
(
targets
,
list
):
...
...
python/paddle/fluid/io.py
浏览文件 @
5e99f31b
...
...
@@ -1121,7 +1121,8 @@ def save_inference_model(dirname,
main_program
.
desc
.
flush
()
main_program
=
main_program
.
_prune
(
feeded_var_names
,
target_vars
)
main_program
=
main_program
.
_prune_with_input
(
feeded_var_names
=
feeded_var_names
,
targets
=
target_vars
)
main_program
=
main_program
.
_inference_optimize
(
prune_read_op
=
True
)
fetch_var_names
=
[
v
.
name
for
v
in
target_vars
]
...
...
python/paddle/fluid/tests/unittests/test_prune.py
0 → 100644
浏览文件 @
5e99f31b
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
paddle.fluid
as
fluid
import
paddle.fluid.framework
as
framework
import
paddle.compat
as
cpt
class
TestPrune
(
unittest
.
TestCase
):
def
net
(
self
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
2
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
"label"
,
shape
=
[
1
],
dtype
=
"int64"
)
y
=
fluid
.
layers
.
fc
(
input
=
[
x
],
size
=
2
,
act
=
"softmax"
)
loss
=
fluid
.
layers
.
cross_entropy
(
input
=
y
,
label
=
label
)
loss
=
fluid
.
layers
.
mean
(
x
=
loss
)
return
x
,
y
,
label
,
loss
def
test_prune_with_input
(
self
):
program
=
framework
.
Program
()
startup_program
=
framework
.
Program
()
block
=
program
.
global_block
()
with
fluid
.
program_guard
(
program
,
startup_program
):
(
x
,
y
,
label
,
loss
)
=
self
.
net
()
self
.
assertEqual
(
len
(
block
.
ops
),
5
)
self
.
assertEqual
([
op
.
type
for
op
in
block
.
ops
],
[
"mul"
,
"elementwise_add"
,
"softmax"
,
"cross_entropy2"
,
"mean"
])
pruned_program
=
program
.
_prune_with_input
(
feeded_var_names
=
[
y
.
name
,
label
.
name
],
targets
=
[
loss
])
self
.
assertEqual
(
len
(
pruned_program
.
global_block
().
ops
),
2
)
self
.
assertEqual
([
op
.
type
for
op
in
pruned_program
.
global_block
().
ops
],
[
"cross_entropy2"
,
"mean"
])
def
test_prune
(
self
):
program
=
framework
.
Program
()
startup_program
=
framework
.
Program
()
block
=
program
.
global_block
()
with
fluid
.
program_guard
(
program
,
startup_program
):
(
x
,
y
,
label
,
loss
)
=
self
.
net
()
self
.
assertEqual
(
len
(
block
.
ops
),
5
)
self
.
assertEqual
([
op
.
type
for
op
in
block
.
ops
],
[
"mul"
,
"elementwise_add"
,
"softmax"
,
"cross_entropy2"
,
"mean"
])
pruned_program
=
program
.
_prune
(
targets
=
[
loss
])
self
.
assertEqual
(
len
(
pruned_program
.
global_block
().
ops
),
5
)
self
.
assertEqual
(
[
op
.
type
for
op
in
pruned_program
.
global_block
().
ops
],
[
"mul"
,
"elementwise_add"
,
"softmax"
,
"cross_entropy2"
,
"mean"
])
def
test_prune_target_not_list
(
self
):
program
=
framework
.
Program
()
startup_program
=
framework
.
Program
()
block
=
program
.
global_block
()
with
fluid
.
program_guard
(
program
,
startup_program
):
(
x
,
y
,
label
,
loss
)
=
self
.
net
()
self
.
assertEqual
(
len
(
block
.
ops
),
5
)
self
.
assertEqual
([
op
.
type
for
op
in
block
.
ops
],
[
"mul"
,
"elementwise_add"
,
"softmax"
,
"cross_entropy2"
,
"mean"
])
pruned_program
=
program
.
_prune
(
targets
=
loss
)
self
.
assertEqual
(
len
(
pruned_program
.
global_block
().
ops
),
5
)
self
.
assertEqual
(
[
op
.
type
for
op
in
pruned_program
.
global_block
().
ops
],
[
"mul"
,
"elementwise_add"
,
"softmax"
,
"cross_entropy2"
,
"mean"
])
def
test_prune_target_none
(
self
):
program
=
framework
.
Program
()
startup_program
=
framework
.
Program
()
block
=
program
.
global_block
()
with
fluid
.
program_guard
(
program
,
startup_program
):
(
x
,
y
,
label
,
loss
)
=
self
.
net
()
self
.
assertEqual
(
len
(
block
.
ops
),
5
)
self
.
assertEqual
([
op
.
type
for
op
in
block
.
ops
],
[
"mul"
,
"elementwise_add"
,
"softmax"
,
"cross_entropy2"
,
"mean"
])
try
:
pruned_program
=
program
.
_prune
(
targets
=
None
)
except
ValueError
as
e
:
self
.
assertEqual
(
"All targets of prune() can only be Variable or Operator."
,
cpt
.
get_exception_message
(
e
))
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录