diff --git a/src/opr-mm/impl/megray_helper.cpp b/src/opr-mm/impl/megray_helper.cpp index b491d3f2403db1198d2a363e0a3f529f6beb6308..2465f7f5427b7cf46ec83966645e70a888e1362c 100644 --- a/src/opr-mm/impl/megray_helper.cpp +++ b/src/opr-mm/impl/megray_helper.cpp @@ -14,19 +14,35 @@ using namespace mgb; using namespace opr; +bool MegRayCommunicatorBuilder::find(uint64_t hash, std::shared_ptr& comm) { + std::unique_lock lk(m_mtx); + auto it = m_megray_comms.find(hash); + if (it != m_megray_comms.end()) { + comm = it->second; + return true; + } + return false; +} + +void MegRayCommunicatorBuilder::emplace(uint64_t hash, + std::shared_ptr comm) { + std::unique_lock lk(m_mtx); + m_megray_comms.emplace(hash, comm); +} + std::shared_ptr MegRayCommunicatorBuilder::get_megray_comm( uint64_t hash, std::string key, uint32_t size, uint32_t rank, MegRay::Backend backend, std::shared_ptr group_client) { - auto it = m_megray_comms.find(hash); - if (it == m_megray_comms.end()) { - auto comm = MegRay::get_communicator(size, rank, backend); + std::shared_ptr comm; + if (!find(hash, comm)) { + comm = MegRay::get_communicator(size, rank, backend); auto uid = comm->get_uid(); auto uids = group_client->gather_uid(uid, key, size, rank); - comm->init(uids); - m_megray_comms.emplace(hash, std::move(comm)); + mgb_assert(comm->init(uids) == MegRay::Status::MEGRAY_OK); + emplace(hash, comm); } - return m_megray_comms[hash]; + return comm; } MGB_TYPEINFO_OBJ_IMPL(MegRayCommunicatorBuilder); diff --git a/src/opr-mm/include/megbrain/opr/megray_helper.h b/src/opr-mm/include/megbrain/opr/megray_helper.h index 53dae9e47cb8cb9ee5f76f23a756a0e83dbbf5dd..255af039d71a1d1553998e500cb27d317276daf3 100644 --- a/src/opr-mm/include/megbrain/opr/megray_helper.h +++ b/src/opr-mm/include/megbrain/opr/megray_helper.h @@ -11,6 +11,8 @@ #pragma once +#include + #include "megbrain/utils/metahelper.h" #include "megbrain/opr/group_manager.h" #include "megray.h" @@ -25,7 +27,11 @@ class MegRayCommunicatorBuilder final : public mgb::UserDataContainer::UserData MGB_TYPEINFO_OBJ_DECL; private: + bool find(uint64_t hash, std::shared_ptr& comm); + void emplace(uint64_t hash, std::shared_ptr comm); + std::unordered_map> m_megray_comms; + std::mutex m_mtx; public: std::shared_ptr get_megray_comm(