未验证 提交 358b51af 编写于 作者: M minghaoBD 提交者: GitHub

optimize unstructured pruner using GMP (#881)

* optimize unstructured pruner using GMP

* add uniittests for GMP pruner

* add training scripts for fleet and gmp; add conv1x1 pruning

* add conv1x1 pruning to dygraph

* unify hyper parameters across methods and scripts, static graph

* unify hyper parameters across methods and scripts, dynamic graph

* refine test cases

* fix singlecard training in static graph

* refine readmes

* update unstructured pruner api doc static

* update api docs dygraph

* fix bugs in api docs

* rename UnstructuredPrunerGMP to GMPUnstructuredPruner

* fix some format issues

* remove redundant methods/arguments

* description for gmp configs

* fix minor bugs

* fix test cases accordingly
Co-authored-by: Nwhs <wanghaoshuang@baidu.com>
上级 7e5a588f
......@@ -4,13 +4,13 @@
在模型压缩中,常见的稀疏方式为结构化和非结构化稀疏,前者在某个特定维度(特征通道、卷积核等等)上进行稀疏化操作;后者以每一个参数为单元进行稀疏化,并不会改变参数矩阵的形状,所以更加依赖于硬件对稀疏后矩阵运算的加速能力。本目录即在PaddlePaddle和PaddleSlim框架下开发的非结构化稀疏算法,`MobileNetV1``ImageNet`上的稀疏化实验中,剪裁率55.19%,达到无损的表现。
本示例将演示基于不同的剪裁模式(阈值/比例)进行非结构化稀疏。默认会自动下载并使用`CIFAR-10`数据集。当前示例目前支持`MobileNetV1`,使用其他模型可以按照下面的训练代码示例进行API调用。
本示例将演示基于不同的剪裁模式(阈值/比例)进行非结构化稀疏。默认会自动下载并使用`CIFAR-10`数据集。当前示例目前支持`MobileNetV1`,使用其他模型可以按照下面的训练代码示例进行API调用。此外,为保证大稀疏度(75%+)下模型的精度,我们引入了`GMP`训练策略,详细的介绍和使用请参照[介绍](../../unstructured_prune/README_GMP.md)
## 版本要求
```bash
python3.5+
paddlepaddle>=2.0.0
paddleslim>=2.1.0
paddlepaddle>=2.2.0
paddleslim>=2.2.0
```
请参照github安装[paddlepaddle](https://github.com/PaddlePaddle/Paddle)[paddleslim](https://github.com/PaddlePaddle/PaddleSlim)
......@@ -62,34 +62,34 @@ def _get_skip_params(model):
按照阈值剪裁:
```bash
python3.7 train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01
python train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01
```
按照比例剪裁(训练速度较慢,推荐按照阈值剪裁):
```bash
python3.7 train.py --data imagenet --lr 0.05 --pruning_mode ratio --ratio 0.55
python train.py --data imagenet --lr 0.05 --pruning_mode ratio --ratio 0.55
```
GPU多卡训练:
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
python3.7 -m paddle.distributed.launch \
python -m paddle.distributed.launch \
--gpus="0,1,2,3" \
--log_dir="train_mbv1_imagenet_threshold_001_log" \
train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01 --batch_size 256
train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01 --batch_size 64
```
**注意**,这里的batch_size为单卡上的。
恢复训练(请替代命令中的`dir/to/the/saved/pruned/model``INTERRUPTED_EPOCH`):
恢复训练(请替代命令中的`dir/to/the/saved/pruned/model``LAST_EPOCH`):
```bash
python3.7 train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01 \
--pretrained_model dir/to/the/saved/pruned/model --resume_epoch INTERRUPTED_EPOCH
python train.py --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01 \
--pretrained_model dir/to/the/saved/pruned/model --resume_epoch LAST_EPOCH
```
## 推理:
```bash
python3.7 evaluate.py --pruned_model models/model-pruned.pdparams --data imagenet
python evaluate.py --pruned_model models/model.pdparams --data imagenet
```
**注意**,上述`pruned_model` 参数应该指向pdparams文件。
......@@ -118,14 +118,14 @@ for epoch in range(epochs):
if epoch % args.model_period == 0:
# STEP4: same purpose as STEP3
pruner.update_params()
paddle.save(model.state_dict(), "model-pruned.pdparams")
paddle.save(opt.state_dict(), "opt-pruned.pdopt")
paddle.save(model.state_dict(), "model.pdparams")
paddle.save(opt.state_dict(), "model.pdopt")
```
剪裁后测试代码示例:
```python
model = mobilenet_v1(num_classes=class_dim, pretrained=True)
model.set_state_dict(paddle.load("model-pruned.pdparams"))
model.set_state_dict(paddle.load("model.pdparams"))
#注意,total_sparse为静态方法(static method),可以不创建实例(instance)直接调用,方便只做测试的写法。
print(UnstructuredPruner.total_sparse(model))
test()
......@@ -133,8 +133,8 @@ test()
更多使用参数请参照shell文件或者运行如下命令查看:
```bash
python3.7 train.py --h
python3.7 evaluate.py --h
python train.py --h
python evaluate.py --h
```
## 实验结果
......@@ -144,5 +144,7 @@ python3.7 evaluate.py --h
| MobileNetV1 | ImageNet | Baseline | - | 70.99%/89.68% | - | - | - |
| MobileNetV1 | ImageNet | ratio | -55.19% | 70.87%/89.80% (-0.12%/+0.12%) | 0.005 | - | 68 |
| MobileNetV1 | ImageNet | threshold | -49.49% | 71.22%/89.78% (+0.23%/+0.10%) | 0.05 | 0.01 | 93 |
| MobileNetV1 | Imagenet | ratio, 1x1conv, GMP | 75% | 70.49%/89.48% (-0.5%/-0.20%) | 0.005 | - | 108 |
| MobileNetV1 | Imagenet | ratio, 1x1conv, GMP | 80% | 70.02%/89.26% (-0.97%/-0.42%) | 0.005 | - | 108 |
| YOLO v3 | VOC | - | - |76.24% | - | - | - |
| YOLO v3 | VOC |threshold | -56.50% | 77.21% (+0.97%) | 0.001 | 0.01 | 150k iterations |
......@@ -23,7 +23,7 @@ add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('batch_size', int, 64, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model', str, "MobileNet", "The target model.")
add_arg('pruned_model', str, "dymodels/model-pruned.pdparams", "Whether to use pretrained model.")
add_arg('pruned_model', str, "dymodels/model.pdparams", "Whether to use pretrained model.")
add_arg('data', str, "cifar10", "Which data to use. 'cifar10' or 'imagenet'.")
add_arg('log_period', int, 100, "Log period in batches.")
# yapf: enable
......@@ -92,7 +92,7 @@ def compress(args):
acc_top5_ns, dtype="object"))))
model.set_state_dict(paddle.load(args.pruned_model))
_logger.info("The current density of the pruned model is: {}%".format(
_logger.info("The current sparsity of the pruned model is: {}%".format(
round(100 * UnstructuredPruner.total_sparse(model), 2)))
test(0)
......
#!/bin/bash
export CUDA_VISIBLE_DEVICES=3
python3.7 evaluate.py \
python evaluate.py \
--pruned_model="models/model-pruned.pdparams" \
--data="cifar10"
#!/bin/bash
export CUDA_VISIBLE_DEVICES=3
python3.7 evaluate.py \
python evaluate.py \
--pruned_model="models/model-pruned.pdparams" \
--data="imagenet"
......@@ -3,7 +3,7 @@ import os
import sys
import argparse
import numpy as np
from paddleslim import UnstructuredPruner
from paddleslim import UnstructuredPruner, GMPUnstructuredPruner
sys.path.append(
os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir))
from utility import add_arguments, print_arguments
......@@ -22,33 +22,42 @@ _logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64 * 4, "Minibatch size.")
add_arg('batch_size_for_validation', int, 64, "Minibatch size for validation.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.")
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('ratio', float, 0.3, "The ratio to set zeros, the smaller part bounded by the ratio will be zeros.")
add_arg('pruning_mode', str, 'ratio', "the pruning mode: whether by ratio or by threshold.")
add_arg('threshold', float, 0.001, "The threshold to set zeros.")
add_arg('num_epochs', int, 120, "The number of total epochs.")
add_arg('batch_size', int, 64, "Minibatch size. Default: 64")
add_arg('batch_size_for_validation', int, 64, "Minibatch size for validation. Default: 64")
add_arg('lr', float, 0.05, "The learning rate used to fine-tune pruned model. Default: 0.05")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy. Default: piecewise_decay")
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter. Default: 3e-5")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate. Default: 0.9")
add_arg('ratio', float, 0.55, "The ratio to set zeros, the smaller part bounded by the ratio will be zeros. Default: 0.55")
add_arg('pruning_mode', str, 'ratio', "the pruning mode: whether by ratio or by threshold. Default: ratio")
add_arg('threshold', float, 0.01, "The threshold to set zeros. Default: 0.01")
add_arg('num_epochs', int, 120, "The number of total epochs. Default: 120")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('data', str, "cifar10", "Which data to use. 'cifar10' or 'imagenet'.")
add_arg('log_period', int, 100, "Log period in batches.")
add_arg('test_period', int, 1, "Test period in epoches.")
add_arg('data', str, "imagenet", "Which data to use. 'cifar10' or 'imagenet'. Default: imagenet")
add_arg('log_period', int, 100, "Log period in batches. Default: 100")
add_arg('test_period', int, 5, "Test period in epoches. Default: 5")
add_arg('pretrained_model', str, None, "The pretrained model the load. Default: None.")
add_arg('model_path', str, "./models", "The path to save model.")
add_arg('checkpoint', str, None, "The checkpoint path to resume training. Default: None.")
add_arg('model_path', str, "./models", "The path to save model. Default: ./models")
add_arg('model_period', int, 10, "The period to save model in epochs.")
add_arg('resume_epoch', int, -1, "The epoch to resume training.")
add_arg('num_workers', int, 16, "number of workers when loading dataset.")
add_arg('last_epoch', int, -1, "The last epoch we'll train from. Default: -1")
add_arg('num_workers', int, 16, "number of workers when loading dataset. Default: 16")
add_arg('stable_epochs', int, 0, "The epoch numbers used to stablize the model before pruning. Default: 0")
add_arg('pruning_epochs', int, 60, "The epoch numbers used to prune the model by a ratio step. Default: 60")
add_arg('tunning_epochs', int, 60, "The epoch numbers used to tune the after-pruned models. Default: 60")
add_arg('pruning_steps', int, 100, "How many times you want to increase your ratio during training. Default: 100")
add_arg('initial_ratio', float, 0.15, "The initial pruning ratio used at the start of pruning stage. Default: 0.15")
add_arg('pruning_strategy', str, 'base', "Which training strategy to use in pruning, we only support base and gmp for now. Default: base")
add_arg('prune_params_type', str, None, "Which kind of params should be pruned, we only support None (all but norms) and conv1x1_only for now. Default: None")
# yapf: enable
def piecewise_decay(args, step_per_epoch, model):
bd = [step_per_epoch * e for e in args.step_epochs]
lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = paddle.optimizer.lr.PiecewiseDecay(boundaries=bd, values=lr)
last_iter = (1 + args.last_epoch) * step_per_epoch
learning_rate = paddle.optimizer.lr.PiecewiseDecay(
boundaries=bd, values=lr, last_epoch=last_iter)
optimizer = paddle.optimizer.Momentum(
learning_rate=learning_rate,
......@@ -59,8 +68,11 @@ def piecewise_decay(args, step_per_epoch, model):
def cosine_decay(args, step_per_epoch, model):
last_iter = (1 + args.last_epoch) * step_per_epoch
learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=args.lr, T_max=args.num_epochs * step_per_epoch)
learning_rate=args.lr,
T_max=args.num_epochs * step_per_epoch,
last_epoch=last_iter)
optimizer = paddle.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
......@@ -76,11 +88,24 @@ def create_optimizer(args, step_per_epoch, model):
return cosine_decay(args, step_per_epoch, model)
def compress(args):
if args.use_gpu:
place = paddle.set_device('gpu')
def create_unstructured_pruner(model, args, configs=None):
if configs is None:
return UnstructuredPruner(
model,
mode=args.pruning_mode,
ratio=args.ratio,
threshold=args.threshold,
prune_params_type=args.prune_params_type)
else:
place = paddle.set_device('cpu')
return GMPUnstructuredPruner(
model,
ratio=args.ratio,
prune_params_type=args.prune_params_type,
configs=configs)
def compress(args):
place = paddle.set_device('gpu')
trainer_num = paddle.distributed.get_world_size()
use_data_parallel = trainer_num != 1
......@@ -132,11 +157,38 @@ def compress(args):
if ParallelEnv().nranks > 1:
model = paddle.DataParallel(model)
if args.pretrained_model is not None:
model.set_state_dict(paddle.load(args.pretrained_model))
opt, learning_rate = create_optimizer(args, step_per_epoch, model)
if args.checkpoint is not None and args.last_epoch > -1:
if args.checkpoint.endswith('pdparams'):
args.checkpoint = args.checkpoint[:-9]
if args.checkpoint.endswith('pdopt'):
args.checkpoint = args.checkpoint[:-6]
model.set_state_dict(paddle.load(args.checkpoint + ".pdparams"))
opt.set_state_dict(paddle.load(args.checkpoint + ".pdopt"))
elif args.pretrained_model is not None:
if args.pretrained_model.endswith('pdparams'):
args.pretrained_model = args.pretrained_model[:-9]
if args.pretrained_model.endswith('pdopt'):
args.pretrained_model = args.pretrained_model[:-6]
model.set_state_dict(paddle.load(args.pretrained_model + ".pdparams"))
if args.pruning_strategy == 'gmp':
# GMP pruner step 0: define configs. No need to do this if you are not using 'gmp'
configs = {
'stable_iterations': args.stable_epochs * step_per_epoch,
'pruning_iterations': args.pruning_epochs * step_per_epoch,
'tunning_iterations': args.tunning_epochs * step_per_epoch,
'resume_iteration': (args.last_epoch + 1) * step_per_epoch,
'pruning_steps': args.pruning_steps,
'initial_ratio': args.initial_ratio,
}
else:
configs = None
# GMP pruner step 1: initialize a pruner object
pruner = create_unstructured_pruner(model, args, configs=configs)
def test(epoch):
model.eval()
acc_top1_ns = []
......@@ -193,7 +245,9 @@ def compress(args):
opt.step()
learning_rate.step()
opt.clear_grad()
# GMP pruner step 2: step() to update ratios and other internal states of the pruner.
pruner.step()
train_run_cost += time.time() - train_start
total_samples += args.batch_size
......@@ -215,26 +269,23 @@ def compress(args):
reader_start = time.time()
pruner = UnstructuredPruner(
model,
mode=args.pruning_mode,
ratio=args.ratio,
threshold=args.threshold)
for i in range(args.resume_epoch + 1, args.num_epochs):
for i in range(args.last_epoch + 1, args.num_epochs):
train(i)
# GMP pruner step 3: update params before summrizing sparsity, saving model or evaluation.
pruner.update_params()
if (i + 1) % args.test_period == 0:
pruner.update_params()
_logger.info(
"The current density of the pruned model is: {}%".format(
"The current sparsity of the pruned model is: {}%".format(
round(100 * UnstructuredPruner.total_sparse(model), 2)))
test(i)
if (i + 1) % args.model_period == 0:
pruner.update_params()
paddle.save(model.state_dict(),
os.path.join(args.model_path, "model-pruned.pdparams"))
os.path.join(args.model_path, "model.pdparams"))
paddle.save(opt.state_dict(),
os.path.join(args.model_path, "opt-pruned.pdopt"))
os.path.join(args.model_path, "model.pdopt"))
def main():
......
#!/bin/bash
export CUDA_VISIBLE_DEVICES=3
python3.7 train.py \
python train.py \
--batch_size=256 \
--lr=0.05 \
--threshold=0.01 \
......
#!/bin/bash
export CUDA_VISIBLE_DEVICES=3
python3.7 train.py \
--batch_size=256 \
--lr=0.05 \
--threshold=0.01 \
--pruning_mode="threshold" \
--data="imagenet" \
#!/bin/bash
python -m paddle.distributed.launch \
--gpus='0,1,2,3' \
--log_dir='log' \
train.py \
--batch_size 64 \
--data imagenet \
--pruning_mode ratio \
--ratio 0.55 \
--lr 0.05
#!/bin/bash
python -m paddle.distributed.launch \
--gpus='0,1,2,3' \
--log_dir='log' \
train.py \
--batch_size 64 \
--data imagenet \
--pruning_mode ratio \
--ratio 0.75 \
--lr 0.005 \
--num_epochs 108 \
--step_epochs 71 88 \
--initial_ratio 0.15 \
--pruning_steps 100 \
--stable_epochs 0 \
--pruning_epochs 54 \
--tunning_epochs 54 \
--pruning_strategy gmp
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python \
train.py \
--batch_size 64 \
--data imagenet \
--pruning_mode ratio \
--ratio 0.55 \
--lr 0.05
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python \
train.py \
--batch_size 64 \
--data imagenet \
--pruning_mode ratio \
--ratio 0.75 \
--lr 0.005 \
--num_epochs 108 \
--step_epochs 71 88 \
--initial_ratio 0.15 \
--pruning_steps 100 \
--stable_epochs 0 \
--pruning_epochs 54 \
--tunning_epochs 54 \
--pruning_strategy gmp
......@@ -195,19 +195,7 @@ class ImageNetDataset(Dataset):
with open(train_file_list) as flist:
full_lines = [line.strip() for line in flist]
np.random.shuffle(full_lines)
if os.getenv('PADDLE_TRAINING_ROLE'):
# distributed mode if the env var `PADDLE_TRAINING_ROLE` exits
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
per_node_lines = len(full_lines) // trainer_count
lines = full_lines[trainer_id * per_node_lines:(
trainer_id + 1) * per_node_lines]
print(
"read images from %d, length: %d, lines length: %d, total: %d"
% (trainer_id * per_node_lines, per_node_lines,
len(lines), len(full_lines)))
else:
lines = full_lines
lines = full_lines
self.data = [line.split() for line in lines]
else:
with open(val_file_list) as flist:
......
......@@ -4,13 +4,13 @@
在模型压缩中,常见的稀疏方式为结构化和非结构化稀疏,前者在某个特定维度(特征通道、卷积核等等)上进行稀疏化操作;后者以每一个参数为单元进行稀疏化,并不会改变参数矩阵的形状,所以更加依赖于硬件对稀疏后矩阵运算的加速能力。本目录即在PaddlePaddle和PaddleSlim框架下开发的非结构化稀疏算法,`MobileNetV1``ImageNet`上的稀疏化实验中,剪裁率55.19%,达到无损的表现。
本示例将演示基于不同的剪裁模式(阈值/比例)进行非结构化稀疏。默认会自动下载并使用`MNIST`数据集。当前示例目前支持`MobileNetV1`,使用其他模型可以按照下面的**训练代码示例**>行API调用
本示例将演示基于不同的剪裁模式(阈值/比例)进行非结构化稀疏。默认会自动下载并使用`MNIST`数据集。当前示例目前支持`MobileNetV1`,使用其他模型可以按照下面的**训练代码示例**行API调用。另外,为提升大稀疏度下的稀疏模型精度,我们引入了`GMP`训练策略(`Gradual Magnititude Pruning`),使得稀疏度在训练过程中逐步增加。`GMP`训练策略在[这里](./README_GMP.md)介绍
## 版本要求
```bash
python3.5+
paddlepaddle>=2.0.0
paddleslim>=2.1.0
paddlepaddle>=2.2.0
paddleslim>=2.2.0
```
请参照github安装[PaddlePaddle](https://github.com/PaddlePaddle/Paddle)[PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim)
......@@ -66,27 +66,40 @@ def _get_skip_params(program):
## 训练
按照阈值剪裁:
按照阈值剪裁,GPU单卡训练
```bash
CUDA_VISIBLE_DEVICES=2,3 python3.7 train.py --batch_size 512 --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01
CUDA_VISIBLE_DEVICES=0 python train.py --batch_size 64 --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01
```
按照比例剪裁(训练速度较慢,推荐按照阈值剪裁)
按照比例剪裁,GPU单卡训练
```bash
CUDA_VISIBLE_DEVICES=2,3 python3.7 train.py --batch_size 512 --data imagenet --lr 0.05 --pruning_mode ratio --ratio 0.55
CUDA_VISIBLE_DEVICES=0 python train.py --batch_size 64 --data imagenet --lr 0.05 --pruning_mode ratio --ratio 0.55
```
恢复训练(请替代命令中的`dir/to/the/saved/pruned/model``INTERRUPTED_EPOCH`):
GPU多卡训练:由于静态图多卡训练方式与非结构化稀疏中的mask逻辑存在兼容性问题,会在一定程度上影响训练精度,我们建议使用[Fleet](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/user_guides/howto/training/fleet_api_howto_cn.html)方式启动稀疏化多卡训练,实测精度与单卡一致。同时,为帮助开发者将`with_data_parallel`方式配置的分布式代码转换为`Fleet`我们在[示例代码](./train.py)里面也用`"Fleet step"`清晰标注出了用代码需要做的更改
```bash
python -m paddle.distributed.launch \
--selected_gpus="0,1,2,3" \
train.py \
--batch_size 64 \
--data imagenet \
--lr 0.05 \
--pruning_mode ratio \
--ratio 0.55 \
--is_distributed True
```
恢复训练(请替代命令中的`dir/to/the/saved/pruned/model``LAST_EPOCH`):
```
CUDA_VISIBLE_DEVICES=2,3 python3.7 train.py --batch_size 512 --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01 \
--pretrained_model dir/to/the/saved/pruned/model --resume_epoch INTERRUPTED_EPOCH
CUDA_VISIBLE_DEVICES=0 python train.py --batch_size 64 --data imagenet --lr 0.05 --pruning_mode threshold --threshold 0.01 \
--checkpoint dir/to/the/saved/pruned/model --last_epoch LAST_EPOCH
```
**注意**,上述命令中的`batch_size`多张卡上总的`batch_size`,即一张卡的`batch_size`为256
**注意**,上述命令中的`batch_size`单张卡上的`batch_size`
## 推理
```bash
CUDA_VISIBLE_DEVICES=0 python3.7 evaluate.py --pruned_model models/ --data imagenet
CUDA_VISIBLE_DEVICES=0 python evaluate.py --pruned_model models/ --data imagenet
```
剪裁训练代码示例:
......@@ -146,8 +159,8 @@ test()
更多使用参数请参照shell文件,或者通过运行以下命令查看:
```bash
python3.7 train.py --h
python3.7 evaluate.py --h
python train.py --h
python evaluate.py --h
```
## 实验结果
......@@ -155,7 +168,11 @@ python3.7 evaluate.py --h
| 模型 | 数据集 | 压缩方法 | 压缩率| Top-1/Top-5 Acc | lr | threshold | epoch |
|:--:|:---:|:--:|:--:|:--:|:--:|:--:|:--:|
| MobileNetV1 | ImageNet | Baseline | - | 70.99%/89.68% | - | - | - |
| MobileNetV1 | ImageNet | ratio | -55.19% | 70.87%/89.80% (-0.12%/+0.12%) | 0.05 | - | 68 |
| MobileNetV1 | ImageNet | threshold | -49.49% | 71.22%/89.78% (+0.23%/+0.10%) | 0.05 | 0.01 | 93 |
| MobileNetV1 | ImageNet | ratio | 55.19% | 70.87%/89.80% (-0.12%/+0.12%) | 0.05 | - | 68 |
| MobileNetV1 | ImageNet | threshold | 49.49% | 71.22%/89.78% (+0.23%/+0.10%) | 0.05 | 0.01 | 93 |
| MobileNetV1 | Imagenet | ratio, 1x1conv, GMP | 75% | 70.49%/89.48% (-0.5%/-0.20%) | 0.005 | - | 108 |
| MobileNetV1 | Imagenet | ratio, 1x1conv, GMP | 80% | 70.02%/89.26% (-0.97%/-0.42%) | 0.005 | - | 108 |
| YOLO v3 | VOC | - | - |76.24% | - | - | - |
| YOLO v3 | VOC |threshold | -56.50% | 77.21%(+0.97%) | 0.001 | 0.01 |150k iterations|
| YOLO v3 | VOC |threshold | 56.50% | 77.21%(+0.97%) | 0.001 | 0.01 |150k iterations|
**注意**,上述`ratio, 1x1conv, GMP`代表根据比例剪裁,只稀疏化1x1conv层参数,并且使用GMP训练方式。
# 非结构化稀疏 -- GMP训练方式介绍与示例
## 简介
承接一步到位的稀疏化训练方式(即根据预训练模型只做一次剪裁,再去`finetune`),我们为`PaddleSlim`引入了一种逐步增加稀疏度的训练方式:`GMP``Gradual Magnitude Pruning`)。详细介绍可以参考[博客](https://neuralmagic.com/blog/pruning-gmp/)。这个训练策略的引入是为了解决大稀疏度训练时,精度恢复困难的问题。最终实验证明,`MobileNetV1-ImageNet`任务,在75%稀疏化conv1x1的实验中,使用`GMP`算法会使模型精度提升1%以上。
## 具体介绍
概念上讲,由于网络的参数的重要性与绝对值大小不是严格的对应关系,且随着数值增大,该对应关系会越薄弱,所以一步剪裁会对某些绝对值大的权重不友好,进而影响稀疏化模型的精度。而`GMP`算法采用逐步增加稀疏度的训练方式,增加了权重空间的灵活性,可以使得权重在稀疏化训练中去做适应性优化,从而在一定程度上保证精度。`GMP`将稀疏化训练过程分为三个阶段:稳定训练(`stable phase`)、剪裁训练(`pruning phase`)和调优训练(`finetuning phase`)。三个阶段中逐步增加稀疏度,且有各自的学习率变化。
- 稳定阶段:该阶段epoch较小,用于正式剪裁前的稳定模型。我们测试来看,对于模型精度的影响较小。这是因为我们调用了`pretrained_model`。开发者可以根据需求自行调整,一般取0(有预训练模型)或者2~5(无预训练模型)即可。稀疏度保持为0、学习率保持初始数值。
- 剪裁阶段:该阶段的epoch/iteration数目等于全量训练的一半,且在该过程中,稀疏度从某一个初始值(`inital ratio`)增加到最终值(`target ratio`),且增加的幅度逐渐减小。数学上,`ratio`变化为:
$ratio = ((i / pruning_steps) - 1.0) ^ 3 + 1$
$ratio_scaled = initial_ratio + (target_ratio - initial_ratio) * ratio$
上述$ratio$为一个三次函数所得,在$i == pruning_steps$时,$ratio=1$且梯度为0,保证了稀疏度随训练增加且增加速度减小。其中,$pruning_steps$代表$ratio$增加的次数,一般每一个epoch增加2次即可,我们实验发现,当剪裁次数过多时,也会不利于精度恢复。$ratio_scaled$则是根据输入的初始稀疏度和目标稀疏度,对$ratio$进行缩放和平移所得。
此外,学习率在该过程中保持为初始数值,不衰减。
- 调优阶段:该阶段的epoch/iteration数目等于全量训练的一半,且在该过程中,稀疏度保持为最终值(`target ratio`)。学习率衰减。`piecewise_decay`方式时,将调优阶段等分,设置为衰减边界即可。
## 参数介绍
根据上一节的具体介绍,我们归纳参数及其设置规则如下:
- stable_epochs: 0 (pretrained_model) / 2-5 (from-scratch model)
- pruning_epochs: total_epochs / 2
- tunning_epochs: total_epochs / 2
- pruning_steps: pruning_epochs * 2
- initial_ratio: 0.15
- lr: 预训练时的一个中间lr即可。例如,`MobileNetV1-ImageNet`预训练时,学习率由0.1降低为0.0001,我们在稀疏化训练时就采用了0.005。
- learning_rate_strategy: 目前仅支持piecewise_decay。cosine_decay的方式正在开发中。
- piecewise_decay_bound: $stable_epochs+pruning_epochs+tunning_epochs/3$, $stable_epochs+pruning_epochs+2*tunning_epochs/3$
## 代码调用
本节介绍如何在静态图和动态图中方便的调用`GMP`训练策略,以达到保证精度的目标。
```python
# 将上述参数定义为配置字典
configs = {
'stable_iterations': args.stable_epochs * step_per_epoch,
'pruning_iterations': args.pruning_epochs * step_per_epoch,
'tunning_iterations': args.tunning_epochs * step_per_epoch,
'resume_iteration': (args.last_epoch + 1) * step_per_epoch,
'pruning_steps': args.pruning_steps,
'initial_ratio': args.initial_ratio,
}
# 将configs作为参数初始化GMPUnstructuredPruner即可。
# 静态图
pruner = GMPUnstructuredPruner(
train_program,
mode='ratio', # 模式必须为'ratio','threshold'模式与GMP不兼容。
ratio=args.ratio,
place=place,
configs=configs)
# 动态图
pruner = GMPUnstructuredPruner(
model,
mode='ratio', # 模式必须为'ratio','threshold'模式与GMP不兼容。
ratio=args.ratio,
configs=configs)
```
后续调用与正常训练无异,请参照[静态图](./README.md)[动态图](../dygraph/unstructured_pruning/README.md)
## 实验结果
| 模型 | 数据集 | 压缩方法 | 压缩率| Top-1/Top-5 Acc | lr | threshold | epoch |
|:--:|:---:|:--:|:--:|:--:|:--:|:--:|:--:|
| MobileNetV1 | ImageNet | Baseline | - | 70.99%/89.68% | - | - | - |
| MobileNetV1 | Imagenet | ratio, 1x1conv | 75% | 68.76%/88.91% (-2.23%/-0.77%) | 0.005 | - | 108 |
| MobileNetV1 | Imagenet | ratio, 1x1conv, GMP | 75% | 70.49%/89.48% (-0.50%/-0.20%) | 0.005 | - | 108 |
| MobileNetV1 | Imagenet | ratio, 1x1conv, GMP | 80% | 70.02%/89.26% (-0.97%/-0.42%) | 0.005 | - | 108 |
......@@ -96,9 +96,10 @@ def compress(args):
acc_top1_ns = []
acc_top5_ns = []
_logger.info("The current density of the inference model is {}%".format(
round(100 * UnstructuredPruner.total_sparse(
paddle.static.default_main_program()), 2)))
_logger.info(
"The current sparsity of the inference model is {}%".format(
round(100 * UnstructuredPruner.total_sparse(
paddle.static.default_main_program()), 2)))
for batch_id, data in enumerate(valid_loader):
start_time = time.time()
acc_top1_n, acc_top5_n = exe.run(
......
#!/bin/bash
export CUDA_VISIBLE_DEVICES=3
export FLAGS_fraction_of_gpu_memory_to_use=0.98
python3.7 evaluate.py \
python evaluate.py \
--pruned_model="models" \
--data="imagenet"
#!/bin/bash
export CUDA_VISIBLE_DEVICES=3
python3.7 evaluate.py \
python evaluate.py \
--pruned_model="models" \
--data="mnist"
......@@ -7,38 +7,47 @@ import functools
import time
import numpy as np
import paddle.fluid as fluid
from paddleslim.prune.unstructured_pruner import UnstructuredPruner
from paddleslim.prune.unstructured_pruner import UnstructuredPruner, GMPUnstructuredPruner
from paddleslim.common import get_logger
sys.path.append(os.path.join(os.path.dirname("__file__"), os.path.pardir))
import models
from utility import add_arguments, print_arguments
import paddle.vision.transforms as T
from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy
from paddle.fluid.incubate.fleet.base import role_maker
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64 * 4, "Minibatch size.")
add_arg('batch_size_for_validation', int, 64, "Minibatch size for validation.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('batch_size', int, 64, "Minibatch size. Default: 64")
add_arg('batch_size_for_validation', int, 64, "Minibatch size for validation. Default: 64")
add_arg('model', str, "MobileNet", "The target model.")
add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretrained", "Whether to use pretrained model.")
add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.")
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('threshold', float, 1e-5, "The threshold to set zeros, the abs(weights) lower than which will be zeros.")
add_arg('pruning_mode', str, 'ratio', "the pruning mode: whether by ratio or by threshold.")
add_arg('ratio', float, 0.5, "The ratio to set zeros, the smaller portion will be zeros.")
add_arg('num_epochs', int, 120, "The number of total epochs.")
add_arg('pretrained_model', str, None, "Whether to use pretrained model. Default: None")
add_arg('checkpoint', str, None, "The model to load for resuming training. Default: None")
add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model. Default: 0.1")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy. Default: piecewise_decay")
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter. Default: 3e-5")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate. Default: 0.9")
add_arg('pruning_strategy', str, 'base', "The pruning strategy, currently we support base and gmp. Default: base")
add_arg('threshold', float, 0.01, "The threshold to set zeros, the abs(weights) lower than which will be zeros. Default: 0.01")
add_arg('pruning_mode', str, 'ratio', "the pruning mode: whether by ratio or by threshold. Default: ratio")
add_arg('ratio', float, 0.55, "The ratio to set zeros, the smaller portion will be zeros. Default: 0.55")
add_arg('num_epochs', int, 120, "The number of total epochs. Default: 120")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'.")
add_arg('log_period', int, 100, "Log period in batches.")
add_arg('test_period', int, 10, "Test period in epoches.")
add_arg('model_path', str, "./models", "The path to save model.")
add_arg('model_period', int, 10, "The period to save model in epochs.")
add_arg('resume_epoch', int, -1, "The epoch to resume training.")
add_arg('data', str, "imagenet", "Which data to use. 'mnist' or 'imagenet'. Default: imagenet")
add_arg('log_period', int, 100, "Log period in batches. Default: 100")
add_arg('test_period', int, 5, "Test period in epoches. Default: 5")
add_arg('model_path', str, "./models", "The path to save model. Default: ./models")
add_arg('model_period', int, 10, "The period to save model in epochs. Default: 10")
add_arg('last_epoch', int, -1, "The last epoch we could train from. Default: -1")
add_arg('stable_epochs', int, 0, "The epoch numbers used to stablize the model before pruning. Default: 0")
add_arg('pruning_epochs', int, 60, "The epoch numbers used to prune the model by a ratio step. Default: 60")
add_arg('tunning_epochs', int, 60, "The epoch numbers used to tune the after-pruned models. Default: 60")
add_arg('pruning_steps', int, 120, "How many times you want to increase your ratio during training. Default: 120")
add_arg('initial_ratio', float, 0.15, "The initial pruning ratio used at the start of pruning stage. Default: 0.15")
add_arg('prune_params_type', str, None, "Which kind of params should be pruned, we only support None (all but norms) and conv1x1_only for now. Default: None")
# yapf: enable
model_list = models.__all__
......@@ -47,7 +56,9 @@ model_list = models.__all__
def piecewise_decay(args, step_per_epoch):
bd = [step_per_epoch * e for e in args.step_epochs]
lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = paddle.optimizer.lr.PiecewiseDecay(boundaries=bd, values=lr)
last_iter = (1 + args.last_epoch) * step_per_epoch
learning_rate = paddle.optimizer.lr.PiecewiseDecay(
boundaries=bd, values=lr, last_epoch=last_iter)
optimizer = paddle.optimizer.Momentum(
learning_rate=learning_rate,
......@@ -57,8 +68,11 @@ def piecewise_decay(args, step_per_epoch):
def cosine_decay(args, step_per_epoch):
last_iter = (1 + args.last_epoch) * step_per_epoch
learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=args.lr, T_max=args.num_epochs * step_per_epoch)
learning_rate=args.lr,
T_max=args.num_epochs * step_per_epoch,
last_epoch=last_iter)
optimizer = paddle.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
......@@ -73,7 +87,34 @@ def create_optimizer(args, step_per_epoch):
return cosine_decay(args, step_per_epoch)
def create_unstructured_pruner(train_program, args, place, configs):
if configs is None:
return UnstructuredPruner(
train_program,
mode=args.pruning_mode,
ratio=args.ratio,
threshold=args.threshold,
prune_params_type=args.prune_params_type,
place=place)
else:
return GMPUnstructuredPruner(
train_program,
ratio=args.ratio,
prune_params_type=args.prune_params_type,
place=place,
configs=configs)
def compress(args):
env = os.environ
num_trainers = int(env.get('PADDLE_TRAINERS_NUM', 1))
use_data_parallel = num_trainers > 1
if use_data_parallel:
# Fleet step 1: initialize the distributed environment
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
train_reader = None
test_reader = None
if args.data == "mnist":
......@@ -96,27 +137,30 @@ def compress(args):
image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(args.model,
model_list)
if args.use_gpu:
places = paddle.static.cuda_places()
else:
places = paddle.static.cpu_places()
places = paddle.static.cuda_places()
place = places[0]
exe = paddle.static.Executor(place)
image = paddle.static.data(
name='image', shape=[None] + image_shape, dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
batch_size_per_card = int(args.batch_size / len(places))
train_loader = paddle.io.DataLoader(
batch_size_per_card = args.batch_size
batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset,
places=places,
feed_list=[image, label],
drop_last=True,
batch_size=batch_size_per_card,
shuffle=True,
drop_last=True)
train_loader = paddle.io.DataLoader(
train_dataset,
places=place,
batch_sampler=batch_sampler,
feed_list=[image, label],
return_list=False,
use_shared_memory=True,
num_workers=32)
valid_loader = paddle.io.DataLoader(
val_dataset,
places=place,
......@@ -126,7 +170,9 @@ def compress(args):
use_shared_memory=True,
batch_size=args.batch_size_for_validation,
shuffle=False)
step_per_epoch = int(np.ceil(len(train_dataset) * 1. / args.batch_size))
step_per_epoch = int(
np.ceil(len(train_dataset) * 1. / args.batch_size / num_trainers))
# model definition
model = models.__dict__[args.model]()
......@@ -139,18 +185,46 @@ def compress(args):
val_program = paddle.static.default_main_program().clone(for_test=True)
opt, learning_rate = create_optimizer(args, step_per_epoch)
opt.minimize(avg_cost)
pruner = UnstructuredPruner(
paddle.static.default_main_program(),
mode=args.pruning_mode,
ratio=args.ratio,
threshold=args.threshold,
place=place)
# Fleet step 2: distributed strategy
if use_data_parallel:
dist_strategy = DistributedStrategy()
dist_strategy.sync_batch_norm = False
dist_strategy.exec_strategy = paddle.static.ExecutionStrategy()
dist_strategy.fuse_all_reduce_ops = False
train_program = paddle.static.default_main_program()
if args.pruning_strategy == 'gmp':
# GMP pruner step 0: define configs for GMP, no need to define configs for the base training.
configs = {
'stable_iterations': args.stable_epochs * step_per_epoch,
'pruning_iterations': args.pruning_epochs * step_per_epoch,
'tunning_iterations': args.tunning_epochs * step_per_epoch,
'resume_iteration': (args.last_epoch + 1) * step_per_epoch,
'pruning_steps': args.pruning_steps,
'initial_ratio': args.initial_ratio,
}
elif args.pruning_strategy == 'base':
configs = None
# GMP pruner step 1: initialize a pruner object by calling entry function.
pruner = create_unstructured_pruner(
train_program, args, place, configs=configs)
if use_data_parallel:
# Fleet step 3: decorate the origial optimizer and minimize it
opt = fleet.distributed_optimizer(opt, strategy=dist_strategy)
opt.minimize(avg_cost, no_grad_set=pruner.no_grad_set)
exe.run(paddle.static.default_startup_program())
if args.last_epoch > -1:
assert args.checkpoint is not None and os.path.exists(
args.checkpoint), "Please specify a valid checkpoint path."
paddle.fluid.io.load_persistables(
executor=exe, dirname=args.checkpoint, main_program=train_program)
if args.pretrained_model:
elif args.pretrained_model:
assert os.path.exists(
args.
pretrained_model), "Pretrained model path {} doesn't exist".format(
......@@ -162,7 +236,7 @@ def compress(args):
_logger.info("Load pretrained model from {}".format(
args.pretrained_model))
# NOTE: We are using fluid.io.load_vars() because the pretrained model is from an older version which requires this API.
# Please consider using paddle.static.load(program, model_path) when possible
# Please consider using paddle.static.load(program, model_path) when possible
paddle.fluid.io.load_vars(
exe, args.pretrained_model, predicate=if_exist)
......@@ -170,9 +244,10 @@ def compress(args):
acc_top1_ns = []
acc_top5_ns = []
_logger.info("The current density of the inference model is {}%".format(
round(100 * UnstructuredPruner.total_sparse(
paddle.static.default_main_program()), 2)))
_logger.info(
"The current sparsity of the inference model is {}%".format(
round(100 * UnstructuredPruner.total_sparse(
paddle.static.default_main_program()), 2)))
for batch_id, data in enumerate(valid_loader):
start_time = time.time()
acc_top1_n, acc_top5_n = exe.run(
......@@ -200,10 +275,12 @@ def compress(args):
train_reader_cost += time.time() - reader_start
train_start = time.time()
loss_n, acc_top1_n, acc_top5_n = exe.run(
train_program,
program,
feed=data,
fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
# GMP pruner step 2: step() to update ratios and other internal states of the pruner.
pruner.step()
train_run_cost += time.time() - train_start
total_samples += args.batch_size
loss_n = np.mean(loss_n)
......@@ -225,28 +302,30 @@ def compress(args):
learning_rate.step()
reader_start = time.time()
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
train_program = paddle.static.CompiledProgram(
paddle.static.default_main_program()).with_data_parallel(
loss_name=avg_cost.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
for i in range(args.resume_epoch + 1, args.num_epochs):
train(i, train_program)
_logger.info("The current density of the pruned model is: {}%".format(
if use_data_parallel:
# Fleet step 4: get the compiled program from fleet
compiled_train_program = fleet.main_program
else:
compiled_train_program = paddle.static.CompiledProgram(
paddle.static.default_main_program())
for i in range(args.last_epoch + 1, args.num_epochs):
train(i, compiled_train_program)
# GMP pruner step 3: update params before summrizing sparsity, saving model or evaluation.
pruner.update_params()
_logger.info("The current sparsity of the pruned model is: {}%".format(
round(100 * UnstructuredPruner.total_sparse(
paddle.static.default_main_program()), 2)))
if (i + 1) % args.test_period == 0:
pruner.update_params()
test(i, val_program)
if (i + 1) % args.model_period == 0:
pruner.update_params()
# NOTE: We are using fluid.io.save_params() because the pretrained model is from an older version which requires this API.
# Please consider using paddle.static.save(program, model_path) as long as it becomes possible.
fluid.io.save_params(executor=exe, dirname=args.model_path)
if use_data_parallel:
fleet.save_persistables(executor=exe, dirname=args.model_path)
else:
paddle.fluid.io.save_persistables(
executor, dirname=args.model_path)
def main():
......
python -m paddle.distributed.launch \
--selected_gpus="0,1,2,3" \
train.py \
--batch_size 64 \
--data imagenet \
--pruning_mode ratio \
--ratio 0.55 \
--lr 0.05 \
--model MobileNet \
--pretrained_model "MobileNetV1_pretrained"
python -m paddle.distributed.launch \
--selected_gpus="0,1,2,3" \
train.py \
--batch_size 64 \
--data imagenet \
--pruning_mode ratio \
--ratio 0.75 \
--lr 0.005 \
--model MobileNet \
--num_epochs 108 \
--pretrained_model "MobileNetV1_pretrained" \
--step_epochs 71 88 \
--initial_ratio 0.15 \
--pruning_steps 100 \
--stable_epochs 0 \
--pruning_epochs 54 \
--tunning_epochs 54 \
--last_epoch -1 \
--pruning_strategy gmp
#!/bin/bash
export CUDA_VISIBLE_DEVICES=2,3
export FLAGS_fraction_of_gpu_memory_to_use=0.98
python3.7 train.py \
--batch_size 512 \
CUDA_VISIBLE_DEVICES=0 python train.py \
--batch_size 64 \
--data imagenet \
--pruning_mode ratio \
--ratio 0.55 \
--lr 0.05 \
--pretrained_model ./MobileNetV1_pretrained
--model MobileNet \
--pretrained_model "MobileNetV1_pretrained" \
CUDA_VISIBLE_DEVICES=0 python train.py \
--batch_size 64 \
--data imagenet \
--pruning_mode ratio \
--ratio 0.75 \
--lr 0.005 \
--model MobileNet \
--num_epochs 108 \
--pretrained_model "MobileNetV1_pretrained" \
--model_path "./models" \
--step_epochs 71 88 \
--last_epoch -1 \
--initial_ratio 0.15 \
--pruning_steps 100 \
--stable_epochs 0 \
--pruning_epochs 54 \
--tunning_epochs 54 \
--pruning_strategy gmp
#!/bin/bash
export CUDA_VISIBLE_DEVICES=2,3
export CUDA_VISIBLE_DEVICES=2
export FLAGS_fraction_of_gpu_memory_to_use=0.98
python3.7 train.py \
--batch_size=512 \
python train.py \
--batch_size=64 \
--data="mnist" \
--pruning_mode="threshold" \
--threshold=0.01 \
--lr=0.05 \
--lr=0.05
......@@ -4,8 +4,7 @@
UnstructuredPruner
----------
.. py:class:: paddleslim.UnstructuredPruner(model, mode, threshold=0.01, ratio=0.3, skip_params_func=None)
.. py:class:: paddleslim.UnstructuredPruner(model, mode, threshold=0.01, ratio=0.55, prune_params_type=None, skip_params_func=None)
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dygraph/prune/unstructured_pruner.py>`_
......@@ -17,7 +16,33 @@ UnstructuredPruner
- **mode(str)** - 稀疏化的模式,目前支持的模式有:'ratio''threshold'。在'ratio'模式下,会给定一个固定比例,例如0.5,然后所有参数中重要性较低的50%会被置0。类似的,在'threshold'模式下,会给定一个固定阈值,例如1e-5,然后重要性低于1e-5的参数会被置0
- **ratio(float)** - 稀疏化比例期望,只有在 mode=='ratio' 时才会生效。
- **threshold(float)** - 稀疏化阈值期望,只有在 mode=='threshold' 时才会生效。
- **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。
- **prune_params_type(String)** - 用以指定哪些类型的参数参与稀疏。目前只支持None"conv1x1_only"两个选项,后者表示只稀疏化1x1卷积。而前者表示稀疏化除了归一化层的参数。
- **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。示例代码如下:
.. code-block:: python
NORMS_ALL = [ 'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'BatchNorm1D',
'BatchNorm2D', 'BatchNorm3D', 'InstanceNorm1D', 'InstanceNorm2D',
'InstanceNorm3D', 'SyncBatchNorm', 'LocalResponseNorm' ]
def _get_skip_params(model):
"""
This function is used to check whether the given model's layers are valid to be pruned.
Usually, the convolutions are to be pruned while we skip the normalization-related parameters.
Deverlopers could replace this function by passing their own when initializing the UnstructuredPuner instance.
Args:
- model(Paddle.nn.Layer): the current model waiting to be checked.
Return:
- skip_params(set<String>): a set of parameters' names
"""
skip_params = set()
for _, sub_layer in model.named_sublayers():
if type(sub_layer).__name__.split('.')[-1] in NORMS_ALL:
skip_params.add(sub_layer.full_name())
return skip_params
..
**返回:** 一个UnstructuredPruner类的实例。
......@@ -33,13 +58,13 @@ UnstructuredPruner
place = paddle.set_device('cpu')
model = net(num_classes=10)
pruner = UnstructuredPruner(model, mode='ratio', ratio=0.5)
pruner = UnstructuredPruner(model, mode='ratio', ratio=0.55)
..
.. py:method:: paddleslim.UnstructuredPruner.step()
更新稀疏化的阈值,如果是'threshold'模式,则维持设定的阈值,如果是'ratio'模式,则根据优化后的模型参数和设定的比例,重新计算阈值。
更新稀疏化的阈值,如果是'threshold'模式,则维持设定的阈值,如果是'ratio'模式,则根据优化后的模型参数和设定的比例,重新计算阈值。该函数调用在训练过程中每个batchoptimizer.step()之后。
**示例代码:**
......@@ -52,7 +77,7 @@ UnstructuredPruner
place = paddle.set_device('cpu')
model = net(num_classes=10)
pruner = UnstructuredPruner(model, mode='ratio', ratio=0.5)
pruner = UnstructuredPruner(model, mode='ratio', ratio=0.55)
print(pruner.threshold)
pruner.step()
......@@ -77,27 +102,26 @@ UnstructuredPruner
model = net(num_classes=10)
pruner = UnstructuredPruner(model, mode='threshold', threshold=0.5)
density = UnstructuredPruner.total_sparse(model)
print(density)
model(paddle.to_tensor(
np.random.uniform(0, 1, [16, 1, 28, 28]), dtype='float32'))
sparsity = UnstructuredPruner.total_sparse(model)
print(sparsity)
pruner.step()
pruner.update_params()
density = UnstructuredPruner.total_sparse(model)
print(density) # 可以看出,这里打印的模型稠密度与上述不同,这是因为update_params()函数置零了所有绝对值小于0.5的权重。
sparsity = UnstructuredPruner.total_sparse(model)
print(sparsity) # 可以看出,这里打印的模型稀疏度与上述不同,这是因为update_params()函数置零了所有绝对值小于0.5的权重。
..
.. py:method:: paddleslim.UnstructuredPruner.total_sparse(model)
UnstructuredPruner中的静态方法,用于计算给定的模型(model)的稠密度(1-稀疏度)并返回。该方法为静态方法,是考虑到在单单做模型评价的时候,我们就不需要初始化一个UnstructuredPruner示例了。
UnstructuredPruner中的静态方法,用于计算给定的模型(model)的稀疏度并返回。该方法为静态方法,是考虑到在单单做模型评价的时候,我们就不需要初始化一个UnstructuredPruner示例了。
**参数:**
- **model(paddle.nn.Layer)** - 要计算稠密度的目标网络。
- **model(paddle.nn.Layer)** - 要计算稀疏度的目标网络。
**返回:**
- **density(float)** - 模型的稠密度。
- **sparsity(float)** - 模型的稀疏度。
**示例代码:**
......@@ -110,11 +134,39 @@ UnstructuredPruner
place = paddle.set_device('cpu')
model = net(num_classes=10)
density = UnstructuredPruner.total_sparse(model)
print(density)
sparsity = UnstructuredPruner.total_sparse(model)
print(sparsity)
..
.. py:method:: paddleslim.UnstructuredPruner.total_sparse_conv1x1(model)
UnstructuredPruner中的静态方法,用于计算给定的模型(model)的1x1卷积的稀疏度并返回。该方法为静态方法,是考虑到在单单做模型评价的时候,我们就不需要初始化一个UnstructuredPruner示例了。
**参数:**
- **model(paddle.nn.Layer)** - 要计算稀疏度的目标网络。
**返回:**
- **sparsity(float)** - 模型的1x1卷积稀疏度。
**示例代码:**
.. code-block:: python
import paddle
from paddleslim import UnstructuredPruner
from paddle.vision.models import MobileNetV1 as net
import numpy as np
place = paddle.set_device('cpu')
model = net(num_classes=10)
sparsity = UnstructuredPruner.total_sparse_conv1x1(model)
print(sparsity)
..
.. py:method:: paddleslim.UnstructuredPruner.summarize_weights(model, ratio=0.1)
该函数用于估计预训练模型中参数的分布情况,尤其是在不清楚如何设置threshold的数值时,尤为有用。例如,当输入为ratio=0.1时,函数会返回一个数值v,而绝对值小于v的权重的个数占所有权重个数的(100*ratio%)
......@@ -139,9 +191,97 @@ UnstructuredPruner
place = paddle.set_device('cpu')
model = net(num_classes=10)
pruner = UnstructuredPruner(model, mode='ratio', ratio=0.5)
pruner = UnstructuredPruner(model, mode='ratio', ratio=0.55)
threshold = pruner.summarize_weights(model, 0.5)
print(threshold)
..
GMPUnstructuredPruner
----------
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dygraph/prune/unstructured_pruner.py>`_
.. py:class:: paddleslim.GMPUnstructuredPruner(model, ratio=0.55, prune_params_type=None, skip_params_func=None, configs=None)
该类是UnstructuredPruner的一个子类,通过覆盖step()方法,优化了训练策略,使稀疏化训练更易恢复到稠密模型精度。其他方法均继承自父类。
**参数:**
- **model(paddle.nn.Layer)** - 待剪裁的动态图模型。
- **ratio(float)** - 稀疏化比例期望,只有在 mode=='ratio' 时才会生效。
- **prune_params_type(str)** - 用以指定哪些类型的参数参与稀疏。目前只支持None"conv1x1_only"两个选项,后者表示只稀疏化1x1卷积。而前者表示稀疏化除了归一化层的参数。
- **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。
- **configs(Dict)** - 传入额外的训练超参用以指导GMP训练过程。各参数介绍如下:
.. code-block:: python
{'stable_iterations': int} # the duration of stable phase in terms of global iterations
{'pruning_iterations': int} # the duration of pruning phase in terms of global iterations
{'tunning_iterations': int} # the duration of tunning phase in terms of global iterations
{'resume_iteration': int} # the start timestamp you want to train from, in terms if global iteration
{'pruning_steps': int} # the total times you want to increase the ratio
{'initial_ratio': float} # the initial ratio value
..
**返回:** 一个GMPUnstructuredPruner类的实例
.. code-block:: python
import paddle
from paddleslim import GMPUnstructuredPruner
from paddle.vision.models import LeNet as net
import numpy as np
place = paddle.set_device('cpu')
model = net(num_classes=10)
configs = {
'stable_iterations': 0,
'pruning_iterations': 1000,
'tunning_iterations': 1000,
'resume_iteration': 0,
'pruning_steps': 10,
'initial_ratio': 0.15,
}
pruner = GMPUnstructuredPruner(model, ratio=0.55, configs=configs)
..
.. py:method:: paddleslim.GMPUnstructuredPruner.step()
更新稀疏化的阈值:根据优化后的模型参数和设定的比例,重新计算阈值。该函数调用在训练过程中每个batchoptimizer.step()之后。
**示例代码:**
.. code-block:: python
import paddle
from paddleslim import GMPUnstructuredPruner
from paddle.vision.models import LeNet as net
import numpy as np
place = paddle.set_device('cpu')
model = net(num_classes=10)
configs = {
'stable_iterations': 0,
'pruning_iterations': 1000,
'tunning_iterations': 1000,
'resume_iteration': 0,
'pruning_steps': 10,
'initial_ratio': 0.15,
}
pruner = GMPUnstructuredPruner(model, ratio=0.55, configs=configs)
print(pruner.threshold)
for i in range(200):
pruner.step()
print(pruner.threshold) # 可以看出,这里的threshold和上面打印的不同,这是因为step函数根据设定的ratio更新了threshold数值,便于剪裁操作。
..
......@@ -4,7 +4,7 @@
UnstrucuturedPruner
----------
.. py:class:: paddleslim.prune.UnstructuredPruner(program, mode, ratio=0.5, threshold=1e-5, scope=None, place=None, skip_params_func=None)
.. py:class:: paddleslim.prune.UnstructuredPruner(program, mode, ratio=0.55, threshold=1e-2, scope=None, place=None, prune_params_type, skip_params_func=None)
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/unstructured_pruner.py>`_
......@@ -13,13 +13,37 @@ UnstrucuturedPruner
**参数:**
- **program(paddle.static.Program)** - 一个paddle.static.Program对象,是待剪裁的模型。
- **mode(str)** - 稀疏化的模式,目前支持的模式有:'ratio'和'threshold'。在'ratio'模式下,会给定一个固定比例,例如0.5,然后所有参数中重要性较低的50%会被置0。类似的,在'threshold'模式下,会给定一个固定阈值,例如1e-5,然后重要性低于1e-5的参数会被置0。
- **mode(str)** - 稀疏化的模式,目前支持的模式有:'ratio'和'threshold'。在'ratio'模式下,会给定一个固定比例,例如0.55,然后所有参数中重要性较低的50%会被置0。类似的,在'threshold'模式下,会给定一个固定阈值,例如1e-2,然后重要性低于1e-2的参数会被置0。
- **ratio(float)** - 稀疏化比例期望,只有在 mode=='ratio' 时才会生效。
- **threshold(float)** - 稀疏化阈值期望,只有在 mode=='threshold' 时才会生效。
- **scope(paddle.static.Scope)** - 一个paddle.static.Scope对象,存储了所有变量的数值,默认(None)时表示paddle.static.global_scope。
- **place(CPUPlace|CUDAPlace)** - 模型执行的设备,类型为CPUPlace或者CUDAPlace,默认(None)时代表CPUPlace。
- **prune_params_type(String)** - 用以指定哪些类型的参数参与稀疏。目前只支持None和"conv1x1_only"两个选项,后者表示只稀疏化1x1卷积。而前者表示稀疏化除了归一化的参数。
- **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。
.. code-block:: python
def _get_skip_params(program):
"""
The function is used to get a set of all the skipped parameters when performing pruning.
By default, the normalization-related ones will not be pruned.
Developers could replace it by passing their own function when initializing the UnstructuredPruner instance.
Args:
- program(paddle.static.Program): the current model.
Returns:
- skip_params(Set<String>): a set of parameters' names.
"""
skip_params = set()
graph = paddleslim.core.GraphWrapper(program)
for op in graph.ops():
if 'norm' in op.type() and 'grad' not in op.type():
for input in op.all_inputs():
skip_params.add(input.name())
return skip_params
..
**返回:** 一个UnstructuredPruner类的实例
**示例代码:**
......@@ -47,12 +71,12 @@ UnstrucuturedPruner
exe = paddle.static.Executor(place)
exe.run(startup_program)
pruner = UnstructuredPruner(paddle.static.default_main_program(), 'ratio', ratio=0.5, place=place)
pruner = UnstructuredPruner(paddle.static.default_main_program(), 'ratio', ratio=0.55, place=place)
..
.. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.step()
更新稀疏化的阈值,如果是'threshold'模式,则维持设定的阈值,如果是'ratio'模式,则根据优化后的模型参数和设定的比例,重新计算阈值。
更新稀疏化的阈值,如果是'threshold'模式,则维持设定的阈值,如果是'ratio'模式,则根据优化后的模型参数和设定的比例,重新计算阈值。该函数调用在训练过程中每个batch的optimizer.step()之后。
**示例代码:**
......@@ -79,7 +103,7 @@ UnstrucuturedPruner
exe = paddle.static.Executor(place)
exe.run(startup_program)
pruner = UnstructuredPruner(paddle.static.default_main_program(), 'ratio', ratio=0.5, place=place)
pruner = UnstructuredPruner(paddle.static.default_main_program(), 'ratio', ratio=0.55, place=place)
print(pruner.threshold)
pruner.step()
print(pruner.threshold) # 可以看出,这里的threshold和上面打印的不同,这是因为step函数根据设定的ratio更新了threshold数值,便于剪裁操作。
......@@ -87,7 +111,7 @@ UnstrucuturedPruner
.. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.update_params()
每一步优化后,重制模型中本来是0的权重。这一步通常用于模型evaluation和save之前,确保模型的稀疏率。但是,在训练过程中,由于step()函数会调用该方法,故不需要开发者在训练过程中额外调用了。
每一步优化后,重制模型中本来是0的权重。这一步通常用于模型evaluation和save之前,确保模型的稀疏率。但是,在训练过程中,由于前向过程中插入了稀疏化权重的op,故不需要开发者在训练过程中额外调用了。
**示例代码:**
......@@ -114,19 +138,19 @@ UnstrucuturedPruner
exe = paddle.static.Executor(place)
exe.run(startup_program)
pruner = UnstructuredPruner(paddle.static.default_main_program(), 'threshold', threshold=0.5, place=place)
density = UnstructuredPruner.total_sparse(paddle.static.default_main_program())
print(density)
pruner = UnstructuredPruner(paddle.static.default_main_program(), 'threshold', threshold=0.55, place=place)
sparsity = UnstructuredPruner.total_sparse(paddle.static.default_main_program())
print(sparsity)
pruner.step()
pruner.update_params()
density = UnstructuredPruner.total_sparse(paddle.static.default_main_program())
print(density) # 可以看出,这里打印的模型稠密度与上述不同,这是因为update_params()函数置零了所有绝对值小于0.5的权重。
sparsity = UnstructuredPruner.total_sparse(paddle.static.default_main_program())
print(sparsity) # 可以看出,这里打印的模型稀疏度与上述不同,这是因为update_params()函数置零了所有绝对值小于0.55的权重。
..
.. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.total_sparse(program)
UnstructuredPruner中的静态方法,用于计算给定的模型(program)的稠密度(1-稀疏度)并返回。该方法为静态方法,是考虑到在单单做模型评价的时候,我们就不需要初始化一个UnstructuredPruner示例了。
UnstructuredPruner中的静态方法,用于计算给定的模型(program)的稀疏度并返回。该方法为静态方法,是考虑到在单单做模型评价的时候,我们就不需要初始化一个UnstructuredPruner示例了。
**参数:**
......@@ -134,7 +158,7 @@ UnstrucuturedPruner
**返回:**
- **density(float)** - 模型的稠密度。
- **sparsity(float)** - 模型的稀疏度。
**示例代码:**
......@@ -161,8 +185,51 @@ UnstrucuturedPruner
exe = paddle.static.Executor(place)
exe.run(startup_program)
density = UnstructuredPruner.total_sparse(paddle.static.default_main_program())
print(density)
sparsity = UnstructuredPruner.total_sparse(paddle.static.default_main_program())
print(sparsity)
..
.. py:method:: paddleslim.prune.unstructured_pruner.UnstructuredPruner.total_sparse_conv1x1(program)
UnstructuredPruner中的静态方法,用于计算给定的模型(program)的1x1卷积稀疏度并返回。该方法为静态方法,是考虑到在单单做模型评价的时候,我们就不需要初始化一个UnstructuredPruner示例了。
**参数:**
- **program(paddle.static.Program)** - 要计算稠密度的目标网络。
**返回:**
- **sparsity(float)** - 模型的1x1卷积部分的稀疏度。
**示例代码:**
.. code-block:: python
import paddle
import paddle.fluid as fluid
from paddleslim.prune import UnstructuredPruner
paddle.enable_static()
train_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
with fluid.program_guard(train_program, startup_program):
image = fluid.data(name='x', shape=[None, 1, 28, 28])
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
conv1x1 = fluid.layers.conv2d(image, 32, 1)
conv3x3 = fluid.layers.conv2d(conv1x1, 32, 3)
feature = fluid.layers.fc(conv3x3, 10, act='softmax')
cost = fluid.layers.cross_entropy(input=feature, label=label)
avg_cost = fluid.layers.mean(x=cost)
place = paddle.static.cpu_places()[0]
exe = paddle.static.Executor(place)
exe.run(startup_program)
sparsity = UnstructuredPruner.total_sparse_conv1x1(paddle.static.default_main_program())
print(sparsity)
..
......@@ -204,9 +271,126 @@ UnstrucuturedPruner
exe = paddle.static.Executor(place)
exe.run(startup_program)
pruner = UnstructuredPruner(paddle.static.default_main_program(), 'ratio', ratio=0.5, place=place)
threshold = pruner.summarize_weights(paddle.static.default_main_program(), ratio=0.5)
pruner = UnstructuredPruner(paddle.static.default_main_program(), 'ratio', ratio=0.55, place=place)
threshold = pruner.summarize_weights(paddle.static.default_main_program(), ratio=0.55)
print(threshold)
..
GMPUnstrucuturedPruner
----------
.. py:class:: paddleslim.prune.GMPUnstructuredPruner(program, ratio=0.55, scope=None, place=None, prune_params_type=None, skip_params_func=None, configs=None)
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/unstructured_pruner.py>`_
该类是UnstructuredPruner的一个子类,通过覆盖step()方法,优化了训练策略,使稀疏化训练更易恢复到稠密模型精度。其他方法均继承自父类。
**参数:**
- **program(paddle.static.Program)** - 一个paddle.static.Program对象,是待剪裁的模型。
- **ratio(float)** - 稀疏化比例期望,只有在 mode=='ratio' 时才会生效。
- **scope(paddle.static.Scope)** - 一个paddle.static.Scope对象,存储了所有变量的数值,默认(None)时表示paddle.static.global_scope。
- **place(CPUPlace|CUDAPlace)** - 模型执行的设备,类型为CPUPlace或者CUDAPlace,默认(None)时代表CPUPlace。
- **prune_params_type(String)** - 用以指定哪些类型的参数参与稀疏。目前只支持None和"conv1x1_only"两个选项,后者表示只稀疏化1x1卷积。而前者表示稀疏化除了归一化的参数。
- **skip_params_func(function)** - 一个指向function的指针,该function定义了哪些参数不应该被剪裁,默认(None)时代表所有归一化层参数不参与剪裁。
- **configs(Dict)** - 传入额外的训练超参用以指导GMP训练过程。具体描述如下:
.. code-block:: python
{'stable_iterations': int} # the duration of stable phase in terms of global iterations
{'pruning_iterations': int} # the duration of pruning phase in terms of global iterations
{'tunning_iterations': int} # the duration of tunning phase in terms of global iterations
{'resume_iteration': int} # the start timestamp you want to train from, in terms if global iteration
{'pruning_steps': int} # the total times you want to increase the ratio
{'initial_ratio': float} # the initial ratio value
..
**返回:** 一个GMPUnstructuredPruner类的实例
**示例代码:**
.. code-block:: python
import paddle
import paddle.fluid as fluid
from paddleslim.prune import GMPUnstructuredPruner
paddle.enable_static()
train_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
with fluid.program_guard(train_program, startup_program):
image = fluid.data(name='x', shape=[None, 1, 28, 28])
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
conv = fluid.layers.conv2d(image, 32, 1)
feature = fluid.layers.fc(conv, 10, act='softmax')
cost = fluid.layers.cross_entropy(input=feature, label=label)
avg_cost = fluid.layers.mean(x=cost)
place = paddle.static.cpu_places()[0]
exe = paddle.static.Executor(place)
exe.run(startup_program)
configs = {
'stable_iterations': 0,
'pruning_iterations': 1000,
'tunning_iterations': 1000,
'resume_iteration': 0,
'pruning_steps': 10,
'initial_ratio': 0.15,
}
pruner = GMPUnstructuredPruner(paddle.static.default_main_program(), ratio=0.55, place=place, configs=configs)
for i in range(2000):
pruner.step()
print(pruner.ratio) # 可以看到ratio从0.15非线性的增加到0.55。
..
.. py:method:: paddleslim.prune.unstructured_pruner.GMPUnstructuredPruner.step()
根据优化后的模型参数和设定的比例,重新计算阈值,并且更新mask。该函数调用在训练过程中每个batch的optimizer.step()之后。
**示例代码:**
.. code-block:: python
import paddle
import paddle.fluid as fluid
from paddleslim.prune import GMPUnstructuredPruner
paddle.enable_static()
train_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
with fluid.program_guard(train_program, startup_program):
image = fluid.data(name='x', shape=[None, 1, 28, 28])
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
conv = fluid.layers.conv2d(image, 32, 1)
feature = fluid.layers.fc(conv, 10, act='softmax')
cost = fluid.layers.cross_entropy(input=feature, label=label)
avg_cost = fluid.layers.mean(x=cost)
place = paddle.static.cpu_places()[0]
exe = paddle.static.Executor(place)
exe.run(startup_program)
configs = {
'stable_iterations': 0,
'pruning_iterations': 1000,
'tunning_iterations': 1000,
'resume_iteration': 0,
'pruning_steps': 10,
'initial_ratio': 0.15,
}
pruner = GMPUnstructuredPruner(paddle.static.default_main_program(), ratio=0.55, place=place, configs=configs)
print(pruner.threshold)
for i in range(200):
pruner.step()
print(pruner.threshold) # 可以看出,这里的threshold和上面打印的不同,这是因为step函数根据设定的ratio更新了threshold数值,便于剪裁操作。
..
......@@ -3,7 +3,7 @@ import paddle
import logging
from paddleslim.common import get_logger
__all__ = ["UnstructuredPruner"]
__all__ = ["UnstructuredPruner", "GMPUnstructuredPruner"]
_logger = get_logger(__name__, level=logging.INFO)
......@@ -21,7 +21,8 @@ class UnstructuredPruner():
- model(Paddle.nn.Layer): The model to be pruned.
- mode(str): Pruning mode, must be selected from 'ratio' and 'threshold'.
- threshold(float): The parameters whose absolute values are smaller than the THRESHOLD will be zeros. Default: 0.01
- ratio(float): The parameters whose absolute values are in the smaller part decided by the ratio will be zeros. Default: 0.3
- ratio(float): The parameters whose absolute values are in the smaller part decided by the ratio will be zeros. Default: 0.55
- prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None
- skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params.
"""
......@@ -29,15 +30,25 @@ class UnstructuredPruner():
model,
mode,
threshold=0.01,
ratio=0.3,
ratio=0.55,
prune_params_type=None,
skip_params_func=None):
assert mode in ('ratio', 'threshold'
), "mode must be selected from 'ratio' and 'threshold'"
assert prune_params_type is None or prune_params_type == 'conv1x1_only', "prune_params_type only supports None or conv1x1_only for now."
self.model = model
self.mode = mode
self.threshold = threshold
self.ratio = ratio
if skip_params_func is None: skip_params_func = self._get_skip_params
# Prority: passed-in skip_params_func > prune_params_type (conv1x1_only) > built-in _get_skip_params
if skip_params_func is not None:
skip_params_func = skip_params_func
elif prune_params_type == 'conv1x1_only':
skip_params_func = self._get_skip_params_conv1x1
elif skip_params_func is None:
skip_params_func = self._get_skip_params
self.skip_params = skip_params_func(model)
self._apply_masks()
......@@ -49,8 +60,6 @@ class UnstructuredPruner():
- parameters(list<Tensor>): The parameters to be pruned.
- masks(list<Tensor>): The masks used to keep zero values in parameters.
"""
bool_tmp = (paddle.abs(param) >= self.threshold)
paddle.assign(mask * bool_tmp, output=mask)
param_tmp = param * mask
param_tmp.stop_gradient = True
paddle.assign(param_tmp, output=param)
......@@ -86,6 +95,14 @@ class UnstructuredPruner():
self.threshold = np.sort(np.abs(params_flatten))[max(
0, round(self.ratio * total_length) - 1)].item()
def _update_masks(self):
for name, sub_layer in self.model.named_sublayers():
if not self._should_prune_layer(sub_layer): continue
for param in sub_layer.parameters(include_sublayers=False):
mask = self.masks.get(param.name)
bool_tmp = (paddle.abs(param) >= self.threshold)
paddle.assign(bool_tmp, output=mask)
def summarize_weights(self, model, ratio=0.1):
"""
The function is used to get the weights corresponding to a given ratio
......@@ -114,8 +131,9 @@ class UnstructuredPruner():
"""
if self.mode == 'ratio':
self.update_threshold()
self._update_masks()
elif self.mode == 'threshold':
return
self._update_masks()
def _forward_pre_hook(self, layer, input):
if not self._should_prune_layer(layer):
......@@ -140,21 +158,46 @@ class UnstructuredPruner():
@staticmethod
def total_sparse(model):
"""
This static function is used to get the whole model's density (1-sparsity).
This static function is used to get the whole model's sparsity.
It is static because during testing, we can calculate sparsity without initializing a pruner instance.
Args:
- model(paddle.nn.Layer): The sparse model.
Returns:
- ratio(float): The model's sparsity.
"""
total = 0
values = 0
for name, sub_layer in model.named_sublayers():
for param in sub_layer.parameters(include_sublayers=False):
total += np.product(param.shape)
values += len(paddle.nonzero(param))
ratio = 1 - float(values) / total
return ratio
@staticmethod
def total_sparse_conv1x1(model):
"""
This static function is used to get the partial model's sparsity in terms of conv1x1 layers.
It is static because during testing, we can calculate sparsity without initializing a pruner instance.
Args:
- model(paddle.nn.Layer): The sparse model.
Returns:
- ratio(float): The model's density.
- ratio(float): The model's sparsity.
"""
total = 0
values = 0
for name, sub_layer in model.named_sublayers():
if type(sub_layer).__name__.split('.')[-1] in NORMS_ALL:
continue
for param in sub_layer.parameters(include_sublayers=False):
cond = len(param.shape) == 4 and param.shape[
2] == 1 and param.shape[3] == 1
if not cond: continue
total += np.product(param.shape)
values += len(paddle.nonzero(param))
ratio = float(values) / total
ratio = 1 - float(values) / total
return ratio
def _get_skip_params(self, model):
......@@ -174,6 +217,108 @@ class UnstructuredPruner():
skip_params.add(sub_layer.full_name())
return skip_params
def _get_skip_params_conv1x1(self, model):
skip_params = set()
for _, sub_layer in model.named_sublayers():
if type(sub_layer).__name__.split('.')[-1] in NORMS_ALL:
skip_params.add(sub_layer.full_name())
for param in sub_layer.parameters(include_sublayers=False):
cond = len(param.shape) == 4 and param.shape[
2] == 1 and param.shape[3] == 1
if not cond: skip_params.add(sub_layer.full_name())
return skip_params
def _should_prune_layer(self, layer):
should_prune = layer.full_name() not in self.skip_params
return should_prune
class GMPUnstructuredPruner(UnstructuredPruner):
"""
The unstructure pruner using GMP training strategy (Gradual Magnitute Pruning). In this subclass of UnstructuredPruner, most methods are inheritated apart from the step(), since we add some ratio increment logics here.
Conceptually, the algorithm divide the training into three phases: stable, pruning and tuning. And the ratio is increasing from initial_ratio gradually and nonlinearly w.r.t. the training epochs/iterations.
Args:
- model(Paddle.nn.Layer): The model to be pruned.
- ratio(float): The parameters whose absolute values are in the smaller part decided by the ratio will be zeros. Default: 0.55
- prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None
- skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params.
- configs(Dict): The dictionary contains all the configs for GMP pruner. Default: None
.. code-block:: python
{'stable_iterations': int} # the duration of stable phase in terms of global iterations
{'pruning_iterations': int} # the duration of pruning phase in terms of global iterations
{'tunning_iterations': int} # the duration of tunning phase in terms of global iterations
{'resume_iteration': int} # the start timestamp you want to train from, in terms if global iteration
{'pruning_steps': int} # the total times you want to increase the ratio
{'initial_ratio': float} # the initial ratio value
..
"""
def __init__(self,
model,
ratio=0.55,
prune_params_type=None,
skip_params_func=None,
configs=None):
assert configs is not None, "Configs must be passed in for GMP pruner."
super(GMPUnstructuredPruner, self).__init__(
model, 'ratio', 0.0, ratio, prune_params_type, skip_params_func)
self.stable_iterations = configs.get('stable_iterations')
self.pruning_iterations = configs.get('pruning_iterations')
self.tunning_iterations = configs.get('tunning_iterations')
self.pruning_steps = configs.get('pruning_steps')
self.initial_ratio = configs.get('initial_ratio')
self.ratio = 0.0
self.target_ratio = ratio
self.cur_iteration = configs.get('resume_iteration')
assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin."
self._prepare_training_hyper_parameters()
def _prepare_training_hyper_parameters(self):
self.ratios_stack = []
self.ratio_increment_period = int(self.pruning_iterations /
self.pruning_steps)
for i in range(self.pruning_steps):
ratio_tmp = ((i / self.pruning_steps) - 1.0)**3 + 1
ratio_tmp = ratio_tmp * (self.target_ratio - self.initial_ratio
) + self.initial_ratio
self.ratios_stack.append(ratio_tmp)
stable_steps = int(
float(self.stable_iterations) / self.pruning_iterations *
self.pruning_steps)
tunning_steps = int(
float(self.tunning_iterations) / self.pruning_iterations *
self.pruning_steps)
stable_ratios_stack = [0.0] * stable_steps
tunning_ratios_stack = [self.target_ratio] * tunning_steps
self.ratios_stack = stable_ratios_stack + self.ratios_stack + tunning_ratios_stack
self.ratios_stack.reverse()
# pop out used ratios to resume training
for i in range(self.cur_iteration):
if len(self.
ratios_stack) > 0 and i % self.ratio_increment_period == 0:
self.ratio = self.ratios_stack.pop()
def step(self):
ori_ratio = self.ratio
if self.cur_iteration % self.ratio_increment_period == 0:
if len(self.ratios_stack) > 0:
self.ratio = self.ratios_stack.pop()
else:
self.ratio = self.target_ratio
# Update the threshold and masks only when a new ratio has been set.
# This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period.
if ori_ratio != self.ratio:
self.update_threshold()
self._update_masks()
self.cur_iteration += 1
......@@ -2,8 +2,9 @@ import numpy as np
from ..common import get_logger
from ..core import GraphWrapper
import paddle
import copy
__all__ = ["UnstructuredPruner"]
__all__ = ["UnstructuredPruner", "GMPUnstructuredPruner"]
class UnstructuredPruner():
......@@ -13,20 +14,22 @@ class UnstructuredPruner():
Args:
- program(paddle.static.Program): The model to be pruned.
- mode(str): the mode to prune the model, must be selected from 'ratio' and 'threshold'.
- ratio(float): the ratio to prune the model. Only set it when mode=='ratio'. Default: 0.5.
- threshold(float): the threshold to prune the model. Only set it when mode=='threshold'. Default: 1e-5.
- ratio(float): the ratio to prune the model. Only set it when mode=='ratio'. Default: 0.55.
- threshold(float): the threshold to prune the model. Only set it when mode=='threshold'. Default: 1e-2.
- scope(paddle.static.Scope): The scope storing values of all variables. None means paddle.static.global_scope. Default: None.
- place(CPUPlace | CUDAPlace): The device place used to execute model. None means CPUPlace. Default: None.
- skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params.
- prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None
- skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params. Default: None
"""
def __init__(self,
program,
mode,
ratio=0.5,
threshold=1e-5,
ratio=0.55,
threshold=1e-2,
scope=None,
place=None,
prune_params_type=None,
skip_params_func=None):
self.mode = mode
self.ratio = ratio
......@@ -34,15 +37,46 @@ class UnstructuredPruner():
assert self.mode in [
'ratio', 'threshold'
], "mode must be selected from 'ratio' and 'threshold'"
assert prune_params_type is None or prune_params_type == 'conv1x1_only', "prune_params_type only supports None or conv1x1_only for now."
self.scope = paddle.static.global_scope() if scope == None else scope
self.place = paddle.CPUPlace() if place is None else place
if skip_params_func is None: skip_params_func = self._get_skip_params
self.place = paddle.static.cpu_places()[0] if place is None else place
# Prority: passed-in skip_params_func > prune_params_type (conv1x1_only) > built-in _get_skip_params
if skip_params_func is not None:
skip_params_func = skip_params_func
elif prune_params_type == 'conv1x1_only':
skip_params_func = self._get_skip_params_conv1x1
elif skip_params_func is None:
skip_params_func = self._get_skip_params
self.skip_params = skip_params_func(program)
self.masks = self._apply_masks(program)
self.masks = self._apply_masks(program, self.mask_parameters)
def _apply_masks(self, program):
def mask_parameters(self, parameters, masks, program):
"""
Update masks and parameters. It is executed before each iteration.
User can overwrite this function in subclass to implememt different pruning stragies.
Args:
- parameters(list<Tensor>): The parameters to be pruned.
- masks(list<Tensor>): The masks used to keep zero values in parameters.
- program(paddle.static.Program): The model to add mask op to.
"""
block = program.global_block()
for param, mask in zip(parameters, masks):
block._prepend_op(
type='elementwise_mul',
inputs={'X': param,
'Y': mask},
outputs={'Out': param},
attrs={'axis': -1,
'use_mkldnn': False})
def _apply_masks(self, program, mask_func):
params = []
masks = []
self.no_grad_set = set()
for param in program.all_parameters():
mask = program.global_block().create_var(
name=param.name + "_mask",
......@@ -56,6 +90,13 @@ class UnstructuredPruner():
np.ones(mask.shape).astype("float32"), self.place)
params.append(param)
masks.append(mask)
self.no_grad_set.add(param.name + "_mask")
with paddle.static.program_guard(main_program=program):
ops = program.global_block().ops
ori_len = len(ops)
mask_func(params, masks, program)
program.global_block().ops = ops
d_masks = {}
for _param, _mask in zip(params, masks):
......@@ -100,7 +141,8 @@ class UnstructuredPruner():
value = np.count_nonzero(
np.array(paddle.static.global_scope().find_var(param.name)
.get_tensor()))
layer_sparse[param.name] = value / np.product(param.shape)
layer_sparse[param.name] = 1 - float(value) / np.product(
param.shape)
return layer_sparse
def update_threshold(self):
......@@ -116,11 +158,18 @@ class UnstructuredPruner():
v_param = np.array(t_param)
params_flatten.append(v_param.flatten())
params_flatten = np.concatenate(params_flatten, axis=0)
total_len = len(params_flatten)
self.threshold = np.sort(np.abs(params_flatten))[max(
0, int(self.ratio * total_len) - 1)]
self.threshold = self._partition_sort(params_flatten)
def _update_params_masks(self):
def _partition_sort(self, params):
total_len = len(params)
params_zeros = params[params == 0]
params_nonzeros = params[params != 0]
new_ratio = max((self.ratio * total_len - len(params_zeros)),
0) / len(params_nonzeros)
return np.sort(np.abs(params_nonzeros))[max(
0, int(new_ratio * len(params_nonzeros)) - 1)]
def _update_masks(self):
for param in self.masks:
if not self._should_prune_param(param):
continue
......@@ -131,22 +180,20 @@ class UnstructuredPruner():
v_param[np.abs(v_param) < self.threshold] = 0
v_mask = (v_param != 0).astype(v_param.dtype)
t_mask.set(v_mask, self.place)
v_param = np.array(t_param) * np.array(t_mask)
t_param.set(v_param, self.place)
def step(self):
"""
Update the threshold after each optimization step.
Update the threshold and masks.
"""
if self.mode == 'threshold':
pass
elif self.mode == 'ratio':
self.update_threshold()
self._update_params_masks()
self._update_masks()
def update_params(self):
"""
Update the parameters given self.masks, usually called before saving models.
Update the parameters given self.masks, usually called before saving or evaluating models.
"""
for param in self.masks:
mask = self.masks[param]
......@@ -158,13 +205,13 @@ class UnstructuredPruner():
@staticmethod
def total_sparse(program):
"""
The function is used to get the whole model's density (1-sparsity).
The function is used to get the whole model's sparsity.
It is static because during testing, we can calculate sparsity without initializing a pruner instance.
Args:
- program(paddle.static.Program): The current model.
Returns:
- density(float): the model's density.
- sparsity(float): the model's sparsity.
"""
total = 0
values = 0
......@@ -173,8 +220,8 @@ class UnstructuredPruner():
values += np.count_nonzero(
np.array(paddle.static.global_scope().find_var(param.name)
.get_tensor()))
density = float(values) / total
return density
sparsity = 1 - float(values) / total
return sparsity
def _get_skip_params(self, program):
"""
......@@ -195,6 +242,141 @@ class UnstructuredPruner():
skip_params.add(input.name())
return skip_params
def _get_skip_params_conv1x1(self, program):
skip_params = set()
graph = GraphWrapper(program)
for op in graph.ops():
if 'norm' in op.type() and 'grad' not in op.type():
for input in op.all_inputs():
skip_params.add(input.name())
for param in program.all_parameters():
if not (len(param.shape) == 4 and param.shape[2] == 1 and
param.shape[3] == 1):
skip_params.add(param.name)
return skip_params
@staticmethod
def total_sparse_conv1x1(program):
"""
The function is used to get the model's spasity for all the 1x1 convolutional weights.
It is static because during testing, we can calculate sparsity without initializing a pruner instance.
Args:
- program(paddle.static.Program): The current model.
Returns:
- sparsity(float): the model's sparsity.
"""
total = 0
values = 0
for param in program.all_parameters():
if not (len(param.shape) == 4 and param.shape[2] == 1 and
param.shape[3] == 1):
continue
total += np.product(param.shape)
values += np.count_nonzero(
np.array(paddle.static.global_scope().find_var(param.name)
.get_tensor()))
sparsity = 1 - float(values) / total
return sparsity
def _should_prune_param(self, param):
should_prune = param not in self.skip_params
return should_prune
class GMPUnstructuredPruner(UnstructuredPruner):
"""
The unstructure pruner using GMP training strategy (Gradual Magnitute Pruning). In this subclass of UnstructuredPruner, most methods are inheritated apart from the step(), since we add some ratio increment logics here.
Conceptually, the algorithm divide the training into three phases: stable, pruning and tuning. And the ratio is increasing from initial_ratio gradually and nonlinearly w.r.t. the training epochs/iterations.
Args:
- program(paddle.static.Program): The model to be pruned.
- ratio(float): the ratio to prune the model. Only set it when mode=='ratio'. Default: 0.55.
- scope(paddle.static.Scope): The scope storing values of all variables. None means paddle.static.global_scope. Default: None.
- place(CPUPlace | CUDAPlace): The device place used to execute model. None means CPUPlace. Default: None.
- prune_params_type(str): The argument to control which type of ops will be pruned. Currently we only support None (all but norms) or conv1x1_only as input. It acts as a straightforward call to conv1x1 pruning. Default: None
- skip_params_func(function): The function used to select the parameters which should be skipped when performing pruning. Default: normalization-related params. Default: None
- configs(Dict): The dictionary contains all the configs for GMP pruner. Default: None. The detailed description is as below:
.. code-block:: python
{'stable_iterations': int} # the duration of stable phase in terms of global iterations
{'pruning_iterations': int} # the duration of pruning phase in terms of global iterations
{'tunning_iterations': int} # the duration of tunning phase in terms of global iterations
{'resume_iteration': int} # the start timestamp you want to train from, in terms if global iteration
{'pruning_steps': int} # the total times you want to increase the ratio
{'initial_ratio': float} # the initial ratio value
..
"""
def __init__(self,
program,
ratio=0.55,
scope=None,
place=None,
prune_params_type=None,
skip_params_func=None,
configs=None):
assert configs is not None, "Please pass in a valid config dictionary."
super(GMPUnstructuredPruner, self).__init__(
program, 'ratio', ratio, 0.0, scope, place, prune_params_type,
skip_params_func)
self.stable_iterations = configs.get('stable_iterations')
self.pruning_iterations = configs.get('pruning_iterations')
self.tunning_iterations = configs.get('tunning_iterations')
self.pruning_steps = configs.get('pruning_steps')
self.initial_ratio = configs.get('initial_ratio')
self.ratio = 0
self.target_ratio = ratio
self.cur_iteration = configs.get('resume_iteration')
assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin."
self._prepare_training_hyper_parameters()
def _prepare_training_hyper_parameters(self):
self.ratios_stack = []
self.ratio_increment_period = int(self.pruning_iterations /
self.pruning_steps)
for i in range(self.pruning_steps):
ratio_tmp = ((i / self.pruning_steps) - 1.0)**3 + 1
ratio_tmp = ratio_tmp * (self.target_ratio - self.initial_ratio
) + self.initial_ratio
self.ratios_stack.append(ratio_tmp)
stable_steps = int(
float(self.stable_iterations) / self.pruning_iterations *
self.pruning_steps)
tunning_steps = int(
float(self.tunning_iterations) / self.pruning_iterations *
self.pruning_steps)
stable_ratios_stack = [0.0] * stable_steps
tunning_ratios_stack = [self.target_ratio] * tunning_steps
self.ratios_stack = stable_ratios_stack + self.ratios_stack + tunning_ratios_stack
self.ratios_stack.reverse()
# pop out used ratios to resume training
for i in range(self.cur_iteration):
if len(self.
ratios_stack) > 0 and i % self.ratio_increment_period == 0:
self.ratio = self.ratios_stack.pop()
def step(self):
"""
Update the threshold and masks.
"""
ori_ratio = self.ratio
if self.cur_iteration % self.ratio_increment_period == 0:
if len(self.ratios_stack) > 0:
self.ratio = self.ratios_stack.pop()
else:
self.ratio = self.target_ratio
# Update the threshold and masks only when a new ratio has been set.
# This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period.
if ori_ratio != self.ratio:
self.update_threshold()
self._update_masks()
self.cur_iteration += 1
......@@ -15,27 +15,33 @@ class TestUnstructuredPruner(unittest.TestCase):
def _gen_model(self):
self.net = mobilenet_v1(num_classes=10, pretrained=False)
self.pruner = UnstructuredPruner(
self.net, mode='ratio', ratio=0.98, threshold=0.0)
self.net_conv1x1 = mobilenet_v1(num_classes=10, pretrained=False)
self.pruner = UnstructuredPruner(self.net, mode='ratio', ratio=0.55)
self.pruner_conv1x1 = UnstructuredPruner(
self.net_conv1x1,
mode='ratio',
ratio=0.55,
prune_params_type='conv1x1_only')
def test_prune(self):
ori_density = UnstructuredPruner.total_sparse(self.net)
ori_sparsity = UnstructuredPruner.total_sparse(self.net)
ori_threshold = self.pruner.threshold
self.pruner.step()
self.net(
paddle.to_tensor(
np.random.uniform(0, 1, [16, 3, 32, 32]), dtype='float32'))
cur_density = UnstructuredPruner.total_sparse(self.net)
cur_sparsity = UnstructuredPruner.total_sparse(self.net)
cur_threshold = self.pruner.threshold
print("Original threshold: {}".format(ori_threshold))
print("Current threshold: {}".format(cur_threshold))
print("Original density: {}".format(ori_density))
print("Current density: {}".format(cur_density))
print("Original sparsity: {}".format(ori_sparsity))
print("Current sparsity: {}".format(cur_sparsity))
self.assertLessEqual(ori_threshold, cur_threshold)
self.assertLessEqual(cur_density, ori_density)
self.assertGreaterEqual(cur_sparsity, ori_sparsity)
self.pruner.update_params()
self.assertEqual(cur_density, UnstructuredPruner.total_sparse(self.net))
self.assertEqual(cur_sparsity,
UnstructuredPruner.total_sparse(self.net))
def test_summarize_weights(self):
max_value = -float("inf")
......@@ -51,6 +57,16 @@ class TestUnstructuredPruner(unittest.TestCase):
print("The max_value is {}.".format(max_value))
self.assertEqual(max_value, threshold)
def test_unstructured_prune_conv1x1(self):
print(self.pruner.skip_params)
print(self.pruner_conv1x1.skip_params)
self.assertTrue(
len(self.pruner.skip_params) < len(self.pruner_conv1x1.skip_params))
self.pruner_conv1x1.step()
self.pruner_conv1x1.update_params()
cur_sparsity = UnstructuredPruner.total_sparse_conv1x1(self.net_conv1x1)
self.assertTrue(abs(cur_sparsity - 0.55) < 0.01)
if __name__ == "__main__":
unittest.main()
import sys
sys.path.append("../../")
import unittest
import paddle
import numpy as np
from paddleslim import UnstructuredPruner, GMPUnstructuredPruner
from paddle.vision.models import mobilenet_v1
class TestUnstructuredPruner(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestUnstructuredPruner, self).__init__(*args, **kwargs)
paddle.disable_static()
self._gen_model()
def _gen_model(self):
self.net = mobilenet_v1(num_classes=10, pretrained=False)
configs = {
'stable_iterations': 0,
'pruning_iterations': 1000,
'tunning_iterations': 1000,
'resume_iteration': 500,
'pruning_steps': 20,
'initial_ratio': 0.05,
}
self.pruner = GMPUnstructuredPruner(
self.net, ratio=0.55, configs=configs)
self.assertGreater(self.pruner.ratio, 0.3)
def test_unstructured_prune_gmp(self):
last_ratio = 0.0
ratio = 0.0
while len(self.pruner.ratios_stack) > 0:
self.pruner.step()
last_ratio = ratio
ratio = self.pruner.ratio
self.assertGreaterEqual(ratio, last_ratio)
self.assertEqual(ratio, 0.55)
if __name__ == "__main__":
unittest.main()
......@@ -27,8 +27,8 @@ class TestUnstructuredPruner(StaticCase):
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
conv3 = conv_bn_layer(sum1, 8, 1, "conv3")
conv4 = conv_bn_layer(conv3, 8, 1, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
......@@ -43,6 +43,12 @@ class TestUnstructuredPruner(StaticCase):
self.pruner = UnstructuredPruner(
self.main_program, 'ratio', scope=self.scope, place=place)
self.pruner_conv1x1 = UnstructuredPruner(
self.main_program,
'ratio',
scope=self.scope,
place=place,
prune_params_type='conv1x1_only')
def test_unstructured_prune(self):
for param in self.main_program.global_block().all_parameters():
......@@ -51,23 +57,24 @@ class TestUnstructuredPruner(StaticCase):
self.assertTrue(tuple(mask_shape) == param.shape)
def test_sparsity(self):
ori_density = UnstructuredPruner.total_sparse(self.main_program)
ori_sparsity = UnstructuredPruner.total_sparse(self.main_program)
self.pruner.step()
cur_density = UnstructuredPruner.total_sparse(self.main_program)
cur_layer_density = self.pruner.sparse_by_layer(self.main_program)
print('original density: {}.'.format(ori_density))
print('current density: {}.'.format(cur_density))
self.pruner.update_params()
cur_sparsity = UnstructuredPruner.total_sparse(self.main_program)
cur_layer_sparsity = self.pruner.sparse_by_layer(self.main_program)
print('original sparsity: {}.'.format(ori_sparsity))
print('current sparsity: {}.'.format(cur_sparsity))
total = 0
non_zeros = 0
for param in self.main_program.all_parameters():
total += np.product(param.shape)
non_zeros += np.count_nonzero(
np.array(self.scope.find_var(param.name).get_tensor()))
self.assertEqual(cur_density, non_zeros / total)
self.assertLessEqual(cur_density, ori_density)
self.assertEqual(cur_sparsity, 1 - non_zeros / total)
self.assertGreater(cur_sparsity, ori_sparsity)
self.pruner.update_params()
self.assertEqual(cur_density,
self.assertEqual(cur_sparsity,
UnstructuredPruner.total_sparse(self.main_program))
def test_summarize_weights(self):
......@@ -76,11 +83,32 @@ class TestUnstructuredPruner(StaticCase):
for param in self.main_program.global_block().all_parameters():
max_value = max(
max_value,
np.max(np.array(self.scope.find_var(param.name).get_tensor())))
np.max(
np.abs(
np.array(self.scope.find_var(param.name).get_tensor(
)))))
print("The returned threshold is {}.".format(threshold))
print("The max_value is {}.".format(max_value))
self.assertEqual(max_value, threshold)
def test_unstructured_prune_conv1x1(self):
print(self.pruner.skip_params)
print(self.pruner_conv1x1.skip_params)
self.assertTrue(
self.pruner.skip_params < self.pruner_conv1x1.skip_params)
def test_sparsity_conv1x1(self):
ori_sparsity = UnstructuredPruner.total_sparse_conv1x1(
self.main_program)
self.pruner.ratio = 0.99
self.pruner.step()
self.pruner.update_params()
cur_sparsity = UnstructuredPruner.total_sparse_conv1x1(
self.main_program)
print('original sparsity: {}.'.format(ori_sparsity))
print('current sparsity: {}.'.format(cur_sparsity))
self.assertGreater(cur_sparsity, ori_sparsity)
if __name__ == '__main__':
unittest.main()
import sys
sys.path.append("../")
import unittest
from static_case import StaticCase
import paddle.fluid as fluid
import paddle
from paddleslim.prune import UnstructuredPruner, GMPUnstructuredPruner
from layers import conv_bn_layer
import numpy as np
class TestUnstructuredPruner(StaticCase):
def __init__(self, *args, **kwargs):
super(TestUnstructuredPruner, self).__init__(*args, **kwargs)
paddle.enable_static()
self._gen_model()
def _gen_model(self):
self.main_program = paddle.static.default_main_program()
self.startup_program = paddle.static.default_startup_program()
# conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6
# | ^ | ^
# |____________| |____________________|
with paddle.static.program_guard(self.main_program,
self.startup_program):
input = paddle.static.data(name='image', shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
conv7 = fluid.layers.conv2d_transpose(
input=conv6, num_filters=16, filter_size=2, stride=2)
place = paddle.static.cpu_places()[0]
exe = paddle.static.Executor(place)
self.scope = paddle.static.global_scope()
exe.run(self.startup_program, scope=self.scope)
configs = {
'stable_iterations': 0,
'pruning_iterations': 1000,
'tunning_iterations': 1000,
'resume_iteration': 500,
'pruning_steps': 20,
'initial_ratio': 0.05,
}
self.pruner = GMPUnstructuredPruner(
self.main_program,
scope=self.scope,
place=place,
configs=configs,
ratio=0.55)
print(self.pruner.ratio)
self.assertGreater(self.pruner.ratio, 0.3)
def test_unstructured_prune_gmp(self):
last_ratio = 0.0
ratio = 0.0
while len(self.pruner.ratios_stack) > 0:
self.pruner.step()
last_ratio = ratio
ratio = self.pruner.ratio
self.assertGreaterEqual(ratio, last_ratio)
self.assertEqual(ratio, 0.55)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册