提交 eea36739 编写于 作者: T Tao Luo

refine test_helper.h

test=develop
上级 2b791f1f
...@@ -107,12 +107,11 @@ std::unordered_map<std::string, int> GetFuseStatis(PaddlePredictor *predictor, ...@@ -107,12 +107,11 @@ std::unordered_map<std::string, int> GetFuseStatis(PaddlePredictor *predictor,
} }
void SetFakeImageInput(std::vector<std::vector<PaddleTensor>> *inputs, void SetFakeImageInput(std::vector<std::vector<PaddleTensor>> *inputs,
const std::string &dirname, const std::string &dirname) {
const bool is_combined = true) {
// Set fake_image_data // Set fake_image_data
PADDLE_ENFORCE_EQ(FLAGS_test_all_data, 0, "Only have single batch of data."); PADDLE_ENFORCE_EQ(FLAGS_test_all_data, 0, "Only have single batch of data.");
std::vector<std::vector<int64_t>> feed_target_shapes = std::vector<std::vector<int64_t>> feed_target_shapes =
GetFeedTargetShapes(dirname, is_combined); GetFeedTargetShapes(dirname, true, "model", "params");
int dim1 = feed_target_shapes[0][1]; int dim1 = feed_target_shapes[0][1];
int dim2 = feed_target_shapes[0][2]; int dim2 = feed_target_shapes[0][2];
int dim3 = feed_target_shapes[0][3]; int dim3 = feed_target_shapes[0][3];
......
...@@ -93,15 +93,15 @@ void CheckError(const paddle::framework::LoDTensor& output1, ...@@ -93,15 +93,15 @@ void CheckError(const paddle::framework::LoDTensor& output1,
std::unique_ptr<paddle::framework::ProgramDesc> InitProgram( std::unique_ptr<paddle::framework::ProgramDesc> InitProgram(
paddle::framework::Executor* executor, paddle::framework::Scope* scope, paddle::framework::Executor* executor, paddle::framework::Scope* scope,
const std::string& dirname, const bool is_combined = false) { const std::string& dirname, const bool is_combined = false,
const std::string& prog_filename = "__model_combined__",
const std::string& param_filename = "__params_combined__") {
std::unique_ptr<paddle::framework::ProgramDesc> inference_program; std::unique_ptr<paddle::framework::ProgramDesc> inference_program;
if (is_combined) { if (is_combined) {
// All parameters are saved in a single file. // All parameters are saved in a single file.
// Hard-coding the file names of program and parameters in unittest. // Hard-coding the file names of program and parameters in unittest.
// The file names should be consistent with that used in Python API // The file names should be consistent with that used in Python API
// `fluid.io.save_inference_model`. // `fluid.io.save_inference_model`.
std::string prog_filename = "model";
std::string param_filename = "params";
inference_program = inference_program =
paddle::inference::Load(executor, scope, dirname + "/" + prog_filename, paddle::inference::Load(executor, scope, dirname + "/" + prog_filename,
dirname + "/" + param_filename); dirname + "/" + param_filename);
...@@ -114,12 +114,15 @@ std::unique_ptr<paddle::framework::ProgramDesc> InitProgram( ...@@ -114,12 +114,15 @@ std::unique_ptr<paddle::framework::ProgramDesc> InitProgram(
} }
std::vector<std::vector<int64_t>> GetFeedTargetShapes( std::vector<std::vector<int64_t>> GetFeedTargetShapes(
const std::string& dirname, const bool is_combined = false) { const std::string& dirname, const bool is_combined = false,
const std::string& prog_filename = "__model_combined__",
const std::string& param_filename = "__params_combined__") {
auto place = paddle::platform::CPUPlace(); auto place = paddle::platform::CPUPlace();
auto executor = paddle::framework::Executor(place); auto executor = paddle::framework::Executor(place);
auto* scope = new paddle::framework::Scope(); auto* scope = new paddle::framework::Scope();
auto inference_program = InitProgram(&executor, scope, dirname, is_combined); auto inference_program = InitProgram(&executor, scope, dirname, is_combined,
prog_filename, param_filename);
auto& global_block = inference_program->Block(0); auto& global_block = inference_program->Block(0);
const std::vector<std::string>& feed_target_names = const std::vector<std::string>& feed_target_names =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册