diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 0b171e1dcfa90c3ad8f5a9ace8a9342baaf76e61..b576e5f39e677e630d56e7602d25d2e1e84bcaad 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -93,6 +93,42 @@ static void CheckTensorNANOrInf(const std::string& name, "Tensor %s contains NAN", name); } +void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope) { + auto& global_block = pdesc.Block(0); + + const Scope* ancestor_scope = scope; + while (ancestor_scope->parent()) { + ancestor_scope = ancestor_scope->parent(); + } + + if (ancestor_scope != scope) { + for (auto& var : global_block.AllVars()) { + if (var->Name() == framework::kEmptyVarName) { + continue; + } + + if (var->Persistable()) { + auto* ptr = const_cast(ancestor_scope)->Var(var->Name()); + CreateTensor(ptr, var->GetType()); + VLOG(3) << "Create Variable " << var->Name() + << " global, which pointer is " << ptr; + } else { + auto* ptr = scope->Var(var->Name()); + CreateTensor(ptr, var->GetType()); + VLOG(3) << "Create Variable " << var->Name() + << " locally, which pointer is " << ptr; + } + } + } else { + for (auto& var : global_block.AllVars()) { + auto* ptr = scope->Var(var->Name()); + CreateTensor(ptr, var->GetType()); + VLOG(3) << "Create variable " << var->Name() << ", which pointer is " + << ptr; + } + } +} + void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, bool create_local_scope, bool create_vars) { platform::RecordBlock b(block_id); @@ -184,8 +220,8 @@ static bool has_fetch_operators( void Executor::Run(const ProgramDesc& program, Scope* scope, std::map& feed_targets, std::map& fetch_targets, - const std::string& feed_holder_name, - const std::string& fetch_holder_name, bool create_vars) { + bool create_vars, const std::string& feed_holder_name, + const std::string& fetch_holder_name) { platform::RecordBlock b(kProgramId); bool has_feed_ops = has_feed_operators(program.Block(0), feed_targets, feed_holder_name); @@ -281,39 +317,16 @@ std::unique_ptr Executor::Prepare( void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, bool create_local_scope, bool create_vars) { - auto& block = ctx->prog_.Block(ctx->block_id_); - Scope* local_scope = scope; if (create_vars) { if (create_local_scope) { local_scope = &scope->NewScope(); - for (auto& var : block.AllVars()) { - if (var->Name() == framework::kEmptyVarName) { - continue; - } - - if (var->Persistable()) { - auto* ptr = scope->Var(var->Name()); - CreateTensor(ptr, var->GetType()); - VLOG(3) << "Create Variable " << var->Name() - << " global, which pointer is " << ptr; - } else { - auto* ptr = local_scope->Var(var->Name()); - CreateTensor(ptr, var->GetType()); - VLOG(3) << "Create Variable " << var->Name() - << " locally, which pointer is " << ptr; - } - } } else { - for (auto& var : block.AllVars()) { - auto* ptr = local_scope->Var(var->Name()); - CreateTensor(ptr, var->GetType()); - VLOG(3) << "Create variable " << var->Name() << ", which pointer is " - << ptr; - } } // if (create_local_scope) } // if (create_vars) + CreateVariables(ctx->prog_, local_scope); + for (auto& op : ctx->ops_) { VLOG(3) << place_ << " " << op->DebugStringEx(local_scope); op->Run(*local_scope, place_); diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index d8dd82469af06a4c5c6a37d2249ee23413884a91..688ba09f9b93851907ac2bcba9400243695ad852 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -53,13 +53,15 @@ class Executor { void Run(const ProgramDesc& program, Scope* scope, std::map& feed_targets, std::map& fetch_targets, + bool create_vars = true, const std::string& feed_holder_name = "feed", - const std::string& fetch_holder_name = "fetch", - bool create_vars = true); + const std::string& fetch_holder_name = "fetch"); static std::unique_ptr Prepare( const ProgramDesc& program, int block_id); + void CreateVariables(const ProgramDesc& pdesc, Scope* scope); + void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, bool create_local_scope = true, bool create_vars = true); diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index c1e1f49caaa5a60df0e97289aada465b45213971..11084b301acf8dcdc14069e35d2b5378dcc9a20a 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -57,7 +57,7 @@ class Scope { /// nullptr if cannot find. Variable* FindVar(const std::string& name) const; - const Scope& parent() const { return *parent_; } + const Scope* parent() const { return parent_; } /// Find the scope or an ancestor scope that contains the given variable. const Scope* FindScope(const Variable* var) const; diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index dce541c0971a6ff9a3728e915fe8c3d009c23550..68dd020f395c901e8ab7638348fd7bae25f5dc30 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -169,8 +169,14 @@ void TestInference(const std::string& dirname, // 6. Run the inference program { + const bool create_vars = false; + if (!create_vars) { + executor.CreateVariables(*inference_program, scope); + } + // Ignore the profiling results of the first run - executor.Run(*inference_program, scope, feed_targets, fetch_targets); + executor.Run( + *inference_program, scope, feed_targets, fetch_targets, create_vars); // Enable the profiler paddle::platform::EnableProfiler(state); @@ -181,7 +187,8 @@ void TestInference(const std::string& dirname, "run_inference", paddle::platform::DeviceContextPool::Instance().Get(place)); - executor.Run(*inference_program, scope, feed_targets, fetch_targets); + executor.Run( + *inference_program, scope, feed_targets, fetch_targets, create_vars); } // Disable the profiler and print the timing information diff --git a/paddle/fluid/operators/go_op.cc b/paddle/fluid/operators/go_op.cc index cfa797717d78aa72e1b931b6db6e153270b3424e..58fe32446217e07235b40b9b78190094e57e4951 100644 --- a/paddle/fluid/operators/go_op.cc +++ b/paddle/fluid/operators/go_op.cc @@ -56,11 +56,11 @@ class GoOp : public framework::OperatorBase { // TODO(varunarora): Consider moving this root scope lookup to scope.h. const framework::Scope *root_scope = &scope; - const framework::Scope *parent_scope = &(root_scope->parent()); + const framework::Scope *parent_scope = root_scope->parent(); while (parent_scope != nullptr) { root_scope = parent_scope; - parent_scope = &(parent_scope->parent()); + parent_scope = parent_scope->parent(); } framework::BlockDesc *block = Attr(kBlock);