未验证 提交 1293c6c8 编写于 作者: Z zhouzj 提交者: GitHub

Cherry pick pr on ACT (#1540)

上级 dfe2e3a7
......@@ -164,7 +164,7 @@ ac = AutoCompression(
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="MobileNetV1_quant",
config={"Quantization": {}, "HyperParameterOptimization": {'ptq_algo': ['avg'], 'max_quant_count': 3}},
config={"QuantPost": {}, "HyperParameterOptimization": {'ptq_algo': ['avg'], 'max_quant_count': 3}},
### config={"Quantization": {}, "Distillation": {}}, ### 如果您的系统为Windows系统, 请使用当前这一行配置
train_dataloader=train_loader,
eval_dataloader=train_loader)
......
......@@ -3,7 +3,7 @@
## 1.1 各压缩方法超参解析
### 1.1.1 量化(quantization)
### 1.1.1 量化训练(quantization)
量化参数主要设置量化比特数和量化op类型,其中量化op包含卷积层(conv2d, depthwise_conv2d)和全连接层(mul, matmul_v2)。以下为只量化卷积层的示例:
```yaml
......@@ -50,7 +50,53 @@ print(TENSORRT_OP_TYPES)
- is_full_quantize: 是否量化所有可支持op类型。默认值为False.
### 1.1.2 知识蒸馏(knowledge distillation)
### 1.1.2 离线量化(post-traing quantization)
离线量化中基本的量化参数和量化训练相同,不再赘述。以下介绍离线量化特有的参数:
```yaml
QuantPost:
batch_size: 32
batch_nums: None
algo: 'hist'
hist_percent: 0.999
bias_correct: False
recon_level: None
regions: None
epochs: 20
lr: 0.1
simulate_activation_quant: False
skip_tensor_list: None
```
以上配置项说明如下:
- batch_size: 设置每个 batch 的图片数量。默认值为32。
- batch_nums: 离线量化迭代次数。如果设置为 None ,则会一直运行到全部训练数据迭代结束;否则,迭代次数为 batch_nums, 即参与对 Scale 进行校正的样本个数为 batch_nums * batch_size 。
- algo: 量化时使用的算法名称,可为 'KL','mse', 'hist', 'avg' 或 'abs_max'。当 algo 设置为 'abs_max' 时,使用校正数据的激活值的绝对值的最大值当作 scale 值,当设置为 'KL' 时,则使用KL散度的方法来计算 Scale 值,当设置为 'avg' 时,使用校正数据激活值的最大绝对值平均数作为 scale 值,当设置为 'hist' 时,则使用基于百分比的直方图的方法来计算 scale 值,当设置为 'mse' 时,则使用搜索最小mse损失的方法来计算 scale 值。默认值为 'hist' 。
- hist_percent: 'hist' 方法的百分位数。默认值为0.9999。
- bias_correct: 是否使用 bias correction 算法。默认值为 False 。
- recon_level: 设置该参数将在离线量化之后进行逐区域重建训练,目前支持 'layer-wise' 和 'region-wise'。当设置为'layer-wise'时, 以层为单位进行重建训练;当设置为'region-wise'时,以 `regions` 中每个块区域为单位进行重建训练;当设置为 None 时,则不进行重建训练。 默认值为 None 。
- regions(list[list]): 当 recon_level 是 'region-wise' 时,需要设置该参数。该列表中每个元素由一个区域的输入和输出变量名组成,可参考该[示例](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/example/post_training_quantization/pytorch_yolo_series/configs/yolov6s_fine_tune.yaml#L11)
- epochs: 逐区域重建训练的训练次数。每个 epoch 内的样本数量为 batch_nums * batch_size 。默认值为20。
- lr: 设置逐区域重建训练的学习率。
- simulate_activation_quant: 是否在重建训练中引入激活量化噪声。默认值为 False 。
- skip_tensor_list: 不进行量化的 Tensor 列表,需填入 Tensor 的 name。Tensor 的name 可以通过可视化工具查看。默认值为 None 。
### 1.1.3 离线量化超参优化(hyper parameter optimization)
超参优化是对离线量化中的超参数进行搜索,以选择最优的超参实现更好的量化效果。离线量化超参优化需要设置 `QuantPost``HyperParameterOptimization`
```yaml
HyperParameterOptimization:
ptq_algo: ["KL", "hist", "avg", "mse"]
bias_correct: [True, False]
hist_percent: [0.98, 0.999],
batch_num: [10, 30],
```
以上配置项说明如下:
- ptq_algo: 设置待搜索的离线量化算法。
- bias_correct: 是否使用 bias correction 算法。
- hist_percent: 设置 'hist' 算法阈值的上限和下限,实际百分比在此范围内均匀采样而得。
- batch_num: 设置 'batch_num' 的上下限,实际数值在此范围内均匀采样而得。
### 1.1.4 知识蒸馏(knowledge distillation)
蒸馏参数主要设置蒸馏节点(`node`)和教师预测模型路径,如下所示:
```yaml
......@@ -96,7 +142,7 @@ Distillation:
- teacher_params_filename: 教师模型的参数文件名称,格式为 *.pdiparams 或 __params__。仅当设置`teacher_model_dir`后生效。
### 1.1.3 结构化稀疏(sparsity)
### 1.1.5 结构化稀疏(sparsity)
结构化稀疏参数设置如下所示:
```yaml
......@@ -126,7 +172,7 @@ for var_ in inference_program.list_vars():
- criterion: 评估卷积通道重要性的指标。可选 “l1_norm” , “bn_scale” , “geometry_median”。具体定义和使用可参考[结构化稀疏API文档](https://paddleslim.readthedocs.io/zh_CN/latest/api_cn/static/prune/prune_api.html)
### 1.1.4 ASP半结构化稀疏
### 1.1.6 ASP半结构化稀疏
半结构化稀疏参数设置如下所示:
```yaml
......@@ -151,7 +197,7 @@ for var_ in inference_program.list_vars():
或者,使用[Netron工具](https://netron.app/) 可视化`*.pdmodel`模型文件,选择合适的卷积层进行剪裁。
### 1.1.5 Transformer结构化剪枝
### 1.1.7 Transformer结构化剪枝
针对Transformer结构的结构化剪枝参数设置如下所示:
```yaml
......@@ -160,7 +206,7 @@ TransformerPrune:
```
- pruned_ratio: 每个全链接层的被剪裁的比例。
### 1.1.6 非结构化稀疏策略
### 1.1.8 非结构化稀疏策略
非结构化稀疏参数设置如下所示:
```yaml
......@@ -200,7 +246,7 @@ UnstructurePrune:
- local_sparsity 表示剪裁比例(ratio)应用的范围,仅在 'ratio' 模式生效。local_sparsity 开启时意味着每个参与剪裁的参数矩阵稀疏度均为 'ratio', 关闭时表示只保证模型整体稀疏度达到'ratio',但是每个参数矩阵的稀疏度可能存在差异。各个矩阵稀疏度保持一致时,稀疏加速更显著。
- 更多非结构化稀疏的参数含义详见[非结构化稀疏API文档](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/dygraph/pruners/unstructured_pruner.rst)
### 1.1.7 训练超参
### 1.1.9 训练超参
训练参数主要设置学习率、训练次数(epochs)和优化器等。
```yaml
......
......@@ -157,7 +157,9 @@ Prune:
pruned_ratio: 0.25
```
- 优化参数
- 离线量化超参搜索
本示例的离线量化采取了超参搜索策略,以选择最优的超参数取得更好的离线量化效果。首先,配置待搜索的参数:
```yaml
HyperParameterOptimization:
......@@ -177,12 +179,12 @@ HyperParameterOptimization:
- channel_wise_abs_max
```
- 量化参数
其次,配置离线量化参数:
量化参数主要设置量化比特数和量化op类型,其中量化op包含卷积层(conv2d, depthwise_conv2d)和全连接层(mul,matmul_v2)。
```yaml
Quantization:
QuantPost:
activation_bits: 8
quantize_op_types:
- conv2d
......
......@@ -10,7 +10,7 @@ TransformerPrune:
pruned_ratio: 0.25
HyperParameterOptimization:
Distillation:
Quantization:
QuantPost:
TrainConfig:
epochs: 6
eval_iter: 1070
......
......@@ -10,7 +10,7 @@ TransformerPrune:
pruned_ratio: 0.25
HyperParameterOptimization:
Distillation:
Quantization:
QuantPost:
TrainConfig:
epochs: 100
eval_iter: 70
......
......@@ -10,7 +10,7 @@ TransformerPrune:
pruned_ratio: 0.25
HyperParameterOptimization:
Distillation:
Quantization:
QuantPost:
TrainConfig:
epochs: 6
eval_iter: 2000
......
......@@ -10,7 +10,7 @@ TransformerPrune:
pruned_ratio: 0.25
HyperParameterOptimization:
Distillation:
Quantization:
QuantPost:
TrainConfig:
epochs: 16
eval_iter: 1000
......
......@@ -10,7 +10,7 @@ TransformerPrune:
pruned_ratio: 0.25
HyperParameterOptimization:
Distillation:
Quantization:
QuantPost:
TrainConfig:
epochs: 12
eval_iter: 750
......
......@@ -10,7 +10,7 @@ TransformerPrune:
pruned_ratio: 0.25
HyperParameterOptimization:
Distillation:
Quantization:
QuantPost:
TrainConfig:
epochs: 20
eval_iter: 1050
......
......@@ -10,7 +10,7 @@ TransformerPrune:
pruned_ratio: 0.25
HyperParameterOptimization:
Distillation:
Quantization:
QuantPost:
TrainConfig:
epochs: 6
eval_iter: 1110
......
......@@ -61,8 +61,14 @@ python -m pip install paddlepaddle_gpu==2.4rc0 -f https://www.paddlepaddle.org.c
pip install paddleslim==2.4rc
```
#### 版本对齐
#### 3.2 准备数据集
| PaddleSlim | x2paddle |
| :-----------: | :------------: |
| 2.3.x | 1.3.8 |
| develop / 2.4 | 1.3.9 |
### 3.2 准备数据集
**选择(1)或(2)中一种方法准备数据即可。**
......@@ -107,7 +113,7 @@ pip install paddleslim==2.4rc
```
#### 3.3 准备预测模型
### 3.3 准备预测模型
(1)准备ONNX模型:
......@@ -130,7 +136,7 @@ pip install paddleslim==2.4rc
**注意**:目前ACT支持**不带NMS**模型,使用如上命令导出即可。也可以直接下载我们已经准备好的[yolov7.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov7-tiny.onnx)
#### 3.4 自动压缩并产出模型
### 3.4 自动压缩并产出模型
蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口```paddleslim.auto_compression.AutoCompression```对模型进行自动压缩。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。
......@@ -160,7 +166,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch --log_dir=log -
│ ├── calibration.cache # TensorRT可以直接加载的校准表
```
#### Paddle Inference部署测试
### Paddle Inference部署测试
量化模型在GPU上可以使用TensorRT进行加速,在CPU上可以使用MKLDNN进行加速。
......@@ -219,7 +225,7 @@ bash compile.sh
./build/trt_run --model_file yolov7_quant/model.pdmodel --params_file yolov7_quant/model.pdiparams --run_mode=trt_int8
```
#### 导出至ONNX使用TensorRT部署
### 导出至ONNX使用TensorRT部署
加载`quant_model.onnx``calibration.cache`,可以直接使用TensorRT测试脚本进行验证,详细代码可参考[TensorRT部署](./TensorRT)
......
......@@ -128,7 +128,7 @@ def create_strategy_config(strategy_str, model_type):
quant_config = Quantization(**default_quant_config)
hpo_config = HyperParameterOptimization(**hpo_config_tester)
configs.append({
'Quantization': quant_config,
'QuantPost': quant_config,
'HyperParameterOptimization': hpo_config
})
else:
......@@ -251,7 +251,7 @@ def get_final_quant_config(ptq_loss, model_type=None):
quant_config = Quantization(**default_quant_config)
hpo_config = HyperParameterOptimization(**default_hpo_config)
configs = [{
'Quantization': quant_config,
'QuantPost': quant_config,
'HyperParameterOptimization': hpo_config
}]
......
......@@ -26,9 +26,10 @@ import paddle
import itertools
import paddle.distributed.fleet as fleet
from ..quant.quanter import convert, quant_post
from ..quant.reconstruction_quantization import quant_recon_static
from ..common.recover_program import recover_inference_program
from ..common import get_logger
from ..common.patterns import get_patterns
from ..common.patterns import get_patterns, find_final_nodes
from ..common.load_model import load_inference_model, get_model_dir, export_onnx
from ..common.dataloader import wrap_dataloader, get_feed_vars
from ..common.config_helper import load_config
......@@ -88,27 +89,29 @@ class AutoCompression:
to None. Default: None.
strategy_config(dict, list(dict), optional): The strategy config. You can set single config to get multi-strategy config, such as
1. set ``Quantization`` and ``Distillation`` to get quant_aware and distillation compress config.
The Quantization config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L24`_ .
The Distillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L39`_ .
2. set ``Quantization`` and ``HyperParameterOptimization`` to get quant_post and hyperparameter optimization compress config.
The Quantization config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L24`_ .
The HyperParameterOptimization config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L73`_ .
The Quantization config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L55`_ .
The Distillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L107`_ .
2. set ``QuantPost`` and ``HyperParameterOptimization`` to get quant_post and hyperparameter optimization compress config.
The QuantPost config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L187`_ .
The HyperParameterOptimization config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L160`_ .
3. set ``ChannelPrune`` and ``Distillation`` to get channel prune and distillation compress config.
The ChannelPrune config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L82`_ .
The Distillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L39`_ .
The ChannelPrune config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L254`_ .
The Distillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L107`_ .
4. set ``ASPPrune`` and ``Distillation`` to get asp prune and distillation compress config.
The ASPPrune config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L82`_ .
The Distillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L39`_ .
The ASPPrune config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L268`_ .
The Distillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L107`_ .
5. set ``TransformerPrune`` and ``Distillation`` to get transformer prune and distillation compress config.
The TransformerPrune config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L82`_ .
The Distillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L39`_ .
The TransformerPrune config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L278`_ .
The Distillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L107`_ .
6. set ``UnstructurePrune`` and ``Distillation`` to get unstructureprune and distillation compress config.
The UnstructurePrune config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L91`_ .
The Distillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L39`_ .
The UnstructurePrune config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L288`_ .
The Distillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L107`_ .
7. set ``Distillation`` to use one teacher modol to distillation student model.
The Distillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L39`_ .
The Distillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L107`_ .
8. set ``MultiTeacherDistillation`` to use multi-teacher to distillation student model.
The MultiTeacherDistillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L56`_ .
The MultiTeacherDistillation config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L134`_ .
9. set ``QuantPost`` to get quant_post compress config.
The QuantPost config can reference `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L187`_ .
If set to None, will choose a strategy automatically. Default: None.
target_speedup(float, optional): target speedup ratio by the way of auto compress. Default: None.
......@@ -155,7 +158,7 @@ class AutoCompression:
paddle.enable_static()
self._exe, self._places = self._prepare_envs()
self.model_type = self._get_model_type()
self.default_distill_node_pair, self.model_type = self._get_model_info()
if self.train_config is not None and self.train_config.use_fleet:
fleet.init(is_collective=True)
......@@ -188,7 +191,6 @@ class AutoCompression:
self._strategy, self._config = self._prepare_strategy(
self.strategy_config)
self.train_config = self._get_final_train_config(
self.train_config, self._strategy, self.model_type)
_logger.info(f"Selected strategies: {self._strategy}")
......@@ -206,7 +208,7 @@ class AutoCompression:
### The TrainConfig for quantization is extrapolate from above.
tmp_train_config = copy.deepcopy(train_config.__dict__)
### the epoch, train_iter, learning rate of quant is 10% of the prune compress
if self.model_type != 'transformer':
if self.model_type != 'transformer' and train_config.epochs is not None:
tmp_train_config['epochs'] = max(
int(train_config.epochs * 0.1), 1)
if train_config.train_iter is not None:
......@@ -306,13 +308,25 @@ class AutoCompression:
exe = paddle.static.Executor(places)
return exe, places
def _get_model_type(self):
def _get_model_info(self):
[inference_program, _, _] = (load_inference_model(
self.model_dir,
model_filename=self.model_filename,
params_filename=self.params_filename,
executor=self._exe))
_, _, model_type = get_patterns(inference_program)
### set the output of final weight node as the default distillation node
distill_node = []
final_weight_node = find_final_nodes(inference_program)
for out_var in final_weight_node:
distill_node.append('teacher_' + out_var.name())
distill_node.append(out_var.name())
model_type = None
if not isinstance(self.strategy_config, dict):
_, model_type = get_patterns(inference_program)
_logger.info(f"Detect model type: {model_type}")
if self.model_filename is None:
opt_model_filename = '__opt_model__'
else:
......@@ -326,8 +340,8 @@ class AutoCompression:
shutil.move(
os.path.join(self.updated_model_dir, opt_model_filename),
os.path.join(self.updated_model_dir, self.model_filename))
_logger.info(f"Detect model type: {model_type}")
return model_type
return distill_node, model_type
def _prepare_strategy(self, strategy_config):
if not isinstance(strategy_config, list):
......@@ -338,6 +352,7 @@ class AutoCompression:
for strategy_c in strategy_config:
quant_config = strategy_c.get("Quantization", None)
hpo_config = strategy_c.get("HyperParameterOptimization", None)
ptq_config = strategy_c.get("QuantPost", None)
prune_config = strategy_c.get("ChannelPrune", None)
asp_config = strategy_c.get("ASPPrune", None)
transformer_prune_config = strategy_c.get("TransformerPrune", None)
......@@ -388,10 +403,10 @@ class AutoCompression:
self._distill_config))
### case5: quant_config & hpo_config ==> PTQ & HPO
if quant_config is not None and hpo_config is not None:
if ptq_config is not None and hpo_config is not None:
only_distillation = False
strategy.append('ptq_hpo')
config.append(merge_config(quant_config, hpo_config))
config.append(merge_config(ptq_config, hpo_config))
### case6: quant_config & distill config ==> QAT & Distill
if quant_config is not None and self._distill_config is not None and 'ptq_hpo' not in strategy:
......@@ -408,6 +423,11 @@ class AutoCompression:
strategy.append('multi_teacher_dis')
config.append(multi_teacher_distill_config)
### case8: only qtp_config ==> PTQ
if ptq_config is not None and hpo_config is None:
strategy.append('quant_post')
config.append(ptq_config)
### NOTE: keep quantation in the last step
idx = -1
if 'qat_dis' in strategy and strategy.index('qat_dis') != (
......@@ -443,8 +463,7 @@ class AutoCompression:
return strategy
def _prepare_program(self, program, feed_target_names, fetch_targets,
patterns, default_distill_node_pair, strategy, config,
train_config):
patterns, strategy, config, train_config):
train_program = recover_inference_program(program)
startup_program = paddle.static.Program()
train_program_info = ProgramInfo(startup_program, train_program,
......@@ -481,7 +500,7 @@ class AutoCompression:
strategy, patterns, self.eval_dataloader)
if train_config.use_fleet:
dist_strategy = _prepare_fleet_strategy(train_config)
dist_strategy = self._prepare_fleet_strategy(train_config)
else:
dist_strategy = None
......@@ -495,7 +514,7 @@ class AutoCompression:
train_program_info,
pruner=self._pruner,
dist_strategy=dist_strategy,
default_distill_node_pair=default_distill_node_pair)
default_distill_node_pair=self.default_distill_node_pair)
self._quant_config = None
### add quant_aware program, quant always is last step
......@@ -567,6 +586,7 @@ class AutoCompression:
config = None
train_config = None
strategy_idx = None
self.final_metric = -1.0
for strategy_idx, (
strategy, config, train_config
) in enumerate(zip(self._strategy, self._config, self.train_config)):
......@@ -594,6 +614,19 @@ class AutoCompression:
if os.path.isfile(_file_path):
shutil.copy(_file_path, final_model_path)
shutil.rmtree(self.tmp_dir)
if self.eval_function is not None and self.final_metric < 0.0:
[inference_program, feed_target_names, fetch_targets]= load_inference_model( \
final_model_path, \
model_filename=self.model_filename, params_filename=self.params_filename,
executor=self._exe)
self.final_metric = self.eval_function(
self._exe, inference_program, feed_target_names,
fetch_targets)
if self.eval_function is not None:
_logger.info("==> The metric of final model is {:.4f}".format(
self.final_metric))
_logger.info(
"==> The ACT compression has been completed and the final model is saved in `{}`".
format(final_model_path))
......@@ -616,41 +649,64 @@ class AutoCompression:
params_filename=self.params_filename,
executor=self._exe)
if strategy == 'quant_post':
quant_post(
self._exe,
model_dir=model_dir,
quantize_model_path=os.path.join(
self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1))),
data_loader=self.train_dataloader,
model_filename=self.model_filename,
params_filename=self.params_filename,
save_model_filename=self.model_filename,
save_params_filename=self.params_filename,
batch_size=1,
batch_nums=config.batch_num,
algo=config.ptq_algo,
round_type='round',
bias_correct=config.bias_correct,
hist_percent=config.hist_percent,
quantizable_op_type=config.quantize_op_types,
is_full_quantize=config.is_full_quantize,
weight_bits=config.weight_bits,
activation_bits=config.activation_bits,
activation_quantize_type='range_abs_max',
weight_quantize_type=config.weight_quantize_type,
onnx_format=False)
if config.recon_level is None:
quant_post(
self._exe,
model_dir=self.updated_model_dir,
quantize_model_path=os.path.join(
self.tmp_dir,
'strategy_{}'.format(str(strategy_idx + 1))),
data_loader=self.train_dataloader,
model_filename=self.model_filename,
params_filename=self.params_filename,
save_model_filename=self.model_filename,
save_params_filename=self.params_filename,
batch_size=config.batch_size,
batch_nums=config.batch_nums,
algo=config.algo,
bias_correction=config.bias_correction,
hist_percent=config.hist_percent,
quantizable_op_type=config.quantize_op_types,
is_full_quantize=config.is_full_quantize,
weight_bits=config.weight_bits,
activation_bits=config.activation_bits,
activation_quantize_type=config.activation_quantize_type,
weight_quantize_type=config.weight_quantize_type,
onnx_format=config.onnx_format)
else:
quant_recon_static(
executor=self._exe,
model_dir=self.updated_model_dir,
quantize_model_path=os.path.join(
self.tmp_dir,
'strategy_{}'.format(str(strategy_idx + 1))),
data_loader=self.train_dataloader,
model_filename=self.model_filename,
params_filename=self.params_filename,
batch_size=config.batch_size,
batch_nums=config.batch_nums,
algo=config.algo,
hist_percent=config.hist_percent,
quantizable_op_type=config.quantize_op_types,
is_full_quantize=config.is_full_quantize,
bias_correction=config.bias_correction,
onnx_format=config.onnx_format,
weight_bits=config.weight_bits,
activation_bits=config.activation_bits,
weight_quantize_type=config.weight_quantize_type,
activation_quantize_type=config.activation_quantize_type,
recon_level=config.recon_level,
simulate_activation_quant=config.simulate_activation_quant,
regions=config.regions,
region_weights_names=config.region_weights_names,
skip_tensor_list=config.skip_tensor_list,
epochs=config.epochs,
lr=config.lr)
elif strategy == 'ptq_hpo':
if platform.system().lower() != 'linux':
raise NotImplementedError(
"post-quant-hpo is not support in system other than linux")
if self.updated_model_dir != model_dir:
# If model is ONNX, convert it to inference model firstly.
load_inference_model(
model_dir,
model_filename=self.model_filename,
params_filename=self.params_filename,
executor=self._exe)
if self.eval_function is None:
# If eval function is None, ptq_hpo will use emd distance to eval the quantized model, so need the dataloader without label
eval_dataloader = self.train_dataloader
......@@ -659,7 +715,7 @@ class AutoCompression:
post_quant_hpo.quant_post_hpo(
self._exe,
self._places,
model_dir=model_dir,
model_dir=self.updated_model_dir,
quantize_model_path=os.path.join(
self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1))),
train_dataloader=self.train_dataloader,
......@@ -707,12 +763,12 @@ class AutoCompression:
train_config.origin_metric, metric))
self.metric_before_compressed = metric
patterns, default_distill_node_pair, _ = get_patterns(
inference_program)
patterns = None
if 'transformer' in strategy:
patterns, _ = get_patterns(inference_program)
train_program_info, test_program_info = self._prepare_program(
inference_program, feed_target_names, fetch_targets, patterns,
default_distill_node_pair, strategy, config, train_config)
strategy, config, train_config)
if 'unstructure' in self._strategy:
test_program_info.program._program = remove_unused_var_nodes(
test_program_info.program._program)
......@@ -776,7 +832,7 @@ class AutoCompression:
self.metric_before_compressed)
) / self.metric_before_compressed <= 0.005:
_logger.info(
"The error rate between the compressed model and original model is less than 5%. The training process ends."
"The error rate between the compressed model and original model is less than 0.5%. The training process ends."
)
stop_training = True
break
......@@ -798,8 +854,9 @@ class AutoCompression:
)
if (train_config.train_iter and total_train_iter >=
train_config.train_iter) or stop_training:
stop_training = True
break
self.final_metric = best_metric
if 'unstructure' in self._strategy or train_config.sparse_model:
self._pruner.update_params()
......
......@@ -29,6 +29,7 @@ __all__ = [
"TrainConfig",
"SUPPORTED_CONFIG",
"TRAIN_CONFIG_NAME",
"QuantPost",
]
SUPPORTED_CONFIG = [
......@@ -40,6 +41,7 @@ SUPPORTED_CONFIG = [
"UnstructurePrune",
"TransformerPrune",
"ASPPrune",
"QuantPost",
]
TRAIN_CONFIG_NAME = "TrainConfig"
......@@ -182,6 +184,73 @@ class HyperParameterOptimization(BaseStrategy):
self.max_quant_count = max_quant_count
class QuantPost(BaseStrategy):
def __init__(self,
batch_size=32,
batch_nums=None,
epochs=20,
lr=0.1,
algo='hist',
hist_percent=0.999,
regions=None,
region_weights_names=None,
recon_level=None,
is_full_quantize=False,
bias_correction=False,
weight_quantize_type='channel_wise_abs_max',
activation_quantize_type='range_abs_max',
simulate_activation_quant=False,
skip_tensor_list=None,
onnx_format=False,
quantize_op_types=[
"conv2d", "depthwise_conv2d", "mul", "matmul", "matmul_v2"
],
weight_bits=8,
activation_bits=8):
"""
QuantPost Config.
Args:
batch_size(int, optional): The batch size of DataLoader. Default: 1.
batch_nums(int, optional): If batch_nums is not None, the number of calibrate data is 'batch_size*batch_nums'. If batch_nums is None, use all data generated by sample_generator as calibrate data. Default: None.
lr(float, optional): The learning rate of Reconstruction Quanter. Default: 0.1.
algo(str, optional): Post-Training Quantization algorithm, can be set reference the algo from `<https://paddleslim.readthedocs.io/zh_CN/latest/api_cn/static/quant/quantization_api.html#quant-post-static>`. Default: 'hist'.
hist_percent(float, optional): The percentile of histogram for algo hist. Default: 0.999.
regions(list[list], optional): The list of some regions, each region is a subgraph of fp32 program and it will have exact 1 input operation and 1 output operation. When the recon-level is region, the reconstruction loss of each region is minimized. Default: None.
region_weights_names(list[list], optional): The weight names inside every region. Default: None.
recon_level(str, optional): The type of reconstruction granularity. Currently support ['layer-wise', 'region-wise'] types. Only when recon_level isn't None can Reconstruction Quanter be used. Default: None.
is_full_quantize(bool): If True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES. Default: False.
bias_correct(list(bool)): Whether to use bias correction method of https://arxiv.org/abs/1810.05723. Default: False.
weight_quantize_type(str): Weight quantize type. Default: 'channel_wise_abs_max'.
activation_quantize_type(str): Activation quantize type. Default: 'moving_average_abs_max'.
simulate_activation_quant(bool, optional): Whether we need the noise caused by activation quantization during the reconstruction process. Default: False.
skip_tensor_list(list): List of skip quant tensor name. Default: None.
onnx_format(bool): Whether to export the quantized model with format of ONNX. Default: False.
quantize_op_types(list(str)): Ops of type in quantize_op_types, will be quantized. Default: ['conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2'].
weight_bits(int): Weight quantize bit num. Default: 8.
activation_bits(int): Activation quantize bit num. Default: 8.
"""
super(QuantPost, self).__init__("PTQ")
self.batch_size = batch_size
self.batch_nums = batch_nums
self.epochs = epochs
self.lr = lr
self.algo = algo
self.hist_percent = hist_percent
self.regions = regions
self.region_weights_names = region_weights_names
self.recon_level = recon_level
self.is_full_quantize = is_full_quantize
self.bias_correction = bias_correction
self.weight_quantize_type = weight_quantize_type
self.activation_quantize_type = activation_quantize_type
self.simulate_activation_quant = simulate_activation_quant
self.skip_tensor_list = skip_tensor_list
self.onnx_format = onnx_format
self.quantize_op_types = quantize_op_types
self.weight_bits = weight_bits
self.activation_bits = activation_bits
class ChannelPrune:
def __init__(self, pruned_ratio, prune_params_name, criterion='l1_norm'):
"""
......
......@@ -79,8 +79,7 @@ def _is_ffn(pattern_ops, pattern_ops_type):
def get_patterns(program, only_final_node=True):
""" distinguish the pattern in the program and get distillation node """
distill_node = []
""" distinguish the pattern in the program and get model type """
skip_quant_tensor_list = []
patterns = {}
graph = GraphWrapper(program)
......@@ -124,10 +123,6 @@ def get_patterns(program, only_final_node=True):
pattern_name = 'FFN$' + str(block_num)
block_num += 1
if not only_final_node:
distill_node.append('teacher_' + out_var_name)
distill_node.append(out_var_name)
if model_type == 'transformer' and (
'fetch' in pattern_ops_type or
pattern_ops_type[-1] == 'scale'):
......@@ -140,16 +135,6 @@ def get_patterns(program, only_final_node=True):
patterns[pattern_name] = pattern_ops
if model_type != 'transformer' and (not only_final_node):
distill_node.append('teacher_' + out_var_name)
distill_node.append(out_var_name)
### add the output of final weight node to distill node
final_weight_node = find_final_nodes(program)
for out_var in final_weight_node:
distill_node.append('teacher_' + out_var.name())
distill_node.append(out_var.name())
#### skip quant matmul in attention
if model_type == 'transformer':
for block_id in range(len(program.blocks)):
......@@ -158,4 +143,4 @@ def get_patterns(program, only_final_node=True):
if inp_name in skip_quant_tensor_list:
op._set_attr("op_namescope", "skip_quant")
return patterns, distill_node, model_type
return patterns, model_type
......@@ -319,7 +319,7 @@ def quant_aware(program,
skip_tensor_list = []
same_scale_tensor_list = []
if model_type == 'transformer' and pattern_ops is None:
pattern_ops, _, model_type = get_patterns(program)
pattern_ops, model_type = get_patterns(program)
if model_type != 'transformer':
_logger.info(
'Warning! After analysis, the real model type is not transformer! If you encounter this situation, please raise an issue let us know in which case "get_patterns" determines model type is not transformer.'
......
......@@ -352,40 +352,42 @@ class ReconstructionQuanter(object):
def _run(self):
self._preprocess()
startup_program = paddle.static.Program()
tmp_program = self._student_program.clone()
for k in range(len(self._regions)):
region_ = self._regions[k]
names = self._region_weights_names[k]
tmp_program = self._student_program.clone()
tmp_program.global_block().var(region_[0]).stop_gradient = True
quant_op_out_name = region_[1]
names = self._region_weights_names[k]
_logger.info(f"Current weights: {names}")
loss_function = ReconstructionQuanterLoss(
program=tmp_program, weight_region_names=names)
update_params = [
tmp_program.global_block().var(name + '.alpha')
for name in names
]
with paddle.static.program_guard(tmp_program, startup_program):
loss_function = ReconstructionQuanterLoss(tmp_program, names)
quant_op_out_name = region_[1]
student_var = tmp_program.global_block().var(quant_op_out_name)
teacher_var = tmp_program.global_block().var("teacher_" +
quant_op_out_name)
scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=20,
eta_min=2,
T_max=2000,
verbose=True, )
total_loss, recon_loss, round_loss = loss_function.get_loss(
student_var,
teacher_var,
scheduler, )
teacher_var, )
train_fetches_loss = {
"total_loss": total_loss,
"recon_loss": recon_loss,
"round_loss": round_loss,
}
optimizer = paddle.optimizer.Adam(learning_rate=self._lr)
optimizer = paddle.optimizer.Adam(
learning_rate=self._lr, parameters=update_params)
optimizer.minimize(total_loss)
self._exe.run(startup_program)
start_time = time.time()
prev_start_time = start_time
loader = self._data_loader()
for epoch in range(self._epochs):
for i, data in enumerate(loader):
for i, data in (enumerate(self._data_loader())):
prev_start_time = start_time
start_time = time.time()
out = self._exe.run(
......@@ -396,14 +398,14 @@ class ReconstructionQuanter(object):
],
return_numpy=True, )
_logger.info(
"Iter {:d}, lr {}, total_loss {:.5f}, recon_loss {:.5f}, round_loss {:.5f}, time {:.5f}s"
.format(epoch, self._lr,
"Epoch {:d}, Iter {:d}, lr {}, total_loss {:.5f}, recon_loss {:.5f}, round_loss {:.5f}, time {:.5f}s"
.format(epoch, i, self._lr,
np.mean(out[0]),
np.mean(out[1]),
np.mean(out[2]),
start_time - prev_start_time), )
sys.stdout.flush()
if i == self._num_iterations:
if i + 1 == self._num_iterations:
break
self._update_weights_to_int()
if self._bias_correction:
......@@ -776,7 +778,7 @@ class ReconstructionQuanterLoss(object):
paddle.nn.functional.sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA, 0,
1)
def get_loss(self, student_tensor, teacher_tensor, scheduler):
def get_loss(self, student_tensor, teacher_tensor, scheduler=None):
if self.rec_loss_type == 'mse':
rec_loss = paddle.nn.functional.mse_loss(
student_tensor,
......
......@@ -153,5 +153,55 @@ class TestLoadONNXModel(ACTBase):
deploy_backend='tensorrt')
class TestDictPTQ(ACTBase):
def __init__(self, *args, **kwargs):
super(TestDictPTQ, self).__init__(*args, **kwargs)
def test_compress(self):
image = paddle.static.data(
name='data', shape=[-1, 3, 32, 32], dtype='float32')
train_loader = paddle.io.DataLoader(
self.eval_dataset,
feed_list=[image],
batch_size=4,
return_list=False)
ac = AutoCompression(
model_dir=self.tmpdir.name,
model_filename="infer.pdmodel",
params_filename="infer.pdiparams",
save_dir="output",
config={'QuantPost': {}},
train_dataloader=train_loader,
eval_dataloader=train_loader
) # eval_function to verify accuracy
ac.compress()
class TestDictPTQRecon(ACTBase):
def __init__(self, *args, **kwargs):
super(TestDictPTQRecon, self).__init__(*args, **kwargs)
def test_compress(self):
image = paddle.static.data(
name='data', shape=[-1, 3, 32, 32], dtype='float32')
train_loader = paddle.io.DataLoader(
self.eval_dataset,
feed_list=[image],
batch_size=4,
return_list=False)
ac = AutoCompression(
model_dir=self.tmpdir.name,
model_filename="infer.pdmodel",
params_filename="infer.pdiparams",
save_dir="output",
config={'QuantPost': {
'recon_level': 'layer-wise'
}},
train_dataloader=train_loader,
eval_dataloader=train_loader
) # eval_function to verify accuracy
ac.compress()
if __name__ == '__main__':
unittest.main()
......@@ -56,7 +56,7 @@ class ACTDemo(unittest.TestCase):
params_filename="inference.pdiparams",
save_dir="MobileNetV1_quant",
config={
'Quantization': {},
'QuantPost': {},
"HyperParameterOptimization": {
'ptq_algo': ['avg'],
'max_quant_count': 3
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册