未验证 提交 f0f4dd1e 编写于 作者: R Ruibiao Chen 提交者: GitHub

Improve stream analyzer (#49314)

* Memory search for stream analyzer

* Shrink redundant waiters
上级 3d09bd53
......@@ -258,13 +258,15 @@ void StreamAnalyzer::AnalyseAllEventInfo(
const std::vector<size_t>& next_instr_ids =
run_type_info[cur_instr_id][DownstreamRunType::kEventRun];
std::set<size_t> waiter_instr_ids;
std::set<size_t> visited_next_instr_id;
for (size_t next_instr_id : next_instr_ids) {
AnalyseEventInfoForTwoInstructions(instructions,
run_type_info,
cur_instr_id,
next_instr_id,
&waiter_instr_ids);
&waiter_instr_ids,
&visited_next_instr_id);
}
for (size_t waiter_instr_id : waiter_instr_ids) {
......@@ -302,7 +304,14 @@ void StreamAnalyzer::AnalyseEventInfoForTwoInstructions(
const std::vector<std::vector<std::vector<size_t>>>& run_type_info,
const size_t cur_instr_id,
const size_t next_instr_id,
std::set<size_t>* waiter_instr_ids) const {
std::set<size_t>* waiter_instr_ids,
std::set<size_t>* visited_next_instr_id) const {
if (visited_next_instr_id->find(next_instr_id) !=
visited_next_instr_id->end()) {
return;
}
visited_next_instr_id->insert(next_instr_id);
// NOTE(Ruibiao): Though depend_op as next_instr is no_need_buffer, we should
// also wait event for it. Because depend_op is used to build dependencies for
// fused vars in some scenarios. In those cases, we do not know which vars may
......@@ -338,21 +347,26 @@ void StreamAnalyzer::AnalyseEventInfoForTwoInstructions(
// between cur_instr and next_instr.
for (size_t instr_id :
run_type_info[next_instr_id][DownstreamRunType::kDirectRun]) {
AnalyseEventInfoForTwoInstructions(
instructions, run_type_info, cur_instr_id, instr_id, waiter_instr_ids);
AnalyseEventInfoForTwoInstructions(instructions,
run_type_info,
cur_instr_id,
instr_id,
waiter_instr_ids,
visited_next_instr_id);
}
}
// waiter instr should only wait events for the last recorder instrs in each
// stream
void StreamAnalyzer::ShrinkEventInfo(
const DependencyBuilder& dependency_builder,
std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>*
event_info) const {
for (auto& context_item : *event_info) {
for (auto& waiter_item : context_item.second) {
size_t waiter_instr_id = waiter_item.first;
std::set<size_t>& recorder_instr_ids = waiter_item.second;
for (auto& item : *event_info) {
// shrink redundant recorders, waiter instrs should only wait for the last
// recorder instrs in each stream
std::map<size_t, std::set<size_t>>& waiter_recorder_map = item.second;
for (auto& waiter_recorder : waiter_recorder_map) {
size_t waiter_instr_id = waiter_recorder.first;
std::set<size_t>& recorder_instr_ids = waiter_recorder.second;
std::set<size_t> unnecessary_recorder_instr_ids;
for (size_t cur_instr_id : recorder_instr_ids) {
for (size_t next_instr_id : recorder_instr_ids) {
......@@ -370,6 +384,38 @@ void StreamAnalyzer::ShrinkEventInfo(
recorder_instr_ids.erase(unnecessary_recorder_instr_id);
}
}
// shrink redundant waiters, recorder instrs should only wait by the first
// waiter instrs in each stream
std::map<size_t, std::set<size_t>> recorder_waiter_map;
for (auto& waiter_recorder : waiter_recorder_map) {
size_t waiter_instr_id = waiter_recorder.first;
std::set<size_t>& recorder_instr_ids = waiter_recorder.second;
for (size_t record_instr_id : recorder_instr_ids) {
recorder_waiter_map[record_instr_id].insert(waiter_instr_id);
}
}
for (auto& recorder_waiter : recorder_waiter_map) {
size_t recorder_instr_id = recorder_waiter.first;
std::set<size_t>& waiter_instr_ids = recorder_waiter.second;
std::set<size_t> unnecessary_waiter_instr_ids;
for (size_t cur_instr_id : waiter_instr_ids) {
for (size_t next_instr_id : waiter_instr_ids) {
if (dependency_builder.OpHappensBefore(cur_instr_id, next_instr_id)) {
unnecessary_waiter_instr_ids.insert(next_instr_id);
break;
}
}
}
for (size_t unnecessary_wiater_instr_id : unnecessary_waiter_instr_ids) {
VLOG(8) << "Shrink event : " << recorder_instr_id << " -> "
<< unnecessary_wiater_instr_id;
waiter_recorder_map[unnecessary_wiater_instr_id].erase(
recorder_instr_id);
}
}
}
}
......
......@@ -64,7 +64,8 @@ class StreamAnalyzer {
const std::vector<std::vector<std::vector<size_t>>>& run_type_info,
const size_t cur_instr_id,
const size_t next_instr_id,
std::set<size_t>* waiter_instr_ids) const;
std::set<size_t>* waiter_instr_ids,
std::set<size_t>* visited_next_instr_id) const;
void ShrinkEventInfo(
const DependencyBuilder& dependency_builder,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册