未验证 提交 4cddb43c 编写于 作者: W Wojciech Uss 提交者: GitHub

Add support for Ernie NLP model to the Slim QAT (#22506)

* a test for Ernie QAT INT8 accuracy check

test=develop

* Remove NLP comparison test to split PRs

test=develop

* Fix typo and tabs, delete commented lines

test=develop

* re-combine the 2 PRs, test=develop
Co-authored-by: NMichał Gallus <sand3r@interia.eu>
Co-authored-by: Nbingyanghuang <33643817+bingyanghuang@users.noreply.github.com>
上级 5a1a9a1e
......@@ -226,15 +226,18 @@ if(WITH_MKLDNN)
set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
### Image classification tests
set(IMAGENET_DATA_PATH "${INT8_DATA_DIR}/data.bin")
set(INT8_IMG_CLASS_TEST_APP "test_analyzer_int8_image_classification")
set(INT8_IMG_CLASS_TEST_APP_SRC "analyzer_int8_image_classification_tester.cc")
## Image classification models
# download dataset if necessary
download_int8_data(${INT8_DATA_DIR} "imagenet_val_100_tail.tar.gz")
# ImageNet small dataset
# May be already downloaded for INT8 QAT unit tests
set(IMAGENET_DATA_ARCHIVE "imagenet_val_100_tail.tar.gz")
set(IMAGENET_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/imagenet")
set(IMAGENET_DATA_PATH "${IMAGENET_DATA_DIR}/data.bin")
download_int8_data(${IMAGENET_DATA_DIR} ${IMAGENET_DATA_ARCHIVE})
# build test binary to be used in subsequent tests
set(INT8_IMG_CLASS_TEST_APP "test_analyzer_int8_image_classification")
set(INT8_IMG_CLASS_TEST_APP_SRC "analyzer_int8_image_classification_tester.cc")
inference_analysis_api_test_build(${INT8_IMG_CLASS_TEST_APP} ${INT8_IMG_CLASS_TEST_APP_SRC})
# resnet50 int8
......@@ -296,7 +299,7 @@ if(WITH_MKLDNN)
### optimized FP32 vs. QAT INT8 tests
set(QAT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
set(QAT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/qat")
set(QAT_IMG_CLASS_TEST_APP "test_analyzer_qat_image_classification")
set(QAT_IMG_CLASS_TEST_APP_SRC "analyzer_qat_image_classification_tester.cc")
......@@ -304,8 +307,8 @@ if(WITH_MKLDNN)
inference_analysis_api_test_build(${QAT_IMG_CLASS_TEST_APP} ${QAT_IMG_CLASS_TEST_APP_SRC})
# MobileNet FP32 vs. QAT INT8
# The FP32 model should already be downloaded for slim QAT unit tests
set(QAT2_MobileNet_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf")
download_qat_data(${QAT2_MobileNet_MODEL_DIR} "MobileNet_qat_perf.tar.gz")
set(QAT2_INT8_MobileNet_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf_int8")
download_qat_data(${QAT2_INT8_MobileNet_MODEL_DIR} "MobileNet_qat_perf_int8.tar.gz")
inference_analysis_api_qat_test_run(test_analyzer_qat_performance_benchmark ${QAT_IMG_CLASS_TEST_APP} ${QAT2_MobileNet_MODEL_DIR}/MobileNet_qat_perf/float ${QAT2_INT8_MobileNet_MODEL_DIR}/MobileNet_qat_perf_int8 ${IMAGENET_DATA_PATH})
......
......@@ -479,7 +479,7 @@ GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx,
const Tensor* weights,
const mkldnn::engine& mkldnn_engine) {
const std::string key = platform::CreateKey(
platform::ThreadIDasStr(), input->format(),
platform::ThreadIDasStr(), input->format(), input->dims()[0],
framework::vectorize<int>(weights->dims()), ctx.OutputName("Out"));
auto prim_creator =
......
......@@ -17,10 +17,10 @@ from .... import core
from ....framework import IrGraph
from ....framework import IrNode
__all__ = ['FakeQAT2MkldnnINT8KernelPass', 'FakeQAT2MkldnnINT8PerfPass']
__all__ = ['QatInt8MkldnnPass', 'Qat2Int8MkldnnPass']
class FakeQAT2MkldnnINT8KernelPass(object):
class QatInt8MkldnnPass(object):
"""
Convert QuantizationFreezePass generated IrGraph to MKL-DNN supported INT8
IrGraph. Following transformations did in this pass:
......@@ -48,13 +48,13 @@ class FakeQAT2MkldnnINT8KernelPass(object):
# The original graph will be rewrite.
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization \
import FakeQAT2MkldnnINT8KernelPass
import QatInt8MkldnnPass
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False)
place = fluid.CPUPlace()
mkldnn_pass = FakeQAT2MkldnnINT8KernelPass(fluid.global_scope(),
mkldnn_pass = QatInt8MkldnnPass(fluid.global_scope(),
place)
mkldnn_pass.apply(graph)
"""
......@@ -276,7 +276,7 @@ class FakeQAT2MkldnnINT8KernelPass(object):
graph.safe_remove_nodes(all_unused_vars)
class FakeQAT2MkldnnINT8PerfPass(object):
class Qat2Int8MkldnnPass(object):
"""
Transform a QAT model IrGraph into MKL-DNN supported INT8 IrGraph.
The pass consists of the following transformations:
......@@ -290,7 +290,12 @@ class FakeQAT2MkldnnINT8PerfPass(object):
passes (`cpu_quantize_pass`, `cpu_quantize_squash_pass`).
"""
def __init__(self, _scope=None, _place=None, _core=None, _debug=False):
def __init__(self,
_quantized_ops,
_scope=None,
_place=None,
_core=None,
_debug=False):
self._scope = _scope
self._place = _place
self._core = _core
......@@ -305,6 +310,10 @@ class FakeQAT2MkldnnINT8PerfPass(object):
'fake_quantize_dequantize_moving_average_abs_max'
]
self._fake_dequantize_types = ['fake_dequantize_max_abs']
self._quantized_ops = _quantized_ops
self._scale_immutable_ops = [
'transpose2', 'reshape2', 'pool2d', 'scale'
]
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._pool_ops = ['pool2d']
self._mul_ops = ['mul']
......@@ -324,8 +333,9 @@ class FakeQAT2MkldnnINT8PerfPass(object):
graph = self._dequantize_weights(graph)
graph = self._optimize_fp32_graph(graph)
graph = self._compute_weight_scales(graph)
graph = self._update_conv_relu_scales(graph)
graph = self._update_pooling_scales(graph)
graph = self._update_relu_output_scales(graph)
graph = self._propagate_scales(graph)
graph = self._set_dummy_fc_out_scales(graph)
graph = self._quantize_fp32_graph(graph)
graph = self._remove_unused_var_nodes(graph)
return graph
......@@ -346,6 +356,12 @@ class FakeQAT2MkldnnINT8PerfPass(object):
tensor.set(scale, core.CPUPlace())
return tensor
def _is_conv_quantized(self):
return any(op_type in self._quantized_ops for op_type in self._conv_ops)
def _is_fc_quantized(self):
return 'fc' in self._quantized_ops
def _gather_scales(self, graph):
for op in graph.all_op_nodes():
if op.name() in self._quantize_types:
......@@ -371,34 +387,94 @@ class FakeQAT2MkldnnINT8PerfPass(object):
self._weight_scales[input_name] = _max_range
return graph
def _update_pooling_scales(self, graph):
def _propagate_scales(self, graph):
def _update_scale_op_in_scale(op, input, output):
unsigned, tensor = self._var_quant_scales[output]
scale = np.array(tensor) * op.op().attr("scale")
new_tensor = self._convert_scale2tensor(scale.astype(np.float64))
self._var_quant_scales[input] = (unsigned, new_tensor)
def _update_scales(graph):
waiting_for_scale = set()
for op in graph.all_op_nodes():
if op.name() in self._scale_immutable_ops:
input_name = op.input("X")[0]
output_name = op.output("Out")[0]
tensor_names = [input_name, output_name]
# Scale is not quantized, so if it doesn't have any scales
# to propagate, its tensors won't be added to the waiting list.
if all(name not in self._var_quant_scales for name in tensor_names) \
and op.name() != 'scale':
waiting_for_scale.update(tensor_names)
continue
if input_name in self._var_quant_scales:
self._var_quant_scales[
output_name] = self._var_quant_scales[input_name]
elif output_name in self._var_quant_scales:
if op.name() == 'scale':
_update_scale_op_in_scale(op, input_name,
output_name)
else:
self._var_quant_scales[
input_name] = self._var_quant_scales[
output_name]
return waiting_for_scale
waiting_for_scale = _update_scales(graph)
while len(waiting_for_scale) != 0:
waiting_for_scale = _update_scales(graph)
return graph
def _set_dummy_fc_out_scales(self, graph):
'''
For the output tensors of FC that do not have an assigned scale,
assign a dummy scale (same scale as input), so that the quantize pass
won't fail. In the end these scales aren't used, since FCs that
have an unassigend output scale will have a force_fp32_output attr
set to True.
'''
for op in graph.all_op_nodes():
if op.name() in self._pool_ops:
input_name = op.input("X")[0]
if op.name() in self._fc_ops:
input_name = op.input("Input")[0]
output_name = op.output("Out")[0]
if input_name in self._var_quant_scales:
if input_name in self._var_quant_scales and \
output_name not in self._var_quant_scales:
# use input scale as a "dummy" scale
self._var_quant_scales[
output_name] = self._var_quant_scales[input_name]
return graph
def _load_param(self, scope, param_name):
return np.array(scope.find_var(param_name).get_tensor())
def _remove_fake_ops(self, graph):
'''
When FC isn't quantized:
Remove fake (de)quantize ops that do not surround mul.
When FC is quantized:
Remove all fake (de)quantize ops.
'''
is_fc_quantized = self._is_fc_quantized()
for op in graph.all_op_nodes():
if op.name() in self._fake_quantize_types:
op_out = graph._find_node_by_name(op.outputs,
op.output("Out")[0])
next_op = op_out.outputs[0]
if next_op.name() not in self._mul_ops:
if next_op.name() not in self._mul_ops or is_fc_quantized:
self._remove_fake_quantize(graph, op)
for op in graph.all_op_nodes():
if op.name() in self._fake_dequantize_types:
op_in = graph._find_node_by_name(op.inputs, op.input("X")[0])
prev_op = op_in.inputs[0]
if prev_op.name() not in self._mul_ops:
if prev_op.name() not in self._mul_ops or is_fc_quantized:
self._remove_fake_dequantize(graph, op)
return graph
def _remove_fake_quantize(self, graph, op):
......@@ -444,6 +520,8 @@ class FakeQAT2MkldnnINT8PerfPass(object):
for op in graph.all_op_nodes():
if op.name() in self._conv_ops:
self._dequantize_conv_weights(graph, op)
elif self._is_fc_quantized() and op.name() in self._mul_ops:
self._dequantize_mul_weights(graph, op)
return graph
def _dequantize_conv_weights(self, graph, op_node):
......@@ -472,13 +550,20 @@ class FakeQAT2MkldnnINT8PerfPass(object):
def _optimize_fp32_graph(self, graph):
graph = self._apply_pass(graph, 'mkldnn_placement_pass',
['mkldnn_enabled_op_types'], [set()])
graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass')
graph = self._apply_pass(graph, 'conv_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_bias_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass')
if self._is_conv_quantized():
graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass')
graph = self._apply_pass(graph, 'conv_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_bias_mkldnn_fuse_pass')
graph = self._apply_pass(graph,
'conv_elementwise_add_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass')
if self._is_fc_quantized():
graph = self._apply_pass(graph, 'fc_fuse_pass',
['use_gpu', 'use_fc_padding'],
[False, False])
graph = self._apply_pass(graph, 'fc_mkldnn_pass')
return graph
def _apply_pass(self, graph, pass_name, attrs=None, attr_values=None):
......@@ -528,6 +613,7 @@ class FakeQAT2MkldnnINT8PerfPass(object):
np.abs(weights.reshape(weights.shape[0], -1)).astype(
np.float64),
axis=axis)
scales[scales == np.Inf] = 0.0
lod_tensor = self._convert_scale2tensor(scales)
use_unsigned_int = False
......@@ -546,46 +632,41 @@ class FakeQAT2MkldnnINT8PerfPass(object):
ids.append(op.id())
return set(ids) if len(ids) else set([-1])
def _transform_to_quantize_mkldnn(self, graph, op_node):
"""
Transform fake_quantize_xx op to quantize mkldnn op in the graph.
"""
input_var_node = graph._find_node_by_name(op_node.inputs,
op_node.input("X")[0])
output_var_node = graph._find_node_by_name(op_node.outputs,
op_node.output("Out")[0])
scale_in = self._s8_max / self._load_param(
self._scope, op_node.input("InScale")[0])[0]
quant_op_node = graph.create_op_node(
op_type='quantize',
attrs={
'data_format': 'MKLDNNLAYOUT',
'use_mkldnn': 1,
'Scale': scale_in,
'is_negative_input': 1
},
inputs={'Input': input_var_node},
outputs={'Output': output_var_node})
graph.link_to(input_var_node, quant_op_node)
graph.link_to(quant_op_node, output_var_node)
graph.safe_remove_nodes(op_node)
return quant_op_node
def _update_relu_output_scales(self, graph):
def _update_scale(graph, ops, op_out_name, predicate):
'''
Sets the type of an output scale of a passed op type(s) to 'unsigned int8' if the
predicate applied on op passes. Typically, the predicate checks if op's
activation is set to relu.
'''
for op in graph.all_op_nodes():
if op.name() in ops:
out_name = op.output(op_out_name)[0]
if out_name in self._var_quant_scales and predicate(op.op(
)):
_, tensor = self._var_quant_scales[out_name]
self._var_quant_scales[out_name] = (True, tensor)
return graph
if self._is_conv_quantized():
conv_predicate = lambda op: op.attr("fuse_activation") == 'relu' and \
op.attr("fuse_residual_connection") == False
graph = _update_scale(graph, self._conv_ops, "Output",
conv_predicate)
if self._is_fc_quantized():
fc_predicate = lambda op: op.attr("activation_type") == 'relu'
graph = _update_scale(graph, self._fc_ops, "Out", fc_predicate)
def _update_conv_relu_scales(self, graph):
for op in graph.all_op_nodes():
if op.name() in self._conv_ops:
out_name = op.output("Output")[0]
if out_name in self._var_quant_scales and \
op.op().attr("fuse_activation") == 'relu' and \
op.op().attr("fuse_residual_connection") == False:
_, tensor = self._var_quant_scales[out_name]
self._var_quant_scales[out_name] = (True, tensor)
return graph
def _get_data_layout(self):
return 'NHWC' if self._is_conv_quantized() else 'NCHW'
def _quantize_fp32_graph(self, graph):
ir_pass = self._core.get_pass('cpu_quantize_placement_pass')
cpp_graph = graph.graph
ir_pass.set('quantize_enabled_op_types', {'conv2d', 'pool2d'})
ir_pass.set('quantize_enabled_op_types', self._quantized_ops)
ir_pass.set('quantize_excluded_op_ids',
self._find_avg_pooling_ids(graph))
ir_pass.apply(cpp_graph)
......@@ -593,8 +674,8 @@ class FakeQAT2MkldnnINT8PerfPass(object):
graph.draw('.', 'qat_int8_{}'.format(ir_pass.type()),
graph.all_op_nodes())
graph = self._apply_pass(graph, 'cpu_quantize_pass',
['quant_var_scales'],
[self._var_quant_scales])
graph = self._apply_pass(
graph, 'cpu_quantize_pass', ['quant_var_scales', 'data_layout'],
[self._var_quant_scales, self._get_data_layout()])
graph = self._apply_pass(graph, 'cpu_quantize_squash_pass')
return graph
......@@ -24,8 +24,8 @@ import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8PerfPass
from paddle.fluid.contrib.slim.quantization import QatInt8MkldnnPass
from paddle.fluid.contrib.slim.quantization import Qat2Int8MkldnnPass
from paddle.fluid import core
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
......@@ -53,10 +53,6 @@ def parse_args():
action='store_true',
help='If used, the QAT model is treated as a second generation model for performance optimization.'
)
parser.add_argument(
'--save_model',
action='store_true',
help='If used, the QAT model will be saved after all transformations')
parser.add_argument('--infer_data', type=str, default='', help='Data file.')
parser.add_argument(
'--batch_num',
......@@ -68,15 +64,20 @@ def parse_args():
type=float,
default=0.01,
help='Accepted accuracy difference threshold.')
parser.add_argument(
'--quantized_ops',
type=str,
default='',
help='A comma separated list of quantized operators.')
test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args
class TestQatInt8Comparison(unittest.TestCase):
class QatInt8ImageClassificationComparisonTest(unittest.TestCase):
"""
Test for accuracy comparison of QAT FP32 and INT8 inference.
Test for accuracy comparison of QAT FP32 and INT8 Image Classification inference.
"""
def _reader_creator(self, data_file='data.bin'):
......@@ -182,14 +183,15 @@ class TestQatInt8Comparison(unittest.TestCase):
graph.draw('.', 'qat_orig', graph.all_op_nodes())
if (transform_to_int8):
if (test_case_args.qat2):
transform_to_mkldnn_int8_pass = FakeQAT2MkldnnINT8PerfPass(
transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass(
self._quantized_ops,
_scope=inference_scope,
_place=place,
_core=core,
_debug=self._debug)
graph = transform_to_mkldnn_int8_pass.apply(graph)
else:
mkldnn_int8_pass = FakeQAT2MkldnnINT8KernelPass(
mkldnn_int8_pass = QatInt8MkldnnPass(
_scope=inference_scope, _place=place)
graph = mkldnn_int8_pass.apply(graph)
......@@ -256,12 +258,6 @@ class TestQatInt8Comparison(unittest.TestCase):
_logger.info('Total inference run time: {:.2f} s'.format(
infer_total_time))
if test_case_args.save_model:
with fluid.scope_guard(inference_scope):
fluid.io.save_inference_model(
'transformed_qat_int8_model', feed_target_names,
fetch_targets, exe, inference_program)
return outputs, acc1_avg, acc5_avg, fps_avg, latency_avg
def _summarize_performance(self, fp32_fps, fp32_lat, int8_fps, int8_lat):
......@@ -298,6 +294,7 @@ class TestQatInt8Comparison(unittest.TestCase):
skip_batch_num = test_case_args.skip_batch_num
acc_diff_threshold = test_case_args.acc_diff_threshold
self._debug = test_case_args.debug
self._quantized_ops = set(test_case_args.quantized_ops.split(','))
_logger.info('QAT FP32 & INT8 prediction run.')
_logger.info('QAT model: {0}'.format(qat_model_path))
......@@ -305,6 +302,7 @@ class TestQatInt8Comparison(unittest.TestCase):
_logger.info('Batch size: {0}'.format(batch_size))
_logger.info('Batch number: {0}'.format(batch_num))
_logger.info('Accuracy drop threshold: {0}.'.format(acc_diff_threshold))
_logger.info('Quantized ops: {0}.'.format(self._quantized_ops))
_logger.info('--- QAT FP32 prediction start ---')
val_reader = paddle.batch(
......
# copyright (c) 2020 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import os
import sys
import argparse
import logging
import struct
import six
import numpy as np
import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import Qat2Int8MkldnnPass
from paddle.fluid import core
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=1, help='Batch size.')
parser.add_argument(
'--skip_batch_num',
type=int,
default=0,
help='Number of the first minibatches to skip in performance statistics.'
)
parser.add_argument(
'--debug',
action='store_true',
help='If used, the graph of QAT model is drawn.')
parser.add_argument(
'--qat_model', type=str, default='', help='A path to a QAT model.')
parser.add_argument(
'--save_model',
action='store_true',
help='If used, the QAT model will be saved after all transformations')
parser.add_argument('--infer_data', type=str, default='', help='Data file.')
parser.add_argument(
'--labels', type=str, default='', help='File with labels.')
parser.add_argument(
'--batch_num',
type=int,
default=1,
help='Number of batches to process. 0 or less means all.')
parser.add_argument(
'--acc_diff_threshold',
type=float,
default=0.01,
help='Accepted accuracy difference threshold.')
parser.add_argument(
'--quantized_ops',
type=str,
default='',
help='A comma separated list of quantized operators.')
test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args
class QatInt8NLPComparisonTest(unittest.TestCase):
"""
Test for accuracy comparison of QAT FP32 and INT8 NLP inference.
"""
def _reader_creator(self, data_file=None, labels_file=None):
assert data_file, "The dataset file is missing."
assert labels_file, "The labels file is missing."
def reader():
with open(data_file, 'r') as df:
with open(labels_file, 'r') as lf:
data_lines = df.readlines()
labels_lines = lf.readlines()
assert len(data_lines) == len(
labels_lines
), "The number of labels does not match the length of the dataset."
for i in range(len(data_lines)):
data_fields = data_lines[i].split(';')
assert len(
data_fields
) >= 2, "The number of data fields in the dataset is less than 2"
buffers = []
shape = []
for j in range(2):
data = data_fields[j].split(':')
assert len(
data
) >= 2, "Size of data in the dataset is less than 2"
# Shape is stored under index 0, while data under 1
shape = data[0].split()
shape.pop(0)
shape_np = np.array(shape).astype("int64")
buffer_i = data[1].split()
buffer_np = np.array(buffer_i).astype("int64")
buffer_np.shape = tuple(shape_np)
buffers.append(buffer_np)
label = labels_lines[i]
yield buffers[0], buffers[1], int(label)
return reader
def _get_batch_correct(self, batch_output=None, labels=None):
total = len(batch_output)
assert total > 0, "The batch output is empty."
correct = 0
for n, output in enumerate(batch_output[0]):
max_idx = np.where(output == output.max())
if max_idx == labels[n]:
correct += 1
return correct
def _predict(self,
test_reader=None,
model_path=None,
batch_size=1,
batch_num=1,
skip_batch_num=0,
transform_to_int8=False):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.executor.global_scope()
with fluid.scope_guard(inference_scope):
if os.path.exists(os.path.join(model_path, '__model__')):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(model_path, exe)
else:
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
model_path, exe, 'model', 'params')
graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if (self._debug):
graph.draw('.', 'qat_orig', graph.all_op_nodes())
if (transform_to_int8):
transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass(
self._quantized_ops,
_scope=inference_scope,
_place=place,
_core=core,
_debug=self._debug)
graph = transform_to_mkldnn_int8_pass.apply(graph)
inference_program = graph.to_program()
total_correct = 0
total_samples = 0
batch_times = []
ppses = [] # predictions per second
iters = 0
infer_start_time = time.time()
for data in test_reader():
if batch_num > 0 and iters >= batch_num:
break
if iters == skip_batch_num:
total_samples = 0
infer_start_time = time.time()
input0 = np.array([x[0] for x in data]).astype('int64')
input1 = np.array([x[1] for x in data]).astype('int64')
labels = np.array([x[2] for x in data]).astype('int64')
start = time.time()
out = exe.run(inference_program,
feed={
feed_target_names[0]: input0,
feed_target_names[1]: input1
},
fetch_list=fetch_targets)
batch_time = (time.time() - start) * 1000 # in miliseconds
batch_times.append(batch_time)
batch_correct = self._get_batch_correct(out, labels)
batch_len = len(data)
total_samples += batch_len
total_correct += batch_correct
batch_acc = float(batch_correct) / float(batch_len)
pps = batch_len / batch_time * 1000
ppses.append(pps)
latency = batch_time / batch_len
iters += 1
appx = ' (warm-up)' if iters <= skip_batch_num else ''
_logger.info(
'batch {0}{4}, acc: {1:.4f}, latency: {2:.4f} ms, predictions per sec: {3:.2f}'
.format(iters, batch_acc, latency, pps, appx))
# Postprocess benchmark data
infer_total_time = time.time() - infer_start_time
batch_latencies = batch_times[skip_batch_num:]
batch_latency_avg = np.average(batch_latencies)
latency_avg = batch_latency_avg / batch_size
ppses = ppses[skip_batch_num:]
pps_avg = np.average(ppses)
acc_avg = float(np.sum(total_correct)) / float(total_samples)
_logger.info('Total inference run time: {:.2f} s'.format(
infer_total_time))
return acc_avg, pps_avg, latency_avg
def _summarize_performance(self, fp32_pps, fp32_lat, int8_pps, int8_lat):
_logger.info('--- Performance summary ---')
_logger.info(
'FP32: avg predictions per sec: {0:.2f}, avg latency: {1:.4f} ms'.
format(fp32_pps, fp32_lat))
_logger.info(
'INT8: avg predictions per sec: {0:.2f}, avg latency: {1:.4f} ms'.
format(int8_pps, int8_lat))
def _compare_accuracy(self, fp32_acc, int8_acc, threshold):
_logger.info('--- Accuracy summary ---')
_logger.info(
'Accepted accuracy drop threshold: {0}. (condition: (FP32_acc - INT8_acc) <= threshold)'
.format(threshold))
_logger.info('FP32: avg accuracy: {0:.6f}'.format(fp32_acc))
_logger.info('INT8: avg accuracy: {0:.6f}'.format(int8_acc))
# Random outputs give accuracy about 0.33, we assume valid accuracy to be at least 0.5
assert fp32_acc > 0.5
assert int8_acc > 0.5
assert fp32_acc - int8_acc <= threshold
def test_graph_transformation(self):
if not fluid.core.is_compiled_with_mkldnn():
return
qat_model_path = test_case_args.qat_model
data_path = test_case_args.infer_data
labels_path = test_case_args.labels
batch_size = test_case_args.batch_size
batch_num = test_case_args.batch_num
skip_batch_num = test_case_args.skip_batch_num
acc_diff_threshold = test_case_args.acc_diff_threshold
self._debug = test_case_args.debug
self._quantized_ops = set(test_case_args.quantized_ops.split(','))
_logger.info('QAT FP32 & INT8 prediction run.')
_logger.info('QAT model: {0}'.format(qat_model_path))
_logger.info('Dataset: {0}'.format(data_path))
_logger.info('Labels: {0}'.format(labels_path))
_logger.info('Batch size: {0}'.format(batch_size))
_logger.info('Batch number: {0}'.format(batch_num))
_logger.info('Accuracy drop threshold: {0}.'.format(acc_diff_threshold))
_logger.info('Quantized ops: {0}.'.format(self._quantized_ops))
_logger.info('--- QAT FP32 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path, labels_path), batch_size=batch_size)
fp32_acc, fp32_pps, fp32_lat = self._predict(
val_reader,
qat_model_path,
batch_size,
batch_num,
skip_batch_num,
transform_to_int8=False)
_logger.info('--- QAT INT8 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path, labels_path), batch_size=batch_size)
int8_acc, int8_pps, int8_lat = self._predict(
val_reader,
qat_model_path,
batch_size,
batch_num,
skip_batch_num,
transform_to_int8=True)
self._summarize_performance(fp32_pps, fp32_lat, int8_pps, int8_lat)
self._compare_accuracy(fp32_acc, int8_acc, acc_diff_threshold)
if __name__ == '__main__':
global test_case_args
test_case_args, remaining_args = parse_args()
unittest.main(argv=remaining_args)
......@@ -24,7 +24,7 @@ import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8PerfPass
from paddle.fluid.contrib.slim.quantization import Qat2Int8MkldnnPass
from paddle.fluid import core
......@@ -42,6 +42,11 @@ def parse_args():
type=str,
default='',
help='Saved optimized and quantized INT8 model')
parser.add_argument(
'--quantized_ops',
type=str,
default='',
help='A comma separated list of quantized operators.')
test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args
......@@ -60,8 +65,9 @@ def transform_and_save_model(original_path, save_path, save_type):
fetch_targets] = fluid.io.load_inference_model(original_path, exe,
'model', 'params')
transform_to_mkldnn_int8_pass = FakeQAT2MkldnnINT8PerfPass(
_scope=inference_scope, _place=place, _core=core)
quantized_ops = set(test_args.quantized_ops.split(','))
transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass(
quantized_ops, _scope=inference_scope, _place=place, _core=core)
graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if save_type == 'FP32':
......
......@@ -22,7 +22,7 @@ import paddle
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass
from paddle.fluid.contrib.slim.quantization import QatInt8MkldnnPass
from paddle.fluid import core
os.environ["CPU_NUM"] = "1"
......@@ -149,8 +149,7 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
freeze_pass.apply(test_graph)
# Transform quantized graph for MKL-DNN INT8 inference
mkldnn_int8_pass = FakeQAT2MkldnnINT8KernelPass(
_scope=scope, _place=place)
mkldnn_int8_pass = QatInt8MkldnnPass(_scope=scope, _place=place)
mkldnn_int8_pass.apply(test_graph)
dev_name = '_cpu_'
if not for_ci:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册