Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
e72ce197
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
接近 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e72ce197
编写于
12月 20, 2022
作者:
C
ceci3
提交者:
GitHub
12月 20, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support vit prune (#1590)
* support vit prune * update * add unittest
上级
5662a660
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
141 addition
and
25 deletion
+141
-25
paddleslim/auto_compression/transformer_pruner.py
paddleslim/auto_compression/transformer_pruner.py
+67
-17
paddleslim/common/patterns.py
paddleslim/common/patterns.py
+4
-3
paddleslim/common/patterns_common.py
paddleslim/common/patterns_common.py
+3
-2
paddleslim/common/transformer_pattern.py
paddleslim/common/transformer_pattern.py
+11
-1
paddleslim/quant/quanter.py
paddleslim/quant/quanter.py
+1
-2
tests/act/test_act_prune.py
tests/act/test_act_prune.py
+55
-0
未找到文件。
paddleslim/auto_compression/transformer_pruner.py
浏览文件 @
e72ce197
...
...
@@ -287,8 +287,10 @@ class TransformerPruner:
def
_preprocess_patterns
(
self
,
patterns
,
graph
):
""" Preprocess pattern of the program, get some info need by reorder"""
input_mask_op
=
patterns
[
'input_mask'
]
layer_num
=
int
((
len
(
patterns
)
-
1
)
/
2
)
input_mask_op
=
patterns
.
get
(
'input_mask'
,
None
)
layer_num
=
int
(
(
len
(
patterns
)
-
1
)
/
2
)
if
input_mask_op
is
not
None
else
int
(
(
len
(
patterns
)
/
2
))
### get real head number
head_num
=
-
1
...
...
@@ -395,8 +397,6 @@ class TransformerPruner:
shape
=
[
program
.
global_block
().
var
(
w_name
).
shape
[
1
]],
dtype
=
'float32'
))
exe
.
run
(
paddle
.
static
.
default_startup_program
())
### need to send a dataloader with label
for
batch_id
,
data
in
enumerate
(
dataloader
()):
outs
=
exe
.
run
(
program
,
feed
=
data
,
fetch_list
=
fetch_list
)
...
...
@@ -445,13 +445,21 @@ class TransformerPruner:
new_w
=
np
.
take
(
np_w
,
index
,
axis
=
dim
)
pd_w
.
set
(
new_w
,
place
)
if
int
(
len
(
qkv
)
/
2
)
==
1
:
q_index
=
index
k_index
=
index
+
768
v_index
=
index
+
(
768
*
2
)
qkv_index
=
np
.
append
(
np
.
append
(
q_index
,
k_index
),
v_index
)
else
:
qkv_index
=
index
for
w_idx
,
weight_name
in
enumerate
(
qkv
):
if
w_idx
%
2
==
0
:
### reorder qkv weight
reorder_head_matrix
(
weight_name
,
index
,
dim
=
1
)
reorder_head_matrix
(
weight_name
,
qkv_
index
,
dim
=
1
)
else
:
### reorder qkv bias
reorder_head_matrix
(
weight_name
,
index
,
dim
=
0
)
reorder_head_matrix
(
weight_name
,
qkv_
index
,
dim
=
0
)
### reorder attention output weight
reorder_head_matrix
(
attn_out
[
0
],
index
,
dim
=
0
)
...
...
@@ -507,7 +515,13 @@ class TransformerPruner:
op
.
desc
.
set_input
(
'X'
,
input_var_name
[:
int
(
len
(
input_var_name
)
*
new_inputs_len
)])
def
_prune_weight
(
self
,
graph
,
scope
,
place
,
pruned_name
,
pruned_ratio
):
def
_prune_weight
(
self
,
graph
,
scope
,
place
,
pruned_name
,
pruned_ratio
,
fuse_qkv
=
False
):
""" Prune every weight in program """
param
=
graph
.
var
(
pruned_name
)
_var
=
scope
.
find_var
(
param
.
name
())
...
...
@@ -516,26 +530,62 @@ class TransformerPruner:
param_t
=
_var
.
get_tensor
()
pruned_ratio
=
[
pruned_ratio
[
1
]]
if
len
(
param_t
.
shape
(
))
==
1
else
pruned_ratio
pruned_shape
=
np
.
multiply
(
param_t
.
shape
(),
pruned_ratio
)
pruned_shape
=
list
(
map
(
int
,
pruned_shape
))
param
.
set_shape
(
pruned_shape
)
if
len
(
pruned_shape
)
==
2
:
pruned_param
=
np
.
array
(
param_t
)[:
pruned_shape
[
0
],
:
pruned_shape
[
1
]]
origin_shape
=
param_t
.
shape
()
def
process_qkv
(
qkv_param
,
pruned_ratio
):
qkv_param_shape
=
qkv_param
.
shape
()
if
len
(
qkv_param_shape
)
==
2
:
tmp_qkv_param_shape
=
[
qkv_param_shape
[
0
],
-
1
,
3
]
else
:
tmp_qkv_param_shape
=
[
-
1
,
3
]
tmp_param
=
np
.
reshape
(
qkv_param
,
tmp_qkv_param_shape
)
tmp_pruned_ratio
=
pruned_ratio
+
[
1.0
]
tmp_pruned_shape
=
np
.
multiply
(
tmp_param
.
shape
,
tmp_pruned_ratio
)
tmp_pruned_shape
=
list
(
map
(
int
,
tmp_pruned_shape
))
if
len
(
qkv_param_shape
)
==
2
:
tmp_prune_qkv_param
=
tmp_param
[:
tmp_pruned_shape
[
0
],
:
tmp_pruned_shape
[
1
],
:
tmp_pruned_shape
[
2
]]
pruned_param
=
np
.
reshape
(
tmp_prune_qkv_param
,
(
qkv_param_shape
[
0
],
-
1
))
else
:
tmp_prune_qkv_param
=
tmp_param
[:
tmp_pruned_shape
[
0
],
:
tmp_pruned_shape
[
1
]]
pruned_param
=
np
.
reshape
(
tmp_prune_qkv_param
,
(
-
1
))
return
pruned_param
if
fuse_qkv
:
pruned_param
=
process_qkv
(
param_t
,
pruned_ratio
)
param
.
set_shape
(
pruned_param
.
shape
)
param_t
.
set
(
pruned_param
,
place
)
else
:
pruned_param
=
np
.
array
(
param_t
)[:
pruned_shape
[
0
]]
param_t
.
set
(
pruned_param
,
place
)
pruned_shape
=
np
.
multiply
(
param_t
.
shape
(),
pruned_ratio
)
pruned_shape
=
list
(
map
(
int
,
pruned_shape
))
param
.
set_shape
(
pruned_shape
)
if
len
(
pruned_shape
)
==
2
:
pruned_param
=
np
.
array
(
param_t
)[:
pruned_shape
[
0
],
:
pruned_shape
[
1
]]
else
:
pruned_param
=
np
.
array
(
param_t
)[:
pruned_shape
[
0
]]
param_t
.
set
(
pruned_param
,
place
)
def
_prune_transformer
(
self
,
scope
,
place
,
graph
,
pruned_dict
):
""" Prune transformer program """
qkv_weights_name
=
[]
if
(
len
(
self
.
mha_weight
[
0
][
'P1'
])
//
2
==
1
):
for
_
,
mha_weights_name
in
self
.
mha_weight
.
items
():
qkv_weights_name
.
extend
(
mha_weights_name
[
'P1'
])
for
name
,
value
in
pruned_dict
.
items
():
### prune weight
self
.
_prune_weight
(
graph
,
scope
,
place
,
name
,
value
)
fuse_qkv
=
False
if
name
in
qkv_weights_name
:
fuse_qkv
=
True
self
.
_prune_weight
(
graph
,
scope
,
place
,
name
,
value
,
fuse_qkv
)
graph
.
infer_shape
()
return
graph
.
program
def
prune
(
self
):
### get input_mask op and start to prune input_mask op
if
self
.
input_mask_op
.
type
==
'stack'
:
if
self
.
input_mask_op
is
not
None
and
self
.
input_mask_op
.
type
==
'stack'
:
self
.
_update_input_mask_inputs
(
self
.
inference_program
,
self
.
input_mask_op
,
self
.
width_mult
)
...
...
@@ -555,7 +605,7 @@ class TransformerPruner:
pruned_shape
[
-
1
]
=
int
(
origin_shape
[
-
1
]
*
self
.
width_mult
)
op
.
set_attr
(
'shape'
,
pruned_shape
)
elif
len
(
origin_shape
)
==
4
:
elif
len
(
origin_shape
)
==
4
or
len
(
origin_shape
)
==
5
:
pruned_shape
[
-
2
]
=
int
(
origin_shape
[
-
2
]
*
self
.
width_mult
)
op
.
set_attr
(
'shape'
,
pruned_shape
)
...
...
paddleslim/common/patterns.py
浏览文件 @
e72ce197
...
...
@@ -101,14 +101,15 @@ def get_patterns(program, only_final_node=True):
if
(
not
inp1
.
_var
.
persistable
)
and
(
not
inp2
.
_var
.
persistable
):
sc_path
=
[]
shortcut_start_op
=
[]
is_sc
=
is_shortcut
(
op
,
graph
,
sc_path
,
shortcut_start_op
)
is_sc
,
target_op_idx
=
is_shortcut
(
op
,
graph
,
sc_path
,
shortcut_start_op
)
if
is_sc
:
out_var_name
=
op
.
all_outputs
()[
0
].
_var
.
name
shortcut_start_op
=
shortcut_start_op
[
0
]
next_op
=
graph
.
next_ops
(
op
)
next_op
s
=
graph
.
next_ops
(
op
)
pattern_ops
,
pattern_ops_type
=
traversal_ops
(
shortcut_start_op
,
graph
,
next_op
[
0
].
idx
()
)
shortcut_start_op
,
graph
,
target_op_idx
)
pattern_name
=
shortcut_start_op
.
type
()
+
'$'
+
str
(
op
.
idx
(
))
...
...
paddleslim/common/patterns_common.py
浏览文件 @
e72ce197
...
...
@@ -132,5 +132,6 @@ def is_shortcut(op, graph, sc_path, shortcut_start_op):
if
n_op
.
idx
()
!=
op
.
idx
():
sc_path
.
append
(
p_op
.
type
())
sc_path
.
append
(
n_op
.
type
())
return
_find_next_target_op
(
n_op
,
graph
,
op
.
idx
(),
sc_path
)
return
False
return
_find_next_target_op
(
n_op
,
graph
,
op
.
idx
(),
sc_path
),
op
.
idx
()
return
False
,
-
1
paddleslim/common/transformer_pattern.py
浏览文件 @
e72ce197
...
...
@@ -18,8 +18,18 @@ from .patterns_common import *
__all__
=
[
'preprocess_transformer_patterns'
]
def
_find_gemm_op
(
op
,
graph
):
while
op
.
type
()
not
in
[
'mul'
,
'matmul'
,
'matmul_v2'
]:
next_op
=
find_weight_op
(
op
,
graph
)
op
=
next_op
return
op
def
_append_transformer_prune_params
(
op
,
graph
,
block_num
,
params_dict
):
for
next_op
in
graph
.
next_ops
(
op
):
if
next_op
.
type
()
==
'elementwise_add'
:
continue
next_op
=
_find_gemm_op
(
next_op
,
graph
)
if
next_op
.
type
()
in
[
'mul'
,
'matmul'
,
'matmul_v2'
]
and
is_dynamic_weight_op
(
next_op
):
if
block_num
not
in
params_dict
:
...
...
@@ -30,7 +40,7 @@ def _append_transformer_prune_params(op, graph, block_num, params_dict):
params_dict
[
block_num
][
'P1'
].
append
(
get_weight
(
has_bias
(
next_op
,
graph
)))
op
=
next_op
next_op
=
find_weight_op
(
op
,
graph
)
next_op
=
_find_gemm_op
(
find_weight_op
(
op
,
graph
)
,
graph
)
if
next_op
:
params_dict
[
block_num
][
'P2'
]
=
[
get_weight
(
next_op
)]
params_dict
[
block_num
][
'P2'
].
append
(
...
...
paddleslim/quant/quanter.py
浏览文件 @
e72ce197
...
...
@@ -428,8 +428,7 @@ def quant_aware(program,
quant_bits
=
config
[
'activation_bits'
],
skip_pattern
=
config
[
'not_quant_pattern'
],
quantizable_op_type
=
quant_dequant_ops
,
is_test
=
is_test
,
scale_dict
=
scale_dict
)
is_test
=
is_test
)
quant_dequant_pass
.
apply
(
main_graph
)
...
...
tests/act/test_act_prune.py
浏览文件 @
e72ce197
...
...
@@ -274,5 +274,60 @@ class ACTChannelPrune(unittest.TestCase):
os
.
system
(
'rm -rf asp_output'
)
class
ACTViTPrune
(
ACTChannelPrune
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
ACTViTPrune
,
self
).
__init__
(
*
args
,
**
kwargs
)
if
not
os
.
path
.
exists
(
'ViT_base_patch16_224_infer'
):
os
.
system
(
'wget -q https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ViT_base_patch16_224_infer.tar'
)
os
.
system
(
'tar -xf ViT_base_patch16_224_infer.tar'
)
if
not
os
.
path
.
exists
(
'ILSVRC2012_data_demo'
):
os
.
system
(
'wget -q https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz'
)
os
.
system
(
'tar -xf ILSVRC2012_data_demo.tar.gz'
)
self
.
train_dataloader
,
self
.
eval_dataloader
=
self
.
create_dataloader
()
def
test_act_vit_transformer_prune
(
self
):
def
eval_function
(
exe
,
compiled_test_program
,
test_feed_names
,
test_fetch_list
):
res
=
eval_func
(
compiled_test_program
,
exe
,
test_feed_names
,
test_fetch_list
,
self
.
eval_dataloader
)
return
res
configs
=
{
'Distillation'
:
{},
'TransformerPrune'
:
{
'pruned_ratio'
:
0.1
},
'TrainConfig'
:
{
'epochs'
:
1
,
'eval_iter'
:
1000
,
'learning_rate'
:
5.0e-03
,
'optimizer_builder'
:
{
'optimizer'
:
{
'type'
:
'SGD'
},
"weight_decay"
:
0.0005
,
}
}
}
ac
=
AutoCompression
(
model_dir
=
'./ViT_base_patch16_224_infer'
,
model_filename
=
"inference.pdmodel"
,
params_filename
=
"inference.pdiparams"
,
save_dir
=
"vit_prune_output"
,
config
=
configs
,
train_dataloader
=
self
.
train_dataloader
,
eval_callback
=
eval_function
,
eval_dataloader
=
self
.
eval_dataloader
)
# eval_function to verify accuracy
ac
.
compress
()
os
.
system
(
'rm -rf vit_prune_output'
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录