From f3b4c238b1b645fa0095e7000c0072abb54005ca Mon Sep 17 00:00:00 2001
From: minghaoBD <79566150+minghaoBD@users.noreply.github.com>
Date: Tue, 23 Nov 2021 16:53:38 +0800
Subject: [PATCH] Unstructured prune (#4653) (#4675)
* unstructured prune for picodet
---
configs/picodet/README.md | 9 ++
configs/picodet/README_PRUNER.md | 135 ++++++++++++++++++
.../picodet/pruner/optimizer_300e_pruner.yml | 18 +++
.../pruner/picodet_m_320_coco_pruner.yml | 13 ++
.../prune/picodet_m_unstructured_prune_75.yml | 11 ++
.../prune/picodet_m_unstructured_prune_85.yml | 11 ++
ppdet/engine/trainer.py | 12 +-
ppdet/slim/__init__.py | 8 ++
ppdet/slim/unstructured_prune.py | 66 +++++++++
9 files changed, 281 insertions(+), 2 deletions(-)
create mode 100644 configs/picodet/README_PRUNER.md
create mode 100644 configs/picodet/pruner/optimizer_300e_pruner.yml
create mode 100644 configs/picodet/pruner/picodet_m_320_coco_pruner.yml
create mode 100644 configs/slim/prune/picodet_m_unstructured_prune_75.yml
create mode 100644 configs/slim/prune/picodet_m_unstructured_prune_85.yml
create mode 100644 ppdet/slim/unstructured_prune.py
diff --git a/configs/picodet/README.md b/configs/picodet/README.md
index aa8f2f8b0..dc6d0dc75 100644
--- a/configs/picodet/README.md
+++ b/configs/picodet/README.md
@@ -267,6 +267,15 @@ python tools/post_quant.py -c configs/picodet/picodet_s_320_coco.yml \
+## Unstructured Pruning
+
+
+Toturial:
+
+Please refer this [documentation](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/README_PRUNER.md) for details such as requirements, training and deployment.
+
+
+
## Application
- **Pedestrian detection:** model zoo of `PicoDet-S-Pedestrian` please refer to [PP-TinyPose](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.3/configs/keypoint/tiny_pose#%E8%A1%8C%E4%BA%BA%E6%A3%80%E6%B5%8B%E6%A8%A1%E5%9E%8B)
diff --git a/configs/picodet/README_PRUNER.md b/configs/picodet/README_PRUNER.md
new file mode 100644
index 000000000..e62ed100d
--- /dev/null
+++ b/configs/picodet/README_PRUNER.md
@@ -0,0 +1,135 @@
+# 非结构化稀疏在 PicoDet 上的应用教程
+
+## 1. 介绍
+在模型压缩中,常见的稀疏方式为结构化稀疏和非结构化稀疏,前者在某个特定维度(特征通道、卷积核等等)上对卷积、矩阵乘法进行剪枝操作,然后生成一个更小的模型结构,这样可以复用已有的卷积、矩阵乘计算,无需特殊实现推理算子;后者以每一个参数为单元进行稀疏化,然而并不会改变参数矩阵的形状,所以更依赖于推理库、硬件对于稀疏后矩阵运算的加速能力。我们在 PP-PicoDet (以下简称PicoDet) 模型上运用了非结构化稀疏技术,在精度损失较小时,获得了在 ARM CPU 端推理的显著性能提升。本文档会介绍如何非结构化稀疏训练 PicoDet,关于非结构化稀疏的更多介绍请参照[这里](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/dygraph/unstructured_pruning)。
+
+## 2. 版本要求
+```bash
+PaddlePaddle >= 2.1.2
+PaddleSlim develop分支 (pip install paddleslim -i https://pypi.tuna.tsinghua.edu.cn/simple)
+```
+
+## 3. 数据准备
+同 PicoDet
+
+## 4. 预训练模型
+在非结构化稀疏训练中,我们规定预训练模型是已经收敛完成的模型参数,所以需要额外在相关配置文件中声明。
+
+声明预训练模型地址的配置文件:./configs/picodet/pruner/picodet_m_320_coco_pruner.yml
+预训练模型地址请参照 PicoDet 文档:./configs/picodet/README.md
+
+## 5. 自定义稀疏化的作用范围
+为达到最佳推理加速效果,我们建议只对 1x1 卷积层进行稀疏化,其他层参数保持稠密。另外,有些层对于精度影响较大(例如head的最后几层,se-block的若干层),我们同样不建议对他们进行稀疏化,我们支持开发者通过传入自定义函数的形式,方便的指定哪些层不参与稀疏。例如,基于picodet_m_320这个模型,我们稀疏时跳过了后4层卷积以及6层se-block中的卷积,自定义函数如下:
+
+```python
+NORMS_ALL = [ 'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'BatchNorm1D',
+ 'BatchNorm2D', 'BatchNorm3D', 'InstanceNorm1D', 'InstanceNorm2D',
+ 'InstanceNorm3D', 'SyncBatchNorm', 'LocalResponseNorm' ]
+
+def skip_params_self(model):
+ skip_params = set()
+ for _, sub_layer in model.named_sublayers():
+ if type(sub_layer).__name__.split('.')[-1] in NORMS_ALL:
+ skip_params.add(sub_layer.full_name())
+ for param in sub_layer.parameters(include_sublayers=False):
+ cond_is_conv1x1 = len(param.shape) == 4 and param.shape[2] == 1 and param.shape[3] == 1
+ cond_is_head_m = cond_is_conv1x1 and param.shape[0] == 112 and param.shape[1] == 128
+ cond_is_se_block_m = param.name.split('.')[0] in ['conv2d_17', 'conv2d_18', 'conv2d_56', 'conv2d_57', 'conv2d_75', 'conv2d_76']
+ if not cond_is_conv1x1 or cond_is_head_m or cond_is_se_block_m:
+ skip_params.add(param.name)
+ return skip_params
+```
+
+## 6. 训练
+我们已经将非结构化稀疏的核心功能通过 API 调用的方式嵌入到了训练中,所以如果您没有更细节的需求,直接运行 6.1 的命令启动训练即可。同时,为帮助您根据自己的需求更改、适配代码,我们也提供了更为详细的使用介绍,请参照 6.2。
+
+### 6.1 直接使用
+```bash
+export CUDA_VISIBLE_DEVICES=0,1,2,3
+python3.7 -m paddle.distributed.launch --log_dir=log_test --gpus 0,1,2,3 tools/train.py -c configs/picodet/pruner/picodet_m_320_coco_pruner.yml --slim_config configs/slim/prune/picodet_m_unstructured_prune_75.yml --eval
+```
+
+### 6.2 详细介绍
+- 自定义稀疏化的作用范围:可以参照本教程的第 5 节
+- 如何添加稀疏化训练所需的 4 行代码
+
+```python
+# after constructing model and before training
+
+# Pruner Step1: configs
+configs = {
+ 'pruning_strategy': 'gmp',
+ 'stable_iterations': self.stable_epochs * steps_per_epoch,
+ 'pruning_iterations': self.pruning_epochs * steps_per_epoch,
+ 'tunning_iterations': self.tunning_epochs * steps_per_epoch,
+ 'resume_iteration': 0,
+ 'pruning_steps': self.pruning_steps,
+ 'initial_ratio': self.initial_ratio,
+}
+
+# Pruner Step2: construct a pruner object
+self.pruner = GMPUnstructuredPruner(
+ model,
+ ratio=self.cfg.ratio,
+ skip_params_func=skip_params_self, # Only pass in this value when you design your own skip_params function. And the following argument (skip_params_type) will be ignored.
+ skip_params_type=self.cfg.skip_params_type,
+ local_sparsity=True,
+ configs=configs)
+
+# training
+for epoch_id in range(self.start_epoch, self.cfg.epoch):
+ model.train()
+ for step_id, data in enumerate(self.loader):
+ # model forward
+ outputs = model(data)
+ loss = outputs['loss']
+ # model backward
+ loss.backward()
+ self.optimizer.step()
+
+ # Pruner Step3: step during training
+ self.pruner.step()
+
+ # Pruner Step4: save the sparse model
+ self.pruner.update_params()
+ # model-saving API
+```
+
+## 7. 模型评估与推理部署
+这部分与 PicoDet 文档中基本一致,只是在转换到 PaddleLite 模型时,需要添加一个输入参数(sparse_model):
+
+```bash
+paddle_lite_opt --model_dir=inference_model/picodet_m_320_coco --valid_targets=arm --optimize_out=picodet_m_320_coco_fp32_sparse --sparse_model=True
+```
+
+**注意:** 目前稀疏化推理适用于 PaddleLite的 FP32 和 INT8 模型,所以执行上述命令时,请不要打开 FP16 开关。
+
+## 8. 稀疏化结果
+我们在75%和85%稀疏度下,训练得到了 FP32 PicoDet-m模型,并在 SnapDragon-835设备上实测推理速度,效果如下表。其中:
+- 对于 m 模型,mAP损失1.5,获得了 34\%-58\% 的加速性能
+- 同样对于 m 模型,除4线程推理速度基本持平外,单线程推理速度、mAP、模型体积均优于 s 模型。
+
+
+| Model | Input size | Sparsity | mAPval
0.5:0.95 | Size
(MB) | Latency single-thread[Lite](#latency)
(ms) | speed-up single-thread | Latency 4-thread[Lite](#latency)
(ms) | speed-up 4-thread | Download | SlimConfig |
+| :-------- | :--------: |:--------: | :---------------------: | :----------------: | :----------------: |:----------------: | :---------------: | :-----------------------------: | :-----------------------------: | :----------------------------------------: |
+| PicoDet-m-1.0 | 320*320 | 0 | 30.9 | 8.9 | 127 | 0 | 43 | 0 | [model](https://paddledet.bj.bcebos.com/models/picodet_m_320_coco.pdparams)| [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.3/configs/picodet/picodet_m_320_coco.yml)|
+| PicoDet-m-1.0 | 320*320 | 75% | 29.4 | 5.6 | **80** | 58% | **32** | 34% | [model](https://paddledet.bj.bcebos.com/models/slim/picodet_m_320__coco_sparse_75.pdparams)| [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_320__coco_sparse_75.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/slim/prune/picodet_m_unstructured_prune_75.yml)|
+| PicoDet-s-1.0 | 320*320 | 0 | 27.1 | 4.6 | 68 | 0 | 26 | 0 | [model](https://paddledet.bj.bcebos.com/models/picodet_s_320_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.3/configs/picodet/picodet_s_320_coco.yml)|
+| PicoDet-m-1.0 | 320*320 | 85% | 27.6 | 4.1 | **65** | 96% | **27** | 59% | [model](https://paddledet.bj.bcebos.com/models/slim/picodet_m_320__coco_sparse_85.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_320__coco_sparse_85.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/slim/prune/picodet_m_unstructured_prune_85.yml)|
+
+**注意:**
+- 上述模型体积是**部署模型体积**,即 PaddleLite 转换得到的 *.nb 文件的体积。
+- 加速一栏我们按照 FPS 增加百分比计算,即:$(dense\_latency - sparse\_latency) / sparse\_latency$
+- 上述稀疏化训练时,我们额外添加了一种数据增强方式到 _base_/picodet_320_reader.yml,代码如下。但是不添加的话,预期mAP也不会有明显下降(<0.1),且对速度和模型体积没有影响。
+```yaml
+worker_num: 6
+TrainReader:
+ sample_transforms:
+ - Decode: {}
+ - RandomCrop: {}
+ - RandomFlip: {prob: 0.5}
+ - RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
+ - RandomDistort: {}
+ batch_transforms:
+etc.
+```
diff --git a/configs/picodet/pruner/optimizer_300e_pruner.yml b/configs/picodet/pruner/optimizer_300e_pruner.yml
new file mode 100644
index 000000000..064d56233
--- /dev/null
+++ b/configs/picodet/pruner/optimizer_300e_pruner.yml
@@ -0,0 +1,18 @@
+epoch: 300
+
+LearningRate:
+ base_lr: 0.15
+ schedulers:
+ - !CosineDecay
+ max_epochs: 300
+ - !LinearWarmup
+ start_factor: 1.0
+ steps: 34350
+
+OptimizerBuilder:
+ optimizer:
+ momentum: 0.9
+ type: Momentum
+ regularizer:
+ factor: 0.00004
+ type: L2
diff --git a/configs/picodet/pruner/picodet_m_320_coco_pruner.yml b/configs/picodet/pruner/picodet_m_320_coco_pruner.yml
new file mode 100644
index 000000000..cc357048d
--- /dev/null
+++ b/configs/picodet/pruner/picodet_m_320_coco_pruner.yml
@@ -0,0 +1,13 @@
+_BASE_: [
+ '../../datasets/coco_detection.yml',
+ '../../runtime.yml',
+ '../_base_/picodet_esnet.yml',
+ './optimizer_300e_pruner.yml',
+ '../_base_/picodet_320_reader.yml',
+]
+
+weights: output/picodet_m_320_coco/model_final
+find_unused_parameters: True
+use_ema: true
+cycle_epoch: 40
+snapshot_epoch: 10
diff --git a/configs/slim/prune/picodet_m_unstructured_prune_75.yml b/configs/slim/prune/picodet_m_unstructured_prune_75.yml
new file mode 100644
index 000000000..94345b4e8
--- /dev/null
+++ b/configs/slim/prune/picodet_m_unstructured_prune_75.yml
@@ -0,0 +1,11 @@
+pretrain_weights: https://paddledet.bj.bcebos.com/models/picodet_m_320_coco.pdparams
+slim: UnstructuredPruner
+
+UnstructuredPruner:
+ stable_epochs: 0
+ pruning_epochs: 150
+ tunning_epochs: 150
+ pruning_steps: 300
+ ratio: 0.75
+ initial_ratio: 0.15
+ prune_params_type: conv1x1_only
diff --git a/configs/slim/prune/picodet_m_unstructured_prune_85.yml b/configs/slim/prune/picodet_m_unstructured_prune_85.yml
new file mode 100644
index 000000000..db0af7e10
--- /dev/null
+++ b/configs/slim/prune/picodet_m_unstructured_prune_85.yml
@@ -0,0 +1,11 @@
+pretrain_weights: https://paddledet.bj.bcebos.com/models/picodet_m_320_coco.pdparams
+slim: UnstructuredPruner
+
+UnstructuredPruner:
+ stable_epochs: 0
+ pruning_epochs: 150
+ tunning_epochs: 150
+ pruning_steps: 300
+ ratio: 0.85
+ initial_ratio: 0.20
+ prune_params_type: conv1x1_only
diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py
index 2d0348a8f..dc739ff62 100644
--- a/ppdet/engine/trainer.py
+++ b/ppdet/engine/trainer.py
@@ -81,7 +81,8 @@ class Trainer(object):
# JDE only support single class MOT now.
if cfg.architecture == 'FairMOT' and self.mode == 'train':
- cfg['FairMOTEmbeddingHead']['num_identities_dict'] = self.dataset.num_identities_dict
+ cfg['FairMOTEmbeddingHead'][
+ 'num_identities_dict'] = self.dataset.num_identities_dict
# FairMOT support single class and multi-class MOT now.
# build model
@@ -119,6 +120,10 @@ class Trainer(object):
self.lr = create('LearningRate')(steps_per_epoch)
self.optimizer = create('OptimizerBuilder')(self.lr, self.model)
+ if self.cfg.get('unstructured_prune'):
+ self.pruner = create('UnstructuredPruner')(self.model,
+ steps_per_epoch)
+
self._nranks = dist.get_world_size()
self._local_rank = dist.get_rank()
@@ -395,9 +400,10 @@ class Trainer(object):
# model backward
loss.backward()
self.optimizer.step()
-
curr_lr = self.optimizer.get_lr()
self.lr.step()
+ if self.cfg.get('unstructured_prune'):
+ self.pruner.step()
self.optimizer.clear_grad()
self.status['learning_rate'] = curr_lr
@@ -414,6 +420,8 @@ class Trainer(object):
if self.use_ema:
weight = copy.deepcopy(self.model.state_dict())
self.model.set_dict(self.ema.apply())
+ if self.cfg.get('unstructured_prune'):
+ self.pruner.update_params()
self._compose_callback.on_epoch_end(self.status)
diff --git a/ppdet/slim/__init__.py b/ppdet/slim/__init__.py
index a3c6ac763..dc22d0717 100644
--- a/ppdet/slim/__init__.py
+++ b/ppdet/slim/__init__.py
@@ -15,10 +15,12 @@
from . import prune
from . import quant
from . import distill
+from . import unstructured_prune
from .prune import *
from .quant import *
from .distill import *
+from .unstructured_prune import *
import yaml
from ppdet.core.workspace import load_config
@@ -56,6 +58,12 @@ def build_slim_model(cfg, slim_cfg, mode='train'):
cfg['slim_type'] = cfg.slim
cfg['model'] = slim(model)
cfg['slim'] = slim
+ elif slim_load_cfg['slim'] == 'UnstructuredPruner':
+ load_config(slim_cfg)
+ slim = create(cfg.slim)
+ cfg['slim_type'] = cfg.slim
+ cfg['slim'] = slim
+ cfg['unstructured_prune'] = True
else:
load_config(slim_cfg)
model = create(cfg.architecture)
diff --git a/ppdet/slim/unstructured_prune.py b/ppdet/slim/unstructured_prune.py
new file mode 100644
index 000000000..1dc876a8c
--- /dev/null
+++ b/ppdet/slim/unstructured_prune.py
@@ -0,0 +1,66 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from paddle.utils import try_import
+
+from ppdet.core.workspace import register, serializable
+from ppdet.utils.logger import setup_logger
+logger = setup_logger(__name__)
+
+
+@register
+@serializable
+class UnstructuredPruner(object):
+ def __init__(self,
+ stable_epochs,
+ pruning_epochs,
+ tunning_epochs,
+ pruning_steps,
+ ratio,
+ initial_ratio,
+ prune_params_type=None):
+ self.stable_epochs = stable_epochs
+ self.pruning_epochs = pruning_epochs
+ self.tunning_epochs = tunning_epochs
+ self.ratio = ratio
+ self.prune_params_type = prune_params_type
+ self.initial_ratio = initial_ratio
+ self.pruning_steps = pruning_steps
+
+ def __call__(self, model, steps_per_epoch, skip_params_func=None):
+ paddleslim = try_import('paddleslim')
+ from paddleslim import GMPUnstructuredPruner
+ configs = {
+ 'pruning_strategy': 'gmp',
+ 'stable_iterations': self.stable_epochs * steps_per_epoch,
+ 'pruning_iterations': self.pruning_epochs * steps_per_epoch,
+ 'tunning_iterations': self.tunning_epochs * steps_per_epoch,
+ 'resume_iteration': 0,
+ 'pruning_steps': self.pruning_steps,
+ 'initial_ratio': self.initial_ratio,
+ }
+
+ pruner = GMPUnstructuredPruner(
+ model,
+ ratio=self.ratio,
+ skip_params_func=skip_params_func,
+ prune_params_type=self.prune_params_type,
+ local_sparsity=True,
+ configs=configs)
+
+ return pruner
--
GitLab