diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index a173328e64ae5ec682fe6b729b8f1ee4f86fb867..bb8faf30fdd87e9045fa3fdc6343c0bca9a0b2ac 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -142,6 +142,19 @@ function(inference_analysis_api_lexical_test_run TARGET_NAME test_binary infer_m --iterations=2) endfunction() +function(inference_analysis_api_lexical_bfloat16_test_run TARGET_NAME test_binary infer_model data_path) + inference_analysis_test_run(${TARGET_NAME} + COMMAND ${test_binary} + ARGS --infer_model=${infer_model} + --infer_data=${data_path} + --batch_size=50 + --cpu_num_threads=${CPU_NUM_THREADS_ON_CI} + --with_accuracy_layer=true + --use_analysis=true + --enable_bf16=true + --iterations=2) +endfunction() + function(preprocess_data2bin_test_run target py_script_source data_dir output_file) py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/${py_script_source} ARGS --data_dir=${data_dir} @@ -421,6 +434,8 @@ if(WITH_MKLDNN) inference_analysis_api_test_build(${LEXICAL_TEST_APP} ${LEXICAL_TEST_APP_SRC}) # run lexcial analysis test inference_analysis_api_lexical_test_run(test_analyzer_lexical_gru ${LEXICAL_TEST_APP} ${GRU_MODEL_PATH} ${GRU_DATA_PATH}) + # run bfloat16 lexical analysis test + inference_analysis_api_lexical_bfloat16_test_run(test_analyzer_lexical_gru_bfloat16 ${LEXICAL_TEST_APP} ${GRU_MODEL_PATH} ${GRU_DATA_PATH}) ### optimized FP32 vs. Quant INT8 tests diff --git a/paddle/fluid/inference/tests/api/analyzer_lexical_analysis_gru_tester.cc b/paddle/fluid/inference/tests/api/analyzer_lexical_analysis_gru_tester.cc index 7c5757ce9d4c630c1f9d96f27310b8f6e86eeac6..024313837e0b63a4ff2325b9cedd75a608c2a879 100644 --- a/paddle/fluid/inference/tests/api/analyzer_lexical_analysis_gru_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_lexical_analysis_gru_tester.cc @@ -38,6 +38,7 @@ void SetAnalysisConfig(AnalysisConfig *cfg, cfg->SwitchSpecifyInputNames(false); cfg->SetCpuMathLibraryNumThreads(num_threads); cfg->EnableMKLDNN(); + cfg->pass_builder()->AppendPass("mkldnn_placement_pass"); } std::vector ReadSentenceLod(std::ifstream &file, size_t offset, @@ -210,7 +211,7 @@ TEST(Analyzer_lexical_test, Analyzer_lexical_analysis) { if (FLAGS_use_analysis) { AnalysisConfig analysis_cfg; SetAnalysisConfig(&analysis_cfg, FLAGS_cpu_num_threads); - analysis_cfg.pass_builder()->AppendPass("mkldnn_placement_pass"); + if (FLAGS_enable_bf16) analysis_cfg.EnableMkldnnBfloat16(); std::vector acc_analysis(3); acc_analysis = Lexical_Test(input_slots_all, &outputs, &analysis_cfg, true); for (size_t i = 0; i < acc_analysis.size(); i++) {