diff --git a/demo/dygraph/unstructured_pruning/README.md b/demo/dygraph/unstructured_pruning/README.md index d343bfa5aea47ad9ec59f8746c750aeba371432b..9af494666060e2ab42834b8103a4b77dc34e34b5 100644 --- a/demo/dygraph/unstructured_pruning/README.md +++ b/demo/dygraph/unstructured_pruning/README.md @@ -1,8 +1,10 @@ -# 非结构化稀疏 -- 动态图剪裁(包括按照阈值和比例剪裁两种模式) +# 非结构化稀疏 -- 动态图剪裁(包括按照阈值和比例剪裁两种模式)示例 ## 简介 -在模型压缩中,常见的稀疏方式为结构化和非结构化稀疏,前者在某个特定维度(特征通道、卷积核等等)上进行稀疏化操作;后者以每一个参数为单元进行稀疏化,所以更加依赖于硬件对稀疏后矩阵运算的加速能力。本目录即在PaddlePaddle和PaddleSlim框架下开发的非结构化稀疏算法,MobileNetV1在ImageNet上的稀疏化实验中,剪裁率55.19%,达到无损的表现。 +在模型压缩中,常见的稀疏方式为结构化和非结构化稀疏,前者在某个特定维度(特征通道、卷积核等等)上进行稀疏化操作;后者以每一个参数为单元进行稀疏化,并不会改变参数矩阵的形状,所以更加依赖于硬件对稀疏后矩阵运算的加速能力。本目录即在PaddlePaddle和PaddleSlim框架下开发的非结构化稀疏算法,`MobileNetV1`在`ImageNet`上的稀疏化实验中,剪裁率55.19%,达到无损的表现。 + +本示例将演示基于不同的剪裁模式(阈值/比例)进行非结构化稀疏。默认会自动下载并使用`CIFAR-10`数据集。当前示例目前支持`MobileNetV1`,使用其他模型可以按照下面的训练代码示例进行API调用。 ## 版本要求 ```bash @@ -13,12 +15,25 @@ paddleslim>=2.1.0 请参照github安装[paddlepaddle](https://github.com/PaddlePaddle/Paddle)和[paddleslim](https://github.com/PaddlePaddle/PaddleSlim)。 -## 使用 +## 数据准备 + +本示例支持`CIFAR-10`和`ImageNet`两种数据。默认情况下,会自动下载并使用`CIFAR-10`数据,如果需要使用`ImageNet`数据。请按以下步骤操作: + +- 根据分类模型中[ImageNet数据准备文档](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E6%95%B0%E6%8D%AE%E5%87%86%E5%A4%87)下载数据到`PaddleSlim/demo/data/ILSVRC2012`路径下。 +- 使用`train.py`和`evaluate.py`运行脚本时,指定`--data`选项为`imagenet`。 + +如果想要使用自定义的数据集,需要重写`../../imagenet_reader.py`文件,并在`train.py`中调用实现。 + +## 下载预训练模型 -训练前: -- 训练数据下载后,可以通过重写../imagenet_reader.py文件,并在train.py/evaluate.py文件中调用实现。 -- 开发者可以通过重写paddleslim.dygraph.prune.unstructured_pruner.py中的UnstructuredPruner.mask_parameters()和UnstructuredPruner.update_threshold()来定义自己的非结构化稀疏策略(目前为剪裁掉绝对值小的parameters)。 -- 开发可以在初始化UnstructuredPruner时,传入自定义的skip_params_func,来定义哪些参数不参与剪裁。skip_params_func示例代码如下(路径:paddleslim.dygraph.prune.unstructured_pruner._get_skip_params())。默认为所有的归一化层的参数不参与剪裁。 +该示例中直接使用`paddle.vision.models`模块提供的针对`ImageNet`分类任务的预训练模型。 对预训练好的模型剪裁后,需要在目标数据集上进行重新训练,以便恢复因剪裁损失的精度。 + +## 自定义稀疏化方法 + +默认根据参数的绝对值大小进行稀疏化,且不稀疏归一化层参数。如果开发者想更改相应的逻辑,可按照下述操作: + +- 开发者可以通过重写`paddleslim.dygraph.prune.unstructured_pruner.py`中的`UnstructuredPruner.mask_parameters()`和`UnstructuredPruner.update_threshold()`来定义自己的非结构化稀疏策略(目前为剪裁掉绝对值小的parameters)。 +- 开发可以在初始化`UnstructuredPruner`时,传入自定义的`skip_params_func`,来定义哪些参数不参与剪裁。`skip_params_func`示例代码如下(路径:`paddleslim.dygraph.prune.unstructured_pruner._get_skip_params())`。默认为所有的归一化层的参数不参与剪裁。 ```python def _get_skip_params(model): @@ -39,21 +54,43 @@ def _get_skip_params(model): return skip_params ``` -训练: +## 训练 + +按照阈值剪裁: +```bash +python3.7 train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01 +``` + +按照比例剪裁(训练速度较慢,推荐按照阈值剪裁): +```bash +python3.7 train.py --data imagenet --lr 0.05 --pruning_mode ratio --ratio 0.5 +``` + +GPU多卡训练: ```bash -python3 train.py --data cifar10 --lr 0.1 --pruning_mode ratio --ratio=0.5 +export CUDA_VISIBLE_DEVICES=0,1,2,3 +python3.7 -m paddle.distributed.launch \ +--gpus="0,1,2,3" \ +--log_dir="train_mbv1_imagenet_threshold_001_log" \ +train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01 ``` -推理: +恢复训练(请替代命令中的`dir/to/the/saved/pruned/model`和`INTERRUPTED_EPOCH`): ```bash -python3 eval --pruned_model models/ --data cifar10 +python3.7 train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01 \ + --pretrained_model dir/to/the/saved/pruned/model --resume_epoch INTERRUPTED_EPOCH +``` + +## 推理: +```bash +python3.7 eval --pruned_model models/ --data imagenet ``` 剪裁训练代码示例: ```python model = mobilenet_v1(num_classes=class_dim, pretrained=True) #STEP1: initialize the pruner -pruner = UnstructuredPruner(model, mode='ratio', ratio=0.5) +pruner = UnstructuredPruner(model, mode='threshold', threshold=0.01) for epoch in range(epochs): for batch_id, data in enumerate(train_loader): @@ -80,27 +117,22 @@ for epoch in range(epochs): ```python model = mobilenet_v1(num_classes=class_dim, pretrained=True) model.set_state_dict(paddle.load("model-pruned.pdparams")) -print(UnstructuredPruner.total_sparse(model)) #注意,total_sparse为静态方法(static method),可以不创建实例(instance)直接调用,方便只做测试的写法。 +#注意,total_sparse为静态方法(static method),可以不创建实例(instance)直接调用,方便只做测试的写法。 +print(UnstructuredPruner.total_sparse(model)) test() ``` 更多使用参数请参照shell文件或者运行如下命令查看: ```bash -python train --h -python evaluate --h +python3.7 train --h +python3.7 evaluate --h ``` -## 实验结果 (刚开始在动态图代码验证,以下为静态图代码上的结果) +## 实验结果 | 模型 | 数据集 | 压缩方法 | 压缩率| Top-1/Top-5 Acc | lr | threshold | epoch | |:--:|:---:|:--:|:--:|:--:|:--:|:--:|:--:| | MobileNetV1 | ImageNet | Baseline | - | 70.99%/89.68% | - | - | - | | MobileNetV1 | ImageNet | ratio | -55.19% | 70.87%/89.80% (-0.12%/+0.12%) | 0.005 | - | 68 | | YOLO v3 | VOC | - | - |76.24% | - | - | - | -| YOLO v3 | VOC |threshold | -41.35% | 75.29%(-0.95%) | 0.005 | 0.05 | 10w | -| YOLO v3 | VOC |threshold | -53.00% | 75.00%(-1.24%) | 0.005 | 0.075 | 10w | - -## TODO - -- [ ] 完成实验,验证动态图下的效果,并得到压缩模型。 -- [ ] 扩充衡量parameter重要性的方法(目前仅为绝对值)。 +| YOLO v3 | VOC |threshold | -56.50% | 77.02%(+0.78%) | 0.001 | 0.01 | 102k iterations | diff --git a/demo/dygraph/unstructured_pruning/evaluate.py b/demo/dygraph/unstructured_pruning/evaluate.py index bd24a4b87b1bf093ec87e1d073dbb7494aed9a59..21e62b16d1e60c477903ff1022c7b5bce515a639 100644 --- a/demo/dygraph/unstructured_pruning/evaluate.py +++ b/demo/dygraph/unstructured_pruning/evaluate.py @@ -5,7 +5,7 @@ import argparse import numpy as np sys.path.append( os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir)) -from paddleslim.dygraph.prune.unstructured_pruner import UnstructuredPruner +from paddleslim import UnstructuredPruner from utility import add_arguments, print_arguments import paddle.vision.transforms as T import paddle.nn.functional as F diff --git a/demo/dygraph/unstructured_pruning/train.py b/demo/dygraph/unstructured_pruning/train.py index 30343e0f91a8243a64f116010e539f7277ff08d3..af6fb2c644b5c0afa910d1f810729b9d81c8992f 100644 --- a/demo/dygraph/unstructured_pruning/train.py +++ b/demo/dygraph/unstructured_pruning/train.py @@ -3,7 +3,7 @@ import os import sys import argparse import numpy as np -from paddleslim.dygraph.prune.unstructured_pruner import UnstructuredPruner +from paddleslim import UnstructuredPruner sys.path.append( os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir)) from utility import add_arguments, print_arguments @@ -35,6 +35,7 @@ parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], add_arg('data', str, "cifar10", "Which data to use. 'cifar10' or 'imagenet'.") add_arg('log_period', int, 100, "Log period in batches.") add_arg('test_period', int, 1, "Test period in epoches.") +add_arg('pretrained_model', str, None, "The pretrained model the load. Default: None.") add_arg('model_path', str, "./models", "The path to save model.") add_arg('model_period', int, 10, "The period to save model in epochs.") add_arg('resume_epoch', int, -1, "The epoch to resume training.") @@ -117,12 +118,13 @@ def compress(args): # model definition model = mobilenet_v1(num_classes=class_dim, pretrained=True) - dp_model = paddle.DataParallel(model) + if args.pretrained_model is not None: + model.set_state_dict(paddle.load(args.pretrained_model)) - opt, learning_rate = create_optimizer(args, step_per_epoch, dp_model) + opt, learning_rate = create_optimizer(args, step_per_epoch, model) def test(epoch): - dp_model.eval() + model.eval() acc_top1_ns = [] acc_top5_ns = [] for batch_id, data in enumerate(valid_loader): @@ -133,7 +135,7 @@ def compress(args): y_data = paddle.unsqueeze(y_data, 1) end_time = time.time() - logits = dp_model(x_data) + logits = model(x_data) loss = F.cross_entropy(logits, y_data) acc_top1 = paddle.metric.accuracy(logits, y_data, k=1) acc_top5 = paddle.metric.accuracy(logits, y_data, k=5) @@ -157,7 +159,7 @@ def compress(args): acc_top5_ns, dtype="object")))) def train(epoch): - dp_model.train() + model.train() for batch_id, data in enumerate(train_loader): start_time = time.time() x_data = data[0] @@ -165,7 +167,7 @@ def compress(args): if args.data == 'cifar10': y_data = paddle.unsqueeze(y_data, 1) - logits = dp_model(x_data) + logits = model(x_data) loss = F.cross_entropy(logits, y_data) acc_top1 = paddle.metric.accuracy(logits, y_data, k=1) acc_top5 = paddle.metric.accuracy(logits, y_data, k=5) @@ -183,7 +185,7 @@ def compress(args): pruner.step() pruner = UnstructuredPruner( - dp_model, + model, mode=args.pruning_mode, ratio=args.ratio, threshold=args.threshold) @@ -193,11 +195,11 @@ def compress(args): pruner.update_params() _logger.info( "The current density of the pruned model is: {}%".format( - round(100 * UnstructuredPruner.total_sparse(dp_model), 2))) + round(100 * UnstructuredPruner.total_sparse(model), 2))) test(i) if i > args.resume_epoch and i % args.model_period == 0: pruner.update_params() - paddle.save(dp_model.state_dict(), + paddle.save(model.state_dict(), os.path.join(args.model_path, "model-pruned.pdparams")) paddle.save(opt.state_dict(), os.path.join(args.model_path, "opt-pruned.pdopt")) diff --git a/demo/unstructured_prune/README.md b/demo/unstructured_prune/README.md index 17d095af5a3abc6e7a8038a643236dda95f73885..2bb142f610977851ce0fce1ac458eb740824743e 100644 --- a/demo/unstructured_prune/README.md +++ b/demo/unstructured_prune/README.md @@ -1,8 +1,10 @@ -# 非结构化稀疏 -- 静态图剪裁(包括按照阈值和比例剪裁两种模式) +# 非结构化稀疏 -- 静态图剪裁(包括按照阈值和比例剪裁两种模式)示例 ## 简介 -在模型压缩中,常见的稀疏方式为结构化和非结构化稀疏,前者在某个特定维度(特征通道、卷积核等等)上进行稀疏化操作;后者以每一个参数为单元进行稀疏化,所以更加依赖于硬件对稀疏后矩阵运算的加速能力。本目录即在PaddlePaddle和PaddleSlim框架下开发的非结构化稀疏算法,MobileNetV1在ImageNet上的稀疏化实验中,剪裁率55.19%,达到无损的表现。 +在模型压缩中,常见的稀疏方式为结构化和非结构化稀疏,前者在某个特定维度(特征通道、卷积核等等)上进行稀疏化操作;后者以每一个参数为单元进行稀疏化,并不会改变参数矩阵的形状,所以更加依赖于硬件对稀疏后矩阵运算的加速能力。本目录即在PaddlePaddle和PaddleSlim框架下开发的非结构化稀疏算法,`MobileNetV1`在`ImageNet`上的稀疏化实验中,剪裁率55.19%,达到无损的表现。 + +本示例将演示基于不同的剪裁模式(阈值/比例)进行非结构化稀疏。默认会自动下载并使用`MNIST`数据集。当前示例目前支持`MobileNetV1`,使用其他模型可以按照下面的**训练代码示例**进>行API调用。 ## 版本要求 ```bash @@ -11,15 +13,36 @@ paddlepaddle>=2.0.0 paddleslim>=2.1.0 ``` -请参照github安装[paddlepaddle](https://github.com/PaddlePaddle/Paddle)和[paddleslim](https://github.com/PaddlePaddle/PaddleSlim)。 +请参照github安装[PaddlePaddle](https://github.com/PaddlePaddle/Paddle)和[PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim)。 + +## 数据准备 + +本示例支持`MNIST`和`ImageNet`两种数据。默认情况下,会自动下载并使用`MNIST`数据,如果需要使用`ImageNet`数据。请按以下步骤操作: + +- 根据分类模型中[ImageNet数据准备文档](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E6%95%B0%E6%8D%AE%E5%87%86%E5%A4%87)下载数据到`PaddleSlim/demo/data/ILSVRC2012`路径下。 +- 使用`train.py`和`evaluate.py`运行脚本时,指定`--data`选项为`imagenet`。 + +如果想要使用自定义的数据集,需要重写`../imagenet_reader.py`文件,并在`train.py`中调用实现。 + +## 下载预训练模型 + +如果使用`ImageNet`数据,建议在预训练模型的基础上进行剪裁,请从[这里](http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar)下载预训练模型。 + +下载并解压预训练模型到当前路径: + +``` +wget http://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV1_pretrained.tar +tar -xf MobileNetV1_pretrained.tar +``` + +使用`train.py`脚本时,指定`--pretrained_model`加载预训练模型,`MNIST`数据无需指定。 + +## 自定义稀疏化方法 -## 使用 +默认根据参数的绝对值大小进行稀疏化,且不稀疏归一化层参数。如果开发者想更改相应的逻辑,可按照下述操作: -训练前: -- 预训练模型下载,并放到某目录下,通过train.py中的--pretrained_model设置。 -- 训练数据下载后,可以通过重写../imagenet_reader.py文件,并在train.py文件中调用实现。 -- 开发者可以通过重写paddleslim.prune.unstructured_pruner.py中的UnstructuredPruner.update_threshold()来定义自己的非结构化稀疏策略(目前为剪裁掉绝对值小的parameters)。 -- 开发可以在初始化UnstructuredPruner时,传入自定义的skip_params_func,来定义哪些参数不参与剪裁。skip_params_func示例代码如下(路径:paddleslim.prune.unstructured_pruner._get_skip_params())。默认为所有的归一化层的参数不参与剪裁。 +- 可以通过重写`paddleslim.prune.unstructured_pruner.py`中的`UnstructuredPruner.update_threshold()`来定义自己的非结构化稀疏策略(目前为剪裁掉绝对值小的parameters)。 +- 可以在初始化`UnstructuredPruner`时,传入自定义的`skip_params_func`,来定义哪些参数不参与剪裁。`skip_params_func`示例代码如下(路径:`paddleslim.prune.unstructured_pruner._get_skip_params()`)。默认为所有的归一化层的参数不参与剪裁。 ```python def _get_skip_params(program): @@ -41,12 +64,25 @@ def _get_skip_params(program): return skip_params ``` -训练: +## 训练 + +按照阈值剪裁: ```bash -CUDA_VISIBLE_DEVICES=2,3 python3.7 train.py --data mnist --lr 0.1 --pruning_mode ratio --ratio=0.5 +CUDA_VISIBLE_DEVICES=2,3 python3.7 train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01 ``` -推理: +按照比例剪裁(训练速度较慢,推荐按照阈值剪裁): +```bash +CUDA_VISIBLE_DEVICES=2,3 python3.7 train.py --data imagenet --lr 0.05 --pruning_mode ratio --ratio 0.5 +``` + +恢复训练(请替代命令中的`dir/to/the/saved/pruned/model`和`INTERRUPTED_EPOCH`): +``` +CUDA_VISIBLE_DEVICES=2,3 python3.7 train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01 \ + --pretrained_model dir/to/the/saved/pruned/model --resume_epoch INTERRUPTED_EPOCH +``` + +## 推理 ```bash CUDA_VISIBLE_DEVICES=0 python3.7 evaluate.py --pruned_model models/ --data imagenet ``` @@ -70,7 +106,8 @@ opt, learning_rate = create_optimizer(args, step_per_epoch) opt.minimize(avg_cost) #STEP1: initialize the pruner -pruner = UnstructuredPruner(paddle.static.default_main_program(), mode='ratio', ratio=0.5, place=place) +pruner = UnstructuredPruner(paddle.static.default_main_program(), mode='threshold', threshold=0.01, place=place) # 按照阈值剪裁 +# pruner = UnstructuredPruner(paddle.static.default_main_program(), mode='ratio', ratio=0.5, place=place) # 按照比例剪裁 exe.run(paddle.static.default_startup_program()) paddle.fluid.io.load_vars(exe, args.pretrained_model) @@ -103,7 +140,8 @@ for epoch in range(epochs): ```python # intialize the model instance in static mode # load weights -print(UnstructuredPruner.total_sparse(paddle.static.default_main_program())) #注意,total_sparse为静态方法(static method),可以不创建实例(instance)直接调用,方便只做测试的写法。 +print(UnstructuredPruner.total_sparse(paddle.static.default_main_program())) +#注意,total_sparse为静态方法(static method),可以不创建实例(instance)直接调用,方便只做测试的写法。 test() ``` @@ -118,11 +156,6 @@ python3.7 evaluate.py --h | 模型 | 数据集 | 压缩方法 | 压缩率| Top-1/Top-5 Acc | lr | threshold | epoch | |:--:|:---:|:--:|:--:|:--:|:--:|:--:|:--:| | MobileNetV1 | ImageNet | Baseline | - | 70.99%/89.68% | - | - | - | -| MobileNetV1 | ImageNet | ratio | -55.19% | 70.87%/89.80% (-0.12%/+0.12%) | 0.005 | - | 68 | +| MobileNetV1 | ImageNet | ratio | -55.19% | 70.87%/89.80% (-0.12%/+0.12%) | 0.05 | - | 68 | | YOLO v3 | VOC | - | - |76.24% | - | - | - | -| YOLO v3 | VOC |threshold | -55.15% | 75.45%(-0.79%) | 0.005 | 0.05 |12.8w| - -## TODO - -- [ ] 完成实验,验证动态图下的效果,并得到压缩模型。 -- [ ] 扩充衡量parameter重要性的方法(目前仅为绝对值)。 +| YOLO v3 | VOC |threshold | -56.50% | 77.02%(+0.78%) | 0.001 | 0.01 |102k iterations| diff --git a/demo/unstructured_prune/train.py b/demo/unstructured_prune/train.py index 384845f35bcc56d739a67356da7fa90ac1c640fc..af8a3ddc108f69eb2988c4db5fbb8240a04d0a97 100644 --- a/demo/unstructured_prune/train.py +++ b/demo/unstructured_prune/train.py @@ -140,7 +140,6 @@ def compress(args): pruner = UnstructuredPruner( paddle.static.default_main_program(), - batch_size=args.batch_size, mode=args.pruning_mode, ratio=args.ratio, threshold=args.threshold, diff --git a/docs/zh_cn/api_cn/dygraph/pruners/index.rst b/docs/zh_cn/api_cn/dygraph/pruners/index.rst index 59fa38e2ff786749dc6b99894f503eca2c3b490c..612c76af0a9d0e013b45fa4f7df8af85727dc18b 100644 --- a/docs/zh_cn/api_cn/dygraph/pruners/index.rst +++ b/docs/zh_cn/api_cn/dygraph/pruners/index.rst @@ -7,3 +7,4 @@ Pruners l1norm_filter_pruner.rst l2norm_filter_pruner.rst fpgm_filter_pruner.rst + unstructured_pruner.rst diff --git a/docs/zh_cn/api_cn/dygraph/pruners/unstructured_pruner.rst b/docs/zh_cn/api_cn/dygraph/pruners/unstructured_pruner.rst new file mode 100644 index 0000000000000000000000000000000000000000..63111975148e4140ccecfa60971f3c1fe89394d9 --- /dev/null +++ b/docs/zh_cn/api_cn/dygraph/pruners/unstructured_pruner.rst @@ -0,0 +1,113 @@ +非结构化稀疏 +================ + +UnstructuredPruner +---------- + +.. py:class:: paddleslim.UnstructuredPruner(model, mode, threshold=0.01, ratio=0.3, skip_params_func=None) + + +`源代码 `_ + +对于神经网络中的参数进行非结构化稀疏。非结构化稀疏是指,根据某些衡量指标,将不重要的参数置0。其不按照固定结构剪裁(例如一个通道等),这是和结构化剪枝的主要区别。 + +**参数:** + +- **model(paddle.nn.Layer)** - 待剪裁的动态图模型。 +- **mode(str)** - 稀疏化的模式,目前支持的模式有:'ratio'和'threshold'。在'ratio'模式下,会给定一个固定比例,例如0.5,然后所有参数中重要性较低的50%会被置0。类似的,在'threshold'模式下,会给定一个固定阈值,例如1e-5,然后重要性低于1e-5的参数会被置0。 +- **ratio(float)** - 稀疏化比例期望,只有在 mode=='ratio' 时才会生效。 +- **threshold(float)** - 稀疏化阈值期望,只有在 mode=='threshold' 时才会生效。 +- **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。 + +**返回:** 一个UnstructuredPruner类的实例。 + +**示例代码:** + +此示例不能直接运行,因为需要定义和加载模型,详细用法请参考 `这里 `_ + +.. code-block:: python + + from paddleslim import UnstructuredPruner + pruner = UnstructuredPruner(model, mode='ratio', ratio=0.5) + +.. + + .. py:method:: paddleslim.UnstructuredPruner.step() + + 更新稀疏化的阈值,如果是'threshold'模式,则维持设定的阈值,如果是'ratio'模式,则根据优化后的模型参数和设定的比例,重新计算阈值。 + + **示例代码:** + + 此示例不能直接运行,因为需要定义和加载模型,详细用法请参考 `这里 `_ + + .. code-block:: python + + from paddleslim import UnstructuredPruner + pruner = UnstructuredPruner(model, mode='ratio', ratio=0.5) + pruner.step() + + .. + + .. py:method:: paddleslim.UnstructuredPruner.update_params() + + 每一步优化后,重制模型中本来是0的权重。这一步通常用于模型evaluation和save之前,确保模型的稀疏率。 + + **示例代码:** + + 此示例不能直接运行,因为需要定义和加载模型,详细用法请参考 `这里 `_ + + .. code-block:: python + + from paddleslim import UnstructuredPruner + pruner = UnstructuredPruner(model, mode='ratio', ratio=0.5) + pruner.update_params() + + .. + + .. py:method:: paddleslim.UnstructuredPruner.total_sparse(model) + + UnstructuredPruner中的静态方法,用于计算给定的模型(model)的稠密度(1-稀疏度)并返回。该方法为静态方法,是考虑到在单单做模型评价的时候,我们就不需要初始化一个UnstructuredPruner示例了。 + + **参数:** + + - **model(paddle.nn.Layer)** - 要计算稠密度的目标网络。 + + **返回:** + + - **density(float)** - 模型的稠密度。 + + **示例代码:** + + 此示例不能直接运行,因为需要定义和加载模型,详细用法请参考 `这里 `_ + + .. code-block:: python + + from paddleslim import UnstructuredPruner + density = UnstructuredPruner.total_sparse(model) + + .. + + .. py:method:: paddleslim.UnstructuredPruner.summarize_weights(model, ratio=0.1) + + 该函数用于估计预训练模型中参数的分布情况,尤其是在不清楚如何设置threshold的数值时,尤为有用。例如,当输入为ratio=0.1时,函数会返回一个数值v,而绝对值小于v的权重的个数占所有权重个数的(100*ratio%)。 + + **参数:** + + - **model(paddle.nn.Layer)** - 要分析权重分布的目标网络。 + - **ratio(float)** - 需要查看的比例情况,具体如上方法描述。 + + **返回:** + + - **threshold(float)** - 和输入ratio对应的阈值。开发者可以根据该阈值初始化UnstructuredPruner。 + + **示例代码:** + + 此示例不能直接运行,因为需要定义和加载模型,详细用法请参考 `这里 `_ + + .. code-block:: python + + from paddleslim import UnstructuredPruner + pruner = UnstructuredPruner(model, mode='ratio', ratio=0.5) + threshold = pruner.summarize_weights(model, ratio=0.1) + + .. diff --git a/docs/zh_cn/api_cn/static/prune/prune_index.rst b/docs/zh_cn/api_cn/static/prune/prune_index.rst index ad2c8ffd3cf7784aeb3a3578c63efcb937775254..998640e824d19345c7c83aea7a144f057d1b4e03 100644 --- a/docs/zh_cn/api_cn/static/prune/prune_index.rst +++ b/docs/zh_cn/api_cn/static/prune/prune_index.rst @@ -6,3 +6,4 @@ :maxdepth: 1 prune_api.rst + unstructured_prune_api.rst diff --git a/docs/zh_cn/api_cn/static/prune/unstructured_prune_api.rst b/docs/zh_cn/api_cn/static/prune/unstructured_prune_api.rst new file mode 100644 index 0000000000000000000000000000000000000000..2622fe48055e378b05787098c02dbdd6f5dba86a --- /dev/null +++ b/docs/zh_cn/api_cn/static/prune/unstructured_prune_api.rst @@ -0,0 +1,121 @@ +非结构化稀疏 +================ + +UnstrucuturedPruner +---------- + +.. py:class:: paddleslim.prune.UnstructuredPruner(program, mode, ratio=0.5, threshold=1e-5, scope=None, place=None, skip_params_func=None) + +`源代码 `_ + +对于神经网络中的参数进行非结构化稀疏。非结构化稀疏是指,根据某些衡量指标,将不重要的参数置0。其不按照固定结构剪裁(例如一个通道等),这是和结构化剪枝的主要区别。 + +**参数:** + +- **program(paddle.static.Program)** - 一个paddle.static.Program对象,是待剪裁的模型。 +- **mode(str)** - 稀疏化的模式,目前支持的模式有:'ratio'和'threshold'。在'ratio'模式下,会给定一个固定比例,例如0.5,然后所有参数中重要性较低的50%会被置0。类似的,在'threshold'模式下,会给定一个固定阈值,例如1e-5,然后重要性低于1e-5的参数会被置0。 +- **ratio(float)** - 稀疏化比例期望,只有在 mode=='ratio' 时才会生效。 +- **threshold(float)** - 稀疏化阈值期望,只有在 mode=='threshold' 时才会生效。 +- **scope(paddle.static.Scope)** - 一个paddle.static.Scope对象,存储了所有变量的数值,默认(None)时表示paddle.static.global_scope。 +- **place(CPUPlace|CUDAPlace)** - 模型执行的设备,类型为CPUPlace或者CUDAPlace,默认(None)时代表CPUPlace。 +- **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。 + +**返回:** 一个UnstructuredPruner类的实例 + +**示例代码:** + +此示例不能直接运行,因为需要加载数据和模型,详细demo请参照 `这里 `_ + +.. code-block:: python + + from paddleslim.prune import UnstructuredPruner + pruner = UnstructuredPruner() + +.. + + .. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.step() + + 更新稀疏化的阈值,如果是'threshold'模式,则维持设定的阈值,如果是'ratio'模式,则根据优化后的模型参数和设定的比例,重新计算阈值。 + + **示例代码:** + + 此示例不能直接运行,因为需要加载数据和模型,详细demo请参照 `这里 `_ + + .. code-block:: python + + from paddleslim.prune import UnstructuredPruner + + pruner = UnstructuredPruner( + paddle.static.default_main_program(), 'ratio', scope=paddle.static.global_scope(), place=paddle.static.cpu_places()[0]) + pruner.step() + + .. + + .. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.update_params() + + 每一步优化后,重制模型中本来是0的权重。这一步通常用于模型evaluation和save之前,确保模型的稀疏率。 + + **示例代码:** + + 此示例不能直接运行,因为需要加载数据和模型,详细demo请参照 `这里 `_ + + .. code-block:: python + + from paddleslim.prune import UnstructuredPruner + + pruner = UnstructuredPruner( + paddle.static.default_main_program(), 'ratio', scope=paddle.static.global_scope(), place=paddle.static.cpu_places()[0]) + pruner.update_params() + + .. + + .. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.total_sparse(program) + + UnstructuredPruner中的静态方法,用于计算给定的模型(program)的稠密度(1-稀疏度)并返回。该方法为静态方法,是考虑到在单单做模型评价的时候,我们就不需要初始化一个UnstructuredPruner示例了。 + + **参数:** + + - **program(paddle.static.Program)** - 要计算稠密度的目标网络。 + + **返回:** + + - **density(float)** - 模型的稠密度。 + + **示例代码:** + + 此示例不能直接运行,因为需要加载数据和模型,详细demo请参照 `这里 `_ + + .. code-block:: python + + from paddleslim.prune import UnstructuredPruner + + density = UnstructuredPruner.total_sparse(paddle.static.default_main_program()) + + .. + + .. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.summarize_weights(program, ratio=0.1) + + 该函数用于估计预训练模型中参数的分布情况,尤其是在不清楚如何设置threshold的数值时,尤为有用。例如,当输入为ratio=0.1时,函数会返回一个数值v,而绝对值小于v的权重的个数占所有权重个数的(100*ratio%)。 + + **参数:** + + - **program(paddle.static.Program)** - 要分析权重分布的目标网络。 + - **ratio(float)** - 需要查看的比例情况,具体如上方法描述。 + + **返回:** + + - **threshold(float)** - 和输入ratio对应的阈值。开发者可以根据该阈值初始化UnstructuredPruner。 + + **示例代码:** + + 此示例不能直接运行,因为需要加载数据和模型,详细demo请参照 `这里 `_ + + .. code-block:: python + + from paddleslim.prune import UnstructuredPruner + + pruner = UnstructuredPruner( + paddle.static.default_main_program(), 'ratio', scope=paddle.static.global_scope(), place=paddle.static.cpu_places()[0]) + threshold = pruner.summarize_weights(paddle.static.default_main_program(), 1.0) + .. + diff --git a/paddleslim/dygraph/prune/unstructured_pruner.py b/paddleslim/dygraph/prune/unstructured_pruner.py index 44116fc84425af645efbb355ce01e1c6d3d711e3..c5234c792a237698a657f8e7656e77940f817c74 100644 --- a/paddleslim/dygraph/prune/unstructured_pruner.py +++ b/paddleslim/dygraph/prune/unstructured_pruner.py @@ -80,6 +80,28 @@ class UnstructuredPruner(): self.threshold = np.sort(np.abs(params_flatten))[max( 0, round(self.ratio * total_length) - 1)].item() + def summarize_weights(self, model, ratio=0.1): + """ + The function is used to get the weights corresponding to a given ratio + when you are uncertain about the threshold in __init__() function above. + For example, when given 0.1 as ratio, the function will print the weight value, + the abs(weights) lower than which count for 10% of the total numbers. + Args: + - model(paddle.nn.Layer): The model which have all the parameters. + - ratio(float): The ratio illustrated above. + Return: + - threshold(float): a threshold corresponding to the input ratio. + """ + data = [] + for name, sub_layer in model.named_sublayers(): + if not self._should_prune_layer(sub_layer): + continue + for param in sub_layer.parameters(include_sublayers=False): + data.append(np.array(param.value().get_tensor()).flatten()) + data = np.concatenate(data, axis=0) + threshold = np.sort(np.abs(data))[max(0, int(ratio * len(data) - 1))] + return threshold + def step(self): """ Update the threshold after each optimization step. @@ -116,7 +138,7 @@ class UnstructuredPruner(): It is static because during testing, we can calculate sparsity without initializing a pruner instance. Args: - - model(Paddle.Model): The sparse model. + - model(paddle.nn.Layer): The sparse model. Returns: - ratio(float): The model's density. """ diff --git a/paddleslim/prune/unstructured_pruner.py b/paddleslim/prune/unstructured_pruner.py index bc7104578f45f8620dfff9fb29e8d7fcc01e1f9e..8c0857df48f5623d5863db5e7acd90ec8b8b7b99 100644 --- a/paddleslim/prune/unstructured_pruner.py +++ b/paddleslim/prune/unstructured_pruner.py @@ -12,7 +12,6 @@ class UnstructuredPruner(): Args: - program(paddle.static.Program): The model to be pruned. - - batch_size(int): batch size. - mode(str): the mode to prune the model, must be selected from 'ratio' and 'threshold'. - ratio(float): the ratio to prune the model. Only set it when mode=='ratio'. Default: 0.5. - threshold(float): the threshold to prune the model. Only set it when mode=='threshold'. Default: 1e-5. @@ -23,7 +22,6 @@ class UnstructuredPruner(): def __init__(self, program, - batch_size, mode, ratio=0.5, threshold=1e-5, diff --git a/tests/dygraph/test_unstructured_prune.py b/tests/dygraph/test_unstructured_prune.py index 5fe74dbecac8c41cee3df5a2a77e2c8eb0a5698f..a8a123ed607a2130378ffff969efa39b37ddec79 100644 --- a/tests/dygraph/test_unstructured_prune.py +++ b/tests/dygraph/test_unstructured_prune.py @@ -3,7 +3,7 @@ sys.path.append("../../") import unittest import paddle import numpy as np -from paddleslim.dygraph.prune.unstructured_pruner import UnstructuredPruner +from paddleslim import UnstructuredPruner from paddle.vision.models import mobilenet_v1 @@ -37,6 +37,20 @@ class TestUnstructuredPruner(unittest.TestCase): self.pruner.update_params() self.assertEqual(cur_density, UnstructuredPruner.total_sparse(self.net)) + def test_summarize_weights(self): + max_value = -float("inf") + threshold = self.pruner.summarize_weights(self.net, 1.0) + for name, sub_layer in self.net.named_sublayers(): + if not self.pruner._should_prune_layer(sub_layer): + continue + for param in sub_layer.parameters(include_sublayers=False): + max_value = max( + max_value, + np.max(np.abs(np.array(param.value().get_tensor())))) + print("The returned threshold is {}.".format(threshold)) + print("The max_value is {}.".format(max_value)) + self.assertEqual(max_value, threshold) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_unstructured_pruner.py b/tests/test_unstructured_pruner.py index cdfc68e8c297dbbc323b8417e6d38edf40d7907f..45ce381f9164af11f139178b9b29740cebf9f62f 100644 --- a/tests/test_unstructured_pruner.py +++ b/tests/test_unstructured_pruner.py @@ -42,7 +42,7 @@ class TestUnstructuredPruner(StaticCase): exe.run(self.startup_program, scope=self.scope) self.pruner = UnstructuredPruner( - self.main_program, 16, 'ratio', scope=self.scope, place=place) + self.main_program, 'ratio', scope=self.scope, place=place) def test_unstructured_prune(self): for param in self.main_program.global_block().all_parameters():