Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
d3aeda6f
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 1 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d3aeda6f
编写于
4月 26, 2021
作者:
M
minghaoBD
提交者:
GitHub
4月 26, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[doc] Update unstructured pruning docs (#727)
上级
98d610f9
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
397 addition
and
61 deletion
+397
-61
demo/dygraph/unstructured_pruning/README.md
demo/dygraph/unstructured_pruning/README.md
+55
-23
demo/dygraph/unstructured_pruning/evaluate.py
demo/dygraph/unstructured_pruning/evaluate.py
+1
-1
demo/dygraph/unstructured_pruning/train.py
demo/dygraph/unstructured_pruning/train.py
+12
-10
demo/unstructured_prune/README.md
demo/unstructured_prune/README.md
+54
-21
demo/unstructured_prune/train.py
demo/unstructured_prune/train.py
+0
-1
docs/zh_cn/api_cn/dygraph/pruners/index.rst
docs/zh_cn/api_cn/dygraph/pruners/index.rst
+1
-0
docs/zh_cn/api_cn/dygraph/pruners/unstructured_pruner.rst
docs/zh_cn/api_cn/dygraph/pruners/unstructured_pruner.rst
+113
-0
docs/zh_cn/api_cn/static/prune/prune_index.rst
docs/zh_cn/api_cn/static/prune/prune_index.rst
+1
-0
docs/zh_cn/api_cn/static/prune/unstructured_prune_api.rst
docs/zh_cn/api_cn/static/prune/unstructured_prune_api.rst
+121
-0
paddleslim/dygraph/prune/unstructured_pruner.py
paddleslim/dygraph/prune/unstructured_pruner.py
+23
-1
paddleslim/prune/unstructured_pruner.py
paddleslim/prune/unstructured_pruner.py
+0
-2
tests/dygraph/test_unstructured_prune.py
tests/dygraph/test_unstructured_prune.py
+15
-1
tests/test_unstructured_pruner.py
tests/test_unstructured_pruner.py
+1
-1
未找到文件。
demo/dygraph/unstructured_pruning/README.md
浏览文件 @
d3aeda6f
# 非结构化稀疏 -- 动态图剪裁(包括按照阈值和比例剪裁两种模式)
# 非结构化稀疏 -- 动态图剪裁(包括按照阈值和比例剪裁两种模式)
示例
## 简介
在模型压缩中,常见的稀疏方式为结构化和非结构化稀疏,前者在某个特定维度(特征通道、卷积核等等)上进行稀疏化操作;后者以每一个参数为单元进行稀疏化,所以更加依赖于硬件对稀疏后矩阵运算的加速能力。本目录即在PaddlePaddle和PaddleSlim框架下开发的非结构化稀疏算法,MobileNetV1在ImageNet上的稀疏化实验中,剪裁率55.19%,达到无损的表现。
在模型压缩中,常见的稀疏方式为结构化和非结构化稀疏,前者在某个特定维度(特征通道、卷积核等等)上进行稀疏化操作;后者以每一个参数为单元进行稀疏化,并不会改变参数矩阵的形状,所以更加依赖于硬件对稀疏后矩阵运算的加速能力。本目录即在PaddlePaddle和PaddleSlim框架下开发的非结构化稀疏算法,
`MobileNetV1`
在
`ImageNet`
上的稀疏化实验中,剪裁率55.19%,达到无损的表现。
本示例将演示基于不同的剪裁模式(阈值/比例)进行非结构化稀疏。默认会自动下载并使用
`CIFAR-10`
数据集。当前示例目前支持
`MobileNetV1`
,使用其他模型可以按照下面的训练代码示例进行API调用。
## 版本要求
```
bash
...
...
@@ -13,12 +15,25 @@ paddleslim>=2.1.0
请参照github安装
[
paddlepaddle
](
https://github.com/PaddlePaddle/Paddle
)
和
[
paddleslim
](
https://github.com/PaddlePaddle/PaddleSlim
)
。
## 使用
## 数据准备
本示例支持
`CIFAR-10`
和
`ImageNet`
两种数据。默认情况下,会自动下载并使用
`CIFAR-10`
数据,如果需要使用
`ImageNet`
数据。请按以下步骤操作:
-
根据分类模型中
[
ImageNet数据准备文档
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E6%95%B0%E6%8D%AE%E5%87%86%E5%A4%87
)
下载数据到
`PaddleSlim/demo/data/ILSVRC2012`
路径下。
-
使用
`train.py`
和
`evaluate.py`
运行脚本时,指定
`--data`
选项为
`imagenet`
。
如果想要使用自定义的数据集,需要重写
`../../imagenet_reader.py`
文件,并在
`train.py`
中调用实现。
## 下载预训练模型
训练前:
-
训练数据下载后,可以通过重写../imagenet_reader.py文件,并在train.py/evaluate.py文件中调用实现。
-
开发者可以通过重写paddleslim.dygraph.prune.unstructured_pruner.py中的UnstructuredPruner.mask_parameters()和UnstructuredPruner.update_threshold()来定义自己的非结构化稀疏策略(目前为剪裁掉绝对值小的parameters)。
-
开发可以在初始化UnstructuredPruner时,传入自定义的skip_params_func,来定义哪些参数不参与剪裁。skip_params_func示例代码如下(路径:paddleslim.dygraph.prune.unstructured_pruner._get_skip_params())。默认为所有的归一化层的参数不参与剪裁。
该示例中直接使用
`paddle.vision.models`
模块提供的针对
`ImageNet`
分类任务的预训练模型。 对预训练好的模型剪裁后,需要在目标数据集上进行重新训练,以便恢复因剪裁损失的精度。
## 自定义稀疏化方法
默认根据参数的绝对值大小进行稀疏化,且不稀疏归一化层参数。如果开发者想更改相应的逻辑,可按照下述操作:
-
开发者可以通过重写
`paddleslim.dygraph.prune.unstructured_pruner.py`
中的
`UnstructuredPruner.mask_parameters()`
和
`UnstructuredPruner.update_threshold()`
来定义自己的非结构化稀疏策略(目前为剪裁掉绝对值小的parameters)。
-
开发可以在初始化
`UnstructuredPruner`
时,传入自定义的
`skip_params_func`
,来定义哪些参数不参与剪裁。
`skip_params_func`
示例代码如下(路径:
`paddleslim.dygraph.prune.unstructured_pruner._get_skip_params())`
。默认为所有的归一化层的参数不参与剪裁。
```
python
def
_get_skip_params
(
model
):
...
...
@@ -39,21 +54,43 @@ def _get_skip_params(model):
return
skip_params
```
训练:
## 训练
按照阈值剪裁:
```
bash
python3.7 train.py
--data
imagenet
--lr
0.05
--pruning_mode
threshold
--threshold
0.01
```
按照比例剪裁(训练速度较慢,推荐按照阈值剪裁):
```
bash
python3.7 train.py
--data
imagenet
--lr
0.05
--pruning_mode
ratio
--ratio
0.5
```
GPU多卡训练:
```
bash
python3 train.py
--data
cifar10
--lr
0.1
--pruning_mode
ratio
--ratio
=
0.5
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
python3.7
-m
paddle.distributed.launch
\
--gpus
=
"0,1,2,3"
\
--log_dir
=
"train_mbv1_imagenet_threshold_001_log"
\
train.py
--data
imagenet
--lr
0.05
--pruning_mode
threshold
--threshold
0.01
```
推理
:
恢复训练(请替代命令中的
`dir/to/the/saved/pruned/model`
和
`INTERRUPTED_EPOCH`
)
:
```
bash
python3
eval
--pruned_model
models/
--data
cifar10
python3.7 train.py
--data
imagenet
--lr
0.05
--pruning_mode
threshold
--threshold
0.01
\
--pretrained_model
dir
/to/the/saved/pruned/model
--resume_epoch
INTERRUPTED_EPOCH
```
## 推理:
```
bash
python3.7
eval
--pruned_model
models/
--data
imagenet
```
剪裁训练代码示例:
```
python
model
=
mobilenet_v1
(
num_classes
=
class_dim
,
pretrained
=
True
)
#STEP1: initialize the pruner
pruner
=
UnstructuredPruner
(
model
,
mode
=
'
ratio'
,
ratio
=
0.5
)
pruner
=
UnstructuredPruner
(
model
,
mode
=
'
threshold'
,
threshold
=
0.01
)
for
epoch
in
range
(
epochs
):
for
batch_id
,
data
in
enumerate
(
train_loader
):
...
...
@@ -80,27 +117,22 @@ for epoch in range(epochs):
```
python
model
=
mobilenet_v1
(
num_classes
=
class_dim
,
pretrained
=
True
)
model
.
set_state_dict
(
paddle
.
load
(
"model-pruned.pdparams"
))
print
(
UnstructuredPruner
.
total_sparse
(
model
))
#注意,total_sparse为静态方法(static method),可以不创建实例(instance)直接调用,方便只做测试的写法。
#注意,total_sparse为静态方法(static method),可以不创建实例(instance)直接调用,方便只做测试的写法。
print
(
UnstructuredPruner
.
total_sparse
(
model
))
test
()
```
更多使用参数请参照shell文件或者运行如下命令查看:
```
bash
python train
--h
python evaluate
--h
python
3.7
train
--h
python
3.7
evaluate
--h
```
## 实验结果
(刚开始在动态图代码验证,以下为静态图代码上的结果)
## 实验结果
| 模型 | 数据集 | 压缩方法 | 压缩率| Top-1/Top-5 Acc | lr | threshold | epoch |
|:--:|:---:|:--:|:--:|:--:|:--:|:--:|:--:|
| MobileNetV1 | ImageNet | Baseline | - | 70.99%/89.68% | - | - | - |
| MobileNetV1 | ImageNet | ratio | -55.19% | 70.87%/89.80% (-0.12%/+0.12%) | 0.005 | - | 68 |
| YOLO v3 | VOC | - | - |76.24% | - | - | - |
| YOLO v3 | VOC |threshold | -41.35% | 75.29%(-0.95%) | 0.005 | 0.05 | 10w |
| YOLO v3 | VOC |threshold | -53.00% | 75.00%(-1.24%) | 0.005 | 0.075 | 10w |
## TODO
-
[ ] 完成实验,验证动态图下的效果,并得到压缩模型。
-
[ ] 扩充衡量parameter重要性的方法(目前仅为绝对值)。
| YOLO v3 | VOC |threshold | -56.50% | 77.02%(+0.78%) | 0.001 | 0.01 | 102k iterations |
demo/dygraph/unstructured_pruning/evaluate.py
浏览文件 @
d3aeda6f
...
...
@@ -5,7 +5,7 @@ import argparse
import
numpy
as
np
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
"__file__"
),
os
.
path
.
pardir
,
os
.
path
.
pardir
))
from
paddleslim
.dygraph.prune.unstructured_pruner
import
UnstructuredPruner
from
paddleslim
import
UnstructuredPruner
from
utility
import
add_arguments
,
print_arguments
import
paddle.vision.transforms
as
T
import
paddle.nn.functional
as
F
...
...
demo/dygraph/unstructured_pruning/train.py
浏览文件 @
d3aeda6f
...
...
@@ -3,7 +3,7 @@ import os
import
sys
import
argparse
import
numpy
as
np
from
paddleslim
.dygraph.prune.unstructured_pruner
import
UnstructuredPruner
from
paddleslim
import
UnstructuredPruner
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
"__file__"
),
os
.
path
.
pardir
,
os
.
path
.
pardir
))
from
utility
import
add_arguments
,
print_arguments
...
...
@@ -35,6 +35,7 @@ parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90],
add_arg
(
'data'
,
str
,
"cifar10"
,
"Which data to use. 'cifar10' or 'imagenet'."
)
add_arg
(
'log_period'
,
int
,
100
,
"Log period in batches."
)
add_arg
(
'test_period'
,
int
,
1
,
"Test period in epoches."
)
add_arg
(
'pretrained_model'
,
str
,
None
,
"The pretrained model the load. Default: None."
)
add_arg
(
'model_path'
,
str
,
"./models"
,
"The path to save model."
)
add_arg
(
'model_period'
,
int
,
10
,
"The period to save model in epochs."
)
add_arg
(
'resume_epoch'
,
int
,
-
1
,
"The epoch to resume training."
)
...
...
@@ -117,12 +118,13 @@ def compress(args):
# model definition
model
=
mobilenet_v1
(
num_classes
=
class_dim
,
pretrained
=
True
)
dp_model
=
paddle
.
DataParallel
(
model
)
if
args
.
pretrained_model
is
not
None
:
model
.
set_state_dict
(
paddle
.
load
(
args
.
pretrained_model
))
opt
,
learning_rate
=
create_optimizer
(
args
,
step_per_epoch
,
dp_
model
)
opt
,
learning_rate
=
create_optimizer
(
args
,
step_per_epoch
,
model
)
def
test
(
epoch
):
dp_
model
.
eval
()
model
.
eval
()
acc_top1_ns
=
[]
acc_top5_ns
=
[]
for
batch_id
,
data
in
enumerate
(
valid_loader
):
...
...
@@ -133,7 +135,7 @@ def compress(args):
y_data
=
paddle
.
unsqueeze
(
y_data
,
1
)
end_time
=
time
.
time
()
logits
=
dp_
model
(
x_data
)
logits
=
model
(
x_data
)
loss
=
F
.
cross_entropy
(
logits
,
y_data
)
acc_top1
=
paddle
.
metric
.
accuracy
(
logits
,
y_data
,
k
=
1
)
acc_top5
=
paddle
.
metric
.
accuracy
(
logits
,
y_data
,
k
=
5
)
...
...
@@ -157,7 +159,7 @@ def compress(args):
acc_top5_ns
,
dtype
=
"object"
))))
def
train
(
epoch
):
dp_
model
.
train
()
model
.
train
()
for
batch_id
,
data
in
enumerate
(
train_loader
):
start_time
=
time
.
time
()
x_data
=
data
[
0
]
...
...
@@ -165,7 +167,7 @@ def compress(args):
if
args
.
data
==
'cifar10'
:
y_data
=
paddle
.
unsqueeze
(
y_data
,
1
)
logits
=
dp_
model
(
x_data
)
logits
=
model
(
x_data
)
loss
=
F
.
cross_entropy
(
logits
,
y_data
)
acc_top1
=
paddle
.
metric
.
accuracy
(
logits
,
y_data
,
k
=
1
)
acc_top5
=
paddle
.
metric
.
accuracy
(
logits
,
y_data
,
k
=
5
)
...
...
@@ -183,7 +185,7 @@ def compress(args):
pruner
.
step
()
pruner
=
UnstructuredPruner
(
dp_
model
,
model
,
mode
=
args
.
pruning_mode
,
ratio
=
args
.
ratio
,
threshold
=
args
.
threshold
)
...
...
@@ -193,11 +195,11 @@ def compress(args):
pruner
.
update_params
()
_logger
.
info
(
"The current density of the pruned model is: {}%"
.
format
(
round
(
100
*
UnstructuredPruner
.
total_sparse
(
dp_
model
),
2
)))
round
(
100
*
UnstructuredPruner
.
total_sparse
(
model
),
2
)))
test
(
i
)
if
i
>
args
.
resume_epoch
and
i
%
args
.
model_period
==
0
:
pruner
.
update_params
()
paddle
.
save
(
dp_
model
.
state_dict
(),
paddle
.
save
(
model
.
state_dict
(),
os
.
path
.
join
(
args
.
model_path
,
"model-pruned.pdparams"
))
paddle
.
save
(
opt
.
state_dict
(),
os
.
path
.
join
(
args
.
model_path
,
"opt-pruned.pdopt"
))
...
...
demo/unstructured_prune/README.md
浏览文件 @
d3aeda6f
# 非结构化稀疏 -- 静态图剪裁(包括按照阈值和比例剪裁两种模式)
# 非结构化稀疏 -- 静态图剪裁(包括按照阈值和比例剪裁两种模式)
示例
## 简介
在模型压缩中,常见的稀疏方式为结构化和非结构化稀疏,前者在某个特定维度(特征通道、卷积核等等)上进行稀疏化操作;后者以每一个参数为单元进行稀疏化,所以更加依赖于硬件对稀疏后矩阵运算的加速能力。本目录即在PaddlePaddle和PaddleSlim框架下开发的非结构化稀疏算法,MobileNetV1在ImageNet上的稀疏化实验中,剪裁率55.19%,达到无损的表现。
在模型压缩中,常见的稀疏方式为结构化和非结构化稀疏,前者在某个特定维度(特征通道、卷积核等等)上进行稀疏化操作;后者以每一个参数为单元进行稀疏化,并不会改变参数矩阵的形状,所以更加依赖于硬件对稀疏后矩阵运算的加速能力。本目录即在PaddlePaddle和PaddleSlim框架下开发的非结构化稀疏算法,
`MobileNetV1`
在
`ImageNet`
上的稀疏化实验中,剪裁率55.19%,达到无损的表现。
本示例将演示基于不同的剪裁模式(阈值/比例)进行非结构化稀疏。默认会自动下载并使用
`MNIST`
数据集。当前示例目前支持
`MobileNetV1`
,使用其他模型可以按照下面的
**训练代码示例**
进>行API调用。
## 版本要求
```
bash
...
...
@@ -11,15 +13,36 @@ paddlepaddle>=2.0.0
paddleslim>
=
2.1.0
```
请参照github安装
[
paddlepaddle
](
https://github.com/PaddlePaddle/Paddle
)
和
[
paddleslim
](
https://github.com/PaddlePaddle/PaddleSlim
)
。
请参照github安装
[
PaddlePaddle
](
https://github.com/PaddlePaddle/Paddle
)
和
[
PaddleSlim
](
https://github.com/PaddlePaddle/PaddleSlim
)
。
## 数据准备
本示例支持
`MNIST`
和
`ImageNet`
两种数据。默认情况下,会自动下载并使用
`MNIST`
数据,如果需要使用
`ImageNet`
数据。请按以下步骤操作:
-
根据分类模型中
[
ImageNet数据准备文档
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E6%95%B0%E6%8D%AE%E5%87%86%E5%A4%87
)
下载数据到
`PaddleSlim/demo/data/ILSVRC2012`
路径下。
-
使用
`train.py`
和
`evaluate.py`
运行脚本时,指定
`--data`
选项为
`imagenet`
。
如果想要使用自定义的数据集,需要重写
`../imagenet_reader.py`
文件,并在
`train.py`
中调用实现。
## 下载预训练模型
如果使用
`ImageNet`
数据,建议在预训练模型的基础上进行剪裁,请从
[
这里
](
http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar
)
下载预训练模型。
下载并解压预训练模型到当前路径:
```
wget http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar
tar -xf MobileNetV1_pretrained.tar
```
使用
`train.py`
脚本时,指定
`--pretrained_model`
加载预训练模型,
`MNIST`
数据无需指定。
## 自定义稀疏化方法
## 使用
默认根据参数的绝对值大小进行稀疏化,且不稀疏归一化层参数。如果开发者想更改相应的逻辑,可按照下述操作:
训练前:
-
预训练模型下载,并放到某目录下,通过train.py中的--pretrained_model设置。
-
训练数据下载后,可以通过重写../imagenet_reader.py文件,并在train.py文件中调用实现。
-
开发者可以通过重写paddleslim.prune.unstructured_pruner.py中的UnstructuredPruner.update_threshold()来定义自己的非结构化稀疏策略(目前为剪裁掉绝对值小的parameters)。
-
开发可以在初始化UnstructuredPruner时,传入自定义的skip_params_func,来定义哪些参数不参与剪裁。skip_params_func示例代码如下(路径:paddleslim.prune.unstructured_pruner._get_skip_params())。默认为所有的归一化层的参数不参与剪裁。
-
可以通过重写
`paddleslim.prune.unstructured_pruner.py`
中的
`UnstructuredPruner.update_threshold()`
来定义自己的非结构化稀疏策略(目前为剪裁掉绝对值小的parameters)。
-
可以在初始化
`UnstructuredPruner`
时,传入自定义的
`skip_params_func`
,来定义哪些参数不参与剪裁。
`skip_params_func`
示例代码如下(路径:
`paddleslim.prune.unstructured_pruner._get_skip_params()`
)。默认为所有的归一化层的参数不参与剪裁。
```
python
def
_get_skip_params
(
program
):
...
...
@@ -41,12 +64,25 @@ def _get_skip_params(program):
return
skip_params
```
训练:
## 训练
按照阈值剪裁:
```
bash
CUDA_VISIBLE_DEVICES
=
2,3 python3.7 train.py
--data
mnist
--lr
0.1
--pruning_mode
ratio
--ratio
=
0.5
CUDA_VISIBLE_DEVICES
=
2,3 python3.7 train.py
--data
imagenet
--lr
0.05
--pruning_mode
threshold
--threshold
0.01
```
推理:
按照比例剪裁(训练速度较慢,推荐按照阈值剪裁):
```
bash
CUDA_VISIBLE_DEVICES
=
2,3 python3.7 train.py
--data
imagenet
--lr
0.05
--pruning_mode
ratio
--ratio
0.5
```
恢复训练(请替代命令中的
`dir/to/the/saved/pruned/model`
和
`INTERRUPTED_EPOCH`
):
```
CUDA_VISIBLE_DEVICES=2,3 python3.7 train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01 \
--pretrained_model dir/to/the/saved/pruned/model --resume_epoch INTERRUPTED_EPOCH
```
## 推理
```
bash
CUDA_VISIBLE_DEVICES
=
0 python3.7 evaluate.py
--pruned_model
models/
--data
imagenet
```
...
...
@@ -70,7 +106,8 @@ opt, learning_rate = create_optimizer(args, step_per_epoch)
opt
.
minimize
(
avg_cost
)
#STEP1: initialize the pruner
pruner
=
UnstructuredPruner
(
paddle
.
static
.
default_main_program
(),
mode
=
'ratio'
,
ratio
=
0.5
,
place
=
place
)
pruner
=
UnstructuredPruner
(
paddle
.
static
.
default_main_program
(),
mode
=
'threshold'
,
threshold
=
0.01
,
place
=
place
)
# 按照阈值剪裁
# pruner = UnstructuredPruner(paddle.static.default_main_program(), mode='ratio', ratio=0.5, place=place) # 按照比例剪裁
exe
.
run
(
paddle
.
static
.
default_startup_program
())
paddle
.
fluid
.
io
.
load_vars
(
exe
,
args
.
pretrained_model
)
...
...
@@ -103,7 +140,8 @@ for epoch in range(epochs):
```
python
# intialize the model instance in static mode
# load weights
print
(
UnstructuredPruner
.
total_sparse
(
paddle
.
static
.
default_main_program
()))
#注意,total_sparse为静态方法(static method),可以不创建实例(instance)直接调用,方便只做测试的写法。
print
(
UnstructuredPruner
.
total_sparse
(
paddle
.
static
.
default_main_program
()))
#注意,total_sparse为静态方法(static method),可以不创建实例(instance)直接调用,方便只做测试的写法。
test
()
```
...
...
@@ -118,11 +156,6 @@ python3.7 evaluate.py --h
| 模型 | 数据集 | 压缩方法 | 压缩率| Top-1/Top-5 Acc | lr | threshold | epoch |
|:--:|:---:|:--:|:--:|:--:|:--:|:--:|:--:|
| MobileNetV1 | ImageNet | Baseline | - | 70.99%/89.68% | - | - | - |
| MobileNetV1 | ImageNet | ratio | -55.19% | 70.87%/89.80% (-0.12%/+0.12%) | 0.0
0
5 | - | 68 |
| MobileNetV1 | ImageNet | ratio | -55.19% | 70.87%/89.80% (-0.12%/+0.12%) | 0.05 | - | 68 |
| YOLO v3 | VOC | - | - |76.24% | - | - | - |
| YOLO v3 | VOC |threshold | -55.15% | 75.45%(-0.79%) | 0.005 | 0.05 |12.8w|
## TODO
-
[ ] 完成实验,验证动态图下的效果,并得到压缩模型。
-
[ ] 扩充衡量parameter重要性的方法(目前仅为绝对值)。
| YOLO v3 | VOC |threshold | -56.50% | 77.02%(+0.78%) | 0.001 | 0.01 |102k iterations|
demo/unstructured_prune/train.py
浏览文件 @
d3aeda6f
...
...
@@ -140,7 +140,6 @@ def compress(args):
pruner
=
UnstructuredPruner
(
paddle
.
static
.
default_main_program
(),
batch_size
=
args
.
batch_size
,
mode
=
args
.
pruning_mode
,
ratio
=
args
.
ratio
,
threshold
=
args
.
threshold
,
...
...
docs/zh_cn/api_cn/dygraph/pruners/index.rst
浏览文件 @
d3aeda6f
...
...
@@ -7,3 +7,4 @@ Pruners
l1norm_filter_pruner.rst
l2norm_filter_pruner.rst
fpgm_filter_pruner.rst
unstructured_pruner.rst
docs/zh_cn/api_cn/dygraph/pruners/unstructured_pruner.rst
0 → 100644
浏览文件 @
d3aeda6f
非结构化稀疏
================
UnstructuredPruner
----------
.. py:class:: paddleslim.UnstructuredPruner(model, mode, threshold=0.01, ratio=0.3, skip_params_func=None)
`源代码 <https://github.com/minghaoBD/PaddleSlim/blob/update_unstructured_pruning_docs/paddleslim/dygraph/prune/unstructured_pruner.py>`_
对于神经网络中的参数进行非结构化稀疏。非结构化稀疏是指,根据某些衡量指标,将不重要的参数置0。其不按照固定结构剪裁(例如一个通道等),这是和结构化剪枝的主要区别。
**参数:**
- **model(paddle.nn.Layer)** - 待剪裁的动态图模型。
- **mode(str)** - 稀疏化的模式,目前支持的模式有:'ratio'和'threshold'。在'ratio'模式下,会给定一个固定比例,例如0.5,然后所有参数中重要性较低的50%会被置0。类似的,在'threshold'模式下,会给定一个固定阈值,例如1e-5,然后重要性低于1e-5的参数会被置0。
- **ratio(float)** - 稀疏化比例期望,只有在 mode=='ratio' 时才会生效。
- **threshold(float)** - 稀疏化阈值期望,只有在 mode=='threshold' 时才会生效。
- **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。
**返回:** 一个UnstructuredPruner类的实例。
**示例代码:**
此示例不能直接运行,因为需要定义和加载模型,详细用法请参考 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/dygraph/unstructured_pruning>`_
.. code-block:: python
from paddleslim import UnstructuredPruner
pruner = UnstructuredPruner(model, mode='ratio', ratio=0.5)
..
.. py:method:: paddleslim.UnstructuredPruner.step()
更新稀疏化的阈值,如果是'threshold'模式,则维持设定的阈值,如果是'ratio'模式,则根据优化后的模型参数和设定的比例,重新计算阈值。
**示例代码:**
此示例不能直接运行,因为需要定义和加载模型,详细用法请参考 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/dygraph/unstructured_pruning>`_
.. code-block:: python
from paddleslim import UnstructuredPruner
pruner = UnstructuredPruner(model, mode='ratio', ratio=0.5)
pruner.step()
..
.. py:method:: paddleslim.UnstructuredPruner.update_params()
每一步优化后,重制模型中本来是0的权重。这一步通常用于模型evaluation和save之前,确保模型的稀疏率。
**示例代码:**
此示例不能直接运行,因为需要定义和加载模型,详细用法请参考 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/dygraph/unstructured_pruning>`_
.. code-block:: python
from paddleslim import UnstructuredPruner
pruner = UnstructuredPruner(model, mode='ratio', ratio=0.5)
pruner.update_params()
..
.. py:method:: paddleslim.UnstructuredPruner.total_sparse(model)
UnstructuredPruner中的静态方法,用于计算给定的模型(model)的稠密度(1-稀疏度)并返回。该方法为静态方法,是考虑到在单单做模型评价的时候,我们就不需要初始化一个UnstructuredPruner示例了。
**参数:**
- **model(paddle.nn.Layer)** - 要计算稠密度的目标网络。
**返回:**
- **density(float)** - 模型的稠密度。
**示例代码:**
此示例不能直接运行,因为需要定义和加载模型,详细用法请参考 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/dygraph/unstructured_pruning>`_
.. code-block:: python
from paddleslim import UnstructuredPruner
density = UnstructuredPruner.total_sparse(model)
..
.. py:method:: paddleslim.UnstructuredPruner.summarize_weights(model, ratio=0.1)
该函数用于估计预训练模型中参数的分布情况,尤其是在不清楚如何设置threshold的数值时,尤为有用。例如,当输入为ratio=0.1时,函数会返回一个数值v,而绝对值小于v的权重的个数占所有权重个数的(100*ratio%)。
**参数:**
- **model(paddle.nn.Layer)** - 要分析权重分布的目标网络。
- **ratio(float)** - 需要查看的比例情况,具体如上方法描述。
**返回:**
- **threshold(float)** - 和输入ratio对应的阈值。开发者可以根据该阈值初始化UnstructuredPruner。
**示例代码:**
此示例不能直接运行,因为需要定义和加载模型,详细用法请参考 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/dygraph/unstructured_pruning>`_
.. code-block:: python
from paddleslim import UnstructuredPruner
pruner = UnstructuredPruner(model, mode='ratio', ratio=0.5)
threshold = pruner.summarize_weights(model, ratio=0.1)
..
docs/zh_cn/api_cn/static/prune/prune_index.rst
浏览文件 @
d3aeda6f
...
...
@@ -6,3 +6,4 @@
:maxdepth: 1
prune_api.rst
unstructured_prune_api.rst
docs/zh_cn/api_cn/static/prune/unstructured_prune_api.rst
0 → 100644
浏览文件 @
d3aeda6f
非结构化稀疏
================
UnstrucuturedPruner
----------
.. py:class:: paddleslim.prune.UnstructuredPruner(program, mode, ratio=0.5, threshold=1e-5, scope=None, place=None, skip_params_func=None)
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/unstructured_pruner.py>`_
对于神经网络中的参数进行非结构化稀疏。非结构化稀疏是指,根据某些衡量指标,将不重要的参数置0。其不按照固定结构剪裁(例如一个通道等),这是和结构化剪枝的主要区别。
**参数:**
- **program(paddle.static.Program)** - 一个paddle.static.Program对象,是待剪裁的模型。
- **mode(str)** - 稀疏化的模式,目前支持的模式有:'ratio'和'threshold'。在'ratio'模式下,会给定一个固定比例,例如0.5,然后所有参数中重要性较低的50%会被置0。类似的,在'threshold'模式下,会给定一个固定阈值,例如1e-5,然后重要性低于1e-5的参数会被置0。
- **ratio(float)** - 稀疏化比例期望,只有在 mode=='ratio' 时才会生效。
- **threshold(float)** - 稀疏化阈值期望,只有在 mode=='threshold' 时才会生效。
- **scope(paddle.static.Scope)** - 一个paddle.static.Scope对象,存储了所有变量的数值,默认(None)时表示paddle.static.global_scope。
- **place(CPUPlace|CUDAPlace)** - 模型执行的设备,类型为CPUPlace或者CUDAPlace,默认(None)时代表CPUPlace。
- **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。
**返回:** 一个UnstructuredPruner类的实例
**示例代码:**
此示例不能直接运行,因为需要加载数据和模型,详细demo请参照 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/unstructured_prune>`_
.. code-block:: python
from paddleslim.prune import UnstructuredPruner
pruner = UnstructuredPruner()
..
.. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.step()
更新稀疏化的阈值,如果是'threshold'模式,则维持设定的阈值,如果是'ratio'模式,则根据优化后的模型参数和设定的比例,重新计算阈值。
**示例代码:**
此示例不能直接运行,因为需要加载数据和模型,详细demo请参照 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/unstructured_prune>`_
.. code-block:: python
from paddleslim.prune import UnstructuredPruner
pruner = UnstructuredPruner(
paddle.static.default_main_program(), 'ratio', scope=paddle.static.global_scope(), place=paddle.static.cpu_places()[0])
pruner.step()
..
.. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.update_params()
每一步优化后,重制模型中本来是0的权重。这一步通常用于模型evaluation和save之前,确保模型的稀疏率。
**示例代码:**
此示例不能直接运行,因为需要加载数据和模型,详细demo请参照 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/unstructured_prune>`_
.. code-block:: python
from paddleslim.prune import UnstructuredPruner
pruner = UnstructuredPruner(
paddle.static.default_main_program(), 'ratio', scope=paddle.static.global_scope(), place=paddle.static.cpu_places()[0])
pruner.update_params()
..
.. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.total_sparse(program)
UnstructuredPruner中的静态方法,用于计算给定的模型(program)的稠密度(1-稀疏度)并返回。该方法为静态方法,是考虑到在单单做模型评价的时候,我们就不需要初始化一个UnstructuredPruner示例了。
**参数:**
- **program(paddle.static.Program)** - 要计算稠密度的目标网络。
**返回:**
- **density(float)** - 模型的稠密度。
**示例代码:**
此示例不能直接运行,因为需要加载数据和模型,详细demo请参照 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/unstructured_prune>`_
.. code-block:: python
from paddleslim.prune import UnstructuredPruner
density = UnstructuredPruner.total_sparse(paddle.static.default_main_program())
..
.. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.summarize_weights(program, ratio=0.1)
该函数用于估计预训练模型中参数的分布情况,尤其是在不清楚如何设置threshold的数值时,尤为有用。例如,当输入为ratio=0.1时,函数会返回一个数值v,而绝对值小于v的权重的个数占所有权重个数的(100*ratio%)。
**参数:**
- **program(paddle.static.Program)** - 要分析权重分布的目标网络。
- **ratio(float)** - 需要查看的比例情况,具体如上方法描述。
**返回:**
- **threshold(float)** - 和输入ratio对应的阈值。开发者可以根据该阈值初始化UnstructuredPruner。
**示例代码:**
此示例不能直接运行,因为需要加载数据和模型,详细demo请参照 `这里 <https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/unstructured_prune>`_
.. code-block:: python
from paddleslim.prune import UnstructuredPruner
pruner = UnstructuredPruner(
paddle.static.default_main_program(), 'ratio', scope=paddle.static.global_scope(), place=paddle.static.cpu_places()[0])
threshold = pruner.summarize_weights(paddle.static.default_main_program(), 1.0)
..
paddleslim/dygraph/prune/unstructured_pruner.py
浏览文件 @
d3aeda6f
...
...
@@ -80,6 +80,28 @@ class UnstructuredPruner():
self
.
threshold
=
np
.
sort
(
np
.
abs
(
params_flatten
))[
max
(
0
,
round
(
self
.
ratio
*
total_length
)
-
1
)].
item
()
def
summarize_weights
(
self
,
model
,
ratio
=
0.1
):
"""
The function is used to get the weights corresponding to a given ratio
when you are uncertain about the threshold in __init__() function above.
For example, when given 0.1 as ratio, the function will print the weight value,
the abs(weights) lower than which count for 10% of the total numbers.
Args:
- model(paddle.nn.Layer): The model which have all the parameters.
- ratio(float): The ratio illustrated above.
Return:
- threshold(float): a threshold corresponding to the input ratio.
"""
data
=
[]
for
name
,
sub_layer
in
model
.
named_sublayers
():
if
not
self
.
_should_prune_layer
(
sub_layer
):
continue
for
param
in
sub_layer
.
parameters
(
include_sublayers
=
False
):
data
.
append
(
np
.
array
(
param
.
value
().
get_tensor
()).
flatten
())
data
=
np
.
concatenate
(
data
,
axis
=
0
)
threshold
=
np
.
sort
(
np
.
abs
(
data
))[
max
(
0
,
int
(
ratio
*
len
(
data
)
-
1
))]
return
threshold
def
step
(
self
):
"""
Update the threshold after each optimization step.
...
...
@@ -116,7 +138,7 @@ class UnstructuredPruner():
It is static because during testing, we can calculate sparsity without initializing a pruner instance.
Args:
- model(
Paddle.Model
): The sparse model.
- model(
paddle.nn.Layer
): The sparse model.
Returns:
- ratio(float): The model's density.
"""
...
...
paddleslim/prune/unstructured_pruner.py
浏览文件 @
d3aeda6f
...
...
@@ -12,7 +12,6 @@ class UnstructuredPruner():
Args:
- program(paddle.static.Program): The model to be pruned.
- batch_size(int): batch size.
- mode(str): the mode to prune the model, must be selected from 'ratio' and 'threshold'.
- ratio(float): the ratio to prune the model. Only set it when mode=='ratio'. Default: 0.5.
- threshold(float): the threshold to prune the model. Only set it when mode=='threshold'. Default: 1e-5.
...
...
@@ -23,7 +22,6 @@ class UnstructuredPruner():
def
__init__
(
self
,
program
,
batch_size
,
mode
,
ratio
=
0.5
,
threshold
=
1e-5
,
...
...
tests/dygraph/test_unstructured_prune.py
浏览文件 @
d3aeda6f
...
...
@@ -3,7 +3,7 @@ sys.path.append("../../")
import
unittest
import
paddle
import
numpy
as
np
from
paddleslim
.dygraph.prune.unstructured_pruner
import
UnstructuredPruner
from
paddleslim
import
UnstructuredPruner
from
paddle.vision.models
import
mobilenet_v1
...
...
@@ -37,6 +37,20 @@ class TestUnstructuredPruner(unittest.TestCase):
self
.
pruner
.
update_params
()
self
.
assertEqual
(
cur_density
,
UnstructuredPruner
.
total_sparse
(
self
.
net
))
def
test_summarize_weights
(
self
):
max_value
=
-
float
(
"inf"
)
threshold
=
self
.
pruner
.
summarize_weights
(
self
.
net
,
1.0
)
for
name
,
sub_layer
in
self
.
net
.
named_sublayers
():
if
not
self
.
pruner
.
_should_prune_layer
(
sub_layer
):
continue
for
param
in
sub_layer
.
parameters
(
include_sublayers
=
False
):
max_value
=
max
(
max_value
,
np
.
max
(
np
.
abs
(
np
.
array
(
param
.
value
().
get_tensor
()))))
print
(
"The returned threshold is {}."
.
format
(
threshold
))
print
(
"The max_value is {}."
.
format
(
max_value
))
self
.
assertEqual
(
max_value
,
threshold
)
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/test_unstructured_pruner.py
浏览文件 @
d3aeda6f
...
...
@@ -42,7 +42,7 @@ class TestUnstructuredPruner(StaticCase):
exe
.
run
(
self
.
startup_program
,
scope
=
self
.
scope
)
self
.
pruner
=
UnstructuredPruner
(
self
.
main_program
,
16
,
'ratio'
,
scope
=
self
.
scope
,
place
=
place
)
self
.
main_program
,
'ratio'
,
scope
=
self
.
scope
,
place
=
place
)
def
test_unstructured_prune
(
self
):
for
param
in
self
.
main_program
.
global_block
().
all_parameters
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录