From eea36739ccc2f5cde74a13ee8dd46da4de1d2223 Mon Sep 17 00:00:00 2001 From: Tao Luo Date: Wed, 7 Nov 2018 13:40:54 +0800 Subject: [PATCH] refine test_helper.h test=develop --- paddle/fluid/inference/tests/api/tester_helper.h | 5 ++--- paddle/fluid/inference/tests/test_helper.h | 13 ++++++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index 79468da03..8c5888d8d 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -107,12 +107,11 @@ std::unordered_map GetFuseStatis(PaddlePredictor *predictor, } void SetFakeImageInput(std::vector> *inputs, - const std::string &dirname, - const bool is_combined = true) { + const std::string &dirname) { // Set fake_image_data PADDLE_ENFORCE_EQ(FLAGS_test_all_data, 0, "Only have single batch of data."); std::vector> feed_target_shapes = - GetFeedTargetShapes(dirname, is_combined); + GetFeedTargetShapes(dirname, true, "model", "params"); int dim1 = feed_target_shapes[0][1]; int dim2 = feed_target_shapes[0][2]; int dim3 = feed_target_shapes[0][3]; diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index 00976a399..2118fcfd4 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -93,15 +93,15 @@ void CheckError(const paddle::framework::LoDTensor& output1, std::unique_ptr InitProgram( paddle::framework::Executor* executor, paddle::framework::Scope* scope, - const std::string& dirname, const bool is_combined = false) { + const std::string& dirname, const bool is_combined = false, + const std::string& prog_filename = "__model_combined__", + const std::string& param_filename = "__params_combined__") { std::unique_ptr inference_program; if (is_combined) { // All parameters are saved in a single file. // Hard-coding the file names of program and parameters in unittest. // The file names should be consistent with that used in Python API // `fluid.io.save_inference_model`. - std::string prog_filename = "model"; - std::string param_filename = "params"; inference_program = paddle::inference::Load(executor, scope, dirname + "/" + prog_filename, dirname + "/" + param_filename); @@ -114,12 +114,15 @@ std::unique_ptr InitProgram( } std::vector> GetFeedTargetShapes( - const std::string& dirname, const bool is_combined = false) { + const std::string& dirname, const bool is_combined = false, + const std::string& prog_filename = "__model_combined__", + const std::string& param_filename = "__params_combined__") { auto place = paddle::platform::CPUPlace(); auto executor = paddle::framework::Executor(place); auto* scope = new paddle::framework::Scope(); - auto inference_program = InitProgram(&executor, scope, dirname, is_combined); + auto inference_program = InitProgram(&executor, scope, dirname, is_combined, + prog_filename, param_filename); auto& global_block = inference_program->Block(0); const std::vector& feed_target_names = -- GitLab