diff --git a/paddle/framework/program_desc.cc b/paddle/framework/program_desc.cc index e59e392dfd16e82c01ded8ca40099c2b71bdabcf..b2368e3a27abe6382b7460222e3fccce6f1beb08 100644 --- a/paddle/framework/program_desc.cc +++ b/paddle/framework/program_desc.cc @@ -67,26 +67,26 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) { } } -const std::vector ProgramDesc::GetFeedVarNames() { +const std::vector ProgramDesc::GetFeedTargetNames() { BlockDesc *global_block = blocks_[0].get(); - std::vector feed_var_names; + std::vector feed_target_names; for (auto *op : global_block->AllOps()) { - if (op->Type() == "feed") { - feed_var_names.insert(feed_var_names.begin(), op->Output("Out")[0]); + if (op->Type() == kFeedOpType) { + feed_target_names.insert(feed_target_names.begin(), op->Output("Out")[0]); } } - return feed_var_names; + return feed_target_names; } -const std::vector ProgramDesc::GetFetchVarNames() { +const std::vector ProgramDesc::GetFetchTargetNames() { BlockDesc *global_block = blocks_[0].get(); - std::vector fetch_var_names; + std::vector fetch_target_names; for (auto *op : global_block->AllOps()) { - if (op->Type() == "fetch") { - fetch_var_names.push_back(op->Input("X")[0]); + if (op->Type() == kFetchOpType) { + fetch_target_names.push_back(op->Input("X")[0]); } } - return fetch_var_names; + return fetch_target_names; } } // namespace framework diff --git a/paddle/framework/program_desc.h b/paddle/framework/program_desc.h index 2c3883275a23c62a3d085467424c48505661fba3..b9741b31393a474e06fd156a2f3354844d53187c 100644 --- a/paddle/framework/program_desc.h +++ b/paddle/framework/program_desc.h @@ -45,9 +45,8 @@ class ProgramDesc { proto::ProgramDesc *Proto(); - const std::vector GetFeedVarNames(); - - const std::vector GetFetchVarNames(); + const std::vector GetFeedTargetNames(); + const std::vector GetFetchTargetNames(); private: proto::ProgramDesc desc_; diff --git a/paddle/inference/example.cc b/paddle/inference/example.cc index 5173779c623a9963d9a2177c31a9248617c4db4e..ac2aedd88b61cde18e8fb9c05d34dd62daf62ab7 100644 --- a/paddle/inference/example.cc +++ b/paddle/inference/example.cc @@ -40,15 +40,15 @@ int main(int argc, char** argv) { std::string dirname = FLAGS_dirname; // 2. Initialize the inference program - auto* inference_program = paddle::inference::Load(*executor, *scope, dirname); + auto inference_program = paddle::inference::Load(*executor, *scope, dirname); // 3. Optional: perform optimization on the inference_program - // 4. Get the feed_var_names and fetch_var_names - const std::vector& feed_var_names = - inference_program->GetFeedVarNames(); - const std::vector& fetch_var_names = - inference_program->GetFetchVarNames(); + // 4. Get the feed_target_names and fetch_target_names + const std::vector& feed_target_names = + inference_program->GetFeedTargetNames(); + const std::vector& fetch_target_names = + inference_program->GetFetchTargetNames(); // 5. Generate input paddle::framework::LoDTensor input; @@ -68,14 +68,14 @@ int main(int argc, char** argv) { std::map fetch_targets; // set_feed_variable - for (size_t i = 0; i < feed_var_names.size(); ++i) { - feed_targets[feed_var_names[i]] = &feeds[i]; + for (size_t i = 0; i < feed_target_names.size(); ++i) { + feed_targets[feed_target_names[i]] = &feeds[i]; } // get_fetch_variable - fetchs.resize(fetch_var_names.size()); - for (size_t i = 0; i < fetch_var_names.size(); ++i) { - fetch_targets[fetch_var_names[i]] = &fetchs[i]; + fetchs.resize(fetch_target_names.size()); + for (size_t i = 0; i < fetch_target_names.size(); ++i) { + fetch_targets[fetch_target_names[i]] = &fetchs[i]; } // Run the inference program @@ -97,7 +97,6 @@ int main(int argc, char** argv) { std::cout << std::endl; } - delete inference_program; delete scope; delete executor; diff --git a/paddle/inference/io.cc b/paddle/inference/io.cc index 98b33d656d254f30577c5cca658e8535ac7582f3..f6d901381e781f161689f05315d4e0fe63610f84 100644 --- a/paddle/inference/io.cc +++ b/paddle/inference/io.cc @@ -21,11 +21,11 @@ namespace inference { const std::string kFeedOpType = "feed"; bool IsParameter(const framework::VarDesc* var, - const framework::ProgramDesc* main_program) { + const framework::ProgramDesc& main_program) { if (var->Persistable()) { // There are many unreachable variables in the program - for (size_t i = 0; i < main_program->Size(); ++i) { - const framework::BlockDesc& block = main_program->Block(i); + for (size_t i = 0; i < main_program.Size(); ++i) { + const framework::BlockDesc& block = main_program.Block(i); for (auto* op : block.AllOps()) { if (op->Type() == kFeedOpType) { continue; @@ -44,12 +44,12 @@ bool IsParameter(const framework::VarDesc* var, void LoadPersistables(framework::Executor& executor, framework::Scope& scope, const std::string& dirname, - framework::ProgramDesc* main_program) { - framework::BlockDesc* global_block = main_program->MutableBlock(0); + const framework::ProgramDesc& main_program) { + const framework::BlockDesc& global_block = main_program.Block(0); framework::ProgramDesc* load_program = new framework::ProgramDesc(); framework::BlockDesc* load_block = load_program->MutableBlock(0); - for (auto* var : global_block->AllVars()) { + for (auto* var : global_block.AllVars()) { if (IsParameter(var, main_program)) { LOG(INFO) << "parameter's name: " << var->Name(); @@ -72,9 +72,9 @@ void LoadPersistables(framework::Executor& executor, delete load_program; } -framework::ProgramDesc* Load(framework::Executor& executor, - framework::Scope& scope, - const std::string& dirname) { +std::unique_ptr Load(framework::Executor& executor, + framework::Scope& scope, + const std::string& dirname) { std::string model_filename = dirname + "/__model__"; LOG(INFO) << "loading model from " << model_filename; std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary); @@ -86,10 +86,10 @@ framework::ProgramDesc* Load(framework::Executor& executor, inputfs.read(&program_desc_str[0], program_desc_str.size()); inputfs.close(); - framework::ProgramDesc* main_program = - new framework::ProgramDesc(program_desc_str); + std::unique_ptr main_program( + new framework::ProgramDesc(program_desc_str)); - LoadPersistables(executor, scope, dirname, main_program); + LoadPersistables(executor, scope, dirname, *main_program); return main_program; } diff --git a/paddle/inference/io.h b/paddle/inference/io.h index 400f5af8c53146cfc9d91ae15c77eb4364c5e3c8..dccb700e9565b3482152cfcf399b2369edf01c7b 100644 --- a/paddle/inference/io.h +++ b/paddle/inference/io.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include "paddle/framework/block_desc.h" @@ -26,16 +27,16 @@ namespace paddle { namespace inference { bool IsParameter(const framework::VarDesc* var, - const framework::ProgramDesc* main_program); + const framework::ProgramDesc& main_program); void LoadPersistables(framework::Executor& executor, framework::Scope& scope, const std::string& dirname, - framework::ProgramDesc* main_program); + const framework::ProgramDesc& main_program); -framework::ProgramDesc* Load(framework::Executor& executor, - framework::Scope& scope, - const std::string& dirname); +std::unique_ptr Load(framework::Executor& executor, + framework::Scope& scope, + const std::string& dirname); } // namespace inference } // namespace paddle