未验证 提交 dfe2e3a7 编写于 作者: C ceci3 提交者: GitHub

fix conflict (#1534)

上级 3c8722a4
...@@ -98,7 +98,7 @@ ACT相比传统的模型压缩方法, ...@@ -98,7 +98,7 @@ ACT相比传统的模型压缩方法,
```shell ```shell
# CPU # CPU
pip install paddlepaddle==2.4rc0 pip install paddlepaddle==2.4rc0
# GPU 以CUDA11.2为例 # GPU 以CUDA11.2为例
python -m pip install paddlepaddle_gpu==2.4rc0 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html python -m pip install paddlepaddle_gpu==2.4rc0 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
``` ```
...@@ -124,6 +124,8 @@ tar -xf ILSVRC2012_data_demo.tar.gz ...@@ -124,6 +124,8 @@ tar -xf ILSVRC2012_data_demo.tar.gz
- **2.运行自动化压缩** - **2.运行自动化压缩**
由于目前离线量化超参搜索仅支持Linux系统,以下默认示例需在Linux环境中测试。如果想要在Windows环境中测试,可以使用代码中Windows环境的config,由于Windows环境中配置的压缩策略为量化训练,所以需要全量数据集,否则会有一定的精度下降。
```python ```python
# 导入依赖包 # 导入依赖包
import paddle import paddle
...@@ -162,7 +164,8 @@ ac = AutoCompression( ...@@ -162,7 +164,8 @@ ac = AutoCompression(
model_filename="inference.pdmodel", model_filename="inference.pdmodel",
params_filename="inference.pdiparams", params_filename="inference.pdiparams",
save_dir="MobileNetV1_quant", save_dir="MobileNetV1_quant",
config={'Quantization': {}, "HyperParameterOptimization": {'ptq_algo': ['avg'], 'max_quant_count': 3}}, config={"Quantization": {}, "HyperParameterOptimization": {'ptq_algo': ['avg'], 'max_quant_count': 3}},
### config={"Quantization": {}, "Distillation": {}}, ### 如果您的系统为Windows系统, 请使用当前这一行配置
train_dataloader=train_loader, train_dataloader=train_loader,
eval_dataloader=train_loader) eval_dataloader=train_loader)
ac.compress() ac.compress()
...@@ -190,7 +193,7 @@ ac.compress() ...@@ -190,7 +193,7 @@ ac.compress()
- 量化模型速度的测试依赖推理库的支持,所以确保安装的是带有TensorRT的PaddlePaddle。以下示例和展示的测试结果是基于Tesla V100、CUDA 10.2、Python3.7、TensorRT得到的。 - 量化模型速度的测试依赖推理库的支持,所以确保安装的是带有TensorRT的PaddlePaddle。以下示例和展示的测试结果是基于Tesla V100、CUDA 10.2、Python3.7、TensorRT得到的。
- 使用以下指令查看本地cuda版本,并且在[下载链接](https://www.paddlepaddle.org.cn/inference/user_guides/download_lib.html#python)中下载对应cuda版本和对应python版本的paddlepaddle安装包。 - 使用以下指令查看本地cuda版本,并且在[下载链接](https://www.paddlepaddle.org.cn/inference/user_guides/download_lib.html#python)中下载对应cuda版本和对应python版本的PaddlePaddle安装包。
```shell ```shell
cat /usr/local/cuda/version.txt ### CUDA Version 10.2.89 cat /usr/local/cuda/version.txt ### CUDA Version 10.2.89
......
...@@ -143,7 +143,12 @@ python -m paddle.distributed.launch run.py --save_dir='./save_quant_mobilev1/' - ...@@ -143,7 +143,12 @@ python -m paddle.distributed.launch run.py --save_dir='./save_quant_mobilev1/' -
- TensorRT预测: - TensorRT预测:
环境配置:如果使用 TesorRT 预测引擎,需安装 ```WITH_TRT=ON``` 的Paddle,下载地址:[Python预测库](https://paddleinference.paddlepaddle.org.cn/master/user_guides/download_lib.html#python) 环境配置:如果使用 TesorRT 预测引擎,需安装的是带有TensorRT的PaddlePaddle,使用以下指令查看本地cuda版本,并且在[下载链接](https://www.paddlepaddle.org.cn/inference/user_guides/download_lib.html#python)中下载对应cuda版本和对应python版本的PaddlePaddle安装包。
```shell
cat /usr/local/cuda/version.txt ### CUDA Version 10.2.89
### 10.2.89 为cuda版本号,可以根据这个版本号选择需要安装的带有TensorRT的PaddlePaddle安装包。
```
```shell ```shell
python paddle_inference_eval.py \ python paddle_inference_eval.py \
......
...@@ -31,12 +31,12 @@ ...@@ -31,12 +31,12 @@
| ERNIE 3.0-Medium | 剪枝+量化训练| 74.17 | 56.84 | 59.75 | 80.54 | 76.03 | 76.97 | 80.80 | 72.16 | | ERNIE 3.0-Medium | 剪枝+量化训练| 74.17 | 56.84 | 59.75 | 80.54 | 76.03 | 76.97 | 80.80 | 72.16 |
模型在不同任务上平均精度以及加速对比如下: 模型在不同任务上平均精度以及加速对比如下:
| 模型 |策略| Accuracy(avg) | 时延(ms) | 加速比 | | 模型 |策略| Accuracy(avg) | 预测时延<sup><small>FP32</small><sup><br><sup> | 预测时延<sup><small>FP16</small><sup><br><sup> | 预测时延<sup><small>INT8</small><sup><br><sup> | 加速比 |
|:-------:|:--------:|:----------:|:------------:| :------:| |:-------:|:--------:|:----------:|:------------:|:------:|:------:|:------:|
|PP-MiniLM| Base模型| 72.81 | 128.01 | - | |PP-MiniLM| Base模型| 72.81 | 94.49ms | 23.31ms | - | - |
|PP-MiniLM| 剪枝+离线量化 | 72.44 | 17.97 | 7.12 | |PP-MiniLM| 剪枝+离线量化 | 71.85 | - | - | 15.76ms | 5.99x |
|ERNIE 3.0-Medium| Base模型| 73.09 | 29.25(fp16) | - | |ERNIE 3.0-Medium| Base模型| 73.09 | 89.71ms | 20.76ms | - | - |
|ERNIE 3.0-Medium| 剪枝+量化训练 | 72.16 | 19.61 | 1.49 | |ERNIE 3.0-Medium| 剪枝+量化训练 | 72.16 | - | - | 14.08ms | 6.37x |
性能测试的环境为 性能测试的环境为
- 硬件:NVIDIA Tesla T4 单卡 - 硬件:NVIDIA Tesla T4 单卡
......
...@@ -316,7 +316,7 @@ class AutoCompression: ...@@ -316,7 +316,7 @@ class AutoCompression:
if self.model_filename is None: if self.model_filename is None:
opt_model_filename = '__opt_model__' opt_model_filename = '__opt_model__'
else: else:
opt_model_filename = 'opt_' + self.model_filename opt_model_filename = self.model_filename
program_bytes = inference_program._remove_training_info( program_bytes = inference_program._remove_training_info(
clip_extra=False).desc.serialize_to_string() clip_extra=False).desc.serialize_to_string()
with open( with open(
...@@ -553,8 +553,8 @@ class AutoCompression: ...@@ -553,8 +553,8 @@ class AutoCompression:
def create_tmp_dir(self, base_dir, prefix="tmp"): def create_tmp_dir(self, base_dir, prefix="tmp"):
# create a new temp directory in final dir # create a new temp directory in final dir
s_datetime = strftime("%Y_%m_%d_%H_%M_%S", gmtime()) s_datetime = strftime("%Y_%m_%d_%H_%M", gmtime())
tmp_base_name = "_".join([prefix, str(os.getpid()), s_datetime]) tmp_base_name = "_".join([prefix, str(os.getppid()), s_datetime])
tmp_dir = os.path.join(base_dir, tmp_base_name) tmp_dir = os.path.join(base_dir, tmp_base_name)
if not os.path.exists(tmp_dir): if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir) os.makedirs(tmp_dir)
...@@ -585,10 +585,10 @@ class AutoCompression: ...@@ -585,10 +585,10 @@ class AutoCompression:
self.single_strategy_compress(quant_strategy[0], self.single_strategy_compress(quant_strategy[0],
quant_config[0], strategy_idx, quant_config[0], strategy_idx,
train_config) train_config)
tmp_model_path = os.path.join(
self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1)))
final_model_path = os.path.join(self.final_dir)
if paddle.distributed.get_rank() == 0: if paddle.distributed.get_rank() == 0:
tmp_model_path = os.path.join(
self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1)))
final_model_path = os.path.join(self.final_dir)
for _file in os.listdir(tmp_model_path): for _file in os.listdir(tmp_model_path):
_file_path = os.path.join(tmp_model_path, _file) _file_path = os.path.join(tmp_model_path, _file)
if os.path.isfile(_file_path): if os.path.isfile(_file_path):
...@@ -718,7 +718,8 @@ class AutoCompression: ...@@ -718,7 +718,8 @@ class AutoCompression:
test_program_info.program._program) test_program_info.program._program)
test_program_info = self._start_train( test_program_info = self._start_train(
train_program_info, test_program_info, strategy, train_config) train_program_info, test_program_info, strategy, train_config)
self._save_model(test_program_info, strategy, strategy_idx) if paddle.distributed.get_rank() == 0:
self._save_model(test_program_info, strategy, strategy_idx)
def _start_train(self, train_program_info, test_program_info, strategy, def _start_train(self, train_program_info, test_program_info, strategy,
train_config): train_config):
...@@ -854,16 +855,17 @@ class AutoCompression: ...@@ -854,16 +855,17 @@ class AutoCompression:
def export_onnx(self, def export_onnx(self,
model_name='quant_model.onnx', model_name='quant_model.onnx',
deploy_backend='tensorrt'): deploy_backend='tensorrt'):
infer_model_path = os.path.join(self.final_dir, self.model_filename) if paddle.distributed.get_rank() == 0:
assert os.path.exists( infer_model_path = os.path.join(self.final_dir, self.model_filename)
infer_model_path), 'Not found {}, please check it.'.format( assert os.path.exists(
infer_model_path) infer_model_path), 'Not found {}, please check it.'.format(
onnx_save_path = os.path.join(self.final_dir, 'ONNX') infer_model_path)
if not os.path.exists(onnx_save_path): onnx_save_path = os.path.join(self.final_dir, 'ONNX')
os.makedirs(onnx_save_path) if not os.path.exists(onnx_save_path):
export_onnx( os.makedirs(onnx_save_path)
self.final_dir, export_onnx(
model_filename=self.model_filename, self.final_dir,
params_filename=self.params_filename, model_filename=self.model_filename,
save_file_path=os.path.join(onnx_save_path, model_name), params_filename=self.params_filename,
deploy_backend=deploy_backend) save_file_path=os.path.join(onnx_save_path, model_name),
deploy_backend=deploy_backend)
...@@ -3,6 +3,7 @@ Quantization: ...@@ -3,6 +3,7 @@ Quantization:
quantize_op_types: quantize_op_types:
- conv2d - conv2d
- depthwise_conv2d - depthwise_conv2d
onnx_format: True
Distillation: Distillation:
alpha: 1.0 alpha: 1.0
......
...@@ -119,6 +119,7 @@ class TestDictQATDist(ACTBase): ...@@ -119,6 +119,7 @@ class TestDictQATDist(ACTBase):
train_dataloader=train_loader, train_dataloader=train_loader,
eval_dataloader=train_loader) # eval_function to verify accuracy eval_dataloader=train_loader) # eval_function to verify accuracy
ac.compress() ac.compress()
ac.export_onnx()
class TestLoadONNXModel(ACTBase): class TestLoadONNXModel(ACTBase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册