From 27205461ae4e75adbebc35128b4c64a66441dbb4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 15 Aug 2020 16:36:07 +0800 Subject: [PATCH] feat(mgb/opr-mm): add register info cache for multi-machine oprs GitOrigin-RevId: d5ae3c5a7c6d8d1939cd4cf92c25045a5e7159e0 --- src/opr-mm/impl/collective_comm.cpp | 16 ++++++++-- src/opr-mm/impl/group_manager.cpp | 16 ++++++++++ src/opr-mm/impl/io_remote.cpp | 32 +++++++++++++++---- .../include/megbrain/opr/group_manager.h | 16 ++++++++++ 4 files changed, 70 insertions(+), 10 deletions(-) diff --git a/src/opr-mm/impl/collective_comm.cpp b/src/opr-mm/impl/collective_comm.cpp index f01094b3c..efb7a75cf 100644 --- a/src/opr-mm/impl/collective_comm.cpp +++ b/src/opr-mm/impl/collective_comm.cpp @@ -687,11 +687,21 @@ SymbolVarArray CollectiveComm::make( void CollectiveComm::opr_register() { if (m_init) return; + auto&& comp_node = output(0)->comp_node(); + bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph; + struct GroupManager::RegisterInfo reg_info; - auto reg_info = m_group_client->opr_register( - m_key, m_nr_devices, m_is_root, m_rank, - comp_node.get_uid()); + if (use_cache and RegInfoCache::has_info(m_key)) { + reg_info = RegInfoCache::get_info(m_key); + } else { + reg_info = m_group_client->opr_register( + m_key, m_nr_devices, m_is_root, m_rank, + comp_node.get_uid()); + if (use_cache) { + RegInfoCache::set_info(m_key, reg_info); + } + } m_rank = reg_info.rank; m_root = reg_info.root_rank; diff --git a/src/opr-mm/impl/group_manager.cpp b/src/opr-mm/impl/group_manager.cpp index 02d262dc0..698bb2f02 100644 --- a/src/opr-mm/impl/group_manager.cpp +++ b/src/opr-mm/impl/group_manager.cpp @@ -205,4 +205,20 @@ uint32_t GroupManager::group_barrier(uint32_t size, uint32_t rank) { return m_barrier_size; } +void RegInfoCache::set_info(const std::string& key, + const GroupManager::RegisterInfo& info) { + std::unique_lock lock(RegInfoCache::mtx); + RegInfoCache::key2info[key] = info; +} + +bool RegInfoCache::has_info(const std::string& key) { + std::unique_lock lock(RegInfoCache::mtx); + return RegInfoCache::key2info.find(key) != RegInfoCache::key2info.end(); +} + +GroupManager::RegisterInfo RegInfoCache::get_info(const std::string& key) { + std::unique_lock lock(RegInfoCache::mtx); + return RegInfoCache::key2info[key]; +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr-mm/impl/io_remote.cpp b/src/opr-mm/impl/io_remote.cpp index 0dead49ca..d26b24917 100644 --- a/src/opr-mm/impl/io_remote.cpp +++ b/src/opr-mm/impl/io_remote.cpp @@ -53,10 +53,19 @@ SymbolVar RemoteSend::make(const std::string& key, SymbolVar var, void RemoteSend::scn_do_execute() { if (!m_init) { auto&& comp_node = output(0)->comp_node(); + bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph; + struct GroupManager::RegisterInfo reg_info; - // rank 0 for RemoteSend - auto reg_info = m_group_client->opr_register(m_key, 2, 0, false, - comp_node.get_uid()); + if (use_cache and RegInfoCache::has_info(m_key)) { + reg_info = RegInfoCache::get_info(m_key); + } else { + // rank 0 for RemoteSend + reg_info = m_group_client->opr_register(m_key, 2, 0, false, + comp_node.get_uid()); + if (use_cache) { + RegInfoCache::set_info(m_key, reg_info); + } + } m_megray_comm = MegRayCommBuilder::get_megray_comm( reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client); @@ -153,11 +162,20 @@ SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph, void RemoteRecv::scn_do_execute() { if (!m_init) { auto&& comp_node = output(0)->comp_node(); + bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph; + struct GroupManager::RegisterInfo reg_info; - // rank 1 for RemoteRecv - auto reg_info = m_group_client->opr_register( - m_key, 2, false, 1, - comp_node.get_uid()); + if (use_cache and RegInfoCache::has_info(m_key)) { + reg_info = RegInfoCache::get_info(m_key); + } else { + // rank 1 for RemoteRecv + reg_info = m_group_client->opr_register( + m_key, 2, false, 1, + comp_node.get_uid()); + if (use_cache) { + RegInfoCache::set_info(m_key, reg_info); + } + } m_megray_comm = MegRayCommBuilder::get_megray_comm( reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client); diff --git a/src/opr-mm/include/megbrain/opr/group_manager.h b/src/opr-mm/include/megbrain/opr/group_manager.h index 35a633b41..3ee7222f6 100644 --- a/src/opr-mm/include/megbrain/opr/group_manager.h +++ b/src/opr-mm/include/megbrain/opr/group_manager.h @@ -145,6 +145,22 @@ class GroupClient { virtual uint32_t group_barrier(uint32_t size, uint32_t rank) = 0; }; +/*! + * Cache RegisterInfo returned from GroupManager. This feature is only enabled + * in imperative runtime mode, so that multi-machine operators do not have to + * call opr_register repeatedly in each iter + */ +namespace RegInfoCache { + +static std::mutex mtx; +static std::unordered_map key2info; + +void set_info(const std::string& key, const GroupManager::RegisterInfo& info); +bool has_info(const std::string& key); +GroupManager::RegisterInfo get_info(const std::string& key); + +} // namespace RegInfoCache + } // namespace opr } // namespace mgb -- GitLab