From bc93a2090617fc41b841b86e0eb7a3530b1b0b50 Mon Sep 17 00:00:00 2001 From: Wojciech Uss Date: Tue, 18 Jun 2019 06:09:24 +0200 Subject: [PATCH] Cherry pick #18077 and #18111 unify FP32 vs. INT8 comparison tests output, reuse C-API INT8 unit test application (#18145) * unify FP32 vs. INT8 comparison tests output (#18111) test=release/1.5 * reuse C-API INT8 unit test application (#18077) test=release/1.5 --- cmake/generic.cmake | 31 +++- .../fluid/inference/analysis/CMakeLists.txt | 38 ++++- .../fluid/inference/tests/api/CMakeLists.txt | 141 +++++++++++------- ...> analyzer_image_classification_tester.cc} | 0 .../fluid/inference/tests/api/tester_helper.h | 59 ++++++-- paddle/fluid/inference/tests/test.cmake | 28 +++- .../contrib/slim/tests/qat_int8_comparison.py | 56 ++++--- .../test_mkldnn_int8_quantization_strategy.py | 22 +-- 8 files changed, 259 insertions(+), 116 deletions(-) rename paddle/fluid/inference/tests/api/{analyzer_resnet50_tester.cc => analyzer_image_classification_tester.cc} (100%) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index c5bedf376ba..3e3a5ba66c8 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -363,10 +363,10 @@ function(cc_binary TARGET_NAME) target_link_libraries(${TARGET_NAME} ${os_dependency_modules}) endfunction(cc_binary) -function(cc_test TARGET_NAME) +function(cc_test_build TARGET_NAME) if(WITH_TESTING) set(oneValueArgs "") - set(multiValueArgs SRCS DEPS ARGS) + set(multiValueArgs SRCS DEPS) cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) add_executable(${TARGET_NAME} ${cc_test_SRCS}) if(WIN32) @@ -379,9 +379,18 @@ function(cc_test TARGET_NAME) target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} ${os_dependency_modules} paddle_gtest_main lod_tensor memory gtest gflags glog) add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) common_link(${TARGET_NAME}) + endif() +endfunction() + +function(cc_test_run TARGET_NAME) + if(WITH_TESTING) + set(oneValueArgs "") + set(multiValueArgs COMMAND ARGS) + cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) add_test(NAME ${TARGET_NAME} - COMMAND ${TARGET_NAME} ${cc_test_ARGS} - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) + COMMAND ${cc_test_COMMAND} + ARGS ${cc_test_ARGS} + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G @@ -389,6 +398,20 @@ function(cc_test TARGET_NAME) # No unit test should exceed 10 minutes. set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600) endif() +endfunction() + +function(cc_test TARGET_NAME) + if(WITH_TESTING) + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS ARGS) + cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + cc_test_build(${TARGET_NAME} + SRCS ${cc_test_SRCS} + DEPS ${cc_test_DEPS}) + cc_test_run(${TARGET_NAME} + COMMAND ${TARGET_NAME} + ARGS ${cc_test_ARGS}) + endif() endfunction(cc_test) function(nv_library TARGET_NAME) diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt index 7a795bda820..d79fb529092 100644 --- a/paddle/fluid/inference/analysis/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/CMakeLists.txt @@ -23,18 +23,46 @@ cc_library(analysis SRCS cc_test(test_dot SRCS dot_tester.cc DEPS analysis) +function(inference_analysis_test_build TARGET) + if(WITH_TESTING) + set(options "") + set(oneValueArgs "") + set(multiValueArgs SRCS EXTRA_DEPS) + cmake_parse_arguments(analysis_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + inference_base_test_build(${TARGET} + SRCS ${analysis_test_SRCS} + DEPS analysis pass ${GLOB_PASS_LIB} ${analysis_test_EXTRA_DEPS}) + endif() +endfunction() + +function(inference_analysis_test_run TARGET) + if(WITH_TESTING) + set(options "") + set(oneValueArgs "") + set(multiValueArgs COMMAND ARGS) + cmake_parse_arguments(analysis_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + inference_base_test_run(${TARGET} + COMMAND ${analysis_test_COMMAND} + ARGS ${analysis_test_ARGS}) + endif() +endfunction() + function(inference_analysis_test TARGET) if(WITH_TESTING) set(options "") set(oneValueArgs "") set(multiValueArgs SRCS ARGS EXTRA_DEPS) cmake_parse_arguments(analysis_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - inference_base_test(${TARGET} + inference_base_test_build(${TARGET} SRCS ${analysis_test_SRCS} - DEPS analysis pass ${GLOB_PASS_LIB} ${analysis_test_EXTRA_DEPS} - ARGS --inference_model_dir=${WORD2VEC_MODEL_DIR} ${analysis_test_ARGS}) + DEPS analysis pass ${GLOB_PASS_LIB} ${analysis_test_EXTRA_DEPS}) + inference_base_test_run(${TARGET} + COMMAND ${TARGET} + ARGS ${analysis_test_ARGS}) endif() endfunction(inference_analysis_test) -inference_analysis_test(test_analyzer SRCS analyzer_tester.cc - EXTRA_DEPS reset_tensor_array paddle_inference_api) +inference_analysis_test(test_analyzer + SRCS analyzer_tester.cc + EXTRA_DEPS reset_tensor_array paddle_inference_api + ARGS --inference_model_dir=${WORD2VEC_MODEL_DIR}) diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 3422af32512..db57d39ebf8 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -4,9 +4,15 @@ if(WITH_GPU AND TENSORRT_FOUND) set(INFERENCE_EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} analysis ${analysis_deps} ir_pass_manager analysis_predictor) endif() -function(download_model install_dir model_name) +function(download_data install_dir data_file) if (NOT EXISTS ${install_dir}) - inference_download_and_uncompress(${install_dir} ${INFERENCE_URL} ${model_name}) + inference_download_and_uncompress(${install_dir} ${INFERENCE_URL} ${data_file}) + endif() +endfunction() + +function(download_int8_data install_dir data_file) + if (NOT EXISTS ${install_dir}) + inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file}) endif() endfunction() @@ -23,21 +29,31 @@ function(inference_analysis_api_test target install_dir filename) ARGS --infer_model=${install_dir}/model --infer_data=${install_dir}/data.txt) endfunction() -function(inference_analysis_api_int8_test target model_dir data_dir filename) - inference_analysis_test(${target} SRCS ${filename} - EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} benchmark +function(inference_analysis_api_int8_test_build TARGET_NAME filename) + inference_analysis_test_build(${TARGET_NAME} SRCS ${filename} + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} benchmark) +endfunction() + +function(inference_analysis_api_int8_test_run TARGET_NAME test_binary model_dir data_path) + inference_analysis_test_run(${TARGET_NAME} + COMMAND ${test_binary} ARGS --infer_model=${model_dir}/model - --infer_data=${data_dir}/data.bin + --infer_data=${data_path} --warmup_batch_size=100 --batch_size=50 --paddle_num_threads=${CPU_NUM_THREADS_ON_CI} --iterations=2) endfunction() -function(inference_analysis_api_test_with_fake_data target install_dir filename model_name disable_fc) - download_model(${install_dir} ${model_name}) - inference_analysis_test(${target} SRCS ${filename} - EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} - ARGS --infer_model=${install_dir}/model + +function(inference_analysis_api_test_with_fake_data_build TARGET_NAME filename) + inference_analysis_test_build(${TARGET_NAME} SRCS ${filename} + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}) +endfunction() + +function(inference_analysis_api_test_with_fake_data_run TARGET_NAME test_binary model_dir disable_fc) + inference_analysis_test_run(${TARGET_NAME} + COMMAND ${test_binary} + ARGS --infer_model=${model_dir}/model --disable_mkldnn_fc=${disable_fc}) endfunction() @@ -141,73 +157,82 @@ if (NOT EXISTS ${MOBILENET_INSTALL_DIR}) endif() inference_analysis_api_test_with_refer_result(test_analyzer_mobilenet_transpose ${MOBILENET_INSTALL_DIR} analyzer_vis_tester.cc) +### Image classification tests with fake data +set(IMG_CLASS_TEST_APP "test_analyzer_image_classification") +set(IMG_CLASS_TEST_APP_SRC "analyzer_image_classification_tester.cc") + +# build test binary to be used in subsequent tests +inference_analysis_api_test_with_fake_data_build(${IMG_CLASS_TEST_APP} ${IMG_CLASS_TEST_APP_SRC}) + # googlenet -inference_analysis_api_test_with_fake_data(test_analyzer_googlenet - "${INFERENCE_DEMO_INSTALL_DIR}/googlenet" analyzer_resnet50_tester.cc "googlenet.tar.gz" false) +set(GOOGLENET_MODEL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/googlenet") +download_data(${GOOGLENET_MODEL_DIR} "googlenet.tar.gz") +inference_analysis_api_test_with_fake_data_run(test_analyzer_googlenet ${IMG_CLASS_TEST_APP} + ${GOOGLENET_MODEL_DIR} false) # resnet50 -inference_analysis_api_test_with_fake_data(test_analyzer_resnet50 - "${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz" true) +set(RESNET50_MODEL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/resnet50") +download_data(${RESNET50_MODEL_DIR} "resnet50_model.tar.gz") +inference_analysis_api_test_with_fake_data_run(test_analyzer_resnet50 ${IMG_CLASS_TEST_APP} + ${RESNET50_MODEL_DIR} true) # mobilenet with depthwise_conv op -inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet_depthwise_conv - "${INFERENCE_DEMO_INSTALL_DIR}/mobilenet_depthwise_conv" analyzer_resnet50_tester.cc "mobilenet_model.tar.gz" false) +set(MOBILENET_MODEL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/mobilenet_depthwise_conv") +download_data(${MOBILENET_MODEL_DIR} "mobilenet_model.tar.gz") +inference_analysis_api_test_with_fake_data_run(test_analyzer_mobilenet_depthwise_conv ${IMG_CLASS_TEST_APP} + ${MOBILENET_MODEL_DIR} false) -# int8 image classification tests +### INT8 tests if(WITH_MKLDNN) + set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2") - if (NOT EXISTS ${INT8_DATA_DIR}) - inference_download_and_uncompress(${INT8_DATA_DIR} "${INFERENCE_URL}/int8" "imagenet_val_100_tail.tar.gz") - endif() - #resnet50 int8 + ### Image classification tests + set(IMAGENET_DATA_PATH "${INT8_DATA_DIR}/data.bin") + set(INT8_IMG_CLASS_TEST_APP "test_analyzer_int8_image_classification") + set(INT8_IMG_CLASS_TEST_APP_SRC "analyzer_int8_image_classification_tester.cc") + + # download dataset if necessary + download_int8_data(${INT8_DATA_DIR} "imagenet_val_100_tail.tar.gz") + + # build test binary to be used in subsequent tests + inference_analysis_api_int8_test_build(${INT8_IMG_CLASS_TEST_APP} ${INT8_IMG_CLASS_TEST_APP_SRC}) + + # resnet50 int8 set(INT8_RESNET50_MODEL_DIR "${INT8_DATA_DIR}/resnet50") - if (NOT EXISTS ${INT8_RESNET50_MODEL_DIR}) - inference_download_and_uncompress(${INT8_RESNET50_MODEL_DIR} "${INFERENCE_URL}/int8" "resnet50_int8_model.tar.gz" ) - endif() - inference_analysis_api_int8_test(test_analyzer_int8_resnet50 ${INT8_RESNET50_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc) - - #mobilenet int8 - set(INT8_MOBILENET_MODEL_DIR "${INT8_DATA_DIR}/mobilenet") - if (NOT EXISTS ${INT8_MOBILENET_MODEL_DIR}) - inference_download_and_uncompress(${INT8_MOBILENET_MODEL_DIR} "${INFERENCE_URL}/int8" "mobilenetv1_int8_model.tar.gz" ) - endif() - inference_analysis_api_int8_test(test_analyzer_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc) + download_int8_data(${INT8_RESNET50_MODEL_DIR} "resnet50_int8_model.tar.gz" ) + inference_analysis_api_int8_test_run(test_analyzer_int8_resnet50 ${INT8_IMG_CLASS_TEST_APP} ${INT8_RESNET50_MODEL_DIR} ${IMAGENET_DATA_PATH}) + + # mobilenetv1 int8 + set(INT8_MOBILENETV1_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv1") + download_int8_data(${INT8_MOBILENETV1_MODEL_DIR} "mobilenetv1_int8_model.tar.gz" ) + inference_analysis_api_int8_test_run(test_analyzer_int8_mobilenetv1 ${INT8_IMG_CLASS_TEST_APP} ${INT8_MOBILENETV1_MODEL_DIR} ${IMAGENET_DATA_PATH}) - #mobilenetv2 int8 + # mobilenetv2 int8 set(INT8_MOBILENETV2_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv2") - if (NOT EXISTS ${INT8_MOBILENETV2_MODEL_DIR}) - inference_download_and_uncompress(${INT8_MOBILENETV2_MODEL_DIR} "${INFERENCE_URL}/int8" "mobilenet_v2_int8_model.tar.gz" ) - endif() - inference_analysis_api_int8_test(test_analyzer_int8_mobilenetv2 ${INT8_MOBILENETV2_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc) + download_int8_data(${INT8_MOBILENETV2_MODEL_DIR} "mobilenet_v2_int8_model.tar.gz" ) + inference_analysis_api_int8_test_run(test_analyzer_int8_mobilenetv2 ${INT8_IMG_CLASS_TEST_APP} ${INT8_MOBILENETV2_MODEL_DIR} ${IMAGENET_DATA_PATH}) - #resnet101 int8 + # resnet101 int8 set(INT8_RESNET101_MODEL_DIR "${INT8_DATA_DIR}/resnet101") - if (NOT EXISTS ${INT8_RESNET101_MODEL_DIR}) - inference_download_and_uncompress(${INT8_RESNET101_MODEL_DIR} "${INFERENCE_URL}/int8" "Res101_int8_model.tar.gz" ) - endif() - inference_analysis_api_int8_test(test_analyzer_int8_resnet101 ${INT8_RESNET101_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc) + download_int8_data(${INT8_RESNET101_MODEL_DIR} "Res101_int8_model.tar.gz" ) + inference_analysis_api_int8_test_run(test_analyzer_int8_resnet101 ${INT8_IMG_CLASS_TEST_APP} ${INT8_RESNET101_MODEL_DIR} ${IMAGENET_DATA_PATH}) - #vgg16 int8 + # vgg16 int8 set(INT8_VGG16_MODEL_DIR "${INT8_DATA_DIR}/vgg16") - if (NOT EXISTS ${INT8_VGG16_MODEL_DIR}) - inference_download_and_uncompress(${INT8_VGG16_MODEL_DIR} "${INFERENCE_URL}/int8" "VGG16_int8_model.tar.gz" ) - endif() - inference_analysis_api_int8_test(test_analyzer_int8_vgg16 ${INT8_VGG16_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc) + download_int8_data(${INT8_VGG16_MODEL_DIR} "VGG16_int8_model.tar.gz" ) + inference_analysis_api_int8_test_run(test_analyzer_int8_vgg16 ${INT8_IMG_CLASS_TEST_APP} ${INT8_VGG16_MODEL_DIR} ${IMAGENET_DATA_PATH}) - #vgg19 int8 + # vgg19 int8 set(INT8_VGG19_MODEL_DIR "${INT8_DATA_DIR}/vgg19") - if (NOT EXISTS ${INT8_VGG19_MODEL_DIR}) - inference_download_and_uncompress(${INT8_VGG19_MODEL_DIR} "${INFERENCE_URL}/int8" "VGG19_int8_model.tar.gz" ) - endif() - inference_analysis_api_int8_test(test_analyzer_int8_vgg19 ${INT8_VGG19_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc) + download_int8_data(${INT8_VGG19_MODEL_DIR} "VGG19_int8_model.tar.gz" ) + inference_analysis_api_int8_test_run(test_analyzer_int8_vgg19 ${INT8_IMG_CLASS_TEST_APP} ${INT8_VGG19_MODEL_DIR} ${IMAGENET_DATA_PATH}) - #googlenet int8 + # googlenet int8 set(INT8_GOOGLENET_MODEL_DIR "${INT8_DATA_DIR}/googlenet") - if (NOT EXISTS ${INT8_GOOGLENET_MODEL_DIR}) - inference_download_and_uncompress(${INT8_GOOGLENET_MODEL_DIR} "${INFERENCE_URL}/int8" "GoogleNet_int8_model.tar.gz" ) - endif() - inference_analysis_api_int8_test(test_analyzer_int8_googlenet ${INT8_GOOGLENET_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc SERIAL) + download_int8_data(${INT8_GOOGLENET_MODEL_DIR} "GoogleNet_int8_model.tar.gz" ) + inference_analysis_api_int8_test_run(test_analyzer_int8_googlenet ${INT8_IMG_CLASS_TEST_APP} ${INT8_GOOGLENET_MODEL_DIR} ${IMAGENET_DATA_PATH}) + endif() # bert, max_len=20, embedding_dim=128 diff --git a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc b/paddle/fluid/inference/tests/api/analyzer_image_classification_tester.cc similarity index 100% rename from paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc rename to paddle/fluid/inference/tests/api/analyzer_image_classification_tester.cc diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index eda86c3b42b..eb786196a88 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -320,7 +320,8 @@ void PredictionRun(PaddlePredictor *predictor, const std::vector> &inputs, std::vector> *outputs, int num_threads, int tid, - const VarType::Type data_type = VarType::FP32) { + const VarType::Type data_type = VarType::FP32, + float *sample_latency = nullptr) { int num_times = FLAGS_repeat; int iterations = inputs.size(); // process the whole dataset ... if (FLAGS_iterations > 0 && @@ -360,6 +361,10 @@ void PredictionRun(PaddlePredictor *predictor, auto batch_latency = elapsed_time / (iterations * num_times); PrintTime(FLAGS_batch_size, num_times, num_threads, tid, batch_latency, iterations, data_type); + + if (sample_latency != nullptr) + *sample_latency = batch_latency / FLAGS_batch_size; + if (FLAGS_record_benchmark) { Benchmark benchmark; benchmark.SetName(FLAGS_model_name); @@ -373,12 +378,14 @@ void TestOneThreadPrediction( const PaddlePredictor::Config *config, const std::vector> &inputs, std::vector> *outputs, bool use_analysis = true, - const VarType::Type data_type = VarType::FP32) { + const VarType::Type data_type = VarType::FP32, + float *sample_latency = nullptr) { auto predictor = CreateTestPredictor(config, use_analysis); if (FLAGS_warmup) { PredictionWarmUp(predictor.get(), inputs, outputs, 1, 0, data_type); } - PredictionRun(predictor.get(), inputs, outputs, 1, 0, data_type); + PredictionRun(predictor.get(), inputs, outputs, 1, 0, data_type, + sample_latency); } void TestMultiThreadPrediction( @@ -430,6 +437,31 @@ void TestPrediction(const PaddlePredictor::Config *config, } } +void SummarizeAccuracy(float avg_acc1_fp32, float avg_acc1_int8) { + LOG(INFO) << "--- Accuracy summary --- "; + LOG(INFO) << "Accepted top1 accuracy drop threshold: " + << FLAGS_quantized_accuracy + << ". (condition: (FP32_top1_acc - INT8_top1_acc) <= threshold)"; + LOG(INFO) << "FP32: avg top1 accuracy: " << std::fixed << std::setw(6) + << std::setprecision(4) << avg_acc1_fp32; + LOG(INFO) << "INT8: avg top1 accuracy: " << std::fixed << std::setw(6) + << std::setprecision(4) << avg_acc1_int8; +} + +void SummarizePerformance(float sample_latency_fp32, + float sample_latency_int8) { + // sample latency in ms + auto throughput_fp32 = 1000.0 / sample_latency_fp32; + auto throughput_int8 = 1000.0 / sample_latency_int8; + LOG(INFO) << "--- Performance summary --- "; + LOG(INFO) << "FP32: avg fps: " << std::fixed << std::setw(6) + << std::setprecision(4) << throughput_fp32 + << ", avg latency: " << sample_latency_fp32 << " ms"; + LOG(INFO) << "INT8: avg fps: " << std::fixed << std::setw(6) + << std::setprecision(4) << throughput_int8 + << ", avg latency: " << sample_latency_int8 << " ms"; +} + void CompareTopAccuracy( const std::vector> &output_slots_quant, const std::vector> &output_slots_ref) { @@ -459,12 +491,10 @@ void CompareTopAccuracy( float avg_acc1_quant = total_accs1_quant / output_slots_quant.size(); float avg_acc1_ref = total_accs1_ref / output_slots_ref.size(); - LOG(INFO) << "Avg top1 INT8 accuracy: " << std::fixed << std::setw(6) - << std::setprecision(4) << avg_acc1_quant; - LOG(INFO) << "Avg top1 FP32 accuracy: " << std::fixed << std::setw(6) - << std::setprecision(4) << avg_acc1_ref; - LOG(INFO) << "Accepted accuracy drop threshold: " << FLAGS_quantized_accuracy; - CHECK_LE(std::abs(avg_acc1_quant - avg_acc1_ref), FLAGS_quantized_accuracy); + SummarizeAccuracy(avg_acc1_ref, avg_acc1_quant); + CHECK_GT(avg_acc1_ref, 0.0); + CHECK_GT(avg_acc1_quant, 0.0); + CHECK_LE(avg_acc1_ref - avg_acc1_quant, FLAGS_quantized_accuracy); } void CompareDeterministic( @@ -510,16 +540,19 @@ void CompareQuantizedAndAnalysis( auto *cfg = reinterpret_cast(config); PrintConfig(cfg, true); std::vector> analysis_outputs; - TestOneThreadPrediction(cfg, inputs, &analysis_outputs, true, VarType::FP32); + float sample_latency_fp32{-1}; + TestOneThreadPrediction(cfg, inputs, &analysis_outputs, true, VarType::FP32, + &sample_latency_fp32); LOG(INFO) << "--- INT8 prediction start ---"; auto *qcfg = reinterpret_cast(qconfig); PrintConfig(qcfg, true); std::vector> quantized_outputs; - TestOneThreadPrediction(qcfg, inputs, &quantized_outputs, true, - VarType::INT8); + float sample_latency_int8{-1}; + TestOneThreadPrediction(qcfg, inputs, &quantized_outputs, true, VarType::INT8, + &sample_latency_int8); - LOG(INFO) << "--- comparing outputs --- "; + SummarizePerformance(sample_latency_fp32, sample_latency_int8); CompareTopAccuracy(quantized_outputs, analysis_outputs); } diff --git a/paddle/fluid/inference/tests/test.cmake b/paddle/fluid/inference/tests/test.cmake index c93c9ef2f23..444bab1b33d 100644 --- a/paddle/fluid/inference/tests/test.cmake +++ b/paddle/fluid/inference/tests/test.cmake @@ -48,13 +48,35 @@ if(NOT EXISTS ${WORD2VEC_INSTALL_DIR} AND NOT WIN32) endif() set(WORD2VEC_MODEL_DIR "${WORD2VEC_INSTALL_DIR}/word2vec.inference.model") -function (inference_base_test TARGET) +function (inference_base_test_build TARGET) set(options "") set(oneValueArgs "") - set(multiValueArgs SRCS ARGS DEPS) + set(multiValueArgs SRCS DEPS) + cmake_parse_arguments(base_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + cc_test_build(${TARGET} SRCS ${base_test_SRCS} DEPS ${base_test_DEPS}) +endfunction() + +function (inference_base_test_run TARGET) + set(options "") + set(oneValueArgs "") + set(multiValueArgs COMMAND ARGS) cmake_parse_arguments(base_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) if(WITH_GPU) set(mem_opt "--fraction_of_gpu_memory_to_use=0.5") endif() - cc_test(${TARGET} SRCS ${base_test_SRCS} DEPS ${base_test_DEPS} ARGS ${mem_opt} ${base_test_ARGS}) + cc_test_run(${TARGET} COMMAND ${base_test_COMMAND} ARGS ${mem_opt} ${base_test_ARGS}) endfunction() + +function (inference_base_test TARGET) + set(options "") + set(oneValueArgs "") + set(multiValueArgs SRCS ARGS DEPS) + cmake_parse_arguments(base_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + inference_base_test_build(${TARGET} + SRCS ${base_test_SRCS} + DEPS ${base_test_DEPS}) + inference_base_test_run(${TARGET} + COMMAND ${TARGET} + ARGS ${base_test_ARGS}) +endfunction() + diff --git a/python/paddle/fluid/contrib/slim/tests/qat_int8_comparison.py b/python/paddle/fluid/contrib/slim/tests/qat_int8_comparison.py index f8cd5a663ec..6673811a791 100644 --- a/python/paddle/fluid/contrib/slim/tests/qat_int8_comparison.py +++ b/python/paddle/fluid/contrib/slim/tests/qat_int8_comparison.py @@ -83,8 +83,8 @@ class TestQatInt8Comparison(unittest.TestCase): while step < num: fp.seek(imgs_offset + img_size * step) img = fp.read(img_size) - img = struct.unpack_from('{}f'.format(img_ch * img_w * - img_h), img) + img = struct.unpack_from( + '{}f'.format(img_ch * img_w * img_h), img) img = np.array(img) img.shape = (img_ch, img_w, img_h) fp.seek(labels_offset + label_size * step) @@ -147,6 +147,7 @@ class TestQatInt8Comparison(unittest.TestCase): def _predict(self, test_reader=None, model_path=None, + batch_size=1, batch_num=1, skip_batch_num=0, transform_to_int8=False): @@ -199,7 +200,7 @@ class TestQatInt8Comparison(unittest.TestCase): out = exe.run(inference_program, feed={feed_target_names[0]: images}, fetch_list=fetch_targets) - batch_time = time.time() - start + batch_time = (time.time() - start) * 1000 # in miliseconds outputs.append(out[0]) batch_acc1, batch_acc5 = self._get_batch_accuracy(out[0], labels) @@ -212,14 +213,15 @@ class TestQatInt8Comparison(unittest.TestCase): fpses.append(fps) iters += 1 appx = ' (warm-up)' if iters <= skip_batch_num else '' - _logger.info( - 'batch {0}{5}, acc1: {1:.4f}, acc5: {2:.4f}, ' - 'batch latency: {3:.4f} s, batch fps: {4:.2f}'.format( - iters, batch_acc1, batch_acc5, batch_time, fps, appx)) + _logger.info('batch {0}{5}, acc1: {1:.4f}, acc5: {2:.4f}, ' + 'latency: {3:.4f} ms, fps: {4:.2f}'.format( + iters, batch_acc1, batch_acc5, batch_time / + batch_size, fps, appx)) # Postprocess benchmark data - latencies = batch_times[skip_batch_num:] - latency_avg = np.average(latencies) + batch_latencies = batch_times[skip_batch_num:] + batch_latency_avg = np.average(batch_latencies) + latency_avg = batch_latency_avg / batch_size fpses = fpses[skip_batch_num:] fps_avg = np.average(fpses) infer_total_time = time.time() - infer_start_time @@ -230,13 +232,25 @@ class TestQatInt8Comparison(unittest.TestCase): return outputs, acc1_avg, acc5_avg, fps_avg, latency_avg + def _summarize_performance(self, fp32_fps, fp32_lat, int8_fps, int8_lat): + _logger.info('--- Performance summary ---') + _logger.info('FP32: avg fps: {0:.2f}, avg latency: {1:.4f} ms'.format( + fp32_fps, fp32_lat)) + _logger.info('INT8: avg fps: {0:.2f}, avg latency: {1:.4f} ms'.format( + int8_fps, int8_lat)) + def _compare_accuracy(self, fp32_acc1, fp32_acc5, int8_acc1, int8_acc5, threshold): - _logger.info('Accepted acc1 diff threshold: {0}'.format(threshold)) - _logger.info('FP32: avg acc1: {0:.4f}, avg acc5: {1:.4f}'.format( - fp32_acc1, fp32_acc5)) - _logger.info('INT8: avg acc1: {0:.4f}, avg acc5: {1:.4f}'.format( - int8_acc1, int8_acc5)) + _logger.info('--- Accuracy summary ---') + _logger.info( + 'Accepted top1 accuracy drop threshold: {0}. (condition: (FP32_top1_acc - IN8_top1_acc) <= threshold)' + .format(threshold)) + _logger.info( + 'FP32: avg top1 accuracy: {0:.4f}, avg top5 accuracy: {1:.4f}'. + format(fp32_acc1, fp32_acc5)) + _logger.info( + 'INT8: avg top1 accuracy: {0:.4f}, avg top5 accuracy: {1:.4f}'. + format(int8_acc1, int8_acc5)) assert fp32_acc1 > 0.0 assert int8_acc1 > 0.0 assert fp32_acc1 - int8_acc1 <= threshold @@ -257,9 +271,7 @@ class TestQatInt8Comparison(unittest.TestCase): _logger.info('Dataset: {0}'.format(data_path)) _logger.info('Batch size: {0}'.format(batch_size)) _logger.info('Batch number: {0}'.format(batch_num)) - _logger.info('Accuracy diff threshold: {0}. ' - '(condition: (fp32_acc - int8_acc) <= threshold)' - .format(acc_diff_threshold)) + _logger.info('Accuracy drop threshold: {0}.'.format(acc_diff_threshold)) _logger.info('--- QAT FP32 prediction start ---') val_reader = paddle.batch( @@ -267,6 +279,7 @@ class TestQatInt8Comparison(unittest.TestCase): fp32_output, fp32_acc1, fp32_acc5, fp32_fps, fp32_lat = self._predict( val_reader, qat_model_path, + batch_size, batch_num, skip_batch_num, transform_to_int8=False) @@ -277,17 +290,12 @@ class TestQatInt8Comparison(unittest.TestCase): int8_output, int8_acc1, int8_acc5, int8_fps, int8_lat = self._predict( val_reader, qat_model_path, + batch_size, batch_num, skip_batch_num, transform_to_int8=True) - _logger.info('--- Performance summary ---') - _logger.info('FP32: avg fps: {0:.2f}, avg latency: {1:.4f} s'.format( - fp32_fps, fp32_lat)) - _logger.info('INT8: avg fps: {0:.2f}, avg latency: {1:.4f} s'.format( - int8_fps, int8_lat)) - - _logger.info('--- Comparing accuracy ---') + self._summarize_performance(fp32_fps, fp32_lat, int8_fps, int8_lat) self._compare_accuracy(fp32_acc1, fp32_acc5, int8_acc1, int8_acc5, acc_diff_threshold) diff --git a/python/paddle/fluid/contrib/slim/tests/test_mkldnn_int8_quantization_strategy.py b/python/paddle/fluid/contrib/slim/tests/test_mkldnn_int8_quantization_strategy.py index f1ebb8ae72f..c7429af5ffb 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_mkldnn_int8_quantization_strategy.py +++ b/python/paddle/fluid/contrib/slim/tests/test_mkldnn_int8_quantization_strategy.py @@ -172,6 +172,17 @@ class TestMKLDNNPostTrainingQuantStrategy(unittest.TestCase): com_pass.config(config_path) com_pass.run() + def _compare_accuracy(self, fp32_acc1, int8_acc1, threshold): + _logger.info('--- Accuracy summary ---') + _logger.info( + 'Accepted top1 accuracy drop threshold: {0}. (condition: (FP32_top1_acc - IN8_top1_acc) <= threshold)' + .format(threshold)) + _logger.info('FP32: avg top1 accuracy: {0:.4f}'.format(fp32_acc1)) + _logger.info('INT8: avg top1 accuracy: {0:.4f}'.format(int8_acc1)) + assert fp32_acc1 > 0.0 + assert int8_acc1 > 0.0 + assert fp32_acc1 - int8_acc1 <= threshold + def test_compression(self): if not fluid.core.is_compiled_with_mkldnn(): return @@ -204,15 +215,8 @@ class TestMKLDNNPostTrainingQuantStrategy(unittest.TestCase): self._reader_creator(data_path, False), batch_size=batch_size) fp32_model_result = self._predict(val_reader, fp32_model_path) - _logger.info('--- comparing outputs ---') - _logger.info('Avg top1 INT8 accuracy: {0:.4f}'.format(int8_model_result[ - 0])) - _logger.info('Avg top1 FP32 accuracy: {0:.4f}'.format(fp32_model_result[ - 0])) - _logger.info('Accepted accuracy drop threshold: {0}'.format( - accuracy_diff_threshold)) - assert fp32_model_result[0] - int8_model_result[ - 0] <= accuracy_diff_threshold + self._compare_accuracy(fp32_model_result[0], int8_model_result[0], + accuracy_diff_threshold) if __name__ == '__main__': -- GitLab