未验证 提交 23a4f54b 编写于 作者: W Wojciech Uss 提交者: GitHub

rename qat into quant (#24948)

test=develop
上级 f1a9593d
...@@ -20,7 +20,7 @@ function(download_int8_data install_dir data_file) ...@@ -20,7 +20,7 @@ function(download_int8_data install_dir data_file)
endif() endif()
endfunction() endfunction()
function(download_qat_data install_dir data_file) function(download_quant_data install_dir data_file)
if (NOT EXISTS ${install_dir}/${data_file}) if (NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file}) inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file})
endif() endif()
...@@ -85,7 +85,7 @@ function(inference_analysis_api_test_with_fake_data_run TARGET_NAME test_binary ...@@ -85,7 +85,7 @@ function(inference_analysis_api_test_with_fake_data_run TARGET_NAME test_binary
--disable_mkldnn_fc=${disable_fc}) --disable_mkldnn_fc=${disable_fc})
endfunction() endfunction()
function(inference_analysis_api_qat_test_run TARGET_NAME test_binary fp32_model_dir int8_model_dir data_path) function(inference_analysis_api_quant_test_run TARGET_NAME test_binary fp32_model_dir int8_model_dir data_path)
inference_analysis_test_run(${TARGET_NAME} inference_analysis_test_run(${TARGET_NAME}
COMMAND ${test_binary} COMMAND ${test_binary}
ARGS --fp32_model=${fp32_model_dir} ARGS --fp32_model=${fp32_model_dir}
...@@ -249,7 +249,7 @@ if(WITH_MKLDNN) ...@@ -249,7 +249,7 @@ if(WITH_MKLDNN)
## Image classification models ## Image classification models
# ImageNet small dataset # ImageNet small dataset
# May be already downloaded for INT8 QAT unit tests # It may be already downloaded for Quant & INT8 unit tests
set(IMAGENET_DATA_ARCHIVE "imagenet_val_100_tail.tar.gz") set(IMAGENET_DATA_ARCHIVE "imagenet_val_100_tail.tar.gz")
set(IMAGENET_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/imagenet") set(IMAGENET_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/imagenet")
set(IMAGENET_DATA_PATH "${IMAGENET_DATA_DIR}/data.bin") set(IMAGENET_DATA_PATH "${IMAGENET_DATA_DIR}/data.bin")
...@@ -315,21 +315,21 @@ if(WITH_MKLDNN) ...@@ -315,21 +315,21 @@ if(WITH_MKLDNN)
download_int8_data(${INT8_MOBILENET_SSD_MODEL_DIR} "mobilenet_ssd_int8_model.tar.gz" ) download_int8_data(${INT8_MOBILENET_SSD_MODEL_DIR} "mobilenet_ssd_int8_model.tar.gz" )
inference_analysis_api_object_dection_int8_test_run(test_analyzer_int8_mobilenet_ssd ${INT8_OBJ_DETECT_TEST_APP} ${INT8_MOBILENET_SSD_MODEL_DIR} ${PASCALVOC_DATA_PATH}) inference_analysis_api_object_dection_int8_test_run(test_analyzer_int8_mobilenet_ssd ${INT8_OBJ_DETECT_TEST_APP} ${INT8_MOBILENET_SSD_MODEL_DIR} ${PASCALVOC_DATA_PATH})
### optimized FP32 vs. QAT INT8 tests ### optimized FP32 vs. Quant INT8 tests
set(QAT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/qat") set(QUANT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/quant")
set(QAT_IMG_CLASS_TEST_APP "test_analyzer_qat_image_classification") set(QUANT_IMG_CLASS_TEST_APP "test_analyzer_quant_image_classification")
set(QAT_IMG_CLASS_TEST_APP_SRC "analyzer_quant_image_classification_tester.cc") set(QUANT_IMG_CLASS_TEST_APP_SRC "analyzer_quant_image_classification_tester.cc")
# build test binary to be used in subsequent tests # build test binary to be used in subsequent tests
inference_analysis_api_test_build(${QAT_IMG_CLASS_TEST_APP} ${QAT_IMG_CLASS_TEST_APP_SRC}) inference_analysis_api_test_build(${QUANT_IMG_CLASS_TEST_APP} ${QUANT_IMG_CLASS_TEST_APP_SRC})
# MobileNet FP32 vs. QAT INT8 # MobileNetV1 FP32 vs. Quant INT8
# The FP32 model should already be downloaded for slim QAT unit tests # The FP32 model should already be downloaded for slim Quant unit tests
set(QAT2_MobileNet_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf") set(QUANT2_MobileNetV1_MODEL_DIR "${QUANT_DATA_DIR}/MobileNetV1_quant2")
set(QAT2_INT8_MobileNet_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf_int8") set(QUANT2_INT8_MobileNetV1_MODEL_DIR "${QUANT_DATA_DIR}/MobileNetV1_quant2_int8")
download_qat_data(${QAT2_INT8_MobileNet_MODEL_DIR} "MobileNet_qat_perf_int8.tar.gz") download_quant_data(${QUANT2_INT8_MobileNetV1_MODEL_DIR} "MobileNet_qat_perf_int8.tar.gz")
inference_analysis_api_qat_test_run(test_analyzer_qat_performance_benchmark ${QAT_IMG_CLASS_TEST_APP} ${QAT2_MobileNet_MODEL_DIR}/MobileNet_qat_perf/float ${QAT2_INT8_MobileNet_MODEL_DIR}/MobileNet_qat_perf_int8 ${IMAGENET_DATA_PATH}) inference_analysis_api_quant_test_run(test_analyzer_quant_performance_benchmark ${QUANT_IMG_CLASS_TEST_APP} ${QUANT2_MobileNetV1_MODEL_DIR}/MobileNet_qat_perf/float ${QUANT2_INT8_MobileNetV1_MODEL_DIR}/MobileNet_qat_perf_int8 ${IMAGENET_DATA_PATH})
### Other tests ### Other tests
......
...@@ -108,7 +108,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs, ...@@ -108,7 +108,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs,
} }
} }
TEST(Analyzer_qat_image_classification, quantization) { TEST(Analyzer_quant_image_classification, quantization) {
AnalysisConfig fp32_cfg; AnalysisConfig fp32_cfg;
SetConfig(&fp32_cfg, FLAGS_fp32_model); SetConfig(&fp32_cfg, FLAGS_fp32_model);
......
...@@ -16,17 +16,17 @@ import numpy as np ...@@ -16,17 +16,17 @@ import numpy as np
from .... import core from .... import core
from ....framework import IrGraph from ....framework import IrGraph
__all__ = ['Qat2Int8MkldnnPass'] __all__ = ['Quant2Int8MkldnnPass']
OpRole = core.op_proto_and_checker_maker.OpRole OpRole = core.op_proto_and_checker_maker.OpRole
class Qat2Int8MkldnnPass(object): class Quant2Int8MkldnnPass(object):
""" """
Transform a QAT model IrGraph into MKL-DNN supported INT8 IrGraph. Transform a quant model IrGraph into MKL-DNN supported INT8 IrGraph.
The pass consists of the following transformations: The pass consists of the following transformations:
1. gather scale values from fake quantize/dequantize operators, 1. gather scale values from fake quantize/dequantize operators,
2. extract FP32 inference model graph from the QAT graph, i.e. 2. extract FP32 inference model graph from the quant graph, i.e.
a. remove fake quantize/dequantize operators, a. remove fake quantize/dequantize operators,
b. dequantize conv2d and mul's weights, b. dequantize conv2d and mul's weights,
3. optimize the FP32 graph using standard FP32 optimization fuses 3. optimize the FP32 graph using standard FP32 optimization fuses
...@@ -67,7 +67,7 @@ class Qat2Int8MkldnnPass(object): ...@@ -67,7 +67,7 @@ class Qat2Int8MkldnnPass(object):
self._relu_ops = ['relu', 'relu6'] self._relu_ops = ['relu', 'relu6']
self._matmul_ops = ['matmul'] self._matmul_ops = ['matmul']
self._weight_scales = {} self._weight_scales = {}
# Collect the Input and Output sclaes from Fake QAT models # Collect the Input and Output sclaes from Fake quant models
self._var_quant_scales = {} self._var_quant_scales = {}
self._max_range = {} self._max_range = {}
self._s8_max = 127 self._s8_max = 127
...@@ -362,7 +362,7 @@ class Qat2Int8MkldnnPass(object): ...@@ -362,7 +362,7 @@ class Qat2Int8MkldnnPass(object):
ir_pass.set(attr, value) ir_pass.set(attr, value)
ir_pass.apply(cpp_graph) ir_pass.apply(cpp_graph)
if self._debug: if self._debug:
graph.draw('.', 'qat_fp32_{}'.format(pass_name), graph.draw('.', 'quant_fp32_{}'.format(pass_name),
graph.all_op_nodes()) graph.all_op_nodes())
self._remove_unused_var_nodes(graph) self._remove_unused_var_nodes(graph)
return graph return graph
...@@ -472,7 +472,7 @@ class Qat2Int8MkldnnPass(object): ...@@ -472,7 +472,7 @@ class Qat2Int8MkldnnPass(object):
self._find_avg_pooling_ids(graph)) self._find_avg_pooling_ids(graph))
ir_pass.apply(cpp_graph) ir_pass.apply(cpp_graph)
if self._debug: if self._debug:
graph.draw('.', 'qat_int8_{}'.format(ir_pass.type()), graph.draw('.', 'quant_int8_{}'.format(ir_pass.type()),
graph.all_op_nodes()) graph.all_op_nodes())
graph = self._apply_pass(graph, 'scale_matmul_fuse_pass') graph = self._apply_pass(graph, 'scale_matmul_fuse_pass')
graph = self._apply_pass(graph, graph = self._apply_pass(graph,
......
...@@ -17,10 +17,10 @@ from .... import core ...@@ -17,10 +17,10 @@ from .... import core
from ....framework import IrGraph from ....framework import IrGraph
from ....framework import IrNode from ....framework import IrNode
__all__ = ['QatInt8MkldnnPass'] __all__ = ['QuantInt8MkldnnPass']
class QatInt8MkldnnPass(object): class QuantInt8MkldnnPass(object):
""" """
Convert QuantizationFreezePass generated IrGraph to MKL-DNN supported INT8 Convert QuantizationFreezePass generated IrGraph to MKL-DNN supported INT8
IrGraph. Following transformations did in this pass: IrGraph. Following transformations did in this pass:
...@@ -48,13 +48,13 @@ class QatInt8MkldnnPass(object): ...@@ -48,13 +48,13 @@ class QatInt8MkldnnPass(object):
# The original graph will be rewrite. # The original graph will be rewrite.
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization \ from paddle.fluid.contrib.slim.quantization \
import QatInt8MkldnnPass import QuantInt8MkldnnPass
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid import core from paddle.fluid import core
graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False) graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False)
place = fluid.CPUPlace() place = fluid.CPUPlace()
mkldnn_pass = QatInt8MkldnnPass(fluid.global_scope(), mkldnn_pass = QuantInt8MkldnnPass(fluid.global_scope(),
place) place)
mkldnn_pass.apply(graph) mkldnn_pass.apply(graph)
""" """
...@@ -163,7 +163,7 @@ class QatInt8MkldnnPass(object): ...@@ -163,7 +163,7 @@ class QatInt8MkldnnPass(object):
'Filter': weight_var_node}, 'Filter': weight_var_node},
outputs={'Output': output_var_node}) outputs={'Output': output_var_node})
# Based on the QAT's scales to calculate the scales of MKL-DNN INT8 conv2d # Based on the Quant's scales to calculate the scales of MKL-DNN INT8 conv2d
scale_in = self._s8_max / self._in_scale[output_name] scale_in = self._s8_max / self._in_scale[output_name]
scale_w = [] scale_w = []
scale_w = [self._max_range[output_name] / self._s8_max] scale_w = [self._max_range[output_name] / self._s8_max]
...@@ -207,7 +207,7 @@ class QatInt8MkldnnPass(object): ...@@ -207,7 +207,7 @@ class QatInt8MkldnnPass(object):
'Y': weight_var_node}, 'Y': weight_var_node},
outputs={'Out': output_var_node}) outputs={'Out': output_var_node})
# Based on the QAT's scales to calculate MKL-DNN INT8 mul's scales # Based on the Quant's scales to calculate MKL-DNN INT8 mul's scales
scale_in = self._s8_max / self._in_scale[output_name] scale_in = self._s8_max / self._in_scale[output_name]
scale_w = [] scale_w = []
scale_w = [self._max_range[output_name] / self._s8_max] scale_w = [self._max_range[output_name] / self._s8_max]
......
...@@ -25,30 +25,30 @@ function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_pa ...@@ -25,30 +25,30 @@ function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_pa
_inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} True) _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} True)
endfunction() endfunction()
function(download_qat_data install_dir data_file) function(download_quant_data install_dir data_file)
if (NOT EXISTS ${install_dir}/${data_file}) if (NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file}) inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file})
endif() endif()
endfunction() endfunction()
function(download_qat_model install_dir data_file) function(download_quant_model install_dir data_file)
if (NOT EXISTS ${install_dir}/${data_file}) if (NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file}) inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file})
endif() endif()
endfunction() endfunction()
function(download_qat_fp32_model install_dir data_file) function(download_quant_fp32_model install_dir data_file)
if (NOT EXISTS ${install_dir}/${data_file}) if (NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models/fp32 ${data_file}) inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models/fp32 ${data_file})
endif() endif()
endfunction() endfunction()
function(inference_qat_int8_image_classification_test target qat_model_dir dataset_path) function(inference_quant_int8_image_classification_test target quant_model_dir dataset_path)
py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant_int8_image_classification_comparison.py" py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant_int8_image_classification_comparison.py"
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=true FLAGS_use_mkldnn=true
ARGS --qat_model ${qat_model_dir} ARGS --quant_model ${quant_model_dir}
--infer_data ${dataset_path} --infer_data ${dataset_path}
--batch_size 25 --batch_size 25
--batch_num 2 --batch_num 2
...@@ -57,12 +57,12 @@ endfunction() ...@@ -57,12 +57,12 @@ endfunction()
# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 25 # set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 25
function(inference_qat2_int8_image_classification_test target qat_model_dir fp32_model_dir dataset_path ops_to_quantize) function(inference_quant2_int8_image_classification_test target quant_model_dir fp32_model_dir dataset_path ops_to_quantize)
py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant2_int8_image_classification_comparison.py" py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant2_int8_image_classification_comparison.py"
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=true FLAGS_use_mkldnn=true
ARGS --qat_model ${qat_model_dir} ARGS --quant_model ${quant_model_dir}
--fp32_model ${fp32_model_dir} --fp32_model ${fp32_model_dir}
--infer_data ${dataset_path} --infer_data ${dataset_path}
--batch_size 10 --batch_size 10
...@@ -72,12 +72,12 @@ function(inference_qat2_int8_image_classification_test target qat_model_dir fp32 ...@@ -72,12 +72,12 @@ function(inference_qat2_int8_image_classification_test target qat_model_dir fp32
endfunction() endfunction()
# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20 # set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20
function(inference_qat2_int8_nlp_test target qat_model_dir fp32_model_dir dataset_path labels_path) function(inference_quant2_int8_nlp_test target quant_model_dir fp32_model_dir dataset_path labels_path)
py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant2_int8_nlp_comparison.py" py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant2_int8_nlp_comparison.py"
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=true FLAGS_use_mkldnn=true
ARGS --qat_model ${qat_model_dir} ARGS --quant_model ${quant_model_dir}
--fp32_model ${fp32_model_dir} --fp32_model ${fp32_model_dir}
--infer_data ${dataset_path} --infer_data ${dataset_path}
--labels ${labels_path} --labels ${labels_path}
...@@ -86,29 +86,30 @@ function(inference_qat2_int8_nlp_test target qat_model_dir fp32_model_dir datase ...@@ -86,29 +86,30 @@ function(inference_qat2_int8_nlp_test target qat_model_dir fp32_model_dir datase
--acc_diff_threshold 0.1) --acc_diff_threshold 0.1)
endfunction() endfunction()
function(download_qat_data install_dir data_file) function(download_quant_data install_dir data_file)
if (NOT EXISTS ${install_dir}/${data_file}) if (NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file}) inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file})
endif() endif()
endfunction() endfunction()
function(download_qat_model install_dir data_file) function(download_quant_model install_dir data_file)
if (NOT EXISTS ${install_dir}/${data_file}) if (NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file}) inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file})
endif() endif()
endfunction() endfunction()
function(save_qat_ic_model_test target qat_model_dir fp32_model_save_path int8_model_save_path ops_to_quantize) function(save_quant_ic_model_test target quant_model_dir fp32_model_save_path int8_model_save_path ops_to_quantize)
py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_quant_model.py py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_quant_model.py
ARGS --qat_model_path ${qat_model_dir} ARGS --quant_model_path ${quant_model_dir}
--fp32_model_save_path ${fp32_model_save_path} --fp32_model_save_path ${fp32_model_save_path}
--int8_model_save_path ${int8_model_save_path} --int8_model_save_path ${int8_model_save_path}
--ops_to_quantize ${ops_to_quantize}) --ops_to_quantize ${ops_to_quantize}
--debug)
endfunction() endfunction()
function(save_qat_nlp_model_test target qat_model_dir fp32_model_save_path int8_model_save_path) function(save_quant_nlp_model_test target quant_model_dir fp32_model_save_path int8_model_save_path)
py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_quant_model.py py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_quant_model.py
ARGS --qat_model_path ${qat_model_dir} ARGS --quant_model_path ${quant_model_dir}
--fp32_model_save_path ${fp32_model_save_path} --fp32_model_save_path ${fp32_model_save_path}
--int8_model_save_path ${int8_model_save_path}) --int8_model_save_path ${int8_model_save_path})
endfunction() endfunction()
...@@ -173,126 +174,126 @@ if(LINUX AND WITH_MKLDNN) ...@@ -173,126 +174,126 @@ if(LINUX AND WITH_MKLDNN)
inference_analysis_python_api_int8_test(test_slim_int8_vgg19 ${INT8_VGG19_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH}) inference_analysis_python_api_int8_test(test_slim_int8_vgg19 ${INT8_VGG19_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
endif() endif()
#### QAT FP32 & INT8 comparison python api tests #### QUANT & INT8 comparison python api tests
set(QAT_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/qat") set(QUANT_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/quant")
### QATv1 for image classification ### Quant1 for image classification
# QAT ResNet50 # Quant ResNet50
set(QAT_RESNET50_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_QAT") set(QUANT_RESNET50_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet50_quant")
set(QAT_RESNET50_MODEL_ARCHIVE "ResNet50_qat_model.tar.gz") set(QUANT_RESNET50_MODEL_ARCHIVE "ResNet50_qat_model.tar.gz")
download_qat_model(${QAT_RESNET50_MODEL_DIR} ${QAT_RESNET50_MODEL_ARCHIVE}) download_quant_model(${QUANT_RESNET50_MODEL_DIR} ${QUANT_RESNET50_MODEL_ARCHIVE})
inference_qat_int8_image_classification_test(test_qat_int8_resnet50_mkldnn ${QAT_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) inference_quant_int8_image_classification_test(test_quant_int8_resnet50_mkldnn ${QUANT_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
# QAT ResNet101 # Quant ResNet101
set(QAT_RESNET101_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet101_QAT") set(QUANT_RESNET101_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet101_quant")
set(QAT_RESNET101_MODEL_ARCHIVE "ResNet101_qat_model.tar.gz") set(QUANT_RESNET101_MODEL_ARCHIVE "ResNet101_qat_model.tar.gz")
download_qat_model(${QAT_RESNET101_MODEL_DIR} ${QAT_RESNET101_MODEL_ARCHIVE}) download_quant_model(${QUANT_RESNET101_MODEL_DIR} ${QUANT_RESNET101_MODEL_ARCHIVE})
# inference_qat_int8_image_classification_test(test_qat_int8_resnet101_mkldnn ${QAT_RESNET101_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) # inference_quant_int8_image_classification_test(test_quant_int8_resnet101_mkldnn ${QUANT_RESNET101_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
# QAT GoogleNet # Quant GoogleNet
set(QAT_GOOGLENET_MODEL_DIR "${QAT_INSTALL_DIR}/GoogleNet_QAT") set(QUANT_GOOGLENET_MODEL_DIR "${QUANT_INSTALL_DIR}/GoogleNet_quant")
set(QAT_GOOGLENET_MODEL_ARCHIVE "GoogleNet_qat_model.tar.gz") set(QUANT_GOOGLENET_MODEL_ARCHIVE "GoogleNet_qat_model.tar.gz")
download_qat_model(${QAT_GOOGLENET_MODEL_DIR} ${QAT_GOOGLENET_MODEL_ARCHIVE}) download_quant_model(${QUANT_GOOGLENET_MODEL_DIR} ${QUANT_GOOGLENET_MODEL_ARCHIVE})
inference_qat_int8_image_classification_test(test_qat_int8_googlenet_mkldnn ${QAT_GOOGLENET_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) inference_quant_int8_image_classification_test(test_quant_int8_googlenet_mkldnn ${QUANT_GOOGLENET_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
# QAT MobileNetV1 # Quant MobileNetV1
set(QAT_MOBILENETV1_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNetV1_QAT") set(QUANT_MOBILENETV1_MODEL_DIR "${QUANT_INSTALL_DIR}/MobileNetV1_quant")
set(QAT_MOBILENETV1_MODEL_ARCHIVE "MobileNetV1_qat_model.tar.gz") set(QUANT_MOBILENETV1_MODEL_ARCHIVE "MobileNetV1_qat_model.tar.gz")
download_qat_model(${QAT_MOBILENETV1_MODEL_DIR} ${QAT_MOBILENETV1_MODEL_ARCHIVE}) download_quant_model(${QUANT_MOBILENETV1_MODEL_DIR} ${QUANT_MOBILENETV1_MODEL_ARCHIVE})
inference_qat_int8_image_classification_test(test_qat_int8_mobilenetv1_mkldnn ${QAT_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) inference_quant_int8_image_classification_test(test_quant_int8_mobilenetv1_mkldnn ${QUANT_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
# QAT MobileNetV2 # Quant MobileNetV2
set(QAT_MOBILENETV2_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNetV2_QAT") set(QUANT_MOBILENETV2_MODEL_DIR "${QUANT_INSTALL_DIR}/MobileNetV2_quant")
set(QAT_MOBILENETV2_MODEL_ARCHIVE "MobileNetV2_qat_model.tar.gz") set(QUANT_MOBILENETV2_MODEL_ARCHIVE "MobileNetV2_qat_model.tar.gz")
download_qat_model(${QAT_MOBILENETV2_MODEL_DIR} ${QAT_MOBILENETV2_MODEL_ARCHIVE}) download_quant_model(${QUANT_MOBILENETV2_MODEL_DIR} ${QUANT_MOBILENETV2_MODEL_ARCHIVE})
inference_qat_int8_image_classification_test(test_qat_int8_mobilenetv2_mkldnn ${QAT_MOBILENETV2_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) inference_quant_int8_image_classification_test(test_quant_int8_mobilenetv2_mkldnn ${QUANT_MOBILENETV2_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
# QAT VGG16 # Quant VGG16
set(QAT_VGG16_MODEL_DIR "${QAT_INSTALL_DIR}/VGG16_QAT") set(QUANT_VGG16_MODEL_DIR "${QUANT_INSTALL_DIR}/VGG16_quant")
set(QAT_VGG16_MODEL_ARCHIVE "VGG16_qat_model.tar.gz") set(QUANT_VGG16_MODEL_ARCHIVE "VGG16_qat_model.tar.gz")
download_qat_model(${QAT_VGG16_MODEL_DIR} ${QAT_VGG16_MODEL_ARCHIVE}) download_quant_model(${QUANT_VGG16_MODEL_DIR} ${QUANT_VGG16_MODEL_ARCHIVE})
# inference_qat_int8_image_classification_test(test_qat_int8_vgg16_mkldnn ${QAT_VGG16_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) # inference_quant_int8_image_classification_test(test_quant_int8_vgg16_mkldnn ${QUANT_VGG16_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
# QAT VGG19 # Quant VGG19
set(QAT_VGG19_MODEL_DIR "${QAT_INSTALL_DIR}/VGG19_QAT") set(QUANT_VGG19_MODEL_DIR "${QUANT_INSTALL_DIR}/VGG19_quant")
set(QAT_VGG19_MODEL_ARCHIVE "VGG19_qat_model.tar.gz") set(QUANT_VGG19_MODEL_ARCHIVE "VGG19_qat_model.tar.gz")
download_qat_model(${QAT_VGG19_MODEL_DIR} ${QAT_VGG19_MODEL_ARCHIVE}) download_quant_model(${QUANT_VGG19_MODEL_DIR} ${QUANT_VGG19_MODEL_ARCHIVE})
# inference_qat_int8_image_classification_test(test_qat_int8_vgg19_mkldnn ${QAT_VGG19_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) # inference_quant_int8_image_classification_test(test_quant_int8_vgg19_mkldnn ${QUANT_VGG19_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
### QATv2 for image classification ### Quant2 for image classification
set(QAT2_IC_OPS_TO_QUANTIZE "conv2d,pool2d") set(QUANT2_IC_OPS_TO_QUANTIZE "conv2d,pool2d")
# QAT2 ResNet50 with input/output scales in `fake_quantize_moving_average_abs_max` operators, # Quant2 ResNet50 with input/output scales in `fake_quantize_moving_average_abs_max` operators,
# with weight scales in `fake_dequantize_max_abs` operators # with weight scales in `fake_dequantize_max_abs` operators
set(QAT2_RESNET50_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_perf") set(QUANT2_RESNET50_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet50_quant2")
set(QUANT2_RESNET50_MODEL_ARCHIVE "ResNet50_qat_perf.tar.gz")
download_quant_model(${QUANT2_RESNET50_MODEL_DIR} ${QUANT2_RESNET50_MODEL_ARCHIVE})
set(FP32_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50") set(FP32_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50")
set(QAT2_RESNET50_MODEL_ARCHIVE "ResNet50_qat_perf.tar.gz") inference_quant2_int8_image_classification_test(test_quant2_int8_resnet50_mkldnn ${QUANT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QUANT2_IC_OPS_TO_QUANTIZE})
download_qat_model(${QAT2_RESNET50_MODEL_DIR} ${QAT2_RESNET50_MODEL_ARCHIVE})
inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_OPS_TO_QUANTIZE})
# QAT2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes, # Quant2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes,
# with weight scales in `fake_dequantize_max_abs` operators # with weight scales in `fake_dequantize_max_abs` operators
set(QAT2_RESNET50_RANGE_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_range") set(QUANT2_RESNET50_RANGE_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet50_quant2_range")
set(QAT2_RESNET50_RANGE_MODEL_ARCHIVE "ResNet50_qat_range.tar.gz") set(QUANT2_RESNET50_RANGE_MODEL_ARCHIVE "ResNet50_qat_range.tar.gz")
download_qat_model(${QAT2_RESNET50_RANGE_MODEL_DIR} ${QAT2_RESNET50_RANGE_MODEL_ARCHIVE}) download_quant_model(${QUANT2_RESNET50_RANGE_MODEL_DIR} ${QUANT2_RESNET50_RANGE_MODEL_ARCHIVE})
inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_range_mkldnn ${QAT2_RESNET50_RANGE_MODEL_DIR}/ResNet50_qat_range ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_OPS_TO_QUANTIZE}) inference_quant2_int8_image_classification_test(test_quant2_int8_resnet50_range_mkldnn ${QUANT2_RESNET50_RANGE_MODEL_DIR}/ResNet50_qat_range ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QUANT2_IC_OPS_TO_QUANTIZE})
# QAT2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes, # Quant2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes,
# with weight scales in `fake_channel_wise_dequantize_max_abs` operators # with weight scales in `fake_channel_wise_dequantize_max_abs` operators
set(QAT2_RESNET50_CHANNELWISE_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_channelwise") set(QUANT2_RESNET50_CHANNELWISE_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet50_quant2_channelwise")
set(QAT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE "ResNet50_qat_channelwise.tar.gz") set(QUANT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE "ResNet50_qat_channelwise.tar.gz")
download_qat_model(${QAT2_RESNET50_CHANNELWISE_MODEL_DIR} ${QAT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE}) download_quant_model(${QUANT2_RESNET50_CHANNELWISE_MODEL_DIR} ${QUANT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE})
inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_channelwise_mkldnn ${QAT2_RESNET50_CHANNELWISE_MODEL_DIR}/ResNet50_qat_channelwise ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_OPS_TO_QUANTIZE}) inference_quant2_int8_image_classification_test(test_quant2_int8_resnet50_channelwise_mkldnn ${QUANT2_RESNET50_CHANNELWISE_MODEL_DIR}/ResNet50_qat_channelwise ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QUANT2_IC_OPS_TO_QUANTIZE})
# QAT2 MobileNetV1 # Quant2 MobileNetV1
set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNet_qat_perf") set(QUANT2_MOBILENETV1_MODEL_DIR "${QUANT_INSTALL_DIR}/MobileNetV1_quant2")
set(QUANT2_MOBILENETV1_MODEL_ARCHIVE "MobileNet_qat_perf.tar.gz")
download_quant_model(${QUANT2_MOBILENETV1_MODEL_DIR} ${QUANT2_MOBILENETV1_MODEL_ARCHIVE})
set(FP32_MOBILENETV1_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1") set(FP32_MOBILENETV1_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1")
set(QAT2_MOBILENETV1_MODEL_ARCHIVE "MobileNet_qat_perf.tar.gz") inference_quant2_int8_image_classification_test(test_quant2_int8_mobilenetv1_mkldnn ${QUANT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf/float ${FP32_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QUANT2_IC_OPS_TO_QUANTIZE})
download_qat_model(${QAT2_MOBILENETV1_MODEL_DIR} ${QAT2_MOBILENETV1_MODEL_ARCHIVE})
inference_qat2_int8_image_classification_test(test_qat2_int8_mobilenetv1_mkldnn ${QAT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf/float ${FP32_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_OPS_TO_QUANTIZE})
### QATv2 for NLP ### Quant2 for NLP
set(NLP_DATA_ARCHIVE "Ernie_dataset.tar.gz") set(NLP_DATA_ARCHIVE "Ernie_dataset.tar.gz")
set(NLP_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_dataset") set(NLP_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_dataset")
set(NLP_DATA_PATH "${NLP_DATA_DIR}/Ernie_dataset/1.8w.bs1") set(NLP_DATA_PATH "${NLP_DATA_DIR}/Ernie_dataset/1.8w.bs1")
set(NLP_LABLES_PATH "${NLP_DATA_DIR}/Ernie_dataset/label.xnli.dev") set(NLP_LABLES_PATH "${NLP_DATA_DIR}/Ernie_dataset/label.xnli.dev")
download_qat_data(${NLP_DATA_DIR} ${NLP_DATA_ARCHIVE}) download_quant_data(${NLP_DATA_DIR} ${NLP_DATA_ARCHIVE})
# QAT2 Ernie # Quant2 Ernie
set(QAT2_ERNIE_MODEL_ARCHIVE "ernie_qat.tar.gz") set(QUANT2_ERNIE_MODEL_ARCHIVE "ernie_qat.tar.gz")
set(QAT2_ERNIE_MODEL_DIR "${QAT_INSTALL_DIR}/Ernie_qat") set(QUANT2_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_quant2")
download_qat_model(${QAT2_ERNIE_MODEL_DIR} ${QAT2_ERNIE_MODEL_ARCHIVE}) download_quant_model(${QUANT2_ERNIE_MODEL_DIR} ${QUANT2_ERNIE_MODEL_ARCHIVE})
set(FP32_ERNIE_MODEL_ARCHIVE "ernie_fp32_model.tar.gz") set(FP32_ERNIE_MODEL_ARCHIVE "ernie_fp32_model.tar.gz")
set(FP32_ERNIE_MODEL_DIR "${QAT_INSTALL_DIR}/Ernie_float") set(FP32_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_float")
download_qat_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE}) download_quant_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE})
inference_qat2_int8_nlp_test(test_qat2_int8_ernie_mkldnn ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} ${NLP_LABLES_PATH}) inference_quant2_int8_nlp_test(test_quant2_int8_ernie_mkldnn ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} ${NLP_LABLES_PATH})
### Save QAT2 FP32 model or QAT2 INT8 model ### Save FP32 model or INT8 model from Quant model
set(QAT2_INT8_RESNET50_SAVE_PATH "${QAT_INSTALL_DIR}/ResNet50_qat2_int8") set(QUANT2_INT8_RESNET50_SAVE_PATH "${QUANT_INSTALL_DIR}/ResNet50_quant2_int8")
set(QAT2_FP32_RESNET50_SAVE_PATH "${QAT_INSTALL_DIR}/ResNet50_qat2_fp32") set(QUANT2_FP32_RESNET50_SAVE_PATH "${QUANT_INSTALL_DIR}/ResNet50_quant2_fp32")
save_qat_ic_model_test(save_qat2_model_resnet50 ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QAT2_FP32_RESNET50_SAVE_PATH} ${QAT2_INT8_RESNET50_SAVE_PATH} ${QAT2_IC_OPS_TO_QUANTIZE}) save_quant_ic_model_test(save_quant2_model_resnet50 ${QUANT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QUANT2_FP32_RESNET50_SAVE_PATH} ${QUANT2_INT8_RESNET50_SAVE_PATH} ${QUANT2_IC_OPS_TO_QUANTIZE})
set(QAT2_INT8_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_int8") set(QUANT2_INT8_ERNIE_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_int8")
set(QAT2_FP32_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_fp32") set(QUANT2_FP32_ERNIE_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_fp32")
save_qat_nlp_model_test(save_qat2_model_ernie ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QAT2_FP32_ERNIE_SAVE_PATH} ${QAT2_INT8_ERNIE_SAVE_PATH}) save_quant_nlp_model_test(save_quant2_model_ernie ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QUANT2_FP32_ERNIE_SAVE_PATH} ${QUANT2_INT8_ERNIE_SAVE_PATH})
# Convert QAT2 model to dot and pdf files # Convert Quant2 model to dot and pdf files
set(QAT2_INT8_ERNIE_DOT_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_int8_dot_file") set(QUANT2_INT8_ERNIE_DOT_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_int8_dot_file")
convert_model2dot_test(convert_model2dot_ernie ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QAT2_INT8_ERNIE_DOT_SAVE_PATH} "Ernie_qat2_int8") convert_model2dot_test(convert_model2dot_ernie ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QUANT2_INT8_ERNIE_DOT_SAVE_PATH} "Ernie_quant2_int8")
endif() endif()
# Since the tests for QAT FP32 & INT8 comparison support only testing on Linux # Since the tests for Quant & INT8 comparison support only testing on Linux
# with MKL-DNN, we remove it here to not test it on other systems. # with MKL-DNN, we remove it here to not test it on other systems.
list(REMOVE_ITEM TEST_OPS list(REMOVE_ITEM TEST_OPS
test_mkldnn_int8_quantization_strategy test_mkldnn_int8_quantization_strategy
qat_int8_image_classification_comparison quant_int8_image_classification_comparison
qat_int8_nlp_comparison) quant_int8_nlp_comparison)
#TODO(wanghaoshuang): Fix this unitest failed on GCC8. #TODO(wanghaoshuang): Fix this unitest failed on GCC8.
LIST(REMOVE_ITEM TEST_OPS test_auto_pruning) LIST(REMOVE_ITEM TEST_OPS test_auto_pruning)
......
# SLIM Quantization-aware training (QAT) for INT8 MKL-DNN # SLIM Quantization-aware training (QAT) for INT8 MKL-DNN
This document describes how to use [Paddle Slim](https://paddlepaddle.github.io/PaddleSlim/index.html) to convert a quantization-aware trained model into INT8 MKL-DNN quantized model and run it. This document describes how to use [Paddle Slim](https://paddlepaddle.github.io/PaddleSlim/index.html) to convert a quantization-aware trained model (Quant model) into INT8 MKL-DNN quantized model and run it.
In **Release 1.5**, we have released the first approach to the MKL-DNN-based quantization of QAT models, called QAT1. It enabled the `conv2d` and `mul` INT8 MKL-DNN kernels for QAT trained models (GoogleNet, MobileNetV1, MobileNetV2, ResNet50, ResNet101, VGG16, and VGG19) with 0.05% accuracy diff. In **Release 1.5**, we have released the first approach to the MKL-DNN-based quantization of Quant models, called Quant1. It enabled the `conv2d` and `mul` INT8 MKL-DNN kernels for Quant trained models (GoogleNet, MobileNetV1, MobileNetV2, ResNet50, ResNet101, VGG16, and VGG19) with 0.05% accuracy diff.
In **Release 1.6**, a new approach was introduced, called QAT2, which adds support for more performance optimizations and more INT8 MKL-DNN kernels. INT8 MKL-DNN models obtained using QAT2 have much better inference performance than using QAT1, with only a little bit bigger accuracy diff. In **Release 1.6**, a new approach was introduced, called Quant2, which adds support for more performance optimizations and more INT8 MKL-DNN kernels. INT8 MKL-DNN models obtained using Quant2 have much better inference performance than using Quant1, with only a little bit bigger accuracy diff.
In **Release 1.7**, a support for [Ernie (NLP) QAT trained model](https://github.com/PaddlePaddle/benchmark/tree/master/Inference/c%2B%2B/ernie/mkldnn) was added to the QAT2. In **Release 1.7**, a support for [Ernie (NLP) Quant trained model](https://github.com/PaddlePaddle/benchmark/tree/master/Inference/c%2B%2B/ernie/mkldnn) was added to the Quant2.
In **Release 2.0**, further optimizations were added to the QAT2: INT8 `matmul` kernel, inplace execution of activation and `elementwise_add` operators, and broader support for quantization aware strategy from PaddleSlim. In **Release 2.0**, further optimizations were added to the Quant2: INT8 `matmul` kernel, inplace execution of activation and `elementwise_add` operators, and broader support for quantization aware strategy from PaddleSlim.
In this document we focus on the QAT2 approach only. In this document we focus on the Quant2 approach only.
## 0. Prerequisites ## 0. Prerequisites
* PaddlePaddle in version 2.0 or higher is required. For instructions on how to install it see the [installation document](https://www.paddlepaddle.org.cn/install/quick). * PaddlePaddle in version 2.0 or higher is required. For instructions on how to install it see the [installation document](https://www.paddlepaddle.org.cn/install/quick).
...@@ -20,15 +20,15 @@ In this document we focus on the QAT2 approach only. ...@@ -20,15 +20,15 @@ In this document we focus on the QAT2 approach only.
## 1. Introduction ## 1. Introduction
There are two forms of quantization supported in PaddlePaddle: [post-training quantization](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/inference/tests/api/int8_mkldnn_quantization.md) (PTQ) and quantization-aware training (QAT). Using both PTQ and QAT a user can convert models created by PaddleSlim into INT8 models and run INT8 inference on CPU. PTQ is more automatic and requires less model preparation than QAT, but usually QAT gives better accuracy with similar performance. In this document we focus on QAT2 approach to the QAT and INT8 quantization. There are two approaches to quantization supported in PaddlePaddle: [post-training quantization](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/inference/tests/api/int8_mkldnn_quantization.md) (PTQ) and quantization-aware training (QAT). Using both PTQ and QAT a user can convert models created by PaddleSlim into INT8 models and run INT8 inference on CPU. PTQ is more automatic and requires less model preparation. However, QAT usually gives better accuracy with similar performance. In this document we focus on a transformation from intermediate models obtained during the QAT process (Quant models) into MKL-DNN INT8 models. We call this procedure Quant2.
## 2. How to turn an FP32 model into a QAT model? ## 2. How to turn an FP32 model into a Quant model?
A procedure on how to transform an FP32 model into a QAT model supported by the QAT2 approach is described in [this document](https://github.com/PaddlePaddle/PaddleSlim/blob/80c9fab3f419880dd19ca6ea30e0f46a2fedf6b3/demo/mkldnn_quant/quant_aware/PaddleCV_mkldnn_quantaware_tutorial.md). A procedure on how to transform an FP32 model into a Quant model supported by the Quant2 approach is described in [this document](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/demo/mkldnn_quant/README.md).
## 3. How to turn a QAT model into an INT8 MKL-DNN model? ## 3. How to turn a Quant model into an INT8 MKL-DNN model?
A QAT model can be transformed into an INT8 quantized model if it contains enough information about quantization scales for every quantized operator in the graph. The process of quantization is done by the `Qat2Int8MkldnnPass` pass which comprises several steps: A Quant model can be transformed into an INT8 quantized model if it contains enough information about quantization scales for every quantized operator in the graph. The process of quantization is done by the `Quant2Int8MkldnnPass` pass which comprises several steps:
### Gathering scales ### Gathering scales
...@@ -51,7 +51,7 @@ Notes: ...@@ -51,7 +51,7 @@ Notes:
```... → input1 → conv2d → output1 → batch_norm → output2 → relu → output3 → ...``` ```... → input1 → conv2d → output1 → batch_norm → output2 → relu → output3 → ...```
and we want to quantize the `conv2d` op, then after applying FP32 optimizations the sequence will become and we want to quantize the `conv2d` op, then after applying FP32 optimizations the sequence will become
```... → input1 → conv2d → output3 → ...``` ```... → input1 → conv2d → output3 → ...```
and the quantization scales have to be collected for the `input1` and `outpu3` tensors in the QAT model. and the quantization scales have to be collected for the `input1` and `outpu3` tensors in the Quant model.
2. Quantization of the following operators is supported: `conv2d`, `depthwise_conv2d`, `mul`, `fc`, `matmul`, `pool2d`, `reshape2`, `transpose2`, `concat`. 2. Quantization of the following operators is supported: `conv2d`, `depthwise_conv2d`, `mul`, `fc`, `matmul`, `pool2d`, `reshape2`, `transpose2`, `concat`.
3. The longest sequence of consecutive quantizable operators in the model, the biggest performance boost can be achieved through quantization: 3. The longest sequence of consecutive quantizable operators in the model, the biggest performance boost can be achieved through quantization:
```... → conv2d → conv2d → pool2d → conv2d → conv2d → ...``` ```... → conv2d → conv2d → pool2d → conv2d → conv2d → ...```
...@@ -64,7 +64,7 @@ All the `fake_quantize_*` and `fake_dequantize_*` operators are being removed fr ...@@ -64,7 +64,7 @@ All the `fake_quantize_*` and `fake_dequantize_*` operators are being removed fr
### Dequantizing weights ### Dequantizing weights
Weights of `conv2d`, `depthwise_conv2d` and `mul` operators are assumed to be fake-quantized (with integer values in the `int8` range, but kept as `float`s) in QAT models. Here, the information about the scale from `fake_dequantize_max_abs` and `fake_channel_wise_dequantize_max_abs` operators is used to fake-dequantize the weights back to the full float range of values. At this moment the model becomes an unoptimized clean FP32 inference model. Weights of `conv2d`, `depthwise_conv2d` and `mul` operators are assumed to be fake-quantized (with integer values in the `int8` range, but kept as `float`s) in Quant models. Here, the information about the scale from `fake_dequantize_max_abs` and `fake_channel_wise_dequantize_max_abs` operators is used to fake-dequantize the weights back to the full float range of values. At this moment the model becomes an unoptimized clean FP32 inference model.
### Optimizing FP32 graph ### Optimizing FP32 graph
...@@ -88,11 +88,11 @@ Having gathered all the data needed for quantization we apply the `cpu_quantize_ ...@@ -88,11 +88,11 @@ Having gathered all the data needed for quantization we apply the `cpu_quantize_
## 4. Code example ## 4. Code example
The code snipped shows how the `Qat2Int8MkldnnPass` can be applied to a model graph: The code snipped shows how the `Quant2Int8MkldnnPass` can be applied to a model graph:
```python ```python
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization import Qat2Int8MkldnnPass from paddle.fluid.contrib.slim.quantization import Quant2Int8MkldnnPass
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid import core from paddle.fluid import core
...@@ -100,16 +100,16 @@ The code snipped shows how the `Qat2Int8MkldnnPass` can be applied to a model gr ...@@ -100,16 +100,16 @@ The code snipped shows how the `Qat2Int8MkldnnPass` can be applied to a model gr
graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False) graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False)
place = fluid.CPUPlace() place = fluid.CPUPlace()
# Convert the IrGraph to MKL-DNN supported INT8 IrGraph using the # Convert the IrGraph to MKL-DNN supported INT8 IrGraph using the
# Qat2Int8MkldnnPass. It requires a list of operators to be quantized # Quant2Int8MkldnnPass. It requires a list of operators to be quantized
mkldnn_pass = Qat2Int8MkldnnPass({'conv2d', 'pool2d'}, fluid.global_scope(), place, fluid.core, False) mkldnn_pass = Quant2Int8MkldnnPass({'conv2d', 'pool2d'}, fluid.global_scope(), place, fluid.core, False)
# Apply Qat2Int8MkldnnPass to IrGraph # Apply Quant2Int8MkldnnPass to IrGraph
mkldnn_pass.apply(graph) mkldnn_pass.apply(graph)
``` ```
## 5. Accuracy and Performance benchmark ## 5. Accuracy and Performance benchmark
This section contain QAT2 MKL-DNN accuracy and performance benchmark results measured on the following server: This section contain Quant2 MKL-DNN accuracy and performance benchmark results measured on the following server:
* Intel(R) Xeon(R) Gold 6271 (with AVX512 VNNI support), * Intel(R) Xeon(R) Gold 6271 (with AVX512 VNNI support),
...@@ -134,7 +134,7 @@ Performance benchmarks were run with the following environment settings: ...@@ -134,7 +134,7 @@ Performance benchmarks were run with the following environment settings:
>**Intel(R) Xeon(R) Gold 6271** >**Intel(R) Xeon(R) Gold 6271**
| Model | FP32 Top1 Accuracy | INT8 QAT Top1 Accuracy | Top1 Diff | FP32 Top5 Accuracy | INT8 QAT Top5 Accuracy | Top5 Diff | | Model | FP32 Top1 Accuracy | INT8 Quant Top1 Accuracy | Top1 Diff | FP32 Top5 Accuracy | INT8 Quant Top5 Accuracy | Top5 Diff |
| :----------: | :----------------: | :--------------------: | :-------: | :----------------: | :--------------------: | :-------: | | :----------: | :----------------: | :--------------------: | :-------: | :----------------: | :--------------------: | :-------: |
| MobileNet-V1 | 70.78% | 70.71% | -0.07% | 89.69% | 89.41% | -0.28% | | MobileNet-V1 | 70.78% | 70.71% | -0.07% | 89.69% | 89.41% | -0.28% |
| MobileNet-V2 | 71.90% | 72.11% | +0.21% | 90.56% | 90.62% | +0.06% | | MobileNet-V2 | 71.90% | 72.11% | +0.21% | 90.56% | 90.62% | +0.06% |
...@@ -150,7 +150,7 @@ Image classification models performance was measured using a single thread. The ...@@ -150,7 +150,7 @@ Image classification models performance was measured using a single thread. The
>**Intel(R) Xeon(R) Gold 6271** >**Intel(R) Xeon(R) Gold 6271**
| Model | FP32 (images/s) | INT8 QAT (images/s) | Ratio (INT8/FP32) | | Model | FP32 (images/s) | INT8 Quant (images/s) | Ratio (INT8/FP32) |
| :----------: | :-------------: | :-----------------: | :---------------: | | :----------: | :-------------: | :-----------------: | :---------------: |
| MobileNet-V1 | 74.05 | 196.98 | 2.66 | | MobileNet-V1 | 74.05 | 196.98 | 2.66 |
| MobileNet-V2 | 88.60 | 187.67 | 2.12 | | MobileNet-V2 | 88.60 | 187.67 | 2.12 |
...@@ -169,7 +169,7 @@ Notes: ...@@ -169,7 +169,7 @@ Notes:
>**Intel(R) Xeon(R) Gold 6271** >**Intel(R) Xeon(R) Gold 6271**
| Model | FP32 Accuracy | QAT INT8 Accuracy | Accuracy Diff | | Model | FP32 Accuracy | Quant INT8 Accuracy | Accuracy Diff |
|:------------:|:----------------------:|:----------------------:|:---------:| |:------------:|:----------------------:|:----------------------:|:---------:|
| Ernie | 80.20% | 79.44% | -0.76% | | Ernie | 80.20% | 79.44% | -0.76% |
...@@ -179,7 +179,7 @@ Notes: ...@@ -179,7 +179,7 @@ Notes:
>**Intel(R) Xeon(R) Gold 6271** >**Intel(R) Xeon(R) Gold 6271**
| Model | Threads | FP32 Latency (ms) | QAT INT8 Latency (ms) | Ratio (FP32/INT8) | | Model | Threads | FP32 Latency (ms) | Quant INT8 Latency (ms) | Ratio (FP32/INT8) |
|:------------:|:----------------------:|:-------------------:|:---------:|:---------:| |:------------:|:----------------------:|:-------------------:|:---------:|:---------:|
| Ernie | 1 thread | 237.21 | 79.26 | 2.99x | | Ernie | 1 thread | 237.21 | 79.26 | 2.99x |
| Ernie | 20 threads | 22.08 | 12.57 | 1.76x | | Ernie | 20 threads | 22.08 | 12.57 | 1.76x |
...@@ -188,7 +188,7 @@ Notes: ...@@ -188,7 +188,7 @@ Notes:
## 6. How to reproduce the results ## 6. How to reproduce the results
The steps below show, taking ResNet50 as an example, how to reproduce the above accuracy and performance results for Image Classification models. The steps below show, taking ResNet50 as an example, how to reproduce the above accuracy and performance results for Image Classification models.
To reproduce NLP models results (Ernie), please follow [How to reproduce Ernie QAT results on MKL-DNN](https://github.com/PaddlePaddle/benchmark/tree/master/Inference/c%2B%2B/ernie/mkldnn/README.md). To reproduce NLP models results (Ernie), please follow [How to reproduce Ernie Quant results on MKL-DNN](https://github.com/PaddlePaddle/benchmark/tree/master/Inference/c%2B%2B/ernie/mkldnn/README.md).
### Prepare dataset ### Prepare dataset
...@@ -202,18 +202,18 @@ The converted data binary file is saved by default in `$HOME/.cache/paddle/datas ...@@ -202,18 +202,18 @@ The converted data binary file is saved by default in `$HOME/.cache/paddle/datas
### Prepare models ### Prepare models
Run the following commands to download and extract QAT model: Run the following commands to download and extract Quant model:
```bash ```bash
mkdir -p /PATH/TO/DOWNLOAD/MODEL/ mkdir -p /PATH/TO/DOWNLOAD/MODEL/
cd /PATH/TO/DOWNLOAD/MODEL/ cd /PATH/TO/DOWNLOAD/MODEL/
export QAT_MODEL_NAME=resnet50 export QUANT_MODEL_NAME=resnet50
export QAT_MODEL_ARCHIVE=${QAT_MODEL_NAME}_quant.tar.gz export QUANT_MODEL_ARCHIVE=${QUANT_MODEL_NAME}_quant.tar.gz
wget http://paddle-inference-dist.bj.bcebos.com/int8/QAT2_models/${QAT_MODEL_ARCHIVE} wget http://paddle-inference-dist.bj.bcebos.com/int8/QAT2_models/${QUANT_MODEL_ARCHIVE}
mkdir ${QAT_MODEL_NAME} && tar -xvf ${QAT_MODEL_ARCHIVE} -C ${QAT_MODEL_NAME} mkdir ${QUANT_MODEL_NAME} && tar -xvf ${QUANT_MODEL_ARCHIVE} -C ${QUANT_MODEL_NAME}
``` ```
To download other QAT models, set the `QAT_MODEL_NAME` variable in the above commands to one of the values: `resnet101`, `mobilenetv1`, `mobilenetv2`, `vgg16`, `vgg19`. To download other Quant models, set the `QUANT_MODEL_NAME` variable in the above commands to one of the values: `resnet101`, `mobilenetv1`, `mobilenetv2`, `vgg16`, `vgg19`.
Download clean FP32 model for accuracy comparison against the INT8 model: Download clean FP32 model for accuracy comparison against the INT8 model:
...@@ -231,23 +231,23 @@ To download other FP32 models, set the `FP32_MODEL_NAME` variable to on of the v ...@@ -231,23 +231,23 @@ To download other FP32 models, set the `FP32_MODEL_NAME` variable to on of the v
#### Accuracy benchmark commands #### Accuracy benchmark commands
You can use the `qat2_int8_image_classification_comparison.py` script to reproduce the accuracy result of the INT8 QAT models. The following options are required: You can use the `quant2_int8_image_classification_comparison.py` script to reproduce the accuracy result of the INT8 Quant models. The following options are required:
* `--qat_model` - a path to a QAT model that will be transformed into INT8 model. * `--quant_model` - a path to a Quant model that will be transformed into INT8 model.
* `--fp32_model` - a path to an FP32 model whose accuracy will be measured and compared to the accuracy of the INT8 model. * `--fp32_model` - a path to an FP32 model whose accuracy will be measured and compared to the accuracy of the INT8 model.
* `--infer_data` - a path to the validation dataset. * `--infer_data` - a path to the validation dataset.
The following options are also accepted: The following options are also accepted:
* `--ops_to_quantize` - a comma-separated list of operator types to quantize. If the option is not used, an attempt to quantize all quantizable operators will be made, and in that case only quantizable operators which have quantization scales provided in the QAT model will be quantized. When deciding which operators to put on the list, the following have to be considered: * `--ops_to_quantize` - a comma-separated list of operator types to quantize. If the option is not used, an attempt to quantize all quantizable operators will be made, and in that case only quantizable operators which have quantization scales provided in the Quant model will be quantized. When deciding which operators to put on the list, the following have to be considered:
* Only operators which support quantization will be taken into account. * Only operators which support quantization will be taken into account.
* All the quantizable operators from the list, which are present in the model, must have quantization scales provided in the model. Otherwise, quantization of the operator will be skipped with a message saying which variable is missing a quantization scale. * All the quantizable operators from the list, which are present in the model, must have quantization scales provided in the model. Otherwise, quantization of the operator will be skipped with a message saying which variable is missing a quantization scale.
* Sometimes it may be suboptimal to quantize all quantizable operators in the model (cf. *Notes* in the **Gathering scales** section above). To find the optimal configuration for this option, user can run benchmark a few times with different lists of quantized operators present in the model and compare the results. For Image Classification models mentioned above the list usually comprises of `conv2d` and `pool2d` operators. * Sometimes it may be suboptimal to quantize all quantizable operators in the model (cf. *Notes* in the **Gathering scales** section above). To find the optimal configuration for this option, user can run benchmark a few times with different lists of quantized operators present in the model and compare the results. For Image Classification models mentioned above the list usually comprises of `conv2d` and `pool2d` operators.
* `--op_ids_to_skip` - a comma-separated list of operator ids to skip in quantization. To get an id of a particular operator run the script with the `--debug` option first (see below for the description of the option), and having opened the generated file `qat_int8_cpu_quantize_placement_pass.dot` find the id number written in parentheses next to the name of the operator. * `--op_ids_to_skip` - a comma-separated list of operator ids to skip in quantization. To get an id of a particular operator run the script with the `--debug` option first (see below for the description of the option), and having opened the generated file `int8_<some_number>_cpu_quantize_placement_pass.dot` find the id number written in parentheses next to the name of the operator.
* `--debug` - add this option to generate a series of `*.dot` files containing the model graphs after each step of the transformation. For a description of the DOT format see [DOT]( https://graphviz.gitlab.io/_pages/doc/info/lang.html). The files will be saved in the current location. To open the `*.dot` files use any of the Graphviz tools available on your system (e.g. `xdot` tool on Linux or `dot` tool on Windows, for documentation see [Graphviz](http://www.graphviz.org/documentation/)). * `--debug` - add this option to generate a series of `*.dot` files containing the model graphs after each step of the transformation. For a description of the DOT format see [DOT]( https://graphviz.gitlab.io/_pages/doc/info/lang.html). The files will be saved in the current location. To open the `*.dot` files use any of the Graphviz tools available on your system (e.g. `xdot` tool on Linux or `dot` tool on Windows, for documentation see [Graphviz](http://www.graphviz.org/documentation/)).
```bash ```bash
cd /PATH/TO/PADDLE cd /PATH/TO/PADDLE
OMP_NUM_THREADS=28 FLAGS_use_mkldnn=true python python/paddle/fluid/contrib/slim/tests/qat2_int8_image_classification_comparison.py --qat_model=/PATH/TO/DOWNLOADED/QAT/MODEL --fp32_model=/PATH/TO/DOWNLOADED/FP32/MODEL --infer_data=$HOME/.cache/paddle/dataset/int8/download/int8_full_val.bin --batch_size=50 --batch_num=1000 --acc_diff_threshold=0.01 --ops_to_quantize="conv2d,pool2d" OMP_NUM_THREADS=28 FLAGS_use_mkldnn=true python python/paddle/fluid/contrib/slim/tests/quant2_int8_image_classification_comparison.py --quant_model=/PATH/TO/DOWNLOADED/QUANT/MODEL --fp32_model=/PATH/TO/DOWNLOADED/FP32/MODEL --infer_data=$HOME/.cache/paddle/dataset/int8/download/int8_full_val.bin --batch_size=50 --batch_num=1000 --acc_diff_threshold=0.01 --ops_to_quantize="conv2d,pool2d"
``` ```
> Notes: Due to a large amount of images in the `int8_full_val.bin` dataset (50 000), the accuracy benchmark may last long. To accelerate accuracy measuring, it is recommended to set `OMP_NUM_THREADS` to the maximum number of physical cores available on the server. > Notes: Due to a large amount of images in the `int8_full_val.bin` dataset (50 000), the accuracy benchmark may last long. To accelerate accuracy measuring, it is recommended to set `OMP_NUM_THREADS` to the maximum number of physical cores available on the server.
...@@ -256,16 +256,16 @@ OMP_NUM_THREADS=28 FLAGS_use_mkldnn=true python python/paddle/fluid/contrib/slim ...@@ -256,16 +256,16 @@ OMP_NUM_THREADS=28 FLAGS_use_mkldnn=true python python/paddle/fluid/contrib/slim
To reproduce the performance results, the environment variable `OMP_NUM_THREADS=1` and `--batch_size=1` option should be set. To reproduce the performance results, the environment variable `OMP_NUM_THREADS=1` and `--batch_size=1` option should be set.
1. Transform the QAT model into INT8 model by applying the `Qat2Int8MkldnnPass` pass and save the result. You can use the script `save_qat_model.py` for this purpose. It also accepts the option `--ops_to_quantize` with a list of operators to quantize. 1. Transform the Quant model into INT8 model by applying the `Quant2Int8MkldnnPass` pass and save the result. You can use the script `save_quant_model.py` for this purpose. It also accepts the option `--ops_to_quantize` with a list of operators to quantize.
```bash ```bash
cd /PATH/TO/PADDLE/build cd /PATH/TO/PADDLE/build
python ../python/paddle/fluid/contrib/slim/tests/save_qat_model.py --qat_model_path=/PATH/TO/DOWNLOADED/QAT/MODEL --int8_model_save_path=/PATH/TO/SAVE/QAT/INT8/MODEL --ops_to_quantize="conv2d,pool2d" python ../python/paddle/fluid/contrib/slim/tests/save_quant_model.py --quant_model_path=/PATH/TO/DOWNLOADED/QUANT/MODEL --int8_model_save_path=/PATH/TO/SAVE/QUANT/INT8/MODEL --ops_to_quantize="conv2d,pool2d"
``` ```
2. Run the C-API test for performance benchmark. 2. Run the C-API test for performance benchmark.
```bash ```bash
cd /PATH/TO/PADDLE/build cd /PATH/TO/PADDLE/build
OMP_NUM_THREADS=1 paddle/fluid/inference/tests/api/test_analyzer_qat_image_classification ARGS --enable_fp32=false --with_accuracy_layer=false --int8_model=/PATH/TO/SAVED/QAT/INT8/MODEL --infer_data=$HOME/.cache/paddle/dataset/int8/download/int8_full_val.bin --batch_size=1 --paddle_num_threads=1 OMP_NUM_THREADS=1 paddle/fluid/inference/tests/api/test_analyzer_quant_image_classification ARGS --enable_fp32=false --with_accuracy_layer=false --int8_model=/PATH/TO/SAVED/QUANT/INT8/MODEL --infer_data=$HOME/.cache/paddle/dataset/int8/download/int8_full_val.bin --batch_size=1 --paddle_num_threads=1
``` ```
...@@ -24,7 +24,7 @@ import time ...@@ -24,7 +24,7 @@ import time
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import Qat2Int8MkldnnPass from paddle.fluid.contrib.slim.quantization import Quant2Int8MkldnnPass
from paddle.fluid import core from paddle.fluid import core
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
...@@ -42,7 +42,7 @@ def parse_args(): ...@@ -42,7 +42,7 @@ def parse_args():
help='Number of the first minibatches to skip in performance statistics.' help='Number of the first minibatches to skip in performance statistics.'
) )
parser.add_argument( parser.add_argument(
'--qat_model', type=str, default='', help='A path to a QAT model.') '--quant_model', type=str, default='', help='A path to a Quant model.')
parser.add_argument( parser.add_argument(
'--fp32_model', type=str, default='', help='A path to an FP32 model.') '--fp32_model', type=str, default='', help='A path to an FP32 model.')
parser.add_argument('--infer_data', type=str, default='', help='Data file.') parser.add_argument('--infer_data', type=str, default='', help='Data file.')
...@@ -71,15 +71,15 @@ def parse_args(): ...@@ -71,15 +71,15 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--debug', '--debug',
action='store_true', action='store_true',
help='If used, the graph of QAT model is drawn.') help='If used, the graph of Quant model is drawn.')
test_args, args = parser.parse_known_args(namespace=unittest) test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args return test_args, sys.argv[:1] + args
class Qat2Int8ImageClassificationComparisonTest(unittest.TestCase): class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
""" """
Test for accuracy comparison of FP32 and QAT2 INT8 Image Classification inference. Test for accuracy comparison of FP32 and Quant2 INT8 Image Classification inference.
""" """
def _reader_creator(self, data_file='data.bin'): def _reader_creator(self, data_file='data.bin'):
...@@ -182,9 +182,9 @@ class Qat2Int8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -182,9 +182,9 @@ class Qat2Int8ImageClassificationComparisonTest(unittest.TestCase):
graph = IrGraph(core.Graph(inference_program.desc), for_test=True) graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if (self._debug): if (self._debug):
graph.draw('.', 'qat_orig', graph.all_op_nodes()) graph.draw('.', 'quant_orig', graph.all_op_nodes())
if (transform_to_int8): if (transform_to_int8):
transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass( transform_to_mkldnn_int8_pass = Quant2Int8MkldnnPass(
self._quantized_ops, self._quantized_ops,
_op_ids_to_skip=self._op_ids_to_skip, _op_ids_to_skip=self._op_ids_to_skip,
_scope=inference_scope, _scope=inference_scope,
...@@ -223,7 +223,7 @@ class Qat2Int8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -223,7 +223,7 @@ class Qat2Int8ImageClassificationComparisonTest(unittest.TestCase):
labels = np.array([x[1] for x in data]).astype('int64') labels = np.array([x[1] for x in data]).astype('int64')
if (transform_to_int8 == True): if (transform_to_int8 == True):
# QAT INT8 models do not have accuracy measuring layers # INT8 models obtained from Quant models do not have accuracy measuring layers
start = time.time() start = time.time()
out = exe.run(inference_program, out = exe.run(inference_program,
feed={feed_target_names[0]: images}, feed={feed_target_names[0]: images},
...@@ -301,8 +301,8 @@ class Qat2Int8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -301,8 +301,8 @@ class Qat2Int8ImageClassificationComparisonTest(unittest.TestCase):
if not fluid.core.is_compiled_with_mkldnn(): if not fluid.core.is_compiled_with_mkldnn():
return return
qat_model_path = test_case_args.qat_model quant_model_path = test_case_args.quant_model
assert qat_model_path, 'The QAT model path cannot be empty. Please, use the --qat_model option.' assert quant_model_path, 'The Quant model path cannot be empty. Please, use the --quant_model option.'
fp32_model_path = test_case_args.fp32_model fp32_model_path = test_case_args.fp32_model
assert fp32_model_path, 'The FP32 model path cannot be empty. Please, use the --fp32_model option.' assert fp32_model_path, 'The FP32 model path cannot be empty. Please, use the --fp32_model option.'
data_path = test_case_args.infer_data data_path = test_case_args.infer_data
...@@ -323,8 +323,8 @@ class Qat2Int8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -323,8 +323,8 @@ class Qat2Int8ImageClassificationComparisonTest(unittest.TestCase):
self._op_ids_to_skip = set( self._op_ids_to_skip = set(
map(int, test_case_args.op_ids_to_skip.split(','))) map(int, test_case_args.op_ids_to_skip.split(',')))
_logger.info('FP32 & QAT INT8 prediction run.') _logger.info('FP32 & Quant INT8 prediction run.')
_logger.info('QAT model: {}'.format(qat_model_path)) _logger.info('Quant model: {}'.format(quant_model_path))
_logger.info('FP32 model: {}'.format(fp32_model_path)) _logger.info('FP32 model: {}'.format(fp32_model_path))
_logger.info('Dataset: {}'.format(data_path)) _logger.info('Dataset: {}'.format(data_path))
_logger.info('Batch size: {}'.format(batch_size)) _logger.info('Batch size: {}'.format(batch_size))
...@@ -346,12 +346,12 @@ class Qat2Int8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -346,12 +346,12 @@ class Qat2Int8ImageClassificationComparisonTest(unittest.TestCase):
batch_num, batch_num,
skip_batch_num, skip_batch_num,
transform_to_int8=False) transform_to_int8=False)
_logger.info('--- QAT INT8 prediction start ---') _logger.info('--- Quant INT8 prediction start ---')
val_reader = paddle.batch( val_reader = paddle.batch(
self._reader_creator(data_path), batch_size=batch_size) self._reader_creator(data_path), batch_size=batch_size)
int8_output, int8_acc1, int8_acc5, int8_fps, int8_lat = self._predict( int8_output, int8_acc1, int8_acc5, int8_fps, int8_lat = self._predict(
val_reader, val_reader,
qat_model_path, quant_model_path,
batch_size, batch_size,
batch_num, batch_num,
skip_batch_num, skip_batch_num,
......
...@@ -24,7 +24,7 @@ import time ...@@ -24,7 +24,7 @@ import time
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import Qat2Int8MkldnnPass from paddle.fluid.contrib.slim.quantization import Quant2Int8MkldnnPass
from paddle.fluid import core from paddle.fluid import core
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
...@@ -42,12 +42,12 @@ def parse_args(): ...@@ -42,12 +42,12 @@ def parse_args():
help='Number of the first minibatches to skip in performance statistics.' help='Number of the first minibatches to skip in performance statistics.'
) )
parser.add_argument( parser.add_argument(
'--qat_model', type=str, default='', help='A path to a QAT model.') '--quant_model', type=str, default='', help='A path to a Quant model.')
parser.add_argument( parser.add_argument(
'--fp32_model', '--fp32_model',
type=str, type=str,
default='', default='',
help='A path to an FP32 model. If empty, the QAT model will be used for FP32 inference.' help='A path to an FP32 model. If empty, the Quant model will be used for FP32 inference.'
) )
parser.add_argument('--infer_data', type=str, default='', help='Data file.') parser.add_argument('--infer_data', type=str, default='', help='Data file.')
parser.add_argument( parser.add_argument(
...@@ -77,16 +77,16 @@ def parse_args(): ...@@ -77,16 +77,16 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--debug', '--debug',
action='store_true', action='store_true',
help='If used, the graph of QAT model is drawn.') help='If used, the graph of Quant model is drawn.')
test_args, args = parser.parse_known_args(namespace=unittest) test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args return test_args, sys.argv[:1] + args
class QatInt8NLPComparisonTest(unittest.TestCase): class QuantInt8NLPComparisonTest(unittest.TestCase):
""" """
Test for accuracy comparison of QAT FP32 and INT8 NLP inference. Test for accuracy comparison of Quant FP32 and INT8 NLP inference.
""" """
def _reader_creator(self, data_file=None, labels_file=None): def _reader_creator(self, data_file=None, labels_file=None):
...@@ -158,9 +158,9 @@ class QatInt8NLPComparisonTest(unittest.TestCase): ...@@ -158,9 +158,9 @@ class QatInt8NLPComparisonTest(unittest.TestCase):
graph = IrGraph(core.Graph(inference_program.desc), for_test=True) graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if (self._debug): if (self._debug):
graph.draw('.', 'qat_orig', graph.all_op_nodes()) graph.draw('.', 'quant_orig', graph.all_op_nodes())
if (transform_to_int8): if (transform_to_int8):
transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass( transform_to_mkldnn_int8_pass = Quant2Int8MkldnnPass(
self._quantized_ops, self._quantized_ops,
_op_ids_to_skip=self._op_ids_to_skip, _op_ids_to_skip=self._op_ids_to_skip,
_scope=inference_scope, _scope=inference_scope,
...@@ -248,9 +248,9 @@ class QatInt8NLPComparisonTest(unittest.TestCase): ...@@ -248,9 +248,9 @@ class QatInt8NLPComparisonTest(unittest.TestCase):
if not fluid.core.is_compiled_with_mkldnn(): if not fluid.core.is_compiled_with_mkldnn():
return return
qat_model_path = test_case_args.qat_model quant_model_path = test_case_args.quant_model
assert qat_model_path, 'The QAT model path cannot be empty. Please, use the --qat_model option.' assert quant_model_path, 'The Quant model path cannot be empty. Please, use the --quant_model option.'
fp32_model_path = test_case_args.fp32_model if test_case_args.fp32_model else qat_model_path fp32_model_path = test_case_args.fp32_model if test_case_args.fp32_model else quant_model_path
data_path = test_case_args.infer_data data_path = test_case_args.infer_data
assert data_path, 'The dataset path cannot be empty. Please, use the --infer_data option.' assert data_path, 'The dataset path cannot be empty. Please, use the --infer_data option.'
labels_path = test_case_args.labels labels_path = test_case_args.labels
...@@ -270,8 +270,8 @@ class QatInt8NLPComparisonTest(unittest.TestCase): ...@@ -270,8 +270,8 @@ class QatInt8NLPComparisonTest(unittest.TestCase):
self._op_ids_to_skip = set( self._op_ids_to_skip = set(
map(int, test_case_args.op_ids_to_skip.split(','))) map(int, test_case_args.op_ids_to_skip.split(',')))
_logger.info('FP32 & QAT INT8 prediction run.') _logger.info('FP32 & Quant INT8 prediction run.')
_logger.info('QAT model: {}'.format(qat_model_path)) _logger.info('Quant model: {}'.format(quant_model_path))
_logger.info('FP32 model: {}'.format(fp32_model_path)) _logger.info('FP32 model: {}'.format(fp32_model_path))
_logger.info('Dataset: {}'.format(data_path)) _logger.info('Dataset: {}'.format(data_path))
_logger.info('Labels: {}'.format(labels_path)) _logger.info('Labels: {}'.format(labels_path))
...@@ -295,12 +295,12 @@ class QatInt8NLPComparisonTest(unittest.TestCase): ...@@ -295,12 +295,12 @@ class QatInt8NLPComparisonTest(unittest.TestCase):
skip_batch_num, skip_batch_num,
transform_to_int8=False) transform_to_int8=False)
_logger.info('FP32: avg accuracy: {0:.6f}'.format(fp32_acc)) _logger.info('FP32: avg accuracy: {0:.6f}'.format(fp32_acc))
_logger.info('--- QAT INT8 prediction start ---') _logger.info('--- Quant INT8 prediction start ---')
val_reader = paddle.batch( val_reader = paddle.batch(
self._reader_creator(data_path, labels_path), batch_size=batch_size) self._reader_creator(data_path, labels_path), batch_size=batch_size)
int8_acc, int8_pps, int8_lat = self._predict( int8_acc, int8_pps, int8_lat = self._predict(
val_reader, val_reader,
qat_model_path, quant_model_path,
batch_size, batch_size,
batch_num, batch_num,
skip_batch_num, skip_batch_num,
......
...@@ -24,7 +24,7 @@ import time ...@@ -24,7 +24,7 @@ import time
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QatInt8MkldnnPass from paddle.fluid.contrib.slim.quantization import QuantInt8MkldnnPass
from paddle.fluid import core from paddle.fluid import core
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
...@@ -44,9 +44,9 @@ def parse_args(): ...@@ -44,9 +44,9 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--debug', '--debug',
action='store_true', action='store_true',
help='If used, the graph of QAT model is drawn.') help='If used, the graph of Quant model is drawn.')
parser.add_argument( parser.add_argument(
'--qat_model', type=str, default='', help='A path to a QAT model.') '--quant_model', type=str, default='', help='A path to a Quant model.')
parser.add_argument('--infer_data', type=str, default='', help='Data file.') parser.add_argument('--infer_data', type=str, default='', help='Data file.')
parser.add_argument( parser.add_argument(
'--batch_num', '--batch_num',
...@@ -64,9 +64,9 @@ def parse_args(): ...@@ -64,9 +64,9 @@ def parse_args():
return test_args, sys.argv[:1] + args return test_args, sys.argv[:1] + args
class QatInt8ImageClassificationComparisonTest(unittest.TestCase): class QuantInt8ImageClassificationComparisonTest(unittest.TestCase):
""" """
Test for accuracy comparison of QAT FP32 and INT8 Image Classification inference. Test for accuracy comparison of Quant FP32 and INT8 Image Classification inference.
""" """
def _reader_creator(self, data_file='data.bin'): def _reader_creator(self, data_file='data.bin'):
...@@ -169,9 +169,9 @@ class QatInt8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -169,9 +169,9 @@ class QatInt8ImageClassificationComparisonTest(unittest.TestCase):
graph = IrGraph(core.Graph(inference_program.desc), for_test=True) graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if (self._debug): if (self._debug):
graph.draw('.', 'qat_orig', graph.all_op_nodes()) graph.draw('.', 'quant_orig', graph.all_op_nodes())
if (transform_to_int8): if (transform_to_int8):
mkldnn_int8_pass = QatInt8MkldnnPass( mkldnn_int8_pass = QuantInt8MkldnnPass(
_scope=inference_scope, _place=place) _scope=inference_scope, _place=place)
graph = mkldnn_int8_pass.apply(graph) graph = mkldnn_int8_pass.apply(graph)
else: else:
...@@ -264,8 +264,8 @@ class QatInt8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -264,8 +264,8 @@ class QatInt8ImageClassificationComparisonTest(unittest.TestCase):
if not fluid.core.is_compiled_with_mkldnn(): if not fluid.core.is_compiled_with_mkldnn():
return return
qat_model_path = test_case_args.qat_model quant_model_path = test_case_args.quant_model
assert qat_model_path, 'The QAT model path cannot be empty. Please, use the --qat_model option.' assert quant_model_path, 'The Quant model path cannot be empty. Please, use the --quant_model option.'
data_path = test_case_args.infer_data data_path = test_case_args.infer_data
assert data_path, 'The dataset path cannot be empty. Please, use the --infer_data option.' assert data_path, 'The dataset path cannot be empty. Please, use the --infer_data option.'
batch_size = test_case_args.batch_size batch_size = test_case_args.batch_size
...@@ -274,29 +274,29 @@ class QatInt8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -274,29 +274,29 @@ class QatInt8ImageClassificationComparisonTest(unittest.TestCase):
acc_diff_threshold = test_case_args.acc_diff_threshold acc_diff_threshold = test_case_args.acc_diff_threshold
self._debug = test_case_args.debug self._debug = test_case_args.debug
_logger.info('QAT FP32 & INT8 prediction run.') _logger.info('Quant FP32 & INT8 prediction run.')
_logger.info('QAT model: {0}'.format(qat_model_path)) _logger.info('Quant model: {0}'.format(quant_model_path))
_logger.info('Dataset: {0}'.format(data_path)) _logger.info('Dataset: {0}'.format(data_path))
_logger.info('Batch size: {0}'.format(batch_size)) _logger.info('Batch size: {0}'.format(batch_size))
_logger.info('Batch number: {0}'.format(batch_num)) _logger.info('Batch number: {0}'.format(batch_num))
_logger.info('Accuracy drop threshold: {0}.'.format(acc_diff_threshold)) _logger.info('Accuracy drop threshold: {0}.'.format(acc_diff_threshold))
_logger.info('--- QAT FP32 prediction start ---') _logger.info('--- Quant FP32 prediction start ---')
val_reader = paddle.batch( val_reader = paddle.batch(
self._reader_creator(data_path), batch_size=batch_size) self._reader_creator(data_path), batch_size=batch_size)
fp32_output, fp32_acc1, fp32_acc5, fp32_fps, fp32_lat = self._predict( fp32_output, fp32_acc1, fp32_acc5, fp32_fps, fp32_lat = self._predict(
val_reader, val_reader,
qat_model_path, quant_model_path,
batch_size, batch_size,
batch_num, batch_num,
skip_batch_num, skip_batch_num,
transform_to_int8=False) transform_to_int8=False)
_logger.info('--- QAT INT8 prediction start ---') _logger.info('--- Quant INT8 prediction start ---')
val_reader = paddle.batch( val_reader = paddle.batch(
self._reader_creator(data_path), batch_size=batch_size) self._reader_creator(data_path), batch_size=batch_size)
int8_output, int8_acc1, int8_acc5, int8_fps, int8_lat = self._predict( int8_output, int8_acc1, int8_acc5, int8_fps, int8_lat = self._predict(
val_reader, val_reader,
qat_model_path, quant_model_path,
batch_size, batch_size,
batch_num, batch_num,
skip_batch_num, skip_batch_num,
......
...@@ -24,14 +24,17 @@ import time ...@@ -24,14 +24,17 @@ import time
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import Qat2Int8MkldnnPass from paddle.fluid.contrib.slim.quantization import Quant2Int8MkldnnPass
from paddle.fluid import core from paddle.fluid import core
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'--qat_model_path', type=str, default='', help='A path to a QAT model.') '--quant_model_path',
type=str,
default='',
help='A path to a Quant model.')
parser.add_argument( parser.add_argument(
'--fp32_model_save_path', '--fp32_model_save_path',
type=str, type=str,
...@@ -56,7 +59,7 @@ def parse_args(): ...@@ -56,7 +59,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--debug', '--debug',
action='store_true', action='store_true',
help='If used, the graph of QAT model is drawn.') help='If used, the graph of Quant model is drawn.')
test_args, args = parser.parse_known_args(namespace=unittest) test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args return test_args, sys.argv[:1] + args
...@@ -85,8 +88,8 @@ def transform_and_save_model(original_path, save_path, save_type): ...@@ -85,8 +88,8 @@ def transform_and_save_model(original_path, save_path, save_type):
graph = IrGraph(core.Graph(inference_program.desc), for_test=True) graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if (test_args.debug): if (test_args.debug):
graph.draw('.', 'qat_orig', graph.all_op_nodes()) graph.draw('.', 'quant_orig', graph.all_op_nodes())
transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass( transform_to_mkldnn_int8_pass = Quant2Int8MkldnnPass(
ops_to_quantize, ops_to_quantize,
_op_ids_to_skip=op_ids_to_skip, _op_ids_to_skip=op_ids_to_skip,
_scope=inference_scope, _scope=inference_scope,
...@@ -103,16 +106,16 @@ def transform_and_save_model(original_path, save_path, save_type): ...@@ -103,16 +106,16 @@ def transform_and_save_model(original_path, save_path, save_type):
with fluid.scope_guard(inference_scope): with fluid.scope_guard(inference_scope):
fluid.io.save_inference_model(save_path, feed_target_names, fluid.io.save_inference_model(save_path, feed_target_names,
fetch_targets, exe, inference_program) fetch_targets, exe, inference_program)
print("Success! Transformed QAT_{0} model can be found at {1}\n".format( print("Success! Transformed Quant_{0} model can be found at {1}\n".
save_type, save_path)) format(save_type, save_path))
if __name__ == '__main__': if __name__ == '__main__':
global test_args global test_args
test_args, remaining_args = parse_args() test_args, remaining_args = parse_args()
if test_args.fp32_model_save_path: if test_args.fp32_model_save_path:
transform_and_save_model(test_args.qat_model_path, transform_and_save_model(test_args.quant_model_path,
test_args.fp32_model_save_path, 'FP32') test_args.fp32_model_save_path, 'FP32')
if test_args.int8_model_save_path: if test_args.int8_model_save_path:
transform_and_save_model(test_args.qat_model_path, transform_and_save_model(test_args.quant_model_path,
test_args.int8_model_save_path, 'INT8') test_args.int8_model_save_path, 'INT8')
...@@ -17,10 +17,10 @@ import numpy as np ...@@ -17,10 +17,10 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import Qat2Int8MkldnnPass from paddle.fluid.contrib.slim.quantization import Quant2Int8MkldnnPass
class TestQat2Int8MkldnnPass(unittest.TestCase): class TestQuant2Int8MkldnnPass(unittest.TestCase):
def setUp(self): def setUp(self):
self.scope = fluid.Scope() self.scope = fluid.Scope()
self.place = fluid.CPUPlace() self.place = fluid.CPUPlace()
...@@ -109,20 +109,20 @@ class TestQat2Int8MkldnnPass(unittest.TestCase): ...@@ -109,20 +109,20 @@ class TestQat2Int8MkldnnPass(unittest.TestCase):
if op.op().has_attr("fuse_brelu") and op.op().attr("fuse_brelu"): if op.op().has_attr("fuse_brelu") and op.op().attr("fuse_brelu"):
self.assertTrue(op.op().attr("fuse_activation") == "relu6") self.assertTrue(op.op().attr("fuse_activation") == "relu6")
def test_qat_update_activation(self): def test_quant_update_activation(self):
program = fluid.Program() program = fluid.Program()
with fluid.program_guard(program): with fluid.program_guard(program):
self.prepare_program(program) self.prepare_program(program)
graph = IrGraph(core.Graph(program.desc), for_test=True) graph = IrGraph(core.Graph(program.desc), for_test=True)
graph = self.remove_fuse_activation_attribute(graph) graph = self.remove_fuse_activation_attribute(graph)
self.check_graph_before_pass(graph) self.check_graph_before_pass(graph)
qat2_int8_mkldnn_pass = Qat2Int8MkldnnPass( quant2_int8_mkldnn_pass = Quant2Int8MkldnnPass(
self.quantized_ops, self.quantized_ops,
_scope=self.scope, _scope=self.scope,
_place=self.place, _place=self.place,
_core=core, _core=core,
_debug=False) _debug=False)
graph = qat2_int8_mkldnn_pass._update_activations(graph) graph = quant2_int8_mkldnn_pass._update_activations(graph)
self.check_graph_after_pass(graph) self.check_graph_after_pass(graph)
......
...@@ -22,7 +22,7 @@ import paddle ...@@ -22,7 +22,7 @@ import paddle
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QatInt8MkldnnPass from paddle.fluid.contrib.slim.quantization import QuantInt8MkldnnPass
from paddle.fluid import core from paddle.fluid import core
os.environ["CPU_NUM"] = "1" os.environ["CPU_NUM"] = "1"
...@@ -90,7 +90,7 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase): ...@@ -90,7 +90,7 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
seed, seed,
activation_quant_type, activation_quant_type,
weight_quant_type='abs_max', weight_quant_type='abs_max',
qat_perf=False, quant_perf=False,
for_ci=False): for_ci=False):
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
...@@ -109,7 +109,7 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase): ...@@ -109,7 +109,7 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
scope = fluid.Scope() scope = fluid.Scope()
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
exe.run(startup) exe.run(startup)
# Apply the QAT QuantizationTransformPass # Apply the QuantizationTransformPass
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=scope, scope=scope,
place=place, place=place,
...@@ -149,7 +149,7 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase): ...@@ -149,7 +149,7 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
freeze_pass.apply(test_graph) freeze_pass.apply(test_graph)
# Transform quantized graph for MKL-DNN INT8 inference # Transform quantized graph for MKL-DNN INT8 inference
mkldnn_int8_pass = QatInt8MkldnnPass(_scope=scope, _place=place) mkldnn_int8_pass = QuantInt8MkldnnPass(_scope=scope, _place=place)
mkldnn_int8_pass.apply(test_graph) mkldnn_int8_pass.apply(test_graph)
dev_name = '_cpu_' dev_name = '_cpu_'
if not for_ci: if not for_ci:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册