diff --git a/src/opr-mm/impl/collective_comm.cpp b/src/opr-mm/impl/collective_comm.cpp index f77f382792e8430700035ca3292be0a827273b38..5f6166e16db3d962dac8671ace548b1202d020b2 100644 --- a/src/opr-mm/impl/collective_comm.cpp +++ b/src/opr-mm/impl/collective_comm.cpp @@ -408,7 +408,6 @@ CollectiveComm::CollectiveComm( "CollectiveComm inputs should not contain duplicated input device"); ModeTrait::from_mode(param.mode).add_output_var(this, inp_cn); - m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); const char* c_debug = MGB_GETENV("MGE_MM_OPR_DEBUG"); if (c_debug != nullptr and strcmp(c_debug, "1") == 0) { @@ -469,6 +468,8 @@ void CollectiveComm::opr_register() { hash, m_key, m_nr_devices, m_rank, get_megray_backend(m_backend), m_group_client); + m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); + m_init = true; }