From 7eec162cf5173bab3cca416b2ee6b748156a1077 Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Sat, 8 Aug 2020 12:57:32 +0800 Subject: [PATCH] fix test shape after adding stack op --- .../api/trt_dynamic_shape_ernie_deserialize_test.cc | 9 ++++----- .../inference/tests/api/trt_dynamic_shape_ernie_test.cc | 6 +++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc index 6526b874365..8570fad28cc 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc @@ -90,7 +90,6 @@ void trt_ernie(bool with_fp16, std::vector result) { config.SwitchUseFeedFetchOps(false); - int head_number = 12; int batch = 1; int min_seq_len = 1; int max_seq_len = 128; @@ -104,23 +103,23 @@ void trt_ernie(bool with_fp16, std::vector result) { {"read_file_0.tmp_0", min_shape}, {"read_file_0.tmp_1", min_shape}, {"read_file_0.tmp_2", min_shape}, - {"stack_0.tmp_0", {batch, head_number, min_seq_len, min_seq_len}}}; + {"matmul_0.tmp_0", {batch, min_seq_len, min_seq_len}}}; std::map> max_input_shape = { {"read_file_0.tmp_0", max_shape}, {"read_file_0.tmp_1", max_shape}, {"read_file_0.tmp_2", max_shape}, - {"stack_0.tmp_0", {batch, head_number, max_seq_len, max_seq_len}}}; + {"matmul_0.tmp_0", {batch, max_seq_len, max_seq_len}}}; std::map> opt_input_shape = { {"read_file_0.tmp_0", opt_shape}, {"read_file_0.tmp_1", opt_shape}, {"read_file_0.tmp_2", opt_shape}, - {"stack_0.tmp_0", {batch, head_number, opt_seq_len, opt_seq_len}}}; + {"matmul_0.tmp_0", {batch, opt_seq_len, opt_seq_len}}}; auto precision = AnalysisConfig::Precision::kFloat32; if (with_fp16) { precision = AnalysisConfig::Precision::kHalf; } - config.EnableTensorRtEngine(1 << 30, 1, 5, precision, true, false); + config.EnableTensorRtEngine(1 << 30, 1, 3, precision, true, false); config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape, opt_input_shape); std::vector out_data; diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc index bf5f8828e0b..4ce987169f6 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc @@ -103,17 +103,17 @@ void trt_ernie(bool with_fp16, std::vector result) { {"read_file_0.tmp_0", min_shape}, {"read_file_0.tmp_1", min_shape}, {"read_file_0.tmp_2", min_shape}, - {"matmul_0.tmp_0", min_shape}}; + {"matmul_0.tmp_0", {batch, min_seq_len, min_seq_len}}}; std::map> max_input_shape = { {"read_file_0.tmp_0", max_shape}, {"read_file_0.tmp_1", max_shape}, {"read_file_0.tmp_2", max_shape}, - {"matmul_0.tmp_0", max_shape}}; + {"matmul_0.tmp_0", {batch, max_seq_len, max_seq_len}}}; std::map> opt_input_shape = { {"read_file_0.tmp_0", opt_shape}, {"read_file_0.tmp_1", opt_shape}, {"read_file_0.tmp_2", opt_shape}, - {"matmul_0.tmp_0", opt_shape}}; + {"matmul_0.tmp_0", {batch, opt_seq_len, opt_seq_len}}}; auto precision = AnalysisConfig::Precision::kFloat32; if (with_fp16) { -- GitLab