diff --git a/src/opr-mm/impl/collective_comm.cpp b/src/opr-mm/impl/collective_comm.cpp index f01094b3c7364253f19827b8084d57334c377af9..efb7a75cfe057e66b34c4b3df3f57e9954181f58 100644 --- a/src/opr-mm/impl/collective_comm.cpp +++ b/src/opr-mm/impl/collective_comm.cpp @@ -687,11 +687,21 @@ SymbolVarArray CollectiveComm::make( void CollectiveComm::opr_register() { if (m_init) return; + auto&& comp_node = output(0)->comp_node(); + bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph; + struct GroupManager::RegisterInfo reg_info; - auto reg_info = m_group_client->opr_register( - m_key, m_nr_devices, m_is_root, m_rank, - comp_node.get_uid()); + if (use_cache and RegInfoCache::has_info(m_key)) { + reg_info = RegInfoCache::get_info(m_key); + } else { + reg_info = m_group_client->opr_register( + m_key, m_nr_devices, m_is_root, m_rank, + comp_node.get_uid()); + if (use_cache) { + RegInfoCache::set_info(m_key, reg_info); + } + } m_rank = reg_info.rank; m_root = reg_info.root_rank; diff --git a/src/opr-mm/impl/group_manager.cpp b/src/opr-mm/impl/group_manager.cpp index 02d262dc045f28a73c00faed3b457fd10c5942cd..698bb2f02ccd48b4b3e392114e6d37a98eb826cb 100644 --- a/src/opr-mm/impl/group_manager.cpp +++ b/src/opr-mm/impl/group_manager.cpp @@ -205,4 +205,20 @@ uint32_t GroupManager::group_barrier(uint32_t size, uint32_t rank) { return m_barrier_size; } +void RegInfoCache::set_info(const std::string& key, + const GroupManager::RegisterInfo& info) { + std::unique_lock lock(RegInfoCache::mtx); + RegInfoCache::key2info[key] = info; +} + +bool RegInfoCache::has_info(const std::string& key) { + std::unique_lock lock(RegInfoCache::mtx); + return RegInfoCache::key2info.find(key) != RegInfoCache::key2info.end(); +} + +GroupManager::RegisterInfo RegInfoCache::get_info(const std::string& key) { + std::unique_lock lock(RegInfoCache::mtx); + return RegInfoCache::key2info[key]; +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr-mm/impl/io_remote.cpp b/src/opr-mm/impl/io_remote.cpp index 0dead49ca42b747e9e19446ff5680d2a9c1bf875..d26b24917868e5d658df790e470654ac853ac224 100644 --- a/src/opr-mm/impl/io_remote.cpp +++ b/src/opr-mm/impl/io_remote.cpp @@ -53,10 +53,19 @@ SymbolVar RemoteSend::make(const std::string& key, SymbolVar var, void RemoteSend::scn_do_execute() { if (!m_init) { auto&& comp_node = output(0)->comp_node(); + bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph; + struct GroupManager::RegisterInfo reg_info; - // rank 0 for RemoteSend - auto reg_info = m_group_client->opr_register(m_key, 2, 0, false, - comp_node.get_uid()); + if (use_cache and RegInfoCache::has_info(m_key)) { + reg_info = RegInfoCache::get_info(m_key); + } else { + // rank 0 for RemoteSend + reg_info = m_group_client->opr_register(m_key, 2, 0, false, + comp_node.get_uid()); + if (use_cache) { + RegInfoCache::set_info(m_key, reg_info); + } + } m_megray_comm = MegRayCommBuilder::get_megray_comm( reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client); @@ -153,11 +162,20 @@ SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, void RemoteRecv::scn_do_execute() { if (!m_init) { auto&& comp_node = output(0)->comp_node(); + bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph; + struct GroupManager::RegisterInfo reg_info; - // rank 1 for RemoteRecv - auto reg_info = m_group_client->opr_register( - m_key, 2, false, 1, - comp_node.get_uid()); + if (use_cache and RegInfoCache::has_info(m_key)) { + reg_info = RegInfoCache::get_info(m_key); + } else { + // rank 1 for RemoteRecv + reg_info = m_group_client->opr_register( + m_key, 2, false, 1, + comp_node.get_uid()); + if (use_cache) { + RegInfoCache::set_info(m_key, reg_info); + } + } m_megray_comm = MegRayCommBuilder::get_megray_comm( reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client); diff --git a/src/opr-mm/include/megbrain/opr/group_manager.h b/src/opr-mm/include/megbrain/opr/group_manager.h index 35a633b414459d0bc05942d1d4d34571a96f9e14..3ee7222f6a979fbc76914abc7588e9bb2ce2107c 100644 --- a/src/opr-mm/include/megbrain/opr/group_manager.h +++ b/src/opr-mm/include/megbrain/opr/group_manager.h @@ -145,6 +145,22 @@ class GroupClient { virtual uint32_t group_barrier(uint32_t size, uint32_t rank) = 0; }; +/*! + * Cache RegisterInfo returned from GroupManager. This feature is only enabled + * in imperative runtime mode, so that multi-machine operators do not have to + * call opr_register repeatedly in each iter + */ +namespace RegInfoCache { + +static std::mutex mtx; +static std::unordered_map key2info; + +void set_info(const std::string& key, const GroupManager::RegisterInfo& info); +bool has_info(const std::string& key); +GroupManager::RegisterInfo get_info(const std::string& key); + +} // namespace RegInfoCache + } // namespace opr } // namespace mgb