README.md 6.4 KB
Newer Older
1
# 非结构化稀疏 -- 动态图剪裁(包括按照阈值和比例剪裁两种模式)示例
M
minghaoBD 已提交
2 3 4

## 简介

5 6 7
在模型压缩中,常见的稀疏方式为结构化和非结构化稀疏,前者在某个特定维度(特征通道、卷积核等等)上进行稀疏化操作;后者以每一个参数为单元进行稀疏化,并不会改变参数矩阵的形状,所以更加依赖于硬件对稀疏后矩阵运算的加速能力。本目录即在PaddlePaddle和PaddleSlim框架下开发的非结构化稀疏算法,`MobileNetV1``ImageNet`上的稀疏化实验中,剪裁率55.19%,达到无损的表现。

本示例将演示基于不同的剪裁模式(阈值/比例)进行非结构化稀疏。默认会自动下载并使用`CIFAR-10`数据集。当前示例目前支持`MobileNetV1`,使用其他模型可以按照下面的训练代码示例进行API调用。
M
minghaoBD 已提交
8 9 10 11 12 13 14 15 16 17

## 版本要求
```bash
python3.5+
paddlepaddle>=2.0.0
paddleslim>=2.1.0
```

请参照github安装[paddlepaddle](https://github.com/PaddlePaddle/Paddle)[paddleslim](https://github.com/PaddlePaddle/PaddleSlim)

18 19 20 21 22 23 24 25 26 27
## 数据准备

本示例支持`CIFAR-10``ImageNet`两种数据。默认情况下,会自动下载并使用`CIFAR-10`数据,如果需要使用`ImageNet`数据。请按以下步骤操作:

- 根据分类模型中[ImageNet数据准备文档](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E6%95%B0%E6%8D%AE%E5%87%86%E5%A4%87)下载数据到`PaddleSlim/demo/data/ILSVRC2012`路径下。
- 使用`train.py``evaluate.py`运行脚本时,指定`--data`选项为`imagenet`

如果想要使用自定义的数据集,需要重写`../../imagenet_reader.py`文件,并在`train.py`中调用实现。

## 下载预训练模型
M
minghaoBD 已提交
28

29 30 31 32 33 34 35 36
该示例中直接使用`paddle.vision.models`模块提供的针对`ImageNet`分类任务的预训练模型。 对预训练好的模型剪裁后,需要在目标数据集上进行重新训练,以便恢复因剪裁损失的精度。

## 自定义稀疏化方法

默认根据参数的绝对值大小进行稀疏化,且不稀疏归一化层参数。如果开发者想更改相应的逻辑,可按照下述操作:

- 开发者可以通过重写`paddleslim.dygraph.prune.unstructured_pruner.py`中的`UnstructuredPruner.mask_parameters()``UnstructuredPruner.update_threshold()`来定义自己的非结构化稀疏策略(目前为剪裁掉绝对值小的parameters)。
- 开发可以在初始化`UnstructuredPruner`时,传入自定义的`skip_params_func`,来定义哪些参数不参与剪裁。`skip_params_func`示例代码如下(路径:`paddleslim.dygraph.prune.unstructured_pruner._get_skip_params())`。默认为所有的归一化层的参数不参与剪裁。
M
minghaoBD 已提交
37 38

```python
39 40 41 42
NORMS_ALL = [ 'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'BatchNorm1D',
    'BatchNorm2D', 'BatchNorm3D', 'InstanceNorm1D', 'InstanceNorm2D',
    'InstanceNorm3D', 'SyncBatchNorm', 'LocalResponseNorm' ]

M
minghaoBD 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55
def _get_skip_params(model):
    """
    This function is used to check whether the given model's layers are valid to be pruned.
    Usually, the convolutions are to be pruned while we skip the normalization-related parameters.
    Deverlopers could replace this function by passing their own when initializing the UnstructuredPuner instance.

    Args:
      - model(Paddle.nn.Layer): the current model waiting to be checked.
    Return:
      - skip_params(set<String>): a set of parameters' names
    """
    skip_params = set()
    for _, sub_layer in model.named_sublayers():
56
        if type(sub_layer).__name__.split('.')[-1] in NORMS_ALL:
M
minghaoBD 已提交
57 58 59 60
            skip_params.add(sub_layer.full_name())
    return skip_params
```

61 62 63 64 65 66 67 68 69 70 71 72 73
## 训练

按照阈值剪裁:
```bash
python3.7 train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01
```

按照比例剪裁(训练速度较慢,推荐按照阈值剪裁):
```bash
python3.7 train.py --data imagenet --lr 0.05 --pruning_mode ratio --ratio 0.5
```

GPU多卡训练:
M
minghaoBD 已提交
74
```bash
75 76 77 78 79
export CUDA_VISIBLE_DEVICES=0,1,2,3
python3.7 -m paddle.distributed.launch \
--gpus="0,1,2,3" \
--log_dir="train_mbv1_imagenet_threshold_001_log" \
train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01
M
minghaoBD 已提交
80 81
```

82
恢复训练(请替代命令中的`dir/to/the/saved/pruned/model``INTERRUPTED_EPOCH`):
M
minghaoBD 已提交
83
```bash
84 85 86 87 88 89 90
python3.7 train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01 \
                                            --pretrained_model dir/to/the/saved/pruned/model --resume_epoch INTERRUPTED_EPOCH
```

## 推理:
```bash
python3.7 eval --pruned_model models/ --data imagenet
M
minghaoBD 已提交
91 92 93 94 95 96
```

剪裁训练代码示例:
```python
model = mobilenet_v1(num_classes=class_dim, pretrained=True)
#STEP1: initialize the pruner
97
pruner = UnstructuredPruner(model, mode='threshold', threshold=0.01)
M
minghaoBD 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123

for epoch in range(epochs):
    for batch_id, data in enumerate(train_loader):
        loss = calculate_loss()
        loss.backward()
        opt.step()
        opt.clear_grad()
        #STEP2: update the pruner's threshold given the updated parameters
        pruner.step()

    if epoch % args.test_period == 0:
        #STEP3: before evaluation during training, eliminate the non-zeros generated by opt.step(), which, however, the cached masks setting to be zeros.
        pruner.update_params()
        eval(epoch)

    if epoch % args.model_period == 0:
        # STEP4: same purpose as STEP3
        pruner.update_params()
        paddle.save(model.state_dict(), "model-pruned.pdparams")
        paddle.save(opt.state_dict(), "opt-pruned.pdopt")
```

剪裁后测试代码示例:
```python
model = mobilenet_v1(num_classes=class_dim, pretrained=True)
model.set_state_dict(paddle.load("model-pruned.pdparams"))
124 125
#注意,total_sparse为静态方法(static method),可以不创建实例(instance)直接调用,方便只做测试的写法。
print(UnstructuredPruner.total_sparse(model))
M
minghaoBD 已提交
126 127 128 129 130
test()
```

更多使用参数请参照shell文件或者运行如下命令查看:
```bash
131 132
python3.7 train --h
python3.7 evaluate --h
M
minghaoBD 已提交
133 134
```

135
## 实验结果
M
minghaoBD 已提交
136 137 138 139 140 141

| 模型 | 数据集 | 压缩方法 | 压缩率| Top-1/Top-5 Acc | lr | threshold | epoch |
|:--:|:---:|:--:|:--:|:--:|:--:|:--:|:--:|
| MobileNetV1 | ImageNet | Baseline | - | 70.99%/89.68% | - | - | - |
| MobileNetV1 | ImageNet |   ratio  | -55.19% | 70.87%/89.80% (-0.12%/+0.12%) | 0.005 | - | 68 |
| YOLO v3     |  VOC     | - | - |76.24% | - | - | - |
142
| YOLO v3     |  VOC     |threshold | -56.50% | 77.02%(+0.78%) | 0.001 | 0.01 | 102k iterations |