提交 de866bf7 编写于 作者: C czhu15 提交者: Tao Luo

Enable MKL-DNN for slim FP32 vs. INT8 tests and update slim QAT MKL-DNN readme document (#18221)

上级 6e310e2d
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
function(inference_analysis_python_api_int8_test target model_dir data_dir filename) function(_inference_analysis_python_api_int8_test target model_dir data_dir filename use_mkldnn)
py_test(${target} SRCS ${filename} py_test(${target} SRCS ${filename}
ENVS CPU_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} ENVS CPU_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=${use_mkldnn}
ARGS --infer_model ${model_dir}/model ARGS --infer_model ${model_dir}/model
--infer_data ${data_dir}/data.bin --infer_data ${data_dir}/data.bin
--int8_model_save_path int8_models/${target} --int8_model_save_path int8_models/${target}
...@@ -11,6 +12,14 @@ function(inference_analysis_python_api_int8_test target model_dir data_dir filen ...@@ -11,6 +12,14 @@ function(inference_analysis_python_api_int8_test target model_dir data_dir filen
--batch_size 50) --batch_size 50)
endfunction() endfunction()
function(inference_analysis_python_api_int8_test target model_dir data_dir filename)
_inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename} False)
endfunction()
function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_dir filename)
_inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename} True)
endfunction()
function(inference_qat_int8_test target model_dir data_dir test_script use_mkldnn) function(inference_qat_int8_test target model_dir data_dir test_script use_mkldnn)
py_test(${target} SRCS ${test_script} py_test(${target} SRCS ${test_script}
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
...@@ -44,6 +53,7 @@ if(LINUX AND WITH_MKLDNN) ...@@ -44,6 +53,7 @@ if(LINUX AND WITH_MKLDNN)
# mobilenet int8 # mobilenet int8
set(INT8_MOBILENET_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv1") set(INT8_MOBILENET_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv1")
inference_analysis_python_api_int8_test(test_slim_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE}) inference_analysis_python_api_int8_test(test_slim_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE})
inference_analysis_python_api_int8_test_mkldnn(test_slim_int8_mobilenet_mkldnn ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE})
# temporarily adding WITH_SLIM_MKLDNN_FULL_TEST FLAG for QA testing the following UTs locally, # temporarily adding WITH_SLIM_MKLDNN_FULL_TEST FLAG for QA testing the following UTs locally,
# since the following UTs cost too much time on CI test. # since the following UTs cost too much time on CI test.
......
# SLIM Quantization-aware training (QAT) on INT8 MKL-DNN
This document describes how to use [Paddle Slim](https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/advanced_usage/paddle_slim/paddle_slim.md) to convert a quantization-aware trained model to an INT8 MKL-DNN runnable model which has almost the same accuracy as QAT on GoogleNet, MobileNet-V1, MobileNet-V2, ResNet-101, ResNet-50, VGG16 and VGG19. We provide the accuracy results compared with fake QAT accuracy by running the QAT trained model with MKL-DNN int8 kernel on above 7 models.
## 0. Prerequisite
You need to install at least PaddlePaddle-1.5 python package `pip install paddlepaddle==1.5`.
## 1. How to generate INT8 MKL-DNN QAT model
You can refer to the unit test in [test_quantization_mkldnn_pass.py](test_quantization_mkldnn_pass.py). Users firstly use PaddleSlim quantization strategy to get a saved fake QAT model by [QuantizationFreezePass](https://github.com/PaddlePaddle/models/tree/develop/PaddleSlim/quant_low_level_api), then use the `TransformForMkldnnPass` to get the graph which can be run with MKL-DNN INT8 kernel. In Paddle Release 1.5, this pass only supports `conv2d` and `depthwise_conv2d` with channel-wise quantization for weights.
```python
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization import TransformForMkldnnPass
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
# Create the IrGraph by Program
graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False)
place = fluid.CPUPlace()
# Convert the IrGraph to MKL-DNN supported INT8 IrGraph by using
# TransformForMkldnnPass
mkldnn_pass = TransformForMkldnnPass(fluid.global_scope(), place)
# Apply TransformForMkldnnPass to IrGraph
mkldnn_pass.apply(graph)
```
## 2. Accuracy benchmark
>**I. Top-1 Accuracy on Intel(R) Xeon(R) Gold 6271**
| Model | Fake QAT Top1 Accuracy | Fake QAT Top5 Accuracy |MKL-DNN INT8 Top1 Accuracy | Top1 Diff | MKL-DNN INT8 Top5 Accuracy | Top5 Diff |
| :----------: | :--------------------: | :--------------------: |:-----------------------: | :----------: | :------------------------: | :--------: |
| GoogleNet | 70.40% | 89.46% | 70.39% | 0.010% | 89.46% | 0.000% |
| MobileNet-V1 | 70.83% | 89.56% | 70.84% | -0.010% | 89.56% | 0.000% |
| MobileNet-V2 | 72.17% | 90.67% | 72.13% | 0.040% | 90.67% | 0.000% |
| ResNet-101 | 77.49% | 93.65% | 77.51% | -0.020% | 93.67% | -0.020% |
| ResNet-50 | 76.62% | 93.08% | 76.61% | 0.010% | 93.09% | -0.010% |
| VGG16 | 72.71% | 91.11% | 72.69% | 0.020% | 91.09% | 0.020% |
| VGG19 | 73.37% | 91.40% | 73.37% | 0.000% | 91.41% | -0.010% |
Notes:
* MKL-DNN and MKL are required.
## 3. How to reproduce the results
Three steps to reproduce the above-mentioned accuracy results, and we take ResNet50 benchmark as an example:
* ### Prepare dataset
```bash
cd /PATH/TO/PADDLE
python paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py
```
The converted data binary file is saved by default in `~/.cache/paddle/dataset/int8/download/int8_full_val.bin`
* ### Prepare model
You can run the following commands to download ResNet50 model.
```bash
mkdir -p /PATH/TO/DOWNLOAD/MODEL/
cd /PATH/TO/DOWNLOAD/MODEL/
export MODEL_NAME=ResNet50
wget http://paddle-inference-dist.bj.bcebos.com/int8/QAT_models/${MODEL_NAME}_qat_model.tar.gz
mkdir -p ${MODEL_NAME}
tar -xvf ${MODEL_NAME}_qat_model.tar.gz -C ${MODEL_NAME}
```
To download and verify all the 7 models, you need to set `MODEL_NAME` to one of the following values in command line:
```text
MODEL_NAME=ResNet50, ResNet101, GoogleNet, MobileNetV1, MobileNetV2, VGG16, VGG19
```
* ### Commands to reproduce benchmark
You can run `qat_int8_comparison.py` with the following arguments to reproduce the accuracy result on ResNet50.
```bash
OMP_NUM_THREADS=28 FLAGS_use_mkldnn=true python python/paddle/fluid/contrib/slim/tests/qat_int8_comparison.py --qat_model=/PATH/TO/DOWNLOAD/MODEL/${MODEL_NAME}/model --infer_data=~/.cache/paddle/dataset/int8/download/int8_full_val.bin --batch_size=50 --batch_num=1000 --acc_diff_threshold=0.001
```
> Notes: The above commands will cost maybe several hours in the prediction stage (include int8 prediction and fp32 prediction) since there have 50000 pictures need to be predicted in `int8_full_val.bin`. User can set `OMP_NUM_THREADS` to the max number of physical cores of the used server to accelerate the process.
...@@ -128,3 +128,4 @@ python ./test_mkldnn_int8_quantization_strategy.py --infer_model /PATH/TO/DOWNLO ...@@ -128,3 +128,4 @@ python ./test_mkldnn_int8_quantization_strategy.py --infer_model /PATH/TO/DOWNLO
Notes: Notes:
* The above commands will cost maybe several hours in the prediction stage (include int8 prediction and fp32 prediction) since there have 50000 pictures need to be predicted in `int8_full_val.bin` * The above commands will cost maybe several hours in the prediction stage (include int8 prediction and fp32 prediction) since there have 50000 pictures need to be predicted in `int8_full_val.bin`
* Running the above command with environment variable `FLAGS_use_mkldnn=true` will make the FP32 part of the test running using MKL-DNN (the INT8 part uses MKL-DNN either way).
...@@ -23,6 +23,8 @@ import six ...@@ -23,6 +23,8 @@ import six
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
from paddle.fluid.contrib.slim.core import Compressor from paddle.fluid.contrib.slim.core import Compressor
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
...@@ -112,6 +114,41 @@ class TestMKLDNNPostTrainingQuantStrategy(unittest.TestCase): ...@@ -112,6 +114,41 @@ class TestMKLDNNPostTrainingQuantStrategy(unittest.TestCase):
return new_config_path return new_config_path
def _transform_depthwise_conv(self, graph):
'''
Transform depthwise_conv2d into conv2d, with MKL-DNN only
'''
ops = graph.all_op_nodes()
for op_node in ops:
name = op_node.name()
if name in ['depthwise_conv2d']:
input_var_node = graph._find_node_by_name(
op_node.inputs, op_node.input("Input")[0])
weight_var_node = graph._find_node_by_name(
op_node.inputs, op_node.input("Filter")[0])
output_var_node = graph._find_node_by_name(
graph.all_var_nodes(), op_node.output("Output")[0])
attrs = {
name: op_node.op().attr(name)
for name in op_node.op().attr_names()
}
conv_op_node = graph.create_op_node(
op_type='conv2d',
attrs=attrs,
inputs={
'Input': input_var_node,
'Filter': weight_var_node
},
outputs={'Output': output_var_node})
graph.link_to(input_var_node, conv_op_node)
graph.link_to(weight_var_node, conv_op_node)
graph.link_to(conv_op_node, output_var_node)
graph.safe_remove_nodes(op_node)
return graph
def _predict(self, test_reader=None, model_path=None): def _predict(self, test_reader=None, model_path=None):
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
...@@ -125,6 +162,13 @@ class TestMKLDNNPostTrainingQuantStrategy(unittest.TestCase): ...@@ -125,6 +162,13 @@ class TestMKLDNNPostTrainingQuantStrategy(unittest.TestCase):
fetch_targets] = fluid.io.load_inference_model( fetch_targets] = fluid.io.load_inference_model(
model_path, exe, 'model', 'params') model_path, exe, 'model', 'params')
use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False))
if (use_mkldnn):
graph = IrGraph(
core.Graph(inference_program.desc), for_test=True)
graph = self._transform_depthwise_conv(graph)
inference_program = graph.to_program()
dshape = [3, 224, 224] dshape = [3, 224, 224]
top1 = 0.0 top1 = 0.0
top5 = 0.0 top5 = 0.0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册