diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc index cf677348aecc701d454ba21ac662980444300bdd..847089dcbf324d96665beb5b7487a4580fed9292 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc @@ -344,6 +344,19 @@ void NewIRInterpreter::UpdateSyncOpNum() { VLOG(4) << "Update sync op num, sync op num is: " << sync_op_num_; } +void NewIRInterpreter::UpdateNcclOpNum() { + static std::set nccl_op_set = { + "pd.sync_batch_norm_", "pd.sync_batch_norm", "pd.sync_batch_norm_grad"}; + int64_t nccl_op_num = 0; + for (auto& ins : vec_instruction_base_) { + if (nccl_op_set.count(ins->Name())) { + nccl_op_num = nccl_op_num + 1; + } + } + nccl_op_num_ = nccl_op_num; + VLOG(4) << "Update nccl op num, nccl op num is: " << nccl_op_num; +} + // Note(zhangbo): // When there is a KQueueSync type OP in the model, breadth traversal is better // than depth traversal. For example: OP(O) ->(direct_run)-> OP(A) @@ -852,7 +865,7 @@ FetchList NewIRInterpreter::Run(const std::vector& feed_names, VLOG(4) << "Done PreAnalysis"; // Run - if (FLAGS_enable_new_ir_in_executor_trace_run || + if (FLAGS_enable_new_ir_in_executor_trace_run || nccl_op_num_ > 1 || ((execution_config_.used_for_jit || execution_config_.used_for_cinn) && (sync_op_num_ == 0))) { LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode " @@ -867,7 +880,7 @@ FetchList NewIRInterpreter::Run(const std::vector& feed_names, is_build_ = true; is_shared_results_build_ = true; } else { - if (FLAGS_enable_new_ir_in_executor_trace_run || + if (FLAGS_enable_new_ir_in_executor_trace_run || nccl_op_num_ > 1 || ((execution_config_.used_for_jit || execution_config_.used_for_cinn) && (sync_op_num_ == 0))) { TraceRunImpl(); @@ -1182,6 +1195,9 @@ void NewIRInterpreter::PreAnalysis() { UpdateSyncOpNum(); VLOG(4) << "Done UpdateSyncOpNum"; + + UpdateNcclOpNum(); + VLOG(4) << "Done UpdateNcclOpNum"; } ::ir::Value NewIRInterpreter::GetValueByName(const std::string& var_name) { diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.h b/paddle/fluid/framework/new_executor/new_ir_interpreter.h index 3669d8f8dd970b239d2332f8abf8a8c9825efc54..841e9136a2ecc1b693ad7d74c638f593e23ec94d 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.h +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.h @@ -84,6 +84,7 @@ class NewIRInterpreter : public InterpreterBaseImpl { private: // build graph void UpdateSyncOpNum(); + void UpdateNcclOpNum(); void AnalyseExecuteOrderForTrace( std::map> op_downstream_map, InstructionSchedulingPriorityLess compare); @@ -148,6 +149,7 @@ class NewIRInterpreter : public InterpreterBaseImpl { // used for Trace int64_t sync_op_num_{-1}; + int64_t nccl_op_num_{-1}; std::vector trace_execute_order_; std::vector hookfuncs_;