From b44917d09befcc6a300b2ef7c6ff86302d085f07 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 12 Feb 2018 08:45:06 +0000 Subject: [PATCH] Implement IsPersistable() in c++. --- paddle/fluid/inference/io.cc | 27 +++++++-------------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index 58d7ab40bfa..8132347d630 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -32,23 +32,11 @@ void ReadBinaryFile(const std::string& filename, std::string& contents) { inputfs.close(); } -bool IsParameter(const framework::VarDesc* var, - const framework::ProgramDesc& main_program) { - if (var->Persistable()) { - // There are many unreachable variables in the program - for (size_t i = 0; i < main_program.Size(); ++i) { - const framework::BlockDesc& block = main_program.Block(i); - for (auto* op : block.AllOps()) { - if (op->Type() == framework::kFeedOpType) { - continue; - } - for (auto input_argument_name : op->InputArgumentNames()) { - if (input_argument_name == var->Name()) { - return true; - } - } - } - } +bool IsPersistable(const framework::VarDesc* var) { + if (var->Persistable() && + var->GetType() != framework::proto::VarDesc::FEED_MINIBATCH && + var->GetType() != framework::proto::VarDesc::FETCH_LIST) { + return true; } return false; } @@ -65,8 +53,8 @@ void LoadPersistables(framework::Executor& executor, std::vector paramlist; for (auto* var : global_block.AllVars()) { - if (IsParameter(var, main_program)) { - VLOG(3) << "parameter's name: " << var->Name(); + if (IsPersistable(var)) { + VLOG(3) << "persistable variable's name: " << var->Name(); framework::VarDesc* new_var = load_block->Var(var->Name()); new_var->SetShape(var->GetShape()); @@ -101,7 +89,6 @@ void LoadPersistables(framework::Executor& executor, executor.Run(*load_program, &scope, 0, true, true); - VLOG(3) << "Ran loading successfully"; delete load_program; } -- GitLab