From f3b4c238b1b645fa0095e7000c0072abb54005ca Mon Sep 17 00:00:00 2001 From: minghaoBD <79566150+minghaoBD@users.noreply.github.com> Date: Tue, 23 Nov 2021 16:53:38 +0800 Subject: [PATCH] Unstructured prune (#4653) (#4675) * unstructured prune for picodet --- configs/picodet/README.md | 9 ++ configs/picodet/README_PRUNER.md | 135 ++++++++++++++++++ .../picodet/pruner/optimizer_300e_pruner.yml | 18 +++ .../pruner/picodet_m_320_coco_pruner.yml | 13 ++ .../prune/picodet_m_unstructured_prune_75.yml | 11 ++ .../prune/picodet_m_unstructured_prune_85.yml | 11 ++ ppdet/engine/trainer.py | 12 +- ppdet/slim/__init__.py | 8 ++ ppdet/slim/unstructured_prune.py | 66 +++++++++ 9 files changed, 281 insertions(+), 2 deletions(-) create mode 100644 configs/picodet/README_PRUNER.md create mode 100644 configs/picodet/pruner/optimizer_300e_pruner.yml create mode 100644 configs/picodet/pruner/picodet_m_320_coco_pruner.yml create mode 100644 configs/slim/prune/picodet_m_unstructured_prune_75.yml create mode 100644 configs/slim/prune/picodet_m_unstructured_prune_85.yml create mode 100644 ppdet/slim/unstructured_prune.py diff --git a/configs/picodet/README.md b/configs/picodet/README.md index aa8f2f8b0..dc6d0dc75 100644 --- a/configs/picodet/README.md +++ b/configs/picodet/README.md @@ -267,6 +267,15 @@ python tools/post_quant.py -c configs/picodet/picodet_s_320_coco.yml \ +## Unstructured Pruning + +
+Toturial: + +Please refer this [documentation](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/README_PRUNER.md) for details such as requirements, training and deployment. + +
+ ## Application - **Pedestrian detection:** model zoo of `PicoDet-S-Pedestrian` please refer to [PP-TinyPose](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.3/configs/keypoint/tiny_pose#%E8%A1%8C%E4%BA%BA%E6%A3%80%E6%B5%8B%E6%A8%A1%E5%9E%8B) diff --git a/configs/picodet/README_PRUNER.md b/configs/picodet/README_PRUNER.md new file mode 100644 index 000000000..e62ed100d --- /dev/null +++ b/configs/picodet/README_PRUNER.md @@ -0,0 +1,135 @@ +# 非结构化稀疏在 PicoDet 上的应用教程 + +## 1. 介绍 +在模型压缩中,常见的稀疏方式为结构化稀疏和非结构化稀疏,前者在某个特定维度(特征通道、卷积核等等)上对卷积、矩阵乘法进行剪枝操作,然后生成一个更小的模型结构,这样可以复用已有的卷积、矩阵乘计算,无需特殊实现推理算子;后者以每一个参数为单元进行稀疏化,然而并不会改变参数矩阵的形状,所以更依赖于推理库、硬件对于稀疏后矩阵运算的加速能力。我们在 PP-PicoDet (以下简称PicoDet) 模型上运用了非结构化稀疏技术,在精度损失较小时,获得了在 ARM CPU 端推理的显著性能提升。本文档会介绍如何非结构化稀疏训练 PicoDet,关于非结构化稀疏的更多介绍请参照[这里](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/dygraph/unstructured_pruning)。 + +## 2. 版本要求 +```bash +PaddlePaddle >= 2.1.2 +PaddleSlim develop分支 (pip install paddleslim -i https://pypi.tuna.tsinghua.edu.cn/simple) +``` + +## 3. 数据准备 +同 PicoDet + +## 4. 预训练模型 +在非结构化稀疏训练中,我们规定预训练模型是已经收敛完成的模型参数,所以需要额外在相关配置文件中声明。 + +声明预训练模型地址的配置文件:./configs/picodet/pruner/picodet_m_320_coco_pruner.yml +预训练模型地址请参照 PicoDet 文档:./configs/picodet/README.md + +## 5. 自定义稀疏化的作用范围 +为达到最佳推理加速效果,我们建议只对 1x1 卷积层进行稀疏化,其他层参数保持稠密。另外,有些层对于精度影响较大(例如head的最后几层,se-block的若干层),我们同样不建议对他们进行稀疏化,我们支持开发者通过传入自定义函数的形式,方便的指定哪些层不参与稀疏。例如,基于picodet_m_320这个模型,我们稀疏时跳过了后4层卷积以及6层se-block中的卷积,自定义函数如下: + +```python +NORMS_ALL = [ 'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'BatchNorm1D', + 'BatchNorm2D', 'BatchNorm3D', 'InstanceNorm1D', 'InstanceNorm2D', + 'InstanceNorm3D', 'SyncBatchNorm', 'LocalResponseNorm' ] + +def skip_params_self(model): + skip_params = set() + for _, sub_layer in model.named_sublayers(): + if type(sub_layer).__name__.split('.')[-1] in NORMS_ALL: + skip_params.add(sub_layer.full_name()) + for param in sub_layer.parameters(include_sublayers=False): + cond_is_conv1x1 = len(param.shape) == 4 and param.shape[2] == 1 and param.shape[3] == 1 + cond_is_head_m = cond_is_conv1x1 and param.shape[0] == 112 and param.shape[1] == 128 + cond_is_se_block_m = param.name.split('.')[0] in ['conv2d_17', 'conv2d_18', 'conv2d_56', 'conv2d_57', 'conv2d_75', 'conv2d_76'] + if not cond_is_conv1x1 or cond_is_head_m or cond_is_se_block_m: + skip_params.add(param.name) + return skip_params +``` + +## 6. 训练 +我们已经将非结构化稀疏的核心功能通过 API 调用的方式嵌入到了训练中,所以如果您没有更细节的需求,直接运行 6.1 的命令启动训练即可。同时,为帮助您根据自己的需求更改、适配代码,我们也提供了更为详细的使用介绍,请参照 6.2。 + +### 6.1 直接使用 +```bash +export CUDA_VISIBLE_DEVICES=0,1,2,3 +python3.7 -m paddle.distributed.launch --log_dir=log_test --gpus 0,1,2,3 tools/train.py -c configs/picodet/pruner/picodet_m_320_coco_pruner.yml --slim_config configs/slim/prune/picodet_m_unstructured_prune_75.yml --eval +``` + +### 6.2 详细介绍 +- 自定义稀疏化的作用范围:可以参照本教程的第 5 节 +- 如何添加稀疏化训练所需的 4 行代码 + +```python +# after constructing model and before training + +# Pruner Step1: configs +configs = { + 'pruning_strategy': 'gmp', + 'stable_iterations': self.stable_epochs * steps_per_epoch, + 'pruning_iterations': self.pruning_epochs * steps_per_epoch, + 'tunning_iterations': self.tunning_epochs * steps_per_epoch, + 'resume_iteration': 0, + 'pruning_steps': self.pruning_steps, + 'initial_ratio': self.initial_ratio, +} + +# Pruner Step2: construct a pruner object +self.pruner = GMPUnstructuredPruner( + model, + ratio=self.cfg.ratio, + skip_params_func=skip_params_self, # Only pass in this value when you design your own skip_params function. And the following argument (skip_params_type) will be ignored. + skip_params_type=self.cfg.skip_params_type, + local_sparsity=True, + configs=configs) + +# training +for epoch_id in range(self.start_epoch, self.cfg.epoch): + model.train() + for step_id, data in enumerate(self.loader): + # model forward + outputs = model(data) + loss = outputs['loss'] + # model backward + loss.backward() + self.optimizer.step() + + # Pruner Step3: step during training + self.pruner.step() + + # Pruner Step4: save the sparse model + self.pruner.update_params() + # model-saving API +``` + +## 7. 模型评估与推理部署 +这部分与 PicoDet 文档中基本一致,只是在转换到 PaddleLite 模型时,需要添加一个输入参数(sparse_model): + +```bash +paddle_lite_opt --model_dir=inference_model/picodet_m_320_coco --valid_targets=arm --optimize_out=picodet_m_320_coco_fp32_sparse --sparse_model=True +``` + +**注意:** 目前稀疏化推理适用于 PaddleLite的 FP32 和 INT8 模型,所以执行上述命令时,请不要打开 FP16 开关。 + +## 8. 稀疏化结果 +我们在75%和85%稀疏度下,训练得到了 FP32 PicoDet-m模型,并在 SnapDragon-835设备上实测推理速度,效果如下表。其中: +- 对于 m 模型,mAP损失1.5,获得了 34\%-58\% 的加速性能 +- 同样对于 m 模型,除4线程推理速度基本持平外,单线程推理速度、mAP、模型体积均优于 s 模型。 + + +| Model | Input size | Sparsity | mAPval
0.5:0.95 | Size
(MB) | Latency single-thread[Lite](#latency)
(ms) | speed-up single-thread | Latency 4-thread[Lite](#latency)
(ms) | speed-up 4-thread | Download | SlimConfig | +| :-------- | :--------: |:--------: | :---------------------: | :----------------: | :----------------: |:----------------: | :---------------: | :-----------------------------: | :-----------------------------: | :----------------------------------------: | +| PicoDet-m-1.0 | 320*320 | 0 | 30.9 | 8.9 | 127 | 0 | 43 | 0 | [model](https://paddledet.bj.bcebos.com/models/picodet_m_320_coco.pdparams)| [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.3/configs/picodet/picodet_m_320_coco.yml)| +| PicoDet-m-1.0 | 320*320 | 75% | 29.4 | 5.6 | **80** | 58% | **32** | 34% | [model](https://paddledet.bj.bcebos.com/models/slim/picodet_m_320__coco_sparse_75.pdparams)| [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_320__coco_sparse_75.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/slim/prune/picodet_m_unstructured_prune_75.yml)| +| PicoDet-s-1.0 | 320*320 | 0 | 27.1 | 4.6 | 68 | 0 | 26 | 0 | [model](https://paddledet.bj.bcebos.com/models/picodet_s_320_coco.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_s_320_coco.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.3/configs/picodet/picodet_s_320_coco.yml)| +| PicoDet-m-1.0 | 320*320 | 85% | 27.6 | 4.1 | **65** | 96% | **27** | 59% | [model](https://paddledet.bj.bcebos.com/models/slim/picodet_m_320__coco_sparse_85.pdparams) | [log](https://paddledet.bj.bcebos.com/logs/train_picodet_m_320__coco_sparse_85.log) | [config](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/configs/slim/prune/picodet_m_unstructured_prune_85.yml)| + +**注意:** +- 上述模型体积是**部署模型体积**,即 PaddleLite 转换得到的 *.nb 文件的体积。 +- 加速一栏我们按照 FPS 增加百分比计算,即:$(dense\_latency - sparse\_latency) / sparse\_latency$ +- 上述稀疏化训练时,我们额外添加了一种数据增强方式到 _base_/picodet_320_reader.yml,代码如下。但是不添加的话,预期mAP也不会有明显下降(<0.1),且对速度和模型体积没有影响。 +```yaml +worker_num: 6 +TrainReader: + sample_transforms: + - Decode: {} + - RandomCrop: {} + - RandomFlip: {prob: 0.5} + - RandomExpand: {fill_value: [123.675, 116.28, 103.53]} + - RandomDistort: {} + batch_transforms: +etc. +``` diff --git a/configs/picodet/pruner/optimizer_300e_pruner.yml b/configs/picodet/pruner/optimizer_300e_pruner.yml new file mode 100644 index 000000000..064d56233 --- /dev/null +++ b/configs/picodet/pruner/optimizer_300e_pruner.yml @@ -0,0 +1,18 @@ +epoch: 300 + +LearningRate: + base_lr: 0.15 + schedulers: + - !CosineDecay + max_epochs: 300 + - !LinearWarmup + start_factor: 1.0 + steps: 34350 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.00004 + type: L2 diff --git a/configs/picodet/pruner/picodet_m_320_coco_pruner.yml b/configs/picodet/pruner/picodet_m_320_coco_pruner.yml new file mode 100644 index 000000000..cc357048d --- /dev/null +++ b/configs/picodet/pruner/picodet_m_320_coco_pruner.yml @@ -0,0 +1,13 @@ +_BASE_: [ + '../../datasets/coco_detection.yml', + '../../runtime.yml', + '../_base_/picodet_esnet.yml', + './optimizer_300e_pruner.yml', + '../_base_/picodet_320_reader.yml', +] + +weights: output/picodet_m_320_coco/model_final +find_unused_parameters: True +use_ema: true +cycle_epoch: 40 +snapshot_epoch: 10 diff --git a/configs/slim/prune/picodet_m_unstructured_prune_75.yml b/configs/slim/prune/picodet_m_unstructured_prune_75.yml new file mode 100644 index 000000000..94345b4e8 --- /dev/null +++ b/configs/slim/prune/picodet_m_unstructured_prune_75.yml @@ -0,0 +1,11 @@ +pretrain_weights: https://paddledet.bj.bcebos.com/models/picodet_m_320_coco.pdparams +slim: UnstructuredPruner + +UnstructuredPruner: + stable_epochs: 0 + pruning_epochs: 150 + tunning_epochs: 150 + pruning_steps: 300 + ratio: 0.75 + initial_ratio: 0.15 + prune_params_type: conv1x1_only diff --git a/configs/slim/prune/picodet_m_unstructured_prune_85.yml b/configs/slim/prune/picodet_m_unstructured_prune_85.yml new file mode 100644 index 000000000..db0af7e10 --- /dev/null +++ b/configs/slim/prune/picodet_m_unstructured_prune_85.yml @@ -0,0 +1,11 @@ +pretrain_weights: https://paddledet.bj.bcebos.com/models/picodet_m_320_coco.pdparams +slim: UnstructuredPruner + +UnstructuredPruner: + stable_epochs: 0 + pruning_epochs: 150 + tunning_epochs: 150 + pruning_steps: 300 + ratio: 0.85 + initial_ratio: 0.20 + prune_params_type: conv1x1_only diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 2d0348a8f..dc739ff62 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -81,7 +81,8 @@ class Trainer(object): # JDE only support single class MOT now. if cfg.architecture == 'FairMOT' and self.mode == 'train': - cfg['FairMOTEmbeddingHead']['num_identities_dict'] = self.dataset.num_identities_dict + cfg['FairMOTEmbeddingHead'][ + 'num_identities_dict'] = self.dataset.num_identities_dict # FairMOT support single class and multi-class MOT now. # build model @@ -119,6 +120,10 @@ class Trainer(object): self.lr = create('LearningRate')(steps_per_epoch) self.optimizer = create('OptimizerBuilder')(self.lr, self.model) + if self.cfg.get('unstructured_prune'): + self.pruner = create('UnstructuredPruner')(self.model, + steps_per_epoch) + self._nranks = dist.get_world_size() self._local_rank = dist.get_rank() @@ -395,9 +400,10 @@ class Trainer(object): # model backward loss.backward() self.optimizer.step() - curr_lr = self.optimizer.get_lr() self.lr.step() + if self.cfg.get('unstructured_prune'): + self.pruner.step() self.optimizer.clear_grad() self.status['learning_rate'] = curr_lr @@ -414,6 +420,8 @@ class Trainer(object): if self.use_ema: weight = copy.deepcopy(self.model.state_dict()) self.model.set_dict(self.ema.apply()) + if self.cfg.get('unstructured_prune'): + self.pruner.update_params() self._compose_callback.on_epoch_end(self.status) diff --git a/ppdet/slim/__init__.py b/ppdet/slim/__init__.py index a3c6ac763..dc22d0717 100644 --- a/ppdet/slim/__init__.py +++ b/ppdet/slim/__init__.py @@ -15,10 +15,12 @@ from . import prune from . import quant from . import distill +from . import unstructured_prune from .prune import * from .quant import * from .distill import * +from .unstructured_prune import * import yaml from ppdet.core.workspace import load_config @@ -56,6 +58,12 @@ def build_slim_model(cfg, slim_cfg, mode='train'): cfg['slim_type'] = cfg.slim cfg['model'] = slim(model) cfg['slim'] = slim + elif slim_load_cfg['slim'] == 'UnstructuredPruner': + load_config(slim_cfg) + slim = create(cfg.slim) + cfg['slim_type'] = cfg.slim + cfg['slim'] = slim + cfg['unstructured_prune'] = True else: load_config(slim_cfg) model = create(cfg.architecture) diff --git a/ppdet/slim/unstructured_prune.py b/ppdet/slim/unstructured_prune.py new file mode 100644 index 000000000..1dc876a8c --- /dev/null +++ b/ppdet/slim/unstructured_prune.py @@ -0,0 +1,66 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle.utils import try_import + +from ppdet.core.workspace import register, serializable +from ppdet.utils.logger import setup_logger +logger = setup_logger(__name__) + + +@register +@serializable +class UnstructuredPruner(object): + def __init__(self, + stable_epochs, + pruning_epochs, + tunning_epochs, + pruning_steps, + ratio, + initial_ratio, + prune_params_type=None): + self.stable_epochs = stable_epochs + self.pruning_epochs = pruning_epochs + self.tunning_epochs = tunning_epochs + self.ratio = ratio + self.prune_params_type = prune_params_type + self.initial_ratio = initial_ratio + self.pruning_steps = pruning_steps + + def __call__(self, model, steps_per_epoch, skip_params_func=None): + paddleslim = try_import('paddleslim') + from paddleslim import GMPUnstructuredPruner + configs = { + 'pruning_strategy': 'gmp', + 'stable_iterations': self.stable_epochs * steps_per_epoch, + 'pruning_iterations': self.pruning_epochs * steps_per_epoch, + 'tunning_iterations': self.tunning_epochs * steps_per_epoch, + 'resume_iteration': 0, + 'pruning_steps': self.pruning_steps, + 'initial_ratio': self.initial_ratio, + } + + pruner = GMPUnstructuredPruner( + model, + ratio=self.ratio, + skip_params_func=skip_params_func, + prune_params_type=self.prune_params_type, + local_sparsity=True, + configs=configs) + + return pruner -- GitLab