提交 250a1921 编写于 作者: G GaoWei8 提交者: Tao Luo

Add ernie large c++ inference test (#21365)

* add ernie-large test
test=develop

* add ernie large c++ inference test
test=develop
上级 2445fef3
......@@ -76,13 +76,6 @@ function(inference_analysis_api_test_with_fake_data_run TARGET_NAME test_binary
--disable_mkldnn_fc=${disable_fc})
endfunction()
function(inference_analysis_api_test_with_refer_result target install_dir filename)
inference_analysis_test(${target} SRCS ${filename}
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${install_dir}/model --infer_data=${install_dir}/data.txt
--refer_result=${install_dir}/result.txt)
endfunction()
function(inference_analysis_api_qat_test_run TARGET_NAME test_binary fp32_model_dir int8_model_dir data_path)
inference_analysis_test_run(${TARGET_NAME}
COMMAND ${test_binary}
......@@ -157,6 +150,14 @@ download_model_and_data(${ERNIE_INSTALL_DIR} "Ernie_model.tar.gz" "Ernie_data.tx
download_result(${ERNIE_INSTALL_DIR} "Ernie_result.txt.tar.gz")
inference_analysis_api_test(test_analyzer_ernie ${ERNIE_INSTALL_DIR} analyzer_ernie_tester.cc)
#Ernie large
set(ERNIE_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_Large")
download_model_and_data(${ERNIE_INSTALL_DIR} "Ernie_large_model.tar.gz" "Ernie_large_data.txt.tar.gz" "Ernie_large_result.txt.tar.gz")
download_result(${ERNIE_INSTALL_DIR} "Ernie_large_result.txt.tar.gz")
inference_analysis_test(test_analyzer_ernie_large SRCS analyzer_ernie_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} benchmark
ARGS --infer_model=${ERNIE_INSTALL_DIR}/model --infer_data=${ERNIE_INSTALL_DIR}/data.txt --refer_result=${ERNIE_INSTALL_DIR}/result.txt --ernie_large=true)
# text_classification
set(TEXT_CLASSIFICATION_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/text_classification")
download_model_and_data(${TEXT_CLASSIFICATION_INSTALL_DIR} "text-classification-Senta.tar.gz" "text_classification_data.txt.tar.gz")
......@@ -180,14 +181,14 @@ set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr")
if (NOT EXISTS ${OCR_INSTALL_DIR})
inference_download_and_uncompress(${OCR_INSTALL_DIR} "http://paddlemodels.bj.bcebos.com/" "inference-vis-demos%2Focr.tar.gz")
endif()
inference_analysis_api_test_with_refer_result(test_analyzer_ocr ${OCR_INSTALL_DIR} analyzer_vis_tester.cc)
inference_analysis_api_test(test_analyzer_ocr ${OCR_INSTALL_DIR} analyzer_vis_tester.cc)
# mobilenet with transpose op
set(MOBILENET_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/mobilenet")
if (NOT EXISTS ${MOBILENET_INSTALL_DIR})
inference_download_and_uncompress(${MOBILENET_INSTALL_DIR} "http://paddlemodels.bj.bcebos.com/" "inference-vis-demos%2Fmobilenet.tar.gz")
endif()
inference_analysis_api_test_with_refer_result(test_analyzer_mobilenet_transpose ${MOBILENET_INSTALL_DIR} analyzer_vis_tester.cc)
inference_analysis_api_test(test_analyzer_mobilenet_transpose ${MOBILENET_INSTALL_DIR} analyzer_vis_tester.cc)
### Image classification tests with fake data
set(IMG_CLASS_TEST_APP "test_analyzer_image_classification")
......
......@@ -91,17 +91,18 @@ bool ParseLine(const std::string &line,
tensors->reserve(4);
int i = 0;
auto input_name = FLAGS_ernie_large ? "eval_placeholder_" : "placeholder_";
for (; i < 3; i++) {
paddle::PaddleTensor temp;
ParseTensor<int64_t>(fields[i], &temp);
temp.name = "placeholder_" + std::to_string(i);
temp.name = input_name + std::to_string(i);
tensors->push_back(temp);
}
// input_mask
paddle::PaddleTensor input_mask;
ParseTensor<float>(fields[i++], &input_mask);
input_mask.name = "placeholder_3";
ParseTensor<float>(fields[i], &input_mask);
input_mask.name = input_name + std::to_string(i);
tensors->push_back(input_mask);
return true;
......@@ -176,9 +177,14 @@ TEST(Analyzer_Ernie, fuse_statis) {
auto fuse_statis = GetFuseStatis(
static_cast<AnalysisPredictor *>(predictor.get()), &num_ops);
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
ASSERT_EQ(fuse_statis.at("fc_fuse"), 74);
LOG(INFO) << "num_ops: " << num_ops;
EXPECT_EQ(num_ops, 295);
if (FLAGS_ernie_large) {
ASSERT_EQ(fuse_statis.at("fc_fuse"), 146);
EXPECT_EQ(num_ops, 859);
} else {
ASSERT_EQ(fuse_statis.at("fc_fuse"), 74);
EXPECT_EQ(num_ops, 295);
}
}
// Compare result of NativeConfig and AnalysisConfig
......
......@@ -44,6 +44,7 @@ DEFINE_string(int8_model, "", "INT8 model path");
DEFINE_string(infer_data, "", "data file");
DEFINE_string(refer_result, "", "reference result for comparison");
DEFINE_int32(batch_size, 1, "batch size");
DEFINE_bool(ernie_large, false, "Test ernie large");
DEFINE_bool(with_accuracy_layer, true,
"Calculate the accuracy while label is in the input");
DEFINE_bool(enable_fp32, true, "Enable FP32 type prediction");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册