未验证 提交 4812eda5 编写于 作者: W wanghuancoder 提交者: GitHub

set feed var skip inplace, test=develop (#37467)

上级 df14dbf0
...@@ -93,7 +93,8 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -93,7 +93,8 @@ paddle::framework::FetchList InterpreterCore::Run(
return *(fetch_var->GetMutable<framework::FetchList>()); return *(fetch_var->GetMutable<framework::FetchList>());
} }
paddle::framework::FetchList InterpreterCore::Run() { paddle::framework::FetchList InterpreterCore::Run(
const std::vector<std::string>& feed_names) {
if (!is_build_) { if (!is_build_) {
if (create_local_scope_ && if (create_local_scope_ &&
global_scope_->GetMutableLocalScope() != global_scope_->GetMutableLocalScope() !=
...@@ -113,6 +114,7 @@ paddle::framework::FetchList InterpreterCore::Run() { ...@@ -113,6 +114,7 @@ paddle::framework::FetchList InterpreterCore::Run() {
paddle::framework::interpreter::build_op_func_list( paddle::framework::interpreter::build_op_func_list(
place_, block_, &op_func_nodes, global_scope_, create_local_scope_); place_, block_, &op_func_nodes, global_scope_, create_local_scope_);
is_build_ = true; is_build_ = true;
SetFeedVarsInplaceSkip(feed_names);
// convert vec func_list to graph // convert vec func_list to graph
Convert(&op_func_nodes); Convert(&op_func_nodes);
...@@ -260,6 +262,13 @@ void InterpreterCore::BuildInplace() { ...@@ -260,6 +262,13 @@ void InterpreterCore::BuildInplace() {
for (auto& pair : in_to_outs) { for (auto& pair : in_to_outs) {
auto iter = inputs.find(pair.first); auto iter = inputs.find(pair.first);
if (iter != inputs.end() && !iter->second.empty()) { if (iter != inputs.end() && !iter->second.empty()) {
auto in_var_desc = global_scope_->VarDesc(iter->second[0]);
if (in_var_desc && in_var_desc->Persistable()) {
continue;
}
if (global_scope_->GetVarSikpInplace(iter->second[0])) {
continue;
}
if (BuildInplaceCheckVarIsOnlyInput(iter->second[0])) { if (BuildInplaceCheckVarIsOnlyInput(iter->second[0])) {
auto iterout = outputs.find(pair.second); auto iterout = outputs.find(pair.second);
if (iterout != outputs.end() && !iterout->second.empty()) { if (iterout != outputs.end() && !iterout->second.empty()) {
...@@ -578,6 +587,7 @@ void InterpreterCore::Prepare( ...@@ -578,6 +587,7 @@ void InterpreterCore::Prepare(
paddle::framework::interpreter::build_op_func_list( paddle::framework::interpreter::build_op_func_list(
place_, block_, &op_func_nodes, global_scope_, create_local_scope_); place_, block_, &op_func_nodes, global_scope_, create_local_scope_);
is_build_ = true; is_build_ = true;
SetFeedVarsInplaceSkip(feed_names);
// convert vec func_list to graph // convert vec func_list to graph
Convert(&op_func_nodes); Convert(&op_func_nodes);
} }
...@@ -604,5 +614,12 @@ interpreter::CostInfo InterpreterCore::DryRun( ...@@ -604,5 +614,12 @@ interpreter::CostInfo InterpreterCore::DryRun(
return cost_info; return cost_info;
} }
void InterpreterCore::SetFeedVarsInplaceSkip(
const std::vector<std::string>& feed_names) {
for (auto& feed_name : feed_names) {
global_scope_->SetVarSikpInplace(feed_name, true);
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -49,7 +49,7 @@ class InterpreterCore { ...@@ -49,7 +49,7 @@ class InterpreterCore {
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors); const std::vector<framework::LoDTensor>& feed_tensors);
paddle::framework::FetchList Run(); paddle::framework::FetchList Run(const std::vector<std::string>& feed_names);
interpreter::CostInfo DryRun( interpreter::CostInfo DryRun(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
...@@ -84,6 +84,8 @@ class InterpreterCore { ...@@ -84,6 +84,8 @@ class InterpreterCore {
void BuildOperatorDependences(); void BuildOperatorDependences();
void SetFeedVarsInplaceSkip(const std::vector<std::string>& feed_names);
bool is_build_; bool is_build_;
const platform::Place& place_; const platform::Place& place_;
......
...@@ -598,6 +598,16 @@ paddle::framework::VarDesc* VariableScope::VarDesc(int id) const { ...@@ -598,6 +598,16 @@ paddle::framework::VarDesc* VariableScope::VarDesc(int id) const {
return vec_meta_info_[id].var_desc_; return vec_meta_info_[id].var_desc_;
} }
void VariableScope::SetVarSikpInplace(const std::string& name, bool skip) {
CheckExist(name);
vec_meta_info_[VarId(name)].sikp_inplace_ = skip;
}
bool VariableScope::GetVarSikpInplace(int id) const {
CheckExist(id);
return vec_meta_info_[id].sikp_inplace_;
}
void VariableScope::CheckExist(int id) const { void VariableScope::CheckExist(int id) const {
PADDLE_ENFORCE_LT(id, var_list_.size(), PADDLE_ENFORCE_LT(id, var_list_.size(),
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
......
...@@ -145,6 +145,7 @@ struct OpKernelFunc { ...@@ -145,6 +145,7 @@ struct OpKernelFunc {
struct VariableMetaInfo { struct VariableMetaInfo {
int var_ref_count_{0}; int var_ref_count_{0};
framework::VarDesc* var_desc_{nullptr}; framework::VarDesc* var_desc_{nullptr};
bool sikp_inplace_{false};
VariableMetaInfo() {} VariableMetaInfo() {}
VariableMetaInfo(int var_ref_count, framework::VarDesc* var_desc) VariableMetaInfo(int var_ref_count, framework::VarDesc* var_desc)
...@@ -228,6 +229,10 @@ class VariableScope : public ScopeBase { ...@@ -228,6 +229,10 @@ class VariableScope : public ScopeBase {
return listener_; return listener_;
} }
void SetVarSikpInplace(const std::string& name, bool skip);
bool GetVarSikpInplace(int id) const;
friend class VariableScopeListener; friend class VariableScopeListener;
private: private:
......
...@@ -68,7 +68,7 @@ paddle::framework::FetchList StandaloneExecutor::Run( ...@@ -68,7 +68,7 @@ paddle::framework::FetchList StandaloneExecutor::Run(
const std::vector<std::string>& fetch_names) { const std::vector<std::string>& fetch_names) {
auto core = GetInterpreterCore(feed_names, fetch_names, false); auto core = GetInterpreterCore(feed_names, fetch_names, false);
VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core; VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core;
return core->Run(); return core->Run(feed_names);
} }
framework::interpreter::CostInfo StandaloneExecutor::DryRun( framework::interpreter::CostInfo StandaloneExecutor::DryRun(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册