diff --git a/paddle/framework/program_desc.h b/paddle/framework/program_desc.h index b9741b31393a474e06fd156a2f3354844d53187c..8e958eab6ee08436ca73b13bac010e66c7df2b8b 100644 --- a/paddle/framework/program_desc.h +++ b/paddle/framework/program_desc.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include "paddle/framework/block_desc.h" #include "paddle/framework/framework.pb.h" #include "paddle/framework/proto_desc.h" #include "paddle/platform/macros.h" diff --git a/paddle/framework/prune.cc b/paddle/framework/prune.cc index db63cd12c9a58d62cfc0b80a155148eea3b73cab..bff8e0bceaca9749101b2c45edddba526d565624 100644 --- a/paddle/framework/prune.cc +++ b/paddle/framework/prune.cc @@ -21,11 +21,12 @@ limitations under the License. */ #include #include -#include "paddle/framework/feed_fetch_type.h" namespace paddle { namespace framework { +const std::string kFeedOpType = "feed"; +const std::string kFetchOpType = "fetch"; const std::string kDropOutOpType = "dropout"; const std::string kBatchNormOpType = "batch_norm"; diff --git a/paddle/inference/io.cc b/paddle/inference/io.cc index 556f235d1644e90330d30b882d7ca1a3684292ba..60ad7af1c0a469beb6a07bf057a8647fcb98cca8 100644 --- a/paddle/inference/io.cc +++ b/paddle/inference/io.cc @@ -22,11 +22,11 @@ namespace paddle { namespace inference { bool IsParameter(const framework::VarDesc* var, - const framework::ProgramDesc* main_program) { + const framework::ProgramDesc& main_program) { if (var->Persistable()) { // There are many unreachable variables in the program - for (size_t i = 0; i < main_program->Size(); ++i) { - const framework::BlockDesc& block = main_program->Block(i); + for (size_t i = 0; i < main_program.Size(); ++i) { + const framework::BlockDesc& block = main_program.Block(i); for (auto* op : block.AllOps()) { if (op->Type() == framework::kFeedOpType) { continue; @@ -45,12 +45,12 @@ bool IsParameter(const framework::VarDesc* var, void LoadPersistables(framework::Executor& executor, framework::Scope& scope, const std::string& dirname, - framework::ProgramDesc* main_program) { - framework::BlockDesc* global_block = main_program->MutableBlock(0); + const framework::ProgramDesc& main_program) { + const framework::BlockDesc& global_block = main_program.Block(0); framework::ProgramDesc* load_program = new framework::ProgramDesc(); framework::BlockDesc* load_block = load_program->MutableBlock(0); - for (auto* var : global_block->AllVars()) { + for (auto* var : global_block.AllVars()) { if (IsParameter(var, main_program)) { VLOG(3) << "parameter's name: " << var->Name(); @@ -73,9 +73,9 @@ void LoadPersistables(framework::Executor& executor, delete load_program; } -framework::ProgramDesc* Load(framework::Executor& executor, - framework::Scope& scope, - const std::string& dirname) { +std::unique_ptr Load(framework::Executor& executor, + framework::Scope& scope, + const std::string& dirname) { std::string model_filename = dirname + "/__model__"; LOG(INFO) << "loading model from " << model_filename; std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary); @@ -87,10 +87,10 @@ framework::ProgramDesc* Load(framework::Executor& executor, inputfs.read(&program_desc_str[0], program_desc_str.size()); inputfs.close(); - framework::ProgramDesc* main_program = - new framework::ProgramDesc(program_desc_str); + std::unique_ptr main_program( + new framework::ProgramDesc(program_desc_str)); - LoadPersistables(executor, scope, dirname, main_program); + LoadPersistables(executor, scope, dirname, *main_program); return main_program; } diff --git a/paddle/inference/io.h b/paddle/inference/io.h index fa9a620764f548ab4de67eae1c92b5bef12b22ab..962b6c4e20d30de3cc28eae1c8c5c33b3ab5f6ac 100644 --- a/paddle/inference/io.h +++ b/paddle/inference/io.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include "paddle/framework/executor.h" @@ -26,11 +27,11 @@ namespace inference { void LoadPersistables(framework::Executor& executor, framework::Scope& scope, const std::string& dirname, - framework::ProgramDesc* main_program); + const framework::ProgramDesc& main_program); -framework::ProgramDesc* Load(framework::Executor& executor, - framework::Scope& scope, - const std::string& dirname); +std::unique_ptr Load(framework::Executor& executor, + framework::Scope& scope, + const std::string& dirname); } // namespace inference } // namespace paddle diff --git a/paddle/inference/tests/book/test_inference_recognize_digits.cc b/paddle/inference/tests/book/test_inference_recognize_digits.cc index a2cdd60752672fb08b52ca58bab0760a27b10fad..26dc2aee04261d9a1fd29b4d75bfacc7870c09d8 100644 --- a/paddle/inference/tests/book/test_inference_recognize_digits.cc +++ b/paddle/inference/tests/book/test_inference_recognize_digits.cc @@ -31,7 +31,7 @@ void TestInference(const std::string& dirname, auto* scope = new paddle::framework::Scope(); // 2. Initialize the inference_program and load all parameters from file - auto* inference_program = paddle::inference::Load(executor, *scope, dirname); + auto inference_program = paddle::inference::Load(executor, *scope, dirname); // 3. Get the feed_target_names and fetch_target_names const std::vector& feed_target_names = @@ -39,14 +39,14 @@ void TestInference(const std::string& dirname, const std::vector& fetch_target_names = inference_program->GetFetchTargetNames(); - // 4. Prepare inputs + // 4. Prepare inputs: set up maps for feed targets std::map feed_targets; for (size_t i = 0; i < feed_target_names.size(); ++i) { // Please make sure that cpu_feeds[i] is right for feed_target_names[i] feed_targets[feed_target_names[i]] = cpu_feeds[i]; } - // 5. Define Tensor to get the outputs + // 5. Define Tensor to get the outputs: set up maps for fetch targets std::map fetch_targets; for (size_t i = 0; i < fetch_target_names.size(); ++i) { fetch_targets[fetch_target_names[i]] = cpu_fetchs[i]; @@ -55,7 +55,6 @@ void TestInference(const std::string& dirname, // 6. Run the inference program executor.Run(*inference_program, scope, feed_targets, fetch_targets); - delete inference_program; delete scope; }