diff --git a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc index fa33610096b7ccbdc461242a81ce1a1e89ade889..8d700f51012e2fadbc6448b6825f68210a1edbfc 100644 --- a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc @@ -258,13 +258,15 @@ void StreamAnalyzer::AnalyseAllEventInfo( const std::vector& next_instr_ids = run_type_info[cur_instr_id][DownstreamRunType::kEventRun]; std::set waiter_instr_ids; + std::set visited_next_instr_id; for (size_t next_instr_id : next_instr_ids) { AnalyseEventInfoForTwoInstructions(instructions, run_type_info, cur_instr_id, next_instr_id, - &waiter_instr_ids); + &waiter_instr_ids, + &visited_next_instr_id); } for (size_t waiter_instr_id : waiter_instr_ids) { @@ -302,7 +304,14 @@ void StreamAnalyzer::AnalyseEventInfoForTwoInstructions( const std::vector>>& run_type_info, const size_t cur_instr_id, const size_t next_instr_id, - std::set* waiter_instr_ids) const { + std::set* waiter_instr_ids, + std::set* visited_next_instr_id) const { + if (visited_next_instr_id->find(next_instr_id) != + visited_next_instr_id->end()) { + return; + } + visited_next_instr_id->insert(next_instr_id); + // NOTE(Ruibiao): Though depend_op as next_instr is no_need_buffer, we should // also wait event for it. Because depend_op is used to build dependencies for // fused vars in some scenarios. In those cases, we do not know which vars may @@ -338,21 +347,26 @@ void StreamAnalyzer::AnalyseEventInfoForTwoInstructions( // between cur_instr and next_instr. for (size_t instr_id : run_type_info[next_instr_id][DownstreamRunType::kDirectRun]) { - AnalyseEventInfoForTwoInstructions( - instructions, run_type_info, cur_instr_id, instr_id, waiter_instr_ids); + AnalyseEventInfoForTwoInstructions(instructions, + run_type_info, + cur_instr_id, + instr_id, + waiter_instr_ids, + visited_next_instr_id); } } -// waiter instr should only wait events for the last recorder instrs in each -// stream void StreamAnalyzer::ShrinkEventInfo( const DependencyBuilder& dependency_builder, std::map>>* event_info) const { - for (auto& context_item : *event_info) { - for (auto& waiter_item : context_item.second) { - size_t waiter_instr_id = waiter_item.first; - std::set& recorder_instr_ids = waiter_item.second; + for (auto& item : *event_info) { + // shrink redundant recorders, waiter instrs should only wait for the last + // recorder instrs in each stream + std::map>& waiter_recorder_map = item.second; + for (auto& waiter_recorder : waiter_recorder_map) { + size_t waiter_instr_id = waiter_recorder.first; + std::set& recorder_instr_ids = waiter_recorder.second; std::set unnecessary_recorder_instr_ids; for (size_t cur_instr_id : recorder_instr_ids) { for (size_t next_instr_id : recorder_instr_ids) { @@ -370,6 +384,38 @@ void StreamAnalyzer::ShrinkEventInfo( recorder_instr_ids.erase(unnecessary_recorder_instr_id); } } + + // shrink redundant waiters, recorder instrs should only wait by the first + // waiter instrs in each stream + std::map> recorder_waiter_map; + for (auto& waiter_recorder : waiter_recorder_map) { + size_t waiter_instr_id = waiter_recorder.first; + std::set& recorder_instr_ids = waiter_recorder.second; + for (size_t record_instr_id : recorder_instr_ids) { + recorder_waiter_map[record_instr_id].insert(waiter_instr_id); + } + } + + for (auto& recorder_waiter : recorder_waiter_map) { + size_t recorder_instr_id = recorder_waiter.first; + std::set& waiter_instr_ids = recorder_waiter.second; + std::set unnecessary_waiter_instr_ids; + for (size_t cur_instr_id : waiter_instr_ids) { + for (size_t next_instr_id : waiter_instr_ids) { + if (dependency_builder.OpHappensBefore(cur_instr_id, next_instr_id)) { + unnecessary_waiter_instr_ids.insert(next_instr_id); + break; + } + } + } + + for (size_t unnecessary_wiater_instr_id : unnecessary_waiter_instr_ids) { + VLOG(8) << "Shrink event : " << recorder_instr_id << " -> " + << unnecessary_wiater_instr_id; + waiter_recorder_map[unnecessary_wiater_instr_id].erase( + recorder_instr_id); + } + } } } diff --git a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h index de0e6c741c245109e7a52931df6eedbdc8533ba1..1dd13c90da331171098fcfec69775e5818e5106e 100644 --- a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h +++ b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h @@ -64,7 +64,8 @@ class StreamAnalyzer { const std::vector>>& run_type_info, const size_t cur_instr_id, const size_t next_instr_id, - std::set* waiter_instr_ids) const; + std::set* waiter_instr_ids, + std::set* visited_next_instr_id) const; void ShrinkEventInfo( const DependencyBuilder& dependency_builder,