未验证 提交 04294f80 编写于 作者: R Ruibiao Chen 提交者: GitHub

Fix shrink downstream op bugs for standalone executor (#43330)

上级 c96f7a29
......@@ -629,46 +629,75 @@ void update_var_min_rw_op(const std::map<int, std::set<int>>& op2dependences,
var2min_rw_op->at(rw_var).push_back(cur_op);
}
std::map<int, std::list<int>> get_downstream_map(
const std::map<int, std::set<int>>& op2dependences,
std::vector<std::vector<bool>>* op_happens_before) {
// step1: convert op2dependences to downstream_map directly
// op2dependences is op -> it's dependences.
// we want to get op -> [next ops] map,
// where ops is the next instruction of op.
std::map<int, std::list<int>> downstream;
void AddDownstreamOp(int prior_op_idx, int posterior_op_idx,
std::map<int, std::list<int>>* op_downstream_map) {
if (op_downstream_map->find(prior_op_idx) == op_downstream_map->end()) {
op_downstream_map->emplace(std::make_pair(prior_op_idx, std::list<int>()));
}
op_downstream_map->at(prior_op_idx).push_back(posterior_op_idx);
}
void AddDownstreamOp(int prior_op_idx, int posterior_op_idx,
std::map<int, std::list<int>>* op_downstream_map,
const std::vector<std::vector<bool>>& op_happens_before) {
if (op_downstream_map->find(prior_op_idx) != op_downstream_map->end()) {
for (int op_idx : op_downstream_map->at(prior_op_idx)) {
if (op_happens_before[op_idx][posterior_op_idx]) {
VLOG(7) << "Find dependencies " << prior_op_idx << "->" << op_idx
<< "->" << posterior_op_idx << ", skip adding " << prior_op_idx
<< "->" << posterior_op_idx;
return;
}
}
}
AddDownstreamOp(prior_op_idx, posterior_op_idx, op_downstream_map);
}
size_t CountDownstreamMap(const std::map<int, std::list<int>>& downstream_map) {
size_t count = 0;
for (auto pair : downstream_map) {
count += pair.second.size();
}
return count;
}
const std::string StringizeDownstreamMap(
const std::map<int, std::list<int>>& downstream_map) {
std::ostringstream oss;
for (auto pair : downstream_map) {
oss << pair.first << " -> ";
std::copy(pair.second.begin(), pair.second.end(),
std::ostream_iterator<int>(oss, " "));
oss << std::endl;
}
return oss.str();
}
// convert op2dependences to downstream_map directly. op2dependences is op ->
// it's dependences, we want to get op -> [next ops] map, where ops is the next
// instruction of op.
std::map<int, std::list<int>> GetDownstreamMap(
const std::map<int, std::set<int>>& op2dependences) {
std::map<int, std::list<int>> downstream_map;
for (auto& item : op2dependences) {
int op = item.first;
for (auto dep_op : item.second) {
if (downstream.find(dep_op) == downstream.end())
downstream[dep_op] = std::list<int>();
downstream[dep_op].push_back(op);
AddDownstreamOp(dep_op, op, &downstream_map);
}
}
auto downstream_map_to_str = [&]() -> std::string {
std::ostringstream oss;
for (auto pair : downstream) {
oss << pair.first << " -> ";
std::copy(pair.second.begin(), pair.second.end(),
std::ostream_iterator<int>(oss, " "));
oss << std::endl;
}
return oss.str();
};
auto downstream_map_count = [&]() -> size_t {
size_t count = 0;
for (auto pair : downstream) {
count += pair.second.size();
}
return count;
};
VLOG(6) << "downstream count: " << CountDownstreamMap(downstream_map);
VLOG(6) << "downstream_map: " << std::endl
<< StringizeDownstreamMap(downstream_map);
VLOG(6) << "downstream count: " << downstream_map_count();
VLOG(6) << "downstream_map: " << std::endl << downstream_map_to_str();
return downstream_map;
}
// step2: remove unnecessary downstream ops
void ShrinkDownstreamMap(std::map<int, std::list<int>>* downstream_map,
std::vector<std::vector<bool>>* op_happens_before,
size_t op_num) {
// remove unnecessary downstream ops
// for example, a->b->c
// a: b, c
// b: c
......@@ -676,9 +705,6 @@ std::map<int, std::list<int>> get_downstream_map(
// a: b
// b: c
// NOTE(zhiqiu): the size of downstream != size of op2dependences
// since there are some ops that have no downstream-op.
auto op_num = op2dependences.size();
// happens_before[i][j] means i should be executed before j
op_happens_before->resize(op_num);
for (size_t i = 0; i < op_num; ++i) {
......@@ -696,10 +722,10 @@ std::map<int, std::list<int>> get_downstream_map(
size_t op = q.front();
q.pop();
visited[op] = true;
if (!downstream.count(op)) {
if (!downstream_map->count(op)) {
continue;
}
for (auto next : downstream[op]) {
for (auto next : downstream_map->at(op)) {
if (!visited[next]) {
PADDLE_ENFORCE_EQ((*op_happens_before)[next][op_idx], false,
paddle::platform::errors::AlreadyExists(
......@@ -721,11 +747,15 @@ std::map<int, std::list<int>> get_downstream_map(
// shrink, find the downstream op that has no other op in the
// downstream list happens before it
for (size_t i = 0; i < op_num; ++i) {
if (downstream_map->find(i) == downstream_map->end()) {
continue;
}
std::list<int> minumum_nexts;
for (size_t item : downstream[i]) {
for (size_t item : downstream_map->at(i)) {
bool not_after_any = true;
// find the op that is not executed after any
for (size_t other_item : downstream[i]) {
for (size_t other_item : downstream_map->at(i)) {
if ((*op_happens_before)[other_item][item]) {
VLOG(8) << "happens_before: " << other_item << "->" << item
<< ", so skip " << item;
......@@ -738,12 +768,11 @@ std::map<int, std::list<int>> get_downstream_map(
minumum_nexts.push_back(item);
}
}
downstream[i] = minumum_nexts;
downstream_map->at(i) = minumum_nexts;
}
VLOG(6) << "downstream count: " << downstream_map_count();
VLOG(6) << "downstream_map: " << std::endl << downstream_map_to_str();
return downstream;
VLOG(6) << "downstream count: " << CountDownstreamMap(*downstream_map);
VLOG(6) << "downstream_map: " << std::endl
<< StringizeDownstreamMap(*downstream_map);
}
std::map<int, std::list<int>> build_op_downstream_map(
......@@ -825,6 +854,14 @@ std::map<int, std::list<int>> build_op_downstream_map(
}
}
// NOTE(zhiqiu): the size of downstream != size of op2dependences since there
// are some ops that have no downstream-op.
std::map<int, std::list<int>> op_downstream_map =
GetDownstreamMap(op2dependences);
ShrinkDownstreamMap(&op_downstream_map, op_happens_before,
vec_instruction.size());
// add dependences for random op, make sure that the random op is scheduled
// sequentially
const std::set<std::string> random_op_set = {
......@@ -846,7 +883,8 @@ std::map<int, std::list<int>> build_op_downstream_map(
for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
if (random_op_set.count(vec_instruction[op_idx].OpBase()->Type())) {
if (dependence_op_idx != -1) {
op2dependences[op_idx].insert(dependence_op_idx);
AddDownstreamOp(dependence_op_idx, op_idx, &op_downstream_map,
*op_happens_before);
}
dependence_op_idx = op_idx;
}
......@@ -872,7 +910,8 @@ std::map<int, std::list<int>> build_op_downstream_map(
for (size_t op_idx = 0; op_idx < vec_instruction.size(); ++op_idx) {
if (is_comm_op(vec_instruction[op_idx].OpBase()->Type())) {
if (dependence_op_idx != -1) {
op2dependences[op_idx].insert(dependence_op_idx);
AddDownstreamOp(dependence_op_idx, op_idx, &op_downstream_map,
*op_happens_before);
VLOG(4) << "Add depend from "
<< vec_instruction[dependence_op_idx].OpBase()->Type() << " to "
<< vec_instruction[op_idx].OpBase()->Type();
......@@ -900,7 +939,8 @@ std::map<int, std::list<int>> build_op_downstream_map(
VLOG(4) << "Add depend from "
<< vec_instruction[dependence_op_idx].OpBase()->Type() << " to "
<< vec_instruction[op_idx].OpBase()->Type();
op2dependences[op_idx].insert(dependence_op_idx);
AddDownstreamOp(dependence_op_idx, op_idx, &op_downstream_map,
*op_happens_before);
}
}
}
......@@ -956,7 +996,8 @@ std::map<int, std::list<int>> build_op_downstream_map(
j < static_cast<size_t>(first_read_fused_out_op); ++j) {
for (auto var_id : outputs) {
if (is_write(vec_instruction[j], var_id)) {
op2dependences[first_read_fused_out_op].insert(j);
AddDownstreamOp(j, first_read_fused_out_op, &op_downstream_map,
*op_happens_before);
VLOG(4) << j << " -> " << first_read_fused_out_op;
VLOG(4)
<< "Add depend from " << vec_instruction[j].OpBase()->Type()
......@@ -990,6 +1031,7 @@ std::map<int, std::list<int>> build_op_downstream_map(
for (auto var_id : outputs) {
if (is_read(vec_instruction[j], var_id)) {
AddDownstreamOp(target, j, &op_downstream_map, *op_happens_before);
op2dependences[j].insert(target);
VLOG(4) << target << " -> " << j;
VLOG(4) << "Add depend from "
......@@ -1000,14 +1042,12 @@ std::map<int, std::list<int>> build_op_downstream_map(
}
}
}
for (auto pair : op2dependences) {
std::ostringstream oss;
oss << pair.first << " Depends on " << pair.second.size() << " ops: ";
std::copy(pair.second.begin(), pair.second.end(),
std::ostream_iterator<int>(oss, " "));
VLOG(10) << oss.str();
}
return get_downstream_map(op2dependences, op_happens_before);
VLOG(8) << "downstream count: " << CountDownstreamMap(op_downstream_map);
VLOG(8) << "downstream_map: " << std::endl
<< StringizeDownstreamMap(op_downstream_map);
return op_downstream_map;
}
} // namespace interpreter
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册