diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index bea36168a786d6b4f275d0d2509c36b3adc557b7..c7b922ad9e424ebd85449061e0554981183a54dc 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -93,7 +93,8 @@ paddle::framework::FetchList InterpreterCore::Run( return *(fetch_var->GetMutable()); } -paddle::framework::FetchList InterpreterCore::Run() { +paddle::framework::FetchList InterpreterCore::Run( + const std::vector& feed_names) { if (!is_build_) { if (create_local_scope_ && global_scope_->GetMutableLocalScope() != @@ -113,6 +114,7 @@ paddle::framework::FetchList InterpreterCore::Run() { paddle::framework::interpreter::build_op_func_list( place_, block_, &op_func_nodes, global_scope_, create_local_scope_); is_build_ = true; + SetFeedVarsInplaceSkip(feed_names); // convert vec func_list to graph Convert(&op_func_nodes); @@ -260,6 +262,13 @@ void InterpreterCore::BuildInplace() { for (auto& pair : in_to_outs) { auto iter = inputs.find(pair.first); if (iter != inputs.end() && !iter->second.empty()) { + auto in_var_desc = global_scope_->VarDesc(iter->second[0]); + if (in_var_desc && in_var_desc->Persistable()) { + continue; + } + if (global_scope_->GetVarSikpInplace(iter->second[0])) { + continue; + } if (BuildInplaceCheckVarIsOnlyInput(iter->second[0])) { auto iterout = outputs.find(pair.second); if (iterout != outputs.end() && !iterout->second.empty()) { @@ -578,6 +587,7 @@ void InterpreterCore::Prepare( paddle::framework::interpreter::build_op_func_list( place_, block_, &op_func_nodes, global_scope_, create_local_scope_); is_build_ = true; + SetFeedVarsInplaceSkip(feed_names); // convert vec func_list to graph Convert(&op_func_nodes); } @@ -604,5 +614,12 @@ interpreter::CostInfo InterpreterCore::DryRun( return cost_info; } +void InterpreterCore::SetFeedVarsInplaceSkip( + const std::vector& feed_names) { + for (auto& feed_name : feed_names) { + global_scope_->SetVarSikpInplace(feed_name, true); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index 204e4ff3e4d6779a7bb21bd05b4cdc591f519ad0..70392faf6d42cfa8c9ae98db75496671d79db7db 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -49,7 +49,7 @@ class InterpreterCore { const std::vector& feed_names, const std::vector& feed_tensors); - paddle::framework::FetchList Run(); + paddle::framework::FetchList Run(const std::vector& feed_names); interpreter::CostInfo DryRun( const std::vector& feed_names, @@ -84,6 +84,8 @@ class InterpreterCore { void BuildOperatorDependences(); + void SetFeedVarsInplaceSkip(const std::vector& feed_names); + bool is_build_; const platform::Place& place_; diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.cc b/paddle/fluid/framework/new_executor/new_executor_defs.cc index 2fd27bc076598bbad5a5c7da43d400349ce71171..0d7c5187e91131f5b2a4e926c0258e19b710b009 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.cc +++ b/paddle/fluid/framework/new_executor/new_executor_defs.cc @@ -598,6 +598,16 @@ paddle::framework::VarDesc* VariableScope::VarDesc(int id) const { return vec_meta_info_[id].var_desc_; } +void VariableScope::SetVarSikpInplace(const std::string& name, bool skip) { + CheckExist(name); + vec_meta_info_[VarId(name)].sikp_inplace_ = skip; +} + +bool VariableScope::GetVarSikpInplace(int id) const { + CheckExist(id); + return vec_meta_info_[id].sikp_inplace_; +} + void VariableScope::CheckExist(int id) const { PADDLE_ENFORCE_LT(id, var_list_.size(), platform::errors::PreconditionNotMet( diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index a21aa47b899ef75be7871e905ef751a61e291dad..94631b7ae6412304b5959f6760d5318f68468300 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -145,6 +145,7 @@ struct OpKernelFunc { struct VariableMetaInfo { int var_ref_count_{0}; framework::VarDesc* var_desc_{nullptr}; + bool sikp_inplace_{false}; VariableMetaInfo() {} VariableMetaInfo(int var_ref_count, framework::VarDesc* var_desc) @@ -228,6 +229,10 @@ class VariableScope : public ScopeBase { return listener_; } + void SetVarSikpInplace(const std::string& name, bool skip); + + bool GetVarSikpInplace(int id) const; + friend class VariableScopeListener; private: diff --git a/paddle/fluid/framework/new_executor/standalone_executor.cc b/paddle/fluid/framework/new_executor/standalone_executor.cc index 56999e4463b08ac8d85f2536cf58657f71270a49..51885543d12fb89460e1c5cfcd1ec35ffc813d40 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor.cc @@ -68,7 +68,7 @@ paddle::framework::FetchList StandaloneExecutor::Run( const std::vector& fetch_names) { auto core = GetInterpreterCore(feed_names, fetch_names, false); VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core; - return core->Run(); + return core->Run(feed_names); } framework::interpreter::CostInfo StandaloneExecutor::DryRun(