diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index b26d213ddf7740f6de02d07e3ce7ed50ae64c646..7e16c3619d61c45c8df73cf05034521a01621f7c 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -376,7 +376,8 @@ void InterpreterCore::ExecuteInstructionList( vec_instr.size(), op_run_number_.load())); } -void InterpreterCore::RunNextInstruction(const Instruction& instr) { +void InterpreterCore::RunNextInstructions( + const Instruction& instr, std::queue* reserved_next_ops) { auto& next_instr = instr.next_instruction_; auto& atomic_deps = async_work_queue_.AtomicDeps(); auto IsReady = [&](size_t next_id) { @@ -395,12 +396,12 @@ void InterpreterCore::RunNextInstruction(const Instruction& instr) { // keep all async_ops running in current thread for (auto next_id : next_instr.direct_run_) { if (IsReady(next_id)) { - RunInstructionAsync(next_id); + reserved_next_ops->push(next_id); } } for (auto next_id : next_instr.event_wait_run_) { if (IsReady(next_id)) { - RunInstructionAsync(next_id); + reserved_next_ops->push(next_id); } } } else { @@ -428,25 +429,31 @@ void InterpreterCore::RunNextInstruction(const Instruction& instr) { [&, next_id] { RunInstructionAsync(next_id); }); } } - if (first_op != 0) RunInstructionAsync(first_op); + if (first_op != 0) reserved_next_ops->push(first_op); } } void InterpreterCore::RunInstructionAsync(size_t instr_id) { - auto& instr_node = vec_instruction_[instr_id]; - platform::RecordEvent instruction_event( - instr_node.kernel_func_.operator_base_->Type()); - event_manager_.WaitEvent(instr_node, place_); + std::queue ready_ops; + ready_ops.push(instr_id); + while (!ready_ops.empty()) { + instr_id = ready_ops.front(); + ready_ops.pop(); + auto& instr_node = vec_instruction_[instr_id]; + platform::RecordEvent instruction_event( + instr_node.kernel_func_.operator_base_->Type()); + event_manager_.WaitEvent(instr_node, place_); - RunInstruction(instr_node); + RunInstruction(instr_node); - event_manager_.RecordEvent(instr_node, place_); - op_run_number_.fetch_add(1, std::memory_order_relaxed); + event_manager_.RecordEvent(instr_node, place_); + op_run_number_.fetch_add(1, std::memory_order_relaxed); - // GC infomation - CheckGC(instr_id, instr_node.gc_check_var_list); + // GC infomation + CheckGC(instr_id, instr_node.gc_check_var_list); - RunNextInstruction(instr_node); + RunNextInstructions(instr_node, &ready_ops); + } } void InterpreterCore::CheckGC(size_t instr_id, diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index 47f23aff4f00e7b48bc25aa0d999dde1016cc07a..d6c916b9ddc4c808796b95abb255ed2f1af57dd8 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -68,7 +68,8 @@ class InterpreterCore { void CheckGC(size_t instr_id, const std::vector& gc_check_list); void RunInstructionAsync(size_t instr_id); - void RunNextInstruction(const Instruction& instr_id); + void RunNextInstructions(const Instruction& instr_id, + std::queue* reserved_next_ops); void AddFetch(const std::vector& fetch_names); void BuildSkipShareLoDInfo();