未验证 提交 94144bea 编写于 作者: G gushiqiao 提交者: GitHub

Add README.md (#1483)

上级 301822dd
......@@ -26,6 +26,9 @@
| YOLOv6s | Base模型 | 640*640 | 42.4 | 9.06ms | 2.90ms | - | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov6s.onnx) |
| YOLOv6s | KL离线量化(量化分析前) | 640*640 | 30.3 | - | - | 1.83ms | - | - |
| YOLOv6s | KL离线量化(量化分析后) | 640*640 | 39.7 | - | - | - | - | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_analyzed_ptq.tar) |
| YOLOv6s | avg离线量化(+Adaround) | 640*640 | 39.2 | - | - | - | - | - |
| YOLOv6s | avg离线量化(+BRECQ) | 640*640 | 38.7 | - | - | - | - | - |
| YOLOv6s | avg离线量化(+QDrop) | 640*640 | 38.0 | - | - | - | - | - |
| | | | | | | | | |
| YOLOv7 | Base模型 | 640*640 | 51.1 | 26.84ms | 7.44ms | - | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov7.onnx) |
| YOLOv7 | KL离线量化 | 640*640 | 50.2 | - | - | 4.55ms | - | - |
......@@ -116,8 +119,11 @@ python eval.py --config_path=./configs/yolov5s_ptq.yaml
#### 3.6 提高离线量化精度
###### 3.6.1 量化分析工具
本节介绍如何使用量化分析工具提升离线量化精度。离线量化功能仅需使用少量数据,且使用简单、能快速得到量化模型,但往往会造成较大的精度损失。PaddleSlim提供量化分析工具,会使用接口```paddleslim.quant.AnalysisQuant```,可视化展示出不适合量化的层,通过跳过这些层,提高离线量化模型精度。```paddleslim.quant.AnalysisQuant```详解见[AnalysisQuant.md](../../../docs/zh_cn/tutorials/quant/AnalysisQuant.md)
由于YOLOv6离线量化效果较差,以YOLOv6为例,量化分析工具具体使用方法如下:
```shell
......@@ -148,8 +154,6 @@ python post_quant.py --config_path=./configs/yolov6s_analyzed_ptq.yaml --save_di
如想分析之后直接产出符合目标精度的量化模型,可在 `yolov6s_analysis.yaml` 中将`get_target_quant_model`设置为True,并填写 `target_metric`,注意 `target_metric` 不能比原模型精度高。
**加速分析过程**
使用量化分析工具时,因需要逐层量化模型并进行验证,因此过程可能较慢,若想加速分析过程,可以在配置文件中设置 `fast_val_anno_path` ,输入一个图片数量较少的annotation文件路径。注意,用少量数据验证的模型精度不一定等于全量数据验证的模型精度,若只需分析时获得不同层量化效果的相对排序,可以使用少量数据集;若要求准确精度,请使用全量验证数据集。如需要全量验证数据,将 `fast_val_anno_path` 设置为None即可。
......@@ -157,6 +161,48 @@ python post_quant.py --config_path=./configs/yolov6s_analyzed_ptq.yaml --save_di
注:分析之后若需要直接产出符合目标精度的量化模型,demo代码不会使用少量数据集验证,会自动使用全量验证数据。
量化分析工具详细介绍见[量化分析工具介绍](../analysis.md)
###### 3.6.2 精度重构工具
本节介绍如何使用精度重构工具提高精度。该工具的思想是,通过最小化量化前后模型输出的重构误差(minimizing the reconstruction error,MRE),学习权重的取整方式(上取整or下取整),从而`fine-tune`经量化后的模型的权重,提高精度。同样以YOLOv6为例,运行命令如下:
```shell
python fine_tune.py --config_path=./configs/yolov6s_fine_tune.yaml --recon_level=layer-wise
```
其中`recon_level`表示重构的粒度,默认为`layer-wise`,即逐层重构。如下图,该工具首先会统计激活和权重量化需要的`scales`,随后为每个权重添加`soft-rounding`操作使得权重的取整方式可学习,以及逐层的增加重构`loss`
<p align="center">
<img src="../../../docs/images/adaround.png" width=749 hspace='10'/> <br />
</p>
通过最小化重构`loss`,为每层的权重学习最合适的`round`方式,其思想类似[论文](https://arxiv.org/abs/2004.10568)提出的`Adround`方法。
该过程也可看成知识蒸馏,预训练模型可视为教师模型,经离线量化后的模型可视为学生模型。
类似的,该工具还支持以`region/block`为单位添加重构`loss`,类似[论文](https://arxiv.org/pdf/2102.05426)提出的`BRECQ`方法,其中`region`可能包含多层,如下图所示。
<p align="center">
<img src="../../../docs/images/brecq.png" width=749 hspace='10'/> <br />
</p>
具体运行命令如下:
```shell
python fine_tune.py --config_path=./configs/yolov6s_fine_tune.yaml --recon_level=region-wise
```
此外,该工具还支持在重构过程中引入激活量化产生的噪声,如下图所示,在每层前插入`quant/dequant`节点,随机的进行激活量化,核心思想类似[论文](https://arxiv.org/pdf/2203.05740)提出的`QDrop`方法。
<p align="center">
<img src="../../../docs/images/qdrop.png" width=749 hspace='10'/> <br />
</p>
具体运行命令如下,只需将`simulate_activation_quant`设置为`True`即可。
```shell
python fine_tune.py --config_path=./configs/yolov6s_fine_tune.yaml --simulate_activation_quant=True
```
实验结果如上表所示,与量化分析工具不同,精度重构工具无需跳过某些层,就可提升离线量化精度。
## 4.预测部署
预测部署可参考[YOLO系列模型自动压缩示例](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/auto_compression/pytorch_yolo_series)
......
......@@ -157,6 +157,8 @@ class ReconstructionQuantization(PostTrainingQuantization):
place=self._place,
quantized_op_pairs=self._quantized_op_pairs,
weight_quantize_type=self._weight_quantize_type,
activation_bits=self._activation_bits,
weight_bits=self._weight_bits,
scale_dict=copy.deepcopy(self._scale_dict),
regions=self._config['regions'],
region_weights_names=self._config['region_weights_names'],
......@@ -165,8 +167,7 @@ class ReconstructionQuantization(PostTrainingQuantization):
num_iterations=self._batch_nums,
lr=self._config['lr'],
bias_correction=self._bias_correction,
epochs=self._config['epochs'],
scale_trainable=self._config['scale_trainable'])
epochs=self._config['epochs'], )
self._program = reconstruction_quanter._run()
def _postprocessing(self):
......@@ -211,6 +212,8 @@ class ReconstructionQuanter(object):
place,
quantized_op_pairs,
weight_quantize_type,
activation_bits,
weight_bits,
scale_dict,
regions,
region_weights_names,
......@@ -220,7 +223,6 @@ class ReconstructionQuanter(object):
lr=0.1,
bias_correction=False,
epochs=20,
scale_trainable=False,
drop_prob=0.5):
'''
Reconstruction Quanter, used to optimize the rounding policy
......@@ -259,7 +261,6 @@ class ReconstructionQuanter(object):
lr(float, optional): The learning rate of Reconstruction Quanter. Default is 0.1.
bias_correction(bool, optional): If set as True, use the bias correction
method of https://arxiv.org/abs/1810.05723. Default is False.
scale_trainable: Wether weight‘s scale is trainable. Default is False.
drop_prob: The dropout probability of activation quantization, and it is valid only if
simulate_activation_quant is True. Default is 0.5.
Returns:
......@@ -286,6 +287,8 @@ class ReconstructionQuanter(object):
self._weight_var_names = list(self._quantized_op_pairs.keys())
self._weight_quantize_type = weight_quantize_type
self._scale_dict = scale_dict
self._activation_bits = activation_bits
self._weight_bits = weight_bits
self._num_iterations = num_iterations
self._epochs = epochs
self._lr = lr
......@@ -296,7 +299,6 @@ class ReconstructionQuanter(object):
regions, region_weights_names = self._get_layers()
self._regions = regions
self._region_weights_names = region_weights_names
self._scale_trainable = scale_trainable
self._drop_prob = drop_prob
def _get_layers(self):
......@@ -336,7 +338,9 @@ class ReconstructionQuanter(object):
for name in self._weight_var_names:
weight_np = utils.load_variable_data(self._scope, name)
scale = self._scale_dict[name]
weight_np_floor = np.floor(utils.quant_tensor(weight_np, scale))
weight_np_floor = np.floor(
utils.quant_tensor(
x=weight_np, scale=scale, weight_bits=self._weight_bits))
utils.set_variable_data(
self._scope,
self._place,
......@@ -359,7 +363,6 @@ class ReconstructionQuanter(object):
quant_op_out_name = region_[1]
with paddle.static.program_guard(tmp_program, startup_program):
loss_function = ReconstructionQuanterLoss(tmp_program, names)
quant_op_out_name = region_[1]
student_var = tmp_program.global_block().var(quant_op_out_name)
teacher_var = tmp_program.global_block().var("teacher_" +
quant_op_out_name)
......@@ -385,7 +388,11 @@ class ReconstructionQuanter(object):
prev_start_time = start_time
loader = self._data_loader()
for epoch in range(self._epochs):
for i, data in enumerate(loader):
for i, data in (
enumerate(loader) if
(isinstance(self._data_loader, paddle.fluid.io.DataLoader)
and self._data_loader.batch_size == 1) else
enumerate(self._data_loader())):
prev_start_time = start_time
start_time = time.time()
out = self._exe.run(
......@@ -418,13 +425,13 @@ class ReconstructionQuanter(object):
alpha = -np.log((ZETA - GAMMA) / (tensor - GAMMA) - 1)
return alpha
def _soft_rounding(self, weight, scale, weight_bits=8):
def _soft_rounding(self, weight, scale):
"""
Define network of soft rounding.
Args:
weight: The quanted weight with dtype=float32
"""
bnt = (1 << (weight_bits - 1)) - 1
bnt = (1 << (self._weight_bits - 1)) - 1
def _dequant(x, scale):
s = (scale + 1e-8) / bnt
......@@ -470,18 +477,18 @@ class ReconstructionQuanter(object):
scale = np.array(scale)
scale = scale.reshape(scale.shape[0], 1)
if len(shape) == 2:
scale = scale.repeat(shape[0], axis=0)
scale = scale.repeat(shape[0], axis=1).T
else:
scale = scale.repeat(shape[1] * shape[2] * shape[3], axis=1)
scale = scale.reshape(shape)
scale = scale.reshape(shape)
self._insert_func(var=weight, scale=scale, func="_soft_rounding")
def _drop_quant_dequant(self, inputs, scale, weight_bits=8):
def _drop_quant_dequant(self, inputs, scale):
x = paddle.static.data(
shape=inputs.shape,
dtype=inputs.dtype,
name=inputs.name + '.tmp', )
bnt = (1 << (weight_bits - 1)) - 1
bnt = (1 << (self._weight_bits - 1)) - 1
scale = scale / bnt
dequantized_tensor = paddle.round(x / scale) * scale
quant_noise = x - dequantized_tensor
......@@ -537,8 +544,7 @@ class ReconstructionQuanter(object):
shape=new_var.shape,
dtype=new_var.dtype,
type=new_var.type,
stop_gradient=True,
trainable=self._scale_trainable, )
stop_gradient=True, )
else:
if func == "_soft_rounding":
program.global_block().create_var(
......@@ -721,7 +727,7 @@ class ReconstructionQuanter(object):
weight_quant_tensor,
scale,
quant_axis=0,
weight_bits=8, )
weight_bits=self._weight_bits, )
utils.set_variable_data(
self._scope,
self._place,
......@@ -826,8 +832,6 @@ def quant_recon_static(executor,
"conv2d",
"depthwise_conv2d",
"mul",
"matmul",
"matmul_v2",
],
is_full_quantize=False,
weight_bits=8,
......@@ -842,7 +846,6 @@ def quant_recon_static(executor,
regions=None,
region_weights_names=None,
epochs=20,
scale_trainable=False,
drop_prob=0.5,
lr=0.1):
"""
......@@ -919,7 +922,6 @@ def quant_recon_static(executor,
is_use_cache_file(bool): This param is deprecated.
cache_dir(str): This param is deprecated.
epochs: The number of steps in the reconstruction proces. Default is 20.
scale_trainable: Wether weight‘s scale is trainable. Default is False.
drop_prob: The dropout probability of activation quantization, and it is valid only if
simulate_activation_quant is True. Default is 0.5.
regions(list[list], optional): The list of some regions, each region is a subgraph of
......@@ -963,7 +965,6 @@ def quant_recon_static(executor,
regions=regions,
region_weights_names=region_weights_names,
epochs=epochs,
scale_trainable=scale_trainable,
lr=lr)
reconstruction_quantization = ReconstructionQuantization(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册