提交 de866bf7 编写于 作者: C czhu15 提交者: Tao Luo

Enable MKL-DNN for slim FP32 vs. INT8 tests and update slim QAT MKL-DNN readme document (#18221)

上级 6e310e2d
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
function(inference_analysis_python_api_int8_test target model_dir data_dir filename)
function(_inference_analysis_python_api_int8_test target model_dir data_dir filename use_mkldnn)
py_test(${target} SRCS ${filename}
ENVS CPU_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=${use_mkldnn}
ARGS --infer_model ${model_dir}/model
--infer_data ${data_dir}/data.bin
--int8_model_save_path int8_models/${target}
......@@ -11,6 +12,14 @@ function(inference_analysis_python_api_int8_test target model_dir data_dir filen
--batch_size 50)
endfunction()
function(inference_analysis_python_api_int8_test target model_dir data_dir filename)
_inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename} False)
endfunction()
function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_dir filename)
_inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename} True)
endfunction()
function(inference_qat_int8_test target model_dir data_dir test_script use_mkldnn)
py_test(${target} SRCS ${test_script}
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
......@@ -44,6 +53,7 @@ if(LINUX AND WITH_MKLDNN)
# mobilenet int8
set(INT8_MOBILENET_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv1")
inference_analysis_python_api_int8_test(test_slim_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE})
inference_analysis_python_api_int8_test_mkldnn(test_slim_int8_mobilenet_mkldnn ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE})
# temporarily adding WITH_SLIM_MKLDNN_FULL_TEST FLAG for QA testing the following UTs locally,
# since the following UTs cost too much time on CI test.
......
# SLIM Quantization-aware training (QAT) on INT8 MKL-DNN
This document describes how to use [Paddle Slim](https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/advanced_usage/paddle_slim/paddle_slim.md) to convert a quantization-aware trained model to an INT8 MKL-DNN runnable model which has almost the same accuracy as QAT on GoogleNet, MobileNet-V1, MobileNet-V2, ResNet-101, ResNet-50, VGG16 and VGG19. We provide the accuracy results compared with fake QAT accuracy by running the QAT trained model with MKL-DNN int8 kernel on above 7 models.
## 0. Prerequisite
You need to install at least PaddlePaddle-1.5 python package `pip install paddlepaddle==1.5`.
## 1. How to generate INT8 MKL-DNN QAT model
You can refer to the unit test in [test_quantization_mkldnn_pass.py](test_quantization_mkldnn_pass.py). Users firstly use PaddleSlim quantization strategy to get a saved fake QAT model by [QuantizationFreezePass](https://github.com/PaddlePaddle/models/tree/develop/PaddleSlim/quant_low_level_api), then use the `TransformForMkldnnPass` to get the graph which can be run with MKL-DNN INT8 kernel. In Paddle Release 1.5, this pass only supports `conv2d` and `depthwise_conv2d` with channel-wise quantization for weights.
```python
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization import TransformForMkldnnPass
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
# Create the IrGraph by Program
graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False)
place = fluid.CPUPlace()
# Convert the IrGraph to MKL-DNN supported INT8 IrGraph by using
# TransformForMkldnnPass
mkldnn_pass = TransformForMkldnnPass(fluid.global_scope(), place)
# Apply TransformForMkldnnPass to IrGraph
mkldnn_pass.apply(graph)
```
## 2. Accuracy benchmark
>**I. Top-1 Accuracy on Intel(R) Xeon(R) Gold 6271**
| Model | Fake QAT Top1 Accuracy | Fake QAT Top5 Accuracy |MKL-DNN INT8 Top1 Accuracy | Top1 Diff | MKL-DNN INT8 Top5 Accuracy | Top5 Diff |
| :----------: | :--------------------: | :--------------------: |:-----------------------: | :----------: | :------------------------: | :--------: |
| GoogleNet | 70.40% | 89.46% | 70.39% | 0.010% | 89.46% | 0.000% |
| MobileNet-V1 | 70.83% | 89.56% | 70.84% | -0.010% | 89.56% | 0.000% |
| MobileNet-V2 | 72.17% | 90.67% | 72.13% | 0.040% | 90.67% | 0.000% |
| ResNet-101 | 77.49% | 93.65% | 77.51% | -0.020% | 93.67% | -0.020% |
| ResNet-50 | 76.62% | 93.08% | 76.61% | 0.010% | 93.09% | -0.010% |
| VGG16 | 72.71% | 91.11% | 72.69% | 0.020% | 91.09% | 0.020% |
| VGG19 | 73.37% | 91.40% | 73.37% | 0.000% | 91.41% | -0.010% |
Notes:
* MKL-DNN and MKL are required.
## 3. How to reproduce the results
Three steps to reproduce the above-mentioned accuracy results, and we take ResNet50 benchmark as an example:
* ### Prepare dataset
```bash
cd /PATH/TO/PADDLE
python paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py
```
The converted data binary file is saved by default in `~/.cache/paddle/dataset/int8/download/int8_full_val.bin`
* ### Prepare model
You can run the following commands to download ResNet50 model.
```bash
mkdir -p /PATH/TO/DOWNLOAD/MODEL/
cd /PATH/TO/DOWNLOAD/MODEL/
export MODEL_NAME=ResNet50
wget http://paddle-inference-dist.bj.bcebos.com/int8/QAT_models/${MODEL_NAME}_qat_model.tar.gz
mkdir -p ${MODEL_NAME}
tar -xvf ${MODEL_NAME}_qat_model.tar.gz -C ${MODEL_NAME}
```
To download and verify all the 7 models, you need to set `MODEL_NAME` to one of the following values in command line:
```text
MODEL_NAME=ResNet50, ResNet101, GoogleNet, MobileNetV1, MobileNetV2, VGG16, VGG19
```
* ### Commands to reproduce benchmark
You can run `qat_int8_comparison.py` with the following arguments to reproduce the accuracy result on ResNet50.
```bash
OMP_NUM_THREADS=28 FLAGS_use_mkldnn=true python python/paddle/fluid/contrib/slim/tests/qat_int8_comparison.py --qat_model=/PATH/TO/DOWNLOAD/MODEL/${MODEL_NAME}/model --infer_data=~/.cache/paddle/dataset/int8/download/int8_full_val.bin --batch_size=50 --batch_num=1000 --acc_diff_threshold=0.001
```
> Notes: The above commands will cost maybe several hours in the prediction stage (include int8 prediction and fp32 prediction) since there have 50000 pictures need to be predicted in `int8_full_val.bin`. User can set `OMP_NUM_THREADS` to the max number of physical cores of the used server to accelerate the process.
......@@ -128,3 +128,4 @@ python ./test_mkldnn_int8_quantization_strategy.py --infer_model /PATH/TO/DOWNLO
Notes:
* The above commands will cost maybe several hours in the prediction stage (include int8 prediction and fp32 prediction) since there have 50000 pictures need to be predicted in `int8_full_val.bin`
* Running the above command with environment variable `FLAGS_use_mkldnn=true` will make the FP32 part of the test running using MKL-DNN (the INT8 part uses MKL-DNN either way).
......@@ -23,6 +23,8 @@ import six
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
from paddle.fluid.contrib.slim.core import Compressor
from paddle.fluid.log_helper import get_logger
......@@ -112,6 +114,41 @@ class TestMKLDNNPostTrainingQuantStrategy(unittest.TestCase):
return new_config_path
def _transform_depthwise_conv(self, graph):
'''
Transform depthwise_conv2d into conv2d, with MKL-DNN only
'''
ops = graph.all_op_nodes()
for op_node in ops:
name = op_node.name()
if name in ['depthwise_conv2d']:
input_var_node = graph._find_node_by_name(
op_node.inputs, op_node.input("Input")[0])
weight_var_node = graph._find_node_by_name(
op_node.inputs, op_node.input("Filter")[0])
output_var_node = graph._find_node_by_name(
graph.all_var_nodes(), op_node.output("Output")[0])
attrs = {
name: op_node.op().attr(name)
for name in op_node.op().attr_names()
}
conv_op_node = graph.create_op_node(
op_type='conv2d',
attrs=attrs,
inputs={
'Input': input_var_node,
'Filter': weight_var_node
},
outputs={'Output': output_var_node})
graph.link_to(input_var_node, conv_op_node)
graph.link_to(weight_var_node, conv_op_node)
graph.link_to(conv_op_node, output_var_node)
graph.safe_remove_nodes(op_node)
return graph
def _predict(self, test_reader=None, model_path=None):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
......@@ -125,6 +162,13 @@ class TestMKLDNNPostTrainingQuantStrategy(unittest.TestCase):
fetch_targets] = fluid.io.load_inference_model(
model_path, exe, 'model', 'params')
use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False))
if (use_mkldnn):
graph = IrGraph(
core.Graph(inference_program.desc), for_test=True)
graph = self._transform_depthwise_conv(graph)
inference_program = graph.to_program()
dshape = [3, 224, 224]
top1 = 0.0
top5 = 0.0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册