未验证 提交 b757466b 编写于 作者: P Pei Yang 提交者: GitHub

fix trt dynamic ernie serialization unit test (#26228)

上级 ea6716a5
...@@ -309,7 +309,8 @@ std::vector<std::vector<Node *>> SubgraphDetector::ExtractSubGraphs() { ...@@ -309,7 +309,8 @@ std::vector<std::vector<Node *>> SubgraphDetector::ExtractSubGraphs() {
BriefNode *brief_node = itr.second; BriefNode *brief_node = itr.second;
if (!Agent(brief_node->node).marked()) { if (!Agent(brief_node->node).marked()) {
VLOG(4) << brief_node->node->id() << " node not a trt candidate."; VLOG(4) << brief_node->node->id() << " node named "
<< brief_node->node->Name() << " is not a trt candidate.";
continue; continue;
} }
......
...@@ -471,19 +471,10 @@ if(WITH_GPU AND TENSORRT_FOUND) ...@@ -471,19 +471,10 @@ if(WITH_GPU AND TENSORRT_FOUND)
inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4_unserialized.tgz") inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4_unserialized.tgz")
endif() endif()
inference_analysis_test(test_trt_dynamic_shape_ernie_serialize SRCS trt_dynamic_shape_ernie_deserialize_test.cc inference_analysis_test(test_trt_dynamic_shape_ernie_ser_deser SRCS trt_dynamic_shape_ernie_deserialize_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4_unserialized) ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4_unserialized)
set(TEST_TRT_ERNIE_SER_MODEL "${TRT_MODEL_INSTALL_DIR}/ernie_test/ernie_model_4_serialized/")
if (NOT EXISTS ${TEST_TRT_ERNIE_SER_MODEL})
inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4_serialized.tgz")
endif()
inference_analysis_test(test_trt_dynamic_shape_ernie_deserialize SRCS trt_dynamic_shape_ernie_deserialize_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4_serialized)
endif() endif()
set(LITE_MODEL_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/lite") set(LITE_MODEL_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/lite")
......
...@@ -123,8 +123,11 @@ void trt_ernie(bool with_fp16, std::vector<float> result) { ...@@ -123,8 +123,11 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
config.EnableTensorRtEngine(1 << 30, 1, 5, precision, true, false); config.EnableTensorRtEngine(1 << 30, 1, 5, precision, true, false);
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape, config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
opt_input_shape); opt_input_shape);
AnalysisConfig* config_deser = new AnalysisConfig(config);
std::vector<float> out_data; std::vector<float> out_data;
run(config, &out_data); run(config, &out_data); // serialize
run(*config_deser, &out_data); // deserialize
for (size_t i = 0; i < out_data.size(); i++) { for (size_t i = 0; i < out_data.size(); i++) {
EXPECT_NEAR(result[i], out_data[i], 1e-6); EXPECT_NEAR(result[i], out_data[i], 1e-6);
} }
......
...@@ -126,7 +126,7 @@ void trt_ernie(bool with_fp16, std::vector<float> result) { ...@@ -126,7 +126,7 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
std::vector<float> out_data; std::vector<float> out_data;
run(config, &out_data); run(config, &out_data);
for (size_t i = 0; i < out_data.size(); i++) { for (size_t i = 0; i < out_data.size(); i++) {
EXPECT_NEAR(result[i], out_data[i], 1e-6); EXPECT_NEAR(result[i], out_data[i], 1e-5);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册