diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index 58d7ab40bfa67595a9c7c61ed431a7cf9509e1f7..8132347d630592db3d98e23a2c77b85568c70c38 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -32,23 +32,11 @@ void ReadBinaryFile(const std::string& filename, std::string& contents) { inputfs.close(); } -bool IsParameter(const framework::VarDesc* var, - const framework::ProgramDesc& main_program) { - if (var->Persistable()) { - // There are many unreachable variables in the program - for (size_t i = 0; i < main_program.Size(); ++i) { - const framework::BlockDesc& block = main_program.Block(i); - for (auto* op : block.AllOps()) { - if (op->Type() == framework::kFeedOpType) { - continue; - } - for (auto input_argument_name : op->InputArgumentNames()) { - if (input_argument_name == var->Name()) { - return true; - } - } - } - } +bool IsPersistable(const framework::VarDesc* var) { + if (var->Persistable() && + var->GetType() != framework::proto::VarDesc::FEED_MINIBATCH && + var->GetType() != framework::proto::VarDesc::FETCH_LIST) { + return true; } return false; } @@ -65,8 +53,8 @@ void LoadPersistables(framework::Executor& executor, std::vector paramlist; for (auto* var : global_block.AllVars()) { - if (IsParameter(var, main_program)) { - VLOG(3) << "parameter's name: " << var->Name(); + if (IsPersistable(var)) { + VLOG(3) << "persistable variable's name: " << var->Name(); framework::VarDesc* new_var = load_block->Var(var->Name()); new_var->SetShape(var->GetShape()); @@ -101,7 +89,6 @@ void LoadPersistables(framework::Executor& executor, executor.Run(*load_program, &scope, 0, true, true); - VLOG(3) << "Ran loading successfully"; delete load_program; }