From 3ae3b86489225becabe50453159c8b88e0b2d905 Mon Sep 17 00:00:00 2001 From: Pei Yang Date: Tue, 15 Sep 2020 10:42:30 +0800 Subject: [PATCH] fix trt_dynamic_shape_ernie_deserialize_test (#27290) * fix trt_dynamic_shape_ernie_deserialize_test * support when opt cache dir does not exist --- .../fluid/inference/tests/api/CMakeLists.txt | 7 +++--- ...rt_dynamic_shape_ernie_deserialize_test.cc | 23 +++++++++++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index b3ec4b5714..a1b43de469 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -480,10 +480,9 @@ if(WITH_GPU AND TENSORRT_FOUND) inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4_unserialized.tgz") endif() - # disable test_trt_dynamic_shape_ernie_ser_deser temporary - #inference_analysis_test(test_trt_dynamic_shape_ernie_ser_deser SRCS trt_dynamic_shape_ernie_deserialize_test.cc - # EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} - # ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4_unserialized) + inference_analysis_test(test_trt_dynamic_shape_ernie_ser_deser SRCS trt_dynamic_shape_ernie_deserialize_test.cc + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} + ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4_unserialized) endif() diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc index 524e08891f..685f7b6600 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc @@ -12,15 +12,33 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include #include +#include #include "paddle/fluid/inference/tests/api/trt_test_helper.h" namespace paddle { namespace inference { +int DeleteCache(std::string path) { + DIR* dir = opendir(path.c_str()); + if (dir == NULL) return 0; + struct dirent* ptr; + while ((ptr = readdir(dir)) != NULL) { + if (std::strcmp(ptr->d_name, ".") == 0 || + std::strcmp(ptr->d_name, "..") == 0) { + continue; + } else if (ptr->d_type == 8) { + std::string file_rm = path + "/" + ptr->d_name; + return remove(file_rm.c_str()); + } + } + return 0; +} + void run(const AnalysisConfig& config, std::vector* out_data) { auto predictor = CreatePaddlePredictor(config); auto input_names = predictor->GetInputNames(); @@ -86,6 +104,11 @@ void run(const AnalysisConfig& config, std::vector* out_data) { void trt_ernie(bool with_fp16, std::vector result) { AnalysisConfig config; std::string model_dir = FLAGS_infer_model; + // Delete serialization cache to perform serialization first rather than + // deserialization. + std::string opt_cache_dir = FLAGS_infer_model + "/_opt_cache"; + DeleteCache(opt_cache_dir); + SetConfig(&config, model_dir, true /* use_gpu */); config.SwitchUseFeedFetchOps(false); -- GitLab