未验证 提交 3779412c 编写于 作者: W wanghuancoder 提交者: GitHub

[IR] trace run when sync bn (#56362)

* trace run when sync bn
上级 5356f2e0
......@@ -344,6 +344,19 @@ void NewIRInterpreter::UpdateSyncOpNum() {
VLOG(4) << "Update sync op num, sync op num is: " << sync_op_num_;
}
void NewIRInterpreter::UpdateNcclOpNum() {
static std::set<std::string> nccl_op_set = {
"pd.sync_batch_norm_", "pd.sync_batch_norm", "pd.sync_batch_norm_grad"};
int64_t nccl_op_num = 0;
for (auto& ins : vec_instruction_base_) {
if (nccl_op_set.count(ins->Name())) {
nccl_op_num = nccl_op_num + 1;
}
}
nccl_op_num_ = nccl_op_num;
VLOG(4) << "Update nccl op num, nccl op num is: " << nccl_op_num;
}
// Note(zhangbo):
// When there is a KQueueSync type OP in the model, breadth traversal is better
// than depth traversal. For example: OP(O) ->(direct_run)-> OP(A)
......@@ -852,7 +865,7 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
VLOG(4) << "Done PreAnalysis";
// Run
if (FLAGS_enable_new_ir_in_executor_trace_run ||
if (FLAGS_enable_new_ir_in_executor_trace_run || nccl_op_num_ > 1 ||
((execution_config_.used_for_jit || execution_config_.used_for_cinn) &&
(sync_op_num_ == 0))) {
LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode "
......@@ -867,7 +880,7 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
is_build_ = true;
is_shared_results_build_ = true;
} else {
if (FLAGS_enable_new_ir_in_executor_trace_run ||
if (FLAGS_enable_new_ir_in_executor_trace_run || nccl_op_num_ > 1 ||
((execution_config_.used_for_jit || execution_config_.used_for_cinn) &&
(sync_op_num_ == 0))) {
TraceRunImpl();
......@@ -1182,6 +1195,9 @@ void NewIRInterpreter::PreAnalysis() {
UpdateSyncOpNum();
VLOG(4) << "Done UpdateSyncOpNum";
UpdateNcclOpNum();
VLOG(4) << "Done UpdateNcclOpNum";
}
::ir::Value NewIRInterpreter::GetValueByName(const std::string& var_name) {
......
......@@ -84,6 +84,7 @@ class NewIRInterpreter : public InterpreterBaseImpl {
private:
// build graph
void UpdateSyncOpNum();
void UpdateNcclOpNum();
void AnalyseExecuteOrderForTrace(
std::map<size_t, std::set<size_t>> op_downstream_map,
InstructionSchedulingPriorityLess compare);
......@@ -148,6 +149,7 @@ class NewIRInterpreter : public InterpreterBaseImpl {
// used for Trace
int64_t sync_op_num_{-1};
int64_t nccl_op_num_{-1};
std::vector<size_t> trace_execute_order_;
std::vector<HookFunc> hookfuncs_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册