diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 2d7fb6ccda328913ae3a74b94ff4f2a118dd61a6..bea36168a786d6b4f275d0d2509c36b3adc557b7 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -22,6 +22,9 @@ PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace, true, "Use inplace in new executor"); +PADDLE_DEFINE_EXPORTED_bool(new_executor_use_local_scope, true, + "Use local_scope in new executor(especially used " + "in UT), can turn off for better performance"); DECLARE_bool(check_nan_inf); DECLARE_bool(benchmark); @@ -48,6 +51,14 @@ InterpreterCore::InterpreterCore(const platform::Place& place, exception_notifier_ = main_thread_blocker_.RegisterEvent( kExceptionCaught, [this]() { return exception_holder_.IsCaught(); }); + create_local_scope_ = FLAGS_new_executor_use_local_scope; + if (FLAGS_new_executor_use_local_scope) { + auto local_scope = &global_scope->GetMutableScope()->NewScope(); + local_scope->AddListener(global_scope->Listener()); + local_scope_ = local_scope; + } + VLOG(4) << "create_local_scope_ is " << create_local_scope_; + // prune // optmize graph pass @@ -62,10 +73,15 @@ InterpreterCore::~InterpreterCore() { async_work_queue_.reset(nullptr); } +void InterpreterCore::SetCopyProgram(std::shared_ptr prog) { + copy_program_ = prog; +} + paddle::framework::FetchList InterpreterCore::Run( const std::vector& feed_names, const std::vector& feed_tensors) { bool is_build = is_build_; + global_scope_->SetLocalScope(local_scope_); Prepare(feed_names, feed_tensors, is_build); if (is_build) { @@ -79,13 +95,27 @@ paddle::framework::FetchList InterpreterCore::Run( paddle::framework::FetchList InterpreterCore::Run() { if (!is_build_) { - paddle::framework::interpreter::build_variable_scope(block_, global_scope_); + if (create_local_scope_ && + global_scope_->GetMutableLocalScope() != + global_scope_->GetMutableScope() && + global_scope_->GetMutableLocalScope()) { + VLOG(4) << "Clear previous local scope before run"; + VLOG(4) << global_scope_->GetMutableScope() << " " + << global_scope_->GetMutableLocalScope(); + platform::DeviceContextPool::Instance().Get(place_)->Wait(); + // TODO(zhiqiu): clear the tensor holder of all vars in previous local + // scope? + } + global_scope_->SetLocalScope(local_scope_); + paddle::framework::interpreter::build_variable_scope(block_, global_scope_, + create_local_scope_); std::vector op_func_nodes; paddle::framework::interpreter::build_op_func_list( - place_, block_, &op_func_nodes, global_scope_); + place_, block_, &op_func_nodes, global_scope_, create_local_scope_); is_build_ = true; // convert vec func_list to graph Convert(&op_func_nodes); + } else { ExecuteInstructionList(vec_instruction_); } @@ -300,7 +330,10 @@ void InterpreterCore::BuildSkipShareLoDInfo() { void InterpreterCore::RunInstruction(const Instruction& instr_node) { auto* op = instr_node.OpBase(); auto place = instr_node.DeviceContext().GetPlace(); - VLOG(4) << "Start run" << place << " " << op->DebugStringEx(global_scope_); + VLOG(4) << "Start run " << place << " " << op->DebugStringEx(global_scope_); + Scope* local_scope = create_local_scope_ + ? global_scope_->GetMutableLocalScope() + : global_scope_->GetMutableScope(); auto op_with_kernel = dynamic_cast(op); { @@ -325,13 +358,14 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { } { platform::RecordEvent compute_event("Compute"); - if (op_with_kernel == nullptr) - instr_node.OpBase()->Run(*global_scope_->GetScope(), place_); - else + if (op_with_kernel == nullptr) { + instr_node.OpBase()->Run(*local_scope, place_); + } else { instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get()); + } } - VLOG(4) << "End run" << place << " " << op->DebugStringEx(global_scope_); + VLOG(4) << "End run " << place << " " << op->DebugStringEx(global_scope_); /*For profiling/benchmark only*/ if (FLAGS_benchmark) { @@ -372,8 +406,8 @@ void InterpreterCore::ExecuteInstructionList( } } - auto event_id = main_thread_blocker_.WaitEvent(); - VLOG(3) << "event_id " << event_id; + auto event_name = main_thread_blocker_.WaitEvent(); + VLOG(3) << "event_name: " << event_name; if (UNLIKELY(exception_holder_.IsCaught())) { VLOG(4) << "Exception caught " << exception_holder_.Type(); @@ -526,8 +560,9 @@ void InterpreterCore::Prepare( VLOG(4) << "Feed inputs"; for (size_t i = 0; i < feed_names.size(); ++i) { auto* feed_var = global_scope_->FindVar(feed_names[i]); - PADDLE_ENFORCE_NOT_NULL(feed_var, platform::errors::NotFound( - "feed_var shall not be nullptr.")); + PADDLE_ENFORCE_NOT_NULL( + feed_var, platform::errors::NotFound( + "Variable %s should not be nullptr.", feed_names[i])); auto feed_tensor = feed_var->GetMutable(); feed_tensor->ShareDataWith(feed_tensors[i]); @@ -536,11 +571,12 @@ void InterpreterCore::Prepare( }; if (!is_build_) { - paddle::framework::interpreter::build_variable_scope(block_, global_scope_); + paddle::framework::interpreter::build_variable_scope(block_, global_scope_, + create_local_scope_); FeedInput(); std::vector op_func_nodes; paddle::framework::interpreter::build_op_func_list( - place_, block_, &op_func_nodes, global_scope_); + place_, block_, &op_func_nodes, global_scope_, create_local_scope_); is_build_ = true; // convert vec func_list to graph Convert(&op_func_nodes); @@ -556,6 +592,7 @@ void InterpreterCore::Prepare( interpreter::CostInfo InterpreterCore::DryRun( const std::vector& feed_names, const std::vector& feed_tensors) { + global_scope_->SetLocalScope(local_scope_); Prepare(feed_names, feed_tensors, true); interpreter::CostInfo cost_info; { diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index e56784946a3a6ce146c9af149c857d74e7763ce4..204e4ff3e4d6779a7bb21bd05b4cdc591f519ad0 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -55,6 +55,8 @@ class InterpreterCore { const std::vector& feed_names, const std::vector& feed_tensors); + void SetCopyProgram(std::shared_ptr prog); + private: void Convert(std::vector* op_func_nodes); @@ -85,7 +87,13 @@ class InterpreterCore { bool is_build_; const platform::Place& place_; - const BlockDesc& block_; // not owned + const BlockDesc& block_; // not owned + // NOTE(zhiqiu): when add fetch ops in GetInterpreterCore, we will + // copy a new program and block, the copy_program_ here is used to + // hold the program, otherwise block_ maybe not valid after the + // new program is deleted. + std::shared_ptr copy_program_{nullptr}; + VariableScope* global_scope_; // not owned std::vector vec_instruction_; // deconstruct before OpFuncNode @@ -102,6 +110,8 @@ class InterpreterCore { std::unique_ptr gc_; std::vector gc_event_; + bool create_local_scope_{true}; + Scope* local_scope_{nullptr}; // not owned }; } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index f7f2dd0213409f3f37a9ed2b376341aecb9bda48..07743150b60038e89c88500a38daa73ae910afde 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -132,23 +132,35 @@ std::string get_memcpy_type(const platform::Place& src_place, } void build_variable_scope(const framework::BlockDesc& block, - VariableScope* var_scope) { + VariableScope* var_scope, bool use_local_scope) { + VLOG(3) << "Creating Variables"; + auto inner_scope = var_scope->GetMutableScope(); + + // NOTE(zhiqiu): if create_local_scope_ is true, the persistable is + // created in var_scope.scope_ , and other scope is created in local scope. + Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope() + : var_scope->GetMutableScope(); + for (auto& var_desc : block.AllVars()) { auto var_name = var_desc->Name(); if (var_name == framework::kEmptyVarName) { continue; } + if (var_desc->Persistable()) { + auto* ptr = inner_scope->Var(var_name); - if (nullptr == var_scope->FindVar(var_name)) { - var_scope->AddVar(var_desc->Name(), var_desc); + VLOG(3) << "Initialize Variable " << var_name; + InitializeVariable(ptr, var_desc->GetType()); + VLOG(3) << "Create Variable " << var_name << " global, which pointer is " + << ptr << " type is " << static_cast(var_desc->GetType()); } else { - auto* var_desc_tmp = var_scope->VarDesc(var_name); - if (nullptr == var_desc_tmp) { - VLOG(3) << "update var:" << var_name << " desc from nullptr into " - << var_desc; - var_scope->SetVarDesc(var_name, var_desc); - } + auto* ptr = local_scope->Var(var_name); + InitializeVariable(ptr, var_desc->GetType()); + VLOG(3) << "Create Variable " << var_name << " locally, which pointer is " + << ptr << "Variable Type " + << static_cast(var_desc->GetType()); } + var_scope->SetVarDesc(var_name, var_desc); } } @@ -237,14 +249,14 @@ void apply_device_guard(const OperatorBase* op_base, void deal_operator_base(const platform::Place& place, const VariableScope* var_scope, std::shared_ptr op_base, - OpFuncNode* op_func_node) { + OpFuncNode* op_func_node, Scope* local_scope) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); // input, output is prepared. set the other attributes. op_func_node->operator_base_ = op_base; op_func_node->type_ = OpFuncType::kQueueSync; // alway Sync op_func_node->kernel_func_ = nullptr; - op_base->Run(*var_scope->GetScope(), place); // Run without data transformer. + op_base->Run(*local_scope, place); // Run without data transformer. std::unordered_set no_data_transform_index; for (auto& it : op_func_node->input_index) { @@ -288,12 +300,21 @@ std::tuple apply_place_transform_for_var( const OpKernelType& kernel_type_for_var, const OpKernelType& expected_kernel_key, const platform::Place& place, const std::string& var_name, const std::string& outer_name, - const OpFuncNode& op_func_node, Variable* var, VariableScope* var_scope) { + const OpFuncNode& op_func_node, Variable* var, VariableScope* var_scope, + bool use_local_scope = true) { + Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope() + : var_scope->GetMutableScope(); + auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); std::string new_var_name = var_name + "_copy_" + std::to_string(var_scope->VarSize() + 1); - var_scope->AddVar(new_var_name, nullptr); + + auto* ptr = local_scope->Var(new_var_name); + InitializeVariable(ptr, static_cast(var->Type())); + VLOG(3) << "Create Variable " << var_name << " locally, which pointer is " + << ptr << "Variable Type " << var->Type(); + var_scope->SetVarDesc(var_name, nullptr); VariableNameMap copy_in_map; copy_in_map["X"] = {var_name}; @@ -368,7 +389,8 @@ void apply_data_transform(const OpKernelType& expected_kernel_key, const platform::Place& place, VariableValueMap* ins_map_temp, VariableScope* var_scope, OpFuncNode* op_func_node, - std::vector* copy_func_nodes) { + std::vector* copy_func_nodes, + bool use_local_scope = true) { auto op_base = op_func_node->operator_base_.get(); PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet( "op_base is null, please pass a valid " @@ -402,9 +424,10 @@ void apply_data_transform(const OpKernelType& expected_kernel_key, std::string new_var_name; OpFuncNode copy_op_func_node; std::tie(new_var_name, copy_op_func_node) = - apply_place_transform_for_var( - kernel_type_for_var, expected_kernel_key, place, var_name, - var_name_item.first, *op_func_node, var, var_scope); + apply_place_transform_for_var(kernel_type_for_var, + expected_kernel_key, place, var_name, + var_name_item.first, *op_func_node, + var, var_scope, use_local_scope); op_func_node->input_index[var_name_item.first][i] = var_scope->VarId(new_var_name); copy_func_nodes->emplace_back(copy_op_func_node); @@ -438,7 +461,9 @@ void apply_data_transform(const OpKernelType& expected_kernel_key, void build_op_func_list(const platform::Place& place, const framework::BlockDesc& block, std::vector* vec_func_list, - VariableScope* var_scope) { + VariableScope* var_scope, bool use_local_scope) { + Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope() + : var_scope->GetMutableScope(); auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); std::vector> ops; // its elements will be moved to vec_func_list @@ -478,7 +503,7 @@ void build_op_func_list(const platform::Place& place, if (dynamic_cast(op) == nullptr) { // op is not a operatorwithkernel, so direcly run OperatorBase::Run() - deal_operator_base(place, var_scope, ops[i], &op_func_node); + deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope); } else { // construct RuntimeContext and analysis KernelType RuntimeContext runtime_context({}, {}); @@ -520,7 +545,7 @@ void build_op_func_list(const platform::Place& place, // apply_data_transform. op_func_node.operator_base_ = ops[i]; apply_data_transform(expected_kernel_key, place, &ins_map_temp, var_scope, - &op_func_node, ©_op_to_insert); + &op_func_node, ©_op_to_insert, use_local_scope); for (auto& item : copy_op_to_insert) { vec_func_list->push_back(item); } @@ -631,16 +656,16 @@ std::vector merge_vector(const std::vector& first, } void update_var_min_rw_op(const std::map>& op2dependences, - std::map>& var2min_rw_op, + std::map>* var2min_rw_op, int cur_op, int rw_var) { // rw_var is inputs or outputs of cur_op // this function update the var2min_rw_op set . - if (var2min_rw_op.find(rw_var) == var2min_rw_op.end()) - var2min_rw_op[rw_var] = std::list(); + if (var2min_rw_op->find(rw_var) == var2min_rw_op->end()) + (*var2min_rw_op)[rw_var] = std::list(); for (auto dep_op : op2dependences.at(cur_op)) { - var2min_rw_op[rw_var].remove(dep_op); + (*var2min_rw_op)[rw_var].remove(dep_op); } - var2min_rw_op[rw_var].push_back(cur_op); + (*var2min_rw_op)[rw_var].push_back(cur_op); } std::map> get_downstream_map( @@ -702,7 +727,7 @@ std::map> build_op_downstream_map( for (auto& item : vec_instruction[op_idx].Inputs()) { // for all inputs(read only) for (auto var : item.second) { - update_var_min_rw_op(op2dependences, var2min_rw_op, op_idx, var); + update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var); remove_duplicate.insert(var); } } @@ -713,7 +738,7 @@ std::map> build_op_downstream_map( var2recent_write_op[var] = op_idx; if (remove_duplicate.count(var) == 0) { // var in input list and in output list, so remove it. - update_var_min_rw_op(op2dependences, var2min_rw_op, op_idx, var); + update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var); } } } diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.h b/paddle/fluid/framework/new_executor/interpretercore_util.h index f3b1a8a6b4a53a1a0ea57c7eb9c37ae657b895a6..60312d153c361eeb1162d1b82e583bb2c33371e5 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.h +++ b/paddle/fluid/framework/new_executor/interpretercore_util.h @@ -51,7 +51,7 @@ namespace framework { namespace interpreter { using AtomicVectorSizeT = std::vector>>; -static constexpr char kFetchVarName[] = "fetch_vars"; +static constexpr char kFetchVarName[] = "fetch"; class AsyncWorkQueue { public: @@ -98,12 +98,13 @@ std::string get_memcpy_type(const platform::Place& src_place, const platform::Place& dst_place); void build_variable_scope(const framework::BlockDesc& block, - VariableScope* var_scope); + VariableScope* var_scope, + bool use_local_scope = true); void build_op_func_list(const platform::Place& place, const framework::BlockDesc& block, std::vector* vec_func_list, - VariableScope* var_scope); + VariableScope* var_scope, bool use_local_scope = true); std::map> build_op_downstream_map( const std::vector& vec_instruction); diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.cc b/paddle/fluid/framework/new_executor/new_executor_defs.cc index 7d40102cbe7647cb2e7716f2ff3027bbeb7b22a5..2fd27bc076598bbad5a5c7da43d400349ce71171 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.cc +++ b/paddle/fluid/framework/new_executor/new_executor_defs.cc @@ -497,7 +497,14 @@ VariableScope::~VariableScope() { } } -const Scope* VariableScope::GetScope() const { return scope_; } +Scope* VariableScope::GetMutableScope() const { return scope_; } + +Scope* VariableScope::GetMutableLocalScope() const { return local_scope_; } + +void VariableScope::SetLocalScope(Scope* local_scope) { + VLOG(4) << "Set local scope: " << local_scope; + local_scope_ = local_scope; +} Variable* VariableScope::FindVar(const std::string& name) const { auto it = name2id_.find(name); @@ -554,8 +561,9 @@ Variable* VariableScope::Var(const std::string& name) const { size_t VariableScope::VarSize() const { return var_list_.size(); } void VariableScope::AddVar(const std::string& name, - framework::VarDesc* var_desc) { // NOLINT - auto v = scope_->Var(name); + framework::VarDesc* var_desc, + bool local_scope) { // NOLINT + auto v = local_scope ? local_scope_->Var(name) : scope_->Var(name); if (nullptr == var_desc) { v->GetMutable(); } else { @@ -606,9 +614,9 @@ VariableScopeListener::VariableScopeListener(VariableScope* var_scope) { var_scope_ = var_scope; } -void VariableScopeListener::onCreateVariable(const std::string& name) { - auto v = var_scope_->scope_->GetVar(name); // must exsit in outer_scope_ - if (!var_scope_->HasVar(name)) { // may exist in variable scope. +void VariableScopeListener::onCreateVariable(const std::string& name, + Variable* v) { + if (!var_scope_->HasVar(name)) { // may exist in variable scope. VLOG(4) << "Calling VariableScope::onCreateVariable with var_name: " << name; var_scope_->name2id_[name] = var_scope_->VarSize(); diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index 68ea48fd328032de330d28bc958c542af0656cda..a21aa47b899ef75be7871e905ef751a61e291dad 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -155,7 +155,7 @@ class VariableScope; class VariableScopeListener : public ScopeListener { public: explicit VariableScopeListener(VariableScope* var_scope_); - void onCreateVariable(const std::string& name) override; + void onCreateVariable(const std::string& name, Variable* v) override; void onDeleteVariable(const std::string& name) override; void onRenameVariable(const std::string& old_name, const std::string& new_name) override; @@ -177,7 +177,11 @@ class VariableScope : public ScopeBase { public: explicit VariableScope(Scope* scope); - const Scope* GetScope() const; + Scope* GetMutableScope() const; + + Scope* GetMutableLocalScope() const; + + void SetLocalScope(Scope* local_scope); Variable* FindVar(const std::string& name) const; @@ -199,7 +203,8 @@ class VariableScope : public ScopeBase { size_t VarSize() const; - void AddVar(const std::string& name, VarDesc* var_desc); + void AddVar(const std::string& name, VarDesc* var_desc, + bool local_scope = false); void AddVar(const std::string& name, const Variable& var); @@ -219,15 +224,21 @@ class VariableScope : public ScopeBase { return vec_meta_info_; } + const std::shared_ptr& Listener() const { + return listener_; + } + friend class VariableScopeListener; private: std::vector var_list_; std::map name2id_; std::vector vec_meta_info_; - Scope* scope_ = nullptr; + Scope* scope_{nullptr}; + // TODO(zhiqiu): find a better way to support local scope. + Scope* local_scope_{nullptr}; // mutable RWLock vars_lock_; - std::shared_ptr listener_; + std::shared_ptr listener_{nullptr}; }; class NextInstruction { diff --git a/paddle/fluid/framework/new_executor/standalone_executor.cc b/paddle/fluid/framework/new_executor/standalone_executor.cc index 3344143e651f62011452e676f4dd53b760d501af..1cef303a05b83865a1097ed97c46e0dd8aea57e7 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor.cc @@ -24,30 +24,36 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, startup_prog_(startup_prog), main_prog_(main_prog), global_scope_(VariableScope(scope)) { - // init scope - BuildVariableScope(startup_prog, &global_scope_); - - if (scope != nullptr) { - auto name_list = scope->LocalVarNames(); - for (auto name : name_list) { - auto v = scope->Var(name); - if (!global_scope_.HasVar(name)) { - global_scope_.AddVar(name, *v); + // NOTE(zhiqiu): for startup_program, initialize scope and run once + // if startup_program is empty, the scope is initialize during first run + if (startup_prog.Block(0).AllOps().size() > 0) { + VLOG(4) << "Run startup program"; + // init scope + BuildVariableScope(startup_prog, &global_scope_); + + if (scope != nullptr) { + auto name_list = scope->LocalVarNames(); + for (auto name : name_list) { + auto v = scope->Var(name); + if (!global_scope_.HasVar(name)) { + global_scope_.AddVar(name, *v); + } } } - } - // run startup program - std::vector vec_func_list; - paddle::framework::interpreter::build_op_func_list( - place_, startup_prog.Block(0), &vec_func_list, &global_scope_); + std::vector vec_func_list; + // No need to use_local_scope for startup_program, its variables are + // persistable + paddle::framework::interpreter::build_op_func_list( + place_, startup_prog.Block(0), &vec_func_list, &global_scope_, false); + } } paddle::framework::FetchList StandaloneExecutor::Run( const std::vector& feed_names, const std::vector& feed_tensors, const std::vector& fetch_names) { - auto core = GetInterpreterCore(feed_names, fetch_names); + auto core = GetInterpreterCore(feed_names, fetch_names, true); return core->Run(feed_names, feed_tensors); } @@ -55,15 +61,15 @@ paddle::framework::FetchList StandaloneExecutor::Run( paddle::framework::FetchList StandaloneExecutor::Run( const std::vector& feed_names, const std::vector& fetch_names) { - auto core = GetInterpreterCore(feed_names, fetch_names); - + auto core = GetInterpreterCore(feed_names, fetch_names, false); + VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core; return core->Run(); } framework::interpreter::CostInfo StandaloneExecutor::DryRun( const std::vector& feed_names, const std::vector& feed_tensors) { - auto core = GetInterpreterCore(feed_names, {}); + auto core = GetInterpreterCore(feed_names, {}, true); return core->DryRun(feed_names, feed_tensors); } @@ -85,7 +91,7 @@ void StandaloneExecutor::BuildVariableScope(const framework::ProgramDesc& pdesc, std::shared_ptr StandaloneExecutor::GetInterpreterCore( const std::vector& feed_names, - const std::vector& fetch_names) { + const std::vector& fetch_names, bool add_fetch_op) { std::ostringstream oss; oss << "feed:"; for (auto& feedname : feed_names) { @@ -100,15 +106,22 @@ std::shared_ptr StandaloneExecutor::GetInterpreterCore( if (iter == interpretercores_.end()) { VLOG(3) << "create interpreter_core for " << oss.str(); - // NOTE(Aurelius84): `add_fetch` will modify BlockDesc, so we should copy a - // new program. - auto new_prog = std::make_shared(main_prog_); - auto* block = new_prog->MutableBlock(0); - interpreter::add_fetch(fetch_names, block); - - auto core = - std::make_shared(place_, *block, &global_scope_); - programs_.emplace(oss.str(), new_prog); + VLOG(3) << "add fetch op: " << add_fetch_op; + std::shared_ptr core = nullptr; + if (add_fetch_op) { + // NOTE(Aurelius84): `add_fetch` will modify BlockDesc, so we should copy + // a + // new program. + auto new_prog = std::make_shared(main_prog_); + auto* block = new_prog->MutableBlock(0); + interpreter::add_fetch(fetch_names, block); + + core = std::make_shared(place_, *block, &global_scope_); + core->SetCopyProgram(new_prog); + } else { + core = std::make_shared(place_, main_prog_.Block(0), + &global_scope_); + } interpretercores_.emplace(oss.str(), core); return core; } else { diff --git a/paddle/fluid/framework/new_executor/standalone_executor.h b/paddle/fluid/framework/new_executor/standalone_executor.h index eb46cb8aabf00d8bb04389cc8f6799872ce2217d..e84df2abb36d99f3ccfc49d41e044fa4fe173018 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.h +++ b/paddle/fluid/framework/new_executor/standalone_executor.h @@ -61,14 +61,13 @@ class StandaloneExecutor : public ExecutorBase { std::shared_ptr GetInterpreterCore( const std::vector& feed_names, - const std::vector& fetch_names); + const std::vector& fetch_names, bool add_fetch_op); const platform::Place& place_; const ProgramDesc& startup_prog_; const ProgramDesc& main_prog_; VariableScope global_scope_; - std::unordered_map> programs_; std::unordered_map> interpretercores_; }; diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index 4bb94a4e7e5a1857ba10addddbb830b7ab7e8748..49cca5018ced619ef5a0d20bb315e9f88fb3672a 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -67,7 +67,7 @@ Variable* Scope::Var(const std::string& name) { ret = VarInternal(name); } for (auto l : listeners_) { - l->onCreateVariable(name); + l->onCreateVariable(name, ret); } return ret; } @@ -85,7 +85,7 @@ Variable* Scope::Var(std::string* name) { ret = VarInternal(new_name); } for (auto l : listeners_) { - l->onCreateVariable(new_name); + l->onCreateVariable(new_name, ret); } return ret; } diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index 892618b7e6cc19fee949826e84188bcff2f6de3f..c18d1d588a356350da741f8766937538c27eb7fd 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -58,7 +58,7 @@ class ScopeListener { // in original Scope. public: virtual ~ScopeListener() {} - virtual void onCreateVariable(const std::string& name) {} + virtual void onCreateVariable(const std::string& name, Variable* v) {} virtual void onDeleteVariable(const std::string& name) {} virtual void onRenameVariable(const std::string& old_name, const std::string& new_name) {} diff --git a/paddle/fluid/operators/controlflow/fetch_v2_op.cc b/paddle/fluid/operators/controlflow/fetch_v2_op.cc index 0837caf9353a3a3dda60453fb80699b716a5c91d..93035dddefee799fff5d7ea673d23f23c6a10746 100644 --- a/paddle/fluid/operators/controlflow/fetch_v2_op.cc +++ b/paddle/fluid/operators/controlflow/fetch_v2_op.cc @@ -112,13 +112,6 @@ class FetchV2Op : public framework::OperatorWithKernel { } }; -class FetchV2InferVarType : public framework::VarTypeInference { - public: - void operator()(framework::InferVarTypeContext *ctx) const override { - ctx->SyncTypeAndDataType("X", "Out"); - } -}; - class FetchV2Kernel { public: void operator()(const framework::ExecutionContext &ctx) const { @@ -211,7 +204,6 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OPERATOR( fetch_v2, ops::FetchV2Op, ops::FetchV2OpProtoMaker, - ops::FetchV2InferVarType, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 09279f1208dc32f53fc714a2a87d5bb448252b79..f04c67c3174b6d8a3a392eaab7d8ca2e48c36808 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -288,7 +288,10 @@ def has_feed_operators(block, feed_targets, feed_holder_name): return feed_count > 0 -def has_fetch_operators(block, fetch_targets, fetch_holder_name): +def has_fetch_operators(block, + fetch_targets, + fetch_holder_name, + fetch_op='fetch'): """ Check whether the block already has fetch operators. Return false if the block does not have any fetch operators. @@ -303,6 +306,7 @@ def has_fetch_operators(block, fetch_targets, fetch_holder_name): fetch_holder_name: the name of the variable that holds the data of all fetch targets. The type of this fetch_holder variable is FETCH_LIST, which is essentially vector. + fetch_op: the operator name of fetch Return: A boolean value that indicates whether a block has fetch operators @@ -311,7 +315,7 @@ def has_fetch_operators(block, fetch_targets, fetch_holder_name): fetch_count = 0 for op in block.ops: - if op.desc.type() == 'fetch': + if op.desc.type() == fetch_op: fetch_count += 1 assert op.desc.output('Out')[0] == fetch_holder_name fetch_target_name = op.desc.input('X')[0] @@ -740,7 +744,7 @@ class Executor(object): fetch_list, feed_var_name, fetch_var_name, - skip_fetch=False): + use_fetch_v2=False): tmp_program = program.clone() global_block = tmp_program.global_block() @@ -775,17 +779,21 @@ class Executor(object): warnings.warn( "The variable %s is not found in program. It is not declared or is pruned." % name) - if skip_fetch: - return tmp_program + + if use_fetch_v2: + fetch_op = 'fetch_v2' + else: + fetch_op = 'fetch' # append fetch_operators - if not has_fetch_operators(global_block, fetch_list, fetch_var_name): + if not has_fetch_operators(global_block, fetch_list, fetch_var_name, + fetch_op): for i, var in enumerate(fetch_list): assert isinstance(var, Variable) or isinstance( var, six.string_types), ( "Wrong type for fetch_list[%s]: %s" % (i, type(var))) global_block.append_op( - type='fetch', + type=fetch_op, inputs={'X': [var]}, outputs={'Out': [fetch_var]}, attrs={'col': i}) @@ -1345,7 +1353,13 @@ class Executor(object): fetch_list=fetch_list, feed_var_name=feed_var_name, fetch_var_name=fetch_var_name, - skip_fetch=True) + use_fetch_v2=True) + + # NPTE(zhiqiu): Construct standalone_executor first, so + # the scope is binded with the variable_scope of standalone_executor + new_exe = self._executor_cache._get_exe_from_cache(program, + scope) + self._feed_data(program, feed, feed_var_name, scope) if hasattr(program, 'lr_sheduler'): from paddle.optimizer.lr import LRScheduler @@ -1360,9 +1374,7 @@ class Executor(object): lr_sheduler._var_name) tensor.set(data, self.place) - return self._executor_cache.run(program, scope, - list(feed.keys()), fetch_list, - return_numpy) + return new_exe.run(list(feed.keys()), fetch_list, return_numpy) # use_prune can be overrided by putting optimize_ops in fetch_list _origin_fetch_list = fetch_list diff --git a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py b/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py index eb8fb3b5f9457f721f0e3d3380f3f99f77fd5f89..d6b10c8ca69ae31077b79a4c6ff2a8d4efe1527c 100644 --- a/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py +++ b/python/paddle/fluid/tests/unittests/interpreter/test_standalone_executor.py @@ -309,7 +309,7 @@ class TestException(unittest.TestCase): feed[1]['data'][0] = np.nan self.assertRaises(RuntimeError, self.run_new_executor, feed) - def test_scope(self): + def test_scope_find_temp_var(self): feed = [{ 'id': np.array([1, 2, 3, 4, 5]).astype(np.int64), 'data': np.array([1, 2, 3]).astype(np.float32), @@ -318,7 +318,7 @@ class TestException(unittest.TestCase): 'data': np.array([2, 2, 2]).astype(np.float32), }] self.run_new_executor(feed) - self.assertIsNotNone(paddle.static.global_scope().find_var( + self.assertIsNone(paddle.static.global_scope().find_var( self.fetch_vars.name))