未验证 提交 dcaa0465 编写于 作者: Z zhouzj 提交者: GitHub

add demo for auto-compress (#1078)

上级 95665af4
# 使用预测模型进行自动压缩示例
本示例将介绍如何使用PaddleSeg中预测模型进行自动压缩训练。
[PP-HumanSeg-Lite](https://github.com/PaddlePaddle/PaddleSeg/tree/develop/contrib/PP-HumanSeg#portrait-segmentation)模型为例,使用自动压缩接口分别进行了蒸馏稀疏训练和蒸馏量化训练实验,并在SD710上使用单线程测试加速效果,其压缩结果和测速结果如下所示:
| 压缩方式 | Total IoU | 耗时(ms)<br>thread=1 | 加速比 |
|:-----:|:----------:|:---------:| :------:|
| Baseline | 0.9287 | 56.363 | - |
| 非结构化稀疏 | 0.9235 | 37.712 | 49.456% |
| 量化 | 0.9284 | 49.656 | 13.506% |
## 自动压缩训练流程
### 1. 准备数据集
参考[PaddleSeg数据准备文档](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/data/marker/marker_cn.md)
### 2. 准备待压缩模型
PaddleSeg 是基于飞桨 PaddlePaddle 开发的端到端图像分割开发套件,涵盖了高精度和轻量级等不同方向的大量高质量分割模型。
安装 PaddleSeg 指令如下:
```
pip install paddleseg
```
PaddleSeg 环境依赖详见[安装文档](https://github.com/PaddlePaddle/PaddleSeg/blob/develop/docs/install_cn.md)
#### 2.1 下载代码
```
git clone https://github.com/PaddlePaddle/PaddleSeg.git
```
#### 2.2 准备预训练模型
在 PaddleSeg 目录下执行如下指令,下载预训练模型。
``` shell
wget https://paddleseg.bj.bcebos.com/dygraph/ppseg/ppseg_lite_portrait_398x224.tar.gz
tar -xzf ppseg_lite_portrait_398x224.tar.gz
```
#### 2.3 导出预测模型
在 PaddleSeg 目录下执行如下命令,则预测模型会保存在 inference_model 文件夹。
```shell
# 设置1张可用的卡
export CUDA_VISIBLE_DEVICES=0
# windows下请执行以下命令
# set CUDA_VISIBLE_DEVICES=0
python export.py \
--config configs/pp_humanseg_lite/pp_humanseg_lite_export_398x224.yml \
--model_path ppseg_lite_portrait_398x224/model.pdparams \
--save_dir inference_model
--with_softmax
```
或直接下载 PP-HumanSeg-Lite 的预测模型:
```shell
wegt https://paddleseg.bj.bcebos.com/dygraph/ppseg/ppseg_lite_portrait_398x224_with_softmax.tar.gz
tar -xzf ppseg_lite_portrait_398x224_with_softmax.tar.gz
```
### 3. 多策略融合压缩
每一个小章节代表一种多策略融合压缩方式。
### 3.1 进行蒸馏稀疏压缩
自动压缩训练需要准备 config 文件、数据集 dataloader 以及测试函数(``eval_function``)。
#### 3.1.1 配置config
使用自动压缩进行蒸馏和非结构化稀疏的联合训练,首先要配置 config 文件,包含蒸馏、稀疏和训练三部分参数。
- 蒸馏参数
蒸馏参数主要设置蒸馏节点(``distill_node_pair``)和教师网络测预测模型路径。蒸馏节点需包含教师网络节点和对应的学生网络节点,其中教师网络节点名称将在程序中自动添加 “teacher_” 前缀,如下所示。
```yaml
Distillation:
distill_lambda: 1.0
distill_loss: l2_loss
distill_node_pair:
- teacher_relu_30.tmp_0
- relu_30.tmp_0
merge_feed: true
teacher_model_dir: ./inference_model
teacher_model_filename: model.pdmodel
teacher_params_filename: model.pdiparams
```
- 稀疏参数
稀疏参数设置如下所示,其中参数含义详见[非结构化稀疏API文档](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/dygraph/pruners/unstructured_pruner.rst)
```yaml
UnstructurePrune:
prune_strategy: gmp
prune_mode: ratio
pruned_ratio: 0.75
gmp_config:
stable_iterations: 0
pruning_iterations: 4500
tunning_iterations: 4500
resume_iteration: -1
pruning_steps: 100
initial_ratio: 0.15
prune_params_type: conv1x1_only
local_sparsity: True
```
- 训练参数
训练参数主要设置学习率、训练次数(epochs)和优化器等。
```yaml
TrainConfig:
epochs: 14
eval_iter: 400
learning_rate: 5.0e-03
optimizer: SGD
optim_args:
weight_decay: 0.0005
```
#### 3.1.2 准备 dataloader 和测试函数
准备好数据集后,需将训练数据封装成 dict 类型传入自动压缩接口,可参考以下函数进行封装。测试函数用于测试模型精度,需在静态图模式下实现。
```python
def reader_wrapper(reader):
def gen():
for i, data in enumerate(reader()):
imgs = np.array(data[0])
yield {"x": imgs}
return gen
```
> 注:该dict类型的key值要和保存预测模型时的输入名称保持一致。
#### 3.1.3 开启训练
将训练数据集 dataloader 和测试函数传入接口 ``paddleslim.auto_compression.AutoCompression``,对模型进行非结构化稀疏训练。运行指令如下:
```shell
python run.py \
--model_dir='inference_model' \
--model_filename='inference.pdmodel' \
--params_filename='./inference.pdiparams' \
--save_dir='./save_model' \
--config_path='configs/humanseg_sparse_dis.yaml'
```
### 3.2 进行蒸馏量化压缩
#### 3.2.1 配置config
使用自动压缩进行量化训练,首先要配置config文件,包含蒸馏、量化和训练三部分参数。其中蒸馏和训练参数与稀疏训练类似,下面主要介绍量化参数的设置。
- 量化参数
量化参数主要设置量化比特数和量化op类型,其中量化op包含卷积层(conv2d, depthwise_conv2d)和全连接层(matmul)。以下为只量化卷积层的示例:
```yaml
Quantization:
activation_bits: 8
weight_bits: 8
is_full_quantize: false
not_quant_pattern:
- skip_quant
quantize_op_types:
- conv2d
- depthwise_conv2d
```
#### 3.2.2 开启训练
将数据集 dataloader 和测试函数(``eval_function``)传入接口``paddleslim.auto_compression.AutoCompression``,对模型进行量化训练。运行指令如下:
```
python run.py \
--model_dir='inference_model' \
--model_filename='inference.pdmodel' \
--params_filename='./inference.pdiparams' \
--save_dir='./save_model' \
--config_path='configs/humanseg_quant_dis.yaml'
```
Distillation:
distill_lambda: 1.0
distill_loss: l2_loss
distill_node_pair:
- teacher_reshape2_1.tmp_0 #
- reshape2_1.tmp_0
- teacher_reshape2_3.tmp_0 #
- reshape2_3.tmp_0
- teacher_reshape2_5.tmp_0 #
- reshape2_5.tmp_0
- teacher_reshape2_7.tmp_0 #block1
- reshape2_7.tmp_0
- teacher_reshape2_9.tmp_0 #
- reshape2_9.tmp_0
- teacher_reshape2_11.tmp_0 #
- reshape2_11.tmp_0
- teacher_reshape2_13.tmp_0 #
- reshape2_13.tmp_0
- teacher_reshape2_15.tmp_0 #
- reshape2_15.tmp_0
- teacher_reshape2_17.tmp_0 #
- reshape2_17.tmp_0
- teacher_reshape2_19.tmp_0 #
- reshape2_19.tmp_0
- teacher_reshape2_21.tmp_0 #
- reshape2_21.tmp_0
- teacher_depthwise_conv2d_14.tmp_0 # block2
- depthwise_conv2d_14.tmp_0
- teacher_depthwise_conv2d_15.tmp_0
- depthwise_conv2d_15.tmp_0
- teacher_reshape2_23.tmp_0 #block1
- reshape2_23.tmp_0
- teacher_relu_30.tmp_0 # final_conv
- relu_30.tmp_0
- teacher_bilinear_interp_v2_1.tmp_0
- bilinear_interp_v2_1.tmp_0
merge_feed: true
teacher_model_dir: ./inference_model
teacher_model_filename: inference.pdmodel
teacher_params_filename: inference.pdiparams
Quantization:
activation_bits: 8
is_full_quantize: false
not_quant_pattern:
- skip_quant
quantize_op_types:
- conv2d
- depthwise_conv2d
weight_bits: 8
TrainConfig:
epochs: 1
eval_iter: 400
learning_rate: 0.0005
optimizer: SGD
optim_args:
weight_decay: 4.0e-05
\ No newline at end of file
Distillation:
distill_lambda: 1.0
distill_loss: l2_loss
distill_node_pair:
- teacher_reshape2_1.tmp_0
- reshape2_1.tmp_0
- teacher_reshape2_3.tmp_0
- reshape2_3.tmp_0
- teacher_reshape2_5.tmp_0
- reshape2_5.tmp_0
- teacher_reshape2_7.tmp_0 #block1
- reshape2_7.tmp_0
- teacher_reshape2_9.tmp_0
- reshape2_9.tmp_0
- teacher_reshape2_11.tmp_0
- reshape2_11.tmp_0
- teacher_reshape2_13.tmp_0
- reshape2_13.tmp_0
- teacher_reshape2_15.tmp_0
- reshape2_15.tmp_0
- teacher_reshape2_17.tmp_0
- reshape2_17.tmp_0
- teacher_reshape2_19.tmp_0
- reshape2_19.tmp_0
- teacher_reshape2_21.tmp_0
- reshape2_21.tmp_0
- teacher_depthwise_conv2d_14.tmp_0 # block2
- depthwise_conv2d_14.tmp_0
- teacher_depthwise_conv2d_15.tmp_0
- depthwise_conv2d_15.tmp_0
- teacher_reshape2_23.tmp_0 #block1
- reshape2_23.tmp_0
- teacher_relu_30.tmp_0 # final_conv
- relu_30.tmp_0
- teacher_bilinear_interp_v2_1.tmp_0
- bilinear_interp_v2_1.tmp_0
merge_feed: true
teacher_model_dir: ./inference_model
teacher_model_filename: inference.pdmodel
teacher_params_filename: inference.pdiparams
UnstructurePrune:
prune_strategy: gmp
prune_mode: ratio
pruned_ratio: 0.75
gmp_config:
stable_iterations: 0
pruning_iterations: 4500
tunning_iterations: 4500
resume_iteration: -1
pruning_steps: 100
initial_ratio: 0.15
prune_params_type: conv1x1_only
local_sparsity: True
TrainConfig:
epochs: 14
eval_iter: 400
learning_rate: 5.0e-03
optim_args:
weight_decay: 0.0005
optimizer: SGD
\ No newline at end of file
import os
import argparse
import random
import paddle
import numpy as np
import paddleseg.transforms as T
from paddleseg.datasets import Dataset
from paddleseg.utils import worker_init_fn
from paddleslim.auto_compression.config_helpers import load_config
from paddleslim.auto_compression import AutoCompression
from paddleseg.core.infer import reverse_transform
from paddleseg.utils import metrics
import paddle.nn.functional as F
import cv2
import paddle.fluid as fluid
def parse_args():
parser = argparse.ArgumentParser(description='Model training')
parser.add_argument(
'--model_dir',
type=str,
default=None,
help="inference model directory.")
parser.add_argument(
'--model_filename',
type=str,
default=None,
help="inference model filename.")
parser.add_argument(
'--params_filename',
type=str,
default=None,
help="inference params filename.")
parser.add_argument(
'--save_dir',
type=str,
default=None,
help="directory to save compressed model.")
parser.add_argument(
'--config_path',
type=str,
default=None,
help="path of compression strategy config.")
return parser.parse_args()
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
nranks = paddle.distributed.ParallelEnv().local_rank
batch_sampler = paddle.io.DistributedBatchSampler(
eval_dataset, batch_size=1, shuffle=False, drop_last=False)
loader = paddle.io.DataLoader(
eval_dataset,
batch_sampler=batch_sampler,
num_workers=1,
return_list=True, )
total_iters = len(loader)
intersect_area_all = 0
pred_area_all = 0
label_area_all = 0
print("Start evaluating (total_samples: {}, total_iters: {})...".format(
len(eval_dataset), total_iters))
print("nranks:", nranks)
for iter, (image, label) in enumerate(loader):
paddle.enable_static()
label = np.array(label).astype('int64')
ori_shape = np.array(label).shape[-2:]
image = np.array(image)
logits = exe.run(compiled_test_program,
feed={test_feed_names[0]: image},
fetch_list=test_fetch_list,
return_numpy=True)
paddle.disable_static()
logit = logits[0]
logit = reverse_transform(
paddle.to_tensor(logit),
ori_shape,
eval_dataset.transforms.transforms,
mode='bilinear')
pred = paddle.argmax(
paddle.to_tensor(logit), axis=1, keepdim=True, dtype='int32')
intersect_area, pred_area, label_area = metrics.calculate_area(
pred,
paddle.to_tensor(label),
eval_dataset.num_classes,
ignore_index=eval_dataset.ignore_index)
if nranks > 1:
intersect_area_list = []
pred_area_list = []
label_area_list = []
paddle.distributed.all_gather(intersect_area_list, intersect_area)
paddle.distributed.all_gather(pred_area_list, pred_area)
paddle.distributed.all_gather(label_area_list, label_area)
# Some image has been evaluated and should be eliminated in last iter
if (iter + 1) * nranks > len(eval_dataset):
valid = len(eval_dataset) - iter * nranks
intersect_area_list = intersect_area_list[:valid]
pred_area_list = pred_area_list[:valid]
label_area_list = label_area_list[:valid]
for i in range(len(intersect_area_list)):
intersect_area_all = intersect_area_all + intersect_area_list[i]
pred_area_all = pred_area_all + pred_area_list[i]
label_area_all = label_area_all + label_area_list[i]
else:
intersect_area_all = intersect_area_all + intersect_area
pred_area_all = pred_area_all + pred_area
label_area_all = label_area_all + label_area
class_iou, miou = metrics.mean_iou(intersect_area_all, pred_area_all,
label_area_all)
class_acc, acc = metrics.accuracy(intersect_area_all, pred_area_all)
kappa = metrics.kappa(intersect_area_all, pred_area_all, label_area_all)
class_dice, mdice = metrics.dice(intersect_area_all, pred_area_all,
label_area_all)
infor = "[EVAL] #Images: {} mIoU: {:.4f} Acc: {:.4f} Kappa: {:.4f} Dice: {:.4f}".format(
len(eval_dataset), miou, acc, kappa, mdice)
print(infor)
paddle.enable_static()
return miou
def reader_wrapper(reader):
def gen():
for i, data in enumerate(reader()):
imgs = np.array(data[0])
yield {"x": imgs}
return gen
if __name__ == '__main__':
args = parse_args()
transforms = [T.RandomPaddingCrop(crop_size=(512, 512)), T.Normalize()]
train_dataset = Dataset(
transforms=transforms,
dataset_root='dataset_root', # Need to fill in
num_classes=2,
train_path='train_path', # Need to fill in
mode='train')
eval_dataset = Dataset(
transforms=transforms,
dataset_root='dataset_root', # Need to fill in
num_classes=2,
train_path='val_path', # Need to fill in
mode='val')
batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=128, shuffle=True, drop_last=True)
train_loader = paddle.io.DataLoader(
train_dataset,
batch_sampler=batch_sampler,
num_workers=2,
return_list=True,
worker_init_fn=worker_init_fn, )
train_dataloader = reader_wrapper(train_loader)
# set auto_compression
compress_config, train_config = load_config(args.config_path)
ac = AutoCompression(
model_dir=args.model_dir,
model_filename=args.model_filename,
params_filename=args.param_filename,
save_dir=args.save_dir,
strategy_config=compress_config,
train_config=train_config,
train_dataloader=train_dataloader,
eval_callback=eval_function)
ac.compress()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册