Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleX
提交
9dcbe669
P
PaddleX
项目概览
PaddlePaddle
/
PaddleX
通知
138
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
43
列表
看板
标记
里程碑
合并请求
5
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleX
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
43
Issue
43
列表
看板
标记
里程碑
合并请求
5
合并请求
5
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9dcbe669
编写于
6月 16, 2020
作者:
J
Jason
提交者:
GitHub
6月 16, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #154 from SunAhong1993/syf_prune
add prune configs and prompt
上级
2299445e
b12b6b4f
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
35 addition
and
1 deletion
+35
-1
paddlex/cv/models/load_model.py
paddlex/cv/models/load_model.py
+1
-0
paddlex/cv/models/slim/prune.py
paddlex/cv/models/slim/prune.py
+2
-0
paddlex/cv/models/slim/prune_config.py
paddlex/cv/models/slim/prune_config.py
+32
-1
未找到文件。
paddlex/cv/models/load_model.py
浏览文件 @
9dcbe669
...
...
@@ -108,6 +108,7 @@ def load_model(model_dir, fixed_input_shape=None):
logging
.
info
(
"Model[{}] loaded."
.
format
(
info
[
'Model'
]))
model
.
trainable
=
False
model
.
status
=
status
return
model
...
...
paddlex/cv/models/slim/prune.py
浏览文件 @
9dcbe669
...
...
@@ -158,6 +158,7 @@ def prune_program(model, prune_params_ratios=None):
prune_params_ratios (dict): 由裁剪参数名和裁剪率组成的字典,当为None时
使用默认裁剪参数名和裁剪率。默认为None。
"""
assert
model
.
status
==
'Normal'
,
'Only the models saved while training are supported!'
place
=
model
.
places
[
0
]
train_prog
=
model
.
train_prog
eval_prog
=
model
.
test_prog
...
...
@@ -235,6 +236,7 @@ def cal_params_sensitivities(model, save_file, eval_dataset, batch_size=8):
其中``weight_0``是卷积Kernel名;``sensitivities['weight_0']``是一个字典,key是裁剪率,value是敏感度。
"""
assert
model
.
status
==
'Normal'
,
'Only the models saved while training are supported!'
if
os
.
path
.
exists
(
save_file
):
os
.
remove
(
save_file
)
...
...
paddlex/cv/models/slim/prune_config.py
浏览文件 @
9dcbe669
...
...
@@ -19,6 +19,8 @@ import paddle.fluid as fluid
import
paddlex
sensitivities_data
=
{
'AlexNet'
:
'https://bj.bcebos.com/paddlex/slim_prune/alexnet_sensitivities.data'
,
'ResNet18'
:
'https://bj.bcebos.com/paddlex/slim_prune/resnet18.sensitivities'
,
'ResNet34'
:
...
...
@@ -41,6 +43,10 @@ sensitivities_data = {
'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_large.sensitivities'
,
'MobileNetV3_small'
:
'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_small.sensitivities'
,
'MobileNetV3_large_ssld'
:
'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_large_ssld_sensitivities.data'
,
'MobileNetV3_small_ssld'
:
'https://bj.bcebos.com/paddlex/slim_prune/mobilenetv3_small_ssld_sensitivities.data'
,
'DenseNet121'
:
'https://bj.bcebos.com/paddlex/slim_prune/densenet121.sensitivities'
,
'DenseNet161'
:
...
...
@@ -51,6 +57,8 @@ sensitivities_data = {
'https://bj.bcebos.com/paddlex/slim_prune/xception41.sensitivities'
,
'Xception65'
:
'https://bj.bcebos.com/paddlex/slim_prune/xception65.sensitivities'
,
'ShuffleNetV2'
:
'https://bj.bcebos.com/paddlex/slim_prune/shufflenetv2_sensitivities.data'
,
'YOLOv3_MobileNetV1'
:
'https://bj.bcebos.com/paddlex/slim_prune/yolov3_mobilenetv1.sensitivities'
,
'YOLOv3_MobileNetV3_large'
:
...
...
@@ -143,7 +151,8 @@ def get_prune_params(model):
if
model_type
.
startswith
(
'ResNet'
)
or
\
model_type
.
startswith
(
'DenseNet'
)
or
\
model_type
.
startswith
(
'DarkNet'
)
or
\
model_type
.
startswith
(
'AlexNet'
):
model_type
.
startswith
(
'AlexNet'
)
or
\
model_type
.
startswith
(
'ShuffleNetV2'
):
for
block
in
program
.
blocks
:
for
param
in
block
.
all_parameters
():
pd_var
=
fluid
.
global_scope
().
find_var
(
param
.
name
)
...
...
@@ -152,6 +161,28 @@ def get_prune_params(model):
prune_names
.
append
(
param
.
name
)
if
model_type
==
'AlexNet'
:
prune_names
.
remove
(
'conv5_weights'
)
if
model_type
==
'ShuffleNetV2'
:
not_prune_names
=
[
'stage_2_1_conv5_weights'
,
'stage_2_1_conv3_weights'
,
'stage_2_2_conv3_weights'
,
'stage_2_3_conv3_weights'
,
'stage_2_4_conv3_weights'
,
'stage_3_1_conv5_weights'
,
'stage_3_1_conv3_weights'
,
'stage_3_2_conv3_weights'
,
'stage_3_3_conv3_weights'
,
'stage_3_4_conv3_weights'
,
'stage_3_5_conv3_weights'
,
'stage_3_6_conv3_weights'
,
'stage_3_7_conv3_weights'
,
'stage_3_8_conv3_weights'
,
'stage_4_1_conv5_weights'
,
'stage_4_1_conv3_weights'
,
'stage_4_2_conv3_weights'
,
'stage_4_3_conv3_weights'
,
'stage_4_4_conv3_weights'
,]
for
name
in
not_prune_names
:
prune_names
.
remove
(
name
)
elif
model_type
==
"MobileNetV1"
:
prune_names
.
append
(
"conv1_weights"
)
for
param
in
program
.
global_block
().
all_parameters
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录