diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 20a6e53479323afc5aa153d08f371029dace7f72..80553765df12b079640e7814a7398b10ce8bca88 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -192,7 +192,8 @@ void InterpreterCore::BuildOperatorDependences() { // Schedule auto op_nums = vec_instruction_.size(); dependecy_count_.resize(op_nums); - auto op2downstream = interpreter::build_op_downstream_map(vec_instruction_); + auto op2downstream = interpreter::build_op_downstream_map( + vec_instruction_, &op_happens_before_); for (size_t op = 0; op < vec_instruction_.size(); ++op) { auto op_list = op2downstream[op]; std::vector downsteam_vector(op_list.begin(), op_list.end()); @@ -213,18 +214,21 @@ void InterpreterCore::Convert( auto op_nums = nodes.size(); vec_instruction_.reserve(op_nums); - for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) { auto& op_func_node = nodes[op_idx]; auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node); - vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_); - auto& instr = vec_instruction_.back(); + } + BuildOperatorDependences(); + + // calculate last_live_ops_ + for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) { + auto& instr = vec_instruction_[op_idx]; OpInOutInfo info; - std::vector gc_check_input_list; + std::set gc_check_inputs; - for (auto& item : op_func_node.input_index) { + for (auto& item : instr.Inputs()) { for (auto id : item.second) { if (id == kEmptyVarIndex) { continue; @@ -232,38 +236,24 @@ void InterpreterCore::Convert( input_var2op_info_.at(id).push_back(op_idx); // var can be gc-ed if (!info.IsBuilt()) { - info.Build(op_func_node.operator_base_.get()); + info.Build(instr.OpBase()); } auto* var_desc = global_scope_->VarDesc(id); if (var_desc) { if (info.IsInArgBufferNeeded(var_desc->Name())) { - gc_check_input_list.push_back(id); + gc_check_inputs.insert(id); } } else { - gc_check_input_list.push_back(id); + gc_check_inputs.insert(id); } } } - std::sort(gc_check_input_list.begin(), gc_check_input_list.end()); - auto last = - std::unique(gc_check_input_list.begin(), gc_check_input_list.end()); - gc_check_input_list.erase(last, gc_check_input_list.end()); - for (auto var_id : gc_check_input_list) { + for (auto var_id : gc_check_inputs) { paddle::framework::Variable* var = global_scope_->Var(var_id); if (var->IsType() || var->IsType() || var->IsType()) { - vec_meta_info[var_id].var_ref_count_++; - // TODO(zhiqiu): not all var needs to be checked, var need to be checked - // only - // after the last_live_op. For example, - // b = op1(a) - // c = op2(a, b) - // in this case, a is the input of op1 and op2, we only need to check - // a after op2, because op2 always uses a after op1. - instr.AddGCCheckVar(var_id); - VLOG(4) << "clear " << global_scope_->GetNameById(var_id) << " after " - << instr.OpBase()->Type(); + last_live_ops_[var_id].insert(op_idx); } else { VLOG(4) << "not clear " << global_scope_->GetNameById(var_id) << " after " << instr.OpBase()->Type() @@ -276,19 +266,45 @@ void InterpreterCore::Convert( for (size_t i = 0; i < vec_instruction_.size(); ++i) { // checkout ouput for (auto& item : vec_instruction_[i].Outputs()) { - for (auto id : item.second) { - if (input_var2op_info_.at(id).size() == 0) { - // output var not be used by any kernel - vec_instruction_[i].AddGCCheckVar(id); - VLOG(4) << "clear " << global_scope_->GetNameById(id) << " after " - << vec_instruction_[i].OpBase()->Type(); - vec_meta_info[id].var_ref_count_++; + for (auto var_id : item.second) { + if (input_var2op_info_.at(var_id).size() == 0) { + last_live_ops_[var_id].insert(i); } } } } - BuildOperatorDependences(); + // shrink, find the downstream op that has no other op in the + // downstream list happens before it + // For example, + // b = op1(a) + // c = op2(a, b) + // in this case, a is the input of op1 and op2, we only need to check + // a after op2, because op2 always uses a after op1. + for (size_t i = 0; i < last_live_ops_.size(); ++i) { + std::set minumum_last_live_ops; + for (size_t item : last_live_ops_[i]) { + bool not_before_any = true; + // find the op that is not executed before any + for (size_t other_item : last_live_ops_[i]) { + if (op_happens_before_[item][other_item]) { + VLOG(8) << "happens_before: " << item << "->" << other_item + << ", so skip " << item; + not_before_any = false; + break; + } + } + if (not_before_any) { + VLOG(8) << "last live op of var " << i << " " + << global_scope_->GetNameById(i) << " : " << item << " " + << vec_instruction_[item].OpBase()->Type(); + minumum_last_live_ops.insert(item); + vec_instruction_[item].AddGCCheckVar(i); + } + } + last_live_ops_[i] = minumum_last_live_ops; + vec_meta_info[i].var_ref_count_ = last_live_ops_[i].size(); + } for (size_t i = 0; i < vec_instruction_.size(); ++i) { BuildAndCacheInstructionCtx(&vec_instruction_[i]); diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index c1ade85e1384c0e1f6fe3f3d6480b606e8a24391..3af0ddb675a45157332928c1997d73d3096aff7b 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -109,6 +109,11 @@ class InterpreterCore { std::vector vec_instruction_; // deconstruct before OpFuncNode + // op_happens_before_[i][j] == true means op[i] happens before op[j] + std::vector> op_happens_before_; + // last_live_ops_[i] contains the id of operatos that last access var[i] + std::map> last_live_ops_; + std::vector dependecy_count_; std::atomic unfinished_op_numer_{0}; std::vector> input_var2op_info_; diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index 63fcf0cffaa84c47412958da3c5a44f42a27a091..be05acd7b71ee6c9c770427c9218eecefb275c62 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -614,23 +614,125 @@ void update_var_min_rw_op(const std::map>& op2dependences, } std::map> get_downstream_map( - const std::map>& op2dependences) { - // op2dependences is op -> it's dependences. we want to get op -> [ops] map, + const std::map>& op2dependences, + std::vector>* op_happens_before) { + // step1: convert op2dependences to downstream_map directly + // op2dependences is op -> it's dependences. + // we want to get op -> [next ops] map, // where ops is the next instruction of op. - std::map> result; + std::map> downstream; for (auto& item : op2dependences) { int op = item.first; for (auto dep_op : item.second) { - if (result.find(dep_op) == result.end()) - result[dep_op] = std::list(); - result[dep_op].push_back(op); + if (downstream.find(dep_op) == downstream.end()) + downstream[dep_op] = std::list(); + downstream[dep_op].push_back(op); } } - return std::move(result); + + auto downstream_map_to_str = [&]() -> std::string { + std::ostringstream oss; + for (auto pair : downstream) { + oss << pair.first << " -> "; + std::copy(pair.second.begin(), pair.second.end(), + std::ostream_iterator(oss, " ")); + oss << std::endl; + } + return oss.str(); + }; + + auto downstream_map_count = [&]() -> size_t { + size_t count = 0; + for (auto pair : downstream) { + count += pair.second.size(); + } + return count; + }; + + VLOG(6) << "downstream count: " << downstream_map_count(); + VLOG(6) << "downstream_map: " << std::endl << downstream_map_to_str(); + + // step2: remove unneccessary downstream ops + // for example, a->b->c + // a: b, c + // b: c + // => + // a: b + // b: c + + // NOTE(zhiqiu): the size of downstream != size of op2dependences + // since there are some ops that have no downstream-op. + auto op_num = op2dependences.size(); + // happens_before[i][j] means i should be executed before j + op_happens_before->resize(op_num); + for (size_t i = 0; i < op_num; ++i) { + (*op_happens_before)[i].resize(op_num); + std::fill((*op_happens_before)[i].begin(), (*op_happens_before)[i].end(), + false); + } + + // bfs to get all next ops + auto bfs = [&](size_t op_idx) { + std::queue q; + std::vector visited(op_num, false); + q.push(op_idx); + while (!q.empty()) { + size_t op = q.front(); + q.pop(); + visited[op] = true; + if (!downstream.count(op)) { + continue; + } + for (auto next : downstream[op]) { + if (!visited[next]) { + PADDLE_ENFORCE_EQ((*op_happens_before)[next][op_idx], false, + paddle::platform::errors::AlreadyExists( + "There exists circle in graph, expected " + "%d->%d, but already got %d->%d", + op_idx, next, next, op_idx)); + (*op_happens_before)[op_idx][next] = true; + VLOG(8) << "happens before: " << op_idx << " " << next; + q.push(next); + } + } + } + }; + + for (size_t i = 0; i < op_num; ++i) { + bfs(i); + } + + // shrink, find the downstream op that has no other op in the + // downstream list happens before it + for (size_t i = 0; i < op_num; ++i) { + std::list minumum_nexts; + for (size_t item : downstream[i]) { + bool not_after_any = true; + // find the op that is not executed after any + for (size_t other_item : downstream[i]) { + if ((*op_happens_before)[other_item][item]) { + VLOG(8) << "happens_before: " << other_item << "->" << item + << ", so skip " << item; + not_after_any = false; + break; + } + } + if (not_after_any) { + VLOG(8) << "downstream op of " << i << ": " << item; + minumum_nexts.push_back(item); + } + } + downstream[i] = minumum_nexts; + } + VLOG(6) << "downstream count: " << downstream_map_count(); + VLOG(6) << "downstream_map: " << std::endl << downstream_map_to_str(); + + return std::move(downstream); } std::map> build_op_downstream_map( - const std::vector& vec_instruction) { + const std::vector& vec_instruction, + std::vector>* op_happens_before) { auto var2min_rw_op = std::map< int, std::list>(); // # map from variable id to read / write op id. auto var2recent_write_op = @@ -873,13 +975,13 @@ std::map> build_op_downstream_map( } } for (auto pair : op2dependences) { - VLOG(10) << pair.first << " Depends on " << pair.second.size(); std::ostringstream oss; + oss << pair.first << " Depends on " << pair.second.size() << " ops: "; std::copy(pair.second.begin(), pair.second.end(), std::ostream_iterator(oss, " ")); VLOG(10) << oss.str(); } - return std::move(get_downstream_map(op2dependences)); + return std::move(get_downstream_map(op2dependences, op_happens_before)); } } // namespace interpreter diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.h b/paddle/fluid/framework/new_executor/interpretercore_util.h index 044a9ea368cbc506ce4a30bb82562177263786f9..56683330ee6cb90f40645bfa0160516d30cf5418 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.h +++ b/paddle/fluid/framework/new_executor/interpretercore_util.h @@ -116,7 +116,8 @@ void build_op_func_list(const platform::Place& place, VariableScope* var_scope, bool use_local_scope = true); std::map> build_op_downstream_map( - const std::vector& vec_instruction); + const std::vector& vec_instruction, + std::vector>* op_happens_before); void add_fetch(const std::vector& fetch_names, framework::BlockDesc* block);