Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
2ca0e118
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2ca0e118
编写于
4月 24, 2020
作者:
L
Leo Chen
提交者:
GitHub
4月 24, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support fetch the feed var when use_prune=True, test=develop (#24110)
上级
fb0729ee
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
69 addition
and
5 deletion
+69
-5
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+16
-5
python/paddle/fluid/tests/unittests/test_prune.py
python/paddle/fluid/tests/unittests/test_prune.py
+53
-0
未找到文件。
python/paddle/fluid/framework.py
浏览文件 @
2ca0e118
...
...
@@ -4237,6 +4237,14 @@ class Program(object):
raise
ValueError
(
"All targets of Program._prune_with_input() can only be "
"Variable or Operator, but received %s."
%
type
(
t
))
# NOTEZ(zhiqiu): For variable to be fed in fetch_list, there two cases:
# (1) the variable is leaf, it has no op that generates it;
# (2) the variable is not leaf, and we need to prune the op that generates it.
# In both cases, wo can just skip target_op of that it.
if
name
in
feeded_var_names
:
continue
# After transpiler processing, the op that output this
# variable maybe has been changed, so t.op is not reliable
# and we need to find the current op that generate this
...
...
@@ -4253,11 +4261,14 @@ class Program(object):
else
:
target_op
=
op
break
t
=
target_op
if
t
is
None
:
raise
ValueError
(
"The target variable must have an "
"associated operator that generates it."
)
targets_idx
.
append
([
t
.
block
.
idx
,
t
.
idx
])
if
target_op
is
None
:
raise
ValueError
(
"The target variable used for pruning should have an "
"associated operator that generates it."
)
else
:
targets_idx
.
append
([
target_op
.
block
.
idx
,
target_op
.
idx
])
else
:
targets_idx
.
append
([
t
.
block
.
idx
,
t
.
idx
])
res
=
Program
()
res
.
desc
,
pruned_origin_block_id_map
=
core
.
prune
(
self
.
desc
,
...
...
python/paddle/fluid/tests/unittests/test_prune.py
浏览文件 @
2ca0e118
...
...
@@ -725,6 +725,59 @@ class TestExecutorRunAutoPrune(unittest.TestCase):
self
.
assertTrue
(
np
.
array_equal
(
weight_with_prune
,
weight_expected
))
self
.
assertFalse
(
np
.
array_equal
(
weight_without_prune
,
weight_expected
))
def
test_prune_feed_var_in_fetchlist_1
(
self
):
# the variable to be fed is not leaf
program
=
framework
.
Program
()
startup_program
=
framework
.
Program
()
scope
=
fluid
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
with
fluid
.
program_guard
(
program
,
startup_program
):
(
x
,
y
,
label
,
loss1
,
loss2
,
w_param_attrs
)
=
self
.
net1
()
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
exe
.
run
(
startup_program
)
weight_init
=
np
.
array
(
scope
.
find_var
(
w_param_attrs
.
name
).
get_tensor
())
x_np
=
np
.
random
.
random
(
size
=
(
10
,
2
)).
astype
(
'float32'
)
label_np
=
np
.
random
.
randint
(
1
,
size
=
(
10
,
1
)).
astype
(
'int64'
)
res
=
exe
.
run
(
program
,
feed
=
{
y
.
name
:
x_np
,
'label'
:
label_np
},
fetch_list
=
[
y
.
name
,
loss1
.
name
],
use_prune
=
True
)
self
.
assertIsNotNone
(
scope
.
find_var
(
loss1
.
name
))
self
.
assertIsNone
(
scope
.
find_var
(
loss2
.
name
))
self
.
assertIsNone
(
scope
.
find_var
(
x
.
name
))
weight
=
np
.
array
(
scope
.
find_var
(
w_param_attrs
.
name
).
get_tensor
())
self
.
assertTrue
(
np
.
array_equal
(
weight_init
,
weight
))
# weight unchanged
def
test_prune_feed_var_in_fetchlist_2
(
self
):
# the variable to be fed is leaf
program
=
framework
.
Program
()
startup_program
=
framework
.
Program
()
scope
=
fluid
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
with
fluid
.
program_guard
(
program
,
startup_program
):
(
x
,
y
,
label
,
loss1
,
loss2
,
w_param_attrs
)
=
self
.
net1
()
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
exe
.
run
(
startup_program
)
weight_init
=
np
.
array
(
scope
.
find_var
(
w_param_attrs
.
name
).
get_tensor
())
x_np
=
np
.
random
.
random
(
size
=
(
10
,
2
)).
astype
(
'float32'
)
label_np
=
np
.
random
.
randint
(
1
,
size
=
(
10
,
1
)).
astype
(
'int64'
)
res
=
exe
.
run
(
program
,
feed
=
{
x
.
name
:
x_np
,
'label'
:
label_np
},
fetch_list
=
[
x
.
name
,
loss1
.
name
],
use_prune
=
True
)
self
.
assertIsNotNone
(
scope
.
find_var
(
loss1
.
name
))
self
.
assertIsNone
(
scope
.
find_var
(
loss2
.
name
))
weight
=
np
.
array
(
scope
.
find_var
(
w_param_attrs
.
name
).
get_tensor
())
self
.
assertTrue
(
np
.
array_equal
(
weight_init
,
weight
))
# weight unchanged
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录