Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
b81f27a1
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
1 年多 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b81f27a1
编写于
6月 05, 2020
作者:
W
whs
提交者:
GitHub
6月 05, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix pruning for yolov4 (#313)
上级
44e359c4
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
46 addition
and
35 deletion
+46
-35
paddleslim/prune/criterion.py
paddleslim/prune/criterion.py
+9
-9
paddleslim/prune/group_param.py
paddleslim/prune/group_param.py
+3
-3
paddleslim/prune/idx_selector.py
paddleslim/prune/idx_selector.py
+4
-3
paddleslim/prune/prune_walker.py
paddleslim/prune/prune_walker.py
+19
-14
paddleslim/prune/pruner.py
paddleslim/prune/pruner.py
+11
-6
未找到文件。
paddleslim/prune/criterion.py
浏览文件 @
b81f27a1
...
@@ -43,11 +43,11 @@ def l1_norm(group, graph):
...
@@ -43,11 +43,11 @@ def l1_norm(group, graph):
list: A list of tuple storing l1-norm on given axis.
list: A list of tuple storing l1-norm on given axis.
"""
"""
scores
=
[]
scores
=
[]
for
name
,
value
,
axis
in
group
:
for
name
,
value
,
axis
,
pruned_idx
in
group
:
reduce_dims
=
[
i
for
i
in
range
(
len
(
value
.
shape
))
if
i
!=
axis
]
reduce_dims
=
[
i
for
i
in
range
(
len
(
value
.
shape
))
if
i
!=
axis
]
score
=
np
.
sum
(
np
.
abs
(
value
),
axis
=
tuple
(
reduce_dims
))
score
=
np
.
sum
(
np
.
abs
(
value
),
axis
=
tuple
(
reduce_dims
))
scores
.
append
((
name
,
axis
,
score
))
scores
.
append
((
name
,
axis
,
score
,
pruned_idx
))
return
scores
return
scores
...
@@ -55,7 +55,7 @@ def l1_norm(group, graph):
...
@@ -55,7 +55,7 @@ def l1_norm(group, graph):
@
CRITERION
.
register
@
CRITERION
.
register
def
geometry_median
(
group
,
graph
):
def
geometry_median
(
group
,
graph
):
scores
=
[]
scores
=
[]
name
,
value
,
axis
=
group
[
0
]
name
,
value
,
axis
,
_
=
group
[
0
]
assert
(
len
(
value
.
shape
)
==
4
)
assert
(
len
(
value
.
shape
)
==
4
)
def
get_distance_sum
(
value
,
out_idx
):
def
get_distance_sum
(
value
,
out_idx
):
...
@@ -73,8 +73,8 @@ def geometry_median(group, graph):
...
@@ -73,8 +73,8 @@ def geometry_median(group, graph):
tmp
=
np
.
array
(
dist_sum_list
)
tmp
=
np
.
array
(
dist_sum_list
)
for
name
,
value
,
axis
in
group
:
for
name
,
value
,
axis
,
idx
in
group
:
scores
.
append
((
name
,
axis
,
tmp
))
scores
.
append
((
name
,
axis
,
tmp
,
idx
))
return
scores
return
scores
...
@@ -97,7 +97,7 @@ def bn_scale(group, graph):
...
@@ -97,7 +97,7 @@ def bn_scale(group, graph):
assert
(
isinstance
(
graph
,
GraphWrapper
))
assert
(
isinstance
(
graph
,
GraphWrapper
))
# step1: Get first convolution
# step1: Get first convolution
conv_weight
,
value
,
axis
=
group
[
0
]
conv_weight
,
value
,
axis
,
_
=
group
[
0
]
param_var
=
graph
.
var
(
conv_weight
)
param_var
=
graph
.
var
(
conv_weight
)
conv_op
=
param_var
.
outputs
()[
0
]
conv_op
=
param_var
.
outputs
()[
0
]
...
@@ -111,12 +111,12 @@ def bn_scale(group, graph):
...
@@ -111,12 +111,12 @@ def bn_scale(group, graph):
# steps3: Find scale of bn
# steps3: Find scale of bn
score
=
None
score
=
None
for
name
,
value
,
aixs
in
group
:
for
name
,
value
,
aixs
,
_
in
group
:
if
bn_scale_param
==
name
:
if
bn_scale_param
==
name
:
score
=
np
.
abs
(
value
.
reshape
([
-
1
]))
score
=
np
.
abs
(
value
.
reshape
([
-
1
]))
scores
=
[]
scores
=
[]
for
name
,
value
,
axis
in
group
:
for
name
,
value
,
axis
,
idx
in
group
:
scores
.
append
((
name
,
axis
,
score
))
scores
.
append
((
name
,
axis
,
score
,
idx
))
return
scores
return
scores
paddleslim/prune/group_param.py
浏览文件 @
b81f27a1
...
@@ -57,21 +57,21 @@ def collect_convs(params, graph, visited={}):
...
@@ -57,21 +57,21 @@ def collect_convs(params, graph, visited={}):
conv_op
=
param
.
outputs
()[
0
]
conv_op
=
param
.
outputs
()[
0
]
walker
=
conv2d_walker
(
walker
=
conv2d_walker
(
conv_op
,
pruned_params
=
pruned_params
,
visited
=
visited
)
conv_op
,
pruned_params
=
pruned_params
,
visited
=
visited
)
walker
.
prune
(
param
,
pruned_axis
=
0
,
pruned_idx
=
[])
walker
.
prune
(
param
,
pruned_axis
=
0
,
pruned_idx
=
[
0
])
groups
.
append
(
pruned_params
)
groups
.
append
(
pruned_params
)
visited
=
set
()
visited
=
set
()
uniq_groups
=
[]
uniq_groups
=
[]
for
group
in
groups
:
for
group
in
groups
:
repeat_group
=
False
repeat_group
=
False
simple_group
=
[]
simple_group
=
[]
for
param
,
axis
,
_
in
group
:
for
param
,
axis
,
pruned_idx
in
group
:
param
=
param
.
name
()
param
=
param
.
name
()
if
axis
==
0
:
if
axis
==
0
:
if
param
in
visited
:
if
param
in
visited
:
repeat_group
=
True
repeat_group
=
True
else
:
else
:
visited
.
add
(
param
)
visited
.
add
(
param
)
simple_group
.
append
((
param
,
axis
))
simple_group
.
append
((
param
,
axis
,
pruned_idx
))
if
not
repeat_group
:
if
not
repeat_group
:
uniq_groups
.
append
(
simple_group
)
uniq_groups
.
append
(
simple_group
)
...
...
paddleslim/prune/idx_selector.py
浏览文件 @
b81f27a1
...
@@ -52,7 +52,7 @@ def default_idx_selector(group, ratio):
...
@@ -52,7 +52,7 @@ def default_idx_selector(group, ratio):
list: pruned indexes
list: pruned indexes
"""
"""
name
,
axis
,
score
=
group
[
name
,
axis
,
score
,
_
=
group
[
0
]
# sort channels by the first convolution's score
0
]
# sort channels by the first convolution's score
sorted_idx
=
score
.
argsort
()
sorted_idx
=
score
.
argsort
()
...
@@ -60,8 +60,9 @@ def default_idx_selector(group, ratio):
...
@@ -60,8 +60,9 @@ def default_idx_selector(group, ratio):
pruned_idx
=
sorted_idx
[:
pruned_num
]
pruned_idx
=
sorted_idx
[:
pruned_num
]
idxs
=
[]
idxs
=
[]
for
name
,
axis
,
score
in
group
:
for
name
,
axis
,
score
,
offsets
in
group
:
idxs
.
append
((
name
,
axis
,
pruned_idx
))
r_idx
=
[
i
+
offsets
[
0
]
for
i
in
pruned_idx
]
idxs
.
append
((
name
,
axis
,
r_idx
))
return
idxs
return
idxs
...
...
paddleslim/prune/prune_walker.py
浏览文件 @
b81f27a1
...
@@ -77,9 +77,10 @@ class PruneWorker(object):
...
@@ -77,9 +77,10 @@ class PruneWorker(object):
if
op
.
type
()
in
SKIP_OPS
:
if
op
.
type
()
in
SKIP_OPS
:
_logger
.
warn
(
"Skip operator [{}]"
.
format
(
op
.
type
()))
_logger
.
warn
(
"Skip operator [{}]"
.
format
(
op
.
type
()))
return
return
_logger
.
warn
(
"{} op will be pruned by default walker to keep the shapes of input and output being same because its walker is not registered."
.
# _logger.warn(
format
(
op
.
type
()))
# "{} op will be pruned by default walker to keep the shapes of input and output being same because its walker is not registered.".
# format(op.type()))
cls
=
PRUNE_WORKER
.
get
(
"default_walker"
)
cls
=
PRUNE_WORKER
.
get
(
"default_walker"
)
_logger
.
debug
(
"
\n
from: {}
\n
to: {}
\n
pruned_axis: {}; var: {}"
.
format
(
_logger
.
debug
(
"
\n
from: {}
\n
to: {}
\n
pruned_axis: {}; var: {}"
.
format
(
self
.
op
,
op
,
pruned_axis
,
var
.
name
()))
self
.
op
,
op
,
pruned_axis
,
var
.
name
()))
...
@@ -263,6 +264,8 @@ class elementwise_op(PruneWorker):
...
@@ -263,6 +264,8 @@ class elementwise_op(PruneWorker):
if
name
==
"Y"
:
if
name
==
"Y"
:
actual_axis
=
pruned_axis
-
axis
actual_axis
=
pruned_axis
-
axis
in_var
=
self
.
op
.
inputs
(
name
)[
0
]
in_var
=
self
.
op
.
inputs
(
name
)[
0
]
if
len
(
in_var
.
shape
())
==
1
and
in_var
.
shape
()[
0
]
==
1
:
continue
pre_ops
=
in_var
.
inputs
()
pre_ops
=
in_var
.
inputs
()
for
op
in
pre_ops
:
for
op
in
pre_ops
:
self
.
_prune_op
(
op
,
in_var
,
actual_axis
,
pruned_idx
)
self
.
_prune_op
(
op
,
in_var
,
actual_axis
,
pruned_idx
)
...
@@ -270,15 +273,17 @@ class elementwise_op(PruneWorker):
...
@@ -270,15 +273,17 @@ class elementwise_op(PruneWorker):
else
:
else
:
if
var
in
self
.
op
.
inputs
(
"X"
):
if
var
in
self
.
op
.
inputs
(
"X"
):
in_var
=
self
.
op
.
inputs
(
"Y"
)[
0
]
in_var
=
self
.
op
.
inputs
(
"Y"
)[
0
]
if
not
(
len
(
in_var
.
shape
())
==
1
and
in_var
.
shape
()[
0
]
==
1
):
if
in_var
.
is_parameter
():
if
in_var
.
is_parameter
():
self
.
pruned_params
.
append
(
self
.
pruned_params
.
append
(
(
in_var
,
pruned_axis
-
axis
,
pruned_idx
))
(
in_var
,
pruned_axis
-
axis
,
pruned_idx
))
pre_ops
=
in_var
.
inputs
()
pre_ops
=
in_var
.
inputs
()
for
op
in
pre_ops
:
for
op
in
pre_ops
:
self
.
_prune_op
(
op
,
in_var
,
pruned_axis
-
axis
,
pruned_idx
)
self
.
_prune_op
(
op
,
in_var
,
pruned_axis
-
axis
,
pruned_idx
)
elif
var
in
self
.
op
.
inputs
(
"Y"
):
elif
var
in
self
.
op
.
inputs
(
"Y"
):
in_var
=
self
.
op
.
inputs
(
"X"
)[
0
]
in_var
=
self
.
op
.
inputs
(
"X"
)[
0
]
if
not
(
len
(
in_var
.
shape
())
==
1
and
in_var
.
shape
()[
0
]
==
1
):
pre_ops
=
in_var
.
inputs
()
pre_ops
=
in_var
.
inputs
()
pruned_axis
=
pruned_axis
+
axis
pruned_axis
=
pruned_axis
+
axis
for
op
in
pre_ops
:
for
op
in
pre_ops
:
...
...
paddleslim/prune/pruner.py
浏览文件 @
b81f27a1
...
@@ -90,12 +90,14 @@ class Pruner():
...
@@ -90,12 +90,14 @@ class Pruner():
visited
=
{}
visited
=
{}
pruned_params
=
[]
pruned_params
=
[]
for
param
,
ratio
in
zip
(
params
,
ratios
):
for
param
,
ratio
in
zip
(
params
,
ratios
):
_logger
.
info
(
"pruning: {}"
.
format
(
param
))
if
graph
.
var
(
param
)
is
None
:
if
graph
.
var
(
param
)
is
None
:
_logger
.
warn
(
_logger
.
warn
(
"Variable[{}] to be pruned is not in current graph."
.
"Variable[{}] to be pruned is not in current graph."
.
format
(
param
))
format
(
param
))
continue
continue
group
=
collect_convs
([
param
],
graph
,
visited
)[
0
]
# [(name, axis)]
group
=
collect_convs
([
param
],
graph
,
visited
)[
0
]
# [(name, axis, pruned_idx)]
if
group
is
None
or
len
(
group
)
==
0
:
if
group
is
None
or
len
(
group
)
==
0
:
continue
continue
if
only_graph
and
self
.
idx_selector
.
__name__
==
"default_idx_selector"
:
if
only_graph
and
self
.
idx_selector
.
__name__
==
"default_idx_selector"
:
...
@@ -103,30 +105,33 @@ class Pruner():
...
@@ -103,30 +105,33 @@ class Pruner():
param_v
=
graph
.
var
(
param
)
param_v
=
graph
.
var
(
param
)
pruned_num
=
int
(
round
(
param_v
.
shape
()[
0
]
*
ratio
))
pruned_num
=
int
(
round
(
param_v
.
shape
()[
0
]
*
ratio
))
pruned_idx
=
[
0
]
*
pruned_num
pruned_idx
=
[
0
]
*
pruned_num
for
name
,
axis
in
group
:
for
name
,
axis
,
_
in
group
:
pruned_params
.
append
((
name
,
axis
,
pruned_idx
))
pruned_params
.
append
((
name
,
axis
,
pruned_idx
))
else
:
else
:
assert
((
not
self
.
pruned_weights
),
assert
((
not
self
.
pruned_weights
),
"The weights have been pruned once."
)
"The weights have been pruned once."
)
group_values
=
[]
group_values
=
[]
for
name
,
axis
in
group
:
for
name
,
axis
,
pruned_idx
in
group
:
values
=
np
.
array
(
scope
.
find_var
(
name
).
get_tensor
())
values
=
np
.
array
(
scope
.
find_var
(
name
).
get_tensor
())
group_values
.
append
((
name
,
values
,
axis
))
group_values
.
append
((
name
,
values
,
axis
,
pruned_idx
))
scores
=
self
.
criterion
(
group_values
,
scores
=
self
.
criterion
(
graph
)
# [(name, axis, score
)]
group_values
,
graph
)
# [(name, axis, score, pruned_idx
)]
pruned_params
.
extend
(
self
.
idx_selector
(
scores
,
ratio
))
pruned_params
.
extend
(
self
.
idx_selector
(
scores
,
ratio
))
merge_pruned_params
=
{}
merge_pruned_params
=
{}
for
param
,
pruned_axis
,
pruned_idx
in
pruned_params
:
for
param
,
pruned_axis
,
pruned_idx
in
pruned_params
:
print
(
"{}
\t
{}
\t
{}"
.
format
(
param
,
pruned_axis
,
len
(
pruned_idx
)))
if
param
not
in
merge_pruned_params
:
if
param
not
in
merge_pruned_params
:
merge_pruned_params
[
param
]
=
{}
merge_pruned_params
[
param
]
=
{}
if
pruned_axis
not
in
merge_pruned_params
[
param
]:
if
pruned_axis
not
in
merge_pruned_params
[
param
]:
merge_pruned_params
[
param
][
pruned_axis
]
=
[]
merge_pruned_params
[
param
][
pruned_axis
]
=
[]
merge_pruned_params
[
param
][
pruned_axis
].
append
(
pruned_idx
)
merge_pruned_params
[
param
][
pruned_axis
].
append
(
pruned_idx
)
print
(
"param name: stage.0.conv_layer.conv.weights; idx: {}"
.
format
(
merge_pruned_params
[
"stage.0.conv_layer.conv.weights"
][
1
]))
for
param_name
in
merge_pruned_params
:
for
param_name
in
merge_pruned_params
:
for
pruned_axis
in
merge_pruned_params
[
param_name
]:
for
pruned_axis
in
merge_pruned_params
[
param_name
]:
pruned_idx
=
np
.
concatenate
(
merge_pruned_params
[
param_name
][
pruned_idx
=
np
.
concatenate
(
merge_pruned_params
[
param_name
][
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录