未验证 提交 6aa6b8cf 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #14918 from luotao1/mobilenet_test

add test_analyzer_mobilenet
...@@ -30,6 +30,13 @@ function(inference_analysis_api_test_with_fake_data target install_dir filename ...@@ -30,6 +30,13 @@ function(inference_analysis_api_test_with_fake_data target install_dir filename
ARGS --infer_model=${install_dir}/model) ARGS --infer_model=${install_dir}/model)
endfunction() endfunction()
function(inference_analysis_api_test_with_refer_result target install_dir filename)
inference_analysis_test(${target} SRCS ${filename}
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${install_dir}/model --infer_data=${install_dir}/data.txt
--refer_result=${install_dir}/result.txt)
endfunction()
# RNN1 # RNN1
if(NOT APPLE AND WITH_MKLML) if(NOT APPLE AND WITH_MKLML)
set(RNN1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn1") set(RNN1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn1")
...@@ -83,14 +90,21 @@ set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr") ...@@ -83,14 +90,21 @@ set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr")
if (NOT EXISTS ${OCR_INSTALL_DIR}) if (NOT EXISTS ${OCR_INSTALL_DIR})
inference_download_and_uncompress(${OCR_INSTALL_DIR} "http://paddlemodels.cdn.bcebos.com/" "inference-vis-demos%2Focr.tar.gz") inference_download_and_uncompress(${OCR_INSTALL_DIR} "http://paddlemodels.cdn.bcebos.com/" "inference-vis-demos%2Focr.tar.gz")
endif() endif()
inference_analysis_api_test(test_analyzer_ocr ${OCR_INSTALL_DIR} analyzer_vis_tester.cc) inference_analysis_api_test_with_refer_result(test_analyzer_ocr ${OCR_INSTALL_DIR} analyzer_vis_tester.cc)
# mobilenet with transpose op
set(MOBILENET_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/mobilenet")
if (NOT EXISTS ${MOBILENET_INSTALL_DIR})
inference_download_and_uncompress(${MOBILENET_INSTALL_DIR} "http://paddlemodels.cdn.bcebos.com/" "inference-vis-demos%2Fmobilenet.tar.gz")
endif()
inference_analysis_api_test_with_refer_result(test_analyzer_mobilenet_transpose ${MOBILENET_INSTALL_DIR} analyzer_vis_tester.cc)
# resnet50 # resnet50
inference_analysis_api_test_with_fake_data(test_analyzer_resnet50 inference_analysis_api_test_with_fake_data(test_analyzer_resnet50
"${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz") "${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz")
# mobilenet with depthwise_conv op # mobilenet with depthwise_conv op
inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet_depthwise_conv
"${INFERENCE_DEMO_INSTALL_DIR}/mobilenet_depthwise_conv" analyzer_resnet50_tester.cc "mobilenet_model.tar.gz") "${INFERENCE_DEMO_INSTALL_DIR}/mobilenet_depthwise_conv" analyzer_resnet50_tester.cc "mobilenet_model.tar.gz")
# anakin # anakin
......
...@@ -93,18 +93,20 @@ void profile(bool use_mkldnn = false) { ...@@ -93,18 +93,20 @@ void profile(bool use_mkldnn = false) {
SetInput(&input_slots_all); SetInput(&input_slots_all);
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg), TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all, &outputs, FLAGS_num_threads); input_slots_all, &outputs, FLAGS_num_threads);
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) { if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
const float ocr_result_data[] = { std::string line;
5.273636460856323538e-08, 3.296741795111302054e-07, std::ifstream file(FLAGS_refer_result);
1.873261190610264748e-08, 3.403730275408634043e-08, std::getline(file, line);
3.383312474625199684e-08}; auto refer = ProcessALine(line);
PADDLE_ENFORCE_EQ(outputs.size(), 1UL); file.close();
size_t size = GetSize(outputs[0]);
PADDLE_ENFORCE_GT(size, 0); auto &output = outputs.front();
float *result = static_cast<float *>(outputs[0].data.data()); size_t numel = output.data.length() / PaddleDtypeSize(output.dtype);
for (size_t i = 0; i < std::min(5UL, size); i++) { CHECK_EQ(numel, refer.data.size());
EXPECT_NEAR(result[i], ocr_result_data[i], 1e-3); for (size_t i = 0; i < numel; ++i) {
CHECK_LT(
fabs(static_cast<float *>(output.data.data())[i] - refer.data[i]),
1e-5);
} }
} }
} }
......
...@@ -36,6 +36,7 @@ ...@@ -36,6 +36,7 @@
DEFINE_string(model_name, "", "model name"); DEFINE_string(model_name, "", "model name");
DEFINE_string(infer_model, "", "model path"); DEFINE_string(infer_model, "", "model path");
DEFINE_string(infer_data, "", "data file"); DEFINE_string(infer_data, "", "data file");
DEFINE_string(refer_result, "", "reference result for comparison");
DEFINE_int32(batch_size, 1, "batch size."); DEFINE_int32(batch_size, 1, "batch size.");
DEFINE_int32(repeat, 1, "Running the inference program repeat times."); DEFINE_int32(repeat, 1, "Running the inference program repeat times.");
DEFINE_bool(test_all_data, false, "Test the all dataset in data file."); DEFINE_bool(test_all_data, false, "Test the all dataset in data file.");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册