From de866bf79234785c876a715d9d0ac588fad6002c Mon Sep 17 00:00:00 2001 From: czhu15 <41610754+czhu15@users.noreply.github.com> Date: Thu, 20 Jun 2019 14:46:04 +0800 Subject: [PATCH] Enable MKL-DNN for slim FP32 vs. INT8 tests and update slim QAT MKL-DNN readme document (#18221) --- .../fluid/contrib/slim/tests/CMakeLists.txt | 12 ++- .../slim/tests/QAT_mkldnn_int8_readme.md | 76 +++++++++++++++++++ ..._int8_mkldnn_post_training_quantization.md | 1 + .../test_mkldnn_int8_quantization_strategy.py | 44 +++++++++++ 4 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/contrib/slim/tests/QAT_mkldnn_int8_readme.md diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index c59df49f626..f76797da2a1 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -1,9 +1,10 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") -function(inference_analysis_python_api_int8_test target model_dir data_dir filename) +function(_inference_analysis_python_api_int8_test target model_dir data_dir filename use_mkldnn) py_test(${target} SRCS ${filename} ENVS CPU_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} + FLAGS_use_mkldnn=${use_mkldnn} ARGS --infer_model ${model_dir}/model --infer_data ${data_dir}/data.bin --int8_model_save_path int8_models/${target} @@ -11,6 +12,14 @@ function(inference_analysis_python_api_int8_test target model_dir data_dir filen --batch_size 50) endfunction() +function(inference_analysis_python_api_int8_test target model_dir data_dir filename) + _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename} False) +endfunction() + +function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_dir filename) + _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename} True) +endfunction() + function(inference_qat_int8_test target model_dir data_dir test_script use_mkldnn) py_test(${target} SRCS ${test_script} ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} @@ -44,6 +53,7 @@ if(LINUX AND WITH_MKLDNN) # mobilenet int8 set(INT8_MOBILENET_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv1") inference_analysis_python_api_int8_test(test_slim_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE}) + inference_analysis_python_api_int8_test_mkldnn(test_slim_int8_mobilenet_mkldnn ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE}) # temporarily adding WITH_SLIM_MKLDNN_FULL_TEST FLAG for QA testing the following UTs locally, # since the following UTs cost too much time on CI test. diff --git a/python/paddle/fluid/contrib/slim/tests/QAT_mkldnn_int8_readme.md b/python/paddle/fluid/contrib/slim/tests/QAT_mkldnn_int8_readme.md new file mode 100644 index 00000000000..843f7e2d335 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/QAT_mkldnn_int8_readme.md @@ -0,0 +1,76 @@ +# SLIM Quantization-aware training (QAT) on INT8 MKL-DNN + +This document describes how to use [Paddle Slim](https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/advanced_usage/paddle_slim/paddle_slim.md) to convert a quantization-aware trained model to an INT8 MKL-DNN runnable model which has almost the same accuracy as QAT on GoogleNet, MobileNet-V1, MobileNet-V2, ResNet-101, ResNet-50, VGG16 and VGG19. We provide the accuracy results compared with fake QAT accuracy by running the QAT trained model with MKL-DNN int8 kernel on above 7 models. + +## 0. Prerequisite +You need to install at least PaddlePaddle-1.5 python package `pip install paddlepaddle==1.5`. + +## 1. How to generate INT8 MKL-DNN QAT model +You can refer to the unit test in [test_quantization_mkldnn_pass.py](test_quantization_mkldnn_pass.py). Users firstly use PaddleSlim quantization strategy to get a saved fake QAT model by [QuantizationFreezePass](https://github.com/PaddlePaddle/models/tree/develop/PaddleSlim/quant_low_level_api), then use the `TransformForMkldnnPass` to get the graph which can be run with MKL-DNN INT8 kernel. In Paddle Release 1.5, this pass only supports `conv2d` and `depthwise_conv2d` with channel-wise quantization for weights. + +```python + import paddle.fluid as fluid + from paddle.fluid.contrib.slim.quantization import TransformForMkldnnPass + from paddle.fluid.framework import IrGraph + from paddle.fluid import core + + # Create the IrGraph by Program + graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False) + place = fluid.CPUPlace() + # Convert the IrGraph to MKL-DNN supported INT8 IrGraph by using + # TransformForMkldnnPass + mkldnn_pass = TransformForMkldnnPass(fluid.global_scope(), place) + # Apply TransformForMkldnnPass to IrGraph + mkldnn_pass.apply(graph) +``` + +## 2. Accuracy benchmark + +>**I. Top-1 Accuracy on Intel(R) Xeon(R) Gold 6271** + +| Model | Fake QAT Top1 Accuracy | Fake QAT Top5 Accuracy |MKL-DNN INT8 Top1 Accuracy | Top1 Diff | MKL-DNN INT8 Top5 Accuracy | Top5 Diff | +| :----------: | :--------------------: | :--------------------: |:-----------------------: | :----------: | :------------------------: | :--------: | +| GoogleNet | 70.40% | 89.46% | 70.39% | 0.010% | 89.46% | 0.000% | +| MobileNet-V1 | 70.83% | 89.56% | 70.84% | -0.010% | 89.56% | 0.000% | +| MobileNet-V2 | 72.17% | 90.67% | 72.13% | 0.040% | 90.67% | 0.000% | +| ResNet-101 | 77.49% | 93.65% | 77.51% | -0.020% | 93.67% | -0.020% | +| ResNet-50 | 76.62% | 93.08% | 76.61% | 0.010% | 93.09% | -0.010% | +| VGG16 | 72.71% | 91.11% | 72.69% | 0.020% | 91.09% | 0.020% | +| VGG19 | 73.37% | 91.40% | 73.37% | 0.000% | 91.41% | -0.010% | + +Notes: + +* MKL-DNN and MKL are required. + +## 3. How to reproduce the results +Three steps to reproduce the above-mentioned accuracy results, and we take ResNet50 benchmark as an example: + * ### Prepare dataset +```bash +cd /PATH/TO/PADDLE +python paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py +``` +The converted data binary file is saved by default in `~/.cache/paddle/dataset/int8/download/int8_full_val.bin` + * ### Prepare model +You can run the following commands to download ResNet50 model. + +```bash +mkdir -p /PATH/TO/DOWNLOAD/MODEL/ +cd /PATH/TO/DOWNLOAD/MODEL/ +export MODEL_NAME=ResNet50 +wget http://paddle-inference-dist.bj.bcebos.com/int8/QAT_models/${MODEL_NAME}_qat_model.tar.gz +mkdir -p ${MODEL_NAME} +tar -xvf ${MODEL_NAME}_qat_model.tar.gz -C ${MODEL_NAME} +``` + +To download and verify all the 7 models, you need to set `MODEL_NAME` to one of the following values in command line: + +```text +MODEL_NAME=ResNet50, ResNet101, GoogleNet, MobileNetV1, MobileNetV2, VGG16, VGG19 +``` +* ### Commands to reproduce benchmark +You can run `qat_int8_comparison.py` with the following arguments to reproduce the accuracy result on ResNet50. + +```bash +OMP_NUM_THREADS=28 FLAGS_use_mkldnn=true python python/paddle/fluid/contrib/slim/tests/qat_int8_comparison.py --qat_model=/PATH/TO/DOWNLOAD/MODEL/${MODEL_NAME}/model --infer_data=~/.cache/paddle/dataset/int8/download/int8_full_val.bin --batch_size=50 --batch_num=1000 --acc_diff_threshold=0.001 +``` +> Notes: The above commands will cost maybe several hours in the prediction stage (include int8 prediction and fp32 prediction) since there have 50000 pictures need to be predicted in `int8_full_val.bin`. User can set `OMP_NUM_THREADS` to the max number of physical cores of the used server to accelerate the process. diff --git a/python/paddle/fluid/contrib/slim/tests/slim_int8_mkldnn_post_training_quantization.md b/python/paddle/fluid/contrib/slim/tests/slim_int8_mkldnn_post_training_quantization.md index 33edb13c48c..0e9fd33ee36 100644 --- a/python/paddle/fluid/contrib/slim/tests/slim_int8_mkldnn_post_training_quantization.md +++ b/python/paddle/fluid/contrib/slim/tests/slim_int8_mkldnn_post_training_quantization.md @@ -128,3 +128,4 @@ python ./test_mkldnn_int8_quantization_strategy.py --infer_model /PATH/TO/DOWNLO Notes: * The above commands will cost maybe several hours in the prediction stage (include int8 prediction and fp32 prediction) since there have 50000 pictures need to be predicted in `int8_full_val.bin` +* Running the above command with environment variable `FLAGS_use_mkldnn=true` will make the FP32 part of the test running using MKL-DNN (the INT8 part uses MKL-DNN either way). diff --git a/python/paddle/fluid/contrib/slim/tests/test_mkldnn_int8_quantization_strategy.py b/python/paddle/fluid/contrib/slim/tests/test_mkldnn_int8_quantization_strategy.py index 1c41a316a62..36242efb8b3 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_mkldnn_int8_quantization_strategy.py +++ b/python/paddle/fluid/contrib/slim/tests/test_mkldnn_int8_quantization_strategy.py @@ -23,6 +23,8 @@ import six import numpy as np import paddle import paddle.fluid as fluid +from paddle.fluid.framework import IrGraph +from paddle.fluid import core from paddle.fluid.contrib.slim.core import Compressor from paddle.fluid.log_helper import get_logger @@ -112,6 +114,41 @@ class TestMKLDNNPostTrainingQuantStrategy(unittest.TestCase): return new_config_path + def _transform_depthwise_conv(self, graph): + ''' + Transform depthwise_conv2d into conv2d, with MKL-DNN only + ''' + ops = graph.all_op_nodes() + for op_node in ops: + name = op_node.name() + if name in ['depthwise_conv2d']: + input_var_node = graph._find_node_by_name( + op_node.inputs, op_node.input("Input")[0]) + weight_var_node = graph._find_node_by_name( + op_node.inputs, op_node.input("Filter")[0]) + output_var_node = graph._find_node_by_name( + graph.all_var_nodes(), op_node.output("Output")[0]) + attrs = { + name: op_node.op().attr(name) + for name in op_node.op().attr_names() + } + + conv_op_node = graph.create_op_node( + op_type='conv2d', + attrs=attrs, + inputs={ + 'Input': input_var_node, + 'Filter': weight_var_node + }, + outputs={'Output': output_var_node}) + + graph.link_to(input_var_node, conv_op_node) + graph.link_to(weight_var_node, conv_op_node) + graph.link_to(conv_op_node, output_var_node) + graph.safe_remove_nodes(op_node) + + return graph + def _predict(self, test_reader=None, model_path=None): place = fluid.CPUPlace() exe = fluid.Executor(place) @@ -125,6 +162,13 @@ class TestMKLDNNPostTrainingQuantStrategy(unittest.TestCase): fetch_targets] = fluid.io.load_inference_model( model_path, exe, 'model', 'params') + use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False)) + if (use_mkldnn): + graph = IrGraph( + core.Graph(inference_program.desc), for_test=True) + graph = self._transform_depthwise_conv(graph) + inference_program = graph.to_program() + dshape = [3, 224, 224] top1 = 0.0 top5 = 0.0 -- GitLab