提交 4e0054f7 编写于 作者: M Megvii Engine Team

fix(mgb/opr-mm): fix megray_helper thread safety

GitOrigin-RevId: f7b7c1d97ffab48b97a9ae8d694a7970e668afac
上级 35d46dbb
......@@ -14,19 +14,35 @@
using namespace mgb;
using namespace opr;
bool MegRayCommunicatorBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) {
std::unique_lock<std::mutex> lk(m_mtx);
auto it = m_megray_comms.find(hash);
if (it != m_megray_comms.end()) {
comm = it->second;
return true;
}
return false;
}
void MegRayCommunicatorBuilder::emplace(uint64_t hash,
std::shared_ptr<MegRay::Communicator> comm) {
std::unique_lock<std::mutex> lk(m_mtx);
m_megray_comms.emplace(hash, comm);
}
std::shared_ptr<MegRay::Communicator> MegRayCommunicatorBuilder::get_megray_comm(
uint64_t hash, std::string key, uint32_t size, uint32_t rank,
MegRay::Backend backend,
std::shared_ptr<mgb::opr::GroupClient> group_client) {
auto it = m_megray_comms.find(hash);
if (it == m_megray_comms.end()) {
auto comm = MegRay::get_communicator(size, rank, backend);
std::shared_ptr<MegRay::Communicator> comm;
if (!find(hash, comm)) {
comm = MegRay::get_communicator(size, rank, backend);
auto uid = comm->get_uid();
auto uids = group_client->gather_uid(uid, key, size, rank);
comm->init(uids);
m_megray_comms.emplace(hash, std::move(comm));
mgb_assert(comm->init(uids) == MegRay::Status::MEGRAY_OK);
emplace(hash, comm);
}
return m_megray_comms[hash];
return comm;
}
MGB_TYPEINFO_OBJ_IMPL(MegRayCommunicatorBuilder);
......
......@@ -11,6 +11,8 @@
#pragma once
#include <mutex>
#include "megbrain/utils/metahelper.h"
#include "megbrain/opr/group_manager.h"
#include "megray.h"
......@@ -25,7 +27,11 @@ class MegRayCommunicatorBuilder final : public mgb::UserDataContainer::UserData
MGB_TYPEINFO_OBJ_DECL;
private:
bool find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm);
void emplace(uint64_t hash, std::shared_ptr<MegRay::Communicator> comm);
std::unordered_map<uint64_t, std::shared_ptr<MegRay::Communicator>> m_megray_comms;
std::mutex m_mtx;
public:
std::shared_ptr<MegRay::Communicator> get_megray_comm(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册