From 4e0054f7b238cc9dc48653526dbd7df223b5fa8b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 14 May 2020 14:48:57 +0800 Subject: [PATCH] fix(mgb/opr-mm): fix megray_helper thread safety GitOrigin-RevId: f7b7c1d97ffab48b97a9ae8d694a7970e668afac --- src/opr-mm/impl/megray_helper.cpp | 28 +++++++++++++++---- .../include/megbrain/opr/megray_helper.h | 6 ++++ 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/src/opr-mm/impl/megray_helper.cpp b/src/opr-mm/impl/megray_helper.cpp index b491d3f2..2465f7f5 100644 --- a/src/opr-mm/impl/megray_helper.cpp +++ b/src/opr-mm/impl/megray_helper.cpp @@ -14,19 +14,35 @@ using namespace mgb; using namespace opr; +bool MegRayCommunicatorBuilder::find(uint64_t hash, std::shared_ptr& comm) { + std::unique_lock lk(m_mtx); + auto it = m_megray_comms.find(hash); + if (it != m_megray_comms.end()) { + comm = it->second; + return true; + } + return false; +} + +void MegRayCommunicatorBuilder::emplace(uint64_t hash, + std::shared_ptr comm) { + std::unique_lock lk(m_mtx); + m_megray_comms.emplace(hash, comm); +} + std::shared_ptr MegRayCommunicatorBuilder::get_megray_comm( uint64_t hash, std::string key, uint32_t size, uint32_t rank, MegRay::Backend backend, std::shared_ptr group_client) { - auto it = m_megray_comms.find(hash); - if (it == m_megray_comms.end()) { - auto comm = MegRay::get_communicator(size, rank, backend); + std::shared_ptr comm; + if (!find(hash, comm)) { + comm = MegRay::get_communicator(size, rank, backend); auto uid = comm->get_uid(); auto uids = group_client->gather_uid(uid, key, size, rank); - comm->init(uids); - m_megray_comms.emplace(hash, std::move(comm)); + mgb_assert(comm->init(uids) == MegRay::Status::MEGRAY_OK); + emplace(hash, comm); } - return m_megray_comms[hash]; + return comm; } MGB_TYPEINFO_OBJ_IMPL(MegRayCommunicatorBuilder); diff --git a/src/opr-mm/include/megbrain/opr/megray_helper.h b/src/opr-mm/include/megbrain/opr/megray_helper.h index 53dae9e4..255af039 100644 --- a/src/opr-mm/include/megbrain/opr/megray_helper.h +++ b/src/opr-mm/include/megbrain/opr/megray_helper.h @@ -11,6 +11,8 @@ #pragma once +#include + #include "megbrain/utils/metahelper.h" #include "megbrain/opr/group_manager.h" #include "megray.h" @@ -25,7 +27,11 @@ class MegRayCommunicatorBuilder final : public mgb::UserDataContainer::UserData MGB_TYPEINFO_OBJ_DECL; private: + bool find(uint64_t hash, std::shared_ptr& comm); + void emplace(uint64_t hash, std::shared_ptr comm); + std::unordered_map> m_megray_comms; + std::mutex m_mtx; public: std::shared_ptr get_megray_comm( -- GitLab