提交 b87a9a74 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mgb/opr-mm): add debug log for CollectiveComm opr

GitOrigin-RevId: 11bcf110434bf7750a14f6a6041273a3d2c0fcff
上级 d6b098a0
......@@ -410,6 +410,11 @@ CollectiveComm::CollectiveComm(
ModeTrait::from_mode(param.mode).add_output_var(this, inp_cn);
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));
const char* c_debug = MGB_GETENV("MGE_MM_OPR_DEBUG");
if (c_debug != nullptr and strcmp(c_debug, "1") == 0) {
m_debug_mode = true;
}
add_equivalence_component<PODHash<Param>>(&m_param);
add_equivalence_component<PODHash<size_t>>(&m_nr_devices);
m_hash = XXHash{}.update(key.data(), key.size() * sizeof(char)).digest();
......@@ -536,6 +541,11 @@ void CollectiveComm::do_execute(ExecEnv& env) {
opr_register();
cn.activate();
if (m_debug_mode) {
mgb_log_debug("collective comm: executing %s, rank = %d, key = %s",
cname(), rank(), key().c_str());
}
owner_graph()->event().signal_inplace<cg::event::BeforeKernel>(this, cn);
trait.exec(this);
owner_graph()->event().signal_inplace<cg::event::AfterKernel>(this, cn);
......
......@@ -123,6 +123,7 @@ private:
std::shared_ptr<MegRay::Context> m_megray_ctx;
std::shared_ptr<MegRay::Communicator> m_megray_comm;
bool m_init = false;
bool m_debug_mode = false;
//! dev buffers for each outputs
SmallVector<std::shared_ptr<DeviceTensorND>> m_dev_buffers;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册