未验证 提交 4cddb43c 编写于 作者: W Wojciech Uss 提交者: GitHub

Add support for Ernie NLP model to the Slim QAT (#22506)

* a test for Ernie QAT INT8 accuracy check

test=develop

* Remove NLP comparison test to split PRs

test=develop

* Fix typo and tabs, delete commented lines

test=develop

* re-combine the 2 PRs, test=develop
Co-authored-by: NMichał Gallus <sand3r@interia.eu>
Co-authored-by: Nbingyanghuang <33643817+bingyanghuang@users.noreply.github.com>
上级 5a1a9a1e
...@@ -226,15 +226,18 @@ if(WITH_MKLDNN) ...@@ -226,15 +226,18 @@ if(WITH_MKLDNN)
set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2") set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
### Image classification tests ## Image classification models
set(IMAGENET_DATA_PATH "${INT8_DATA_DIR}/data.bin")
set(INT8_IMG_CLASS_TEST_APP "test_analyzer_int8_image_classification")
set(INT8_IMG_CLASS_TEST_APP_SRC "analyzer_int8_image_classification_tester.cc")
# download dataset if necessary # ImageNet small dataset
download_int8_data(${INT8_DATA_DIR} "imagenet_val_100_tail.tar.gz") # May be already downloaded for INT8 QAT unit tests
set(IMAGENET_DATA_ARCHIVE "imagenet_val_100_tail.tar.gz")
set(IMAGENET_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/imagenet")
set(IMAGENET_DATA_PATH "${IMAGENET_DATA_DIR}/data.bin")
download_int8_data(${IMAGENET_DATA_DIR} ${IMAGENET_DATA_ARCHIVE})
# build test binary to be used in subsequent tests # build test binary to be used in subsequent tests
set(INT8_IMG_CLASS_TEST_APP "test_analyzer_int8_image_classification")
set(INT8_IMG_CLASS_TEST_APP_SRC "analyzer_int8_image_classification_tester.cc")
inference_analysis_api_test_build(${INT8_IMG_CLASS_TEST_APP} ${INT8_IMG_CLASS_TEST_APP_SRC}) inference_analysis_api_test_build(${INT8_IMG_CLASS_TEST_APP} ${INT8_IMG_CLASS_TEST_APP_SRC})
# resnet50 int8 # resnet50 int8
...@@ -296,7 +299,7 @@ if(WITH_MKLDNN) ...@@ -296,7 +299,7 @@ if(WITH_MKLDNN)
### optimized FP32 vs. QAT INT8 tests ### optimized FP32 vs. QAT INT8 tests
set(QAT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2") set(QAT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/qat")
set(QAT_IMG_CLASS_TEST_APP "test_analyzer_qat_image_classification") set(QAT_IMG_CLASS_TEST_APP "test_analyzer_qat_image_classification")
set(QAT_IMG_CLASS_TEST_APP_SRC "analyzer_qat_image_classification_tester.cc") set(QAT_IMG_CLASS_TEST_APP_SRC "analyzer_qat_image_classification_tester.cc")
...@@ -304,8 +307,8 @@ if(WITH_MKLDNN) ...@@ -304,8 +307,8 @@ if(WITH_MKLDNN)
inference_analysis_api_test_build(${QAT_IMG_CLASS_TEST_APP} ${QAT_IMG_CLASS_TEST_APP_SRC}) inference_analysis_api_test_build(${QAT_IMG_CLASS_TEST_APP} ${QAT_IMG_CLASS_TEST_APP_SRC})
# MobileNet FP32 vs. QAT INT8 # MobileNet FP32 vs. QAT INT8
# The FP32 model should already be downloaded for slim QAT unit tests
set(QAT2_MobileNet_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf") set(QAT2_MobileNet_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf")
download_qat_data(${QAT2_MobileNet_MODEL_DIR} "MobileNet_qat_perf.tar.gz")
set(QAT2_INT8_MobileNet_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf_int8") set(QAT2_INT8_MobileNet_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf_int8")
download_qat_data(${QAT2_INT8_MobileNet_MODEL_DIR} "MobileNet_qat_perf_int8.tar.gz") download_qat_data(${QAT2_INT8_MobileNet_MODEL_DIR} "MobileNet_qat_perf_int8.tar.gz")
inference_analysis_api_qat_test_run(test_analyzer_qat_performance_benchmark ${QAT_IMG_CLASS_TEST_APP} ${QAT2_MobileNet_MODEL_DIR}/MobileNet_qat_perf/float ${QAT2_INT8_MobileNet_MODEL_DIR}/MobileNet_qat_perf_int8 ${IMAGENET_DATA_PATH}) inference_analysis_api_qat_test_run(test_analyzer_qat_performance_benchmark ${QAT_IMG_CLASS_TEST_APP} ${QAT2_MobileNet_MODEL_DIR}/MobileNet_qat_perf/float ${QAT2_INT8_MobileNet_MODEL_DIR}/MobileNet_qat_perf_int8 ${IMAGENET_DATA_PATH})
......
...@@ -479,7 +479,7 @@ GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx, ...@@ -479,7 +479,7 @@ GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx,
const Tensor* weights, const Tensor* weights,
const mkldnn::engine& mkldnn_engine) { const mkldnn::engine& mkldnn_engine) {
const std::string key = platform::CreateKey( const std::string key = platform::CreateKey(
platform::ThreadIDasStr(), input->format(), platform::ThreadIDasStr(), input->format(), input->dims()[0],
framework::vectorize<int>(weights->dims()), ctx.OutputName("Out")); framework::vectorize<int>(weights->dims()), ctx.OutputName("Out"));
auto prim_creator = auto prim_creator =
......
...@@ -17,10 +17,10 @@ from .... import core ...@@ -17,10 +17,10 @@ from .... import core
from ....framework import IrGraph from ....framework import IrGraph
from ....framework import IrNode from ....framework import IrNode
__all__ = ['FakeQAT2MkldnnINT8KernelPass', 'FakeQAT2MkldnnINT8PerfPass'] __all__ = ['QatInt8MkldnnPass', 'Qat2Int8MkldnnPass']
class FakeQAT2MkldnnINT8KernelPass(object): class QatInt8MkldnnPass(object):
""" """
Convert QuantizationFreezePass generated IrGraph to MKL-DNN supported INT8 Convert QuantizationFreezePass generated IrGraph to MKL-DNN supported INT8
IrGraph. Following transformations did in this pass: IrGraph. Following transformations did in this pass:
...@@ -48,13 +48,13 @@ class FakeQAT2MkldnnINT8KernelPass(object): ...@@ -48,13 +48,13 @@ class FakeQAT2MkldnnINT8KernelPass(object):
# The original graph will be rewrite. # The original graph will be rewrite.
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization \ from paddle.fluid.contrib.slim.quantization \
import FakeQAT2MkldnnINT8KernelPass import QatInt8MkldnnPass
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid import core from paddle.fluid import core
graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False) graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False)
place = fluid.CPUPlace() place = fluid.CPUPlace()
mkldnn_pass = FakeQAT2MkldnnINT8KernelPass(fluid.global_scope(), mkldnn_pass = QatInt8MkldnnPass(fluid.global_scope(),
place) place)
mkldnn_pass.apply(graph) mkldnn_pass.apply(graph)
""" """
...@@ -276,7 +276,7 @@ class FakeQAT2MkldnnINT8KernelPass(object): ...@@ -276,7 +276,7 @@ class FakeQAT2MkldnnINT8KernelPass(object):
graph.safe_remove_nodes(all_unused_vars) graph.safe_remove_nodes(all_unused_vars)
class FakeQAT2MkldnnINT8PerfPass(object): class Qat2Int8MkldnnPass(object):
""" """
Transform a QAT model IrGraph into MKL-DNN supported INT8 IrGraph. Transform a QAT model IrGraph into MKL-DNN supported INT8 IrGraph.
The pass consists of the following transformations: The pass consists of the following transformations:
...@@ -290,7 +290,12 @@ class FakeQAT2MkldnnINT8PerfPass(object): ...@@ -290,7 +290,12 @@ class FakeQAT2MkldnnINT8PerfPass(object):
passes (`cpu_quantize_pass`, `cpu_quantize_squash_pass`). passes (`cpu_quantize_pass`, `cpu_quantize_squash_pass`).
""" """
def __init__(self, _scope=None, _place=None, _core=None, _debug=False): def __init__(self,
_quantized_ops,
_scope=None,
_place=None,
_core=None,
_debug=False):
self._scope = _scope self._scope = _scope
self._place = _place self._place = _place
self._core = _core self._core = _core
...@@ -305,6 +310,10 @@ class FakeQAT2MkldnnINT8PerfPass(object): ...@@ -305,6 +310,10 @@ class FakeQAT2MkldnnINT8PerfPass(object):
'fake_quantize_dequantize_moving_average_abs_max' 'fake_quantize_dequantize_moving_average_abs_max'
] ]
self._fake_dequantize_types = ['fake_dequantize_max_abs'] self._fake_dequantize_types = ['fake_dequantize_max_abs']
self._quantized_ops = _quantized_ops
self._scale_immutable_ops = [
'transpose2', 'reshape2', 'pool2d', 'scale'
]
self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._pool_ops = ['pool2d'] self._pool_ops = ['pool2d']
self._mul_ops = ['mul'] self._mul_ops = ['mul']
...@@ -324,8 +333,9 @@ class FakeQAT2MkldnnINT8PerfPass(object): ...@@ -324,8 +333,9 @@ class FakeQAT2MkldnnINT8PerfPass(object):
graph = self._dequantize_weights(graph) graph = self._dequantize_weights(graph)
graph = self._optimize_fp32_graph(graph) graph = self._optimize_fp32_graph(graph)
graph = self._compute_weight_scales(graph) graph = self._compute_weight_scales(graph)
graph = self._update_conv_relu_scales(graph) graph = self._update_relu_output_scales(graph)
graph = self._update_pooling_scales(graph) graph = self._propagate_scales(graph)
graph = self._set_dummy_fc_out_scales(graph)
graph = self._quantize_fp32_graph(graph) graph = self._quantize_fp32_graph(graph)
graph = self._remove_unused_var_nodes(graph) graph = self._remove_unused_var_nodes(graph)
return graph return graph
...@@ -346,6 +356,12 @@ class FakeQAT2MkldnnINT8PerfPass(object): ...@@ -346,6 +356,12 @@ class FakeQAT2MkldnnINT8PerfPass(object):
tensor.set(scale, core.CPUPlace()) tensor.set(scale, core.CPUPlace())
return tensor return tensor
def _is_conv_quantized(self):
return any(op_type in self._quantized_ops for op_type in self._conv_ops)
def _is_fc_quantized(self):
return 'fc' in self._quantized_ops
def _gather_scales(self, graph): def _gather_scales(self, graph):
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._quantize_types: if op.name() in self._quantize_types:
...@@ -371,34 +387,94 @@ class FakeQAT2MkldnnINT8PerfPass(object): ...@@ -371,34 +387,94 @@ class FakeQAT2MkldnnINT8PerfPass(object):
self._weight_scales[input_name] = _max_range self._weight_scales[input_name] = _max_range
return graph return graph
def _update_pooling_scales(self, graph): def _propagate_scales(self, graph):
def _update_scale_op_in_scale(op, input, output):
unsigned, tensor = self._var_quant_scales[output]
scale = np.array(tensor) * op.op().attr("scale")
new_tensor = self._convert_scale2tensor(scale.astype(np.float64))
self._var_quant_scales[input] = (unsigned, new_tensor)
def _update_scales(graph):
waiting_for_scale = set()
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._pool_ops: if op.name() in self._scale_immutable_ops:
input_name = op.input("X")[0] input_name = op.input("X")[0]
output_name = op.output("Out")[0] output_name = op.output("Out")[0]
tensor_names = [input_name, output_name]
# Scale is not quantized, so if it doesn't have any scales
# to propagate, its tensors won't be added to the waiting list.
if all(name not in self._var_quant_scales for name in tensor_names) \
and op.name() != 'scale':
waiting_for_scale.update(tensor_names)
continue
if input_name in self._var_quant_scales: if input_name in self._var_quant_scales:
self._var_quant_scales[ self._var_quant_scales[
output_name] = self._var_quant_scales[input_name] output_name] = self._var_quant_scales[input_name]
elif output_name in self._var_quant_scales:
if op.name() == 'scale':
_update_scale_op_in_scale(op, input_name,
output_name)
else:
self._var_quant_scales[
input_name] = self._var_quant_scales[
output_name]
return waiting_for_scale
waiting_for_scale = _update_scales(graph)
while len(waiting_for_scale) != 0:
waiting_for_scale = _update_scales(graph)
return graph
def _set_dummy_fc_out_scales(self, graph):
'''
For the output tensors of FC that do not have an assigned scale,
assign a dummy scale (same scale as input), so that the quantize pass
won't fail. In the end these scales aren't used, since FCs that
have an unassigend output scale will have a force_fp32_output attr
set to True.
'''
for op in graph.all_op_nodes():
if op.name() in self._fc_ops:
input_name = op.input("Input")[0]
output_name = op.output("Out")[0]
if input_name in self._var_quant_scales and \
output_name not in self._var_quant_scales:
# use input scale as a "dummy" scale
self._var_quant_scales[
output_name] = self._var_quant_scales[input_name]
return graph return graph
def _load_param(self, scope, param_name): def _load_param(self, scope, param_name):
return np.array(scope.find_var(param_name).get_tensor()) return np.array(scope.find_var(param_name).get_tensor())
def _remove_fake_ops(self, graph): def _remove_fake_ops(self, graph):
'''
When FC isn't quantized:
Remove fake (de)quantize ops that do not surround mul.
When FC is quantized:
Remove all fake (de)quantize ops.
'''
is_fc_quantized = self._is_fc_quantized()
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._fake_quantize_types: if op.name() in self._fake_quantize_types:
op_out = graph._find_node_by_name(op.outputs, op_out = graph._find_node_by_name(op.outputs,
op.output("Out")[0]) op.output("Out")[0])
next_op = op_out.outputs[0] next_op = op_out.outputs[0]
if next_op.name() not in self._mul_ops: if next_op.name() not in self._mul_ops or is_fc_quantized:
self._remove_fake_quantize(graph, op) self._remove_fake_quantize(graph, op)
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._fake_dequantize_types: if op.name() in self._fake_dequantize_types:
op_in = graph._find_node_by_name(op.inputs, op.input("X")[0]) op_in = graph._find_node_by_name(op.inputs, op.input("X")[0])
prev_op = op_in.inputs[0] prev_op = op_in.inputs[0]
if prev_op.name() not in self._mul_ops: if prev_op.name() not in self._mul_ops or is_fc_quantized:
self._remove_fake_dequantize(graph, op) self._remove_fake_dequantize(graph, op)
return graph return graph
def _remove_fake_quantize(self, graph, op): def _remove_fake_quantize(self, graph, op):
...@@ -444,6 +520,8 @@ class FakeQAT2MkldnnINT8PerfPass(object): ...@@ -444,6 +520,8 @@ class FakeQAT2MkldnnINT8PerfPass(object):
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._conv_ops: if op.name() in self._conv_ops:
self._dequantize_conv_weights(graph, op) self._dequantize_conv_weights(graph, op)
elif self._is_fc_quantized() and op.name() in self._mul_ops:
self._dequantize_mul_weights(graph, op)
return graph return graph
def _dequantize_conv_weights(self, graph, op_node): def _dequantize_conv_weights(self, graph, op_node):
...@@ -472,13 +550,20 @@ class FakeQAT2MkldnnINT8PerfPass(object): ...@@ -472,13 +550,20 @@ class FakeQAT2MkldnnINT8PerfPass(object):
def _optimize_fp32_graph(self, graph): def _optimize_fp32_graph(self, graph):
graph = self._apply_pass(graph, 'mkldnn_placement_pass', graph = self._apply_pass(graph, 'mkldnn_placement_pass',
['mkldnn_enabled_op_types'], [set()]) ['mkldnn_enabled_op_types'], [set()])
if self._is_conv_quantized():
graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass') graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass')
graph = self._apply_pass(graph, 'conv_bn_fuse_pass') graph = self._apply_pass(graph, 'conv_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass') graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_bias_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_bias_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass') graph = self._apply_pass(graph,
'conv_elementwise_add_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass')
if self._is_fc_quantized():
graph = self._apply_pass(graph, 'fc_fuse_pass',
['use_gpu', 'use_fc_padding'],
[False, False])
graph = self._apply_pass(graph, 'fc_mkldnn_pass')
return graph return graph
def _apply_pass(self, graph, pass_name, attrs=None, attr_values=None): def _apply_pass(self, graph, pass_name, attrs=None, attr_values=None):
...@@ -528,6 +613,7 @@ class FakeQAT2MkldnnINT8PerfPass(object): ...@@ -528,6 +613,7 @@ class FakeQAT2MkldnnINT8PerfPass(object):
np.abs(weights.reshape(weights.shape[0], -1)).astype( np.abs(weights.reshape(weights.shape[0], -1)).astype(
np.float64), np.float64),
axis=axis) axis=axis)
scales[scales == np.Inf] = 0.0
lod_tensor = self._convert_scale2tensor(scales) lod_tensor = self._convert_scale2tensor(scales)
use_unsigned_int = False use_unsigned_int = False
...@@ -546,46 +632,41 @@ class FakeQAT2MkldnnINT8PerfPass(object): ...@@ -546,46 +632,41 @@ class FakeQAT2MkldnnINT8PerfPass(object):
ids.append(op.id()) ids.append(op.id())
return set(ids) if len(ids) else set([-1]) return set(ids) if len(ids) else set([-1])
def _transform_to_quantize_mkldnn(self, graph, op_node): def _update_relu_output_scales(self, graph):
""" def _update_scale(graph, ops, op_out_name, predicate):
Transform fake_quantize_xx op to quantize mkldnn op in the graph. '''
""" Sets the type of an output scale of a passed op type(s) to 'unsigned int8' if the
input_var_node = graph._find_node_by_name(op_node.inputs, predicate applied on op passes. Typically, the predicate checks if op's
op_node.input("X")[0]) activation is set to relu.
output_var_node = graph._find_node_by_name(op_node.outputs, '''
op_node.output("Out")[0])
scale_in = self._s8_max / self._load_param(
self._scope, op_node.input("InScale")[0])[0]
quant_op_node = graph.create_op_node(
op_type='quantize',
attrs={
'data_format': 'MKLDNNLAYOUT',
'use_mkldnn': 1,
'Scale': scale_in,
'is_negative_input': 1
},
inputs={'Input': input_var_node},
outputs={'Output': output_var_node})
graph.link_to(input_var_node, quant_op_node)
graph.link_to(quant_op_node, output_var_node)
graph.safe_remove_nodes(op_node)
return quant_op_node
def _update_conv_relu_scales(self, graph):
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._conv_ops: if op.name() in ops:
out_name = op.output("Output")[0] out_name = op.output(op_out_name)[0]
if out_name in self._var_quant_scales and \ if out_name in self._var_quant_scales and predicate(op.op(
op.op().attr("fuse_activation") == 'relu' and \ )):
op.op().attr("fuse_residual_connection") == False:
_, tensor = self._var_quant_scales[out_name] _, tensor = self._var_quant_scales[out_name]
self._var_quant_scales[out_name] = (True, tensor) self._var_quant_scales[out_name] = (True, tensor)
return graph return graph
if self._is_conv_quantized():
conv_predicate = lambda op: op.attr("fuse_activation") == 'relu' and \
op.attr("fuse_residual_connection") == False
graph = _update_scale(graph, self._conv_ops, "Output",
conv_predicate)
if self._is_fc_quantized():
fc_predicate = lambda op: op.attr("activation_type") == 'relu'
graph = _update_scale(graph, self._fc_ops, "Out", fc_predicate)
return graph
def _get_data_layout(self):
return 'NHWC' if self._is_conv_quantized() else 'NCHW'
def _quantize_fp32_graph(self, graph): def _quantize_fp32_graph(self, graph):
ir_pass = self._core.get_pass('cpu_quantize_placement_pass') ir_pass = self._core.get_pass('cpu_quantize_placement_pass')
cpp_graph = graph.graph cpp_graph = graph.graph
ir_pass.set('quantize_enabled_op_types', {'conv2d', 'pool2d'}) ir_pass.set('quantize_enabled_op_types', self._quantized_ops)
ir_pass.set('quantize_excluded_op_ids', ir_pass.set('quantize_excluded_op_ids',
self._find_avg_pooling_ids(graph)) self._find_avg_pooling_ids(graph))
ir_pass.apply(cpp_graph) ir_pass.apply(cpp_graph)
...@@ -593,8 +674,8 @@ class FakeQAT2MkldnnINT8PerfPass(object): ...@@ -593,8 +674,8 @@ class FakeQAT2MkldnnINT8PerfPass(object):
graph.draw('.', 'qat_int8_{}'.format(ir_pass.type()), graph.draw('.', 'qat_int8_{}'.format(ir_pass.type()),
graph.all_op_nodes()) graph.all_op_nodes())
graph = self._apply_pass(graph, 'cpu_quantize_pass', graph = self._apply_pass(
['quant_var_scales'], graph, 'cpu_quantize_pass', ['quant_var_scales', 'data_layout'],
[self._var_quant_scales]) [self._var_quant_scales, self._get_data_layout()])
graph = self._apply_pass(graph, 'cpu_quantize_squash_pass') graph = self._apply_pass(graph, 'cpu_quantize_squash_pass')
return graph return graph
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
function(_inference_analysis_python_api_int8_test target model_dir data_dir filename use_mkldnn) function(_inference_analysis_python_api_int8_test target model_dir data_path filename use_mkldnn)
py_test(${target} SRCS ${filename} py_test(${target} SRCS ${filename}
ENVS CPU_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} ENVS CPU_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=${use_mkldnn} FLAGS_use_mkldnn=${use_mkldnn}
ARGS --infer_model ${model_dir}/model ARGS --infer_model ${model_dir}/model
--infer_data ${data_dir}/data.bin --infer_data ${data_path}
--int8_model_save_path int8_models/${target} --int8_model_save_path int8_models/${target}
--warmup_batch_size ${WARMUP_BATCH_SIZE} --warmup_batch_size ${WARMUP_BATCH_SIZE}
--batch_size 50) --batch_size 50)
endfunction() endfunction()
function(inference_analysis_python_api_int8_test target model_dir data_dir filename) function(inference_analysis_python_api_int8_test target model_dir data_path filename)
_inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename} False) _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} False)
endfunction()
function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_dir filename)
_inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename} True)
endfunction() endfunction()
function(inference_analysis_python_api_int8_test_custom_warmup_batch_size target model_dir data_dir filename warmup_batch_size) function(inference_analysis_python_api_int8_test_custom_warmup_batch_size target model_dir data_dir filename warmup_batch_size)
...@@ -25,13 +21,29 @@ function(inference_analysis_python_api_int8_test_custom_warmup_batch_size target ...@@ -25,13 +21,29 @@ function(inference_analysis_python_api_int8_test_custom_warmup_batch_size target
inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename}) inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename})
endfunction() endfunction()
function(inference_qat_int8_test target model_dir data_dir test_script use_mkldnn) function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_path filename)
py_test(${target} SRCS ${test_script} _inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} True)
endfunction()
function(download_qat_data install_dir data_file)
if (NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file})
endif()
endfunction()
function(download_qat_model install_dir data_file)
if (NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file})
endif()
endfunction()
function(inference_qat_int8_image_classification_test target model_dir dataset_path)
py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat_int8_image_classification_comparison.py"
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=${use_mkldnn} FLAGS_use_mkldnn=true
ARGS --qat_model ${model_dir}/model ARGS --qat_model ${model_dir}/model
--infer_data ${data_dir}/data.bin --infer_data ${dataset_path}
--batch_size 25 --batch_size 25
--batch_num 2 --batch_num 2
--acc_diff_threshold 0.1) --acc_diff_threshold 0.1)
...@@ -39,24 +51,53 @@ endfunction() ...@@ -39,24 +51,53 @@ endfunction()
# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 25 # set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 25
function(inference_qat2_int8_test target model_dir data_dir test_script use_mkldnn) function(inference_qat2_int8_image_classification_test target model_dir data_path quantized_ops)
py_test(${target} SRCS ${test_script} py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat_int8_image_classification_comparison.py"
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=${use_mkldnn} FLAGS_use_mkldnn=true
ARGS --qat_model ${model_dir}/float ARGS --qat_model ${model_dir}/float
--infer_data ${data_dir}/data.bin --infer_data ${data_path}
--batch_size 10 --batch_size 10
--batch_num 2 --batch_num 2
--acc_diff_threshold 0.1 --acc_diff_threshold 0.1
--quantized_ops ${quantized_ops}
--qat2) --qat2)
endfunction() endfunction()
function(save_qat_model_test target qat_model_dir fp32_model_save_path int8_model_save_path test_script) # set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20
py_test(${target} SRCS ${test_script} function(inference_qat2_int8_nlp_test target model_dir data_path labels_path quantized_ops)
py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/qat_int8_nlp_comparison.py"
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=true
ARGS --qat_model ${model_dir}/float
--infer_data ${data_path}
--labels ${labels_path}
--batch_size 10
--batch_num 2
--quantized_ops ${quantized_ops}
--acc_diff_threshold 0.1)
endfunction()
function(download_qat_data install_dir data_file)
if (NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file})
endif()
endfunction()
function(download_qat_model install_dir data_file)
if (NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file})
endif()
endfunction()
function(save_qat_model_test target qat_model_dir fp32_model_save_path int8_model_save_path quantized_ops)
py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_qat_model.py
ARGS --qat_model_path ${qat_model_dir} ARGS --qat_model_path ${qat_model_dir}
--fp32_model_save_path ${fp32_model_save_path} --fp32_model_save_path ${fp32_model_save_path}
--int8_model_save_path ${int8_model_save_path}) --int8_model_save_path ${int8_model_save_path}
--quantized_ops ${quantized_ops})
endfunction() endfunction()
if(WIN32) if(WIN32)
...@@ -66,137 +107,151 @@ if(WIN32) ...@@ -66,137 +107,151 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1) list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1)
endif() endif()
# int8 image classification python api test
if(LINUX AND WITH_MKLDNN) if(LINUX AND WITH_MKLDNN)
set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
set(MKLDNN_INT8_TEST_FILE "test_mkldnn_int8_quantization_strategy.py") #### Image classification dataset: ImageNet (small)
set(MKLDNN_INT8_TEST_FILE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_INT8_TEST_FILE}") # The dataset should already be downloaded for INT8v2 unit tests
set(IMAGENET_DATA_PATH "${INFERENCE_DEMO_INSTALL_DIR}/imagenet/data.bin")
#### INT8 image classification python api test
# Models should be already downloaded for INT8v2 unit tests
set(INT8_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
set(INT8_IC_TEST_FILE "test_mkldnn_int8_quantization_strategy.py")
set(INT8_IC_TEST_FILE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${INT8_IC_TEST_FILE}")
# googlenet int8 # googlenet int8
set(INT8_GOOGLENET_MODEL_DIR "${INT8_DATA_DIR}/googlenet") set(INT8_GOOGLENET_MODEL_DIR "${INT8_INSTALL_DIR}/googlenet")
inference_analysis_python_api_int8_test_custom_warmup_batch_size(test_slim_int8_googlenet ${INT8_GOOGLENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH} 10) inference_analysis_python_api_int8_test_custom_warmup_batch_size(test_slim_int8_googlenet ${INT8_GOOGLENET_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH} 10)
# mobilenet int8 # mobilenet int8
set(INT8_MOBILENET_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv1") set(INT8_MOBILENET_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1")
inference_analysis_python_api_int8_test(test_slim_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH}) inference_analysis_python_api_int8_test(test_slim_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
inference_analysis_python_api_int8_test_mkldnn(test_slim_int8_mobilenet_mkldnn ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH}) inference_analysis_python_api_int8_test_mkldnn(test_slim_int8_mobilenet_mkldnn ${INT8_MOBILENET_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
# temporarily adding WITH_SLIM_MKLDNN_FULL_TEST FLAG for QA testing the following UTs locally, # temporarily adding WITH_SLIM_MKLDNN_FULL_TEST FLAG for QA testing the following UTs locally,
# since the following UTs cost too much time on CI test. # since the following UTs cost too much time on CI test.
if (WITH_SLIM_MKLDNN_FULL_TEST) if (WITH_SLIM_MKLDNN_FULL_TEST)
# resnet50 int8 # resnet50 int8
set(INT8_RESNET50_MODEL_DIR "${INT8_DATA_DIR}/resnet50") set(INT8_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50")
inference_analysis_python_api_int8_test(test_slim_int8_resnet50 ${INT8_RESNET50_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH}) inference_analysis_python_api_int8_test(test_slim_int8_resnet50 ${INT8_RESNET50_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
# mobilenetv2 int8 # mobilenetv2 int8
set(INT8_MOBILENETV2_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv2") set(INT8_MOBILENETV2_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv2")
inference_analysis_python_api_int8_test(test_slim_int8_mobilenetv2 ${INT8_MOBILENETV2_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH}) inference_analysis_python_api_int8_test(test_slim_int8_mobilenetv2 ${INT8_MOBILENETV2_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
# resnet101 int8 # resnet101 int8
set(INT8_RESNET101_MODEL_DIR "${INT8_DATA_DIR}/resnet101") set(INT8_RESNET101_MODEL_DIR "${INT8_INSTALL_DIR}/resnet101")
inference_analysis_python_api_int8_test(test_slim_int8_resnet101 ${INT8_RESNET101_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH}) inference_analysis_python_api_int8_test(test_slim_int8_resnet101 ${INT8_RESNET101_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
# vgg16 int8 # vgg16 int8
set(INT8_VGG16_MODEL_DIR "${INT8_DATA_DIR}/vgg16") set(INT8_VGG16_MODEL_DIR "${INT8_INSTALL_DIR}/vgg16")
inference_analysis_python_api_int8_test(test_slim_int8_vgg16 ${INT8_VGG16_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH}) inference_analysis_python_api_int8_test(test_slim_int8_vgg16 ${INT8_VGG16_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
# vgg19 int8 # vgg19 int8
set(INT8_VGG19_MODEL_DIR "${INT8_DATA_DIR}/vgg19") set(INT8_VGG19_MODEL_DIR "${INT8_INSTALL_DIR}/vgg19")
inference_analysis_python_api_int8_test(test_slim_int8_vgg19 ${INT8_VGG19_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE_PATH}) inference_analysis_python_api_int8_test(test_slim_int8_vgg19 ${INT8_VGG19_MODEL_DIR} ${IMAGENET_DATA_PATH} ${INT8_IC_TEST_FILE_PATH})
endif() endif()
endif()
# Since test_mkldnn_int8_quantization_strategy only supports testing on Linux #### QAT FP32 & INT8 comparison python api tests
# with MKL-DNN, we remove it here for not repeating test, or not testing on other systems.
list(REMOVE_ITEM TEST_OPS test_mkldnn_int8_quantization_strategy)
# QAT FP32 & INT8 comparison python api tests set(QAT_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/qat")
if(LINUX AND WITH_MKLDNN)
set(DATASET_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2") ### QATv1 for image classification
set(QAT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
set(QAT_MODELS_BASE_URL "${INFERENCE_URL}/int8/QAT_models")
set(MKLDNN_QAT_TEST_FILE "qat_int8_comparison.py")
set(MKLDNN_QAT_TEST_FILE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_QAT_TEST_FILE}")
# ImageNet small dataset
# May be already downloaded for INT8v2 unit tests
if (NOT EXISTS ${DATASET_DIR})
inference_download_and_uncompress(${DATASET_DIR} "${INFERENCE_URL}/int8" "imagenet_val_100_tail.tar.gz")
endif()
# QAT ResNet50 # QAT ResNet50
set(QAT_RESNET50_MODEL_DIR "${QAT_DATA_DIR}/ResNet50_QAT") set(QAT_RESNET50_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_QAT")
if (NOT EXISTS ${QAT_RESNET50_MODEL_DIR}) set(QAT_RESNET50_MODEL_ARCHIVE "ResNet50_qat_model.tar.gz")
inference_download_and_uncompress(${QAT_RESNET50_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet50_qat_model.tar.gz" ) download_qat_model(${QAT_RESNET50_MODEL_DIR} ${QAT_RESNET50_MODEL_ARCHIVE})
endif() inference_qat_int8_image_classification_test(test_qat_int8_resnet50_mkldnn ${QAT_RESNET50_MODEL_DIR} ${IMAGENET_DATA_PATH})
inference_qat_int8_test(test_qat_int8_resnet50_mkldnn ${QAT_RESNET50_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
# QAT ResNet101 # QAT ResNet101
set(QAT_RESNET101_MODEL_DIR "${QAT_DATA_DIR}/ResNet101_QAT") set(QAT_RESNET101_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet101_QAT")
if (NOT EXISTS ${QAT_RESNET101_MODEL_DIR}) set(QAT_RESNET101_MODEL_ARCHIVE "ResNet101_qat_model.tar.gz")
inference_download_and_uncompress(${QAT_RESNET101_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet101_qat_model.tar.gz" ) download_qat_model(${QAT_RESNET101_MODEL_DIR} ${QAT_RESNET101_MODEL_ARCHIVE})
endif() # inference_qat_int8_image_classification_test(test_qat_int8_resnet101_mkldnn ${QAT_RESNET101_MODEL_DIR} ${IMAGENET_DATA_PATH})
# inference_qat_int8_test(test_qat_int8_resnet101_mkldnn ${QAT_RESNET101_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
# QAT GoogleNet # QAT GoogleNet
set(QAT_GOOGLENET_MODEL_DIR "${QAT_DATA_DIR}/GoogleNet_QAT") set(QAT_GOOGLENET_MODEL_DIR "${QAT_INSTALL_DIR}/GoogleNet_QAT")
if (NOT EXISTS ${QAT_GOOGLENET_MODEL_DIR}) set(QAT_GOOGLENET_MODEL_ARCHIVE "GoogleNet_qat_model.tar.gz")
inference_download_and_uncompress(${QAT_GOOGLENET_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "GoogleNet_qat_model.tar.gz" ) download_qat_model(${QAT_GOOGLENET_MODEL_DIR} ${QAT_GOOGLENET_MODEL_ARCHIVE})
endif() inference_qat_int8_image_classification_test(test_qat_int8_googlenet_mkldnn ${QAT_GOOGLENET_MODEL_DIR} ${IMAGENET_DATA_PATH})
inference_qat_int8_test(test_qat_int8_googlenet_mkldnn ${QAT_GOOGLENET_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
# QAT MobileNetV1 # QAT MobileNetV1
set(QAT_MOBILENETV1_MODEL_DIR "${QAT_DATA_DIR}/MobileNetV1_QAT") set(QAT_MOBILENETV1_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNetV1_QAT")
if (NOT EXISTS ${QAT_MOBILENETV1_MODEL_DIR}) set(QAT_MOBILENETV1_MODEL_ARCHIVE "MobileNetV1_qat_model.tar.gz")
inference_download_and_uncompress(${QAT_MOBILENETV1_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNetV1_qat_model.tar.gz" ) download_qat_model(${QAT_MOBILENETV1_MODEL_DIR} ${QAT_MOBILENETV1_MODEL_ARCHIVE})
endif() inference_qat_int8_image_classification_test(test_qat_int8_mobilenetv1_mkldnn ${QAT_MOBILENETV1_MODEL_DIR} ${IMAGENET_DATA_PATH})
inference_qat_int8_test(test_qat_int8_mobilenetv1_mkldnn ${QAT_MOBILENETV1_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
# QAT MobileNetV2 # QAT MobileNetV2
set(QAT_MOBILENETV2_MODEL_DIR "${QAT_DATA_DIR}/MobileNetV2_QAT") set(QAT_MOBILENETV2_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNetV2_QAT")
if (NOT EXISTS ${QAT_MOBILENETV2_MODEL_DIR}) set(QAT_MOBILENETV2_MODEL_ARCHIVE "MobileNetV2_qat_model.tar.gz")
inference_download_and_uncompress(${QAT_MOBILENETV2_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNetV2_qat_model.tar.gz" ) download_qat_model(${QAT_MOBILENETV2_MODEL_DIR} ${QAT_MOBILENETV2_MODEL_ARCHIVE})
endif() inference_qat_int8_image_classification_test(test_qat_int8_mobilenetv2_mkldnn ${QAT_MOBILENETV2_MODEL_DIR} ${IMAGENET_DATA_PATH})
inference_qat_int8_test(test_qat_int8_mobilenetv2_mkldnn ${QAT_MOBILENETV2_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
# QAT VGG16 # QAT VGG16
set(QAT_VGG16_MODEL_DIR "${QAT_DATA_DIR}/VGG16_QAT") set(QAT_VGG16_MODEL_DIR "${QAT_INSTALL_DIR}/VGG16_QAT")
if (NOT EXISTS ${QAT_VGG16_MODEL_DIR}) set(QAT_VGG16_MODEL_ARCHIVE "VGG16_qat_model.tar.gz")
inference_download_and_uncompress(${QAT_VGG16_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "VGG16_qat_model.tar.gz" ) download_qat_model(${QAT_VGG16_MODEL_DIR} ${QAT_VGG16_MODEL_ARCHIVE})
endif() # inference_qat_int8_image_classification_test(test_qat_int8_vgg16_mkldnn ${QAT_VGG16_MODEL_DIR} ${IMAGENET_DATA_PATH})
# inference_qat_int8_test(test_qat_int8_vgg16_mkldnn ${QAT_VGG16_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
# QAT VGG19 # QAT VGG19
set(QAT_VGG19_MODEL_DIR "${QAT_DATA_DIR}/VGG19_QAT") set(QAT_VGG19_MODEL_DIR "${QAT_INSTALL_DIR}/VGG19_QAT")
if (NOT EXISTS ${QAT_VGG19_MODEL_DIR}) set(QAT_VGG19_MODEL_ARCHIVE "VGG19_qat_model.tar.gz")
inference_download_and_uncompress(${QAT_VGG19_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "VGG19_qat_model.tar.gz" ) download_qat_model(${QAT_VGG19_MODEL_DIR} ${QAT_VGG19_MODEL_ARCHIVE})
endif() # inference_qat_int8_image_classification_test(test_qat_int8_vgg19_mkldnn ${QAT_VGG19_MODEL_DIR} ${IMAGENET_DATA_PATH})
# inference_qat_int8_test(test_qat_int8_vgg19_mkldnn ${QAT_VGG19_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
set(QAT2_RESNET50_MODEL_DIR "${QAT_DATA_DIR}/ResNet50_qat_perf") ### QATv2 for image classification
if (NOT EXISTS ${QAT2_RESNET50_MODEL_DIR})
inference_download_and_uncompress(${QAT2_RESNET50_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet50_qat_perf.tar.gz" )
endif()
inference_qat2_int8_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf") set(QAT2_IC_QUANTIZED_OPS "conv2d,pool2d")
if (NOT EXISTS ${QAT2_MOBILENETV1_MODEL_DIR})
inference_download_and_uncompress(${QAT2_MOBILENETV1_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNet_qat_perf.tar.gz" ) # QAT2 ResNet50
endif() set(QAT2_RESNET50_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_perf")
inference_qat2_int8_test(test_qat2_int8_mobilenetv1_mkldnn ${QAT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true) set(QAT2_RESNET50_MODEL_ARCHIVE "ResNet50_qat_perf.tar.gz")
download_qat_model(${QAT2_RESNET50_MODEL_DIR} ${QAT2_RESNET50_MODEL_ARCHIVE})
inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS})
# Save qat2 fp32 model or qat2 int8 model # QAT2 MobileNetV1
set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNet_qat_perf")
set(QAT2_MOBILENETV1_MODEL_ARCHIVE "MobileNet_qat_perf.tar.gz")
download_qat_model(${QAT2_MOBILENETV1_MODEL_DIR} ${QAT2_MOBILENETV1_MODEL_ARCHIVE})
inference_qat2_int8_image_classification_test(test_qat2_int8_mobilenetv1_mkldnn ${QAT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS})
### QATv2 for NLP
set(QAT2_NLP_QUANTIZED_OPS "fc,reshape2,transpose2")
set(NLP_DATA_ARCHIVE "Ernie_dataset.tar.gz")
set(NLP_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_dataset")
set(NLP_DATA_PATH "${NLP_DATA_DIR}/Ernie_dataset/1.8w.bs1")
set(NLP_LABLES_PATH "${NLP_DATA_DIR}/Ernie_dataset/label.xnli.dev")
download_qat_data(${NLP_DATA_DIR} ${NLP_DATA_ARCHIVE})
# QAT2 Ernie
set(QAT2_ERNIE_MODEL_ARCHIVE "ernie_qat.tar.gz")
set(QAT2_ERNIE_MODEL_DIR "${QAT_INSTALL_DIR}/Ernie_qat")
download_qat_model(${QAT2_ERNIE_MODEL_DIR} ${QAT2_ERNIE_MODEL_ARCHIVE})
inference_qat2_int8_nlp_test(test_qat2_int8_ernie_mkldnn ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat ${NLP_DATA_PATH} ${NLP_LABLES_PATH} ${QAT2_NLP_QUANTIZED_OPS})
# Save QAT2 FP32 model or QAT2 INT8 model
set(QAT2_INT8_SAVE_PATH "${QAT_DATA_DIR}/ResNet50_qat2_int8") set(QAT2_INT8_SAVE_PATH "${QAT_DATA_DIR}/ResNet50_qat2_int8")
set(QAT2_FP32_SAVE_PATH "${QAT_DATA_DIR}/ResNet50_qat2_fp32") set(QAT2_FP32_SAVE_PATH "${QAT_DATA_DIR}/ResNet50_qat2_fp32")
set(SAVE_QAT2_MODEL_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/save_qat_model.py") set(SAVE_QAT2_MODEL_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/save_qat_model.py")
save_qat_model_test(save_qat2_model_resnet50 ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QAT2_FP32_SAVE_PATH} ${QAT2_INT8_SAVE_PATH} ${SAVE_QAT2_MODEL_SCRIPT} true) save_qat_model_test(save_qat2_model_resnet50 ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QAT2_FP32_SAVE_PATH} ${QAT2_INT8_SAVE_PATH} ${SAVE_QAT2_MODEL_SCRIPT} true)
set(QAT2_INT8_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_int8")
set(QAT2_FP32_ERNIE_SAVE_PATH "${QAT_INSTALL_DIR}/Ernie_qat2_fp32")
save_qat_model_test(save_qat2_model_ernie ${QAT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QAT2_FP32_ERNIE_SAVE_PATH} ${QAT2_INT8_ERNIE_SAVE_PATH} ${QAT2_NLP_QUANTIZED_OPS})
endif() endif()
# Since the test for QAT FP32 & INT8 comparison supports only testing on Linux # Since the tests for QAT FP32 & INT8 comparison support only testing on Linux
# with MKL-DNN, we remove it here to not test it on other systems. # with MKL-DNN, we remove it here to not test it on other systems.
list(REMOVE_ITEM TEST_OPS qat_int8_comparison.py) list(REMOVE_ITEM TEST_OPS
test_mkldnn_int8_quantization_strategy
qat_int8_image_classification_comparison
qat_int8_nlp_comparison)
foreach(src ${TEST_OPS}) foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py) py_test(${src} SRCS ${src}.py)
......
...@@ -24,8 +24,8 @@ import time ...@@ -24,8 +24,8 @@ import time
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass from paddle.fluid.contrib.slim.quantization import QatInt8MkldnnPass
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8PerfPass from paddle.fluid.contrib.slim.quantization import Qat2Int8MkldnnPass
from paddle.fluid import core from paddle.fluid import core
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
...@@ -53,10 +53,6 @@ def parse_args(): ...@@ -53,10 +53,6 @@ def parse_args():
action='store_true', action='store_true',
help='If used, the QAT model is treated as a second generation model for performance optimization.' help='If used, the QAT model is treated as a second generation model for performance optimization.'
) )
parser.add_argument(
'--save_model',
action='store_true',
help='If used, the QAT model will be saved after all transformations')
parser.add_argument('--infer_data', type=str, default='', help='Data file.') parser.add_argument('--infer_data', type=str, default='', help='Data file.')
parser.add_argument( parser.add_argument(
'--batch_num', '--batch_num',
...@@ -68,15 +64,20 @@ def parse_args(): ...@@ -68,15 +64,20 @@ def parse_args():
type=float, type=float,
default=0.01, default=0.01,
help='Accepted accuracy difference threshold.') help='Accepted accuracy difference threshold.')
parser.add_argument(
'--quantized_ops',
type=str,
default='',
help='A comma separated list of quantized operators.')
test_args, args = parser.parse_known_args(namespace=unittest) test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args return test_args, sys.argv[:1] + args
class TestQatInt8Comparison(unittest.TestCase): class QatInt8ImageClassificationComparisonTest(unittest.TestCase):
""" """
Test for accuracy comparison of QAT FP32 and INT8 inference. Test for accuracy comparison of QAT FP32 and INT8 Image Classification inference.
""" """
def _reader_creator(self, data_file='data.bin'): def _reader_creator(self, data_file='data.bin'):
...@@ -182,14 +183,15 @@ class TestQatInt8Comparison(unittest.TestCase): ...@@ -182,14 +183,15 @@ class TestQatInt8Comparison(unittest.TestCase):
graph.draw('.', 'qat_orig', graph.all_op_nodes()) graph.draw('.', 'qat_orig', graph.all_op_nodes())
if (transform_to_int8): if (transform_to_int8):
if (test_case_args.qat2): if (test_case_args.qat2):
transform_to_mkldnn_int8_pass = FakeQAT2MkldnnINT8PerfPass( transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass(
self._quantized_ops,
_scope=inference_scope, _scope=inference_scope,
_place=place, _place=place,
_core=core, _core=core,
_debug=self._debug) _debug=self._debug)
graph = transform_to_mkldnn_int8_pass.apply(graph) graph = transform_to_mkldnn_int8_pass.apply(graph)
else: else:
mkldnn_int8_pass = FakeQAT2MkldnnINT8KernelPass( mkldnn_int8_pass = QatInt8MkldnnPass(
_scope=inference_scope, _place=place) _scope=inference_scope, _place=place)
graph = mkldnn_int8_pass.apply(graph) graph = mkldnn_int8_pass.apply(graph)
...@@ -256,12 +258,6 @@ class TestQatInt8Comparison(unittest.TestCase): ...@@ -256,12 +258,6 @@ class TestQatInt8Comparison(unittest.TestCase):
_logger.info('Total inference run time: {:.2f} s'.format( _logger.info('Total inference run time: {:.2f} s'.format(
infer_total_time)) infer_total_time))
if test_case_args.save_model:
with fluid.scope_guard(inference_scope):
fluid.io.save_inference_model(
'transformed_qat_int8_model', feed_target_names,
fetch_targets, exe, inference_program)
return outputs, acc1_avg, acc5_avg, fps_avg, latency_avg return outputs, acc1_avg, acc5_avg, fps_avg, latency_avg
def _summarize_performance(self, fp32_fps, fp32_lat, int8_fps, int8_lat): def _summarize_performance(self, fp32_fps, fp32_lat, int8_fps, int8_lat):
...@@ -298,6 +294,7 @@ class TestQatInt8Comparison(unittest.TestCase): ...@@ -298,6 +294,7 @@ class TestQatInt8Comparison(unittest.TestCase):
skip_batch_num = test_case_args.skip_batch_num skip_batch_num = test_case_args.skip_batch_num
acc_diff_threshold = test_case_args.acc_diff_threshold acc_diff_threshold = test_case_args.acc_diff_threshold
self._debug = test_case_args.debug self._debug = test_case_args.debug
self._quantized_ops = set(test_case_args.quantized_ops.split(','))
_logger.info('QAT FP32 & INT8 prediction run.') _logger.info('QAT FP32 & INT8 prediction run.')
_logger.info('QAT model: {0}'.format(qat_model_path)) _logger.info('QAT model: {0}'.format(qat_model_path))
...@@ -305,6 +302,7 @@ class TestQatInt8Comparison(unittest.TestCase): ...@@ -305,6 +302,7 @@ class TestQatInt8Comparison(unittest.TestCase):
_logger.info('Batch size: {0}'.format(batch_size)) _logger.info('Batch size: {0}'.format(batch_size))
_logger.info('Batch number: {0}'.format(batch_num)) _logger.info('Batch number: {0}'.format(batch_num))
_logger.info('Accuracy drop threshold: {0}.'.format(acc_diff_threshold)) _logger.info('Accuracy drop threshold: {0}.'.format(acc_diff_threshold))
_logger.info('Quantized ops: {0}.'.format(self._quantized_ops))
_logger.info('--- QAT FP32 prediction start ---') _logger.info('--- QAT FP32 prediction start ---')
val_reader = paddle.batch( val_reader = paddle.batch(
......
# copyright (c) 2020 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import os
import sys
import argparse
import logging
import struct
import six
import numpy as np
import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import Qat2Int8MkldnnPass
from paddle.fluid import core
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=1, help='Batch size.')
parser.add_argument(
'--skip_batch_num',
type=int,
default=0,
help='Number of the first minibatches to skip in performance statistics.'
)
parser.add_argument(
'--debug',
action='store_true',
help='If used, the graph of QAT model is drawn.')
parser.add_argument(
'--qat_model', type=str, default='', help='A path to a QAT model.')
parser.add_argument(
'--save_model',
action='store_true',
help='If used, the QAT model will be saved after all transformations')
parser.add_argument('--infer_data', type=str, default='', help='Data file.')
parser.add_argument(
'--labels', type=str, default='', help='File with labels.')
parser.add_argument(
'--batch_num',
type=int,
default=1,
help='Number of batches to process. 0 or less means all.')
parser.add_argument(
'--acc_diff_threshold',
type=float,
default=0.01,
help='Accepted accuracy difference threshold.')
parser.add_argument(
'--quantized_ops',
type=str,
default='',
help='A comma separated list of quantized operators.')
test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args
class QatInt8NLPComparisonTest(unittest.TestCase):
"""
Test for accuracy comparison of QAT FP32 and INT8 NLP inference.
"""
def _reader_creator(self, data_file=None, labels_file=None):
assert data_file, "The dataset file is missing."
assert labels_file, "The labels file is missing."
def reader():
with open(data_file, 'r') as df:
with open(labels_file, 'r') as lf:
data_lines = df.readlines()
labels_lines = lf.readlines()
assert len(data_lines) == len(
labels_lines
), "The number of labels does not match the length of the dataset."
for i in range(len(data_lines)):
data_fields = data_lines[i].split(';')
assert len(
data_fields
) >= 2, "The number of data fields in the dataset is less than 2"
buffers = []
shape = []
for j in range(2):
data = data_fields[j].split(':')
assert len(
data
) >= 2, "Size of data in the dataset is less than 2"
# Shape is stored under index 0, while data under 1
shape = data[0].split()
shape.pop(0)
shape_np = np.array(shape).astype("int64")
buffer_i = data[1].split()
buffer_np = np.array(buffer_i).astype("int64")
buffer_np.shape = tuple(shape_np)
buffers.append(buffer_np)
label = labels_lines[i]
yield buffers[0], buffers[1], int(label)
return reader
def _get_batch_correct(self, batch_output=None, labels=None):
total = len(batch_output)
assert total > 0, "The batch output is empty."
correct = 0
for n, output in enumerate(batch_output[0]):
max_idx = np.where(output == output.max())
if max_idx == labels[n]:
correct += 1
return correct
def _predict(self,
test_reader=None,
model_path=None,
batch_size=1,
batch_num=1,
skip_batch_num=0,
transform_to_int8=False):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.executor.global_scope()
with fluid.scope_guard(inference_scope):
if os.path.exists(os.path.join(model_path, '__model__')):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(model_path, exe)
else:
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
model_path, exe, 'model', 'params')
graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if (self._debug):
graph.draw('.', 'qat_orig', graph.all_op_nodes())
if (transform_to_int8):
transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass(
self._quantized_ops,
_scope=inference_scope,
_place=place,
_core=core,
_debug=self._debug)
graph = transform_to_mkldnn_int8_pass.apply(graph)
inference_program = graph.to_program()
total_correct = 0
total_samples = 0
batch_times = []
ppses = [] # predictions per second
iters = 0
infer_start_time = time.time()
for data in test_reader():
if batch_num > 0 and iters >= batch_num:
break
if iters == skip_batch_num:
total_samples = 0
infer_start_time = time.time()
input0 = np.array([x[0] for x in data]).astype('int64')
input1 = np.array([x[1] for x in data]).astype('int64')
labels = np.array([x[2] for x in data]).astype('int64')
start = time.time()
out = exe.run(inference_program,
feed={
feed_target_names[0]: input0,
feed_target_names[1]: input1
},
fetch_list=fetch_targets)
batch_time = (time.time() - start) * 1000 # in miliseconds
batch_times.append(batch_time)
batch_correct = self._get_batch_correct(out, labels)
batch_len = len(data)
total_samples += batch_len
total_correct += batch_correct
batch_acc = float(batch_correct) / float(batch_len)
pps = batch_len / batch_time * 1000
ppses.append(pps)
latency = batch_time / batch_len
iters += 1
appx = ' (warm-up)' if iters <= skip_batch_num else ''
_logger.info(
'batch {0}{4}, acc: {1:.4f}, latency: {2:.4f} ms, predictions per sec: {3:.2f}'
.format(iters, batch_acc, latency, pps, appx))
# Postprocess benchmark data
infer_total_time = time.time() - infer_start_time
batch_latencies = batch_times[skip_batch_num:]
batch_latency_avg = np.average(batch_latencies)
latency_avg = batch_latency_avg / batch_size
ppses = ppses[skip_batch_num:]
pps_avg = np.average(ppses)
acc_avg = float(np.sum(total_correct)) / float(total_samples)
_logger.info('Total inference run time: {:.2f} s'.format(
infer_total_time))
return acc_avg, pps_avg, latency_avg
def _summarize_performance(self, fp32_pps, fp32_lat, int8_pps, int8_lat):
_logger.info('--- Performance summary ---')
_logger.info(
'FP32: avg predictions per sec: {0:.2f}, avg latency: {1:.4f} ms'.
format(fp32_pps, fp32_lat))
_logger.info(
'INT8: avg predictions per sec: {0:.2f}, avg latency: {1:.4f} ms'.
format(int8_pps, int8_lat))
def _compare_accuracy(self, fp32_acc, int8_acc, threshold):
_logger.info('--- Accuracy summary ---')
_logger.info(
'Accepted accuracy drop threshold: {0}. (condition: (FP32_acc - INT8_acc) <= threshold)'
.format(threshold))
_logger.info('FP32: avg accuracy: {0:.6f}'.format(fp32_acc))
_logger.info('INT8: avg accuracy: {0:.6f}'.format(int8_acc))
# Random outputs give accuracy about 0.33, we assume valid accuracy to be at least 0.5
assert fp32_acc > 0.5
assert int8_acc > 0.5
assert fp32_acc - int8_acc <= threshold
def test_graph_transformation(self):
if not fluid.core.is_compiled_with_mkldnn():
return
qat_model_path = test_case_args.qat_model
data_path = test_case_args.infer_data
labels_path = test_case_args.labels
batch_size = test_case_args.batch_size
batch_num = test_case_args.batch_num
skip_batch_num = test_case_args.skip_batch_num
acc_diff_threshold = test_case_args.acc_diff_threshold
self._debug = test_case_args.debug
self._quantized_ops = set(test_case_args.quantized_ops.split(','))
_logger.info('QAT FP32 & INT8 prediction run.')
_logger.info('QAT model: {0}'.format(qat_model_path))
_logger.info('Dataset: {0}'.format(data_path))
_logger.info('Labels: {0}'.format(labels_path))
_logger.info('Batch size: {0}'.format(batch_size))
_logger.info('Batch number: {0}'.format(batch_num))
_logger.info('Accuracy drop threshold: {0}.'.format(acc_diff_threshold))
_logger.info('Quantized ops: {0}.'.format(self._quantized_ops))
_logger.info('--- QAT FP32 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path, labels_path), batch_size=batch_size)
fp32_acc, fp32_pps, fp32_lat = self._predict(
val_reader,
qat_model_path,
batch_size,
batch_num,
skip_batch_num,
transform_to_int8=False)
_logger.info('--- QAT INT8 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path, labels_path), batch_size=batch_size)
int8_acc, int8_pps, int8_lat = self._predict(
val_reader,
qat_model_path,
batch_size,
batch_num,
skip_batch_num,
transform_to_int8=True)
self._summarize_performance(fp32_pps, fp32_lat, int8_pps, int8_lat)
self._compare_accuracy(fp32_acc, int8_acc, acc_diff_threshold)
if __name__ == '__main__':
global test_case_args
test_case_args, remaining_args = parse_args()
unittest.main(argv=remaining_args)
...@@ -24,7 +24,7 @@ import time ...@@ -24,7 +24,7 @@ import time
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8PerfPass from paddle.fluid.contrib.slim.quantization import Qat2Int8MkldnnPass
from paddle.fluid import core from paddle.fluid import core
...@@ -42,6 +42,11 @@ def parse_args(): ...@@ -42,6 +42,11 @@ def parse_args():
type=str, type=str,
default='', default='',
help='Saved optimized and quantized INT8 model') help='Saved optimized and quantized INT8 model')
parser.add_argument(
'--quantized_ops',
type=str,
default='',
help='A comma separated list of quantized operators.')
test_args, args = parser.parse_known_args(namespace=unittest) test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args return test_args, sys.argv[:1] + args
...@@ -60,8 +65,9 @@ def transform_and_save_model(original_path, save_path, save_type): ...@@ -60,8 +65,9 @@ def transform_and_save_model(original_path, save_path, save_type):
fetch_targets] = fluid.io.load_inference_model(original_path, exe, fetch_targets] = fluid.io.load_inference_model(original_path, exe,
'model', 'params') 'model', 'params')
transform_to_mkldnn_int8_pass = FakeQAT2MkldnnINT8PerfPass( quantized_ops = set(test_args.quantized_ops.split(','))
_scope=inference_scope, _place=place, _core=core) transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass(
quantized_ops, _scope=inference_scope, _place=place, _core=core)
graph = IrGraph(core.Graph(inference_program.desc), for_test=True) graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if save_type == 'FP32': if save_type == 'FP32':
......
...@@ -22,7 +22,7 @@ import paddle ...@@ -22,7 +22,7 @@ import paddle
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass from paddle.fluid.contrib.slim.quantization import QatInt8MkldnnPass
from paddle.fluid import core from paddle.fluid import core
os.environ["CPU_NUM"] = "1" os.environ["CPU_NUM"] = "1"
...@@ -149,8 +149,7 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase): ...@@ -149,8 +149,7 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
freeze_pass.apply(test_graph) freeze_pass.apply(test_graph)
# Transform quantized graph for MKL-DNN INT8 inference # Transform quantized graph for MKL-DNN INT8 inference
mkldnn_int8_pass = FakeQAT2MkldnnINT8KernelPass( mkldnn_int8_pass = QatInt8MkldnnPass(_scope=scope, _place=place)
_scope=scope, _place=place)
mkldnn_int8_pass.apply(test_graph) mkldnn_int8_pass.apply(test_graph)
dev_name = '_cpu_' dev_name = '_cpu_'
if not for_ci: if not for_ci:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册