diff --git a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc index 03e3d6aaf40e4521ac2b858da431901f6930121b..920fec72bd43ae577874ec5113a6e62d079af52b 100644 --- a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc @@ -64,8 +64,6 @@ class ContextManager { inline std::string RunTypeToString(DownstreamRunType run_type) { if (run_type == DownstreamRunType::kDirectRun) { return "DirectRun"; - } else if (run_type == DownstreamRunType::kSyncRun) { - return "SyncRun"; } else { return "EventRun"; } @@ -238,16 +236,10 @@ void StreamAnalyzer::AnalyseAllEventInfo( event_info) const { for (size_t cur_instr_id = 0; cur_instr_id < instructions.size(); ++cur_instr_id) { - const Instruction& cur_instr = instructions[cur_instr_id]; - const std::vector>& next_instr_list = - run_type_info[cur_instr_id]; + const std::vector& next_instr_ids = + run_type_info[cur_instr_id][DownstreamRunType::kEventRun]; std::set waiter_instr_ids; - std::vector next_instr_ids = - next_instr_list[DownstreamRunType::kSyncRun]; - next_instr_ids.insert(next_instr_ids.end(), - next_instr_list[DownstreamRunType::kEventRun].begin(), - next_instr_list[DownstreamRunType::kEventRun].end()); for (size_t next_instr_id : next_instr_ids) { AnalyseEventInfoForTwoInstructions(instructions, run_type_info, @@ -257,8 +249,9 @@ void StreamAnalyzer::AnalyseAllEventInfo( } for (size_t waiter_instr_id : waiter_instr_ids) { - (*event_info)[&(cur_instr.DeviceContext())][waiter_instr_id].insert( - cur_instr_id); + (*event_info)[&(instructions[cur_instr_id].DeviceContext())] + [waiter_instr_id] + .insert(cur_instr_id); } } } @@ -284,7 +277,7 @@ void StreamAnalyzer::AnalyseAllRunType( } } -// The caller should guarantee cur_instr and next_instr is kSyncRun or kEventRun +// The caller should guarantee cur_instr and next_instr is kEventRun void StreamAnalyzer::AnalyseEventInfoForTwoInstructions( const std::vector& instructions, const std::vector>>& run_type_info, @@ -311,7 +304,6 @@ void StreamAnalyzer::AnalyseEventInfoForTwoInstructions( // can only add event for it with the help of depend_op. if (HasDataDependency(instructions[cur_instr_id], instructions[next_instr_id]) || - run_type_info[next_instr_id][DownstreamRunType::kSyncRun].size() || run_type_info[next_instr_id][DownstreamRunType::kEventRun].size() || instructions[next_instr_id].OpBase()->Type() == "depend") { waiter_instr_ids->insert(next_instr_id); @@ -319,8 +311,8 @@ void StreamAnalyzer::AnalyseEventInfoForTwoInstructions( } // NOTE(Ruibiao): If no data dependency from cur_instr to next_instr, and - // simultaneously next_instr has neither sync_run nor event_run downstream - // instr, we try to recursively add events between cur_instr and next_instr's + // simultaneously next_instr has no event_run downstream instr, we try to + // recursively add events between cur_instr and next_instr's // direct-run-instrs. This can delay the event wait and achieve better // scheduling performance in some scenarios. However, when next_instr has too // many direct-run-instrs, it may perform worse than add event directly @@ -393,15 +385,9 @@ DownstreamRunType StreamAnalyzer::AnalyseRunTypeForTwoInstructions( } } - if (cur_instr.KernelType() == OpFuncType::kGpuAsync) { - if (next_instr.KernelType() == OpFuncType::kCpuSync) { - return DownstreamRunType::kSyncRun; - } else { - // cross-stream: kGpuAsync -> kGpuSync, kGpuAsync -> kGpuSync - if (&cur_instr.DeviceContext() != &next_instr.DeviceContext()) { - return DownstreamRunType::kEventRun; - } - } + if (cur_instr.KernelType() == OpFuncType::kGpuAsync && + (&cur_instr.DeviceContext() != &next_instr.DeviceContext())) { + return DownstreamRunType::kEventRun; } return DownstreamRunType::kDirectRun; diff --git a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h index c76341907330160ce0271cc0a876d4fbb6917a0b..b9a228869d4c96b13a9f742ae1b32d693c339f55 100644 --- a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h +++ b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h @@ -26,7 +26,7 @@ namespace paddle { namespace framework { namespace interpreter { -enum DownstreamRunType { kDirectRun, kSyncRun, kEventRun }; +enum DownstreamRunType { kDirectRun, kEventRun }; class StreamAnalyzer { public: