未验证 提交 d4cf5666 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] clear the scope listener after run (#41947)

* clear the listener after run

* only sync variables in program

* refine code

* fit for lod_tensor_blocking_queue
上级 23ad2166
......@@ -121,6 +121,9 @@ paddle::framework::FetchList InterpreterCore::Run(
Prepare(feed_names, feed_tensors, is_build);
if (is_build) {
// add listener before run and is_build=true
global_scope_->ResetListener();
ExecuteInstructionList(vec_instruction_);
}
......@@ -128,6 +131,9 @@ paddle::framework::FetchList InterpreterCore::Run(
ClearLoDTensorArrayInLocalScope();
}
// clear the listener after run
global_scope_->ClearListener();
// return Fetch Tensors
auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName);
return std::move(*fetch_var->GetMutable<framework::FetchList>());
......@@ -162,6 +168,9 @@ paddle::framework::FetchList InterpreterCore::Run(
Convert(&op_func_nodes);
} else {
// add listener before run and is_build=true
global_scope_->ResetListener();
ExecuteInstructionList(vec_instruction_);
}
......@@ -169,6 +178,9 @@ paddle::framework::FetchList InterpreterCore::Run(
ClearLoDTensorArrayInLocalScope();
}
// clear the listener after run
global_scope_->ClearListener();
// return Fetch Tensors
auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName);
return std::move(*fetch_var->GetMutable<framework::FetchList>());
......
......@@ -172,6 +172,8 @@ void build_variable_scope(const framework::BlockDesc& block,
auto* ptr = inner_scope->Var(var_name);
VLOG(3) << "Initialize Variable " << var_name;
// NOTE(zhiqiu): if var exists in scope and the type is right,
// InitializeVariable will not create a new variable.
InitializeVariable(ptr, var_desc->GetType());
VLOG(3) << "Create Variable " << var_name << " global, which pointer is "
<< ptr << " type is " << static_cast<int>(var_desc->GetType());
......
......@@ -642,6 +642,28 @@ void VariableScope::CheckExist(const std::string& name) const {
"%s not in VariableScope.", name));
}
void VariableScope::ClearListener() {
if (scope_ && listener_ && scope_->HasListener(listener_)) {
VLOG(4) << "Clear listener " << listener_ << " for " << scope_;
scope_->DelListener(listener_);
}
if (local_scope_ && listener_ && local_scope_->HasListener(listener_)) {
VLOG(4) << "Clear listener " << listener_ << " for " << local_scope_;
local_scope_->DelListener(listener_);
}
}
void VariableScope::ResetListener() {
if (scope_ && listener_ && !scope_->HasListener(listener_)) {
VLOG(4) << "Add listener " << listener_ << " for " << scope_;
scope_->AddListener(listener_);
}
if (local_scope_ && listener_ && !local_scope_->HasListener(listener_)) {
VLOG(4) << "Add listener " << listener_ << " for " << local_scope_;
local_scope_->AddListener(listener_);
}
}
VariableScopeListener::VariableScopeListener(VariableScope* var_scope) {
var_scope_ = var_scope;
}
......
......@@ -238,6 +238,10 @@ class VariableScope : public ScopeBase {
bool GetVarSikpInplace(int id) const;
void ClearListener();
void ResetListener();
friend class VariableScopeListener;
private:
......
......@@ -25,19 +25,21 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
startup_prog_(startup_prog),
main_prog_(main_prog),
global_scope_(VariableScope(scope)) {
// NOTE(zhiqiu): it is needed to sync thhe variables in scope to
// variable_scope,
// since the some variable only exists in startup program, e.g,
// lod_tensor_blocking_queue_0 used in dataloader.
// These variables may be created in scope during runing startup program with
// original executor.
// NOTE(zhiqiu): it is needed to sync the variables in scope to
// variable_scope, since the some variable only exists in scope.
// For example, 'lod_tensor_blocking_queue_0' used in dataloader.
// These variables may be created in scope, and it is not existed as
// variable in program.
if (scope) {
auto name_list = scope->LocalVarNames();
for (auto name : name_list) {
VLOG(4) << "Sync Variable from variable scope: " << name;
auto v = scope->Var(name);
if (!global_scope_.HasVar(name)) {
global_scope_.AddVar(name, *v);
const std::string blocking_queue_prefix = "lod_tensor_blocking_queue";
auto vars = scope->LocalVarNames();
for (const auto& name : vars) {
if (name.find(blocking_queue_prefix) != std::string::npos) {
if (!global_scope_.HasVar(name)) {
auto* v = scope->Var(name);
VLOG(4) << "Sync Variable from scope to variable scope: " << name;
global_scope_.AddVar(name, *v);
}
}
}
}
......
......@@ -289,6 +289,11 @@ void Scope::DelListener(const std::shared_ptr<ScopeListener>& listener) {
listeners_.remove(listener);
}
bool Scope::HasListener(const std::shared_ptr<ScopeListener>& listener) {
auto it = std::find(listeners_.begin(), listeners_.end(), listener);
return it != listeners_.end();
}
void Scope::EraseVarsExcept(const std::unordered_set<Variable*>& vars) {
SCOPE_VARS_WRITER_LOCK
for (auto iter = vars_.begin(); iter != vars_.end();) {
......
......@@ -154,6 +154,8 @@ class Scope : public ScopeBase {
void DelListener(const std::shared_ptr<ScopeListener>& listener);
bool HasListener(const std::shared_ptr<ScopeListener>& listener);
protected:
struct KeyHasher {
std::size_t operator()(const std::string& key) const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册