From 4812eda5830a1216c0bad115ee86a3c41abbcc46 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Tue, 23 Nov 2021 20:02:51 +0800 Subject: [PATCH] set feed var skip inplace, test=develop (#37467) --- .../framework/new_executor/interpretercore.cc | 19 ++++++++++++++++++- .../framework/new_executor/interpretercore.h | 4 +++- .../new_executor/new_executor_defs.cc | 10 ++++++++++ .../new_executor/new_executor_defs.h | 5 +++++ .../new_executor/standalone_executor.cc | 2 +- 5 files changed, 37 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index bea36168a78..c7b922ad9e4 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -93,7 +93,8 @@ paddle::framework::FetchList InterpreterCore::Run( return *(fetch_var->GetMutable()); } -paddle::framework::FetchList InterpreterCore::Run() { +paddle::framework::FetchList InterpreterCore::Run( + const std::vector& feed_names) { if (!is_build_) { if (create_local_scope_ && global_scope_->GetMutableLocalScope() != @@ -113,6 +114,7 @@ paddle::framework::FetchList InterpreterCore::Run() { paddle::framework::interpreter::build_op_func_list( place_, block_, &op_func_nodes, global_scope_, create_local_scope_); is_build_ = true; + SetFeedVarsInplaceSkip(feed_names); // convert vec func_list to graph Convert(&op_func_nodes); @@ -260,6 +262,13 @@ void InterpreterCore::BuildInplace() { for (auto& pair : in_to_outs) { auto iter = inputs.find(pair.first); if (iter != inputs.end() && !iter->second.empty()) { + auto in_var_desc = global_scope_->VarDesc(iter->second[0]); + if (in_var_desc && in_var_desc->Persistable()) { + continue; + } + if (global_scope_->GetVarSikpInplace(iter->second[0])) { + continue; + } if (BuildInplaceCheckVarIsOnlyInput(iter->second[0])) { auto iterout = outputs.find(pair.second); if (iterout != outputs.end() && !iterout->second.empty()) { @@ -578,6 +587,7 @@ void InterpreterCore::Prepare( paddle::framework::interpreter::build_op_func_list( place_, block_, &op_func_nodes, global_scope_, create_local_scope_); is_build_ = true; + SetFeedVarsInplaceSkip(feed_names); // convert vec func_list to graph Convert(&op_func_nodes); } @@ -604,5 +614,12 @@ interpreter::CostInfo InterpreterCore::DryRun( return cost_info; } +void InterpreterCore::SetFeedVarsInplaceSkip( + const std::vector& feed_names) { + for (auto& feed_name : feed_names) { + global_scope_->SetVarSikpInplace(feed_name, true); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index 204e4ff3e4d..70392faf6d4 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -49,7 +49,7 @@ class InterpreterCore { const std::vector& feed_names, const std::vector& feed_tensors); - paddle::framework::FetchList Run(); + paddle::framework::FetchList Run(const std::vector& feed_names); interpreter::CostInfo DryRun( const std::vector& feed_names, @@ -84,6 +84,8 @@ class InterpreterCore { void BuildOperatorDependences(); + void SetFeedVarsInplaceSkip(const std::vector& feed_names); + bool is_build_; const platform::Place& place_; diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.cc b/paddle/fluid/framework/new_executor/new_executor_defs.cc index 2fd27bc0765..0d7c5187e91 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.cc +++ b/paddle/fluid/framework/new_executor/new_executor_defs.cc @@ -598,6 +598,16 @@ paddle::framework::VarDesc* VariableScope::VarDesc(int id) const { return vec_meta_info_[id].var_desc_; } +void VariableScope::SetVarSikpInplace(const std::string& name, bool skip) { + CheckExist(name); + vec_meta_info_[VarId(name)].sikp_inplace_ = skip; +} + +bool VariableScope::GetVarSikpInplace(int id) const { + CheckExist(id); + return vec_meta_info_[id].sikp_inplace_; +} + void VariableScope::CheckExist(int id) const { PADDLE_ENFORCE_LT(id, var_list_.size(), platform::errors::PreconditionNotMet( diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index a21aa47b899..94631b7ae64 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -145,6 +145,7 @@ struct OpKernelFunc { struct VariableMetaInfo { int var_ref_count_{0}; framework::VarDesc* var_desc_{nullptr}; + bool sikp_inplace_{false}; VariableMetaInfo() {} VariableMetaInfo(int var_ref_count, framework::VarDesc* var_desc) @@ -228,6 +229,10 @@ class VariableScope : public ScopeBase { return listener_; } + void SetVarSikpInplace(const std::string& name, bool skip); + + bool GetVarSikpInplace(int id) const; + friend class VariableScopeListener; private: diff --git a/paddle/fluid/framework/new_executor/standalone_executor.cc b/paddle/fluid/framework/new_executor/standalone_executor.cc index 56999e4463b..51885543d12 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor.cc @@ -68,7 +68,7 @@ paddle::framework::FetchList StandaloneExecutor::Run( const std::vector& fetch_names) { auto core = GetInterpreterCore(feed_names, fetch_names, false); VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core; - return core->Run(); + return core->Run(feed_names); } framework::interpreter::CostInfo StandaloneExecutor::DryRun( -- GitLab