未验证 提交 a00f830f 编写于 作者: C ceci3 提交者: GitHub

add source free (#974)

* add source-free

* add source-free

* add asp

* update

* add fleet

* add docs

* rename

* fix unittest
上级 d66afebf
# 使用预测模型进行量化训练示例
预测模型保存接口:
动态图使用``paddle.jit.save``保存;
静态图使用``paddle.static.save_inference_model``保存。
本示例将介绍如何使用预测模型进行蒸馏量化训练,
首先使用接口``paddleslim.quant.quant_aware_with_infermodel``训练量化模型,
训练完成后,使用接口``paddleslim.quant.export_quant_infermodel``将训好的量化模型导出为预测模型。
## 分类模型量化训练流程
### 1. 准备数据
``demo``文件夹下创建``data``文件夹,将``ImageNet``数据集解压在``data``文件夹下,解压后``data/ILSVRC2012``文件夹下应包含以下文件:
- ``'train'``文件夹,训练图片
- ``'train_list.txt'``文件
- ``'val'``文件夹,验证图片
- ``'val_list.txt'``文件
### 2. 准备需要量化的模型
飞桨图像识别套件PaddleClas是飞桨为工业界和学术界所准备的一个图像识别任务的工具集,本示例使用该套件产出imagenet分类模型。
#### 2.1 下载PaddleClas release/2.3分支代码
<https://github.com/PaddlePaddle/PaddleClas/archive/refs/heads/release/2.3.zip>
解压后,进入PaddleClas目录
```
cd PaddleClas-release-2.3
```
#### 2.2 下载MobileNetV2预训练模型
在PaddleClas根目录创建``pretrained``文件夹:
```
mkdir pretrained
```
下载预训练模型
分类预训练模型库地址 <https://github.com/PaddlePaddle/PaddleClas/blob/release/2.3/docs/zh_CN/algorithm_introduction/ImageNet_models.md>
MobileNetV2预训练模型地址 <https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_pretrained.pdparams>
执行下载命令:
```
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_pretrained.pdparams -O ./pretrained/MobileNetV2_pretrained.pdparams
```
#### 2.3 导出预测模型
PaddleClas代码库根目录执行如下命令,导出预测模型
```
python tools/export_model.py \
-c ppcls/configs/ImageNet/MobileNetV2/MobileNetV2.yaml \
-o Global.pretrained_model=pretrained/MobileNetV2_pretrained \
-o Global.save_inference_dir=infermodel_mobilenetv2
```
#### 2.4 测试模型精度
拷贝``infermodel_mobilenetv2``文件夹到``PaddleSlim/demo/auto-compression/``文件夹。
```
cd PaddleSlim/demo/auto-compression/
```
使用[eval.py](../quant_post/eval.py)脚本得到模型的分类精度:
```
python ../quant_post/eval.py --model_path infermodel_mobilenetv2 --model_name inference.pdmodel --params_name inference.pdiparams
```
精度输出为:
```
top1_acc/top5_acc= [0.71918 0.90568]
```
### 3. 进行多策略融合压缩
每一个小章节代表一种多策略融合压缩,不代表需要串行执行。
### 3.1 进行量化蒸馏压缩
蒸馏量化训练示例脚本为[demo_imagenet.py](./demo_imagenet.py),使用接口``paddleslim.auto_compression.AutoCompression``对模型进行量化训练。运行命令为:
```
python demo_imagenet.py \
--model_dir='infermodel_mobilenetv2' \
--model_filename='inference.pdmodel' \
--params_filename='./inference.pdiparams' \
--save_dir='./save_qat_mbv2/' \
--devices='gpu' \
--batch_size=64 \
--config_path='./configs/CV/mbv2_qat_dis.yaml'
```
### 3.2 进行离线量化超参搜索压缩
离线量化超参搜索压缩示例脚本为[demo_imagenet.py](./demo_imagenet.py),使用接口``paddleslim.auto_compression.AutoCompression``对模型进行压缩。运行命令为:
```
python demo_imagenet.py \
--model_dir='infermodel_mobilenetv2' \
--model_filename='inference.pdmodel' \
--params_filename='./inference.pdiparams' \
--save_dir='./save_qat_mbv2/' \
--devices='gpu' \
--batch_size=64 \
--config_path='./configs/CV/mbv2_ptq_hpo.yaml'
```
### 3.3 进行剪枝蒸馏策略融合压缩
注意:本示例为对BERT模型进行ASP稀疏。
首先参考[脚本](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples/language_model/bert#%E9%A2%84%E6%B5%8B)得到可部署的模型。
剪枝蒸馏压缩示例脚本为[demo_glue.py](./demo_glue.py),使用接口``paddleslim.auto_compression.AutoCompression``对模型进行压缩。运行命令为:
```
python demo_glue.py \
--model_dir='./static_bert_models/' \
--model_filename='bert.pdmodel' \
--params_filename='bert.pdiparams' \
--save_dir='./save_asp_bert/' \
--devices='gpu' \
--batch_size=32 \
--task='sst-2' \
--config_path='./configs/NLP/bert_asp_dis.yaml'
```
### 3.4 进行非结构化稀疏蒸馏策略融合压缩
非结构化稀疏蒸馏压缩示例脚本为[demo_imagenet.py](./demo_imagenet.py),使用接口``paddleslim.auto_compression.AutoCompression``对模型进行压缩。运行命令为:
```
python demo_imagenet.py \
--model_dir='infermodel_mobilenetv2' \
--model_filename='inference.pdmodel' \
--params_filename='./inference.pdiparams' \
--save_dir='./save_qat_mbv2/' \
--devices='gpu' \
--batch_size=64 \
--config_path='./configs/CV/xxx.yaml'
```
HyperParameterOptimization:
batch_num:
- 4
- 16
bias_correct:
- true
hist_percent:
- 0.999
- 0.99999
max_quant_count: 20
ptq_algo:
- KL
- hist
weight_quantize_type:
- channel_wise_abs_max
Quantization:
activation_bits: 8
quantize_op_types:
- conv2d
- depthwise_conv2d
- mul
weight_bits: 8
Distillation:
distill_lambda: 1.0
distill_loss: l2_loss
distill_node_pair:
- teacher_conv2d_54.tmp_0
- conv2d_54.tmp_0
- teacher_conv2d_55.tmp_0
- conv2d_55.tmp_0
- teacher_conv2d_57.tmp_0
- conv2d_57.tmp_0
- teacher_elementwise_add_0
- elementwise_add_0
- teacher_conv2d_61.tmp_0
- conv2d_61.tmp_0
- teacher_elementwise_add_1
- elementwise_add_1
- teacher_elementwise_add_2
- elementwise_add_2
- teacher_conv2d_67.tmp_0
- conv2d_67.tmp_0
- teacher_elementwise_add_3
- elementwise_add_3
- teacher_elementwise_add_4
- elementwise_add_4
- teacher_elementwise_add_5
- elementwise_add_5
- teacher_conv2d_75.tmp_0
- conv2d_75.tmp_0
- teacher_elementwise_add_6
- elementwise_add_6
- teacher_elementwise_add_7
- elementwise_add_7
- teacher_conv2d_81.tmp_0
- conv2d_81.tmp_0
- teacher_elementwise_add_8
- elementwise_add_8
- teacher_elementwise_add_9
- elementwise_add_9
- teacher_conv2d_87.tmp_0
- conv2d_87.tmp_0
- teacher_linear_1.tmp_0
- linear_1.tmp_0
merge_feed: true
teacher_model_dir: ./MobileNetV2_ssld_infer
teacher_model_filename: inference.pdmodel
teacher_params_filename: inference.pdiparams
Quantization:
activation_bits: 8
is_full_quantize: false
not_quant_pattern:
- skip_quant
quantize_op_types:
- conv2d
- depthwise_conv2d
weight_bits: 8
TrainConfig:
epochs: 1
eval_iter: 1000
learning_rate: 0.0001
optimizer: SGD
origin_metric: 0.765
weight_decay: 4.0e-05
Distillation:
distill_lambda: 1.0
distill_loss: l2_loss
distill_node_pair:
- teacher_tmp_9
- tmp_9
- teacher_tmp_12
- tmp_12
- teacher_tmp_15
- tmp_15
- teacher_tmp_18
- tmp_18
- teacher_tmp_21
- tmp_21
- teacher_tmp_24
- tmp_24
- teacher_tmp_27
- tmp_27
- teacher_tmp_30
- tmp_30
- teacher_tmp_33
- tmp_33
- teacher_tmp_36
- tmp_36
- teacher_tmp_39
- tmp_39
- teacher_tmp_42
- tmp_42
- teacher_linear_147.tmp_1
- linear_147.tmp_1
merge_feed: true
teacher_model_dir: ../auto-compression_origin/static_bert_models
teacher_model_filename: bert.pdmodel
teacher_params_filename: bert.pdiparams
Prune:
prune_algo: asp
TrainConfig:
epochs: 3
eval_iter: 1000
learning_rate: 2.0e-05
optim_args:
weight_decay: 0.0
optimizer: AdamW
origin_metric: 0.93
HyperParameterOptimization:
batch_num:
- 4
- 16
bias_correct:
- true
hist_percent:
- 0.999
- 0.99999
max_quant_count: 20
ptq_algo:
- KL
- hist
weight_quantize_type:
- channel_wise_abs_max
Quantization:
activation_bits: 8
quantize_op_types:
- conv2d
- depthwise_conv2d
- mul
weight_bits: 8
Distillation:
distill_lambda: 1.0
distill_loss: l2_loss
distill_node_pair:
- teacher_tmp_9
- tmp_9
- teacher_tmp_12
- tmp_12
- teacher_tmp_15
- tmp_15
- teacher_tmp_18
- tmp_18
- teacher_tmp_21
- tmp_21
- teacher_tmp_24
- tmp_24
- teacher_tmp_27
- tmp_27
- teacher_tmp_30
- tmp_30
- teacher_tmp_33
- tmp_33
- teacher_tmp_36
- tmp_36
- teacher_tmp_39
- tmp_39
- teacher_tmp_42
- tmp_42
- teacher_linear_147.tmp_1
- linear_147.tmp_1
merge_feed: true
teacher_model_dir: ../auto-compression_origin/static_bert_models
teacher_model_filename: bert.pdmodel
teacher_params_filename: bert.pdiparams
Quantization:
activation_bits: 8
is_full_quantize: false
not_quant_pattern:
- skip_quant
quantize_op_types:
- conv2d
- depthwise_conv2d
- mul
- matmul
weight_bits: 8
TrainConfig:
epochs: 3
eval_iter: 1000
learning_rate: 0.0001
optimizer: SGD
optim_args:
weight_decay: 4.0e-05
origin_metric: 0.93
import os
import sys
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
import argparse
import functools
from functools import partial
import numpy as np
import paddle
import paddle.nn as nn
from paddle.io import Dataset, BatchSampler, DataLoader
from paddle.metric import Metric, Accuracy, Precision, Recall
from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer
from paddlenlp.datasets import load_dataset
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.data.sampler import SamplerHelper
from paddlenlp.metrics import Mcc, PearsonAndSpearman
from paddleslim.auto_compression.config_helpers import load_config
from paddleslim.auto_compression.compressor import AutoCompression
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('model_dir', str, None, "inference model directory.")
add_arg('model_filename', str, None, "inference model filename.")
add_arg('params_filename', str, None, "inference params filename.")
add_arg('save_dir', str, None, "directory to save compressed model.")
add_arg('devices', str, 'gpu', "which device used to compress.")
add_arg('batch_size', int, 1, "train batch size.")
add_arg('task', str, 'sst-2', "task name in glue.")
add_arg('config_path', str, None, "path of compression strategy config.")
# yapf: enable
METRIC_CLASSES = {
"cola": Mcc,
"sst-2": Accuracy,
"sts-b": PearsonAndSpearman,
"mnli": Accuracy,
"qnli": Accuracy,
"rte": Accuracy,
}
def convert_example(example,
tokenizer,
label_list,
max_seq_length=512,
is_test=False):
"""
Convert a glue example into necessary features.
"""
if not is_test:
# `label_list == None` is for regression task
label_dtype = "int64" if label_list else "float32"
# Get the label
label = example['labels']
label = np.array([label], dtype=label_dtype)
# Convert raw text to feature
example = tokenizer(example['sentence'], max_seq_len=max_seq_length)
if not is_test:
return example['input_ids'], example['token_type_ids'], label
else:
return example['input_ids'], example['token_type_ids']
def create_data_holder(task_name):
"""
Define the input data holder for the glue task.
"""
input_ids = paddle.static.data(
name="input_ids", shape=[-1, -1], dtype="int64")
token_type_ids = paddle.static.data(
name="token_type_ids", shape=[-1, -1], dtype="int64")
if task_name == "sts-b":
label = paddle.static.data(name="label", shape=[-1, 1], dtype="float32")
else:
label = paddle.static.data(name="label", shape=[-1, 1], dtype="int64")
return [input_ids, token_type_ids, label]
def reader():
# Create the tokenizer and dataset
tokenizer = BertTokenizer.from_pretrained(args.model_dir)
train_ds = load_dataset('glue', args.task, splits="train")
trans_func = partial(
convert_example,
tokenizer=tokenizer,
label_list=train_ds.label_list,
max_seq_length=128,
is_test=True)
train_ds = train_ds.map(trans_func, lazy=True)
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # token_type
): fn(samples)
train_batch_sampler = paddle.io.BatchSampler(
train_ds, batch_size=32, shuffle=True)
[input_ids, token_type_ids, labels] = create_data_holder(args.task)
feed_list_name = []
train_data_loader = DataLoader(
dataset=train_ds,
feed_list=[input_ids, token_type_ids],
batch_sampler=train_batch_sampler,
collate_fn=batchify_fn,
num_workers=0,
return_list=False)
dev_trans_func = partial(
convert_example,
tokenizer=tokenizer,
label_list=train_ds.label_list,
max_seq_length=128)
dev_batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # token_type
Stack(dtype="int64" if train_ds.label_list else "float32") # label
): fn(samples)
dev_ds = load_dataset('glue', args.task, splits='dev')
dev_ds = dev_ds.map(dev_trans_func, lazy=True)
dev_batch_sampler = paddle.io.BatchSampler(
dev_ds, batch_size=32, shuffle=False)
dev_data_loader = DataLoader(
dataset=dev_ds,
batch_sampler=dev_batch_sampler,
collate_fn=dev_batchify_fn,
num_workers=0,
feed_list=[input_ids, token_type_ids, labels],
return_list=False)
return train_data_loader, dev_data_loader
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
metric.reset()
for data in eval_dataloader():
logits = exe.run(compiled_test_program,
feed={
test_feed_names[0]: data[0]['input_ids'],
test_feed_names[1]: data[0]['token_type_ids']
},
fetch_list=test_fetch_list)
paddle.disable_static()
labels_pd = paddle.to_tensor(np.array(data[0]['label']))
logits_pd = paddle.to_tensor(logits[0])
correct = metric.compute(logits_pd, labels_pd)
metric.update(correct)
paddle.enable_static()
res = metric.accumulate()
return res
def apply_decay_param_fun(name):
if name.find("bias") > -1:
return True
elif name.find("norm") > -1:
return True
else:
return False
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
paddle.enable_static()
compress_config, train_config = load_config(args.config_path)
if train_config is not None and 'optim_args' in train_config:
train_config['optim_args'][
'apply_decay_param_fun'] = apply_decay_param_fun
train_dataloader, eval_dataloader = reader()
metric_class = METRIC_CLASSES[args.task]
metric = metric_class()
ac = AutoCompression(
model_dir=args.model_dir,
model_filename=args.model_filename,
params_filename=args.params_filename,
save_dir=args.save_dir,
strategy_config=compress_config,
train_config=train_config,
train_dataloader=train_dataloader,
eval_callback=eval_function,
devices=args.devices)
ac.compress()
import os
import sys
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
import argparse
import functools
from functools import partial
import numpy as np
import paddle
import paddle.nn as nn
from paddle.io import Dataset, BatchSampler, DataLoader
import imagenet_reader as reader
from paddleslim.auto_compression.config_helpers import load_config
from paddleslim.auto_compression import AutoCompression
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('model_dir', str, None, "inference model directory.")
add_arg('model_filename', str, None, "inference model filename.")
add_arg('params_filename', str, None, "inference params filename.")
add_arg('save_dir', str, None, "directory to save compressed model.")
add_arg('devices', str, 'gpu', "which device used to compress.")
add_arg('batch_size', int, 1, "train batch size.")
add_arg('config_path', str, None, "path of compression strategy config.")
# yapf: enable
def reader_wrapper(reader):
def gen():
for i, data in enumerate(reader()):
imgs = np.float32([item[0] for item in data])
yield {"inputs": imgs}
return gen
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
val_reader = paddle.batch(reader.val(), batch_size=1)
image = paddle.static.data(
name='x', shape=[None, 3, 224, 224], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
results = []
for batch_id, data in enumerate(val_reader()):
# top1_acc, top5_acc
if len(test_feed_names) == 1:
# eval "infer model", which input is image, output is classification probability
image = data[0][0].reshape((1, 3, 224, 224))
label = [[d[1]] for d in data]
pred = exe.run(compiled_test_program,
feed={test_feed_names[0]: image},
fetch_list=test_fetch_list)
pred = np.array(pred[0])
label = np.array(label)
sort_array = pred.argsort(axis=1)
top_1_pred = sort_array[:, -1:][:, ::-1]
top_1 = np.mean(label == top_1_pred)
top_5_pred = sort_array[:, -5:][:, ::-1]
acc_num = 0
for i in range(len(label)):
if label[i][0] in top_5_pred[i]:
acc_num += 1
top_5 = float(acc_num) / len(label)
results.append([top_1, top_5])
else:
# eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
image = data[0][0].reshape((1, 3, 224, 224))
label = [[d[1]] for d in data]
result = exe.run(
compiled_test_program,
feed={test_feed_names[0]: image,
test_feed_names[1]: label},
fetch_list=test_fetch_list)
result = [np.mean(r) for r in result]
results.append(result)
result = np.mean(np.array(results), axis=0)
return result[0]
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
paddle.enable_static()
compress_config, train_config = load_config(args.config_path)
train_reader = paddle.batch(reader.train(), batch_size=64)
train_dataloader = reader_wrapper(train_reader)
ac = AutoCompression(
model_dir=args.model_dir,
model_filename=args.model_filename,
params_filename=args.params_filename,
save_dir=args.save_dir,
strategy_config=compress_config,
train_config=train_config,
train_dataloader=train_dataloader,
eval_callback=eval_function,
devices=args.devices)
ac.compress()
python3.7 demo_glue.py --config_path ./configs/NLP/bert_qat_dis.yaml --task 'sst-2' \
--model_dir='../auto-compression_origin/static_bert_models/' \
--model_filename='bert.pdmodel' \
--params_filename='bert.pdiparams' \
--save_dir='./save_asp_bert/' \
--devices='gpu' \
--batch_size=32 \
python3.7 demo_imagenet.py --config_path ./configs/CV/mbv2_ptq_hpo.yaml \
--model_dir='../auto-compression_origin/MobileNetV2_ssld_infer/' \
--model_filename='inference.pdmodel' \
--params_filename='./inference.pdiparams' \
--save_dir='./save_qat_mbv2/' \
--devices='gpu' \
--batch_size=64 \
AutoCompression自动压缩功能
==========
AutoCompression
---------------
.. py:class:: paddleslim.auto_compression.AutoCompression(model_dir, model_filename, params_filename, save_dir, strategy_config, train_config, train_dataloader, eval_callback, devices='gpu')
`源代码 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/auto_compression.py#L32>`_
根据指定的配置对使用 ``paddle.jit.save`` 接口或者 ``paddle.static.save_inference_model`` 接口保存的推理模型进行压缩。
**参数: **
- **model_dir(str)** - 需要压缩的推理模型所在的目录。
- **model_filename(str)** - 需要压缩的推理模型文件名称。
- **params_filename(str)** - 需要压缩的推理模型参数文件名称。
- **save_dir(str)** - 压缩后模型的所保存的目录。
- **strategy_config(dict)** - 使用的压缩策略。字典的关键字必须在: ``Quantization`` (量化配置, 可配置的参数参考 `<>`_ ), ``Distillation`` (蒸馏配置, 可配置的参数参考 `<>`_),
``MultiTeacherDistillation`` (多teacher蒸馏配置, 可配置的参数参考 `<>`_), ``HyperParameterOptimization`` (超参搜索配置, 可配置的参数参考 `<>`_),
``Prune`` (剪枝配置, 可配置的参数参考 `<>`_), ``UnstructurePrune`` (非结构化稀疏配置, 可配置的参数参考 `<>`_) 之间选择。目前关键字只支持以下几种配置:
1) ``Quantization`` & ``HyperParameterOptimization``: 离线量化超参搜索策略;
2) ``Quantization`` & ``Distillation``: 量化训练和蒸馏的策略;
3) ``Prune`` & ``Distillation``: 结构化剪枝和蒸馏的策略;
4) ``UnstructurePrune`` & ``Distillation``: 非结构化稀疏和蒸馏的策略;
5) ``Distillation``: 单独单蒸馏策略;
6) ``MultiTeacherDistillation``: 多teacher蒸馏策略。
每种配置的具体参数信息可以参考:。
- **train_config(dict)** - 训练配置。可以配置的参数请参考: `<>`_ 。注意:如果选择离线量化超参搜索策略的话, ``train_config`` 直接设置为 ``None`` 即可。
- **train_dataloader(paddle.io.DataLoader)** - 训练数据迭代器。注意:如果选择离线量化超参搜索策略的话, ``train_dataloader`` 和 ``eval_callback`` 设置相同的数据读取即可。
- **eval_callback(paddle.io.DataLoader|function)** - eval回调函数和测试数据迭代器之间必须传入一个,如果传入回调函数,则使用回调函数判断模型训练情况, 回调函数的写法参考: `<>`_ 。如果传入测试数据迭代器,则使用 ``EMD`` 距离判断压缩前后模型之间的差别,目前仅支持离线量化超参搜索使用这种方式判断压缩前后模型的压缩。
- **devices(str)** - 确定特定的运行设备,可以是 ``cpu`` , ``gpu``, ``npu``, ``gpu:x``, ``xpu:x``, ``npu:x`` 。其中, ``x`` 是GPU, XPU 或者是NPU的编号。默认: ``gpu`` 。
**返回:** 一个AutoCompression类的实例。
**示例代码:**
```shell
import paddle
from paddleslim.auto_compression import AutoCompression
default_qat_config = {
"quantize_op_types": ["conv2d", "depthwise_conv2d", "mul"],
"weight_bits": 8,
"activation_bits": 8,
"is_full_quantize": False,
"not_quant_pattern": ["skip_quant"],
}
default_distill_config = {
"distill_loss": args.distill_loss,
"distill_node_pair": args.distill_node_pair,
"distill_lambda": args.distill_lambda,
"teacher_model_dir": args.teacher_model_dir,
"teacher_model_filename": args.teacher_model_filename,
"teacher_params_filename": args.teacher_params_filename,
}
train_dataloader = Cifar10(mode='train')
eval_dataloader = Cifar10(mode='eval')
ac = AutoCompression(model_path, model_filename, params_filename, save_dir, \
strategy_config="Quantization": Quantization(**default_ptq_config),
"HyperParameterOptimization": HyperParameterOptimization(**default_hpo_config)}, \
train_config=None, train_dataloader=train_dataloader, eval_callback=eval_dataloader,devices='gpu')
```
.. py:method:: paddleslim.auto_compression.AutoCompression.compress()
开始进行压缩。
TrainConfig
----------
训练超参配置。
**参数:**
- **epochs(int)** - 训练的轮数,表明当前数据集需要训练几次。
- **learning_rate(float|LRScheduler)** - 模型优化过程中的学习率。
- **optimizer(str)** - 使用的优化器,需要是 ``paddle.optimizer`` 中优化器的名字, 例如: ``SGD`` 。
- **optim_args(dict)** - 优化器参数。可以指定以下参数:
``grid_clip`` ,指名使用的梯度裁剪的方法,需要是 ``paddle.nn`` 中梯度裁剪的类的名字,例如: ``ClipGradByValue`` 等。
``grad_clip_args`` ,梯度裁剪方法中的参数,例如:梯度裁剪选择的方式为 ``ClipGradByValue`` ,那么 ``grad_clip_args`` 可以设置的参数为 ``max`` 和 ``min`` ,参考: `ClipGradByValue <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/nn/ClipGradByValue_cn.html#clipgradbyvalue>`_ 。
其他优化器中可能需要的参数,例如: ``beta1``, ``beta2``, ``apply_decay_param_fun`` 等,参考: `AdamW <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/optimizer/AdamW_cn.html#adamw>`_ 。
- **eval_iter(int)** - 训练多少batch的数据进行一次测试。
- **logging_iter(int)** - 训练多少batch的数据进行一次打印。
- **origin_metric(float)** - 要压缩的推理模型的原始精度,可以用来判断实现的eval function是否有问题, 默认: ``None`` 。
- **target_metric(float, optional)** - 如果训练过程中压缩后模型达到了要求的精度,即退出训练,返回当前达到精度的模型,若没有设置该参数,则训练完设置的epochs数量, 默认: ``None`` 。
- **use_fleet(bool, optional)** - 是否使用fleet api去进行分布式训练,默认: ``None`` 。
- **amp_config(dict, optional)** - 如果使用混合精度训练的话,需要配置本参数。参数按照以下规则进行配置:
1) 若不使用fleet api:
a) 使用 `静态图AMP-O1功能 <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/01_paddle2.0_introduction/basic_concept/amp_cn.html#id2>`_ , 需要配置: ``custom_white_list``, ``custom_black_list``, ``custom_black_varnames`` 参数。
b) 使用 `静态图AMP-O2功能 <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/01_paddle2.0_introduction/basic_concept/amp_cn.html#id3>`_ , 则需要配置: ``use_pure_fp16`` 和 ``use_fp16_guard`` 参数。
2) 使用fleet api:
参考接口: `amp_config <https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/distributed/fleet/DistributedStrategy_cn.html#amp_configs>`_ 来进行相对应的参数配置。
- **recompute_config(dict, optional)** - 使用fleet api的前提下可以使用recompute显存优化逻辑。参数按照fleet 接口中所描述的进行配置: `recompute_configs <https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/distributed/fleet/DistributedStrategy_cn.html#recompute_configs>`_ 。
- **sharding_config(dict, optional)** - 使用fleet api的前提下可以使用sharding 策略。参数按照fleet 接口中所描述的进行配置: `sharding_configs <https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/distributed/fleet/DistributedStrategy_cn.html#sharding_configs>`_ 。
Quantization
----------
量化配置。
**参数:**
- **quantize_op_types(list[str])** - 需要进行量化的 op 类型。
- **weight_bits(int)** - 参数量化bit数。
- **activation_bits(int)** - 激活量化bit数。
- **is_full_quantize(bool)** - 是否量化所有可支持op类型。
- **not_quant_pattern(str|list[str])** - 所有 ``name_scope`` 包含 ``'not_quant_pattern'`` 字符串的 op 都不量化, 设置方式请参考 `fluid.name_scope <https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/name_scope_cn.html#name-scope>`_ 。
Distillation
----------
蒸馏配置。
**参数:**
- **distill_loss(str|list[str])** - 蒸馏损失名字,可以设置的损失类型为paddleslim中支持的蒸馏损失,可选的损失函数有: ``fsp_loss``, ``l2_loss``, ``soft_label_loss`` 。如果您需要其他损失函数,可以暂时通过向 `蒸馏损失文件<https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dist/single_distiller.py>`_ z中添加相应的损失函数计算,或者通过提issue的方式我们来协助解决。
- **distill_node_pair(list[str])** - 蒸馏节点名字列表,每两个节点组成一对,分别属于教师模型和学生模型。
- **distill_lambda(float|list[float])** - 每一个蒸馏损失的权重,长度需要和 ``distill_loss`` 的长度保持一致。
- **teacher_model_dir(str)** - 教师模型的目录。
- **teacher_model_filename(str)** - 教师模型的模型文件名字。
- **teacher_params_filename(str)** - 教师模型的参数文件名字。
- **merge_feed(bool)** - 蒸馏过程是否需要共享同一个输入数据。默认: ``True`` 。
MultiTeacherDistillation
----------
多teacher蒸馏配置。
**参数:**
- **distill_loss(list[str])** - 蒸馏损失名字,可以设置的损失类型为paddleslim中支持的蒸馏损失,可选的损失函数有: ``fsp_loss``, ``l2_loss``, ``soft_label_loss`` 。如果您需要其他损失函数,可以暂时通过向 `蒸馏损失文件<https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dist/single_distiller.py>`_ z中添加相应的损失函数计算,或者通过提issue的方式我们来协助解决。
- **distill_node_pair(list[list[str]])** - 蒸馏节点名字嵌套列表,教师模型的个数和外部列表的长度需要保持一致。每一个列表代表一个教师模型和学生模型直接的蒸馏节点,其中每两个节点组成一对,分别属于教师模型和学生模型。
- **distill_lambda(list[float])** - 每一个蒸馏损失的权重,长度需要和 ``distill_loss`` 的长度保持一致。
- **teacher_model_dir(list[str])** - 教师模型的目录列表。
- **teacher_model_filename(list[str])** - 教师模型的模型文件名字列表。
- **teacher_params_filename(list[str])** - 教师模型的参数文件名字列表。
- **merge_feed(bool)** - 蒸馏过程是否需要共享同一个输入数据。默认: ``True`` 。
HyperParameterOptimization
----------
超参搜索搜索空间配置。
.. note::
目前超参搜索仅支持对离线量化算法进行搜索,所以搜索空间配置都是和离线量化相关的配置。
**参数:**
- **ptq_algo(str|list[str])** - 离线量化算法,可为 ``KL``,``mse``, ``'hist``, ``avg``,或者 ``abs_max`` ,该参数仅针对激活值的量化。
- **bias_correct(bool|list[bool])** - 是否使用 bias correction 算法。
- **weight_quantize_type(str|list[str])** - weight的量化方式,可选 ``abs_max`` 或者 ``channel_wise_abs_max`` 。
- **hist_percent(float|list[float])** - ``hist`` 方法的百分位数,设置类型为列表的话,列表中的最大最小值会作为上下界,在上下界范围内进行均匀采样。
- **batch_num(int|list[int])** - 迭代次数, 设置类型为列表的话,列表中的最大最小值会作为上下界,在上下界范围内进行均匀采样。
- **max_quant_count(int)** - 超参搜索运行的最大轮数, 默认:20。
PruneConfig
----------
裁剪配置。
**参数:**
- **prune_algo(str)** - 裁剪算法,可设置为: ``prune`` 或者 ``asp`` 。 ``prune`` 暂时只支持对视觉模型进行压缩, ``asp`` 裁剪暂时只支持对 ``FC`` 进行压缩。
- **pruned_ratio(float)** - 裁剪比例。
- **prune_params_name(list[str])** - 参与裁剪的参数的名字。
- **criterion(str)** - 裁剪算法设置为 ``prune`` 时,评估一个卷积层内通道重要性所参考的指标。目前支持 ``l1_norm``, ``bn_scale``, ``geometry_median`` 。
UnstructurePrune
----------
非结构化稀疏配置。
**参数:**
- **prune_strategy(str, optional)** - 是否使用 ``GMP`` 方式做非结构化稀疏,设置为 ``None`` 的话则不使用 ``GMP`` 进行非结构化稀疏训练,设置为 ``gmp`` 的话则使用 ``GMP`` 进行非结构化稀疏训练。默认:None。
- **prune_mode(str)** - 稀疏化的模式,目前支持的模式有: ``ratio`` 和 ``threshold`` 。在 ``ratio`` 模式下,会给定一个固定比例,例如0.55,然后所有参数中重要性较低的50%会被置0。类似的,在 ``threshold`` 模式下,会给定一个固定阈值,例如1e-2,然后重要性低于1e-2的参数会被置0。
- **threshold(float)** - 稀疏化阈值期望,只有在 ``prune_mode = threshold`` 时才会生效。
- **prune_ratio(float)** - 稀疏化比例期望,只有在 mode== ``ratio`` 时才会生效。
- **gmp_config(dict, optional)** - 使用 ``GMP`` 模式做非结构化稀疏时,需要传入的特殊配置,可以包括以下配置:
``prune_steps(int)`` - 迭代训练多少iteration后,改变稀疏比例。
``initial_ratio(float)`` - 初始的稀疏比例。
其它配置可以参考非结构化稀疏接口中 `configs参数 <https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/static/prune/unstructured_prune_api.rst#gmpunstrucuturedpruner>`_ 的配置。
- **prune_params_type(str)** - 用以指定哪些类型的参数参与稀疏。目前只支持 ``None`` 和 ``conv1x1_only`` 两个选项,后者表示只稀疏化1x1卷积。而前者表示稀疏化除了归一化的参数。
- **local_sparsity(bool)** - 剪裁比例(ratio)应用的范围: ``local_sparsity`` 开启时意味着每个参与剪裁的参数矩阵稀疏度均为 ``ratio`` , 关闭时表示只保证模型整体稀疏度达到 ``ratio`` ,但是每个参数矩阵的稀疏度可能存在差异。
自动压缩功能 AutoCompression
======================================
.. toctree::
:maxdepth: 1
auto_compress_api.rst
custom_function.rst
如何基于Paddle自定义DataLoader
==========
可以参考飞桨官网:
1. `自定义数据集 <https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/02_paddle2.0_develop/02_data_load_cn.html#erzidingyishujuji>`_
2. `数据加载 <https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/02_paddle2.0_develop/02_data_load_cn.html#sanshujujiazai>`_
如何基于Paddle自定义测试回调函数
==========
1. 输入输出格式
-----------------
自定义测试回调函数的输入和输出是固定的。
1.1 输入
##########
回调函数必须有以下4个输入:
**executor**: 飞桨的执行器,执行器可以用来执行指定的 ``Program`` 或者 ``CompiledProgram`` 。
**program**: 飞桨对计算图的一种静态描述。
**feed_name_list**: 所需提供数据的所有变量名称(即所有输入变量的名称)。
**fetch_targets**: 包含模型的所有输出变量。通过这些输出变量即可得到模型的预测结果。
1.2 输出
##########
回调函数必须有1个输入:
**result(float)**: 模型的计算指标,仅返回最重要的指标即可,返回的指标用来判断是否数据读取是否正确,和训练过程中是否达到了设定的优化目标。
1.3 自定义计算逻辑
##########
首先需要根据 `如何基于Paddle自定义DataLoader <>`_ 章节定义测试数据集 ``test_dataloader`` 。
```python
### 定义包含几个固定输入的测试函数。
def eval_function(exe, program, feed_name_list, fetch_targets):
results = []
### 遍历数据集
for data in test_dataloader():
### 从数据集中提取出label
labels = data.pop('label')
### 传入实际数据,运行计算图,得到输出
outputs = exe.run(program, feed=data, fetch_list=fetch_targets)
### 根据输出结果和label信息计算当前批次数据指标
result.append(metric(outputs, labels))
### 返回float类型的整体指标
return np.mean(results)
```
......@@ -14,3 +14,4 @@
prune/prune_index.rst
dist/distill_index.rst
nas/nas_index.rst
auto-compression/auto_compression_index.rst
......@@ -20,7 +20,11 @@ from paddleslim import analysis
from paddleslim import dist
from paddleslim import quant
from paddleslim import dygraph
__all__ = ['models', 'prune', 'nas', 'analysis', 'dist', 'quant', 'dygraph']
from paddleslim import auto_compression
__all__ = [
'models', 'prune', 'nas', 'analysis', 'dist', 'quant', 'dygraph',
'auto_compression'
]
from paddleslim.dygraph import *
__all__ += dygraph.__all__
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from .compressor import *
from .strategy_config import *
from .config_helpers import *
__all__ = [
"AutoCompression", "Quantization", "Distillation",
"MultiTeacherDistillation", "HyperParameterOptimization", "Prune",
"UnstructurePrune", "ProgramInfo", "TrainConfig", "save_config",
"load_config"
]
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import numpy as np
import inspect
from collections import namedtuple, Iterable
import paddle
import paddle.distributed.fleet as fleet
from ..quant.quant_post_hpo import quant_post_hpo
from ..quant.quanter import convert
from ..common.recover_program import recover_inference_program
from ..common import get_logger
from .create_compressed_program import build_distill_program, build_quant_program, build_prune_program
from .strategy_config import ProgramInfo, merge_config
_logger = get_logger(__name__, level=logging.INFO)
class AutoCompression:
def __init__(self,
model_dir,
model_filename,
params_filename,
save_dir,
strategy_config,
train_config,
train_dataloader,
eval_callback,
devices='gpu'):
### model_dir(str): 模型路径
### model_filename(str): 模型文件名称
### params_filename(str): 参数文件名称
### save_dir(str): 压缩后模型保存的路径
### strategy_config(dict[dict]): 压缩策略配置, 包括量化配置、蒸馏配置
### train_config(dict): 训练配置
### train_dataloader(paddle.nn.Dataloader): 训练数据dataloader
### eval_callback(function,paddle.nn.Dataloader): eval回调函数,和测试数据之间必须传入一个,如果传入回调函数,则使用回调函数判断模型训练情况。callback传入predict结果(paddle的tensor),默认:None。
self.model_dir = model_dir
self.model_filename = model_filename
self.params_filename = params_filename
self.save_dir = save_dir
self.strategy_config = strategy_config
self.train_config = train_config
self.train_dataloader = train_dataloader
paddle.enable_static()
if self.train_config is not None and self.train_config.use_fleet:
fleet.init(is_collective=True)
if self._prepare_eval(eval_callback) == 'eval_dataloader':
self.eval_function = None
self.eval_dataloader = eval_callback
else:
self.eval_function = eval_callback
self.eval_dataloader = None
self._strategy, self._config = self._prepare_strategy()
self._exe, self._places = self._prepare_envs(devices)
def _prepare_envs(self, devices):
places = paddle.device._convert_to_place(devices)
exe = paddle.static.Executor(places)
return exe, places
def _prepare_strategy(self):
quant_config = self.strategy_config.get("Quantization", None)
hpo_config = self.strategy_config.get("HyperParameterOptimization",
None)
prune_config = self.strategy_config.get("Prune", None)
unstructure_prune_config = self.strategy_config.get("UnstructurePrune",
None)
single_teacher_distill_config = self.strategy_config.get("Distillation",
None)
multi_teacher_distill_config = self.strategy_config.get(
"MultiTeacherDistillation", None)
assert (single_teacher_distill_config is None) or (multi_teacher_distill_config is None), \
"Distillation and MultiTeacherDistillation cannot be set at the same time."
self._distill_config = single_teacher_distill_config if \
single_teacher_distill_config is not None else \
multi_teacher_distill_config
### case1: quant_config & hpo_config ==> PTQ & HPO
if quant_config is not None and hpo_config is not None:
strategy = 'ptq_hpo'
config = merge_config(quant_config, hpo_config)
### case2: quant_config & distill config ==> QAT & Distill
elif quant_config is not None and self._distill_config is not None:
strategy = 'qat_dis'
config = merge_config(quant_config, self._distill_config)
### case3: prune_config & distill config
elif prune_config is not None and self._distill_config is not None:
strategy = 'prune_dis'
config = merge_config(prune_config, self._distill_config)
### case4: unstructure_config & distill config
elif unstructure_prune_config is not None and self._distill_config is not None:
strategy = 'unstructure_prune_dis'
config = merge_config(unstructure_prune_config,
self._distill_config)
### case4: distill_config
elif self._distill_config is not None:
if single_teacher_distill_config is not None:
strategy = 'single_teacher_dis'
config = single_teacher_distill_config
else:
strategy = 'multi_teacher_dis'
config = multi_teacher_distill_config
### case N: todo
else:
raise NotImplementedError(
"Not Implemented {} be set at the same time now".format(
self.strategy_config.keys()))
return strategy, config
def _prepare_fleet_strategy(train_config):
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
strategy = fleet.DistributedStrategy()
strategy.build_strategy = build_strategy
if train_config.recompute_config is not None:
strategy.recompute = True
strategy.recompute_configs = { ** train_config.recompute_config}
if train_config.sharding_config is not None:
strategy.sharding = True
strategy.sharding_configs = { ** train_config.sharding_config}
if train_config.amp_config is not None:
strategy.amp = True
strategy.amp_configs = { ** train_config.amp_config}
return strategy
def _prepare_program(self, program, feed_target_names, fetch_targets):
train_program = recover_inference_program(program)
startup_program = paddle.static.Program()
train_program_info = ProgramInfo(startup_program, train_program,
feed_target_names, fetch_targets)
config_dict = dict(self._config._asdict())
### add prune program
self._pruner = None
if 'prune' in self._strategy:
self._pruner, train_program_info = build_prune_program(
self._exe, self._places, config_dict, train_program_info,
self._strategy)
if self.train_config.use_fleet:
dist_strategy = _prepare_fleet_strategy(self.train_config)
else:
dist_strategy = None
### add distill program
if 'dis' in self._strategy:
train_program_info, test_program_info = build_distill_program(
self._exe,
self._places,
config_dict,
self.train_config._asdict(),
train_program_info,
pruner=self._pruner,
dist_strategy=dist_strategy)
self._quant_config = None
### add quant_aware program, quant always is last step
if 'qat' in self._strategy:
train_program_info, test_program_info, self._quant_config = build_quant_program(
self._exe, self._places, config_dict, train_program_info,
test_program_info)
self._exe.run(train_program_info.startup_program)
if (not self.train_config.use_fleet
) and self.train_config.amp_config is not None:
if hasattr(self.train_config.amp_config, 'use_pure_fp16'
) and self.train_config.amp_config.use_pure_fp16:
train_program_info.optimizer.amp_init(
self._places, scope=paddle.static.global_scope())
if 'prune_algo' in config_dict and config_dict['prune_algo'] == 'asp':
### prune weight in scope
self._pruner.prune_model(train_program_info.program)
if not self.train_config.use_fleet:
train_program_info = self._compiled_program(train_program_info,
self._strategy)
test_program_info = self._compiled_program(test_program_info,
self._strategy)
return train_program_info, test_program_info
def _prepare_eval(self, eval_callback):
if isinstance(eval_callback,
Iterable) or inspect.isgeneratorfunction(eval_callback):
return 'eval_dataloader'
else:
return 'eval_callback'
def _compiled_program(self, program_info, strategy):
compiled_prog = paddle.static.CompiledProgram(program_info.program)
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
if 'qat' in strategy:
build_strategy.memory_optimize = False
build_strategy.enable_inplace = False
build_strategy.fuse_all_reduce_ops = False
build_strategy.sync_batch_norm = False
compiled_prog = compiled_prog.with_data_parallel(
loss_name=program_info.fetch_targets[0].name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
program_info.program = compiled_prog
return program_info
def compress(self):
### start compress, including train/eval model
if self._strategy == 'ptq_hpo':
quant_post_hpo(
self._exe,
self._places,
model_dir=self.model_dir,
quantize_model_path=self.save_dir,
train_dataloader=self.train_dataloader,
eval_dataloader=self.eval_dataloader,
eval_function=self.eval_function,
model_filename=self.model_filename,
params_filename=self.params_filename,
save_model_filename=self.model_filename,
save_params_filename=self.params_filename,
quantizable_op_type=self._config.quantize_op_types,
weight_bits=self._config.weight_bits,
activation_bits=self._config.activation_bits,
weight_quantize_type=self._config.weight_quantize_type,
is_full_quantize=self._config.is_full_quantize,
algo=self._config.ptq_algo,
bias_correct=self._config.bias_correct,
hist_percent=self._config.hist_percent,
batch_size=[1],
batch_num=self._config.batch_num,
runcount_limit=self._config.max_quant_count)
else:
assert 'dis' in self._strategy, "Only support optimizer compressed model by distillation loss."
### convert a inference program to train program
###[inference_program, feed_target_names, fetch_targets]= paddle.static.load_inference_model( \
### path_prefix=self.model_dir, \
### model_filename=self.model_filename, params_filename=self.params_filename,
### executor=self._exe)
[inference_program, feed_target_names, fetch_targets]= paddle.fluid.io.load_inference_model( \
dirname=self.model_dir, \
model_filename=self.model_filename, params_filename=self.params_filename,
executor=self._exe)
### used to check whether the dataloader is right
if self.eval_function is not None and self.train_config.origin_metric is not None:
metric = self.eval_function(self._exe, inference_program,
feed_target_names, fetch_targets)
_logger.info("metric of compressed model is: {}".format(metric))
buf = 0.05
if metric < (float(self.train_config.origin_metric) - buf) or \
metric > (float(self.train_config.origin_metric) + buf):
raise RuntimeError("target metric of pretrained model is {}, \
but now is {}, Please check the format of evaluation dataset \
or check the origin_metric in train_config"
.format(\
self.train_config.origin_metric, metric))
train_program_info, test_program_info = self._prepare_program(
inference_program, feed_target_names, fetch_targets)
test_program_info = self._start_train(train_program_info,
test_program_info)
self._save_model(test_program_info)
def _start_train(self, train_program_info, test_program_info):
best_metric = -1.0
for epoch_id in range(self.train_config.epochs):
for batch_id, data in enumerate(self.train_dataloader()):
np_probs_float, = self._exe.run(train_program_info.program, \
feed=data, \
fetch_list=train_program_info.fetch_targets)
if 'unstructure' in self._strategy:
self._pruner.step()
if self.train_config.logging_iter is None:
logging_iter = 10
else:
logging_iter = self.train_config.logging_iter
if batch_id % int(logging_iter) == 0:
_logger.info("epoch: {}, batch: {}, loss: {}".format(
epoch_id, batch_id, np_probs_float))
if batch_id % int(self.train_config.eval_iter) == 0:
if self.eval_function is not None:
# GMP pruner step 3: update params before summrizing sparsity, saving model or evaluation.
if 'unstructure' in self._strategy:
self._pruner.update_params()
metric = self.eval_function(
self._exe, test_program_info.program,
test_program_info.feed_target_names,
test_program_info.fetch_targets)
_logger.info(
"epoch: {}, batch: {} metric of compressed model is: {}".
format(epoch_id, batch_id, metric))
if metric > best_metric:
paddle.static.save(
program=test_program_info.program._program,
model_path=os.path.join(self.save_dir,
'best_model'))
if self.train_config.target_metric is not None:
if metric > float(self.train_config.target_metric):
return
else:
raise NotImplementedError(
"Please support eval function")
if 'qat' in self._strategy:
### TODO: load best model to save
float_program, int8_program = convert(test_program_info.program._program, self._places, self._quant_config, \
scope=paddle.static.global_scope(), \
save_int8=True)
test_program_info.program = float_program
return test_program_info
def _save_model(self, test_program_info):
test_program = test_program_info.program._program if isinstance(
test_program_info.program,
paddle.static.CompiledProgram) else test_program_info.program
feed_vars = []
for name in test_program_info.feed_target_names:
for var in test_program.list_vars():
if var.name == name:
feed_vars.append(var)
break
assert len(feed_vars) > 0, "can not find feed vars in quant program"
paddle.static.save_inference_model(
path_prefix=os.path.join(self.save_dir, 'final_model'),
feed_vars=feed_vars,
fetch_vars=test_program_info.fetch_targets,
executor=self._exe,
program=test_program)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
from .strategy_config import *
__all__ = ['save_config', 'load_config']
def load_config(config_path):
"""
convert yaml to dict config.
"""
f = open(config_path, 'r')
cfg = yaml.load(f, Loader=yaml.FullLoader)
f.close()
compress_config = {}
for key, value in cfg.items():
default_key = eval(key)(**value)
compress_config[key] = default_key
if compress_config.get('TrainConfig') != None:
train_config = compress_config.pop('TrainConfig')
else:
train_config = None
return compress_config, train_config
def save_config(config, config_path):
"""
convert dict config to yaml.
"""
f = open(config_path, "w")
yaml.dump(config, f)
f.close()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import paddle
import paddle.distributed.fleet as fleet
import paddle.optimizer as optimizer
from ..quant.quanter import quant_aware, _quant_config_default, _parse_configs, pact, get_pact_optimizer
from ..dist import *
from ..common.recover_program import recover_inference_program, _remove_fetch_node
from ..common import get_logger
from .strategy_config import ProgramInfo
_logger = get_logger(__name__, level=logging.INFO)
__all__ = [
'build_distill_program', 'build_quant_program', 'build_prune_program'
]
def _create_optimizer(train_config):
"""create optimizer"""
opt = getattr(optimizer, train_config.get('optimizer') or
'SGD') ### default optimizer is SGD
if 'optim_args' in train_config:
if train_config[
'optim_args'] is not None and 'grad_clip' in train_config[
'optim_args'] and train_config['optim_args'][
'grad_clip'] is not None:
grad_clip = getattr(
paddle.nn, train_config['optim_args']['grad_clip'])(
**train_config['optim_args']['grad_clip_args'])
train_config['optim_args'].pop('grad_clip')
train_config['optim_args'].pop('grad_clip_args')
else:
grad_clip = None
if 'grad_clip' in train_config['optim_args'] and train_config[
'optim_args']['grad_clip'] is None:
train_config['optim_args'].pop('grad_clip')
train_config['optim_args'].pop('grad_clip_args')
else:
train_config['optim_args'] = {}
grad_clip = None
op = opt(learning_rate=train_config["learning_rate"],
grad_clip=grad_clip,
**train_config['optim_args'])
return op
def _parse_distill_loss(distill_node_pair,
distill_loss='l2_loss',
distill_lambda=1.0):
"""parse distill loss config"""
loss_dist = 0.0
losses = []
if isinstance(distill_node_pair[0], str):
assert isinstance(distill_loss, str)
assert isinstance(distill_lambda, float)
distill_node_pair = [distill_node_pair]
distill_loss = [distill_loss]
distill_lambda = [distill_lambda]
assert len(distill_node_pair) == len(distill_loss)
assert len(distill_node_pair) == len(distill_lambda)
for node, loss, lam in zip(distill_node_pair, distill_loss, distill_lambda):
tmp_loss = 0.0
_logger.info("train config.distill_node_pair: {}".format(node, loss,
lam))
assert len(node) % 2 == 0, \
"distill_node_pair config wrong, the length needs to be an even number"
for i in range(len(node) // 2):
tmp_loss += eval(loss)(node[i * 2], node[i * 2 + 1])
loss_dist += lam * tmp_loss
losses.append(tmp_loss)
return loss_dist, losses
def _load_program_and_merge(executor,
place,
train_program,
config,
model_dir,
model_filename,
params_filename,
teacher_idx=None,
feed_target_names=None):
try:
[teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.fluid.io.load_inference_model( \
dirname=model_dir, \
model_filename=model_filename, \
params_filename=params_filename, \
executor=executor)
except:
[teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.static.load_inference_model( \
path_prefix=model_dir, \
executor=executor)
_remove_fetch_node(teacher_program)
if teacher_idx == None or teacher_idx == 1:
test_program = train_program.clone(for_test=True)
data_name_map = {}
if 'merge_feed' not in config or config['merge_feed'] == True:
assert len(feed_target_names) == len(teacher_feed_target_names), \
"the number of feed nodes in the teacher model is not equal to the student model"
for i, name in enumerate(feed_target_names):
data_name_map[teacher_feed_target_names[i]] = name
if teacher_idx is None:
teacher_name_prefix = 'teacher_'
else:
teacher_name_prefix = 'teacher{}_'.format(str(teacher_idx))
merge(
teacher_program,
train_program,
data_name_map,
place,
name_prefix=teacher_name_prefix,
merge_feed=config.get('merge_feed') or True)
if teacher_idx == None or teacher_idx == 1:
return train_program, test_program, data_name_map
else:
return train_program, None, data_name_map
def build_distill_program(executor,
place,
config,
train_config,
train_program_info=None,
pruner=None,
dist_strategy=None):
"""build distill program with infermodel"""
startup_program = paddle.static.Program()
if train_program_info is None:
[train_program, feed_target_names, fetch_targets]= paddle.static.load_inference_model( \
path_prefix=config["model_dir"] if "model_dir" in config else config["model_path_prefix"], \
executor=executor)
train_program = recover_inference_program(train_program)
else:
train_program = train_program_info.program
feed_target_names = train_program_info.feed_target_names
fetch_targets = train_program_info.fetch_targets
teacher_model_dir = config[
"teacher_model_dir"] if "teacher_model_dir" in config else config[
"teacher_model_path_prefix"]
if isinstance(teacher_model_dir, list):
for tea_idx in range(len(teacher_model_dir)):
model_filename = config["teacher_model_filename"][
tea_idx] if "teacher_model_filename" in config else None
params_filename = config["teacher_params_filename"][
tea_idx] if "teacher_params_filename" in config else None
if tea_idx == 0:
train_program, test_program, data_name_map = _load_program_and_merge(
executor,
place,
train_program,
config,
teacher_model_dir[tea_idx],
model_filename,
params_filename,
teacher_idx=(tea_idx + 1),
feed_target_names=feed_target_names)
else:
train_program, _, data_name_map = _load_program_and_merge(
executor,
place,
train_program,
config,
teacher_model_dir[tea_idx],
model_filename,
params_filename,
teacher_idx=(tea_idx + 1),
feed_target_names=feed_target_names)
else:
model_filename = config[
"teacher_model_filename"] if "teacher_model_filename" in config else None
params_filename = config[
"teacher_params_filename"] if "teacher_params_filename" in config else None
train_program, test_program, data_name_map = _load_program_and_merge(
executor,
place,
train_program,
config,
teacher_model_dir,
model_filename,
params_filename,
teacher_idx=None,
feed_target_names=feed_target_names)
# all feed node should set stop_gradient is False, for using pact quant algo.
for var in train_program.list_vars():
if var.name in data_name_map.values() or var.name in data_name_map.keys(
):
var.stop_gradient = False
train_fetch_list = []
with paddle.static.program_guard(train_program, startup_program):
with paddle.utils.unique_name.guard('merge'):
optimizer = _create_optimizer(train_config)
if train_config.get('use_fleet'):
optimizer = fleet.distributed_optimizer(optimizer,
dist_strategy)
else:
if train_config.get('amp_config') is not None:
custom_white_list = train_config['amp_config'].get(
'custom_white_list', None)
if custom_white_list is not None:
train_config['amp_config'].pop('custom_white_list')
custom_black_list = train_config['amp_config'].get(
'custom_black_list', None)
if custom_black_list is not None:
train_config['amp_config'].pop('custom_black_list')
custom_black_varnames = train_config['amp_config'].get(
'custom_black_varnames', None)
if custom_black_varnames is not None:
train_config['amp_config'].pop('custom_black_varnames')
amp_list = paddle.static.amp.CustomOpLists(
custom_white_list=custom_white_list,
custom_black_list=custom_black_list,
custom_black_varnames=custom_black_varnames)
optimizer = paddle.static.amp.decorate(
optimizer=optimizer,
amp_lists=amp_list,
init_loss_scaling=128.0,
use_dynamic_loss_scaling=True,
**train_config['amp_config'])
distill_loss, losses = _parse_distill_loss(
config['distill_node_pair'],
config.get('distill_loss') or
'l2_loss', ### default loss is l2_loss
config.get('distill_lambda') or 1.0) ### default lambda is 1.0
loss = paddle.mean(distill_loss)
loss.stop_gradient = False
if 'prune_algo' in config: ### prune & asp
if config['prune_algo'] == 'asp':
optimizer = pruner.decorate(optimizer)
optimizer.minimize(loss)
elif 'prune_strategy' in config: ###unstructure prune
optimizer.minimize(loss, no_grad_set=pruner.no_grad_set)
else:
optimizer.minimize(loss)
train_fetch_list.append(loss)
train_program_info = ProgramInfo(startup_program, train_program,
feed_target_names, train_fetch_list,
optimizer)
test_program_info = ProgramInfo(startup_program, test_program,
feed_target_names, fetch_targets)
return train_program_info, test_program_info
def build_quant_program(executor, place, config, train_program_info,
test_program_info):
scope = paddle.static.global_scope()
assert isinstance(config, dict), "quant config must be dict"
default_config = _quant_config_default
default_config.update(config)
print(default_config)
config = _parse_configs(default_config)
use_pact = config["use_pact"]
if use_pact:
act_preprocess_func = pact
optimizer_func = get_pact_optimizer
pact_executor = executor
else:
act_preprocess_func = None
optimizer_func = None
pact_executor = None
test_program = quant_aware(
test_program_info.program,
place,
config,
scope=scope,
act_preprocess_func=None,
optimizer_func=None,
executor=None,
for_test=True)
train_program = quant_aware(
train_program_info.program,
place,
config,
scope=scope,
act_preprocess_func=act_preprocess_func,
optimizer_func=optimizer_func,
executor=pact_executor,
for_test=False,
return_program=True)
train_program_info.program = train_program
test_program_info.program = test_program
return train_program_info, test_program_info, config
def build_prune_program(executor, place, config, train_program_info, strategy):
if 'unstructure' in strategy:
from ..prune.unstructured_pruner import UnstructuredPruner, GMPUnstructuredPruner
if config["prune_strategy"] is None:
pruner = UnstructuredPruner(
train_program_info.program,
mode=config['prune_mode'],
ratio=config['pruned_ratio'],
threshold=config['threshold'],
prune_params_type=config['prune_params_type'],
place=place,
local_sparsity=config['local_sparsity'], )
elif config["prune_strategy"] == "gmp":
pruner = GMPUnstructuredPruner(
train_program_info.program,
ratio=config['pruned_ratio'],
threshold=config['threshold'],
prune_params_type=config['prune_params_type'],
place=place,
local_sparsity=config['local_sparsity'],
config=config['gmp_config'])
else:
if config['prune_algo'] == 'prune':
from ..prune import Pruner
pruner = Pruner(config["criterion"])
params = []
for param in train_program_info.program.global_block(
).all_parameters():
if config[
'prune_params_name'] is not None and param.name in config[
'prune_params_name']:
params.append(param.name)
pruned_program, _, _ = pruner.prune(
train_program_info.program,
paddle.static.global_scope(),
params=params,
ratios=[config['pruned_ratio']] * len(params),
place=place)
train_program_info.program = pruned_program
elif config['prune_algo'] == 'asp':
from paddle.static import sparsity
pruner = sparsity
excluded_params_name = []
for param in train_program_info.program.global_block(
).all_parameters():
if config[
'prune_params_name'] is not None and param.name not in config[
'prune_params_name']:
excluded_params_name.append(param.name)
pruner.set_excluded_layers(train_program_info.program,
excluded_params_name)
else:
raise NotImplementedError(
"prune_algo must be choice in [\"prune\", \"asp\"], {} is not support".
format(config['prune_algo']))
return pruner, train_program_info
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
__all__ = [
"Quantization", "Distillation", "MultiTeacherDistillation", \
"HyperParameterOptimization", "Prune", "UnstructurePrune", \
"merge_config", "ProgramInfo", "TrainConfig",
]
### Quantization:
Quantization = namedtuple(
"Quantization",
[
"quantize_op_types",
"weight_bits",
"activation_bits",
"not_quant_pattern", ### ptq没有暴露相应接口
"use_pact", ### 仅QAT支持
"is_full_quantize"
])
Quantization.__new__.__defaults__ = (None, ) * (len(Quantization._fields) - 1
) + (False, )
### Distillation:
Distillation = namedtuple(
"Distillation",
[
"distill_loss", ### list[list],支持不同节点之间使用不同的loss。
"distill_node_pair", ### list[list],支持不同节点之间使用不同的loss。
"distill_lambda", ### list[list],支持不同节点之间使用不同的loss。
"teacher_model_dir",
"teacher_model_filename",
"teacher_params_filename",
"merge_feed",
])
Distillation.__new__.__defaults__ = (None, ) * (len(Distillation._fields) - 1
) + (True, )
### 多teacher蒸馏配置
### Multi-Teacher Distillation:
MultiTeacherDistillation = namedtuple(
"MultiTeacherDistillation",
[
"distill_loss", ### list[str],每个teacher对应一个loss
"distill_node_pair", ### list[list],每个teacher对应一个蒸馏。仅支持logits蒸馏,不支持中间层蒸馏
"distill_lambda", ### list[float],每个teacher对应一个lambda。
"teacher_model_dir",
"teacher_model_filename", ### list[str], 每个teacher对应一个模型文件
"teacher_params_filename", ### list[str], 每个teacher对应一个参数文件
"merge_feed",
])
MultiTeacherDistillation.__new__.__defaults__ = (None, ) * (
len(MultiTeacherDistillation._fields) - 1) + (True, )
### 不设置就按照默认的搜索空间进行超参搜索,设置的话按照设置的搜索空间搜索,这样可以支持单PTQ策略
###HyperParameterOptimization
HyperParameterOptimization = namedtuple("HyperParameterOptimization", [
"ptq_algo", "bias_correct", "weight_quantize_type", "hist_percent",
"batch_num", "max_quant_count"
])
HyperParameterOptimization.__new__.__defaults__ = (None, ) * (
len(HyperParameterOptimization._fields) - 1) + (20, )
### Prune
Prune = namedtuple("Prune", [
"prune_algo",
"pruned_ratio",
"prune_params_name",
"criterion",
])
Prune.__new__.__defaults__ = (None, ) * len(Prune._fields)
### UnstructurePrune
UnstructurePrune = namedtuple("UnstructurePrune", [
"prune_strategy",
"prune_mode",
"threshold",
"prune_ratio",
"gmp_config",
"prune_params_type",
"local_sparsity",
])
UnstructurePrune.__new__.__defaults__ = (None, ) * len(UnstructurePrune._fields)
### Train
TrainConfig = namedtuple("Train", [
"epochs",
"learning_rate",
"optimizer",
"optim_args",
"eval_iter",
"logging_iter",
"origin_metric",
"target_metric",
"use_fleet",
"amp_config",
"recompute_config",
"sharding_config",
])
TrainConfig.__new__.__defaults__ = (None, ) * len(TrainConfig._fields)
def merge_config(*args):
fields = tuple()
cfg = dict()
for arg in args:
fields += arg._fields
cfg.update(dict(arg._asdict()))
MergeConfig = namedtuple("MergeConfig", fields)
return MergeConfig(**cfg)
class ProgramInfo:
def __init__(self,
startup_program,
program,
feed_target_names,
fetch_targets,
optimizer=None):
self.startup_program = startup_program
self.program = program
self.feed_target_names = feed_target_names
self.fetch_targets = fetch_targets
self.optimizer = optimizer
......@@ -22,7 +22,8 @@ from .server import Server
from .client import Client
from .meter import AvgrageMeter
from .analyze_helper import VarCollector
from paddleslim.common import wrapper_function
from . import wrapper_function
from . import recover_program
__all__ = [
'EvolutionaryController', 'SAController', 'get_logger', 'ControllerServer',
......@@ -31,3 +32,4 @@ __all__ = [
]
__all__ += wrapper_function.__all__
__all__ += recover_program.__all__
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
from paddle.fluid.framework import Parameter
from paddle.fluid import unique_name
from paddle.fluid import core
from ..core import GraphWrapper
__all__ = ['recover_inference_program']
def _remove_fetch_node(program):
"""remove fetch node in program"""
for block in program.blocks:
removed = 0
ops = list(block.ops)
for op in ops:
if op.type == "fetch":
idx = ops.index(op)
block._remove_op(idx - removed)
removed += 1
def _recover_reserve_space_with_bn(program):
"""Add the outputs which is only used for training and not saved in
inference program."""
for block_idx in six.moves.range(program.num_blocks):
block = program.block(block_idx)
for op in block.ops:
if op.type == "batch_norm":
if "ReserveSpace" not in op.output_names or len(
op.output("ReserveSpace")) == 0:
reserve_space = block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["reserve_space", 'tmp'])),
dtype=block.var(op.input("X")[0]).dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=True)
op.desc.set_output("ReserveSpace", [reserve_space.name])
return program
def _recover_param_attr(program):
"""recover parameters attribute.
Params in infermodel are stored in the form of variable, which can not be trained."""
all_weights = [param for param in program.list_vars() \
if param.persistable is True and param.name != 'feed' and param.name != 'fetch']
for w in all_weights:
new_w = Parameter(
block=program.block(0),
shape=w.shape,
dtype=w.dtype,
type=w.type,
name=w.name)
new_w.set_value(w.get_value())
program.block(0).vars[w.name] = new_w
return program
def recover_inference_program(inference_program):
""" recover inference program to train program which can be trained. """
_remove_fetch_node(inference_program)
inference_program = _recover_param_attr(inference_program)
inference_program = _recover_reserve_space_with_bn(inference_program)
for var in inference_program.list_vars():
var.stop_gradient = False
for op in inference_program.global_block().ops:
op._set_attr("is_test", False)
return inference_program
......@@ -387,5 +387,5 @@ class GraphWrapper(object):
It is used after loading pruned parameters from file.
"""
for op in self.ops():
if op.type() != 'conditional_block':
if op.type() != 'conditional_block' and op.type() != 'feed':
op._op.desc.infer_shape(op._op.block.desc)
......@@ -22,7 +22,8 @@ def merge(teacher_program,
data_name_map,
place,
scope=None,
name_prefix='teacher_'):
name_prefix='teacher_',
merge_feed=True):
"""Merge teacher program into student program and add a uniform prefix to the
names of all vars in teacher program
......@@ -40,6 +41,7 @@ def merge(teacher_program,
will be used. Default: None
name_prefix(str): Name prefix added for all vars of the teacher program.
Default: 'teacher_'
merge_feed(bool): Wheather to merge feed op when merge program. Default: True.
Returns:
None
......@@ -49,7 +51,7 @@ def merge(teacher_program,
teacher_program = teacher_program.clone(for_test=True)
for teacher_var in teacher_program.list_vars():
skip_rename = False
if teacher_var.name != 'fetch' and teacher_var.name != 'feed':
if teacher_var.name != 'fetch' and (not merge_feed or teacher_var.name != 'feed'):
if teacher_var.name in data_name_map.keys():
new_name = data_name_map[teacher_var.name]
if new_name == teacher_var.name:
......@@ -67,7 +69,7 @@ def merge(teacher_program,
teacher_var.name, new_name)
for teacher_var in teacher_program.list_vars():
if teacher_var.name != 'fetch' and teacher_var.name != 'feed':
if teacher_var.name != 'fetch' and (not merge_feed or teacher_var.name != 'feed'):
# student program add var
new_var = student_program.global_block()._clone_variable(
teacher_var, force_persistable=False)
......@@ -75,7 +77,7 @@ def merge(teacher_program,
for block in teacher_program.blocks:
for op in block.ops:
if op.type != 'feed' and op.type != 'fetch':
if (not merge_feed or op.type != 'feed') and op.type != 'fetch':
inputs = {}
outputs = {}
attrs = {}
......
......@@ -24,13 +24,11 @@ import shutil
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import unique_name
from paddle.fluid import core
from paddle.fluid.framework import Parameter
from paddleslim.dist import merge, l2_loss, soft_label_loss, fsp_loss
from paddleslim.core import GraphWrapper
from paddleslim.quant import quant_aware, convert
from ..common.recover_program import recover_inference_program
from .quanter import _quant_config_default, _parse_configs, pact, get_pact_optimizer
from .quanter import quant_aware, convert
from ..dist import merge, l2_loss, soft_label_loss, fsp_loss
from ..auto_compression.create_compressed_program import build_distill_program
import logging
logging.getLogger().setLevel(logging.INFO)
from ..common import get_logger
......@@ -102,139 +100,6 @@ def _parse_train_configs(train_config):
return train_config
def _create_optimizer(train_config):
"""create optimizer"""
optimizer = paddle.optimizer.SGD(
learning_rate=train_config["learning_rate"],
weight_decay=paddle.regularizer.L2Decay(train_config["weight_decay"]))
return optimizer
def _remove_fetch_node(program):
"""remove fetch node in program"""
for block in program.blocks:
removed = 0
ops = list(block.ops)
for op in ops:
if op.type == "fetch":
idx = ops.index(op)
block._remove_op(idx - removed)
removed += 1
def _recover_reserve_space_with_bn(program):
"""Add the outputs which is only used for training and not saved in
inference program."""
for block_idx in six.moves.range(program.num_blocks):
block = program.block(block_idx)
for op in block.ops:
if op.type == "batch_norm":
if "ReserveSpace" not in op.output_names or len(
op.output("ReserveSpace")) == 0:
reserve_space = block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["reserve_space", 'tmp'])),
dtype=block.var(op.input("X")[0]).dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=True)
op.desc.set_output("ReserveSpace", [reserve_space.name])
return program
def _recover_param_attr(program):
"""recover parameters attribute.
Params in infermodel are stored in the form of variable, which can not be trained."""
all_weights = [param for param in program.list_vars() \
if param.persistable is True and param.name != 'feed' and param.name != 'fetch']
for w in all_weights:
new_w = Parameter(
block=program.block(0),
shape=w.shape,
dtype=w.dtype,
type=w.type,
name=w.name)
new_w.set_value(w.get_value())
program.block(0).vars[w.name] = new_w
return program
def _parse_distill_loss(train_config):
"""parse distill loss config"""
assert len(train_config["distill_node_pair"]) % 2 == 0, \
"distill_node_pair config wrong, the length needs to be an even number"
print("train config.distill_node_pair: ", train_config["distill_node_pair"])
distill_loss = 0
for i in range(len(train_config["distill_node_pair"]) // 2):
print(train_config["distill_node_pair"][i * 2],
train_config["distill_node_pair"][i * 2 + 1])
distill_loss += l2_loss(train_config["distill_node_pair"][i * 2],
train_config["distill_node_pair"][i * 2 + 1])
return distill_loss
DistillProgramInfo = namedtuple("DistillProgramInfo", \
"startup_program train_program train_feed_names train_fetch_list \
optimizer test_program test_feed_names test_fetch_list"
)
def build_distill_prog_with_infermodel(executor, place, train_config):
"""build distill program with infermodel"""
[train_program, feed_target_names, fetch_targets]= paddle.static.load_inference_model( \
path_prefix=train_config["model_path_prefix"], \
executor=executor)
_remove_fetch_node(train_program)
[teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.static.load_inference_model( \
path_prefix=train_config["teacher_model_path_prefix"], \
executor=executor)
_remove_fetch_node(teacher_program)
test_program = train_program.clone(for_test=True)
train_program = _recover_param_attr(train_program)
train_program = _recover_reserve_space_with_bn(train_program)
for var in train_program.list_vars():
var.stop_gradient = False
train_graph = GraphWrapper(train_program)
for op in train_graph.ops():
op._op._set_attr("is_test", False)
############################################################################
# distill
############################################################################
data_name_map = {}
assert len(feed_target_names) == len(teacher_feed_target_names), \
"the number of feed nodes in the teacher model is not equal to the student model"
for i, name in enumerate(feed_target_names):
data_name_map[teacher_feed_target_names[i]] = name
merge(teacher_program, train_program, data_name_map, place)
# all feed node should set stop_gradient is False, for using pact quant algo.
for var in train_program.list_vars():
if var.name in data_name_map.values() or var.name in data_name_map.keys(
):
var.stop_gradient = False
train_fetch_list = []
train_fetch_name_list = []
startup_program = paddle.static.Program()
with paddle.static.program_guard(train_program, startup_program):
with fluid.unique_name.guard('merge'):
optimizer = _create_optimizer(train_config)
distill_loss = _parse_distill_loss(train_config)
loss = paddle.mean(distill_loss)
loss.stop_gradient = False
p_g_list = paddle.static.append_backward(loss=loss)
opts = optimizer.apply_gradients(p_g_list)
train_fetch_list.append(loss)
train_fetch_name_list.append(loss.name)
return DistillProgramInfo(startup_program, train_program, \
feed_target_names, train_fetch_list, optimizer, \
test_program, feed_target_names, fetch_targets)
def _compile_program(program, fetch_var_name):
"""compiling program"""
compiled_prog = paddle.static.CompiledProgram(program)
......@@ -294,16 +159,16 @@ def quant_aware_with_infermodel(executor,
_logger.info("quant_aware config {}".format(quant_config))
train_config = _parse_train_configs(train_config)
distill_program_info = build_distill_prog_with_infermodel(executor, place,
train_config)
distill_program_info, test_program_info = build_distill_program(
executor, place, train_config, train_config)
startup_program = distill_program_info.startup_program
train_program = distill_program_info.train_program
train_feed_names = distill_program_info.train_feed_names
train_fetch_list = distill_program_info.train_fetch_list
train_program = distill_program_info.program
train_feed_names = distill_program_info.feed_target_names
train_fetch_list = distill_program_info.fetch_targets
optimizer = distill_program_info.optimizer
test_program = distill_program_info.test_program
test_feed_names = distill_program_info.test_feed_names
test_fetch_list = distill_program_info.test_fetch_list
test_program = test_program_info.program
test_feed_names = test_program_info.feed_target_names
test_fetch_list = test_program_info.fetch_targets
############################################################################
# quant
......@@ -412,11 +277,11 @@ def export_quant_infermodel(
_logger.info("quant_aware config {}".format(quant_config))
train_config = _parse_train_configs(train_config)
distill_program_info = build_distill_prog_with_infermodel(executor, place,
train_config)
test_program = distill_program_info.test_program
test_feed_names = distill_program_info.test_feed_names
test_fetch_list = distill_program_info.test_fetch_list
_, test_program_info = build_distill_program(executor, place, train_config,
train_config)
test_program = test_program_info.program
test_feed_names = test_program_info.feed_target_names
test_fetch_list = test_program_info.fetch_targets
############################################################################
# quant
......
......@@ -46,8 +46,16 @@ class QuantConfig:
place,
float_infer_model_path,
quantize_model_path,
algo,
hist_percent,
bias_correct,
batch_size,
batch_num,
train_sample_generator=None,
eval_sample_generator=None,
train_dataloader=None,
eval_dataloader=None,
eval_function=None,
model_filename=None,
params_filename=None,
save_model_filename='__model__',
......@@ -66,8 +74,16 @@ class QuantConfig:
self.place = place
self.float_infer_model_path = float_infer_model_path
self.quantize_model_path = quantize_model_path
self.algo = algo,
self.hist_percent = hist_percent,
self.bias_correct = bias_correct,
self.batch_size = batch_size,
self.batch_num = batch_num,
self.train_sample_generator = train_sample_generator
self.eval_sample_generator = eval_sample_generator
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
self.eval_function = eval_function
self.model_filename = model_filename
self.params_filename = params_filename
self.save_model_filename = save_model_filename
......@@ -172,13 +188,19 @@ def eval_quant_model():
out_len_sum = 0
valid_data_num = 0
max_eval_data_num = 200
for i, data in enumerate(g_quant_config.eval_sample_generator()):
if g_quant_config.eval_sample_generator is not None:
feed_dict=False
eval_dataloader = g_quant_config.eval_sample_generator
else:
feed_dict=True
eval_dataloader = g_quant_config.eval_dataloader
for i, data in enumerate(eval_dataloader()):
with paddle.static.scope_guard(float_scope):
out_float = g_quant_config.executor.run(infer_prog_float, \
fetch_list=fetch_targets_float, feed=make_feed_dict(feed_target_names_float, data))
fetch_list=fetch_targets_float, feed=data if feed_dict else make_feed_dict(feed_target_names_float, data))
with paddle.static.scope_guard(quant_scope):
out_quant = g_quant_config.executor.run(infer_prog_quant, \
fetch_list=fetch_targets_quant, feed=make_feed_dict(feed_target_names_quant, data))
fetch_list=fetch_targets_quant, feed=data if feed_dict else make_feed_dict(feed_target_names_quant, data))
out_float = convert_model_out_2_nparr(out_float)
out_quant = convert_model_out_2_nparr(out_quant)
......@@ -213,11 +235,13 @@ def eval_quant_model():
def quantize(cfg):
"""model quantize job"""
algo = cfg["algo"]
hist_percent = cfg["hist_percent"]
bias_correct = cfg["bias_correct"]
batch_size = cfg["batch_size"]
batch_num = cfg["batch_num"]
algo = cfg["algo"] if 'algo' in cfg else g_quant_config.algo[0][0]
hist_percent = cfg["hist_percent"] if "hist_percent" in cfg else g_quant_config.hist_percent[0][0]
bias_correct = cfg["bias_correct"] if "bias_correct" in cfg else g_quant_config.bias_correct[0][0]
batch_size = cfg["batch_size"] if "batch_size" in cfg else g_quant_config.batch_size[0][0]
batch_num = cfg["batch_num"] if "batch_num" in cfg else g_quant_config.batch_num[0][0]
weight_quantize_type = cfg["weight_quantize_type"] if "weight_quantize_type" in cfg else g_quant_config.weight_quantize_type[0]
print(hist_percent, bias_correct, batch_size, batch_num, weight_quantize_type)
quant_post( \
executor=g_quant_config.executor, \
......@@ -225,13 +249,14 @@ def quantize(cfg):
model_dir=g_quant_config.float_infer_model_path, \
quantize_model_path=g_quant_model_cache_path, \
sample_generator=g_quant_config.train_sample_generator, \
data_loader=g_quant_config.train_dataloader,
model_filename=g_quant_config.model_filename, \
params_filename=g_quant_config.params_filename, \
save_model_filename=g_quant_config.save_model_filename, \
save_params_filename=g_quant_config.save_params_filename, \
quantizable_op_type=g_quant_config.quantizable_op_type, \
activation_quantize_type="moving_average_abs_max", \
weight_quantize_type=g_quant_config.weight_quantize_type, \
weight_quantize_type=weight_quantize_type, \
algo=algo, \
hist_percent=hist_percent, \
bias_correction=bias_correct, \
......@@ -239,7 +264,12 @@ def quantize(cfg):
batch_nums=batch_num)
global g_min_emd_loss
emd_loss = eval_quant_model()
### if eval_function is not None, use eval function provided by user.
if g_quant_config.eval_function is not None:
emd_loss = g_quant_config.eval_function()
else:
emd_loss = eval_quant_model()
print("emd loss: ", emd_loss)
if emd_loss < g_min_emd_loss:
g_min_emd_loss = emd_loss
if os.path.exists(g_quant_config.quantize_model_path):
......@@ -255,6 +285,9 @@ def quant_post_hpo(executor,
quantize_model_path,
train_sample_generator=None,
eval_sample_generator=None,
train_dataloader=None,
eval_dataloader=None,
eval_function=None,
model_filename=None,
params_filename=None,
save_model_filename='__model__',
......@@ -264,7 +297,12 @@ def quant_post_hpo(executor,
is_full_quantize=False,
weight_bits=8,
activation_bits=8,
weight_quantize_type='channel_wise_abs_max',
weight_quantize_type=['channel_wise_abs_max'],
algo=["KL", "hist", "avg", "mse"],
bias_correct=[True, False],
hist_percent=[0.98, 0.999], ### uniform sample in list.
batch_size=[10, 30], ### uniform sample in list.
batch_num=[10, 30], ### uniform sample in list.
optimize_model=False,
is_use_cache_file=False,
cache_dir="./temp_post_training",
......@@ -322,28 +360,79 @@ def quant_post_hpo(executor,
global g_quant_config
g_quant_config = QuantConfig(
executor, place, model_dir, quantize_model_path, train_sample_generator,
eval_sample_generator, model_filename, params_filename,
executor, place, model_dir, quantize_model_path, algo, hist_percent,
bias_correct, batch_size, batch_num, train_sample_generator,
eval_sample_generator, train_dataloader, eval_dataloader, eval_function,
model_filename, params_filename,
save_model_filename, save_params_filename, scope, quantizable_op_type,
is_full_quantize, weight_bits, activation_bits, weight_quantize_type,
optimize_model, is_use_cache_file, cache_dir)
cs = ConfigurationSpace()
algo = CategoricalHyperparameter(
"algo", ["KL", "hist", "avg", "mse"], default_value="KL")
bias_correct = CategoricalHyperparameter(
"bias_correct", [True, False], default_value=False)
weight_quantize_method = CategoricalHyperparameter("weight_quantize_method", \
[weight_quantize_type], default_value=weight_quantize_type)
hist_percent = UniformFloatHyperparameter(
"hist_percent", 0.98, 0.999, default_value=0.99)
batch_size = UniformIntegerHyperparameter(
"batch_size", 10, 30, default_value=10)
batch_num = UniformIntegerHyperparameter(
"batch_num", 10, 30, default_value=10)
cs.add_hyperparameters([algo, bias_correct, weight_quantize_method, \
hist_percent, batch_size, batch_num])
hyper_params = []
if 'hist' in algo:
hist_percent = UniformFloatHyperparameter(
"hist_percent", hist_percent[0], hist_percent[1], default_value=hist_percent[0])
hyper_params.append(hist_percent)
if len(algo) > 1:
algo = CategoricalHyperparameter(
"algo", algo, default_value=algo[0])
hyper_params.append(algo)
else:
algo = algo[0]
if len(bias_correct) > 1:
bias_correct = CategoricalHyperparameter(
"bias_correct", bias_correct, default_value=bias_correct[0])
hyper_params.append(bias_correct)
else:
bias_correct = bias_correct[0]
if len(weight_quantize_type) > 1:
weight_quantize_type = CategoricalHyperparameter("weight_quantize_type", \
weight_quantize_type, default_value=weight_quantize_type[0])
hyper_params.append(weight_quantize_type)
else:
weight_quantize_type = weight_quantize_type[0]
if len(batch_size) > 1:
batch_size = UniformIntegerHyperparameter(
"batch_size", batch_size[0], batch_size[1], default_value=batch_size[0])
hyper_params.append(batch_size)
else:
batch_size = batch_size[0]
if len(batch_num) > 1:
batch_num = UniformIntegerHyperparameter(
"batch_num", batch_num[0], batch_num[1], default_value=batch_num[0])
hyper_params.append(batch_num)
else:
batch_num = batch_num[0]
if len(hyper_params) == 0:
quant_post( \
executor=g_quant_config.executor, \
scope=g_quant_config.scope, \
model_dir=g_quant_config.float_infer_model_path, \
quantize_model_path=g_quant_model_cache_path, \
sample_generator=g_quant_config.train_sample_generator, \
data_loader=g_quant_config.train_dataloader,
model_filename=g_quant_config.model_filename, \
params_filename=g_quant_config.params_filename, \
save_model_filename=g_quant_config.save_model_filename, \
save_params_filename=g_quant_config.save_params_filename, \
quantizable_op_type=g_quant_config.quantizable_op_type, \
activation_quantize_type="moving_average_abs_max", \
weight_quantize_type=weight_quantize_type, \
algo=algo, \
hist_percent=hist_percent, \
bias_correction=bias_correct, \
batch_size=batch_size, \
batch_nums=batch_num)
return
cs.add_hyperparameters(hyper_params)
scenario = Scenario({
"run_obj": "quality", # we optimize quality (alternative runtime)
......
......@@ -294,6 +294,15 @@ def quant_aware(program,
VARS_MAPPING_TABLE))
save_dict(main_graph.out_node_mapping_table)
main_graph.draw('./', 'graph.pdf')
#remove_ctr_vars = set()
#from paddle.fluid.framework import IrVarNode
#all_var_nodes = {IrVarNode(node) for node in main_graph.nodes() if node.is_var()}
#for node in all_var_nodes:
# print("node: ", node)
# if node.is_ctrl_var():
# remove_ctr_vars.add(node)
#self.safe_remove_nodes(remove_ctr_vars)
if for_test or return_program:
quant_program = main_graph.to_program()
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册