diff --git a/src/opr-mm/impl/collective_comm.cpp b/src/opr-mm/impl/collective_comm.cpp index 28a7412623fc78077dee1e610a00eb4c7a3abc88..f77f382792e8430700035ca3292be0a827273b38 100644 --- a/src/opr-mm/impl/collective_comm.cpp +++ b/src/opr-mm/impl/collective_comm.cpp @@ -410,6 +410,11 @@ CollectiveComm::CollectiveComm( ModeTrait::from_mode(param.mode).add_output_var(this, inp_cn); m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); + const char* c_debug = MGB_GETENV("MGE_MM_OPR_DEBUG"); + if (c_debug != nullptr and strcmp(c_debug, "1") == 0) { + m_debug_mode = true; + } + add_equivalence_component>(&m_param); add_equivalence_component>(&m_nr_devices); m_hash = XXHash{}.update(key.data(), key.size() * sizeof(char)).digest(); @@ -536,6 +541,11 @@ void CollectiveComm::do_execute(ExecEnv& env) { opr_register(); cn.activate(); + if (m_debug_mode) { + mgb_log_debug("collective comm: executing %s, rank = %d, key = %s", + cname(), rank(), key().c_str()); + } + owner_graph()->event().signal_inplace(this, cn); trait.exec(this); owner_graph()->event().signal_inplace(this, cn); diff --git a/src/opr-mm/include/megbrain/opr/collective_comm.h b/src/opr-mm/include/megbrain/opr/collective_comm.h index 0af65cfc1ddde111dd5baa79fd7ba994e3fc6b0e..9edc37d8f18a6d362d67a82dda5b4a1ce9d7b71a 100644 --- a/src/opr-mm/include/megbrain/opr/collective_comm.h +++ b/src/opr-mm/include/megbrain/opr/collective_comm.h @@ -123,6 +123,7 @@ private: std::shared_ptr m_megray_ctx; std::shared_ptr m_megray_comm; bool m_init = false; + bool m_debug_mode = false; //! dev buffers for each outputs SmallVector> m_dev_buffers;