提交 78e93286 编写于 作者: W Wojciech Uss 提交者: Tao Luo

Added unit test for QAT FP32 & INT8 comparison (#17814)

* added unit test for QAT FP32 & INT8 comparison

test=develop

* enabled other models and updated filenames

test=develop

* added accuracy check and multiple batch handling

test=develop

* removed quantization_mkldnn_pass.py

test=develop

* cleanup

test=develop

* updated model paths

test=develop

* renamed tests without MKL-DNN

test=develop

* fix reusing mkldnn pool2d primitive

test=develop

* add performance measuring

test=develop

* fix accuracy statistics

test=develop

* removed non-mkldnn tests

test=develop

* added conv2d_depthwise->conv2d mkldnn transformation

test=develop

* format update

test=develop

* fixed creating key for pool2d grad

test=develop

* added pass

* Fix the accuracy issue while using float precision to get the scale.

test=develop

* Fix the format issue when 'X' is not nchw.

test=develop

* removed output comparing and changed number of images

test=develop

* cmake and comment fix

test=develop

* updated acc threshold for QAT comparison tests

test=develop

* added OMP_NUM_THREADS setting

test=develop

* enable all QAT INT8 tests

test=develop

* restored upstream version of a file

test=develop

* modified directory names

test=develop
上级 566bf2ec
......@@ -36,7 +36,8 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const memory::data_type& dt, const std::string& suffix) {
const memory::data_type& dt, const memory::format& fmt,
const std::string& suffix) {
std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::MKLDNNHandler::AppendKeyDims(&key, input_dims);
......@@ -45,6 +46,7 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
platform::MKLDNNHandler::AppendKeyVec(&key, strides);
platform::MKLDNNHandler::AppendKeyVec(&key, paddings);
platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(fmt));
platform::MKLDNNHandler::AppendKey(&key, suffix);
return key;
}
......@@ -115,8 +117,10 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn::memory::data_type dt =
paddle::framework::ToMKLDNNDataType(input->type());
const std::string key = CreateKey(ctx, src_tz, pooling_type, ksize, strides,
paddings, dt, ctx.op().Output("Out"));
auto fmt = input->format();
const std::string key =
CreateKey(ctx, src_tz, pooling_type, ksize, strides, paddings, dt, fmt,
ctx.op().Output("Out"));
const std::string key_pool_p = key + "@pool_p";
const std::string key_pool_pd = key + "@pool_pd";
const std::string key_pool_src_mem_p = key + "@pool_src_mem_p";
......@@ -294,9 +298,9 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// Get an unique name from "argument" name of "Out" variable
// This name will be used as key when referring info from device context
const std::string key =
CreateKey(ctx, diff_src_tz, pooling_type, ksize, strides, paddings,
memory::data_type::f32, ctx.op().Input("Out"));
const std::string key = CreateKey(ctx, diff_src_tz, pooling_type, ksize,
strides, paddings, memory::data_type::f32,
in_x->format(), ctx.op().Input("Out"));
const std::string key_pool_bwd_p = key + "@pool_bwd_p";
const std::string key_pool_diff_src_mem_p = key + "@pool_diff_src_mem_p";
const std::string key_pool_diff_dst_mem_p = key + "@pool_diff_dst_mem_p";
......
......@@ -11,6 +11,18 @@ function(inference_analysis_python_api_int8_test target model_dir data_dir filen
--batch_size 50)
endfunction()
function(inference_qat_int8_test target model_dir data_dir test_script use_mkldnn)
py_test(${target} SRCS ${test_script}
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=${use_mkldnn}
ARGS --qat_model ${model_dir}/model
--infer_data ${data_dir}/data.bin
--batch_size 25
--batch_num 2
--acc_diff_threshold 0.1)
endfunction()
# NOTE: TODOOOOOOOOOOO
# temporarily disable test_distillation_strategy since it always failed on a specified machine with 4 GPUs
# Need to figure out the root cause and then add it back
......@@ -62,6 +74,74 @@ endif()
# with MKL-DNN, we remove it here for not repeating test, or not testing on other systems.
list(REMOVE_ITEM TEST_OPS test_mkldnn_int8_quantization_strategy)
# QAT FP32 & INT8 comparison python api tests
if(LINUX AND WITH_MKLDNN)
set(DATASET_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
set(QAT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
set(QAT_MODELS_BASE_URL "${INFERENCE_URL}/int8/QAT_models")
set(MKLDNN_QAT_TEST_FILE "qat_int8_comparison.py")
set(MKLDNN_QAT_TEST_FILE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_QAT_TEST_FILE}")
# ImageNet small dataset
# May be already downloaded for INT8v2 unit tests
if (NOT EXISTS ${DATASET_DIR})
inference_download_and_uncompress(${DATASET_DIR} "${INFERENCE_URL}/int8" "imagenet_val_100_tail.tar.gz")
endif()
# QAT ResNet50
set(QAT_RESNET50_MODEL_DIR "${QAT_DATA_DIR}/ResNet50_QAT")
if (NOT EXISTS ${QAT_RESNET50_MODEL_DIR})
inference_download_and_uncompress(${QAT_RESNET50_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet50_qat_model.tar.gz" )
endif()
inference_qat_int8_test(test_qat_int8_resnet50_mkldnn ${QAT_RESNET50_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
# QAT ResNet101
set(QAT_RESNET101_MODEL_DIR "${QAT_DATA_DIR}/ResNet101_QAT")
if (NOT EXISTS ${QAT_RESNET101_MODEL_DIR})
inference_download_and_uncompress(${QAT_RESNET101_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet101_qat_model.tar.gz" )
endif()
inference_qat_int8_test(test_qat_int8_resnet101_mkldnn ${QAT_RESNET101_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
# QAT GoogleNet
set(QAT_GOOGLENET_MODEL_DIR "${QAT_DATA_DIR}/GoogleNet_QAT")
if (NOT EXISTS ${QAT_GOOGLENET_MODEL_DIR})
inference_download_and_uncompress(${QAT_GOOGLENET_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "GoogleNet_qat_model.tar.gz" )
endif()
inference_qat_int8_test(test_qat_int8_googlenet_mkldnn ${QAT_GOOGLENET_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
# QAT MobileNetV1
set(QAT_MOBILENETV1_MODEL_DIR "${QAT_DATA_DIR}/MobileNetV1_QAT")
if (NOT EXISTS ${QAT_MOBILENETV1_MODEL_DIR})
inference_download_and_uncompress(${QAT_MOBILENETV1_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNetV1_qat_model.tar.gz" )
endif()
inference_qat_int8_test(test_qat_int8_mobilenetv1_mkldnn ${QAT_MOBILENETV1_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
# QAT MobileNetV2
set(QAT_MOBILENETV2_MODEL_DIR "${QAT_DATA_DIR}/MobileNetV2_QAT")
if (NOT EXISTS ${QAT_MOBILENETV2_MODEL_DIR})
inference_download_and_uncompress(${QAT_MOBILENETV2_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNetV2_qat_model.tar.gz" )
endif()
inference_qat_int8_test(test_qat_int8_mobilenetv2_mkldnn ${QAT_MOBILENETV2_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
# QAT VGG16
set(QAT_VGG16_MODEL_DIR "${QAT_DATA_DIR}/VGG16_QAT")
if (NOT EXISTS ${QAT_VGG16_MODEL_DIR})
inference_download_and_uncompress(${QAT_VGG16_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "VGG16_qat_model.tar.gz" )
endif()
inference_qat_int8_test(test_qat_int8_vgg16_mkldnn ${QAT_VGG16_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
# QAT VGG19
set(QAT_VGG19_MODEL_DIR "${QAT_DATA_DIR}/VGG19_QAT")
if (NOT EXISTS ${QAT_VGG19_MODEL_DIR})
inference_download_and_uncompress(${QAT_VGG19_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "VGG19_qat_model.tar.gz" )
endif()
inference_qat_int8_test(test_qat_int8_vgg19_mkldnn ${QAT_VGG19_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
endif()
# Since the test for QAT FP32 & INT8 comparison supports only testing on Linux
# with MKL-DNN, we remove it here to not test it on other systems.
list(REMOVE_ITEM TEST_OPS qat_int8_comparison.py)
foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py)
endforeach()
# copyright (c) 2019 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import os
import sys
import argparse
import logging
import struct
import six
import numpy as np
import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import TransformForMkldnnPass
from paddle.fluid import core
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=1, help='Batch size.')
parser.add_argument(
'--skip_batch_num',
type=int,
default=0,
help='Number of the first minibatches to skip in performance statistics.'
)
parser.add_argument(
'--qat_model', type=str, default='', help='A path to a QAT model.')
parser.add_argument('--infer_data', type=str, default='', help='Data file.')
parser.add_argument(
'--batch_num',
type=int,
default=1,
help='Number of batches to process. 0 or less means all.')
parser.add_argument(
'--acc_diff_threshold',
type=float,
default=0.01,
help='Accepted accuracy difference threshold.')
test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args
class TestQatInt8Comparison(unittest.TestCase):
"""
Test for accuracy comparison of QAT FP32 and INT8 inference.
"""
def _reader_creator(self, data_file='data.bin'):
def reader():
with open(data_file, 'rb') as fp:
num = fp.read(8)
num = struct.unpack('q', num)[0]
imgs_offset = 8
img_ch = 3
img_w = 224
img_h = 224
img_pixel_size = 4
img_size = img_ch * img_h * img_w * img_pixel_size
label_size = 8
labels_offset = imgs_offset + num * img_size
step = 0
while step < num:
fp.seek(imgs_offset + img_size * step)
img = fp.read(img_size)
img = struct.unpack_from('{}f'.format(img_ch * img_w *
img_h), img)
img = np.array(img)
img.shape = (img_ch, img_w, img_h)
fp.seek(labels_offset + label_size * step)
label = fp.read(label_size)
label = struct.unpack('q', label)[0]
yield img, int(label)
step += 1
return reader
def _get_batch_accuracy(self, batch_output=None, labels=None):
total = 0
correct = 0
correct_5 = 0
for n, result in enumerate(batch_output):
index = result.argsort()
top_1_index = index[-1]
top_5_index = index[-5:]
total += 1
if top_1_index == labels[n]:
correct += 1
if labels[n] in top_5_index:
correct_5 += 1
acc1 = float(correct) / float(total)
acc5 = float(correct_5) / float(total)
return acc1, acc5
def _prepare_for_fp32_mkldnn(self, graph):
ops = graph.all_op_nodes()
for op_node in ops:
name = op_node.name()
if name in ['depthwise_conv2d']:
input_var_node = graph._find_node_by_name(
op_node.inputs, op_node.input("Input")[0])
weight_var_node = graph._find_node_by_name(
op_node.inputs, op_node.input("Filter")[0])
output_var_node = graph._find_node_by_name(
graph.all_var_nodes(), op_node.output("Output")[0])
attrs = {
name: op_node.op().attr(name)
for name in op_node.op().attr_names()
}
conv_op_node = graph.create_op_node(
op_type='conv2d',
attrs=attrs,
inputs={
'Input': input_var_node,
'Filter': weight_var_node
},
outputs={'Output': output_var_node})
graph.link_to(input_var_node, conv_op_node)
graph.link_to(weight_var_node, conv_op_node)
graph.link_to(conv_op_node, output_var_node)
graph.safe_remove_nodes(op_node)
return graph
def _predict(self,
test_reader=None,
model_path=None,
batch_num=1,
skip_batch_num=0,
transform_to_int8=False):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.executor.global_scope()
with fluid.scope_guard(inference_scope):
if os.path.exists(os.path.join(model_path, '__model__')):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(model_path, exe)
else:
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
model_path, exe, 'model', 'params')
graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if (transform_to_int8):
mkldnn_int8_pass = TransformForMkldnnPass(
scope=inference_scope, place=place)
mkldnn_int8_pass.apply(graph)
else:
graph = self._prepare_for_fp32_mkldnn(graph)
inference_program = graph.to_program()
dshape = [3, 224, 224]
outputs = []
infer_accs1 = []
infer_accs5 = []
fpses = []
batch_times = []
total_samples = 0
top1 = 0.0
top5 = 0.0
iters = 0
infer_start_time = time.time()
for data in test_reader():
if batch_num > 0 and iters >= batch_num:
break
if iters == skip_batch_num:
total_samples = 0
infer_start_time = time.time()
if six.PY2:
images = map(lambda x: x[0].reshape(dshape), data)
if six.PY3:
images = list(map(lambda x: x[0].reshape(dshape), data))
images = np.array(images).astype('float32')
labels = np.array([x[1] for x in data]).astype('int64')
start = time.time()
out = exe.run(inference_program,
feed={feed_target_names[0]: images},
fetch_list=fetch_targets)
batch_time = time.time() - start
outputs.append(out[0])
batch_acc1, batch_acc5 = self._get_batch_accuracy(out[0],
labels)
infer_accs1.append(batch_acc1)
infer_accs5.append(batch_acc5)
samples = len(data)
total_samples += samples
batch_times.append(batch_time)
fps = samples / batch_time
fpses.append(fps)
iters += 1
appx = ' (warm-up)' if iters <= skip_batch_num else ''
_logger.info(
'batch {0}{5}, acc1: {1:.4f}, acc5: {2:.4f}, '
'batch latency: {3:.4f} s, batch fps: {4:.2f}'.format(
iters, batch_acc1, batch_acc5, batch_time, fps, appx))
# Postprocess benchmark data
latencies = batch_times[skip_batch_num:]
latency_avg = np.average(latencies)
fpses = fpses[skip_batch_num:]
fps_avg = np.average(fpses)
infer_total_time = time.time() - infer_start_time
acc1_avg = np.mean(infer_accs1)
acc5_avg = np.mean(infer_accs5)
_logger.info('Total inference run time: {:.2f} s'.format(
infer_total_time))
return outputs, acc1_avg, acc5_avg, fps_avg, latency_avg
def _compare_accuracy(self, fp32_acc1, fp32_acc5, int8_acc1, int8_acc5,
threshold):
_logger.info('Accepted acc1 diff threshold: {0}'.format(threshold))
_logger.info('FP32: avg acc1: {0:.4f}, avg acc5: {1:.4f}'.format(
fp32_acc1, fp32_acc5))
_logger.info('INT8: avg acc1: {0:.4f}, avg acc5: {1:.4f}'.format(
int8_acc1, int8_acc5))
assert fp32_acc1 > 0.0
assert int8_acc1 > 0.0
assert fp32_acc1 - int8_acc1 <= threshold
def test_graph_transformation(self):
if not fluid.core.is_compiled_with_mkldnn():
return
qat_model_path = test_case_args.qat_model
data_path = test_case_args.infer_data
batch_size = test_case_args.batch_size
batch_num = test_case_args.batch_num
skip_batch_num = test_case_args.skip_batch_num
acc_diff_threshold = test_case_args.acc_diff_threshold
_logger.info('QAT FP32 & INT8 prediction run.')
_logger.info('QAT model: {0}'.format(qat_model_path))
_logger.info('Dataset: {0}'.format(data_path))
_logger.info('Batch size: {0}'.format(batch_size))
_logger.info('Batch number: {0}'.format(batch_num))
_logger.info('Accuracy diff threshold: {0}. '
'(condition: (fp32_acc - int8_acc) <= threshold)'
.format(acc_diff_threshold))
_logger.info('--- QAT FP32 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path), batch_size=batch_size)
fp32_output, fp32_acc1, fp32_acc5, fp32_fps, fp32_lat = self._predict(
val_reader,
qat_model_path,
batch_num,
skip_batch_num,
transform_to_int8=False)
_logger.info('--- QAT INT8 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path), batch_size=batch_size)
int8_output, int8_acc1, int8_acc5, int8_fps, int8_lat = self._predict(
val_reader,
qat_model_path,
batch_num,
skip_batch_num,
transform_to_int8=True)
_logger.info('--- Performance summary ---')
_logger.info('FP32: avg fps: {0:.2f}, avg latency: {1:.4f} s'.format(
fp32_fps, fp32_lat))
_logger.info('INT8: avg fps: {0:.2f}, avg latency: {1:.4f} s'.format(
int8_fps, int8_lat))
_logger.info('--- Comparing accuracy ---')
self._compare_accuracy(fp32_acc1, fp32_acc5, int8_acc1, int8_acc5,
acc_diff_threshold)
if __name__ == '__main__':
global test_case_args
test_case_args, remaining_args = parse_args()
unittest.main(argv=remaining_args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册