# 非结构化稀疏 -- 动态图剪裁(包括按照阈值和比例剪裁两种模式)示例 ## 简介 在模型压缩中,常见的稀疏方式为结构化和非结构化稀疏,前者在某个特定维度(特征通道、卷积核等等)上进行稀疏化操作;后者以每一个参数为单元进行稀疏化,并不会改变参数矩阵的形状,所以更加依赖于硬件对稀疏后矩阵运算的加速能力。本目录即在PaddlePaddle和PaddleSlim框架下开发的非结构化稀疏算法,`MobileNetV1`在`ImageNet`上的稀疏化实验中,剪裁率55.19%,达到无损的表现。 本示例将演示基于不同的剪裁模式(阈值/比例)进行非结构化稀疏。默认会自动下载并使用`CIFAR-10`数据集。当前示例目前支持`MobileNetV1`,使用其他模型可以按照下面的训练代码示例进行API调用。 ## 版本要求 ```bash python3.5+ paddlepaddle>=2.0.0 paddleslim>=2.1.0 ``` 请参照github安装[paddlepaddle](https://github.com/PaddlePaddle/Paddle)和[paddleslim](https://github.com/PaddlePaddle/PaddleSlim)。 ## 数据准备 本示例支持`CIFAR-10`和`ImageNet`两种数据。默认情况下,会自动下载并使用`CIFAR-10`数据,如果需要使用`ImageNet`数据。请按以下步骤操作: - 根据分类模型中[ImageNet数据准备文档](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E6%95%B0%E6%8D%AE%E5%87%86%E5%A4%87)下载数据到`PaddleSlim/demo/data/ILSVRC2012`路径下。 - 使用`train.py`和`evaluate.py`运行脚本时,指定`--data`选项为`imagenet`。 如果想要使用自定义的数据集,需要重写`../../imagenet_reader.py`文件,并在`train.py`中调用实现。 ## 下载预训练模型 该示例中直接使用`paddle.vision.models`模块提供的针对`ImageNet`分类任务的预训练模型。 对预训练好的模型剪裁后,需要在目标数据集上进行重新训练,以便恢复因剪裁损失的精度。 ## 自定义稀疏化方法 默认根据参数的绝对值大小进行稀疏化,且不稀疏归一化层参数。如果开发者想更改相应的逻辑,可按照下述操作: - 开发者可以通过重写`paddleslim.dygraph.prune.unstructured_pruner.py`中的`UnstructuredPruner.mask_parameters()`和`UnstructuredPruner.update_threshold()`来定义自己的非结构化稀疏策略(目前为剪裁掉绝对值小的parameters)。 - 开发可以在初始化`UnstructuredPruner`时,传入自定义的`skip_params_func`,来定义哪些参数不参与剪裁。`skip_params_func`示例代码如下(路径:`paddleslim.dygraph.prune.unstructured_pruner._get_skip_params())`。默认为所有的归一化层的参数不参与剪裁。 ```python def _get_skip_params(model): """ This function is used to check whether the given model's layers are valid to be pruned. Usually, the convolutions are to be pruned while we skip the normalization-related parameters. Deverlopers could replace this function by passing their own when initializing the UnstructuredPuner instance. Args: - model(Paddle.nn.Layer): the current model waiting to be checked. Return: - skip_params(set): a set of parameters' names """ skip_params = set() for _, sub_layer in model.named_sublayers(): if type(sub_layer).__name__.split('.')[-1] in paddle.nn.norm.__all__: skip_params.add(sub_layer.full_name()) return skip_params ``` ## 训练 按照阈值剪裁: ```bash python3.7 train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01 ``` 按照比例剪裁(训练速度较慢,推荐按照阈值剪裁): ```bash python3.7 train.py --data imagenet --lr 0.05 --pruning_mode ratio --ratio 0.5 ``` GPU多卡训练: ```bash export CUDA_VISIBLE_DEVICES=0,1,2,3 python3.7 -m paddle.distributed.launch \ --gpus="0,1,2,3" \ --log_dir="train_mbv1_imagenet_threshold_001_log" \ train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01 ``` 恢复训练(请替代命令中的`dir/to/the/saved/pruned/model`和`INTERRUPTED_EPOCH`): ```bash python3.7 train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01 \ --pretrained_model dir/to/the/saved/pruned/model --resume_epoch INTERRUPTED_EPOCH ``` ## 推理: ```bash python3.7 eval --pruned_model models/ --data imagenet ``` 剪裁训练代码示例: ```python model = mobilenet_v1(num_classes=class_dim, pretrained=True) #STEP1: initialize the pruner pruner = UnstructuredPruner(model, mode='threshold', threshold=0.01) for epoch in range(epochs): for batch_id, data in enumerate(train_loader): loss = calculate_loss() loss.backward() opt.step() opt.clear_grad() #STEP2: update the pruner's threshold given the updated parameters pruner.step() if epoch % args.test_period == 0: #STEP3: before evaluation during training, eliminate the non-zeros generated by opt.step(), which, however, the cached masks setting to be zeros. pruner.update_params() eval(epoch) if epoch % args.model_period == 0: # STEP4: same purpose as STEP3 pruner.update_params() paddle.save(model.state_dict(), "model-pruned.pdparams") paddle.save(opt.state_dict(), "opt-pruned.pdopt") ``` 剪裁后测试代码示例: ```python model = mobilenet_v1(num_classes=class_dim, pretrained=True) model.set_state_dict(paddle.load("model-pruned.pdparams")) #注意,total_sparse为静态方法(static method),可以不创建实例(instance)直接调用,方便只做测试的写法。 print(UnstructuredPruner.total_sparse(model)) test() ``` 更多使用参数请参照shell文件或者运行如下命令查看: ```bash python3.7 train --h python3.7 evaluate --h ``` ## 实验结果 | 模型 | 数据集 | 压缩方法 | 压缩率| Top-1/Top-5 Acc | lr | threshold | epoch | |:--:|:---:|:--:|:--:|:--:|:--:|:--:|:--:| | MobileNetV1 | ImageNet | Baseline | - | 70.99%/89.68% | - | - | - | | MobileNetV1 | ImageNet | ratio | -55.19% | 70.87%/89.80% (-0.12%/+0.12%) | 0.005 | - | 68 | | YOLO v3 | VOC | - | - |76.24% | - | - | - | | YOLO v3 | VOC |threshold | -56.50% | 77.02%(+0.78%) | 0.001 | 0.01 | 102k iterations |