提交 4e0054f7 编写于 作者: M Megvii Engine Team

fix(mgb/opr-mm): fix megray_helper thread safety

GitOrigin-RevId: f7b7c1d97ffab48b97a9ae8d694a7970e668afac
上级 35d46dbb
...@@ -14,19 +14,35 @@ ...@@ -14,19 +14,35 @@
using namespace mgb; using namespace mgb;
using namespace opr; using namespace opr;
bool MegRayCommunicatorBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) {
std::unique_lock<std::mutex> lk(m_mtx);
auto it = m_megray_comms.find(hash);
if (it != m_megray_comms.end()) {
comm = it->second;
return true;
}
return false;
}
void MegRayCommunicatorBuilder::emplace(uint64_t hash,
std::shared_ptr<MegRay::Communicator> comm) {
std::unique_lock<std::mutex> lk(m_mtx);
m_megray_comms.emplace(hash, comm);
}
std::shared_ptr<MegRay::Communicator> MegRayCommunicatorBuilder::get_megray_comm( std::shared_ptr<MegRay::Communicator> MegRayCommunicatorBuilder::get_megray_comm(
uint64_t hash, std::string key, uint32_t size, uint32_t rank, uint64_t hash, std::string key, uint32_t size, uint32_t rank,
MegRay::Backend backend, MegRay::Backend backend,
std::shared_ptr<mgb::opr::GroupClient> group_client) { std::shared_ptr<mgb::opr::GroupClient> group_client) {
auto it = m_megray_comms.find(hash); std::shared_ptr<MegRay::Communicator> comm;
if (it == m_megray_comms.end()) { if (!find(hash, comm)) {
auto comm = MegRay::get_communicator(size, rank, backend); comm = MegRay::get_communicator(size, rank, backend);
auto uid = comm->get_uid(); auto uid = comm->get_uid();
auto uids = group_client->gather_uid(uid, key, size, rank); auto uids = group_client->gather_uid(uid, key, size, rank);
comm->init(uids); mgb_assert(comm->init(uids) == MegRay::Status::MEGRAY_OK);
m_megray_comms.emplace(hash, std::move(comm)); emplace(hash, comm);
} }
return m_megray_comms[hash]; return comm;
} }
MGB_TYPEINFO_OBJ_IMPL(MegRayCommunicatorBuilder); MGB_TYPEINFO_OBJ_IMPL(MegRayCommunicatorBuilder);
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#pragma once #pragma once
#include <mutex>
#include "megbrain/utils/metahelper.h" #include "megbrain/utils/metahelper.h"
#include "megbrain/opr/group_manager.h" #include "megbrain/opr/group_manager.h"
#include "megray.h" #include "megray.h"
...@@ -25,7 +27,11 @@ class MegRayCommunicatorBuilder final : public mgb::UserDataContainer::UserData ...@@ -25,7 +27,11 @@ class MegRayCommunicatorBuilder final : public mgb::UserDataContainer::UserData
MGB_TYPEINFO_OBJ_DECL; MGB_TYPEINFO_OBJ_DECL;
private: private:
bool find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm);
void emplace(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm);
std::unordered_map<uint64_t, std::shared_ptr<MegRay::Communicator>> m_megray_comms; std::unordered_map<uint64_t, std::shared_ptr<MegRay::Communicator>> m_megray_comms;
std::mutex m_mtx;
public: public:
std::shared_ptr<MegRay::Communicator> get_megray_comm( std::shared_ptr<MegRay::Communicator> get_megray_comm(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册