未验证 提交 d1160248 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add bf16 gru model test (#31158)

上级 2f116534
...@@ -142,6 +142,19 @@ function(inference_analysis_api_lexical_test_run TARGET_NAME test_binary infer_m ...@@ -142,6 +142,19 @@ function(inference_analysis_api_lexical_test_run TARGET_NAME test_binary infer_m
--iterations=2) --iterations=2)
endfunction() endfunction()
function(inference_analysis_api_lexical_bfloat16_test_run TARGET_NAME test_binary infer_model data_path)
inference_analysis_test_run(${TARGET_NAME}
COMMAND ${test_binary}
ARGS --infer_model=${infer_model}
--infer_data=${data_path}
--batch_size=50
--cpu_num_threads=${CPU_NUM_THREADS_ON_CI}
--with_accuracy_layer=true
--use_analysis=true
--enable_bf16=true
--iterations=2)
endfunction()
function(preprocess_data2bin_test_run target py_script_source data_dir output_file) function(preprocess_data2bin_test_run target py_script_source data_dir output_file)
py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/${py_script_source} py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/${py_script_source}
ARGS --data_dir=${data_dir} ARGS --data_dir=${data_dir}
...@@ -421,6 +434,8 @@ if(WITH_MKLDNN) ...@@ -421,6 +434,8 @@ if(WITH_MKLDNN)
inference_analysis_api_test_build(${LEXICAL_TEST_APP} ${LEXICAL_TEST_APP_SRC}) inference_analysis_api_test_build(${LEXICAL_TEST_APP} ${LEXICAL_TEST_APP_SRC})
# run lexcial analysis test # run lexcial analysis test
inference_analysis_api_lexical_test_run(test_analyzer_lexical_gru ${LEXICAL_TEST_APP} ${GRU_MODEL_PATH} ${GRU_DATA_PATH}) inference_analysis_api_lexical_test_run(test_analyzer_lexical_gru ${LEXICAL_TEST_APP} ${GRU_MODEL_PATH} ${GRU_DATA_PATH})
# run bfloat16 lexical analysis test
inference_analysis_api_lexical_bfloat16_test_run(test_analyzer_lexical_gru_bfloat16 ${LEXICAL_TEST_APP} ${GRU_MODEL_PATH} ${GRU_DATA_PATH})
### optimized FP32 vs. Quant INT8 tests ### optimized FP32 vs. Quant INT8 tests
......
...@@ -38,6 +38,7 @@ void SetAnalysisConfig(AnalysisConfig *cfg, ...@@ -38,6 +38,7 @@ void SetAnalysisConfig(AnalysisConfig *cfg,
cfg->SwitchSpecifyInputNames(false); cfg->SwitchSpecifyInputNames(false);
cfg->SetCpuMathLibraryNumThreads(num_threads); cfg->SetCpuMathLibraryNumThreads(num_threads);
cfg->EnableMKLDNN(); cfg->EnableMKLDNN();
cfg->pass_builder()->AppendPass("mkldnn_placement_pass");
} }
std::vector<size_t> ReadSentenceLod(std::ifstream &file, size_t offset, std::vector<size_t> ReadSentenceLod(std::ifstream &file, size_t offset,
...@@ -210,7 +211,7 @@ TEST(Analyzer_lexical_test, Analyzer_lexical_analysis) { ...@@ -210,7 +211,7 @@ TEST(Analyzer_lexical_test, Analyzer_lexical_analysis) {
if (FLAGS_use_analysis) { if (FLAGS_use_analysis) {
AnalysisConfig analysis_cfg; AnalysisConfig analysis_cfg;
SetAnalysisConfig(&analysis_cfg, FLAGS_cpu_num_threads); SetAnalysisConfig(&analysis_cfg, FLAGS_cpu_num_threads);
analysis_cfg.pass_builder()->AppendPass("mkldnn_placement_pass"); if (FLAGS_enable_bf16) analysis_cfg.EnableMkldnnBfloat16();
std::vector<double> acc_analysis(3); std::vector<double> acc_analysis(3);
acc_analysis = Lexical_Test(input_slots_all, &outputs, &analysis_cfg, true); acc_analysis = Lexical_Test(input_slots_all, &outputs, &analysis_cfg, true);
for (size_t i = 0; i < acc_analysis.size(); i++) { for (size_t i = 0; i < acc_analysis.size(); i++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册