diff --git a/cmake/generic.cmake b/cmake/generic.cmake index dfa90a3fe632bd4fc620ab55d52f079434727a64..c5bedf376ba6bc9d0eadfc20f5671648878de1f7 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -385,7 +385,7 @@ function(cc_test TARGET_NAME) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G - set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true ${MKL_DEBUG_FLAG}) + set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true) # No unit test should exceed 10 minutes. set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600) endif() diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 548ef382576a968e32e99b0d351276c486a06bbd..3422af325129e44592531e4322e7303c7563531e 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -33,14 +33,12 @@ function(inference_analysis_api_int8_test target model_dir data_dir filename) --paddle_num_threads=${CPU_NUM_THREADS_ON_CI} --iterations=2) endfunction() -function(inference_analysis_api_test_with_fake_data target install_dir filename model_name mkl_debug) - if(mkl_debug) - set(MKL_DEBUG_FLAG MKL_DEBUG_CPU_TYPE=7) - endif() +function(inference_analysis_api_test_with_fake_data target install_dir filename model_name disable_fc) download_model(${install_dir} ${model_name}) inference_analysis_test(${target} SRCS ${filename} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} - ARGS --infer_model=${install_dir}/model) + ARGS --infer_model=${install_dir}/model + --disable_mkldnn_fc=${disable_fc}) endfunction() function(inference_analysis_api_test_with_refer_result target install_dir filename) diff --git a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc index 602d59457c0bb99796e59658f0f17c3ad957179d..07934f96dc55ef1e80b54472c74975ff62b6add3 100644 --- a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc @@ -16,6 +16,8 @@ limitations under the License. */ #include #include "paddle/fluid/inference/tests/api/tester_helper.h" +DEFINE_bool(disable_mkldnn_fc, false, "Disable usage of MKL-DNN's FC op"); + namespace paddle { namespace inference { namespace analysis { @@ -48,7 +50,8 @@ void profile(bool use_mkldnn = false) { if (use_mkldnn) { cfg.EnableMKLDNN(); - cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); + if (!FLAGS_disable_mkldnn_fc) + cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); } std::vector> outputs; @@ -80,7 +83,8 @@ void compare(bool use_mkldnn = false) { SetConfig(&cfg); if (use_mkldnn) { cfg.EnableMKLDNN(); - cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); + if (!FLAGS_disable_mkldnn_fc) + cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); } std::vector> input_slots_all; diff --git a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc index 5d8e81921157cbdf35f7016741ab45c362b7261f..c635fd11c37aec5fd75dedfa5ba575868a564232 100644 --- a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc @@ -36,7 +36,8 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx, const std::vector& ksize, const std::vector& strides, const std::vector& paddings, - const memory::data_type& dt, const std::string& suffix) { + const memory::data_type& dt, const memory::format& fmt, + const std::string& suffix) { std::string key; key.reserve(platform::MKLDNNHandler::MaxKeyLength); platform::MKLDNNHandler::AppendKeyDims(&key, input_dims); @@ -45,6 +46,7 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx, platform::MKLDNNHandler::AppendKeyVec(&key, strides); platform::MKLDNNHandler::AppendKeyVec(&key, paddings); platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt)); + platform::MKLDNNHandler::AppendKey(&key, std::to_string(fmt)); platform::MKLDNNHandler::AppendKey(&key, suffix); return key; } @@ -115,8 +117,10 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { mkldnn::memory::data_type dt = paddle::framework::ToMKLDNNDataType(input->type()); - const std::string key = CreateKey(ctx, src_tz, pooling_type, ksize, strides, - paddings, dt, ctx.op().Output("Out")); + auto fmt = input->format(); + const std::string key = + CreateKey(ctx, src_tz, pooling_type, ksize, strides, paddings, dt, fmt, + ctx.op().Output("Out")); const std::string key_pool_p = key + "@pool_p"; const std::string key_pool_pd = key + "@pool_pd"; const std::string key_pool_src_mem_p = key + "@pool_src_mem_p"; @@ -294,9 +298,9 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel { // Get an unique name from "argument" name of "Out" variable // This name will be used as key when referring info from device context - const std::string key = - CreateKey(ctx, diff_src_tz, pooling_type, ksize, strides, paddings, - memory::data_type::f32, ctx.op().Input("Out")); + const std::string key = CreateKey(ctx, diff_src_tz, pooling_type, ksize, + strides, paddings, memory::data_type::f32, + in_x->format(), ctx.op().Input("Out")); const std::string key_pool_bwd_p = key + "@pool_bwd_p"; const std::string key_pool_diff_src_mem_p = key + "@pool_diff_src_mem_p"; const std::string key_pool_diff_dst_mem_p = key + "@pool_diff_dst_mem_p"; diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index 6c8a487dbcf8bc5153504cc521999a407ccbe082..23607d5052c3ec495e53eb3907d61d651b09f292 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -11,6 +11,18 @@ function(inference_analysis_python_api_int8_test target model_dir data_dir filen --batch_size 50) endfunction() +function(inference_qat_int8_test target model_dir data_dir test_script use_mkldnn) + py_test(${target} SRCS ${test_script} + ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} + OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} + FLAGS_use_mkldnn=${use_mkldnn} + ARGS --qat_model ${model_dir}/model + --infer_data ${data_dir}/data.bin + --batch_size 25 + --batch_num 2 + --acc_diff_threshold 0.1) +endfunction() + # NOTE: TODOOOOOOOOOOO # temporarily disable test_distillation_strategy since it always failed on a specified machine with 4 GPUs # Need to figure out the root cause and then add it back @@ -62,6 +74,74 @@ endif() # with MKL-DNN, we remove it here for not repeating test, or not testing on other systems. list(REMOVE_ITEM TEST_OPS test_mkldnn_int8_quantization_strategy) +# QAT FP32 & INT8 comparison python api tests +if(LINUX AND WITH_MKLDNN) + set(DATASET_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2") + set(QAT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2") + set(QAT_MODELS_BASE_URL "${INFERENCE_URL}/int8/QAT_models") + set(MKLDNN_QAT_TEST_FILE "qat_int8_comparison.py") + set(MKLDNN_QAT_TEST_FILE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_QAT_TEST_FILE}") + + # ImageNet small dataset + # May be already downloaded for INT8v2 unit tests + if (NOT EXISTS ${DATASET_DIR}) + inference_download_and_uncompress(${DATASET_DIR} "${INFERENCE_URL}/int8" "imagenet_val_100_tail.tar.gz") + endif() + + # QAT ResNet50 + set(QAT_RESNET50_MODEL_DIR "${QAT_DATA_DIR}/ResNet50_QAT") + if (NOT EXISTS ${QAT_RESNET50_MODEL_DIR}) + inference_download_and_uncompress(${QAT_RESNET50_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet50_qat_model.tar.gz" ) + endif() + inference_qat_int8_test(test_qat_int8_resnet50_mkldnn ${QAT_RESNET50_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true) + + # QAT ResNet101 + set(QAT_RESNET101_MODEL_DIR "${QAT_DATA_DIR}/ResNet101_QAT") + if (NOT EXISTS ${QAT_RESNET101_MODEL_DIR}) + inference_download_and_uncompress(${QAT_RESNET101_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet101_qat_model.tar.gz" ) + endif() + inference_qat_int8_test(test_qat_int8_resnet101_mkldnn ${QAT_RESNET101_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true) + + # QAT GoogleNet + set(QAT_GOOGLENET_MODEL_DIR "${QAT_DATA_DIR}/GoogleNet_QAT") + if (NOT EXISTS ${QAT_GOOGLENET_MODEL_DIR}) + inference_download_and_uncompress(${QAT_GOOGLENET_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "GoogleNet_qat_model.tar.gz" ) + endif() + inference_qat_int8_test(test_qat_int8_googlenet_mkldnn ${QAT_GOOGLENET_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true) + + # QAT MobileNetV1 + set(QAT_MOBILENETV1_MODEL_DIR "${QAT_DATA_DIR}/MobileNetV1_QAT") + if (NOT EXISTS ${QAT_MOBILENETV1_MODEL_DIR}) + inference_download_and_uncompress(${QAT_MOBILENETV1_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNetV1_qat_model.tar.gz" ) + endif() + inference_qat_int8_test(test_qat_int8_mobilenetv1_mkldnn ${QAT_MOBILENETV1_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true) + + # QAT MobileNetV2 + set(QAT_MOBILENETV2_MODEL_DIR "${QAT_DATA_DIR}/MobileNetV2_QAT") + if (NOT EXISTS ${QAT_MOBILENETV2_MODEL_DIR}) + inference_download_and_uncompress(${QAT_MOBILENETV2_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNetV2_qat_model.tar.gz" ) + endif() + inference_qat_int8_test(test_qat_int8_mobilenetv2_mkldnn ${QAT_MOBILENETV2_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true) + + # QAT VGG16 + set(QAT_VGG16_MODEL_DIR "${QAT_DATA_DIR}/VGG16_QAT") + if (NOT EXISTS ${QAT_VGG16_MODEL_DIR}) + inference_download_and_uncompress(${QAT_VGG16_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "VGG16_qat_model.tar.gz" ) + endif() + inference_qat_int8_test(test_qat_int8_vgg16_mkldnn ${QAT_VGG16_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true) + + # QAT VGG19 + set(QAT_VGG19_MODEL_DIR "${QAT_DATA_DIR}/VGG19_QAT") + if (NOT EXISTS ${QAT_VGG19_MODEL_DIR}) + inference_download_and_uncompress(${QAT_VGG19_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "VGG19_qat_model.tar.gz" ) + endif() + inference_qat_int8_test(test_qat_int8_vgg19_mkldnn ${QAT_VGG19_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true) +endif() + +# Since the test for QAT FP32 & INT8 comparison supports only testing on Linux +# with MKL-DNN, we remove it here to not test it on other systems. +list(REMOVE_ITEM TEST_OPS qat_int8_comparison.py) + foreach(src ${TEST_OPS}) py_test(${src} SRCS ${src}.py) endforeach() diff --git a/python/paddle/fluid/contrib/slim/tests/qat_int8_comparison.py b/python/paddle/fluid/contrib/slim/tests/qat_int8_comparison.py new file mode 100644 index 0000000000000000000000000000000000000000..f8cd5a663ec4fbc3a1817782e9bf4a7773ca609a --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/qat_int8_comparison.py @@ -0,0 +1,298 @@ +# copyright (c) 2019 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +import unittest +import os +import sys +import argparse +import logging +import struct +import six +import numpy as np +import time +import paddle +import paddle.fluid as fluid +from paddle.fluid.framework import IrGraph +from paddle.fluid.contrib.slim.quantization import TransformForMkldnnPass +from paddle.fluid import core + +logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--batch_size', type=int, default=1, help='Batch size.') + parser.add_argument( + '--skip_batch_num', + type=int, + default=0, + help='Number of the first minibatches to skip in performance statistics.' + ) + parser.add_argument( + '--qat_model', type=str, default='', help='A path to a QAT model.') + parser.add_argument('--infer_data', type=str, default='', help='Data file.') + parser.add_argument( + '--batch_num', + type=int, + default=1, + help='Number of batches to process. 0 or less means all.') + parser.add_argument( + '--acc_diff_threshold', + type=float, + default=0.01, + help='Accepted accuracy difference threshold.') + + test_args, args = parser.parse_known_args(namespace=unittest) + + return test_args, sys.argv[:1] + args + + +class TestQatInt8Comparison(unittest.TestCase): + """ + Test for accuracy comparison of QAT FP32 and INT8 inference. + """ + + def _reader_creator(self, data_file='data.bin'): + def reader(): + with open(data_file, 'rb') as fp: + num = fp.read(8) + num = struct.unpack('q', num)[0] + imgs_offset = 8 + img_ch = 3 + img_w = 224 + img_h = 224 + img_pixel_size = 4 + img_size = img_ch * img_h * img_w * img_pixel_size + label_size = 8 + labels_offset = imgs_offset + num * img_size + + step = 0 + while step < num: + fp.seek(imgs_offset + img_size * step) + img = fp.read(img_size) + img = struct.unpack_from('{}f'.format(img_ch * img_w * + img_h), img) + img = np.array(img) + img.shape = (img_ch, img_w, img_h) + fp.seek(labels_offset + label_size * step) + label = fp.read(label_size) + label = struct.unpack('q', label)[0] + yield img, int(label) + step += 1 + + return reader + + def _get_batch_accuracy(self, batch_output=None, labels=None): + total = 0 + correct = 0 + correct_5 = 0 + for n, result in enumerate(batch_output): + index = result.argsort() + top_1_index = index[-1] + top_5_index = index[-5:] + total += 1 + if top_1_index == labels[n]: + correct += 1 + if labels[n] in top_5_index: + correct_5 += 1 + acc1 = float(correct) / float(total) + acc5 = float(correct_5) / float(total) + return acc1, acc5 + + def _prepare_for_fp32_mkldnn(self, graph): + ops = graph.all_op_nodes() + for op_node in ops: + name = op_node.name() + if name in ['depthwise_conv2d']: + input_var_node = graph._find_node_by_name( + op_node.inputs, op_node.input("Input")[0]) + weight_var_node = graph._find_node_by_name( + op_node.inputs, op_node.input("Filter")[0]) + output_var_node = graph._find_node_by_name( + graph.all_var_nodes(), op_node.output("Output")[0]) + attrs = { + name: op_node.op().attr(name) + for name in op_node.op().attr_names() + } + + conv_op_node = graph.create_op_node( + op_type='conv2d', + attrs=attrs, + inputs={ + 'Input': input_var_node, + 'Filter': weight_var_node + }, + outputs={'Output': output_var_node}) + + graph.link_to(input_var_node, conv_op_node) + graph.link_to(weight_var_node, conv_op_node) + graph.link_to(conv_op_node, output_var_node) + graph.safe_remove_nodes(op_node) + + return graph + + def _predict(self, + test_reader=None, + model_path=None, + batch_num=1, + skip_batch_num=0, + transform_to_int8=False): + place = fluid.CPUPlace() + exe = fluid.Executor(place) + inference_scope = fluid.executor.global_scope() + with fluid.scope_guard(inference_scope): + if os.path.exists(os.path.join(model_path, '__model__')): + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(model_path, exe) + else: + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model( + model_path, exe, 'model', 'params') + + graph = IrGraph(core.Graph(inference_program.desc), for_test=True) + if (transform_to_int8): + mkldnn_int8_pass = TransformForMkldnnPass( + scope=inference_scope, place=place) + mkldnn_int8_pass.apply(graph) + else: + graph = self._prepare_for_fp32_mkldnn(graph) + inference_program = graph.to_program() + + dshape = [3, 224, 224] + outputs = [] + infer_accs1 = [] + infer_accs5 = [] + fpses = [] + batch_times = [] + total_samples = 0 + top1 = 0.0 + top5 = 0.0 + iters = 0 + infer_start_time = time.time() + for data in test_reader(): + if batch_num > 0 and iters >= batch_num: + break + if iters == skip_batch_num: + total_samples = 0 + infer_start_time = time.time() + if six.PY2: + images = map(lambda x: x[0].reshape(dshape), data) + if six.PY3: + images = list(map(lambda x: x[0].reshape(dshape), data)) + images = np.array(images).astype('float32') + labels = np.array([x[1] for x in data]).astype('int64') + + start = time.time() + out = exe.run(inference_program, + feed={feed_target_names[0]: images}, + fetch_list=fetch_targets) + batch_time = time.time() - start + outputs.append(out[0]) + batch_acc1, batch_acc5 = self._get_batch_accuracy(out[0], + labels) + infer_accs1.append(batch_acc1) + infer_accs5.append(batch_acc5) + samples = len(data) + total_samples += samples + batch_times.append(batch_time) + fps = samples / batch_time + fpses.append(fps) + iters += 1 + appx = ' (warm-up)' if iters <= skip_batch_num else '' + _logger.info( + 'batch {0}{5}, acc1: {1:.4f}, acc5: {2:.4f}, ' + 'batch latency: {3:.4f} s, batch fps: {4:.2f}'.format( + iters, batch_acc1, batch_acc5, batch_time, fps, appx)) + + # Postprocess benchmark data + latencies = batch_times[skip_batch_num:] + latency_avg = np.average(latencies) + fpses = fpses[skip_batch_num:] + fps_avg = np.average(fpses) + infer_total_time = time.time() - infer_start_time + acc1_avg = np.mean(infer_accs1) + acc5_avg = np.mean(infer_accs5) + _logger.info('Total inference run time: {:.2f} s'.format( + infer_total_time)) + + return outputs, acc1_avg, acc5_avg, fps_avg, latency_avg + + def _compare_accuracy(self, fp32_acc1, fp32_acc5, int8_acc1, int8_acc5, + threshold): + _logger.info('Accepted acc1 diff threshold: {0}'.format(threshold)) + _logger.info('FP32: avg acc1: {0:.4f}, avg acc5: {1:.4f}'.format( + fp32_acc1, fp32_acc5)) + _logger.info('INT8: avg acc1: {0:.4f}, avg acc5: {1:.4f}'.format( + int8_acc1, int8_acc5)) + assert fp32_acc1 > 0.0 + assert int8_acc1 > 0.0 + assert fp32_acc1 - int8_acc1 <= threshold + + def test_graph_transformation(self): + if not fluid.core.is_compiled_with_mkldnn(): + return + + qat_model_path = test_case_args.qat_model + data_path = test_case_args.infer_data + batch_size = test_case_args.batch_size + batch_num = test_case_args.batch_num + skip_batch_num = test_case_args.skip_batch_num + acc_diff_threshold = test_case_args.acc_diff_threshold + + _logger.info('QAT FP32 & INT8 prediction run.') + _logger.info('QAT model: {0}'.format(qat_model_path)) + _logger.info('Dataset: {0}'.format(data_path)) + _logger.info('Batch size: {0}'.format(batch_size)) + _logger.info('Batch number: {0}'.format(batch_num)) + _logger.info('Accuracy diff threshold: {0}. ' + '(condition: (fp32_acc - int8_acc) <= threshold)' + .format(acc_diff_threshold)) + + _logger.info('--- QAT FP32 prediction start ---') + val_reader = paddle.batch( + self._reader_creator(data_path), batch_size=batch_size) + fp32_output, fp32_acc1, fp32_acc5, fp32_fps, fp32_lat = self._predict( + val_reader, + qat_model_path, + batch_num, + skip_batch_num, + transform_to_int8=False) + + _logger.info('--- QAT INT8 prediction start ---') + val_reader = paddle.batch( + self._reader_creator(data_path), batch_size=batch_size) + int8_output, int8_acc1, int8_acc5, int8_fps, int8_lat = self._predict( + val_reader, + qat_model_path, + batch_num, + skip_batch_num, + transform_to_int8=True) + + _logger.info('--- Performance summary ---') + _logger.info('FP32: avg fps: {0:.2f}, avg latency: {1:.4f} s'.format( + fp32_fps, fp32_lat)) + _logger.info('INT8: avg fps: {0:.2f}, avg latency: {1:.4f} s'.format( + int8_fps, int8_lat)) + + _logger.info('--- Comparing accuracy ---') + self._compare_accuracy(fp32_acc1, fp32_acc5, int8_acc1, int8_acc5, + acc_diff_threshold) + + +if __name__ == '__main__': + global test_case_args + test_case_args, remaining_args = parse_args() + unittest.main(argv=remaining_args)