diff --git a/paddle/fluid/framework/new_executor/event_manager.cc b/paddle/fluid/framework/new_executor/event_manager.cc index 64018cea670e48eab1058700ef8a8ef827daeb37..bd83f49db1d0e3bc1b8d111c32048ebb9df5b930 100644 --- a/paddle/fluid/framework/new_executor/event_manager.cc +++ b/paddle/fluid/framework/new_executor/event_manager.cc @@ -24,9 +24,12 @@ void EventManager::WaitEvent(const Instruction& instruction, VLOG(3) << "Deal StreamWaitEventOrSync for " << instruction.kernel_func_.operator_base_->Type(); - auto* dev_ctx = instruction.dev_ctx_; - WaitOrSync(instruction.intput_events_, dev_ctx); + for (auto& event_iter : instruction.intput_events_) { + VLOG(3) << "wait var_id: " << event_iter.var_id_ + << " 's event with waiter_type: " << event_iter.waiter_type_; + event_iter.event_->Wait(event_iter.waiter_type_, instruction.dev_ctx_); + } } void EventManager::RecordEvent(const Instruction& instruction, @@ -40,18 +43,5 @@ void EventManager::RecordEvent(const Instruction& instruction, } } -void EventManager::WaitOrSync(const std::vector& events, - const platform::DeviceContext* dev_ctx) { - for (auto& event_iter : events) { - if (event_iter.is_sync_) { - VLOG(3) << "host sync wait in_var_id " << event_iter.var_id_; - event_iter.event_->Wait(platform::kCPU, dev_ctx); - } else { - VLOG(3) << "stream async wait in_var_id " << event_iter.var_id_; - event_iter.event_->Wait(platform::kCUDA, dev_ctx); - } - } -} - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/event_manager.h b/paddle/fluid/framework/new_executor/event_manager.h index 2289be346e21486e817813ca3acba93b2efcea94..d23c240469f964442609f08c6461b83371f181c3 100644 --- a/paddle/fluid/framework/new_executor/event_manager.h +++ b/paddle/fluid/framework/new_executor/event_manager.h @@ -24,10 +24,6 @@ class EventManager { const platform::Place& place); void WaitEvent(const Instruction& instruction, const platform::Place& place); - - private: - void WaitOrSync(const std::vector& events, - const platform::DeviceContext* dev_ctx); }; } // namespace framework diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 2da9c275c3dcd1d97f9a96a126208723f4f259dd..dad39f1471e210337f43c7757b4d087cd490204a 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -183,8 +183,7 @@ void InterpreterCore::Convert() { } } - stream_analyzer_.Schedule(vec_func_list_, filter_next, i, - &vec_instruction_); + stream_analyzer_.Schedule(filter_next, &vec_instruction_, i); for (auto inst_id : filter_next) { dependecy_count_[inst_id]++; diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index d7fb6b7fd91b9a5de7734e170e62d24d8228c579..974fcb0a24ac8b960e9230130e0a896c8e90c09d 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -99,7 +99,6 @@ class InterpreterCore { InterpreterCoreGarbageCollector gc_; std::vector gc_event_; - std::unique_ptr group_thread_pool_; }; } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index 3ab02ac2b7fa20f16e8a8ae93d56f7fb5b5181fb..91f334ffefd0554f906b83547d723337521b09f2 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -365,7 +365,9 @@ void build_op_func_list(const platform::Place& place, OpKernelComputeFunc(kernel_iter->second); copy_op_func_node.kernel_func_(copy_exec_ctx); VLOG(3) << "Run " << memcpy_op_type << " done."; - copy_op_func_node.type_ = OpFuncType::kQueueAsync; + // NOTE(Aurelius84): memcpy_op is expensive operation, so we tag them + // as kQueueSync and execute them in thread pool. + copy_op_func_node.type_ = OpFuncType::kQueueSync; copy_op_func_node.dev_ctx_ = dev_ctx; op_list->push_back(copy_op); vec_func_list->push_back(copy_op_func_node); diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index ca880c76c4d7010e9fda9edfa2b8f72cce856f84..9c0444b3157cb15ab51b74746b3f4a7cb5f9b5da 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -25,11 +25,6 @@ namespace paddle { namespace framework { -namespace interpretercore { -static constexpr char kMemcpyH2D[] = "memcpy_h2d"; -static constexpr char kMemcpyD2H[] = "memcpy_d2h"; -} // namespace interpretercore - using OpKernelComputeFunc = std::function; using OpKernelMap = std::unordered_map; @@ -496,11 +491,11 @@ struct NextInstruction { struct EventInter { explicit EventInter(size_t var_id, std::shared_ptr event, - bool is_sync) - : var_id_(var_id), event_(event), is_sync_(is_sync) {} + platform::DeviceType waiter_type) + : var_id_(var_id), event_(event), waiter_type_(waiter_type) {} size_t var_id_; std::shared_ptr event_; - bool is_sync_; + platform::DeviceType waiter_type_; }; struct InstructionInfo { @@ -543,5 +538,18 @@ struct OpFuncNode { OpFuncType type_; }; +namespace interpretercore { +static constexpr char kMemcpyH2D[] = "memcpy_h2d"; +static constexpr char kMemcpyD2H[] = "memcpy_d2h"; + +static bool IsMemcpyH2D(const Instruction& instr) { + return instr.kernel_func_.operator_base_->Type() == kMemcpyH2D; +} + +static bool IsMemcpyD2H(const Instruction& instr) { + return instr.kernel_func_.operator_base_->Type() == kMemcpyD2H; +} +} // namespace interpretercore + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/stream_analyzer.cc b/paddle/fluid/framework/new_executor/stream_analyzer.cc index 13bbda0f31f42081ea74243bc18bb3b8be0092ba..a9322d8fc88edb58af24a4ac560c635a94e29137 100644 --- a/paddle/fluid/framework/new_executor/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/stream_analyzer.cc @@ -22,7 +22,7 @@ namespace framework { * Parse the var_ids that need to be associated with an event. * The caller should guarantee front_op and back_op satisfy the * following conditions: - * 1. kQueueAsync -> kQueueAsync + * 1. kQueueSync -> kQueueAsync * 2. kQueueAsync -> kQueueSync * * For example: matmul(gpu) -> out_var -> memcpy_d2h @@ -48,7 +48,7 @@ std::vector StreamAnalyzer::ParseEventVarIds( void StreamAnalyzer::AssociateInputWithEvents( const std::vector& new_event_var_id, Instruction* next_instr, - bool is_sync) { + platform::DeviceType waiter_type) { for (auto var_id : new_event_var_id) { if (var_id2event_.count(var_id) == 0) { auto device_event = std::make_shared( @@ -57,52 +57,43 @@ void StreamAnalyzer::AssociateInputWithEvents( } // Add events for next_instr.inputs next_instr->intput_events_.emplace_back(var_id, var_id2event_.at(var_id), - is_sync); + waiter_type); } } -void StreamAnalyzer::Schedule(const std::vector& op_func_nodes, - const std::vector& downstream_ops, - size_t op_index, - std::vector* instructions) { - auto& op_func_type = op_func_nodes[op_index].type_; +void StreamAnalyzer::Schedule(const std::vector& downstream_ops, + std::vector* instructions, + size_t op_index) { auto& cur_instr = instructions->at(op_index); auto& next_instruction = cur_instr.next_instruction_; + std::vector event_var_ids; + for (auto next_op_id : downstream_ops) { + auto& next_instr = instructions->at(next_op_id); - if (op_func_type == OpFuncType::kQueueSync) { - // all downstream ops of kQueueSync can directly run, such as CPU -> Any - next_instruction.direct_run_ = downstream_ops; - } else { // kQueueAsync - std::vector event_var_ids; - for (auto next_op_id : downstream_ops) { - auto& next_instr = instructions->at(next_op_id); - // case 1: GPU -> GPU(same stream) - if (cur_instr.dev_ctx_ == next_instr.dev_ctx_) { - next_instruction.direct_run_.emplace_back(next_op_id); - continue; - } + if (IsDirectRun(cur_instr, next_instr)) { + next_instruction.direct_run_.emplace_back(next_op_id); + } else { // Always insert events between different stream auto new_event_var_ids = ParseEventVarIds(cur_instr, next_instr); event_var_ids.insert(event_var_ids.end(), new_event_var_ids.begin(), new_event_var_ids.end()); - bool is_sync = - (op_func_nodes[next_op_id].type_ == OpFuncType::kQueueSync); - AssociateInputWithEvents(new_event_var_ids, &next_instr, is_sync); + auto waiter_type = GetWaiterType(next_instr); + AssociateInputWithEvents(new_event_var_ids, &next_instr, waiter_type); - if (is_sync) { // GPU -> CPU + if (waiter_type == platform::kCPU) { // GPU -> CPU next_instruction.synchronize_run_.emplace_back(next_op_id); } else { // GPU -> GPU(different stream) next_instruction.event_wait_run_.emplace_back(next_op_id); } } - // Create events for these cross-stream vars - VLOG(3) << cur_instr.kernel_func_.operator_base_->Type() - << " event_var_ids.size: " << event_var_ids.size(); - for (auto var_id : event_var_ids) { - cur_instr.output_events_.emplace_back(var_id, var_id2event_.at(var_id), - false /*not used*/); - } + } + // Create events for these cross-stream vars + VLOG(3) << cur_instr.kernel_func_.operator_base_->Type() + << " event_var_ids.size: " << event_var_ids.size(); + for (auto var_id : event_var_ids) { + cur_instr.output_events_.emplace_back(var_id, var_id2event_.at(var_id), + platform::kCUDA /*not used*/); } } @@ -121,5 +112,27 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( return dev_ctx; } +/* + * NOTE(dev): The following cases are considered as directly run: + * + * 1. with same dev_ctx_, such as: CPU -> CPU, GPU -> GPU + * 2. D2H -> CPU + * 3. CPU -> H2D + */ +bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr, + const Instruction& next_instr) { + return (cur_instr.dev_ctx_ == next_instr.dev_ctx_ || + interpretercore::IsMemcpyD2H(cur_instr) || + interpretercore::IsMemcpyH2D(next_instr)); +} + +platform::DeviceType StreamAnalyzer::GetWaiterType(const Instruction& instr) { + if (instr.type_ == OpFuncType::kQueueSync) { + return platform::kCPU; + } else { + return platform::kCUDA; + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/stream_analyzer.h b/paddle/fluid/framework/new_executor/stream_analyzer.h index ee94c21fc529a854acc400ab3085fa7cd54ae26f..dc2af389e36b0f86e9b096f6309483d3ee0d2b2a 100644 --- a/paddle/fluid/framework/new_executor/stream_analyzer.h +++ b/paddle/fluid/framework/new_executor/stream_analyzer.h @@ -29,9 +29,8 @@ class StreamAnalyzer { ~StreamAnalyzer() {} - void Schedule(const std::vector& op_func_nodes, - const std::vector& downstream_ops, size_t op_index, - std::vector* instructions); + void Schedule(const std::vector& downstream_ops, + std::vector* instructions, size_t op_index); platform::DeviceContext* ParseDeviceContext(const OpFuncNode& op_func_node, const OperatorBase& op_base); @@ -41,7 +40,14 @@ class StreamAnalyzer { const Instruction& next_instr); void AssociateInputWithEvents(const std::vector& new_event_var_id, - Instruction* next_instr, bool is_sync); + Instruction* next_instr, + platform::DeviceType waiter_type); + + bool IsDirectRun(Instruction& cur_instr, // NOLINT + const Instruction& next_instr); + + platform::DeviceType GetWaiterType(const Instruction& instr); + platform::Place place_; platform::DeviceContextPool d2h_ctx_pool_; platform::DeviceContextPool h2d_ctx_pool_;