Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
5312f46e
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
1 年多 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5312f46e
编写于
1月 06, 2020
作者:
W
whs
提交者:
GitHub
1月 06, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix prune worker. (#27)
上级
790a9ffb
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
23 addition
and
34 deletion
+23
-34
paddleslim/prune/prune_walker.py
paddleslim/prune/prune_walker.py
+23
-34
未找到文件。
paddleslim/prune/prune_walker.py
浏览文件 @
5312f46e
...
...
@@ -49,14 +49,18 @@ class PruneWorker(object):
pruned_axis(int): The axis to be pruned of root variable.
pruned_idx(int): The indexes to be pruned in `pruned_axis` of root variable.
"""
if
self
.
_visit
(
var
,
pruned_axis
):
self
.
_prune
(
var
,
pruned_axis
,
pruned_idx
)
def
_visit
(
self
,
var
,
pruned_axis
):
key
=
"_"
.
join
([
str
(
self
.
op
.
idx
()),
var
.
name
()])
if
pruned_axis
not
in
self
.
visited
:
self
.
visited
[
pruned_axis
]
=
{}
if
key
in
self
.
visited
[
pruned_axis
]:
return
return
False
else
:
self
.
visited
[
pruned_axis
][
key
]
=
True
self
.
_prune
(
var
,
pruned_axis
,
pruned_idx
)
return
True
def
_prune
(
self
,
var
,
pruned_axis
,
pruned_idx
):
raise
NotImplementedError
(
'Abstract method.'
)
...
...
@@ -83,7 +87,7 @@ class conv2d(PruneWorker):
super
(
conv2d
,
self
).
__init__
(
op
,
pruned_params
,
visited
)
def
_prune
(
self
,
var
,
pruned_axis
,
pruned_idx
):
data_format
=
sef
.
op
.
attr
(
"data_format"
)
data_format
=
se
l
f
.
op
.
attr
(
"data_format"
)
channel_axis
=
1
if
data_format
==
"NHWC"
:
channel_axis
=
3
...
...
@@ -91,8 +95,7 @@ class conv2d(PruneWorker):
assert
pruned_axis
==
channel_axis
,
"The Input of conv2d can only be pruned at channel axis, but got {}; var: {}"
.
format
(
pruned_axis
,
var
.
name
())
filter_var
=
self
.
op
.
inputs
(
"Filter"
)[
0
]
key
=
"_"
.
join
([
str
(
self
.
op
.
idx
()),
filter_var
.
name
()])
self
.
visited
[
1
][
key
]
=
True
self
.
_visit
(
filter_var
,
1
)
self
.
pruned_params
.
append
((
filter_var
,
1
,
pruned_idx
))
for
op
in
filter_var
.
outputs
():
self
.
_prune_op
(
op
,
filter_var
,
1
,
pruned_idx
)
...
...
@@ -110,16 +113,14 @@ class conv2d(PruneWorker):
self
.
pruned_params
.
append
(
(
self
.
op
.
inputs
(
"Bias"
),
channel_axis
,
pruned_idx
))
output_var
=
self
.
op
.
outputs
(
"Output"
)[
0
]
key
=
"_"
.
join
([
str
(
self
.
op
.
idx
()),
output_var
.
name
()])
self
.
visited
[
channel_axis
][
key
]
=
True
self
.
_visit
(
output_var
,
channel_axis
)
next_ops
=
output_var
.
outputs
()
for
op
in
next_ops
:
self
.
_prune_op
(
op
,
output_var
,
channel_axis
,
pruned_idx
)
elif
pruned_axis
==
1
:
input_var
=
self
.
op
.
inputs
(
"Input"
)[
0
]
key
=
"_"
.
join
([
str
(
self
.
op
.
idx
()),
input_var
.
name
()])
self
.
visited
[
channel_axis
][
key
]
=
True
self
.
_visit
(
input_var
,
channel_axis
)
pre_ops
=
input_var
.
inputs
()
for
op
in
pre_ops
:
self
.
_prune_op
(
op
,
input_var
,
channel_axis
,
pruned_idx
)
...
...
@@ -128,8 +129,7 @@ class conv2d(PruneWorker):
pruned_axis
,
var
.
name
())
filter_var
=
self
.
op
.
inputs
(
"Filter"
)[
0
]
key
=
"_"
.
join
([
str
(
self
.
op
.
idx
()),
filter_var
.
name
()])
self
.
visited
[
0
][
key
]
=
True
self
.
_visit
(
filter_var
,
0
)
self
.
pruned_params
.
append
((
filter_var
,
0
,
pruned_idx
))
...
...
@@ -158,8 +158,7 @@ class batch_norm(PruneWorker):
if
var
in
self
.
op
.
outputs
(
"Y"
):
in_var
=
self
.
op
.
inputs
(
"X"
)[
0
]
key
=
"_"
.
join
([
str
(
self
.
op
.
idx
()),
in_var
.
name
()])
self
.
visited
[
pruned_axis
][
key
]
=
True
self
.
_visit
(
in_var
,
pruned_axis
)
pre_ops
=
in_var
.
inputs
()
for
op
in
pre_ops
:
self
.
_prune_op
(
op
,
in_var
,
pruned_axis
,
pruned_idx
)
...
...
@@ -171,8 +170,7 @@ class batch_norm(PruneWorker):
self
.
pruned_params
.
append
((
param_var
,
0
,
pruned_idx
))
out_var
=
self
.
op
.
outputs
(
"Y"
)[
0
]
key
=
"_"
.
join
([
str
(
self
.
op
.
idx
()),
out_var
.
name
()])
self
.
visited
[
pruned_axis
][
key
]
=
True
self
.
_visit
(
out_var
,
pruned_axis
)
next_ops
=
out_var
.
outputs
()
for
op
in
next_ops
:
self
.
_prune_op
(
op
,
out_var
,
pruned_axis
,
pruned_idx
)
...
...
@@ -214,8 +212,7 @@ class elementwise_op(PruneWorker):
self
.
_prune_op
(
op
,
in_var
,
pruned_axis
,
pruned_idx
)
out_var
=
self
.
op
.
outputs
(
"Out"
)[
0
]
key
=
"_"
.
join
([
str
(
self
.
op
.
idx
()),
out_var
.
name
()])
self
.
visited
[
pruned_axis
][
key
]
=
True
self
.
_visit
(
out_var
,
pruned_axis
)
next_ops
=
out_var
.
outputs
()
for
op
in
next_ops
:
self
.
_prune_op
(
op
,
out_var
,
pruned_axis
,
pruned_idx
)
...
...
@@ -253,8 +250,7 @@ class activation(PruneWorker):
self
.
_prune_op
(
op
,
in_var
,
pruned_axis
,
pruned_idx
)
out_var
=
self
.
op
.
outputs
(
self
.
output_name
)[
0
]
key
=
"_"
.
join
([
str
(
self
.
op
.
idx
()),
out_var
.
name
()])
self
.
visited
[
pruned_axis
][
key
]
=
True
self
.
_visit
(
out_var
,
pruned_axis
)
next_ops
=
out_var
.
outputs
()
for
op
in
next_ops
:
self
.
_prune_op
(
op
,
out_var
,
pruned_axis
,
pruned_idx
)
...
...
@@ -317,8 +313,7 @@ class sum(PruneWorker):
for
op
in
pre_ops
:
self
.
_prune_op
(
op
,
in_var
,
pruned_axis
,
pruned_idx
)
out_var
=
self
.
op
.
outputs
(
"Out"
)[
0
]
key
=
"_"
.
join
([
str
(
self
.
op
.
idx
()),
out_var
.
name
()])
self
.
visited
[
pruned_axis
][
key
]
=
True
self
.
_visit
(
out_var
,
pruned_axis
)
next_ops
=
out_var
.
outputs
()
for
op
in
next_ops
:
self
.
_prune_op
(
op
,
out_var
,
pruned_axis
,
pruned_idx
)
...
...
@@ -363,8 +358,7 @@ class concat(PruneWorker):
start
+=
v
.
shape
()[
pruned_axis
]
out_var
=
self
.
op
.
outputs
(
"Out"
)[
0
]
key
=
"_"
.
join
([
str
(
self
.
op
.
idx
()),
out_var
.
name
()])
self
.
visited
[
pruned_axis
][
key
]
=
True
self
.
_visit
(
out_var
,
pruned_axis
)
next_ops
=
out_var
.
outputs
()
for
op
in
next_ops
:
self
.
_prune_op
(
op
,
out_var
,
pruned_axis
,
idx
,
visited
=
{})
...
...
@@ -373,8 +367,7 @@ class concat(PruneWorker):
for
op
in
v
.
inputs
():
self
.
_prune_op
(
op
,
v
,
pruned_axis
,
pruned_idx
)
out_var
=
self
.
op
.
outputs
(
"Out"
)[
0
]
key
=
"_"
.
join
([
str
(
self
.
op
.
idx
()),
out_var
.
name
()])
self
.
visited
[
pruned_axis
][
key
]
=
True
self
.
_visit
(
out_var
,
pruned_axis
)
next_ops
=
out_var
.
outputs
()
for
op
in
next_ops
:
self
.
_prune_op
(
op
,
out_var
,
pruned_axis
,
pruned_idx
)
...
...
@@ -386,7 +379,7 @@ class depthwise_conv2d(PruneWorker):
super
(
depthwise_conv2d
,
self
).
__init__
(
op
,
pruned_params
,
visited
)
def
_prune
(
self
,
var
,
pruned_axis
,
pruned_idx
):
data_format
=
sef
.
op
.
attr
(
"data_format"
)
data_format
=
se
l
f
.
op
.
attr
(
"data_format"
)
channel_axis
=
1
if
data_format
==
"NHWC"
:
channel_axis
=
3
...
...
@@ -396,8 +389,7 @@ class depthwise_conv2d(PruneWorker):
filter_var
=
self
.
op
.
inputs
(
"Filter"
)[
0
]
self
.
pruned_params
.
append
((
filter_var
,
0
,
pruned_idx
))
key
=
"_"
.
join
([
str
(
self
.
op
.
idx
()),
filter_var
.
name
()])
self
.
visited
[
0
][
key
]
=
True
self
.
_visit
(
filter_var
,
0
)
new_groups
=
filter_var
.
shape
()[
0
]
-
len
(
pruned_idx
)
self
.
op
.
set_attr
(
"groups"
,
new_groups
)
...
...
@@ -425,8 +417,7 @@ class depthwise_conv2d(PruneWorker):
self
.
_prune_op
(
op
,
var
,
0
,
pruned_idx
)
output_var
=
self
.
op
.
outputs
(
"Output"
)[
0
]
key
=
"_"
.
join
([
str
(
self
.
op
.
idx
()),
output_var
.
name
()])
self
.
visited
[
channel_axis
][
key
]
=
True
self
.
_visit
(
output_var
,
channel_axis
)
next_ops
=
output_var
.
outputs
()
for
op
in
next_ops
:
self
.
_prune_op
(
op
,
output_var
,
channel_axis
,
pruned_idx
)
...
...
@@ -436,8 +427,7 @@ class depthwise_conv2d(PruneWorker):
assert
pruned_axis
==
channel_axis
filter_var
=
self
.
op
.
inputs
(
"Filter"
)[
0
]
self
.
pruned_params
.
append
((
filter_var
,
0
,
pruned_idx
))
key
=
"_"
.
join
([
str
(
self
.
op
.
idx
()),
filter_var
.
name
()])
self
.
visited
[
0
][
key
]
=
True
self
.
_visit
(
filter_var
,
0
)
new_groups
=
filter_var
.
shape
()[
0
]
-
len
(
pruned_idx
)
op
.
set_attr
(
"groups"
,
new_groups
)
...
...
@@ -450,8 +440,7 @@ class depthwise_conv2d(PruneWorker):
(
self
.
op
.
inputs
(
"Bias"
)[
0
],
channel_axis
,
pruned_idx
))
in_var
=
self
.
op
.
inputs
(
"Input"
)[
0
]
key
=
"_"
.
join
([
str
(
self
.
op
.
idx
()),
in_var
.
name
()])
self
.
visited
[
channel_axis
][
key
]
=
True
self
.
_visit
(
in_var
,
channel_axis
)
pre_ops
=
in_var
.
inputs
()
for
op
in
pre_ops
:
self
.
_prune_op
(
op
,
in_var
,
channel_axis
,
pruned_idx
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录