未验证 提交 d4cf5666 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] clear the scope listener after run (#41947)

* clear the listener after run

* only sync variables in program

* refine code

* fit for lod_tensor_blocking_queue
上级 23ad2166
...@@ -121,6 +121,9 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -121,6 +121,9 @@ paddle::framework::FetchList InterpreterCore::Run(
Prepare(feed_names, feed_tensors, is_build); Prepare(feed_names, feed_tensors, is_build);
if (is_build) { if (is_build) {
// add listener before run and is_build=true
global_scope_->ResetListener();
ExecuteInstructionList(vec_instruction_); ExecuteInstructionList(vec_instruction_);
} }
...@@ -128,6 +131,9 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -128,6 +131,9 @@ paddle::framework::FetchList InterpreterCore::Run(
ClearLoDTensorArrayInLocalScope(); ClearLoDTensorArrayInLocalScope();
} }
// clear the listener after run
global_scope_->ClearListener();
// return Fetch Tensors // return Fetch Tensors
auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName); auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName);
return std::move(*fetch_var->GetMutable<framework::FetchList>()); return std::move(*fetch_var->GetMutable<framework::FetchList>());
...@@ -162,6 +168,9 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -162,6 +168,9 @@ paddle::framework::FetchList InterpreterCore::Run(
Convert(&op_func_nodes); Convert(&op_func_nodes);
} else { } else {
// add listener before run and is_build=true
global_scope_->ResetListener();
ExecuteInstructionList(vec_instruction_); ExecuteInstructionList(vec_instruction_);
} }
...@@ -169,6 +178,9 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -169,6 +178,9 @@ paddle::framework::FetchList InterpreterCore::Run(
ClearLoDTensorArrayInLocalScope(); ClearLoDTensorArrayInLocalScope();
} }
// clear the listener after run
global_scope_->ClearListener();
// return Fetch Tensors // return Fetch Tensors
auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName); auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName);
return std::move(*fetch_var->GetMutable<framework::FetchList>()); return std::move(*fetch_var->GetMutable<framework::FetchList>());
......
...@@ -172,6 +172,8 @@ void build_variable_scope(const framework::BlockDesc& block, ...@@ -172,6 +172,8 @@ void build_variable_scope(const framework::BlockDesc& block,
auto* ptr = inner_scope->Var(var_name); auto* ptr = inner_scope->Var(var_name);
VLOG(3) << "Initialize Variable " << var_name; VLOG(3) << "Initialize Variable " << var_name;
// NOTE(zhiqiu): if var exists in scope and the type is right,
// InitializeVariable will not create a new variable.
InitializeVariable(ptr, var_desc->GetType()); InitializeVariable(ptr, var_desc->GetType());
VLOG(3) << "Create Variable " << var_name << " global, which pointer is " VLOG(3) << "Create Variable " << var_name << " global, which pointer is "
<< ptr << " type is " << static_cast<int>(var_desc->GetType()); << ptr << " type is " << static_cast<int>(var_desc->GetType());
......
...@@ -642,6 +642,28 @@ void VariableScope::CheckExist(const std::string& name) const { ...@@ -642,6 +642,28 @@ void VariableScope::CheckExist(const std::string& name) const {
"%s not in VariableScope.", name)); "%s not in VariableScope.", name));
} }
void VariableScope::ClearListener() {
if (scope_ && listener_ && scope_->HasListener(listener_)) {
VLOG(4) << "Clear listener " << listener_ << " for " << scope_;
scope_->DelListener(listener_);
}
if (local_scope_ && listener_ && local_scope_->HasListener(listener_)) {
VLOG(4) << "Clear listener " << listener_ << " for " << local_scope_;
local_scope_->DelListener(listener_);
}
}
void VariableScope::ResetListener() {
if (scope_ && listener_ && !scope_->HasListener(listener_)) {
VLOG(4) << "Add listener " << listener_ << " for " << scope_;
scope_->AddListener(listener_);
}
if (local_scope_ && listener_ && !local_scope_->HasListener(listener_)) {
VLOG(4) << "Add listener " << listener_ << " for " << local_scope_;
local_scope_->AddListener(listener_);
}
}
VariableScopeListener::VariableScopeListener(VariableScope* var_scope) { VariableScopeListener::VariableScopeListener(VariableScope* var_scope) {
var_scope_ = var_scope; var_scope_ = var_scope;
} }
......
...@@ -238,6 +238,10 @@ class VariableScope : public ScopeBase { ...@@ -238,6 +238,10 @@ class VariableScope : public ScopeBase {
bool GetVarSikpInplace(int id) const; bool GetVarSikpInplace(int id) const;
void ClearListener();
void ResetListener();
friend class VariableScopeListener; friend class VariableScopeListener;
private: private:
......
...@@ -25,19 +25,21 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, ...@@ -25,19 +25,21 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
startup_prog_(startup_prog), startup_prog_(startup_prog),
main_prog_(main_prog), main_prog_(main_prog),
global_scope_(VariableScope(scope)) { global_scope_(VariableScope(scope)) {
// NOTE(zhiqiu): it is needed to sync thhe variables in scope to // NOTE(zhiqiu): it is needed to sync the variables in scope to
// variable_scope, // variable_scope, since the some variable only exists in scope.
// since the some variable only exists in startup program, e.g, // For example, 'lod_tensor_blocking_queue_0' used in dataloader.
// lod_tensor_blocking_queue_0 used in dataloader. // These variables may be created in scope, and it is not existed as
// These variables may be created in scope during runing startup program with // variable in program.
// original executor.
if (scope) { if (scope) {
auto name_list = scope->LocalVarNames(); const std::string blocking_queue_prefix = "lod_tensor_blocking_queue";
for (auto name : name_list) { auto vars = scope->LocalVarNames();
VLOG(4) << "Sync Variable from variable scope: " << name; for (const auto& name : vars) {
auto v = scope->Var(name); if (name.find(blocking_queue_prefix) != std::string::npos) {
if (!global_scope_.HasVar(name)) { if (!global_scope_.HasVar(name)) {
global_scope_.AddVar(name, *v); auto* v = scope->Var(name);
VLOG(4) << "Sync Variable from scope to variable scope: " << name;
global_scope_.AddVar(name, *v);
}
} }
} }
} }
......
...@@ -289,6 +289,11 @@ void Scope::DelListener(const std::shared_ptr<ScopeListener>& listener) { ...@@ -289,6 +289,11 @@ void Scope::DelListener(const std::shared_ptr<ScopeListener>& listener) {
listeners_.remove(listener); listeners_.remove(listener);
} }
bool Scope::HasListener(const std::shared_ptr<ScopeListener>& listener) {
auto it = std::find(listeners_.begin(), listeners_.end(), listener);
return it != listeners_.end();
}
void Scope::EraseVarsExcept(const std::unordered_set<Variable*>& vars) { void Scope::EraseVarsExcept(const std::unordered_set<Variable*>& vars) {
SCOPE_VARS_WRITER_LOCK SCOPE_VARS_WRITER_LOCK
for (auto iter = vars_.begin(); iter != vars_.end();) { for (auto iter = vars_.begin(); iter != vars_.end();) {
......
...@@ -154,6 +154,8 @@ class Scope : public ScopeBase { ...@@ -154,6 +154,8 @@ class Scope : public ScopeBase {
void DelListener(const std::shared_ptr<ScopeListener>& listener); void DelListener(const std::shared_ptr<ScopeListener>& listener);
bool HasListener(const std::shared_ptr<ScopeListener>& listener);
protected: protected:
struct KeyHasher { struct KeyHasher {
std::size_t operator()(const std::string& key) const { std::size_t operator()(const std::string& key) const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册