未验证 提交 d0a921ba 编写于 作者: W Wojciech Uss 提交者: GitHub

Quant2 updates and fixes (#25313)

上级 869d59cc
...@@ -46,10 +46,8 @@ void LogCannotQuantizeOp(Node* op, const char* details = nullptr) { ...@@ -46,10 +46,8 @@ void LogCannotQuantizeOp(Node* op, const char* details = nullptr) {
} }
void LogScaleIsMissingForVar(Node* var) { void LogScaleIsMissingForVar(Node* var) {
std::stringstream msg_ss; VLOG(4) << "Quantization scale for the variable " << var->Name()
msg_ss << "Quantization scale for the variable " << var->Name()
<< " is missing."; << " is missing.";
PrettyLogDetail(msg_ss.str().c_str());
} }
void LogQuantizationDisabled(Node* op) { void LogQuantizationDisabled(Node* op) {
...@@ -256,6 +254,14 @@ void CPUQuantizePass::QuantizeConv(Graph* graph, ...@@ -256,6 +254,14 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
auto has_output_scale = AreScalesPresentForNodes(conv_op, {conv_output});
if (with_residual_data && !has_output_scale) {
LogCannotQuantizeOp(conv_op,
"Conv op with ResidualData input cannot be quantized "
"without output scale.");
return;
}
if (with_residual_data) { if (with_residual_data) {
GET_IR_NODE_FROM_SUBGRAPH(conv_residual_data, conv_residual_data, GET_IR_NODE_FROM_SUBGRAPH(conv_residual_data, conv_residual_data,
conv_pattern); conv_pattern);
...@@ -294,7 +300,7 @@ void CPUQuantizePass::QuantizeConv(Graph* graph, ...@@ -294,7 +300,7 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
conv_op->Op()->SetAttr("Scale_weights", filter_scale); conv_op->Op()->SetAttr("Scale_weights", filter_scale);
// if quantization scale is missing for output tensor, return fp32 data // if quantization scale is missing for output tensor, return fp32 data
if (AreScalesPresentForNodes(conv_op, {conv_output})) { if (has_output_scale) {
bool is_output_unsigned{false}; bool is_output_unsigned{false};
auto output_scale = auto output_scale =
GetScaleValueForNode(conv_output, &is_output_unsigned); GetScaleValueForNode(conv_output, &is_output_unsigned);
......
...@@ -55,7 +55,7 @@ class Quant2Int8MkldnnPass(object): ...@@ -55,7 +55,7 @@ class Quant2Int8MkldnnPass(object):
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs' 'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
] ]
self._ops_to_quantize = _ops_to_quantize self._ops_to_quantize = _ops_to_quantize
self._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip != None else set( self._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip is not None else set(
[-1]) [-1])
self._scale_immutable_ops = [ self._scale_immutable_ops = [
'transpose2', 'reshape2', 'pool2d', 'scale' 'transpose2', 'reshape2', 'pool2d', 'scale'
...@@ -71,11 +71,14 @@ class Quant2Int8MkldnnPass(object): ...@@ -71,11 +71,14 @@ class Quant2Int8MkldnnPass(object):
self._var_quant_scales = {} self._var_quant_scales = {}
self._max_range = {} self._max_range = {}
self._s8_max = 127 self._s8_max = 127
self._pass_idx = 0
self._pass_group = 'int8'
def apply(self, graph): def apply(self, graph):
assert isinstance(graph, assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.' IrGraph), 'graph must be the instance of IrGraph.'
self._reset_pass_idx_and_group('int8')
graph = self._gather_weight_scales_from_fake(graph) graph = self._gather_weight_scales_from_fake(graph)
graph = self._gather_output_scales_from_attr(graph) graph = self._gather_output_scales_from_attr(graph)
graph = self._gather_input_scales_from_fake(graph) graph = self._gather_input_scales_from_fake(graph)
...@@ -86,21 +89,24 @@ class Quant2Int8MkldnnPass(object): ...@@ -86,21 +89,24 @@ class Quant2Int8MkldnnPass(object):
graph = self._update_relu_output_scales(graph) graph = self._update_relu_output_scales(graph)
graph = self._propagate_scales(graph) graph = self._propagate_scales(graph)
graph = self._quantize_fp32_graph(graph) graph = self._quantize_fp32_graph(graph)
graph = self._optimize_int8_graph(graph) graph = self._final_optimizations(graph)
graph = self._cleanup(graph) graph = self._cleanup(graph)
return graph return graph
def apply_fp32(self, graph): def prepare_and_optimize_fp32(self, graph):
assert isinstance(graph, assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.' IrGraph), 'graph must be the instance of IrGraph.'
graph = self._gather_weight_scales_from_fake(graph) self._reset_pass_idx_and_group('fp32')
graph = self._remove_fake_ops(graph)
graph = self._dequantize_weights(graph)
graph = self._optimize_fp32_graph(graph) graph = self._optimize_fp32_graph(graph)
graph = self._final_optimizations(graph)
graph = self._cleanup(graph) graph = self._cleanup(graph)
return graph return graph
def _reset_pass_idx_and_group(self, group):
self._pass_idx = 0
self._pass_group = group
def _convert_scale2tensor(self, scale): def _convert_scale2tensor(self, scale):
tensor = core.LoDTensor() tensor = core.LoDTensor()
tensor.set(scale, core.CPUPlace()) tensor.set(scale, core.CPUPlace())
...@@ -333,20 +339,38 @@ class Quant2Int8MkldnnPass(object): ...@@ -333,20 +339,38 @@ class Quant2Int8MkldnnPass(object):
def _optimize_fp32_graph(self, graph): def _optimize_fp32_graph(self, graph):
graph = self._update_activations(graph) graph = self._update_activations(graph)
graph = self._remove_ctrl_vars(graph) graph = self._remove_ctrl_vars(graph)
graph = self._apply_pass(graph, 'attention_lstm_fuse_pass')
graph = self._apply_pass(graph, 'seqconv_eltadd_relu_fuse_pass')
# graph = self._apply_pass(graph, 'seqpool_concat_fuse_pass')
graph = self._apply_pass(graph, 'seqpool_cvm_concat_fuse_pass')
# graph = self._apply_pass(graph, 'embedding_fc_lstm_fuse_pass')
graph = self._apply_pass(graph, 'fc_lstm_fuse_pass')
graph = self._apply_pass(graph, 'mul_lstm_fuse_pass')
graph = self._apply_pass(graph, 'fc_gru_fuse_pass')
graph = self._apply_pass(graph, 'mul_gru_fuse_pass')
graph = self._apply_pass(graph, 'seq_concat_fc_fuse_pass')
graph = self._apply_pass(graph, 'squared_mat_sub_fuse_pass')
graph = self._apply_pass(graph, 'is_test_pass')
graph = self._apply_pass(graph, 'mkldnn_placement_pass', graph = self._apply_pass(graph, 'mkldnn_placement_pass',
['mkldnn_enabled_op_types'], [set()]) ['mkldnn_enabled_op_types'], [set()])
graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass') graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass')
graph = self._apply_pass(graph, 'conv_bn_fuse_pass') graph = self._apply_pass(graph, 'conv_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass') graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_transpose_bn_fuse_pass')
graph = self._apply_pass(graph,
'conv_transpose_eltwiseadd_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_bias_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_bias_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'fc_fuse_pass', graph = self._apply_pass(graph, 'fc_fuse_pass',
['use_gpu', 'use_fc_padding'], [False, False]) ['use_gpu', 'use_fc_padding'], [False, False])
graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass')
if self._is_fc_quantized(graph): if self._is_fc_quantized(graph):
graph = self._apply_pass(graph, 'fc_mkldnn_pass') graph = self._apply_pass(graph, 'fc_mkldnn_pass')
graph = self._apply_pass(graph, 'matmul_transpose_reshape_fuse_pass') graph = self._apply_pass(graph, 'matmul_transpose_reshape_fuse_pass')
# the following pass should be the last one since it will work on all fused ops.
graph = self._apply_pass(graph, 'runtime_context_cache_pass')
return graph return graph
def _apply_pass(self, graph, pass_name, attrs=None, attr_values=None): def _apply_pass(self, graph, pass_name, attrs=None, attr_values=None):
...@@ -362,12 +386,13 @@ class Quant2Int8MkldnnPass(object): ...@@ -362,12 +386,13 @@ class Quant2Int8MkldnnPass(object):
ir_pass.set(attr, value) ir_pass.set(attr, value)
ir_pass.apply(cpp_graph) ir_pass.apply(cpp_graph)
if self._debug: if self._debug:
graph.draw('.', 'quant_fp32_{}'.format(pass_name), graph.draw('.', '{}_{}_{}'.format(self._pass_group, self._pass_idx,
graph.all_op_nodes()) pass_name), graph.all_op_nodes())
self._remove_unused_var_nodes(graph) self._remove_unused_var_nodes(graph)
self._pass_idx += 1
return graph return graph
def _optimize_int8_graph(self, graph): def _final_optimizations(self, graph):
# remove dropout ops # remove dropout ops
graph = self._apply_pass(graph, 'simplify_with_basic_ops_pass') graph = self._apply_pass(graph, 'simplify_with_basic_ops_pass')
# make some MKL-DNN ops working inplace # make some MKL-DNN ops working inplace
...@@ -448,8 +473,7 @@ class Quant2Int8MkldnnPass(object): ...@@ -448,8 +473,7 @@ class Quant2Int8MkldnnPass(object):
self._var_quant_scales[out_name] = (True, tensor) self._var_quant_scales[out_name] = (True, tensor)
return graph return graph
conv_predicate = lambda op: op.attr("fuse_activation") in self._relu_ops and \ conv_predicate = lambda op: op.attr("fuse_activation") in self._relu_ops
op.attr("fuse_residual_connection") == False
graph = _set_unsigned_scale(graph, self._conv_ops, "Output", graph = _set_unsigned_scale(graph, self._conv_ops, "Output",
conv_predicate) conv_predicate)
...@@ -465,15 +489,10 @@ class Quant2Int8MkldnnPass(object): ...@@ -465,15 +489,10 @@ class Quant2Int8MkldnnPass(object):
return 'NHWC' if self._is_conv_quantized(graph) else 'NCHW' return 'NHWC' if self._is_conv_quantized(graph) else 'NCHW'
def _quantize_fp32_graph(self, graph): def _quantize_fp32_graph(self, graph):
ir_pass = self._core.get_pass('cpu_quantize_placement_pass') graph = self._apply_pass(
cpp_graph = graph.graph graph, 'cpu_quantize_placement_pass',
ir_pass.set('quantize_enabled_op_types', self._ops_to_quantize) ['quantize_enabled_op_types', 'quantize_excluded_op_ids'],
ir_pass.set('quantize_excluded_op_ids', [self._ops_to_quantize, self._find_avg_pooling_ids(graph)])
self._find_avg_pooling_ids(graph))
ir_pass.apply(cpp_graph)
if self._debug:
graph.draw('.', 'quant_int8_{}'.format(ir_pass.type()),
graph.all_op_nodes())
graph = self._apply_pass(graph, 'scale_matmul_fuse_pass') graph = self._apply_pass(graph, 'scale_matmul_fuse_pass')
graph = self._apply_pass(graph, graph = self._apply_pass(graph,
'reshape_transpose_matmul_mkldnn_fuse_pass') 'reshape_transpose_matmul_mkldnn_fuse_pass')
......
...@@ -57,7 +57,7 @@ endfunction() ...@@ -57,7 +57,7 @@ endfunction()
# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 25 # set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 25
function(inference_quant2_int8_image_classification_test target quant_model_dir fp32_model_dir dataset_path ops_to_quantize) function(inference_quant2_int8_image_classification_test target quant_model_dir fp32_model_dir dataset_path)
py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant2_int8_image_classification_comparison.py" py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant2_int8_image_classification_comparison.py"
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
...@@ -67,12 +67,11 @@ function(inference_quant2_int8_image_classification_test target quant_model_dir ...@@ -67,12 +67,11 @@ function(inference_quant2_int8_image_classification_test target quant_model_dir
--infer_data ${dataset_path} --infer_data ${dataset_path}
--batch_size 10 --batch_size 10
--batch_num 2 --batch_num 2
--acc_diff_threshold 0.1 --acc_diff_threshold 0.1)
--ops_to_quantize ${ops_to_quantize})
endfunction() endfunction()
# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20 # set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20
function(inference_quant2_int8_nlp_test target quant_model_dir fp32_model_dir dataset_path labels_path) function(inference_quant2_int8_nlp_test target quant_model_dir fp32_model_dir dataset_path labels_path ops_to_quantize)
py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant2_int8_nlp_comparison.py" py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant2_int8_nlp_comparison.py"
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
...@@ -83,7 +82,8 @@ function(inference_quant2_int8_nlp_test target quant_model_dir fp32_model_dir da ...@@ -83,7 +82,8 @@ function(inference_quant2_int8_nlp_test target quant_model_dir fp32_model_dir da
--labels ${labels_path} --labels ${labels_path}
--batch_size 10 --batch_size 10
--batch_num 2 --batch_num 2
--acc_diff_threshold 0.1) --acc_diff_threshold 0.1
--ops_to_quantize ${ops_to_quantize})
endfunction() endfunction()
function(download_quant_data install_dir data_file) function(download_quant_data install_dir data_file)
...@@ -98,20 +98,20 @@ function(download_quant_model install_dir data_file) ...@@ -98,20 +98,20 @@ function(download_quant_model install_dir data_file)
endif() endif()
endfunction() endfunction()
function(save_quant_ic_model_test target quant_model_dir fp32_model_save_path int8_model_save_path ops_to_quantize) function(save_quant_ic_model_test target quant_model_dir fp32_model_save_path int8_model_save_path)
py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_quant_model.py py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_quant_model.py
ARGS --quant_model_path ${quant_model_dir} ARGS --quant_model_path ${quant_model_dir}
--fp32_model_save_path ${fp32_model_save_path} --fp32_model_save_path ${fp32_model_save_path}
--int8_model_save_path ${int8_model_save_path} --int8_model_save_path ${int8_model_save_path}
--ops_to_quantize ${ops_to_quantize}
--debug) --debug)
endfunction() endfunction()
function(save_quant_nlp_model_test target quant_model_dir fp32_model_save_path int8_model_save_path) function(save_quant_nlp_model_test target quant_model_dir fp32_model_save_path int8_model_save_path ops_to_quantize)
py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_quant_model.py py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_quant_model.py
ARGS --quant_model_path ${quant_model_dir} ARGS --quant_model_path ${quant_model_dir}
--fp32_model_save_path ${fp32_model_save_path} --fp32_model_save_path ${fp32_model_save_path}
--int8_model_save_path ${int8_model_save_path}) --int8_model_save_path ${int8_model_save_path}
--ops_to_quantize ${ops_to_quantize})
endfunction() endfunction()
function(convert_model2dot_test target model_path save_graph_dir save_graph_name) function(convert_model2dot_test target model_path save_graph_dir save_graph_name)
...@@ -224,36 +224,34 @@ if(LINUX AND WITH_MKLDNN) ...@@ -224,36 +224,34 @@ if(LINUX AND WITH_MKLDNN)
### Quant2 for image classification ### Quant2 for image classification
set(QUANT2_IC_OPS_TO_QUANTIZE "conv2d,pool2d")
# Quant2 ResNet50 with input/output scales in `fake_quantize_moving_average_abs_max` operators, # Quant2 ResNet50 with input/output scales in `fake_quantize_moving_average_abs_max` operators,
# with weight scales in `fake_dequantize_max_abs` operators # with weight scales in `fake_dequantize_max_abs` operators
set(QUANT2_RESNET50_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet50_quant2") set(QUANT2_RESNET50_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet50_quant2")
set(QUANT2_RESNET50_MODEL_ARCHIVE "ResNet50_qat_perf.tar.gz") set(QUANT2_RESNET50_MODEL_ARCHIVE "ResNet50_qat_perf.tar.gz")
download_quant_model(${QUANT2_RESNET50_MODEL_DIR} ${QUANT2_RESNET50_MODEL_ARCHIVE}) download_quant_model(${QUANT2_RESNET50_MODEL_DIR} ${QUANT2_RESNET50_MODEL_ARCHIVE})
set(FP32_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50") set(FP32_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50")
inference_quant2_int8_image_classification_test(test_quant2_int8_resnet50_mkldnn ${QUANT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QUANT2_IC_OPS_TO_QUANTIZE}) inference_quant2_int8_image_classification_test(test_quant2_int8_resnet50_mkldnn ${QUANT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
# Quant2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes, # Quant2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes,
# with weight scales in `fake_dequantize_max_abs` operators # with weight scales in `fake_dequantize_max_abs` operators
set(QUANT2_RESNET50_RANGE_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet50_quant2_range") set(QUANT2_RESNET50_RANGE_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet50_quant2_range")
set(QUANT2_RESNET50_RANGE_MODEL_ARCHIVE "ResNet50_qat_range.tar.gz") set(QUANT2_RESNET50_RANGE_MODEL_ARCHIVE "ResNet50_qat_range.tar.gz")
download_quant_model(${QUANT2_RESNET50_RANGE_MODEL_DIR} ${QUANT2_RESNET50_RANGE_MODEL_ARCHIVE}) download_quant_model(${QUANT2_RESNET50_RANGE_MODEL_DIR} ${QUANT2_RESNET50_RANGE_MODEL_ARCHIVE})
inference_quant2_int8_image_classification_test(test_quant2_int8_resnet50_range_mkldnn ${QUANT2_RESNET50_RANGE_MODEL_DIR}/ResNet50_qat_range ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QUANT2_IC_OPS_TO_QUANTIZE}) inference_quant2_int8_image_classification_test(test_quant2_int8_resnet50_range_mkldnn ${QUANT2_RESNET50_RANGE_MODEL_DIR}/ResNet50_qat_range ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
# Quant2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes, # Quant2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes,
# with weight scales in `fake_channel_wise_dequantize_max_abs` operators # with weight scales in `fake_channel_wise_dequantize_max_abs` operators
set(QUANT2_RESNET50_CHANNELWISE_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet50_quant2_channelwise") set(QUANT2_RESNET50_CHANNELWISE_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet50_quant2_channelwise")
set(QUANT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE "ResNet50_qat_channelwise.tar.gz") set(QUANT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE "ResNet50_qat_channelwise.tar.gz")
download_quant_model(${QUANT2_RESNET50_CHANNELWISE_MODEL_DIR} ${QUANT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE}) download_quant_model(${QUANT2_RESNET50_CHANNELWISE_MODEL_DIR} ${QUANT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE})
inference_quant2_int8_image_classification_test(test_quant2_int8_resnet50_channelwise_mkldnn ${QUANT2_RESNET50_CHANNELWISE_MODEL_DIR}/ResNet50_qat_channelwise ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QUANT2_IC_OPS_TO_QUANTIZE}) inference_quant2_int8_image_classification_test(test_quant2_int8_resnet50_channelwise_mkldnn ${QUANT2_RESNET50_CHANNELWISE_MODEL_DIR}/ResNet50_qat_channelwise ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
# Quant2 MobileNetV1 # Quant2 MobileNetV1
set(QUANT2_MOBILENETV1_MODEL_DIR "${QUANT_INSTALL_DIR}/MobileNetV1_quant2") set(QUANT2_MOBILENETV1_MODEL_DIR "${QUANT_INSTALL_DIR}/MobileNetV1_quant2")
set(QUANT2_MOBILENETV1_MODEL_ARCHIVE "MobileNet_qat_perf.tar.gz") set(QUANT2_MOBILENETV1_MODEL_ARCHIVE "MobileNet_qat_perf.tar.gz")
download_quant_model(${QUANT2_MOBILENETV1_MODEL_DIR} ${QUANT2_MOBILENETV1_MODEL_ARCHIVE}) download_quant_model(${QUANT2_MOBILENETV1_MODEL_DIR} ${QUANT2_MOBILENETV1_MODEL_ARCHIVE})
set(FP32_MOBILENETV1_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1") set(FP32_MOBILENETV1_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1")
inference_quant2_int8_image_classification_test(test_quant2_int8_mobilenetv1_mkldnn ${QUANT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf/float ${FP32_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QUANT2_IC_OPS_TO_QUANTIZE}) inference_quant2_int8_image_classification_test(test_quant2_int8_mobilenetv1_mkldnn ${QUANT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf/float ${FP32_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
### Quant2 for NLP ### Quant2 for NLP
...@@ -263,6 +261,8 @@ if(LINUX AND WITH_MKLDNN) ...@@ -263,6 +261,8 @@ if(LINUX AND WITH_MKLDNN)
set(NLP_LABLES_PATH "${NLP_DATA_DIR}/Ernie_dataset/label.xnli.dev") set(NLP_LABLES_PATH "${NLP_DATA_DIR}/Ernie_dataset/label.xnli.dev")
download_quant_data(${NLP_DATA_DIR} ${NLP_DATA_ARCHIVE}) download_quant_data(${NLP_DATA_DIR} ${NLP_DATA_ARCHIVE})
set(QUANT2_NLP_OPS_TO_QUANTIZE "fc,reshape2,transpose2,matmul,elementwise_add")
# Quant2 Ernie # Quant2 Ernie
set(QUANT2_ERNIE_MODEL_ARCHIVE "ernie_qat.tar.gz") set(QUANT2_ERNIE_MODEL_ARCHIVE "ernie_qat.tar.gz")
set(QUANT2_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_quant2") set(QUANT2_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_quant2")
...@@ -270,17 +270,17 @@ if(LINUX AND WITH_MKLDNN) ...@@ -270,17 +270,17 @@ if(LINUX AND WITH_MKLDNN)
set(FP32_ERNIE_MODEL_ARCHIVE "ernie_fp32_model.tar.gz") set(FP32_ERNIE_MODEL_ARCHIVE "ernie_fp32_model.tar.gz")
set(FP32_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_float") set(FP32_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_float")
download_quant_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE}) download_quant_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE})
inference_quant2_int8_nlp_test(test_quant2_int8_ernie_mkldnn ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} ${NLP_LABLES_PATH}) inference_quant2_int8_nlp_test(test_quant2_int8_ernie_mkldnn ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} ${NLP_LABLES_PATH} ${QUANT2_NLP_OPS_TO_QUANTIZE})
### Save FP32 model or INT8 model from Quant model ### Save FP32 model or INT8 model from Quant model
set(QUANT2_INT8_RESNET50_SAVE_PATH "${QUANT_INSTALL_DIR}/ResNet50_quant2_int8") set(QUANT2_INT8_RESNET50_SAVE_PATH "${QUANT_INSTALL_DIR}/ResNet50_quant2_int8")
set(QUANT2_FP32_RESNET50_SAVE_PATH "${QUANT_INSTALL_DIR}/ResNet50_quant2_fp32") set(QUANT2_FP32_RESNET50_SAVE_PATH "${QUANT_INSTALL_DIR}/ResNet50_quant2_fp32")
save_quant_ic_model_test(save_quant2_model_resnet50 ${QUANT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QUANT2_FP32_RESNET50_SAVE_PATH} ${QUANT2_INT8_RESNET50_SAVE_PATH} ${QUANT2_IC_OPS_TO_QUANTIZE}) save_quant_ic_model_test(save_quant2_model_resnet50 ${QUANT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QUANT2_FP32_RESNET50_SAVE_PATH} ${QUANT2_INT8_RESNET50_SAVE_PATH})
set(QUANT2_INT8_ERNIE_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_int8") set(QUANT2_INT8_ERNIE_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_int8")
set(QUANT2_FP32_ERNIE_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_fp32") set(QUANT2_FP32_ERNIE_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_fp32")
save_quant_nlp_model_test(save_quant2_model_ernie ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QUANT2_FP32_ERNIE_SAVE_PATH} ${QUANT2_INT8_ERNIE_SAVE_PATH}) save_quant_nlp_model_test(save_quant2_model_ernie ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QUANT2_FP32_ERNIE_SAVE_PATH} ${QUANT2_INT8_ERNIE_SAVE_PATH} ${QUANT2_NLP_OPS_TO_QUANTIZE})
# Convert Quant2 model to dot and pdf files # Convert Quant2 model to dot and pdf files
set(QUANT2_INT8_ERNIE_DOT_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_int8_dot_file") set(QUANT2_INT8_ERNIE_DOT_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_int8_dot_file")
......
...@@ -167,7 +167,8 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -167,7 +167,8 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
batch_size=1, batch_size=1,
batch_num=1, batch_num=1,
skip_batch_num=0, skip_batch_num=0,
transform_to_int8=False): target='quant'):
assert target in ['quant', 'int8', 'fp32']
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
inference_scope = fluid.executor.global_scope() inference_scope = fluid.executor.global_scope()
...@@ -183,17 +184,19 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -183,17 +184,19 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
graph = IrGraph(core.Graph(inference_program.desc), for_test=True) graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if (self._debug): if (self._debug):
graph.draw('.', 'quant_orig', graph.all_op_nodes()) graph.draw('.', 'quant_orig', graph.all_op_nodes())
if (transform_to_int8): quant_transform_pass = Quant2Int8MkldnnPass(
transform_to_mkldnn_int8_pass = Quant2Int8MkldnnPass(
self._quantized_ops, self._quantized_ops,
_op_ids_to_skip=self._op_ids_to_skip, _op_ids_to_skip=self._op_ids_to_skip,
_scope=inference_scope, _scope=inference_scope,
_place=place, _place=place,
_core=core, _core=core,
_debug=self._debug) _debug=self._debug)
graph = transform_to_mkldnn_int8_pass.apply(graph) if (target == 'quant'):
else:
graph = self._prepare_for_fp32_mkldnn(graph) graph = self._prepare_for_fp32_mkldnn(graph)
elif (target == 'int8'):
graph = quant_transform_pass.apply(graph)
else: # target == fp32
graph = quant_transform_pass.prepare_and_optimize_fp32(graph)
inference_program = graph.to_program() inference_program = graph.to_program()
...@@ -222,18 +225,7 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -222,18 +225,7 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
images = np.array(images).astype('float32') images = np.array(images).astype('float32')
labels = np.array([x[1] for x in data]).astype('int64') labels = np.array([x[1] for x in data]).astype('int64')
if (transform_to_int8 == True): if (target == 'fp32'):
# INT8 models obtained from Quant models do not have accuracy measuring layers
start = time.time()
out = exe.run(inference_program,
feed={feed_target_names[0]: images},
fetch_list=fetch_targets)
batch_time = (time.time() - start) * 1000 # in miliseconds
outputs.append(out[0])
# Calculate accuracy result
batch_acc1, batch_acc5 = self._get_batch_accuracy(out[0],
labels)
else:
# FP32 models have accuracy measuring layers # FP32 models have accuracy measuring layers
labels = labels.reshape([-1, 1]) labels = labels.reshape([-1, 1])
start = time.time() start = time.time()
...@@ -246,6 +238,18 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -246,6 +238,18 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
batch_time = (time.time() - start) * 1000 # in miliseconds batch_time = (time.time() - start) * 1000 # in miliseconds
batch_acc1, batch_acc5 = out[1][0], out[2][0] batch_acc1, batch_acc5 = out[1][0], out[2][0]
outputs.append(batch_acc1) outputs.append(batch_acc1)
else:
# Quant INT8 models do not have accuracy measuring layers
start = time.time()
out = exe.run(inference_program,
feed={feed_target_names[0]: images},
fetch_list=fetch_targets)
batch_time = (time.time() - start) * 1000 # in miliseconds
outputs.append(out[0])
# Calculate accuracy result
batch_acc1, batch_acc5 = self._get_batch_accuracy(out[0],
labels)
infer_accs1.append(batch_acc1) infer_accs1.append(batch_acc1)
infer_accs5.append(batch_acc5) infer_accs5.append(batch_acc5)
samples = len(data) samples = len(data)
...@@ -274,28 +278,37 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -274,28 +278,37 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
return outputs, acc1_avg, acc5_avg, fps_avg, latency_avg return outputs, acc1_avg, acc5_avg, fps_avg, latency_avg
def _summarize_performance(self, fp32_fps, fp32_lat, int8_fps, int8_lat): def _print_performance(self, title, fps, lat):
_logger.info('{0}: avg fps: {1:.2f}, avg latency: {2:.4f} ms'.format(
title, fps, lat))
def _print_accuracy(self, title, acc1, acc5):
_logger.info(
'{0}: avg top1 accuracy: {1:.4f}, avg top5 accuracy: {2:.4f}'.
format(title, acc1, acc5))
def _summarize_performance(self, int8_fps, int8_lat, fp32_fps, fp32_lat):
_logger.info('--- Performance summary ---') _logger.info('--- Performance summary ---')
_logger.info('FP32: avg fps: {0:.2f}, avg latency: {1:.4f} ms'.format( self._print_performance('INT8', int8_fps, int8_lat)
fp32_fps, fp32_lat)) if fp32_lat >= 0:
_logger.info('INT8: avg fps: {0:.2f}, avg latency: {1:.4f} ms'.format( self._print_performance('FP32', fp32_fps, fp32_lat)
int8_fps, int8_lat))
def _compare_accuracy(self, fp32_acc1, fp32_acc5, int8_acc1, int8_acc5, def _summarize_accuracy(self, quant_acc1, quant_acc5, int8_acc1, int8_acc5,
threshold): fp32_acc1, fp32_acc5):
_logger.info('--- Accuracy summary ---') _logger.info('--- Accuracy summary ---')
self._print_accuracy('Quant', quant_acc1, quant_acc5)
self._print_accuracy('INT8', int8_acc1, int8_acc5)
if fp32_acc1 >= 0:
self._print_accuracy('FP32', fp32_acc1, fp32_acc5)
def _compare_accuracy(self, threshold, quant_acc1, int8_acc1):
_logger.info( _logger.info(
'Accepted top1 accuracy drop threshold: {0}. (condition: (FP32_top1_acc - IN8_top1_acc) <= threshold)' 'Accepted top1 accuracy drop threshold: {0}. (condition: (Quant_top1_acc - IN8_top1_acc) <= threshold && Quant_top1_acc > 0.5 && INT8_top1_acc > 0.5)'
.format(threshold)) .format(threshold))
_logger.info( # We assume valid accuracy to be at least 0.5
'FP32: avg top1 accuracy: {0:.4f}, avg top5 accuracy: {1:.4f}'. assert quant_acc1 > 0.5
format(fp32_acc1, fp32_acc5)) assert int8_acc1 > 0.5
_logger.info( assert quant_acc1 - int8_acc1 <= threshold
'INT8: avg top1 accuracy: {0:.4f}, avg top5 accuracy: {1:.4f}'.
format(int8_acc1, int8_acc5))
assert fp32_acc1 > 0.0
assert int8_acc1 > 0.0
assert fp32_acc1 - int8_acc1 <= threshold
def test_graph_transformation(self): def test_graph_transformation(self):
if not fluid.core.is_compiled_with_mkldnn(): if not fluid.core.is_compiled_with_mkldnn():
...@@ -303,10 +316,9 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -303,10 +316,9 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
quant_model_path = test_case_args.quant_model quant_model_path = test_case_args.quant_model
assert quant_model_path, 'The Quant model path cannot be empty. Please, use the --quant_model option.' assert quant_model_path, 'The Quant model path cannot be empty. Please, use the --quant_model option.'
fp32_model_path = test_case_args.fp32_model
assert fp32_model_path, 'The FP32 model path cannot be empty. Please, use the --fp32_model option.'
data_path = test_case_args.infer_data data_path = test_case_args.infer_data
assert data_path, 'The dataset path cannot be empty. Please, use the --infer_data option.' assert data_path, 'The dataset path cannot be empty. Please, use the --infer_data option.'
fp32_model_path = test_case_args.fp32_model
batch_size = test_case_args.batch_size batch_size = test_case_args.batch_size
batch_num = test_case_args.batch_num batch_num = test_case_args.batch_num
skip_batch_num = test_case_args.skip_batch_num skip_batch_num = test_case_args.skip_batch_num
...@@ -323,8 +335,9 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -323,8 +335,9 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
self._op_ids_to_skip = set( self._op_ids_to_skip = set(
map(int, test_case_args.op_ids_to_skip.split(','))) map(int, test_case_args.op_ids_to_skip.split(',')))
_logger.info('FP32 & Quant INT8 prediction run.') _logger.info('Quant & INT8 prediction run.')
_logger.info('Quant model: {}'.format(quant_model_path)) _logger.info('Quant model: {}'.format(quant_model_path))
if fp32_model_path:
_logger.info('FP32 model: {}'.format(fp32_model_path)) _logger.info('FP32 model: {}'.format(fp32_model_path))
_logger.info('Dataset: {}'.format(data_path)) _logger.info('Dataset: {}'.format(data_path))
_logger.info('Batch size: {}'.format(batch_size)) _logger.info('Batch size: {}'.format(batch_size))
...@@ -336,17 +349,20 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -336,17 +349,20 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
map(str, self._op_ids_to_skip)) if test_case_args.op_ids_to_skip map(str, self._op_ids_to_skip)) if test_case_args.op_ids_to_skip
else 'none')) else 'none'))
_logger.info('--- FP32 prediction start ---') _logger.info('--- Quant prediction start ---')
val_reader = paddle.batch( val_reader = paddle.batch(
self._reader_creator(data_path), batch_size=batch_size) self._reader_creator(data_path), batch_size=batch_size)
fp32_output, fp32_acc1, fp32_acc5, fp32_fps, fp32_lat = self._predict( quant_output, quant_acc1, quant_acc5, quant_fps, quant_lat = self._predict(
val_reader, val_reader,
fp32_model_path, quant_model_path,
batch_size, batch_size,
batch_num, batch_num,
skip_batch_num, skip_batch_num,
transform_to_int8=False) target='quant')
_logger.info('--- Quant INT8 prediction start ---') self._print_performance('Quant', quant_fps, quant_lat)
self._print_accuracy('Quant', quant_acc1, quant_acc5)
_logger.info('--- INT8 prediction start ---')
val_reader = paddle.batch( val_reader = paddle.batch(
self._reader_creator(data_path), batch_size=batch_size) self._reader_creator(data_path), batch_size=batch_size)
int8_output, int8_acc1, int8_acc5, int8_fps, int8_lat = self._predict( int8_output, int8_acc1, int8_acc5, int8_fps, int8_lat = self._predict(
...@@ -355,11 +371,29 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase): ...@@ -355,11 +371,29 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
batch_size, batch_size,
batch_num, batch_num,
skip_batch_num, skip_batch_num,
transform_to_int8=True) target='int8')
self._print_performance('INT8', int8_fps, int8_lat)
self._print_accuracy('INT8', int8_acc1, int8_acc5)
self._summarize_performance(fp32_fps, fp32_lat, int8_fps, int8_lat) fp32_acc1 = fp32_acc5 = fp32_fps = fp32_lat = -1
self._compare_accuracy(fp32_acc1, fp32_acc5, int8_acc1, int8_acc5, if fp32_model_path:
acc_diff_threshold) _logger.info('--- FP32 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path), batch_size=batch_size)
fp32_output, fp32_acc1, fp32_acc5, fp32_fps, fp32_lat = self._predict(
val_reader,
fp32_model_path,
batch_size,
batch_num,
skip_batch_num,
target='fp32')
self._print_performance('FP32', fp32_fps, fp32_lat)
self._print_accuracy('FP32', fp32_acc1, fp32_acc5)
self._summarize_performance(int8_fps, int8_lat, fp32_fps, fp32_lat)
self._summarize_accuracy(quant_acc1, quant_acc5, int8_acc1, int8_acc5,
fp32_acc1, fp32_acc5)
self._compare_accuracy(acc_diff_threshold, quant_acc1, int8_acc1)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -17,8 +17,6 @@ import os ...@@ -17,8 +17,6 @@ import os
import sys import sys
import argparse import argparse
import logging import logging
import struct
import six
import numpy as np import numpy as np
import time import time
import paddle import paddle
...@@ -143,7 +141,8 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -143,7 +141,8 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
batch_size=1, batch_size=1,
batch_num=1, batch_num=1,
skip_batch_num=0, skip_batch_num=0,
transform_to_int8=False): target='quant'):
assert target in ['quant', 'int8', 'fp32']
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
inference_scope = fluid.executor.global_scope() inference_scope = fluid.executor.global_scope()
...@@ -159,15 +158,19 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -159,15 +158,19 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
graph = IrGraph(core.Graph(inference_program.desc), for_test=True) graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if (self._debug): if (self._debug):
graph.draw('.', 'quant_orig', graph.all_op_nodes()) graph.draw('.', 'quant_orig', graph.all_op_nodes())
if (transform_to_int8): if (target != 'quant'):
transform_to_mkldnn_int8_pass = Quant2Int8MkldnnPass( quant_transform_pass = Quant2Int8MkldnnPass(
self._quantized_ops, self._quantized_ops,
_op_ids_to_skip=self._op_ids_to_skip, _op_ids_to_skip=self._op_ids_to_skip,
_scope=inference_scope, _scope=inference_scope,
_place=place, _place=place,
_core=core, _core=core,
_debug=self._debug) _debug=self._debug)
graph = transform_to_mkldnn_int8_pass.apply(graph) if (target == 'int8'):
graph = quant_transform_pass.apply(graph)
else: # target == fp32
graph = quant_transform_pass.prepare_and_optimize_fp32(
graph)
inference_program = graph.to_program() inference_program = graph.to_program()
...@@ -223,26 +226,35 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -223,26 +226,35 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
return acc_avg, pps_avg, latency_avg return acc_avg, pps_avg, latency_avg
def _summarize_performance(self, fp32_pps, fp32_lat, int8_pps, int8_lat): def _print_performance(self, title, pps, lat):
_logger.info('--- Performance summary ---')
_logger.info(
'FP32: avg predictions per sec: {0:.2f}, avg latency: {1:.4f} ms'.
format(fp32_pps, fp32_lat))
_logger.info( _logger.info(
'INT8: avg predictions per sec: {0:.2f}, avg latency: {1:.4f} ms'. '{0}: avg predictions per sec: {1:.2f}, avg latency: {2:.4f} ms'.
format(int8_pps, int8_lat)) format(title, pps, lat))
def _print_accuracy(self, title, acc):
_logger.info('{0}: avg accuracy: {1:.6f}'.format(title, acc))
def _summarize_performance(self, int8_pps, int8_lat, fp32_pps, fp32_lat):
_logger.info('--- Performance summary ---')
self._print_performance('INT8', int8_pps, int8_lat)
if fp32_lat >= 0:
self._print_performance('FP32', fp32_pps, fp32_lat)
def _compare_accuracy(self, fp32_acc, int8_acc, threshold): def _summarize_accuracy(self, quant_acc, int8_acc, fp32_acc):
_logger.info('--- Accuracy summary ---') _logger.info('--- Accuracy summary ---')
self._print_accuracy('Quant', quant_acc)
self._print_accuracy('INT8', int8_acc)
if fp32_acc >= 0:
self._print_accuracy('FP32', fp32_acc)
def _compare_accuracy(self, threshold, quant_acc, int8_acc):
_logger.info( _logger.info(
'Accepted accuracy drop threshold: {0}. (condition: (FP32_acc - INT8_acc) <= threshold)' 'Accepted accuracy drop threshold: {0}. (condition: (Quant_acc - INT8_acc) <= threshold)'
.format(threshold)) .format(threshold))
_logger.info('FP32: avg accuracy: {0:.6f}'.format(fp32_acc))
_logger.info('INT8: avg accuracy: {0:.6f}'.format(int8_acc))
# Random outputs give accuracy about 0.33, we assume valid accuracy to be at least 0.5 # Random outputs give accuracy about 0.33, we assume valid accuracy to be at least 0.5
assert fp32_acc > 0.5 assert quant_acc > 0.5
assert int8_acc > 0.5 assert int8_acc > 0.5
assert fp32_acc - int8_acc <= threshold assert quant_acc - int8_acc <= threshold
def test_graph_transformation(self): def test_graph_transformation(self):
if not fluid.core.is_compiled_with_mkldnn(): if not fluid.core.is_compiled_with_mkldnn():
...@@ -250,9 +262,9 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -250,9 +262,9 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
quant_model_path = test_case_args.quant_model quant_model_path = test_case_args.quant_model
assert quant_model_path, 'The Quant model path cannot be empty. Please, use the --quant_model option.' assert quant_model_path, 'The Quant model path cannot be empty. Please, use the --quant_model option.'
fp32_model_path = test_case_args.fp32_model if test_case_args.fp32_model else quant_model_path
data_path = test_case_args.infer_data data_path = test_case_args.infer_data
assert data_path, 'The dataset path cannot be empty. Please, use the --infer_data option.' assert data_path, 'The dataset path cannot be empty. Please, use the --infer_data option.'
fp32_model_path = test_case_args.fp32_model
labels_path = test_case_args.labels labels_path = test_case_args.labels
batch_size = test_case_args.batch_size batch_size = test_case_args.batch_size
batch_num = test_case_args.batch_num batch_num = test_case_args.batch_num
...@@ -270,8 +282,9 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -270,8 +282,9 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
self._op_ids_to_skip = set( self._op_ids_to_skip = set(
map(int, test_case_args.op_ids_to_skip.split(','))) map(int, test_case_args.op_ids_to_skip.split(',')))
_logger.info('FP32 & Quant INT8 prediction run.') _logger.info('Quant & INT8 prediction run.')
_logger.info('Quant model: {}'.format(quant_model_path)) _logger.info('Quant model: {}'.format(quant_model_path))
if fp32_model_path:
_logger.info('FP32 model: {}'.format(fp32_model_path)) _logger.info('FP32 model: {}'.format(fp32_model_path))
_logger.info('Dataset: {}'.format(data_path)) _logger.info('Dataset: {}'.format(data_path))
_logger.info('Labels: {}'.format(labels_path)) _logger.info('Labels: {}'.format(labels_path))
...@@ -284,18 +297,20 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -284,18 +297,20 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
map(str, self._op_ids_to_skip)) if test_case_args.op_ids_to_skip map(str, self._op_ids_to_skip)) if test_case_args.op_ids_to_skip
else 'none')) else 'none'))
_logger.info('--- FP32 prediction start ---') _logger.info('--- Quant prediction start ---')
val_reader = paddle.batch( val_reader = paddle.batch(
self._reader_creator(data_path, labels_path), batch_size=batch_size) self._reader_creator(data_path, labels_path), batch_size=batch_size)
fp32_acc, fp32_pps, fp32_lat = self._predict( quant_acc, quant_pps, quant_lat = self._predict(
val_reader, val_reader,
fp32_model_path, quant_model_path,
batch_size, batch_size,
batch_num, batch_num,
skip_batch_num, skip_batch_num,
transform_to_int8=False) target='quant')
_logger.info('FP32: avg accuracy: {0:.6f}'.format(fp32_acc)) self._print_performance('Quant', quant_pps, quant_lat)
_logger.info('--- Quant INT8 prediction start ---') self._print_accuracy('Quant', quant_acc)
_logger.info('--- INT8 prediction start ---')
val_reader = paddle.batch( val_reader = paddle.batch(
self._reader_creator(data_path, labels_path), batch_size=batch_size) self._reader_creator(data_path, labels_path), batch_size=batch_size)
int8_acc, int8_pps, int8_lat = self._predict( int8_acc, int8_pps, int8_lat = self._predict(
...@@ -304,11 +319,29 @@ class QuantInt8NLPComparisonTest(unittest.TestCase): ...@@ -304,11 +319,29 @@ class QuantInt8NLPComparisonTest(unittest.TestCase):
batch_size, batch_size,
batch_num, batch_num,
skip_batch_num, skip_batch_num,
transform_to_int8=True) target='int8')
_logger.info('INT8: avg accuracy: {0:.6f}'.format(int8_acc)) self._print_performance('INT8', int8_pps, int8_lat)
self._print_accuracy('INT8', int8_acc)
fp32_acc = fp32_pps = fp32_lat = -1
if fp32_model_path:
_logger.info('--- FP32 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path, labels_path),
batch_size=batch_size)
fp32_acc, fp32_pps, fp32_lat = self._predict(
val_reader,
fp32_model_path,
batch_size,
batch_num,
skip_batch_num,
target='fp32')
self._print_performance('FP32', fp32_pps, fp32_lat)
self._print_accuracy('FP32', fp32_acc)
self._summarize_performance(fp32_pps, fp32_lat, int8_pps, int8_lat) self._summarize_performance(int8_pps, int8_lat, fp32_pps, fp32_lat)
self._compare_accuracy(fp32_acc, int8_acc, acc_diff_threshold) self._summarize_accuracy(quant_acc, int8_acc, fp32_acc)
self._compare_accuracy(acc_diff_threshold, quant_acc, int8_acc)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -35,11 +35,6 @@ def parse_args(): ...@@ -35,11 +35,6 @@ def parse_args():
type=str, type=str,
default='', default='',
help='A path to a Quant model.') help='A path to a Quant model.')
parser.add_argument(
'--fp32_model_save_path',
type=str,
default='',
help='Saved optimized fp32 model')
parser.add_argument( parser.add_argument(
'--int8_model_save_path', '--int8_model_save_path',
type=str, type=str,
...@@ -65,7 +60,7 @@ def parse_args(): ...@@ -65,7 +60,7 @@ def parse_args():
return test_args, sys.argv[:1] + args return test_args, sys.argv[:1] + args
def transform_and_save_model(original_path, save_path, save_type): def transform_and_save_int8_model(original_path, save_path):
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
inference_scope = fluid.executor.global_scope() inference_scope = fluid.executor.global_scope()
...@@ -96,26 +91,18 @@ def transform_and_save_model(original_path, save_path, save_type): ...@@ -96,26 +91,18 @@ def transform_and_save_model(original_path, save_path, save_type):
_place=place, _place=place,
_core=core, _core=core,
_debug=test_args.debug) _debug=test_args.debug)
graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if save_type == 'FP32':
graph = transform_to_mkldnn_int8_pass.apply_fp32(graph)
elif save_type == 'INT8':
graph = transform_to_mkldnn_int8_pass.apply(graph) graph = transform_to_mkldnn_int8_pass.apply(graph)
inference_program = graph.to_program() inference_program = graph.to_program()
with fluid.scope_guard(inference_scope): with fluid.scope_guard(inference_scope):
fluid.io.save_inference_model(save_path, feed_target_names, fluid.io.save_inference_model(save_path, feed_target_names,
fetch_targets, exe, inference_program) fetch_targets, exe, inference_program)
print("Success! Transformed Quant_{0} model can be found at {1}\n". print(
format(save_type, save_path)) "Success! INT8 model obtained from the Quant model can be found at {}\n"
.format(save_path))
if __name__ == '__main__': if __name__ == '__main__':
global test_args global test_args
test_args, remaining_args = parse_args() test_args, remaining_args = parse_args()
if test_args.fp32_model_save_path: transform_and_save_int8_model(test_args.quant_model_path,
transform_and_save_model(test_args.quant_model_path, test_args.int8_model_save_path)
test_args.fp32_model_save_path, 'FP32')
if test_args.int8_model_save_path:
transform_and_save_model(test_args.quant_model_path,
test_args.int8_model_save_path, 'INT8')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册