diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index 0b75964b94e914c580b3f2ffb9e1538668702832..72b7477f2b8703c52d472b713226f02321bb0124 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -629,46 +629,75 @@ void update_var_min_rw_op(const std::map>& op2dependences, var2min_rw_op->at(rw_var).push_back(cur_op); } -std::map> get_downstream_map( - const std::map>& op2dependences, - std::vector>* op_happens_before) { - // step1: convert op2dependences to downstream_map directly - // op2dependences is op -> it's dependences. - // we want to get op -> [next ops] map, - // where ops is the next instruction of op. - std::map> downstream; +void AddDownstreamOp(int prior_op_idx, int posterior_op_idx, + std::map>* op_downstream_map) { + if (op_downstream_map->find(prior_op_idx) == op_downstream_map->end()) { + op_downstream_map->emplace(std::make_pair(prior_op_idx, std::list())); + } + op_downstream_map->at(prior_op_idx).push_back(posterior_op_idx); +} + +void AddDownstreamOp(int prior_op_idx, int posterior_op_idx, + std::map>* op_downstream_map, + const std::vector>& op_happens_before) { + if (op_downstream_map->find(prior_op_idx) != op_downstream_map->end()) { + for (int op_idx : op_downstream_map->at(prior_op_idx)) { + if (op_happens_before[op_idx][posterior_op_idx]) { + VLOG(7) << "Find dependencies " << prior_op_idx << "->" << op_idx + << "->" << posterior_op_idx << ", skip adding " << prior_op_idx + << "->" << posterior_op_idx; + return; + } + } + } + + AddDownstreamOp(prior_op_idx, posterior_op_idx, op_downstream_map); +} + +size_t CountDownstreamMap(const std::map>& downstream_map) { + size_t count = 0; + for (auto pair : downstream_map) { + count += pair.second.size(); + } + return count; +} + +const std::string StringizeDownstreamMap( + const std::map>& downstream_map) { + std::ostringstream oss; + for (auto pair : downstream_map) { + oss << pair.first << " -> "; + std::copy(pair.second.begin(), pair.second.end(), + std::ostream_iterator(oss, " ")); + oss << std::endl; + } + return oss.str(); +} + +// convert op2dependences to downstream_map directly. op2dependences is op -> +// it's dependences, we want to get op -> [next ops] map, where ops is the next +// instruction of op. +std::map> GetDownstreamMap( + const std::map>& op2dependences) { + std::map> downstream_map; for (auto& item : op2dependences) { int op = item.first; for (auto dep_op : item.second) { - if (downstream.find(dep_op) == downstream.end()) - downstream[dep_op] = std::list(); - downstream[dep_op].push_back(op); + AddDownstreamOp(dep_op, op, &downstream_map); } } - auto downstream_map_to_str = [&]() -> std::string { - std::ostringstream oss; - for (auto pair : downstream) { - oss << pair.first << " -> "; - std::copy(pair.second.begin(), pair.second.end(), - std::ostream_iterator(oss, " ")); - oss << std::endl; - } - return oss.str(); - }; - - auto downstream_map_count = [&]() -> size_t { - size_t count = 0; - for (auto pair : downstream) { - count += pair.second.size(); - } - return count; - }; + VLOG(6) << "downstream count: " << CountDownstreamMap(downstream_map); + VLOG(6) << "downstream_map: " << std::endl + << StringizeDownstreamMap(downstream_map); - VLOG(6) << "downstream count: " << downstream_map_count(); - VLOG(6) << "downstream_map: " << std::endl << downstream_map_to_str(); + return downstream_map; +} - // step2: remove unnecessary downstream ops +void ShrinkDownstreamMap(std::map>* downstream_map, + std::vector>* op_happens_before, + size_t op_num) { + // remove unnecessary downstream ops // for example, a->b->c // a: b, c // b: c @@ -676,9 +705,6 @@ std::map> get_downstream_map( // a: b // b: c - // NOTE(zhiqiu): the size of downstream != size of op2dependences - // since there are some ops that have no downstream-op. - auto op_num = op2dependences.size(); // happens_before[i][j] means i should be executed before j op_happens_before->resize(op_num); for (size_t i = 0; i < op_num; ++i) { @@ -696,10 +722,10 @@ std::map> get_downstream_map( size_t op = q.front(); q.pop(); visited[op] = true; - if (!downstream.count(op)) { + if (!downstream_map->count(op)) { continue; } - for (auto next : downstream[op]) { + for (auto next : downstream_map->at(op)) { if (!visited[next]) { PADDLE_ENFORCE_EQ((*op_happens_before)[next][op_idx], false, paddle::platform::errors::AlreadyExists( @@ -721,11 +747,15 @@ std::map> get_downstream_map( // shrink, find the downstream op that has no other op in the // downstream list happens before it for (size_t i = 0; i < op_num; ++i) { + if (downstream_map->find(i) == downstream_map->end()) { + continue; + } + std::list minumum_nexts; - for (size_t item : downstream[i]) { + for (size_t item : downstream_map->at(i)) { bool not_after_any = true; // find the op that is not executed after any - for (size_t other_item : downstream[i]) { + for (size_t other_item : downstream_map->at(i)) { if ((*op_happens_before)[other_item][item]) { VLOG(8) << "happens_before: " << other_item << "->" << item << ", so skip " << item; @@ -738,12 +768,11 @@ std::map> get_downstream_map( minumum_nexts.push_back(item); } } - downstream[i] = minumum_nexts; + downstream_map->at(i) = minumum_nexts; } - VLOG(6) << "downstream count: " << downstream_map_count(); - VLOG(6) << "downstream_map: " << std::endl << downstream_map_to_str(); - - return downstream; + VLOG(6) << "downstream count: " << CountDownstreamMap(*downstream_map); + VLOG(6) << "downstream_map: " << std::endl + << StringizeDownstreamMap(*downstream_map); } std::map> build_op_downstream_map( @@ -825,6 +854,14 @@ std::map> build_op_downstream_map( } } + // NOTE(zhiqiu): the size of downstream != size of op2dependences since there + // are some ops that have no downstream-op. + std::map> op_downstream_map = + GetDownstreamMap(op2dependences); + + ShrinkDownstreamMap(&op_downstream_map, op_happens_before, + vec_instruction.size()); + // add dependences for random op, make sure that the random op is scheduled // sequentially const std::set random_op_set = { @@ -846,7 +883,8 @@ std::map> build_op_downstream_map( for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) { if (random_op_set.count(vec_instruction[op_idx].OpBase()->Type())) { if (dependence_op_idx != -1) { - op2dependences[op_idx].insert(dependence_op_idx); + AddDownstreamOp(dependence_op_idx, op_idx, &op_downstream_map, + *op_happens_before); } dependence_op_idx = op_idx; } @@ -872,7 +910,8 @@ std::map> build_op_downstream_map( for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) { if (is_comm_op(vec_instruction[op_idx].OpBase()->Type())) { if (dependence_op_idx != -1) { - op2dependences[op_idx].insert(dependence_op_idx); + AddDownstreamOp(dependence_op_idx, op_idx, &op_downstream_map, + *op_happens_before); VLOG(4) << "Add depend from " << vec_instruction[dependence_op_idx].OpBase()->Type() << " to " << vec_instruction[op_idx].OpBase()->Type(); @@ -900,7 +939,8 @@ std::map> build_op_downstream_map( VLOG(4) << "Add depend from " << vec_instruction[dependence_op_idx].OpBase()->Type() << " to " << vec_instruction[op_idx].OpBase()->Type(); - op2dependences[op_idx].insert(dependence_op_idx); + AddDownstreamOp(dependence_op_idx, op_idx, &op_downstream_map, + *op_happens_before); } } } @@ -956,7 +996,8 @@ std::map> build_op_downstream_map( j < static_cast(first_read_fused_out_op); ++j) { for (auto var_id : outputs) { if (is_write(vec_instruction[j], var_id)) { - op2dependences[first_read_fused_out_op].insert(j); + AddDownstreamOp(j, first_read_fused_out_op, &op_downstream_map, + *op_happens_before); VLOG(4) << j << " -> " << first_read_fused_out_op; VLOG(4) << "Add depend from " << vec_instruction[j].OpBase()->Type() @@ -990,6 +1031,7 @@ std::map> build_op_downstream_map( for (auto var_id : outputs) { if (is_read(vec_instruction[j], var_id)) { + AddDownstreamOp(target, j, &op_downstream_map, *op_happens_before); op2dependences[j].insert(target); VLOG(4) << target << " -> " << j; VLOG(4) << "Add depend from " @@ -1000,14 +1042,12 @@ std::map> build_op_downstream_map( } } } - for (auto pair : op2dependences) { - std::ostringstream oss; - oss << pair.first << " Depends on " << pair.second.size() << " ops: "; - std::copy(pair.second.begin(), pair.second.end(), - std::ostream_iterator(oss, " ")); - VLOG(10) << oss.str(); - } - return get_downstream_map(op2dependences, op_happens_before); + + VLOG(8) << "downstream count: " << CountDownstreamMap(op_downstream_map); + VLOG(8) << "downstream_map: " << std::endl + << StringizeDownstreamMap(op_downstream_map); + + return op_downstream_map; } } // namespace interpreter