提交 b38e8225 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

refactor(mgb/opr-mm): update megray communicator init interface and fix ci

GitOrigin-RevId: 55c59879f2cda27f678aaa55c44120c698709a07
上级 5e912edd
...@@ -139,27 +139,29 @@ GroupManager::RegisterInfo GroupManager::opr_register(const std::string& key, ...@@ -139,27 +139,29 @@ GroupManager::RegisterInfo GroupManager::opr_register(const std::string& key,
return ret; return ret;
} }
std::vector<std::string> GroupManager::gather_uid(const std::string& uid, void GroupManager::bcast_addr(std::string& master_ip, int& port,
const std::string& key, uint32_t size, uint32_t rank) { const std::string& key, uint32_t size, uint32_t rank, uint32_t root) {
std::unique_lock<std::mutex> lk{m_key2uids_mtx}; std::unique_lock<std::mutex> lk{m_key2addr_mtx};
if (m_key2uids_size[key] == 0) if (rank == root) {
m_key2uids[key].resize(size); m_key2master_ip[key] = master_ip;
m_key2uids[key][rank] = uid; m_key2port[key] = port;
m_key2uids_size[key]++; }
if (m_key2uids_size[key] == size) { m_key2addr_size[key]++;
m_key2uids_flag[key] = true; if (m_key2addr_size[key] == size) {
m_gather_uid_cv.notify_all(); m_key2addr_flag[key] = true;
m_bcast_cv.notify_all();
} else { } else {
m_gather_uid_cv.wait( m_bcast_cv.wait(
lk, [&] { return m_key2uids_flag.count(key) > 0; }); lk, [&] { return m_key2addr_flag.count(key) > 0; });
} }
auto uids = m_key2uids[key]; master_ip = m_key2master_ip[key];
m_key2uids_size[key]--; port = m_key2port[key];
if (m_key2uids_size[key] == 0) { m_key2addr_size[key]--;
m_key2uids.erase(key); if (m_key2addr_size[key] == 0) {
m_key2uids_flag.erase(key); m_key2master_ip.erase(key);
m_key2port.erase(key);
m_key2addr_flag.erase(key);
} }
return uids;
} }
void GroupManager::set_output_shape(const std::string& key, void GroupManager::set_output_shape(const std::string& key,
......
...@@ -44,10 +44,22 @@ std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm( ...@@ -44,10 +44,22 @@ std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm(
std::shared_ptr<MegRay::Communicator> comm; std::shared_ptr<MegRay::Communicator> comm;
if (!sm_instance->find(hash, comm)) { if (!sm_instance->find(hash, comm)) {
uint32_t root = 0;
std::string master_ip;
int port = 0;
if (rank == root) {
char* c = MegRay::get_host_ip();
master_ip = std::string(c);
delete c;
port = MegRay::get_free_port();
auto ret = MegRay::create_server(size, port);
mgb_assert(ret == MegRay::Status::MEGRAY_OK);
}
group_client->bcast_addr(master_ip, port, key, size, rank, root);
comm = MegRay::get_communicator(size, rank, backend); comm = MegRay::get_communicator(size, rank, backend);
auto uid = comm->get_uid(); auto ret = comm->init(master_ip.c_str(), port);
auto uids = group_client->gather_uid(uid, key, size, rank); mgb_assert(ret == MegRay::Status::MEGRAY_OK);
mgb_assert(comm->init(uids) == MegRay::Status::MEGRAY_OK);
sm_instance->emplace(hash, comm); sm_instance->emplace(hash, comm);
} }
return comm; return comm;
......
...@@ -41,7 +41,7 @@ public: ...@@ -41,7 +41,7 @@ public:
RUNSERVER(opr_register); RUNSERVER(opr_register);
RUNSERVER(set_output_shape); RUNSERVER(set_output_shape);
RUNSERVER(get_output_shape); RUNSERVER(get_output_shape);
RUNSERVER(gather_uid); RUNSERVER(bcast_addr);
RUNSERVER(group_barrier); RUNSERVER(group_barrier);
mgb_assert(false, "invalid rpc request"); mgb_assert(false, "invalid rpc request");
} }
...@@ -49,7 +49,7 @@ private: ...@@ -49,7 +49,7 @@ private:
void opr_register(void* input_ptr, size_t input_len, std::string *output); void opr_register(void* input_ptr, size_t input_len, std::string *output);
void set_output_shape(void* input_ptr, size_t input_len, std::string *output); void set_output_shape(void* input_ptr, size_t input_len, std::string *output);
void get_output_shape(void* input_ptr, size_t input_len, std::string *output); void get_output_shape(void* input_ptr, size_t input_len, std::string *output);
void gather_uid(void* input_ptr, size_t input_len, std::string *output); void bcast_addr(void* input_ptr, size_t input_len, std::string *output);
void group_barrier(void* input_ptr, size_t input_len, std::string *output); void group_barrier(void* input_ptr, size_t input_len, std::string *output);
private: private:
...@@ -101,15 +101,14 @@ void GroupServerProxy::get_output_shape(void* input_ptr, size_t input_len, ...@@ -101,15 +101,14 @@ void GroupServerProxy::get_output_shape(void* input_ptr, size_t input_len,
rsp.SerializeToString(output); rsp.SerializeToString(output);
} }
void GroupServerProxy::gather_uid(void* input_ptr, size_t input_len, void GroupServerProxy::bcast_addr(void* input_ptr, size_t input_len,
std::string *output) { std::string *output) {
INFO_INIT(mm_handler, GatherUid); INFO_INIT(mm_handler, BcastAddr);
auto uid = req.uid(); std::string master_ip = req.master_ip();
auto uids = m_mgr.gather_uid(uid, req.key(), req.size(), req.rank()); int port = req.port();
for (size_t i = 0;i < uids.size();i++) { m_mgr.bcast_addr(master_ip, port, req.key(), req.size(), req.rank(), req.root());
rsp.add_uids(); rsp.set_master_ip(master_ip);
rsp.set_uids(i, uids[i].data(), uids[i].size()); rsp.set_port(port);
}
rsp.SerializeToString(output); rsp.SerializeToString(output);
} }
...@@ -184,19 +183,20 @@ TensorShape GroupClientProxy::get_output_shape(const std::string& key) { ...@@ -184,19 +183,20 @@ TensorShape GroupClientProxy::get_output_shape(const std::string& key) {
} }
return shape; return shape;
} }
std::vector<std::string> GroupClientProxy::gather_uid(const std::string& uid,
const std::string& key, uint32_t size, uint32_t rank) { void GroupClientProxy::bcast_addr(std::string& master_ip,
INFO_INIT(mm_handler, gather_uid, GatherUid); int& port, const std::string& key, uint32_t size,
req.set_uid(uid.data(), uid.size()); uint32_t rank, uint32_t root) {
INFO_INIT(mm_handler, bcast_addr, BcastAddr);
req.set_master_ip(master_ip.data(), master_ip.size());
req.set_port(port);
req.set_key(key.data(), key.size()); req.set_key(key.data(), key.size());
req.set_size(size); req.set_size(size);
req.set_rank(rank); req.set_rank(rank);
req.set_root(root);
SOLVE_REQUEST(func_name, req, rsp); SOLVE_REQUEST(func_name, req, rsp);
std::vector<std::string> rst; master_ip = rsp.master_ip();
for (size_t i = 0;i < size;i++) { port = rsp.port();
rst.push_back(rsp.uids(i));
}
return rst;
} }
uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) { uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) {
......
...@@ -82,9 +82,9 @@ class GroupManager { ...@@ -82,9 +82,9 @@ class GroupManager {
RegisterInfo opr_register(const std::string& key, size_t nr_devices, RegisterInfo opr_register(const std::string& key, size_t nr_devices,
bool is_root, int rank, uint64_t comp_node_hash); bool is_root, int rank, uint64_t comp_node_hash);
//! gather uids from all ranks //! broadcast master_ip and port
std::vector<std::string> gather_uid(const std::string& uid, void bcast_addr(std::string& master_ip, int& port,
const std::string& key, uint32_t size, uint32_t rank); const std::string& key, uint32_t size, uint32_t rank, uint32_t root);
//! Set output shape of this key //! Set output shape of this key
void set_output_shape(const std::string& key, const TensorShape& shape); void set_output_shape(const std::string& key, const TensorShape& shape);
...@@ -102,12 +102,13 @@ class GroupManager { ...@@ -102,12 +102,13 @@ class GroupManager {
std::unordered_map<std::string, GroupInfo> m_key2group_info; std::unordered_map<std::string, GroupInfo> m_key2group_info;
std::mutex m_key2group_info_mtx; std::mutex m_key2group_info_mtx;
//! key -> uid //! key -> addr
std::unordered_map<std::string, std::vector<std::string>> m_key2uids; std::unordered_map<std::string, std::string> m_key2master_ip;
std::unordered_map<std::string, uint32_t> m_key2uids_size; std::unordered_map<std::string, int> m_key2port;
std::unordered_map<std::string, bool> m_key2uids_flag; std::unordered_map<std::string, uint32_t> m_key2addr_size;
std::mutex m_key2uids_mtx; std::unordered_map<std::string, bool> m_key2addr_flag;
std::condition_variable m_gather_uid_cv; std::mutex m_key2addr_mtx;
std::condition_variable m_bcast_cv;
//! barrier //! barrier
uint32_t m_barrier_size; uint32_t m_barrier_size;
...@@ -133,8 +134,8 @@ class GroupClient { ...@@ -133,8 +134,8 @@ class GroupClient {
bool is_root, int rank, bool is_root, int rank,
uint64_t comp_node_hash) = 0; uint64_t comp_node_hash) = 0;
virtual std::vector<std::string> gather_uid(const std::string& uid, virtual void bcast_addr(std::string& master_ip, int& port,
const std::string& key, uint32_t size, uint32_t rank) = 0; const std::string& key, uint32_t size, uint32_t rank, uint32_t root) = 0;
virtual void set_output_shape(const std::string& key, virtual void set_output_shape(const std::string& key,
const TensorShape& shape) = 0; const TensorShape& shape) = 0;
......
...@@ -37,8 +37,8 @@ public: ...@@ -37,8 +37,8 @@ public:
int rank, int rank,
uint64_t comp_node_hash) override; uint64_t comp_node_hash) override;
std::vector<std::string> gather_uid(const std::string& uid, void bcast_addr(std::string& master_ip, int& port, const std::string& key,
const std::string& key, uint32_t size, uint32_t rank) override; uint32_t size, uint32_t rank, uint32_t root) override;
void set_output_shape(const std::string& key, void set_output_shape(const std::string& key,
const TensorShape& shape) override; const TensorShape& shape) override;
......
...@@ -16,15 +16,18 @@ message OprRegisterResponse { ...@@ -16,15 +16,18 @@ message OprRegisterResponse {
int32 root_rank = 3; int32 root_rank = 3;
} }
message GatherUidRequest { message BcastAddrRequest {
bytes uid = 1; string master_ip = 1;
string key = 2; int32 port = 2;
uint32 size = 3; string key = 3;
uint32 rank = 4; uint32 size = 4;
} uint32 rank = 5;
uint32 root = 6;
message GatherUidResponse { }
repeated bytes uids = 1;
message BcastAddrResponse {
string master_ip = 1;
int32 port = 2;
} }
message SetOutputShapeRequest { message SetOutputShapeRequest {
......
...@@ -29,13 +29,14 @@ class MockGroupClient final : public opr::GroupClient { ...@@ -29,13 +29,14 @@ class MockGroupClient final : public opr::GroupClient {
} }
RegisterInfo opr_register(const std::string& key, size_t nr_devices, RegisterInfo opr_register(const std::string& key, size_t nr_devices,
bool is_root, int rank, uint64_t comp_node_hash) { bool is_root, int rank, uint64_t comp_node_hash) override {
return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash); return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash);
} }
std::vector<std::string> gather_uid(const std::string& uid, void bcast_addr(std::string& master_ip, int& port,
const std::string& key, uint32_t size, uint32_t rank) { const std::string& key, uint32_t size,
return m_mgr.gather_uid(uid, key, size, rank); uint32_t rank, uint32_t root) override {
return m_mgr.bcast_addr(master_ip, port, key, size, rank, root);
} }
void set_output_shape(const std::string& key, void set_output_shape(const std::string& key,
......
Subproject commit d06c215dc1425fa932e20ecfaab7b07c0343a5bc Subproject commit e14e4f84c1349598ba17c49923168db47a4e9642
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册