未验证 提交 80992253 编写于 作者: L Leo Chen 提交者: GitHub

[cherry-pick] Cherry pick pr of new-exec (#42009)

* [new-exec] shrink downstream map (#41471)

* shrink downstream map

* shrink last live ops of var

* add comment

* fix bug

* add dependency for send/recv to support pp parallel (#41652)

* [new-exec] clear the scope listener after run (#41947)

* clear the listener after run

* only sync variables in program

* refine code

* fit for lod_tensor_blocking_queue
上级 60c212d5
...@@ -121,6 +121,9 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -121,6 +121,9 @@ paddle::framework::FetchList InterpreterCore::Run(
Prepare(feed_names, feed_tensors, is_build); Prepare(feed_names, feed_tensors, is_build);
if (is_build) { if (is_build) {
// add listener before run and is_build=true
global_scope_->ResetListener();
ExecuteInstructionList(vec_instruction_); ExecuteInstructionList(vec_instruction_);
} }
...@@ -128,6 +131,9 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -128,6 +131,9 @@ paddle::framework::FetchList InterpreterCore::Run(
ClearLoDTensorArrayInLocalScope(); ClearLoDTensorArrayInLocalScope();
} }
// clear the listener after run
global_scope_->ClearListener();
// return Fetch Tensors // return Fetch Tensors
auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName); auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName);
return std::move(*fetch_var->GetMutable<framework::FetchList>()); return std::move(*fetch_var->GetMutable<framework::FetchList>());
...@@ -162,6 +168,9 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -162,6 +168,9 @@ paddle::framework::FetchList InterpreterCore::Run(
Convert(&op_func_nodes); Convert(&op_func_nodes);
} else { } else {
// add listener before run and is_build=true
global_scope_->ResetListener();
ExecuteInstructionList(vec_instruction_); ExecuteInstructionList(vec_instruction_);
} }
...@@ -169,6 +178,9 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -169,6 +178,9 @@ paddle::framework::FetchList InterpreterCore::Run(
ClearLoDTensorArrayInLocalScope(); ClearLoDTensorArrayInLocalScope();
} }
// clear the listener after run
global_scope_->ClearListener();
// return Fetch Tensors // return Fetch Tensors
auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName); auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName);
return std::move(*fetch_var->GetMutable<framework::FetchList>()); return std::move(*fetch_var->GetMutable<framework::FetchList>());
...@@ -192,7 +204,8 @@ void InterpreterCore::BuildOperatorDependences() { ...@@ -192,7 +204,8 @@ void InterpreterCore::BuildOperatorDependences() {
// Schedule // Schedule
auto op_nums = vec_instruction_.size(); auto op_nums = vec_instruction_.size();
dependecy_count_.resize(op_nums); dependecy_count_.resize(op_nums);
auto op2downstream = interpreter::build_op_downstream_map(vec_instruction_); auto op2downstream = interpreter::build_op_downstream_map(
vec_instruction_, &op_happens_before_);
for (size_t op = 0; op < vec_instruction_.size(); ++op) { for (size_t op = 0; op < vec_instruction_.size(); ++op) {
auto op_list = op2downstream[op]; auto op_list = op2downstream[op];
std::vector<size_t> downsteam_vector(op_list.begin(), op_list.end()); std::vector<size_t> downsteam_vector(op_list.begin(), op_list.end());
...@@ -213,18 +226,21 @@ void InterpreterCore::Convert( ...@@ -213,18 +226,21 @@ void InterpreterCore::Convert(
auto op_nums = nodes.size(); auto op_nums = nodes.size();
vec_instruction_.reserve(op_nums); vec_instruction_.reserve(op_nums);
for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) { for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
auto& op_func_node = nodes[op_idx]; auto& op_func_node = nodes[op_idx];
auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node); auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node);
vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_); vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_);
auto& instr = vec_instruction_.back(); }
BuildOperatorDependences();
// calculate last_live_ops_
for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
auto& instr = vec_instruction_[op_idx];
OpInOutInfo info; OpInOutInfo info;
std::vector<size_t> gc_check_input_list; std::set<size_t> gc_check_inputs;
for (auto& item : op_func_node.input_index) { for (auto& item : instr.Inputs()) {
for (auto id : item.second) { for (auto id : item.second) {
if (id == kEmptyVarIndex) { if (id == kEmptyVarIndex) {
continue; continue;
...@@ -232,38 +248,24 @@ void InterpreterCore::Convert( ...@@ -232,38 +248,24 @@ void InterpreterCore::Convert(
input_var2op_info_.at(id).push_back(op_idx); input_var2op_info_.at(id).push_back(op_idx);
// var can be gc-ed // var can be gc-ed
if (!info.IsBuilt()) { if (!info.IsBuilt()) {
info.Build(op_func_node.operator_base_.get()); info.Build(instr.OpBase());
} }
auto* var_desc = global_scope_->VarDesc(id); auto* var_desc = global_scope_->VarDesc(id);
if (var_desc) { if (var_desc) {
if (info.IsInArgBufferNeeded(var_desc->Name())) { if (info.IsInArgBufferNeeded(var_desc->Name())) {
gc_check_input_list.push_back(id); gc_check_inputs.insert(id);
} }
} else { } else {
gc_check_input_list.push_back(id); gc_check_inputs.insert(id);
} }
} }
} }
std::sort(gc_check_input_list.begin(), gc_check_input_list.end());
auto last =
std::unique(gc_check_input_list.begin(), gc_check_input_list.end());
gc_check_input_list.erase(last, gc_check_input_list.end());
for (auto var_id : gc_check_input_list) { for (auto var_id : gc_check_inputs) {
paddle::framework::Variable* var = global_scope_->Var(var_id); paddle::framework::Variable* var = global_scope_->Var(var_id);
if (var->IsType<LoDTensor>() || var->IsType<phi::SelectedRows>() || if (var->IsType<LoDTensor>() || var->IsType<phi::SelectedRows>() ||
var->IsType<LoDTensorArray>()) { var->IsType<LoDTensorArray>()) {
vec_meta_info[var_id].var_ref_count_++; last_live_ops_[var_id].insert(op_idx);
// TODO(zhiqiu): not all var needs to be checked, var need to be checked
// only
// after the last_live_op. For example,
// b = op1(a)
// c = op2(a, b)
// in this case, a is the input of op1 and op2, we only need to check
// a after op2, because op2 always uses a after op1.
instr.AddGCCheckVar(var_id);
VLOG(4) << "clear " << global_scope_->GetNameById(var_id) << " after "
<< instr.OpBase()->Type();
} else { } else {
VLOG(4) << "not clear " << global_scope_->GetNameById(var_id) VLOG(4) << "not clear " << global_scope_->GetNameById(var_id)
<< " after " << instr.OpBase()->Type() << " after " << instr.OpBase()->Type()
...@@ -276,19 +278,45 @@ void InterpreterCore::Convert( ...@@ -276,19 +278,45 @@ void InterpreterCore::Convert(
for (size_t i = 0; i < vec_instruction_.size(); ++i) { for (size_t i = 0; i < vec_instruction_.size(); ++i) {
// checkout ouput // checkout ouput
for (auto& item : vec_instruction_[i].Outputs()) { for (auto& item : vec_instruction_[i].Outputs()) {
for (auto id : item.second) { for (auto var_id : item.second) {
if (input_var2op_info_.at(id).size() == 0) { if (input_var2op_info_.at(var_id).size() == 0) {
// output var not be used by any kernel last_live_ops_[var_id].insert(i);
vec_instruction_[i].AddGCCheckVar(id);
VLOG(4) << "clear " << global_scope_->GetNameById(id) << " after "
<< vec_instruction_[i].OpBase()->Type();
vec_meta_info[id].var_ref_count_++;
} }
} }
} }
} }
BuildOperatorDependences(); // shrink, find the downstream op that has no other op in the
// downstream list happens before it
// For example,
// b = op1(a)
// c = op2(a, b)
// in this case, a is the input of op1 and op2, we only need to check
// a after op2, because op2 always uses a after op1.
for (size_t i = 0; i < last_live_ops_.size(); ++i) {
std::set<size_t> minumum_last_live_ops;
for (size_t item : last_live_ops_[i]) {
bool not_before_any = true;
// find the op that is not executed before any
for (size_t other_item : last_live_ops_[i]) {
if (op_happens_before_[item][other_item]) {
VLOG(8) << "happens_before: " << item << "->" << other_item
<< ", so skip " << item;
not_before_any = false;
break;
}
}
if (not_before_any) {
VLOG(8) << "last live op of var " << i << " "
<< global_scope_->GetNameById(i) << " : " << item << " "
<< vec_instruction_[item].OpBase()->Type();
minumum_last_live_ops.insert(item);
vec_instruction_[item].AddGCCheckVar(i);
}
}
last_live_ops_[i] = minumum_last_live_ops;
vec_meta_info[i].var_ref_count_ = last_live_ops_[i].size();
}
for (size_t i = 0; i < vec_instruction_.size(); ++i) { for (size_t i = 0; i < vec_instruction_.size(); ++i) {
BuildAndCacheInstructionCtx(&vec_instruction_[i]); BuildAndCacheInstructionCtx(&vec_instruction_[i]);
......
...@@ -109,6 +109,11 @@ class InterpreterCore { ...@@ -109,6 +109,11 @@ class InterpreterCore {
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
// op_happens_before_[i][j] == true means op[i] happens before op[j]
std::vector<std::vector<bool>> op_happens_before_;
// last_live_ops_[i] contains the id of operatos that last access var[i]
std::map<size_t, std::set<size_t>> last_live_ops_;
std::vector<size_t> dependecy_count_; std::vector<size_t> dependecy_count_;
std::atomic<size_t> unfinished_op_numer_{0}; std::atomic<size_t> unfinished_op_numer_{0};
std::vector<std::vector<size_t>> input_var2op_info_; std::vector<std::vector<size_t>> input_var2op_info_;
......
...@@ -172,6 +172,8 @@ void build_variable_scope(const framework::BlockDesc& block, ...@@ -172,6 +172,8 @@ void build_variable_scope(const framework::BlockDesc& block,
auto* ptr = inner_scope->Var(var_name); auto* ptr = inner_scope->Var(var_name);
VLOG(3) << "Initialize Variable " << var_name; VLOG(3) << "Initialize Variable " << var_name;
// NOTE(zhiqiu): if var exists in scope and the type is right,
// InitializeVariable will not create a new variable.
InitializeVariable(ptr, var_desc->GetType()); InitializeVariable(ptr, var_desc->GetType());
VLOG(3) << "Create Variable " << var_name << " global, which pointer is " VLOG(3) << "Create Variable " << var_name << " global, which pointer is "
<< ptr << " type is " << static_cast<int>(var_desc->GetType()); << ptr << " type is " << static_cast<int>(var_desc->GetType());
...@@ -614,23 +616,125 @@ void update_var_min_rw_op(const std::map<int, std::set<int>>& op2dependences, ...@@ -614,23 +616,125 @@ void update_var_min_rw_op(const std::map<int, std::set<int>>& op2dependences,
} }
std::map<int, std::list<int>> get_downstream_map( std::map<int, std::list<int>> get_downstream_map(
const std::map<int, std::set<int>>& op2dependences) { const std::map<int, std::set<int>>& op2dependences,
// op2dependences is op -> it's dependences. we want to get op -> [ops] map, std::vector<std::vector<bool>>* op_happens_before) {
// step1: convert op2dependences to downstream_map directly
// op2dependences is op -> it's dependences.
// we want to get op -> [next ops] map,
// where ops is the next instruction of op. // where ops is the next instruction of op.
std::map<int, std::list<int>> result; std::map<int, std::list<int>> downstream;
for (auto& item : op2dependences) { for (auto& item : op2dependences) {
int op = item.first; int op = item.first;
for (auto dep_op : item.second) { for (auto dep_op : item.second) {
if (result.find(dep_op) == result.end()) if (downstream.find(dep_op) == downstream.end())
result[dep_op] = std::list<int>(); downstream[dep_op] = std::list<int>();
result[dep_op].push_back(op); downstream[dep_op].push_back(op);
} }
} }
return std::move(result);
auto downstream_map_to_str = [&]() -> std::string {
std::ostringstream oss;
for (auto pair : downstream) {
oss << pair.first << " -> ";
std::copy(pair.second.begin(), pair.second.end(),
std::ostream_iterator<int>(oss, " "));
oss << std::endl;
}
return oss.str();
};
auto downstream_map_count = [&]() -> size_t {
size_t count = 0;
for (auto pair : downstream) {
count += pair.second.size();
}
return count;
};
VLOG(6) << "downstream count: " << downstream_map_count();
VLOG(6) << "downstream_map: " << std::endl << downstream_map_to_str();
// step2: remove unneccessary downstream ops
// for example, a->b->c
// a: b, c
// b: c
// =>
// a: b
// b: c
// NOTE(zhiqiu): the size of downstream != size of op2dependences
// since there are some ops that have no downstream-op.
auto op_num = op2dependences.size();
// happens_before[i][j] means i should be executed before j
op_happens_before->resize(op_num);
for (size_t i = 0; i < op_num; ++i) {
(*op_happens_before)[i].resize(op_num);
std::fill((*op_happens_before)[i].begin(), (*op_happens_before)[i].end(),
false);
}
// bfs to get all next ops
auto bfs = [&](size_t op_idx) {
std::queue<size_t> q;
std::vector<bool> visited(op_num, false);
q.push(op_idx);
while (!q.empty()) {
size_t op = q.front();
q.pop();
visited[op] = true;
if (!downstream.count(op)) {
continue;
}
for (auto next : downstream[op]) {
if (!visited[next]) {
PADDLE_ENFORCE_EQ((*op_happens_before)[next][op_idx], false,
paddle::platform::errors::AlreadyExists(
"There exists circle in graph, expected "
"%d->%d, but already got %d->%d",
op_idx, next, next, op_idx));
(*op_happens_before)[op_idx][next] = true;
VLOG(8) << "happens before: " << op_idx << " " << next;
q.push(next);
}
}
}
};
for (size_t i = 0; i < op_num; ++i) {
bfs(i);
}
// shrink, find the downstream op that has no other op in the
// downstream list happens before it
for (size_t i = 0; i < op_num; ++i) {
std::list<int> minumum_nexts;
for (size_t item : downstream[i]) {
bool not_after_any = true;
// find the op that is not executed after any
for (size_t other_item : downstream[i]) {
if ((*op_happens_before)[other_item][item]) {
VLOG(8) << "happens_before: " << other_item << "->" << item
<< ", so skip " << item;
not_after_any = false;
break;
}
}
if (not_after_any) {
VLOG(8) << "downstream op of " << i << ": " << item;
minumum_nexts.push_back(item);
}
}
downstream[i] = minumum_nexts;
}
VLOG(6) << "downstream count: " << downstream_map_count();
VLOG(6) << "downstream_map: " << std::endl << downstream_map_to_str();
return std::move(downstream);
} }
std::map<int, std::list<int>> build_op_downstream_map( std::map<int, std::list<int>> build_op_downstream_map(
const std::vector<Instruction>& vec_instruction) { const std::vector<Instruction>& vec_instruction,
std::vector<std::vector<bool>>* op_happens_before) {
auto var2min_rw_op = std::map< auto var2min_rw_op = std::map<
int, std::list<int>>(); // # map from variable id to read / write op id. int, std::list<int>>(); // # map from variable id to read / write op id.
auto var2recent_write_op = auto var2recent_write_op =
...@@ -710,7 +814,12 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -710,7 +814,12 @@ std::map<int, std::list<int>> build_op_downstream_map(
// sequentially // sequentially
const std::set<std::string> random_op_set = { const std::set<std::string> random_op_set = {
"bernoulli", "poisson", "multinomial", "gaussian_random", "bernoulli", "poisson", "multinomial", "gaussian_random",
"uniform_random", "randint", "randperm", "exponential"}; "truncated_gaussian_random", "uniform_random", "randint", "randperm",
"exponential",
"sampling_id"
"dropout",
"class_center_sample",
};
int dependence_op_idx = -1; int dependence_op_idx = -1;
for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) { for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
...@@ -723,13 +832,26 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -723,13 +832,26 @@ std::map<int, std::list<int>> build_op_downstream_map(
} }
// add dependency for communication op // add dependency for communication op
auto is_comm_op = [](std::string op) -> bool {
const std::set<std::string> special_comm_op_set = {
"send", "recv", "send_v2", "recv_v2",
};
const std::string communication_op_prefix = "c_"; const std::string communication_op_prefix = "c_";
if (op.find(communication_op_prefix) != std::string::npos ||
special_comm_op_set.count(op)) {
return true;
}
return false;
};
dependence_op_idx = -1; dependence_op_idx = -1;
for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) { for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
if (vec_instruction[op_idx].OpBase()->Type().find( if (is_comm_op(vec_instruction[op_idx].OpBase()->Type())) {
communication_op_prefix) != std::string::npos) {
if (dependence_op_idx != -1) { if (dependence_op_idx != -1) {
op2dependences[op_idx].insert(dependence_op_idx); op2dependences[op_idx].insert(dependence_op_idx);
VLOG(4) << "Add depend from "
<< vec_instruction[dependence_op_idx].OpBase()->Type() << " to "
<< vec_instruction[op_idx].OpBase()->Type();
} }
dependence_op_idx = op_idx; dependence_op_idx = op_idx;
} }
...@@ -833,10 +955,8 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -833,10 +955,8 @@ std::map<int, std::list<int>> build_op_downstream_map(
for (size_t j = first_read_fused_out_op + 1; j < vec_instruction.size(); for (size_t j = first_read_fused_out_op + 1; j < vec_instruction.size();
++j) { ++j) {
if (j == target + 1 && if (j == target + 1 &&
vec_instruction[target].OpBase()->Type().find( is_comm_op(vec_instruction[target].OpBase()->Type()) &&
communication_op_prefix) != std::string::npos && is_comm_op(vec_instruction[j].OpBase()->Type())) {
vec_instruction[j].OpBase()->Type().find(communication_op_prefix) !=
std::string::npos) {
VLOG(4) << "Found consecutive communication ops, " VLOG(4) << "Found consecutive communication ops, "
<< vec_instruction[target].OpBase()->Type() << " -> " << vec_instruction[target].OpBase()->Type() << " -> "
<< vec_instruction[j].OpBase()->Type(); << vec_instruction[j].OpBase()->Type();
...@@ -857,13 +977,13 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -857,13 +977,13 @@ std::map<int, std::list<int>> build_op_downstream_map(
} }
} }
for (auto pair : op2dependences) { for (auto pair : op2dependences) {
VLOG(10) << pair.first << " Depends on " << pair.second.size();
std::ostringstream oss; std::ostringstream oss;
oss << pair.first << " Depends on " << pair.second.size() << " ops: ";
std::copy(pair.second.begin(), pair.second.end(), std::copy(pair.second.begin(), pair.second.end(),
std::ostream_iterator<int>(oss, " ")); std::ostream_iterator<int>(oss, " "));
VLOG(10) << oss.str(); VLOG(10) << oss.str();
} }
return std::move(get_downstream_map(op2dependences)); return std::move(get_downstream_map(op2dependences, op_happens_before));
} }
} // namespace interpreter } // namespace interpreter
......
...@@ -116,7 +116,8 @@ void build_op_func_list(const platform::Place& place, ...@@ -116,7 +116,8 @@ void build_op_func_list(const platform::Place& place,
VariableScope* var_scope, bool use_local_scope = true); VariableScope* var_scope, bool use_local_scope = true);
std::map<int, std::list<int>> build_op_downstream_map( std::map<int, std::list<int>> build_op_downstream_map(
const std::vector<Instruction>& vec_instruction); const std::vector<Instruction>& vec_instruction,
std::vector<std::vector<bool>>* op_happens_before);
void add_fetch(const std::vector<std::string>& fetch_names, void add_fetch(const std::vector<std::string>& fetch_names,
framework::BlockDesc* block); framework::BlockDesc* block);
......
...@@ -641,6 +641,28 @@ void VariableScope::CheckExist(const std::string& name) const { ...@@ -641,6 +641,28 @@ void VariableScope::CheckExist(const std::string& name) const {
"%s not in VariableScope.", name)); "%s not in VariableScope.", name));
} }
void VariableScope::ClearListener() {
if (scope_ && listener_ && scope_->HasListener(listener_)) {
VLOG(4) << "Clear listener " << listener_ << " for " << scope_;
scope_->DelListener(listener_);
}
if (local_scope_ && listener_ && local_scope_->HasListener(listener_)) {
VLOG(4) << "Clear listener " << listener_ << " for " << local_scope_;
local_scope_->DelListener(listener_);
}
}
void VariableScope::ResetListener() {
if (scope_ && listener_ && !scope_->HasListener(listener_)) {
VLOG(4) << "Add listener " << listener_ << " for " << scope_;
scope_->AddListener(listener_);
}
if (local_scope_ && listener_ && !local_scope_->HasListener(listener_)) {
VLOG(4) << "Add listener " << listener_ << " for " << local_scope_;
local_scope_->AddListener(listener_);
}
}
VariableScopeListener::VariableScopeListener(VariableScope* var_scope) { VariableScopeListener::VariableScopeListener(VariableScope* var_scope) {
var_scope_ = var_scope; var_scope_ = var_scope;
} }
......
...@@ -238,6 +238,10 @@ class VariableScope : public ScopeBase { ...@@ -238,6 +238,10 @@ class VariableScope : public ScopeBase {
bool GetVarSikpInplace(int id) const; bool GetVarSikpInplace(int id) const;
void ClearListener();
void ResetListener();
friend class VariableScopeListener; friend class VariableScopeListener;
private: private:
......
...@@ -24,22 +24,24 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, ...@@ -24,22 +24,24 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
startup_prog_(startup_prog), startup_prog_(startup_prog),
main_prog_(main_prog), main_prog_(main_prog),
global_scope_(VariableScope(scope)) { global_scope_(VariableScope(scope)) {
// NOTE(zhiqiu): it is needed to sync thhe variables in scope to // NOTE(zhiqiu): it is needed to sync the variables in scope to
// variable_scope, // variable_scope, since the some variable only exists in scope.
// since the some variable only exists in startup program, e.g, // For example, 'lod_tensor_blocking_queue_0' used in dataloader.
// lod_tensor_blocking_queue_0 used in dataloader. // These variables may be created in scope, and it is not existed as
// These variables may be created in scope during runing startup program with // variable in program.
// original executor.
if (scope) { if (scope) {
auto name_list = scope->LocalVarNames(); const std::string blocking_queue_prefix = "lod_tensor_blocking_queue";
for (auto name : name_list) { auto vars = scope->LocalVarNames();
VLOG(4) << "Sync Variable from variable scope: " << name; for (const auto& name : vars) {
auto v = scope->Var(name); if (name.find(blocking_queue_prefix) != std::string::npos) {
if (!global_scope_.HasVar(name)) { if (!global_scope_.HasVar(name)) {
auto* v = scope->Var(name);
VLOG(4) << "Sync Variable from scope to variable scope: " << name;
global_scope_.AddVar(name, *v); global_scope_.AddVar(name, *v);
} }
} }
} }
}
// NOTE(zhiqiu): for startup_program, initialize scope and run once // NOTE(zhiqiu): for startup_program, initialize scope and run once
// if startup_program is empty, the scope is initialize during first run // if startup_program is empty, the scope is initialize during first run
......
...@@ -289,6 +289,11 @@ void Scope::DelListener(const std::shared_ptr<ScopeListener>& listener) { ...@@ -289,6 +289,11 @@ void Scope::DelListener(const std::shared_ptr<ScopeListener>& listener) {
listeners_.remove(listener); listeners_.remove(listener);
} }
bool Scope::HasListener(const std::shared_ptr<ScopeListener>& listener) {
auto it = std::find(listeners_.begin(), listeners_.end(), listener);
return it != listeners_.end();
}
void Scope::EraseVarsExcept(const std::unordered_set<Variable*>& vars) { void Scope::EraseVarsExcept(const std::unordered_set<Variable*>& vars) {
SCOPE_VARS_WRITER_LOCK SCOPE_VARS_WRITER_LOCK
for (auto iter = vars_.begin(); iter != vars_.end();) { for (auto iter = vars_.begin(); iter != vars_.end();) {
......
...@@ -154,6 +154,8 @@ class Scope : public ScopeBase { ...@@ -154,6 +154,8 @@ class Scope : public ScopeBase {
void DelListener(const std::shared_ptr<ScopeListener>& listener); void DelListener(const std::shared_ptr<ScopeListener>& listener);
bool HasListener(const std::shared_ptr<ScopeListener>& listener);
protected: protected:
struct KeyHasher { struct KeyHasher {
std::size_t operator()(const std::string& key) const { std::size_t operator()(const std::string& key) const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册