未验证 提交 04294f80 编写于 作者: R Ruibiao Chen 提交者: GitHub

Fix shrink downstream op bugs for standalone executor (#43330)

上级 c96f7a29
...@@ -629,46 +629,75 @@ void update_var_min_rw_op(const std::map<int, std::set<int>>& op2dependences, ...@@ -629,46 +629,75 @@ void update_var_min_rw_op(const std::map<int, std::set<int>>& op2dependences,
var2min_rw_op->at(rw_var).push_back(cur_op); var2min_rw_op->at(rw_var).push_back(cur_op);
} }
std::map<int, std::list<int>> get_downstream_map( void AddDownstreamOp(int prior_op_idx, int posterior_op_idx,
const std::map<int, std::set<int>>& op2dependences, std::map<int, std::list<int>>* op_downstream_map) {
std::vector<std::vector<bool>>* op_happens_before) { if (op_downstream_map->find(prior_op_idx) == op_downstream_map->end()) {
// step1: convert op2dependences to downstream_map directly op_downstream_map->emplace(std::make_pair(prior_op_idx, std::list<int>()));
// op2dependences is op -> it's dependences. }
// we want to get op -> [next ops] map, op_downstream_map->at(prior_op_idx).push_back(posterior_op_idx);
// where ops is the next instruction of op. }
std::map<int, std::list<int>> downstream;
void AddDownstreamOp(int prior_op_idx, int posterior_op_idx,
std::map<int, std::list<int>>* op_downstream_map,
const std::vector<std::vector<bool>>& op_happens_before) {
if (op_downstream_map->find(prior_op_idx) != op_downstream_map->end()) {
for (int op_idx : op_downstream_map->at(prior_op_idx)) {
if (op_happens_before[op_idx][posterior_op_idx]) {
VLOG(7) << "Find dependencies " << prior_op_idx << "->" << op_idx
<< "->" << posterior_op_idx << ", skip adding " << prior_op_idx
<< "->" << posterior_op_idx;
return;
}
}
}
AddDownstreamOp(prior_op_idx, posterior_op_idx, op_downstream_map);
}
size_t CountDownstreamMap(const std::map<int, std::list<int>>& downstream_map) {
size_t count = 0;
for (auto pair : downstream_map) {
count += pair.second.size();
}
return count;
}
const std::string StringizeDownstreamMap(
const std::map<int, std::list<int>>& downstream_map) {
std::ostringstream oss;
for (auto pair : downstream_map) {
oss << pair.first << " -> ";
std::copy(pair.second.begin(), pair.second.end(),
std::ostream_iterator<int>(oss, " "));
oss << std::endl;
}
return oss.str();
}
// convert op2dependences to downstream_map directly. op2dependences is op ->
// it's dependences, we want to get op -> [next ops] map, where ops is the next
// instruction of op.
std::map<int, std::list<int>> GetDownstreamMap(
const std::map<int, std::set<int>>& op2dependences) {
std::map<int, std::list<int>> downstream_map;
for (auto& item : op2dependences) { for (auto& item : op2dependences) {
int op = item.first; int op = item.first;
for (auto dep_op : item.second) { for (auto dep_op : item.second) {
if (downstream.find(dep_op) == downstream.end()) AddDownstreamOp(dep_op, op, &downstream_map);
downstream[dep_op] = std::list<int>();
downstream[dep_op].push_back(op);
} }
} }
auto downstream_map_to_str = [&]() -> std::string { VLOG(6) << "downstream count: " << CountDownstreamMap(downstream_map);
std::ostringstream oss; VLOG(6) << "downstream_map: " << std::endl
for (auto pair : downstream) { << StringizeDownstreamMap(downstream_map);
oss << pair.first << " -> ";
std::copy(pair.second.begin(), pair.second.end(),
std::ostream_iterator<int>(oss, " "));
oss << std::endl;
}
return oss.str();
};
auto downstream_map_count = [&]() -> size_t {
size_t count = 0;
for (auto pair : downstream) {
count += pair.second.size();
}
return count;
};
VLOG(6) << "downstream count: " << downstream_map_count(); return downstream_map;
VLOG(6) << "downstream_map: " << std::endl << downstream_map_to_str(); }
// step2: remove unnecessary downstream ops void ShrinkDownstreamMap(std::map<int, std::list<int>>* downstream_map,
std::vector<std::vector<bool>>* op_happens_before,
size_t op_num) {
// remove unnecessary downstream ops
// for example, a->b->c // for example, a->b->c
// a: b, c // a: b, c
// b: c // b: c
...@@ -676,9 +705,6 @@ std::map<int, std::list<int>> get_downstream_map( ...@@ -676,9 +705,6 @@ std::map<int, std::list<int>> get_downstream_map(
// a: b // a: b
// b: c // b: c
// NOTE(zhiqiu): the size of downstream != size of op2dependences
// since there are some ops that have no downstream-op.
auto op_num = op2dependences.size();
// happens_before[i][j] means i should be executed before j // happens_before[i][j] means i should be executed before j
op_happens_before->resize(op_num); op_happens_before->resize(op_num);
for (size_t i = 0; i < op_num; ++i) { for (size_t i = 0; i < op_num; ++i) {
...@@ -696,10 +722,10 @@ std::map<int, std::list<int>> get_downstream_map( ...@@ -696,10 +722,10 @@ std::map<int, std::list<int>> get_downstream_map(
size_t op = q.front(); size_t op = q.front();
q.pop(); q.pop();
visited[op] = true; visited[op] = true;
if (!downstream.count(op)) { if (!downstream_map->count(op)) {
continue; continue;
} }
for (auto next : downstream[op]) { for (auto next : downstream_map->at(op)) {
if (!visited[next]) { if (!visited[next]) {
PADDLE_ENFORCE_EQ((*op_happens_before)[next][op_idx], false, PADDLE_ENFORCE_EQ((*op_happens_before)[next][op_idx], false,
paddle::platform::errors::AlreadyExists( paddle::platform::errors::AlreadyExists(
...@@ -721,11 +747,15 @@ std::map<int, std::list<int>> get_downstream_map( ...@@ -721,11 +747,15 @@ std::map<int, std::list<int>> get_downstream_map(
// shrink, find the downstream op that has no other op in the // shrink, find the downstream op that has no other op in the
// downstream list happens before it // downstream list happens before it
for (size_t i = 0; i < op_num; ++i) { for (size_t i = 0; i < op_num; ++i) {
if (downstream_map->find(i) == downstream_map->end()) {
continue;
}
std::list<int> minumum_nexts; std::list<int> minumum_nexts;
for (size_t item : downstream[i]) { for (size_t item : downstream_map->at(i)) {
bool not_after_any = true; bool not_after_any = true;
// find the op that is not executed after any // find the op that is not executed after any
for (size_t other_item : downstream[i]) { for (size_t other_item : downstream_map->at(i)) {
if ((*op_happens_before)[other_item][item]) { if ((*op_happens_before)[other_item][item]) {
VLOG(8) << "happens_before: " << other_item << "->" << item VLOG(8) << "happens_before: " << other_item << "->" << item
<< ", so skip " << item; << ", so skip " << item;
...@@ -738,12 +768,11 @@ std::map<int, std::list<int>> get_downstream_map( ...@@ -738,12 +768,11 @@ std::map<int, std::list<int>> get_downstream_map(
minumum_nexts.push_back(item); minumum_nexts.push_back(item);
} }
} }
downstream[i] = minumum_nexts; downstream_map->at(i) = minumum_nexts;
} }
VLOG(6) << "downstream count: " << downstream_map_count(); VLOG(6) << "downstream count: " << CountDownstreamMap(*downstream_map);
VLOG(6) << "downstream_map: " << std::endl << downstream_map_to_str(); VLOG(6) << "downstream_map: " << std::endl
<< StringizeDownstreamMap(*downstream_map);
return downstream;
} }
std::map<int, std::list<int>> build_op_downstream_map( std::map<int, std::list<int>> build_op_downstream_map(
...@@ -825,6 +854,14 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -825,6 +854,14 @@ std::map<int, std::list<int>> build_op_downstream_map(
} }
} }
// NOTE(zhiqiu): the size of downstream != size of op2dependences since there
// are some ops that have no downstream-op.
std::map<int, std::list<int>> op_downstream_map =
GetDownstreamMap(op2dependences);
ShrinkDownstreamMap(&op_downstream_map, op_happens_before,
vec_instruction.size());
// add dependences for random op, make sure that the random op is scheduled // add dependences for random op, make sure that the random op is scheduled
// sequentially // sequentially
const std::set<std::string> random_op_set = { const std::set<std::string> random_op_set = {
...@@ -846,7 +883,8 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -846,7 +883,8 @@ std::map<int, std::list<int>> build_op_downstream_map(
for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) { for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
if (random_op_set.count(vec_instruction[op_idx].OpBase()->Type())) { if (random_op_set.count(vec_instruction[op_idx].OpBase()->Type())) {
if (dependence_op_idx != -1) { if (dependence_op_idx != -1) {
op2dependences[op_idx].insert(dependence_op_idx); AddDownstreamOp(dependence_op_idx, op_idx, &op_downstream_map,
*op_happens_before);
} }
dependence_op_idx = op_idx; dependence_op_idx = op_idx;
} }
...@@ -872,7 +910,8 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -872,7 +910,8 @@ std::map<int, std::list<int>> build_op_downstream_map(
for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) { for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
if (is_comm_op(vec_instruction[op_idx].OpBase()->Type())) { if (is_comm_op(vec_instruction[op_idx].OpBase()->Type())) {
if (dependence_op_idx != -1) { if (dependence_op_idx != -1) {
op2dependences[op_idx].insert(dependence_op_idx); AddDownstreamOp(dependence_op_idx, op_idx, &op_downstream_map,
*op_happens_before);
VLOG(4) << "Add depend from " VLOG(4) << "Add depend from "
<< vec_instruction[dependence_op_idx].OpBase()->Type() << " to " << vec_instruction[dependence_op_idx].OpBase()->Type() << " to "
<< vec_instruction[op_idx].OpBase()->Type(); << vec_instruction[op_idx].OpBase()->Type();
...@@ -900,7 +939,8 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -900,7 +939,8 @@ std::map<int, std::list<int>> build_op_downstream_map(
VLOG(4) << "Add depend from " VLOG(4) << "Add depend from "
<< vec_instruction[dependence_op_idx].OpBase()->Type() << " to " << vec_instruction[dependence_op_idx].OpBase()->Type() << " to "
<< vec_instruction[op_idx].OpBase()->Type(); << vec_instruction[op_idx].OpBase()->Type();
op2dependences[op_idx].insert(dependence_op_idx); AddDownstreamOp(dependence_op_idx, op_idx, &op_downstream_map,
*op_happens_before);
} }
} }
} }
...@@ -956,7 +996,8 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -956,7 +996,8 @@ std::map<int, std::list<int>> build_op_downstream_map(
j < static_cast<size_t>(first_read_fused_out_op); ++j) { j < static_cast<size_t>(first_read_fused_out_op); ++j) {
for (auto var_id : outputs) { for (auto var_id : outputs) {
if (is_write(vec_instruction[j], var_id)) { if (is_write(vec_instruction[j], var_id)) {
op2dependences[first_read_fused_out_op].insert(j); AddDownstreamOp(j, first_read_fused_out_op, &op_downstream_map,
*op_happens_before);
VLOG(4) << j << " -> " << first_read_fused_out_op; VLOG(4) << j << " -> " << first_read_fused_out_op;
VLOG(4) VLOG(4)
<< "Add depend from " << vec_instruction[j].OpBase()->Type() << "Add depend from " << vec_instruction[j].OpBase()->Type()
...@@ -990,6 +1031,7 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -990,6 +1031,7 @@ std::map<int, std::list<int>> build_op_downstream_map(
for (auto var_id : outputs) { for (auto var_id : outputs) {
if (is_read(vec_instruction[j], var_id)) { if (is_read(vec_instruction[j], var_id)) {
AddDownstreamOp(target, j, &op_downstream_map, *op_happens_before);
op2dependences[j].insert(target); op2dependences[j].insert(target);
VLOG(4) << target << " -> " << j; VLOG(4) << target << " -> " << j;
VLOG(4) << "Add depend from " VLOG(4) << "Add depend from "
...@@ -1000,14 +1042,12 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -1000,14 +1042,12 @@ std::map<int, std::list<int>> build_op_downstream_map(
} }
} }
} }
for (auto pair : op2dependences) {
std::ostringstream oss; VLOG(8) << "downstream count: " << CountDownstreamMap(op_downstream_map);
oss << pair.first << " Depends on " << pair.second.size() << " ops: "; VLOG(8) << "downstream_map: " << std::endl
std::copy(pair.second.begin(), pair.second.end(), << StringizeDownstreamMap(op_downstream_map);
std::ostream_iterator<int>(oss, " "));
VLOG(10) << oss.str(); return op_downstream_map;
}
return get_downstream_map(op2dependences, op_happens_before);
} }
} // namespace interpreter } // namespace interpreter
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册