提交 80a3fd2e 编写于 作者: W Wojciech Uss 提交者: Tao Luo

Cherry-pick of #17814 and #18030 (#18067)

Added unit test for QAT FP32 & INT8 comparison (#17814)
Disable MKLDNN FC in Resnet50 test (#18030)

test=release/1.5
上级 a114a39e
...@@ -385,7 +385,7 @@ function(cc_test TARGET_NAME) ...@@ -385,7 +385,7 @@ function(cc_test TARGET_NAME)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true ${MKL_DEBUG_FLAG}) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true)
# No unit test should exceed 10 minutes. # No unit test should exceed 10 minutes.
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600) set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
endif() endif()
......
...@@ -33,14 +33,12 @@ function(inference_analysis_api_int8_test target model_dir data_dir filename) ...@@ -33,14 +33,12 @@ function(inference_analysis_api_int8_test target model_dir data_dir filename)
--paddle_num_threads=${CPU_NUM_THREADS_ON_CI} --paddle_num_threads=${CPU_NUM_THREADS_ON_CI}
--iterations=2) --iterations=2)
endfunction() endfunction()
function(inference_analysis_api_test_with_fake_data target install_dir filename model_name mkl_debug) function(inference_analysis_api_test_with_fake_data target install_dir filename model_name disable_fc)
if(mkl_debug)
set(MKL_DEBUG_FLAG MKL_DEBUG_CPU_TYPE=7)
endif()
download_model(${install_dir} ${model_name}) download_model(${install_dir} ${model_name})
inference_analysis_test(${target} SRCS ${filename} inference_analysis_test(${target} SRCS ${filename}
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${install_dir}/model) ARGS --infer_model=${install_dir}/model
--disable_mkldnn_fc=${disable_fc})
endfunction() endfunction()
function(inference_analysis_api_test_with_refer_result target install_dir filename) function(inference_analysis_api_test_with_refer_result target install_dir filename)
......
...@@ -16,6 +16,8 @@ limitations under the License. */ ...@@ -16,6 +16,8 @@ limitations under the License. */
#include <iostream> #include <iostream>
#include "paddle/fluid/inference/tests/api/tester_helper.h" #include "paddle/fluid/inference/tests/api/tester_helper.h"
DEFINE_bool(disable_mkldnn_fc, false, "Disable usage of MKL-DNN's FC op");
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
...@@ -48,7 +50,8 @@ void profile(bool use_mkldnn = false) { ...@@ -48,7 +50,8 @@ void profile(bool use_mkldnn = false) {
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); if (!FLAGS_disable_mkldnn_fc)
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
std::vector<std::vector<PaddleTensor>> outputs; std::vector<std::vector<PaddleTensor>> outputs;
...@@ -80,7 +83,8 @@ void compare(bool use_mkldnn = false) { ...@@ -80,7 +83,8 @@ void compare(bool use_mkldnn = false) {
SetConfig(&cfg); SetConfig(&cfg);
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); if (!FLAGS_disable_mkldnn_fc)
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
} }
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
......
...@@ -36,7 +36,8 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx, ...@@ -36,7 +36,8 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
const std::vector<int>& ksize, const std::vector<int>& ksize,
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const memory::data_type& dt, const std::string& suffix) { const memory::data_type& dt, const memory::format& fmt,
const std::string& suffix) {
std::string key; std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength); key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::MKLDNNHandler::AppendKeyDims(&key, input_dims); platform::MKLDNNHandler::AppendKeyDims(&key, input_dims);
...@@ -45,6 +46,7 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx, ...@@ -45,6 +46,7 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
platform::MKLDNNHandler::AppendKeyVec(&key, strides); platform::MKLDNNHandler::AppendKeyVec(&key, strides);
platform::MKLDNNHandler::AppendKeyVec(&key, paddings); platform::MKLDNNHandler::AppendKeyVec(&key, paddings);
platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt)); platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(fmt));
platform::MKLDNNHandler::AppendKey(&key, suffix); platform::MKLDNNHandler::AppendKey(&key, suffix);
return key; return key;
} }
...@@ -115,8 +117,10 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -115,8 +117,10 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn::memory::data_type dt = mkldnn::memory::data_type dt =
paddle::framework::ToMKLDNNDataType(input->type()); paddle::framework::ToMKLDNNDataType(input->type());
const std::string key = CreateKey(ctx, src_tz, pooling_type, ksize, strides, auto fmt = input->format();
paddings, dt, ctx.op().Output("Out")); const std::string key =
CreateKey(ctx, src_tz, pooling_type, ksize, strides, paddings, dt, fmt,
ctx.op().Output("Out"));
const std::string key_pool_p = key + "@pool_p"; const std::string key_pool_p = key + "@pool_p";
const std::string key_pool_pd = key + "@pool_pd"; const std::string key_pool_pd = key + "@pool_pd";
const std::string key_pool_src_mem_p = key + "@pool_src_mem_p"; const std::string key_pool_src_mem_p = key + "@pool_src_mem_p";
...@@ -294,9 +298,9 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -294,9 +298,9 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// Get an unique name from "argument" name of "Out" variable // Get an unique name from "argument" name of "Out" variable
// This name will be used as key when referring info from device context // This name will be used as key when referring info from device context
const std::string key = const std::string key = CreateKey(ctx, diff_src_tz, pooling_type, ksize,
CreateKey(ctx, diff_src_tz, pooling_type, ksize, strides, paddings, strides, paddings, memory::data_type::f32,
memory::data_type::f32, ctx.op().Input("Out")); in_x->format(), ctx.op().Input("Out"));
const std::string key_pool_bwd_p = key + "@pool_bwd_p"; const std::string key_pool_bwd_p = key + "@pool_bwd_p";
const std::string key_pool_diff_src_mem_p = key + "@pool_diff_src_mem_p"; const std::string key_pool_diff_src_mem_p = key + "@pool_diff_src_mem_p";
const std::string key_pool_diff_dst_mem_p = key + "@pool_diff_dst_mem_p"; const std::string key_pool_diff_dst_mem_p = key + "@pool_diff_dst_mem_p";
......
...@@ -11,6 +11,18 @@ function(inference_analysis_python_api_int8_test target model_dir data_dir filen ...@@ -11,6 +11,18 @@ function(inference_analysis_python_api_int8_test target model_dir data_dir filen
--batch_size 50) --batch_size 50)
endfunction() endfunction()
function(inference_qat_int8_test target model_dir data_dir test_script use_mkldnn)
py_test(${target} SRCS ${test_script}
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=${use_mkldnn}
ARGS --qat_model ${model_dir}/model
--infer_data ${data_dir}/data.bin
--batch_size 25
--batch_num 2
--acc_diff_threshold 0.1)
endfunction()
# NOTE: TODOOOOOOOOOOO # NOTE: TODOOOOOOOOOOO
# temporarily disable test_distillation_strategy since it always failed on a specified machine with 4 GPUs # temporarily disable test_distillation_strategy since it always failed on a specified machine with 4 GPUs
# Need to figure out the root cause and then add it back # Need to figure out the root cause and then add it back
...@@ -62,6 +74,74 @@ endif() ...@@ -62,6 +74,74 @@ endif()
# with MKL-DNN, we remove it here for not repeating test, or not testing on other systems. # with MKL-DNN, we remove it here for not repeating test, or not testing on other systems.
list(REMOVE_ITEM TEST_OPS test_mkldnn_int8_quantization_strategy) list(REMOVE_ITEM TEST_OPS test_mkldnn_int8_quantization_strategy)
# QAT FP32 & INT8 comparison python api tests
if(LINUX AND WITH_MKLDNN)
set(DATASET_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
set(QAT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
set(QAT_MODELS_BASE_URL "${INFERENCE_URL}/int8/QAT_models")
set(MKLDNN_QAT_TEST_FILE "qat_int8_comparison.py")
set(MKLDNN_QAT_TEST_FILE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_QAT_TEST_FILE}")
# ImageNet small dataset
# May be already downloaded for INT8v2 unit tests
if (NOT EXISTS ${DATASET_DIR})
inference_download_and_uncompress(${DATASET_DIR} "${INFERENCE_URL}/int8" "imagenet_val_100_tail.tar.gz")
endif()
# QAT ResNet50
set(QAT_RESNET50_MODEL_DIR "${QAT_DATA_DIR}/ResNet50_QAT")
if (NOT EXISTS ${QAT_RESNET50_MODEL_DIR})
inference_download_and_uncompress(${QAT_RESNET50_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet50_qat_model.tar.gz" )
endif()
inference_qat_int8_test(test_qat_int8_resnet50_mkldnn ${QAT_RESNET50_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
# QAT ResNet101
set(QAT_RESNET101_MODEL_DIR "${QAT_DATA_DIR}/ResNet101_QAT")
if (NOT EXISTS ${QAT_RESNET101_MODEL_DIR})
inference_download_and_uncompress(${QAT_RESNET101_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet101_qat_model.tar.gz" )
endif()
inference_qat_int8_test(test_qat_int8_resnet101_mkldnn ${QAT_RESNET101_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
# QAT GoogleNet
set(QAT_GOOGLENET_MODEL_DIR "${QAT_DATA_DIR}/GoogleNet_QAT")
if (NOT EXISTS ${QAT_GOOGLENET_MODEL_DIR})
inference_download_and_uncompress(${QAT_GOOGLENET_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "GoogleNet_qat_model.tar.gz" )
endif()
inference_qat_int8_test(test_qat_int8_googlenet_mkldnn ${QAT_GOOGLENET_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
# QAT MobileNetV1
set(QAT_MOBILENETV1_MODEL_DIR "${QAT_DATA_DIR}/MobileNetV1_QAT")
if (NOT EXISTS ${QAT_MOBILENETV1_MODEL_DIR})
inference_download_and_uncompress(${QAT_MOBILENETV1_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNetV1_qat_model.tar.gz" )
endif()
inference_qat_int8_test(test_qat_int8_mobilenetv1_mkldnn ${QAT_MOBILENETV1_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
# QAT MobileNetV2
set(QAT_MOBILENETV2_MODEL_DIR "${QAT_DATA_DIR}/MobileNetV2_QAT")
if (NOT EXISTS ${QAT_MOBILENETV2_MODEL_DIR})
inference_download_and_uncompress(${QAT_MOBILENETV2_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNetV2_qat_model.tar.gz" )
endif()
inference_qat_int8_test(test_qat_int8_mobilenetv2_mkldnn ${QAT_MOBILENETV2_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
# QAT VGG16
set(QAT_VGG16_MODEL_DIR "${QAT_DATA_DIR}/VGG16_QAT")
if (NOT EXISTS ${QAT_VGG16_MODEL_DIR})
inference_download_and_uncompress(${QAT_VGG16_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "VGG16_qat_model.tar.gz" )
endif()
inference_qat_int8_test(test_qat_int8_vgg16_mkldnn ${QAT_VGG16_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
# QAT VGG19
set(QAT_VGG19_MODEL_DIR "${QAT_DATA_DIR}/VGG19_QAT")
if (NOT EXISTS ${QAT_VGG19_MODEL_DIR})
inference_download_and_uncompress(${QAT_VGG19_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "VGG19_qat_model.tar.gz" )
endif()
inference_qat_int8_test(test_qat_int8_vgg19_mkldnn ${QAT_VGG19_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
endif()
# Since the test for QAT FP32 & INT8 comparison supports only testing on Linux
# with MKL-DNN, we remove it here to not test it on other systems.
list(REMOVE_ITEM TEST_OPS qat_int8_comparison.py)
foreach(src ${TEST_OPS}) foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py) py_test(${src} SRCS ${src}.py)
endforeach() endforeach()
# copyright (c) 2019 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import os
import sys
import argparse
import logging
import struct
import six
import numpy as np
import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import TransformForMkldnnPass
from paddle.fluid import core
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=1, help='Batch size.')
parser.add_argument(
'--skip_batch_num',
type=int,
default=0,
help='Number of the first minibatches to skip in performance statistics.'
)
parser.add_argument(
'--qat_model', type=str, default='', help='A path to a QAT model.')
parser.add_argument('--infer_data', type=str, default='', help='Data file.')
parser.add_argument(
'--batch_num',
type=int,
default=1,
help='Number of batches to process. 0 or less means all.')
parser.add_argument(
'--acc_diff_threshold',
type=float,
default=0.01,
help='Accepted accuracy difference threshold.')
test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args
class TestQatInt8Comparison(unittest.TestCase):
"""
Test for accuracy comparison of QAT FP32 and INT8 inference.
"""
def _reader_creator(self, data_file='data.bin'):
def reader():
with open(data_file, 'rb') as fp:
num = fp.read(8)
num = struct.unpack('q', num)[0]
imgs_offset = 8
img_ch = 3
img_w = 224
img_h = 224
img_pixel_size = 4
img_size = img_ch * img_h * img_w * img_pixel_size
label_size = 8
labels_offset = imgs_offset + num * img_size
step = 0
while step < num:
fp.seek(imgs_offset + img_size * step)
img = fp.read(img_size)
img = struct.unpack_from('{}f'.format(img_ch * img_w *
img_h), img)
img = np.array(img)
img.shape = (img_ch, img_w, img_h)
fp.seek(labels_offset + label_size * step)
label = fp.read(label_size)
label = struct.unpack('q', label)[0]
yield img, int(label)
step += 1
return reader
def _get_batch_accuracy(self, batch_output=None, labels=None):
total = 0
correct = 0
correct_5 = 0
for n, result in enumerate(batch_output):
index = result.argsort()
top_1_index = index[-1]
top_5_index = index[-5:]
total += 1
if top_1_index == labels[n]:
correct += 1
if labels[n] in top_5_index:
correct_5 += 1
acc1 = float(correct) / float(total)
acc5 = float(correct_5) / float(total)
return acc1, acc5
def _prepare_for_fp32_mkldnn(self, graph):
ops = graph.all_op_nodes()
for op_node in ops:
name = op_node.name()
if name in ['depthwise_conv2d']:
input_var_node = graph._find_node_by_name(
op_node.inputs, op_node.input("Input")[0])
weight_var_node = graph._find_node_by_name(
op_node.inputs, op_node.input("Filter")[0])
output_var_node = graph._find_node_by_name(
graph.all_var_nodes(), op_node.output("Output")[0])
attrs = {
name: op_node.op().attr(name)
for name in op_node.op().attr_names()
}
conv_op_node = graph.create_op_node(
op_type='conv2d',
attrs=attrs,
inputs={
'Input': input_var_node,
'Filter': weight_var_node
},
outputs={'Output': output_var_node})
graph.link_to(input_var_node, conv_op_node)
graph.link_to(weight_var_node, conv_op_node)
graph.link_to(conv_op_node, output_var_node)
graph.safe_remove_nodes(op_node)
return graph
def _predict(self,
test_reader=None,
model_path=None,
batch_num=1,
skip_batch_num=0,
transform_to_int8=False):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.executor.global_scope()
with fluid.scope_guard(inference_scope):
if os.path.exists(os.path.join(model_path, '__model__')):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(model_path, exe)
else:
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
model_path, exe, 'model', 'params')
graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if (transform_to_int8):
mkldnn_int8_pass = TransformForMkldnnPass(
scope=inference_scope, place=place)
mkldnn_int8_pass.apply(graph)
else:
graph = self._prepare_for_fp32_mkldnn(graph)
inference_program = graph.to_program()
dshape = [3, 224, 224]
outputs = []
infer_accs1 = []
infer_accs5 = []
fpses = []
batch_times = []
total_samples = 0
top1 = 0.0
top5 = 0.0
iters = 0
infer_start_time = time.time()
for data in test_reader():
if batch_num > 0 and iters >= batch_num:
break
if iters == skip_batch_num:
total_samples = 0
infer_start_time = time.time()
if six.PY2:
images = map(lambda x: x[0].reshape(dshape), data)
if six.PY3:
images = list(map(lambda x: x[0].reshape(dshape), data))
images = np.array(images).astype('float32')
labels = np.array([x[1] for x in data]).astype('int64')
start = time.time()
out = exe.run(inference_program,
feed={feed_target_names[0]: images},
fetch_list=fetch_targets)
batch_time = time.time() - start
outputs.append(out[0])
batch_acc1, batch_acc5 = self._get_batch_accuracy(out[0],
labels)
infer_accs1.append(batch_acc1)
infer_accs5.append(batch_acc5)
samples = len(data)
total_samples += samples
batch_times.append(batch_time)
fps = samples / batch_time
fpses.append(fps)
iters += 1
appx = ' (warm-up)' if iters <= skip_batch_num else ''
_logger.info(
'batch {0}{5}, acc1: {1:.4f}, acc5: {2:.4f}, '
'batch latency: {3:.4f} s, batch fps: {4:.2f}'.format(
iters, batch_acc1, batch_acc5, batch_time, fps, appx))
# Postprocess benchmark data
latencies = batch_times[skip_batch_num:]
latency_avg = np.average(latencies)
fpses = fpses[skip_batch_num:]
fps_avg = np.average(fpses)
infer_total_time = time.time() - infer_start_time
acc1_avg = np.mean(infer_accs1)
acc5_avg = np.mean(infer_accs5)
_logger.info('Total inference run time: {:.2f} s'.format(
infer_total_time))
return outputs, acc1_avg, acc5_avg, fps_avg, latency_avg
def _compare_accuracy(self, fp32_acc1, fp32_acc5, int8_acc1, int8_acc5,
threshold):
_logger.info('Accepted acc1 diff threshold: {0}'.format(threshold))
_logger.info('FP32: avg acc1: {0:.4f}, avg acc5: {1:.4f}'.format(
fp32_acc1, fp32_acc5))
_logger.info('INT8: avg acc1: {0:.4f}, avg acc5: {1:.4f}'.format(
int8_acc1, int8_acc5))
assert fp32_acc1 > 0.0
assert int8_acc1 > 0.0
assert fp32_acc1 - int8_acc1 <= threshold
def test_graph_transformation(self):
if not fluid.core.is_compiled_with_mkldnn():
return
qat_model_path = test_case_args.qat_model
data_path = test_case_args.infer_data
batch_size = test_case_args.batch_size
batch_num = test_case_args.batch_num
skip_batch_num = test_case_args.skip_batch_num
acc_diff_threshold = test_case_args.acc_diff_threshold
_logger.info('QAT FP32 & INT8 prediction run.')
_logger.info('QAT model: {0}'.format(qat_model_path))
_logger.info('Dataset: {0}'.format(data_path))
_logger.info('Batch size: {0}'.format(batch_size))
_logger.info('Batch number: {0}'.format(batch_num))
_logger.info('Accuracy diff threshold: {0}. '
'(condition: (fp32_acc - int8_acc) <= threshold)'
.format(acc_diff_threshold))
_logger.info('--- QAT FP32 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path), batch_size=batch_size)
fp32_output, fp32_acc1, fp32_acc5, fp32_fps, fp32_lat = self._predict(
val_reader,
qat_model_path,
batch_num,
skip_batch_num,
transform_to_int8=False)
_logger.info('--- QAT INT8 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path), batch_size=batch_size)
int8_output, int8_acc1, int8_acc5, int8_fps, int8_lat = self._predict(
val_reader,
qat_model_path,
batch_num,
skip_batch_num,
transform_to_int8=True)
_logger.info('--- Performance summary ---')
_logger.info('FP32: avg fps: {0:.2f}, avg latency: {1:.4f} s'.format(
fp32_fps, fp32_lat))
_logger.info('INT8: avg fps: {0:.2f}, avg latency: {1:.4f} s'.format(
int8_fps, int8_lat))
_logger.info('--- Comparing accuracy ---')
self._compare_accuracy(fp32_acc1, fp32_acc5, int8_acc1, int8_acc5,
acc_diff_threshold)
if __name__ == '__main__':
global test_case_args
test_case_args, remaining_args = parse_args()
unittest.main(argv=remaining_args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册