Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
dca9b6c5
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dca9b6c5
编写于
9月 05, 2019
作者:
M
mapingshuo
提交者:
Dong Daxiang
9月 05, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add feed_var_names to Prune interface (#19589)
* Fix bug: add feed_vars to the prune function
上级
f45cb1c2
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
48 addition
and
20 deletion
+48
-20
paddle/fluid/framework/prune.cc
paddle/fluid/framework/prune.cc
+16
-7
paddle/fluid/framework/prune.h
paddle/fluid/framework/prune.h
+5
-1
paddle/fluid/framework/prune_test.cc
paddle/fluid/framework/prune_test.cc
+13
-8
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+3
-1
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+10
-2
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+1
-1
未找到文件。
paddle/fluid/framework/prune.cc
浏览文件 @
dca9b6c5
...
...
@@ -68,7 +68,8 @@ bool HasSubBlock(const proto::OpDesc& op_desc) {
// the child block to help pruning
void
prune_impl
(
const
proto
::
ProgramDesc
&
input
,
proto
::
ProgramDesc
*
output
,
int
block_id
,
int
parent_block_id
,
std
::
set
<
std
::
string
>*
dependent_vars
)
{
std
::
set
<
std
::
string
>*
dependent_vars
,
const
std
::
set
<
std
::
string
>
feed_var_names
)
{
auto
&
block
=
input
.
blocks
(
block_id
);
auto
&
ops
=
block
.
ops
();
...
...
@@ -94,7 +95,9 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
// insert its input to the dependency graph
for
(
auto
&
var
:
op_desc
.
inputs
())
{
for
(
auto
&
argu
:
var
.
arguments
())
{
dependent_vars
->
insert
(
argu
);
if
(
feed_var_names
.
count
(
argu
)
==
0
)
{
dependent_vars
->
insert
(
argu
);
}
}
}
should_run
.
push_back
(
true
);
...
...
@@ -127,18 +130,22 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
std
::
set
<
std
::
string
>
sub_block_dependent_vars
;
for
(
auto
&
var
:
op
->
inputs
())
{
for
(
auto
&
argu
:
var
.
arguments
())
{
sub_block_dependent_vars
.
insert
(
argu
);
if
(
feed_var_names
.
count
(
argu
)
==
0
)
{
sub_block_dependent_vars
.
insert
(
argu
);
}
}
}
for
(
auto
&
var
:
op
->
outputs
())
{
for
(
auto
&
argu
:
var
.
arguments
())
{
sub_block_dependent_vars
.
insert
(
argu
);
if
(
feed_var_names
.
count
(
argu
)
==
0
)
{
sub_block_dependent_vars
.
insert
(
argu
);
}
}
}
// GetSubBlockIndex(*op) is the idx of the sub_block in the input desc
// output_block_id is the idx of the current block in the output desc
prune_impl
(
input
,
output
,
GetSubBlockIndex
(
*
op
),
output_block_id
,
&
sub_block_dependent_vars
);
&
sub_block_dependent_vars
,
feed_var_names
);
}
}
}
...
...
@@ -178,10 +185,12 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
}
// TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies
void
Prune
(
const
proto
::
ProgramDesc
&
input
,
proto
::
ProgramDesc
*
output
)
{
void
Prune
(
const
proto
::
ProgramDesc
&
input
,
const
std
::
set
<
std
::
string
>&
feed_var_names
,
proto
::
ProgramDesc
*
output
)
{
std
::
set
<
std
::
string
>
dependent_vars
;
output
->
clear_blocks
();
prune_impl
(
input
,
output
,
0
,
-
1
,
&
dependent_vars
);
prune_impl
(
input
,
output
,
0
,
-
1
,
&
dependent_vars
,
feed_var_names
);
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/prune.h
浏览文件 @
dca9b6c5
...
...
@@ -14,13 +14,17 @@ limitations under the License. */
#pragma once
#include <set>
#include <string>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
framework
{
void
Prune
(
const
proto
::
ProgramDesc
&
input
,
proto
::
ProgramDesc
*
output
);
void
Prune
(
const
proto
::
ProgramDesc
&
input
,
const
std
::
set
<
std
::
string
>&
feed_var_names
,
proto
::
ProgramDesc
*
output
);
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/prune_test.cc
浏览文件 @
dca9b6c5
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/prune.h"
#include <gtest/gtest.h>
#include <set>
#include <string>
#include "paddle/fluid/framework/attribute.h"
...
...
@@ -58,12 +59,13 @@ TEST(Prune, one_operator) {
f
::
proto
::
ProgramDesc
*
pdesc
=
program
.
Proto
();
f
::
proto
::
ProgramDesc
pruned
;
f
::
Prune
(
*
pdesc
,
&
pruned
);
std
::
set
<
std
::
string
>
feed_var_names
=
{};
f
::
Prune
(
*
pdesc
,
feed_var_names
,
&
pruned
);
PADDLE_ENFORCE_EQ
(
pruned
.
blocks
(
0
).
ops_size
(),
0
);
feed_var_names
.
insert
(
"a"
);
pdesc
->
mutable_blocks
(
0
)
->
mutable_ops
(
0
)
->
set_is_target
(
true
);
f
::
Prune
(
*
pdesc
,
&
pruned
);
f
::
Prune
(
*
pdesc
,
feed_var_names
,
&
pruned
);
PADDLE_ENFORCE_EQ
(
pruned
.
blocks
(
0
).
ops_size
(),
1
);
}
...
...
@@ -81,11 +83,11 @@ TEST(Prune, forward) {
block
);
f
::
proto
::
ProgramDesc
*
pdesc
=
program
.
Proto
();
std
::
set
<
std
::
string
>
feed_var_names
=
{
"a"
};
for
(
int
i
=
0
;
i
<
pdesc
->
blocks
(
0
).
ops_size
();
++
i
)
{
f
::
proto
::
ProgramDesc
pruned
;
pdesc
->
mutable_blocks
(
0
)
->
mutable_ops
(
i
)
->
set_is_target
(
true
);
f
::
Prune
(
*
pdesc
,
&
pruned
);
f
::
Prune
(
*
pdesc
,
feed_var_names
,
&
pruned
);
PADDLE_ENFORCE_EQ
(
pruned
.
blocks
(
0
).
ops_size
(),
i
+
1
);
}
}
...
...
@@ -107,7 +109,8 @@ TEST(Prune, multi_input_op) {
pdesc
->
mutable_blocks
(
0
)
->
mutable_ops
(
3
)
->
set_is_target
(
true
);
f
::
proto
::
ProgramDesc
pruned
;
f
::
Prune
(
*
pdesc
,
&
pruned
);
std
::
set
<
std
::
string
>
feed_var_names
=
{
"a0"
,
"a1"
,
"a2"
};
f
::
Prune
(
*
pdesc
,
feed_var_names
,
&
pruned
);
PADDLE_ENFORCE_EQ
(
pruned
.
blocks
(
0
).
ops_size
(),
4
);
}
...
...
@@ -126,7 +129,8 @@ TEST(Prune, multi_output_op) {
pdesc
->
mutable_blocks
(
0
)
->
mutable_ops
(
2
)
->
set_is_target
(
true
);
f
::
proto
::
ProgramDesc
pruned
;
f
::
Prune
(
*
pdesc
,
&
pruned
);
std
::
set
<
std
::
string
>
feed_var_names
=
{
"a"
};
f
::
Prune
(
*
pdesc
,
feed_var_names
,
&
pruned
);
PADDLE_ENFORCE_EQ
(
pruned
.
blocks
(
0
).
ops_size
(),
2
);
}
...
...
@@ -146,6 +150,7 @@ TEST(Prune, multi_target) {
pdesc
->
mutable_blocks
(
0
)
->
mutable_ops
(
2
)
->
set_is_target
(
true
);
f
::
proto
::
ProgramDesc
pruned
;
f
::
Prune
(
*
pdesc
,
&
pruned
);
std
::
set
<
std
::
string
>
feed_var_names
=
{
"a"
};
f
::
Prune
(
*
pdesc
,
feed_var_names
,
&
pruned
);
PADDLE_ENFORCE_EQ
(
pruned
.
blocks
(
0
).
ops_size
(),
3
);
}
paddle/fluid/pybind/pybind.cc
浏览文件 @
dca9b6c5
...
...
@@ -749,13 +749,15 @@ All parameter, weight, gradient are variables in Paddle.
#endif
m
.
def
(
"prune"
,
[](
const
ProgramDesc
&
origin
,
const
std
::
set
<
std
::
string
>
&
feeded_var_names
,
const
std
::
vector
<
std
::
array
<
size_t
,
2
>>
&
targets
)
{
ProgramDesc
prog_with_targets
(
origin
);
for
(
const
auto
&
t
:
targets
)
{
prog_with_targets
.
MutableBlock
(
t
[
0
])
->
Op
(
t
[
1
])
->
SetIsTarget
(
true
);
}
proto
::
ProgramDesc
pruned_desc
;
Prune
(
*
prog_with_targets
.
Proto
(),
&
pruned_desc
);
Prune
(
*
prog_with_targets
.
Proto
(),
feeded_var_names
,
&
pruned_desc
);
return
new
ProgramDesc
(
pruned_desc
);
});
m
.
def
(
"empty_var_name"
,
...
...
python/paddle/fluid/framework.py
浏览文件 @
dca9b6c5
...
...
@@ -3247,7 +3247,7 @@ class Program(object):
p
.
_copy_dist_param_info_from
(
self
)
return
p
def
_prune
(
self
,
targets
):
def
_prune
(
self
,
feeded_var_names
,
targets
):
"""
Prune operators and variables which are not needed to generate
:code:`targets`.
...
...
@@ -3263,8 +3263,16 @@ class Program(object):
Program: A new, pruned program.
"""
if
not
isinstance
(
feeded_var_names
,
list
):
feeded_var_names
=
[
feeded_var_names
]
if
not
isinstance
(
targets
,
list
):
targets
=
[
targets
]
for
var
in
feeded_var_names
:
if
not
isinstance
(
var
,
six
.
string_types
):
raise
ValueError
(
"All feeded_var_names of prune() can only be "
"str."
)
targets_idx
=
[]
for
t
in
targets
:
if
not
isinstance
(
t
,
Operator
):
...
...
@@ -3291,7 +3299,7 @@ class Program(object):
targets_idx
.
append
([
t
.
block
.
idx
,
t
.
idx
])
res
=
Program
()
res
.
desc
=
core
.
prune
(
self
.
desc
,
targets_idx
)
res
.
desc
=
core
.
prune
(
self
.
desc
,
set
(
feeded_var_names
),
targets_idx
)
res
.
blocks
=
[
Block
(
res
,
i
)
for
i
in
six
.
moves
.
range
(
res
.
desc
.
num_blocks
())
]
...
...
python/paddle/fluid/io.py
浏览文件 @
dca9b6c5
...
...
@@ -1080,7 +1080,7 @@ def save_inference_model(dirname,
main_program
.
desc
.
flush
()
main_program
=
main_program
.
_prune
(
targets
=
target_vars
)
main_program
=
main_program
.
_prune
(
feeded_var_names
,
target_vars
)
main_program
=
main_program
.
_inference_optimize
(
prune_read_op
=
True
)
fetch_var_names
=
[
v
.
name
for
v
in
target_vars
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录