From 438aad24a5a82d5e5302543a7f56bfd8f414aaf6 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Fri, 26 Jan 2018 04:07:02 +0000 Subject: [PATCH] Update the inference unittest using the new Executor.Run(). --- paddle/inference/inference.cc | 103 ++---------------- paddle/inference/inference.h | 18 ++- .../book/test_inference_recognize_digits.cc | 56 +++++++--- 3 files changed, 59 insertions(+), 118 deletions(-) diff --git a/paddle/inference/inference.cc b/paddle/inference/inference.cc index 2c4d717a13..51d43a63ee 100644 --- a/paddle/inference/inference.cc +++ b/paddle/inference/inference.cc @@ -14,13 +14,13 @@ limitations under the License. */ #include "inference.h" #include -#include "paddle/framework/executor.h" -#include "paddle/framework/init.h" -#include "paddle/framework/scope.h" namespace paddle { -void InferenceEngine::LoadInferenceModel(const std::string& dirname) { +framework::ProgramDesc* InferenceEngine::LoadInferenceModel( + framework::Executor& exe, + framework::Scope* scope, + const std::string& dirname) { std::string model_filename = dirname + "/__model__"; LOG(INFO) << "loading model from " << model_filename; std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary); @@ -34,6 +34,7 @@ void InferenceEngine::LoadInferenceModel(const std::string& dirname) { program_ = new framework::ProgramDesc(program_desc_str); GenerateLoadProgram(dirname); + exe.Run(*load_program_, scope, 0, true, true); framework::BlockDesc* global_block = program_->MutableBlock(0); feed_var_names_.clear(); @@ -45,6 +46,8 @@ void InferenceEngine::LoadInferenceModel(const std::string& dirname) { fetch_var_names_.push_back(op->Input("X")[0]); } } + + return program_; } bool InferenceEngine::IsParameter(const framework::VarDesc* var) { @@ -92,96 +95,4 @@ void InferenceEngine::GenerateLoadProgram(const std::string& dirname) { } } } - -void InferenceEngine::PrependFeedOp() { - if (!program_) { - LOG(FATAL) << "Please initialize the program_ first."; - } - - framework::BlockDesc* global_block = program_->MutableBlock(0); - - // create_var - framework::VarDesc* feed_var = global_block->Var("feed"); - feed_var->SetType(framework::proto::VarDesc::FEED_MINIBATCH); - feed_var->SetPersistable(true); - - // prepend feed_op - for (size_t i = 0; i < feed_var_names_.size(); ++i) { - std::string var_name = feed_var_names_[i]; - LOG(INFO) << "feed var's name: " << var_name; - - // prepend_op - framework::OpDesc* op = global_block->PrependOp(); - op->SetType("feed"); - op->SetInput("X", {"feed"}); - op->SetOutput("Out", {var_name}); - op->SetAttr("col", {static_cast(i)}); - op->CheckAttrs(); - } -} - -void InferenceEngine::AppendFetchOp() { - if (!program_) { - LOG(FATAL) << "Please initialize the program_ first."; - } - - framework::BlockDesc* global_block = program_->MutableBlock(0); - - // create_var - framework::VarDesc* fetch_var = global_block->Var("fetch"); - fetch_var->SetType(framework::proto::VarDesc::FETCH_LIST); - fetch_var->SetPersistable(true); - - // append fetch_op - for (size_t i = 0; i < fetch_var_names_.size(); ++i) { - std::string var_name = fetch_var_names_[i]; - LOG(INFO) << "fetch var's name: " << var_name; - - // append_op - framework::OpDesc* op = global_block->AppendOp(); - op->SetType("fetch"); - op->SetInput("X", {var_name}); - op->SetOutput("Out", {"fetch"}); - op->SetAttr("col", {static_cast(i)}); - op->CheckAttrs(); - } -} - -void InferenceEngine::Execute(const std::vector& feeds, - std::vector& fetchs) { - if (!program_ || !load_program_) { - LOG(FATAL) << "Please initialize the program_ and load_program_ first."; - } - - if (feeds.size() != feed_var_names_.size()) { - LOG(FATAL) << "Please feed " << feed_var_names_.size() << " input Tensors."; - } - - auto* place = new platform::CPUPlace(); - framework::InitDevices(); - framework::Executor* executor = new framework::Executor(*place); - framework::Scope* scope = new framework::Scope(); - - executor->Run(*load_program_, scope, 0, true, true); - - std::map feed_targets; - std::map fetch_targets; - - // set_feed_variable - for (size_t i = 0; i < feed_var_names_.size(); ++i) { - feed_targets[feed_var_names_[i]] = &feeds[i]; - } - - // get_fetch_variable - fetchs.resize(fetch_var_names_.size()); - for (size_t i = 0; i < fetch_var_names_.size(); ++i) { - fetch_targets[fetch_var_names_[i]] = &fetchs[i]; - } - - executor->Run(*program_, scope, feed_targets, fetch_targets); - - delete place; - delete scope; - delete executor; -} } // namespace paddle diff --git a/paddle/inference/inference.h b/paddle/inference/inference.h index 26f259824b..60caa41c70 100644 --- a/paddle/inference/inference.h +++ b/paddle/inference/inference.h @@ -15,8 +15,10 @@ limitations under the License. */ #pragma once #include "paddle/framework/block_desc.h" +#include "paddle/framework/executor.h" #include "paddle/framework/lod_tensor.h" #include "paddle/framework/program_desc.h" +#include "paddle/framework/scope.h" namespace paddle { @@ -28,15 +30,21 @@ public: delete load_program_; } - void LoadInferenceModel(const std::string& dirname); - void Execute(const std::vector& feeds, - std::vector& fetchs); + framework::ProgramDesc* LoadInferenceModel(framework::Executor& exe, + framework::Scope* scope, + const std::string& dirname); + + const std::vector& GetFeedVarNames() const { + return feed_var_names_; + } + + const std::vector& GetFetchVarNames() const { + return fetch_var_names_; + } private: bool IsParameter(const framework::VarDesc* var); void GenerateLoadProgram(const std::string& dirname); - void PrependFeedOp(); - void AppendFetchOp(); private: framework::ProgramDesc* program_; diff --git a/paddle/inference/tests/book/test_inference_recognize_digits.cc b/paddle/inference/tests/book/test_inference_recognize_digits.cc index d0e811914c..0dfaf9a0ee 100644 --- a/paddle/inference/tests/book/test_inference_recognize_digits.cc +++ b/paddle/inference/tests/book/test_inference_recognize_digits.cc @@ -16,11 +16,12 @@ limitations under the License. */ #include #include #include "gflags/gflags.h" +#include "paddle/framework/init.h" #include "paddle/inference/inference.h" DEFINE_string(dirname, "", "Directory of the inference model."); -TEST(inference, recognize_digits) { +TEST(recognize_digits, CPU) { if (FLAGS_dirname.empty()) { LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model"; } @@ -28,33 +29,54 @@ TEST(inference, recognize_digits) { std::cout << "FLAGS_dirname: " << FLAGS_dirname << std::endl; std::string dirname = FLAGS_dirname; + // 0. Initialize all the devices + paddle::framework::InitDevices(); + + // 1. Define place, executor and scope + auto place = paddle::platform::CPUPlace(); + auto executor = paddle::framework::Executor(place); + auto* scope = new paddle::framework::Scope(); + + // 2. Initialize the inference_program and load all parameters from file paddle::InferenceEngine* engine = new paddle::InferenceEngine(); - engine->LoadInferenceModel(dirname); + paddle::framework::ProgramDesc* inference_program = + engine->LoadInferenceModel(executor, scope, dirname); + + // 3. Get the feed_var_names and fetch_var_names + const std::vector& feed_target_names = engine->GetFeedVarNames(); + const std::vector& fetch_target_names = + engine->GetFetchVarNames(); + // 4. Prepare inputs + std::map feed_targets; paddle::framework::LoDTensor input; srand(time(0)); float* input_ptr = - input.mutable_data({1, 784}, paddle::platform::CPUPlace()); + input.mutable_data({1, 28, 28}, paddle::platform::CPUPlace()); for (int i = 0; i < 784; ++i) { input_ptr[i] = rand() / (static_cast(RAND_MAX)); } + feed_targets[feed_target_names[0]] = &input; + + // 5. Define Tensor to get the outputs + std::map fetch_targets; + paddle::framework::LoDTensor output; + fetch_targets[fetch_target_names[0]] = &output; + + // 6. Run the inference program + executor.Run(*inference_program, scope, feed_targets, fetch_targets); - std::vector feeds; - feeds.push_back(input); - std::vector fetchs; - engine->Execute(feeds, fetchs); - - for (size_t i = 0; i < fetchs.size(); ++i) { - LOG(INFO) << fetchs[i].dims(); - std::stringstream ss; - ss << "result:"; - float* output_ptr = fetchs[i].data(); - for (int j = 0; j < fetchs[i].numel(); ++j) { - ss << " " << output_ptr[j]; - } - LOG(INFO) << ss.str(); + // 7. Use the output as your expect. + LOG(INFO) << output.dims(); + std::stringstream ss; + ss << "result:"; + float* output_ptr = output.data(); + for (int j = 0; j < output.numel(); ++j) { + ss << " " << output_ptr[j]; } + LOG(INFO) << ss.str(); + delete scope; delete engine; } -- GitLab