From d11602481caccb058f55b15f40511fc0d3dafe3f Mon Sep 17 00:00:00 2001 From: "joanna.wozna.intel" Date: Thu, 25 Feb 2021 04:14:24 +0100 Subject: [PATCH] Add bf16 gru model test (#31158) --- paddle/fluid/inference/tests/api/CMakeLists.txt | 15 +++++++++++++++ .../api/analyzer_lexical_analysis_gru_tester.cc | 3 ++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index a173328e64a..bb8faf30fdd 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -142,6 +142,19 @@ function(inference_analysis_api_lexical_test_run TARGET_NAME test_binary infer_m --iterations=2) endfunction() +function(inference_analysis_api_lexical_bfloat16_test_run TARGET_NAME test_binary infer_model data_path) + inference_analysis_test_run(${TARGET_NAME} + COMMAND ${test_binary} + ARGS --infer_model=${infer_model} + --infer_data=${data_path} + --batch_size=50 + --cpu_num_threads=${CPU_NUM_THREADS_ON_CI} + --with_accuracy_layer=true + --use_analysis=true + --enable_bf16=true + --iterations=2) +endfunction() + function(preprocess_data2bin_test_run target py_script_source data_dir output_file) py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/${py_script_source} ARGS --data_dir=${data_dir} @@ -421,6 +434,8 @@ if(WITH_MKLDNN) inference_analysis_api_test_build(${LEXICAL_TEST_APP} ${LEXICAL_TEST_APP_SRC}) # run lexcial analysis test inference_analysis_api_lexical_test_run(test_analyzer_lexical_gru ${LEXICAL_TEST_APP} ${GRU_MODEL_PATH} ${GRU_DATA_PATH}) + # run bfloat16 lexical analysis test + inference_analysis_api_lexical_bfloat16_test_run(test_analyzer_lexical_gru_bfloat16 ${LEXICAL_TEST_APP} ${GRU_MODEL_PATH} ${GRU_DATA_PATH}) ### optimized FP32 vs. Quant INT8 tests diff --git a/paddle/fluid/inference/tests/api/analyzer_lexical_analysis_gru_tester.cc b/paddle/fluid/inference/tests/api/analyzer_lexical_analysis_gru_tester.cc index 7c5757ce9d4..024313837e0 100644 --- a/paddle/fluid/inference/tests/api/analyzer_lexical_analysis_gru_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_lexical_analysis_gru_tester.cc @@ -38,6 +38,7 @@ void SetAnalysisConfig(AnalysisConfig *cfg, cfg->SwitchSpecifyInputNames(false); cfg->SetCpuMathLibraryNumThreads(num_threads); cfg->EnableMKLDNN(); + cfg->pass_builder()->AppendPass("mkldnn_placement_pass"); } std::vector ReadSentenceLod(std::ifstream &file, size_t offset, @@ -210,7 +211,7 @@ TEST(Analyzer_lexical_test, Analyzer_lexical_analysis) { if (FLAGS_use_analysis) { AnalysisConfig analysis_cfg; SetAnalysisConfig(&analysis_cfg, FLAGS_cpu_num_threads); - analysis_cfg.pass_builder()->AppendPass("mkldnn_placement_pass"); + if (FLAGS_enable_bf16) analysis_cfg.EnableMkldnnBfloat16(); std::vector acc_analysis(3); acc_analysis = Lexical_Test(input_slots_all, &outputs, &analysis_cfg, true); for (size_t i = 0; i < acc_analysis.size(); i++) { -- GitLab