提交 b38e8225 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

refactor(mgb/opr-mm): update megray communicator init interface and fix ci

GitOrigin-RevId: 55c59879f2cda27f678aaa55c44120c698709a07
上级 5e912edd
......@@ -139,27 +139,29 @@ GroupManager::RegisterInfo GroupManager::opr_register(const std::string& key,
return ret;
}
std::vector<std::string> GroupManager::gather_uid(const std::string& uid,
const std::string& key, uint32_t size, uint32_t rank) {
std::unique_lock<std::mutex> lk{m_key2uids_mtx};
if (m_key2uids_size[key] == 0)
m_key2uids[key].resize(size);
m_key2uids[key][rank] = uid;
m_key2uids_size[key]++;
if (m_key2uids_size[key] == size) {
m_key2uids_flag[key] = true;
m_gather_uid_cv.notify_all();
void GroupManager::bcast_addr(std::string& master_ip, int& port,
const std::string& key, uint32_t size, uint32_t rank, uint32_t root) {
std::unique_lock<std::mutex> lk{m_key2addr_mtx};
if (rank == root) {
m_key2master_ip[key] = master_ip;
m_key2port[key] = port;
}
m_key2addr_size[key]++;
if (m_key2addr_size[key] == size) {
m_key2addr_flag[key] = true;
m_bcast_cv.notify_all();
} else {
m_gather_uid_cv.wait(
lk, [&] { return m_key2uids_flag.count(key) > 0; });
m_bcast_cv.wait(
lk, [&] { return m_key2addr_flag.count(key) > 0; });
}
auto uids = m_key2uids[key];
m_key2uids_size[key]--;
if (m_key2uids_size[key] == 0) {
m_key2uids.erase(key);
m_key2uids_flag.erase(key);
master_ip = m_key2master_ip[key];
port = m_key2port[key];
m_key2addr_size[key]--;
if (m_key2addr_size[key] == 0) {
m_key2master_ip.erase(key);
m_key2port.erase(key);
m_key2addr_flag.erase(key);
}
return uids;
}
void GroupManager::set_output_shape(const std::string& key,
......
......@@ -44,10 +44,22 @@ std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm(
std::shared_ptr<MegRay::Communicator> comm;
if (!sm_instance->find(hash, comm)) {
uint32_t root = 0;
std::string master_ip;
int port = 0;
if (rank == root) {
char* c = MegRay::get_host_ip();
master_ip = std::string(c);
delete c;
port = MegRay::get_free_port();
auto ret = MegRay::create_server(size, port);
mgb_assert(ret == MegRay::Status::MEGRAY_OK);
}
group_client->bcast_addr(master_ip, port, key, size, rank, root);
comm = MegRay::get_communicator(size, rank, backend);
auto uid = comm->get_uid();
auto uids = group_client->gather_uid(uid, key, size, rank);
mgb_assert(comm->init(uids) == MegRay::Status::MEGRAY_OK);
auto ret = comm->init(master_ip.c_str(), port);
mgb_assert(ret == MegRay::Status::MEGRAY_OK);
sm_instance->emplace(hash, comm);
}
return comm;
......
......@@ -41,7 +41,7 @@ public:
RUNSERVER(opr_register);
RUNSERVER(set_output_shape);
RUNSERVER(get_output_shape);
RUNSERVER(gather_uid);
RUNSERVER(bcast_addr);
RUNSERVER(group_barrier);
mgb_assert(false, "invalid rpc request");
}
......@@ -49,7 +49,7 @@ private:
void opr_register(void* input_ptr, size_t input_len, std::string *output);
void set_output_shape(void* input_ptr, size_t input_len, std::string *output);
void get_output_shape(void* input_ptr, size_t input_len, std::string *output);
void gather_uid(void* input_ptr, size_t input_len, std::string *output);
void bcast_addr(void* input_ptr, size_t input_len, std::string *output);
void group_barrier(void* input_ptr, size_t input_len, std::string *output);
private:
......@@ -101,15 +101,14 @@ void GroupServerProxy::get_output_shape(void* input_ptr, size_t input_len,
rsp.SerializeToString(output);
}
void GroupServerProxy::gather_uid(void* input_ptr, size_t input_len,
void GroupServerProxy::bcast_addr(void* input_ptr, size_t input_len,
std::string *output) {
INFO_INIT(mm_handler, GatherUid);
auto uid = req.uid();
auto uids = m_mgr.gather_uid(uid, req.key(), req.size(), req.rank());
for (size_t i = 0;i < uids.size();i++) {
rsp.add_uids();
rsp.set_uids(i, uids[i].data(), uids[i].size());
}
INFO_INIT(mm_handler, BcastAddr);
std::string master_ip = req.master_ip();
int port = req.port();
m_mgr.bcast_addr(master_ip, port, req.key(), req.size(), req.rank(), req.root());
rsp.set_master_ip(master_ip);
rsp.set_port(port);
rsp.SerializeToString(output);
}
......@@ -184,19 +183,20 @@ TensorShape GroupClientProxy::get_output_shape(const std::string& key) {
}
return shape;
}
std::vector<std::string> GroupClientProxy::gather_uid(const std::string& uid,
const std::string& key, uint32_t size, uint32_t rank) {
INFO_INIT(mm_handler, gather_uid, GatherUid);
req.set_uid(uid.data(), uid.size());
void GroupClientProxy::bcast_addr(std::string& master_ip,
int& port, const std::string& key, uint32_t size,
uint32_t rank, uint32_t root) {
INFO_INIT(mm_handler, bcast_addr, BcastAddr);
req.set_master_ip(master_ip.data(), master_ip.size());
req.set_port(port);
req.set_key(key.data(), key.size());
req.set_size(size);
req.set_rank(rank);
req.set_root(root);
SOLVE_REQUEST(func_name, req, rsp);
std::vector<std::string> rst;
for (size_t i = 0;i < size;i++) {
rst.push_back(rsp.uids(i));
}
return rst;
master_ip = rsp.master_ip();
port = rsp.port();
}
uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) {
......
......@@ -82,9 +82,9 @@ class GroupManager {
RegisterInfo opr_register(const std::string& key, size_t nr_devices,
bool is_root, int rank, uint64_t comp_node_hash);
//! gather uids from all ranks
std::vector<std::string> gather_uid(const std::string& uid,
const std::string& key, uint32_t size, uint32_t rank);
//! broadcast master_ip and port
void bcast_addr(std::string& master_ip, int& port,
const std::string& key, uint32_t size, uint32_t rank, uint32_t root);
//! Set output shape of this key
void set_output_shape(const std::string& key, const TensorShape& shape);
......@@ -102,12 +102,13 @@ class GroupManager {
std::unordered_map<std::string, GroupInfo> m_key2group_info;
std::mutex m_key2group_info_mtx;
//! key -> uid
std::unordered_map<std::string, std::vector<std::string>> m_key2uids;
std::unordered_map<std::string, uint32_t> m_key2uids_size;
std::unordered_map<std::string, bool> m_key2uids_flag;
std::mutex m_key2uids_mtx;
std::condition_variable m_gather_uid_cv;
//! key -> addr
std::unordered_map<std::string, std::string> m_key2master_ip;
std::unordered_map<std::string, int> m_key2port;
std::unordered_map<std::string, uint32_t> m_key2addr_size;
std::unordered_map<std::string, bool> m_key2addr_flag;
std::mutex m_key2addr_mtx;
std::condition_variable m_bcast_cv;
//! barrier
uint32_t m_barrier_size;
......@@ -133,8 +134,8 @@ class GroupClient {
bool is_root, int rank,
uint64_t comp_node_hash) = 0;
virtual std::vector<std::string> gather_uid(const std::string& uid,
const std::string& key, uint32_t size, uint32_t rank) = 0;
virtual void bcast_addr(std::string& master_ip, int& port,
const std::string& key, uint32_t size, uint32_t rank, uint32_t root) = 0;
virtual void set_output_shape(const std::string& key,
const TensorShape& shape) = 0;
......
......@@ -37,8 +37,8 @@ public:
int rank,
uint64_t comp_node_hash) override;
std::vector<std::string> gather_uid(const std::string& uid,
const std::string& key, uint32_t size, uint32_t rank) override;
void bcast_addr(std::string& master_ip, int& port, const std::string& key,
uint32_t size, uint32_t rank, uint32_t root) override;
void set_output_shape(const std::string& key,
const TensorShape& shape) override;
......
......@@ -16,15 +16,18 @@ message OprRegisterResponse {
int32 root_rank = 3;
}
message GatherUidRequest {
bytes uid = 1;
string key = 2;
uint32 size = 3;
uint32 rank = 4;
}
message GatherUidResponse {
repeated bytes uids = 1;
message BcastAddrRequest {
string master_ip = 1;
int32 port = 2;
string key = 3;
uint32 size = 4;
uint32 rank = 5;
uint32 root = 6;
}
message BcastAddrResponse {
string master_ip = 1;
int32 port = 2;
}
message SetOutputShapeRequest {
......
......@@ -29,13 +29,14 @@ class MockGroupClient final : public opr::GroupClient {
}
RegisterInfo opr_register(const std::string& key, size_t nr_devices,
bool is_root, int rank, uint64_t comp_node_hash) {
bool is_root, int rank, uint64_t comp_node_hash) override {
return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash);
}
std::vector<std::string> gather_uid(const std::string& uid,
const std::string& key, uint32_t size, uint32_t rank) {
return m_mgr.gather_uid(uid, key, size, rank);
void bcast_addr(std::string& master_ip, int& port,
const std::string& key, uint32_t size,
uint32_t rank, uint32_t root) override {
return m_mgr.bcast_addr(master_ip, port, key, size, rank, root);
}
void set_output_shape(const std::string& key,
......
Subproject commit d06c215dc1425fa932e20ecfaab7b07c0343a5bc
Subproject commit e14e4f84c1349598ba17c49923168db47a4e9642
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册