未验证 提交 4812eda5 编写于 作者: W wanghuancoder 提交者: GitHub

set feed var skip inplace, test=develop (#37467)

上级 df14dbf0
......@@ -93,7 +93,8 @@ paddle::framework::FetchList InterpreterCore::Run(
return *(fetch_var->GetMutable<framework::FetchList>());
}
paddle::framework::FetchList InterpreterCore::Run() {
paddle::framework::FetchList InterpreterCore::Run(
const std::vector<std::string>& feed_names) {
if (!is_build_) {
if (create_local_scope_ &&
global_scope_->GetMutableLocalScope() !=
......@@ -113,6 +114,7 @@ paddle::framework::FetchList InterpreterCore::Run() {
paddle::framework::interpreter::build_op_func_list(
place_, block_, &op_func_nodes, global_scope_, create_local_scope_);
is_build_ = true;
SetFeedVarsInplaceSkip(feed_names);
// convert vec func_list to graph
Convert(&op_func_nodes);
......@@ -260,6 +262,13 @@ void InterpreterCore::BuildInplace() {
for (auto& pair : in_to_outs) {
auto iter = inputs.find(pair.first);
if (iter != inputs.end() && !iter->second.empty()) {
auto in_var_desc = global_scope_->VarDesc(iter->second[0]);
if (in_var_desc && in_var_desc->Persistable()) {
continue;
}
if (global_scope_->GetVarSikpInplace(iter->second[0])) {
continue;
}
if (BuildInplaceCheckVarIsOnlyInput(iter->second[0])) {
auto iterout = outputs.find(pair.second);
if (iterout != outputs.end() && !iterout->second.empty()) {
......@@ -578,6 +587,7 @@ void InterpreterCore::Prepare(
paddle::framework::interpreter::build_op_func_list(
place_, block_, &op_func_nodes, global_scope_, create_local_scope_);
is_build_ = true;
SetFeedVarsInplaceSkip(feed_names);
// convert vec func_list to graph
Convert(&op_func_nodes);
}
......@@ -604,5 +614,12 @@ interpreter::CostInfo InterpreterCore::DryRun(
return cost_info;
}
void InterpreterCore::SetFeedVarsInplaceSkip(
const std::vector<std::string>& feed_names) {
for (auto& feed_name : feed_names) {
global_scope_->SetVarSikpInplace(feed_name, true);
}
}
} // namespace framework
} // namespace paddle
......@@ -49,7 +49,7 @@ class InterpreterCore {
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors);
paddle::framework::FetchList Run();
paddle::framework::FetchList Run(const std::vector<std::string>& feed_names);
interpreter::CostInfo DryRun(
const std::vector<std::string>& feed_names,
......@@ -84,6 +84,8 @@ class InterpreterCore {
void BuildOperatorDependences();
void SetFeedVarsInplaceSkip(const std::vector<std::string>& feed_names);
bool is_build_;
const platform::Place& place_;
......
......@@ -598,6 +598,16 @@ paddle::framework::VarDesc* VariableScope::VarDesc(int id) const {
return vec_meta_info_[id].var_desc_;
}
void VariableScope::SetVarSikpInplace(const std::string& name, bool skip) {
CheckExist(name);
vec_meta_info_[VarId(name)].sikp_inplace_ = skip;
}
bool VariableScope::GetVarSikpInplace(int id) const {
CheckExist(id);
return vec_meta_info_[id].sikp_inplace_;
}
void VariableScope::CheckExist(int id) const {
PADDLE_ENFORCE_LT(id, var_list_.size(),
platform::errors::PreconditionNotMet(
......
......@@ -145,6 +145,7 @@ struct OpKernelFunc {
struct VariableMetaInfo {
int var_ref_count_{0};
framework::VarDesc* var_desc_{nullptr};
bool sikp_inplace_{false};
VariableMetaInfo() {}
VariableMetaInfo(int var_ref_count, framework::VarDesc* var_desc)
......@@ -228,6 +229,10 @@ class VariableScope : public ScopeBase {
return listener_;
}
void SetVarSikpInplace(const std::string& name, bool skip);
bool GetVarSikpInplace(int id) const;
friend class VariableScopeListener;
private:
......
......@@ -68,7 +68,7 @@ paddle::framework::FetchList StandaloneExecutor::Run(
const std::vector<std::string>& fetch_names) {
auto core = GetInterpreterCore(feed_names, fetch_names, false);
VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core;
return core->Run();
return core->Run(feed_names);
}
framework::interpreter::CostInfo StandaloneExecutor::DryRun(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册