diff --git a/src/opr-mm/impl/group_manager.cpp b/src/opr-mm/impl/group_manager.cpp index ce0a7528cfd3db9d5dabacea4652807ac8224d95..c5ba66c13277e28bad5e8710d21abe2249439f1d 100644 --- a/src/opr-mm/impl/group_manager.cpp +++ b/src/opr-mm/impl/group_manager.cpp @@ -139,27 +139,29 @@ GroupManager::RegisterInfo GroupManager::opr_register(const std::string& key, return ret; } -std::vector GroupManager::gather_uid(const std::string& uid, - const std::string& key, uint32_t size, uint32_t rank) { - std::unique_lock lk{m_key2uids_mtx}; - if (m_key2uids_size[key] == 0) - m_key2uids[key].resize(size); - m_key2uids[key][rank] = uid; - m_key2uids_size[key]++; - if (m_key2uids_size[key] == size) { - m_key2uids_flag[key] = true; - m_gather_uid_cv.notify_all(); +void GroupManager::bcast_addr(std::string& master_ip, int& port, + const std::string& key, uint32_t size, uint32_t rank, uint32_t root) { + std::unique_lock lk{m_key2addr_mtx}; + if (rank == root) { + m_key2master_ip[key] = master_ip; + m_key2port[key] = port; + } + m_key2addr_size[key]++; + if (m_key2addr_size[key] == size) { + m_key2addr_flag[key] = true; + m_bcast_cv.notify_all(); } else { - m_gather_uid_cv.wait( - lk, [&] { return m_key2uids_flag.count(key) > 0; }); + m_bcast_cv.wait( + lk, [&] { return m_key2addr_flag.count(key) > 0; }); } - auto uids = m_key2uids[key]; - m_key2uids_size[key]--; - if (m_key2uids_size[key] == 0) { - m_key2uids.erase(key); - m_key2uids_flag.erase(key); + master_ip = m_key2master_ip[key]; + port = m_key2port[key]; + m_key2addr_size[key]--; + if (m_key2addr_size[key] == 0) { + m_key2master_ip.erase(key); + m_key2port.erase(key); + m_key2addr_flag.erase(key); } - return uids; } void GroupManager::set_output_shape(const std::string& key, diff --git a/src/opr-mm/impl/megray_helper.cpp b/src/opr-mm/impl/megray_helper.cpp index c96f49daf2a08820d8433bd9a11edaba0d2e3af1..6fc70b040b1ca5f7f59a379683b24e8655d37cf3 100644 --- a/src/opr-mm/impl/megray_helper.cpp +++ b/src/opr-mm/impl/megray_helper.cpp @@ -44,10 +44,22 @@ std::shared_ptr MegRayCommBuilder::get_megray_comm( std::shared_ptr comm; if (!sm_instance->find(hash, comm)) { + uint32_t root = 0; + std::string master_ip; + int port = 0; + if (rank == root) { + char* c = MegRay::get_host_ip(); + master_ip = std::string(c); + delete c; + port = MegRay::get_free_port(); + auto ret = MegRay::create_server(size, port); + mgb_assert(ret == MegRay::Status::MEGRAY_OK); + } + group_client->bcast_addr(master_ip, port, key, size, rank, root); + comm = MegRay::get_communicator(size, rank, backend); - auto uid = comm->get_uid(); - auto uids = group_client->gather_uid(uid, key, size, rank); - mgb_assert(comm->init(uids) == MegRay::Status::MEGRAY_OK); + auto ret = comm->init(master_ip.c_str(), port); + mgb_assert(ret == MegRay::Status::MEGRAY_OK); sm_instance->emplace(hash, comm); } return comm; diff --git a/src/opr-mm/impl/mm_handler.cpp b/src/opr-mm/impl/mm_handler.cpp index f60248fa427ef2d2f8f897a8b22d9fc25314b57f..1dc9381dcce6b2cd2b974bd946c8adf2f8ad566f 100644 --- a/src/opr-mm/impl/mm_handler.cpp +++ b/src/opr-mm/impl/mm_handler.cpp @@ -41,7 +41,7 @@ public: RUNSERVER(opr_register); RUNSERVER(set_output_shape); RUNSERVER(get_output_shape); - RUNSERVER(gather_uid); + RUNSERVER(bcast_addr); RUNSERVER(group_barrier); mgb_assert(false, "invalid rpc request"); } @@ -49,7 +49,7 @@ private: void opr_register(void* input_ptr, size_t input_len, std::string *output); void set_output_shape(void* input_ptr, size_t input_len, std::string *output); void get_output_shape(void* input_ptr, size_t input_len, std::string *output); - void gather_uid(void* input_ptr, size_t input_len, std::string *output); + void bcast_addr(void* input_ptr, size_t input_len, std::string *output); void group_barrier(void* input_ptr, size_t input_len, std::string *output); private: @@ -101,15 +101,14 @@ void GroupServerProxy::get_output_shape(void* input_ptr, size_t input_len, rsp.SerializeToString(output); } -void GroupServerProxy::gather_uid(void* input_ptr, size_t input_len, +void GroupServerProxy::bcast_addr(void* input_ptr, size_t input_len, std::string *output) { - INFO_INIT(mm_handler, GatherUid); - auto uid = req.uid(); - auto uids = m_mgr.gather_uid(uid, req.key(), req.size(), req.rank()); - for (size_t i = 0;i < uids.size();i++) { - rsp.add_uids(); - rsp.set_uids(i, uids[i].data(), uids[i].size()); - } + INFO_INIT(mm_handler, BcastAddr); + std::string master_ip = req.master_ip(); + int port = req.port(); + m_mgr.bcast_addr(master_ip, port, req.key(), req.size(), req.rank(), req.root()); + rsp.set_master_ip(master_ip); + rsp.set_port(port); rsp.SerializeToString(output); } @@ -184,19 +183,20 @@ TensorShape GroupClientProxy::get_output_shape(const std::string& key) { } return shape; } -std::vector GroupClientProxy::gather_uid(const std::string& uid, - const std::string& key, uint32_t size, uint32_t rank) { - INFO_INIT(mm_handler, gather_uid, GatherUid); - req.set_uid(uid.data(), uid.size()); + +void GroupClientProxy::bcast_addr(std::string& master_ip, + int& port, const std::string& key, uint32_t size, + uint32_t rank, uint32_t root) { + INFO_INIT(mm_handler, bcast_addr, BcastAddr); + req.set_master_ip(master_ip.data(), master_ip.size()); + req.set_port(port); req.set_key(key.data(), key.size()); req.set_size(size); req.set_rank(rank); + req.set_root(root); SOLVE_REQUEST(func_name, req, rsp); - std::vector rst; - for (size_t i = 0;i < size;i++) { - rst.push_back(rsp.uids(i)); - } - return rst; + master_ip = rsp.master_ip(); + port = rsp.port(); } uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) { diff --git a/src/opr-mm/include/megbrain/opr/group_manager.h b/src/opr-mm/include/megbrain/opr/group_manager.h index c8a172e9a2055011f2ba7ff8645ed0ca7b9a466c..35a633b414459d0bc05942d1d4d34571a96f9e14 100644 --- a/src/opr-mm/include/megbrain/opr/group_manager.h +++ b/src/opr-mm/include/megbrain/opr/group_manager.h @@ -82,9 +82,9 @@ class GroupManager { RegisterInfo opr_register(const std::string& key, size_t nr_devices, bool is_root, int rank, uint64_t comp_node_hash); - //! gather uids from all ranks - std::vector gather_uid(const std::string& uid, - const std::string& key, uint32_t size, uint32_t rank); + //! broadcast master_ip and port + void bcast_addr(std::string& master_ip, int& port, + const std::string& key, uint32_t size, uint32_t rank, uint32_t root); //! Set output shape of this key void set_output_shape(const std::string& key, const TensorShape& shape); @@ -102,12 +102,13 @@ class GroupManager { std::unordered_map m_key2group_info; std::mutex m_key2group_info_mtx; - //! key -> uid - std::unordered_map> m_key2uids; - std::unordered_map m_key2uids_size; - std::unordered_map m_key2uids_flag; - std::mutex m_key2uids_mtx; - std::condition_variable m_gather_uid_cv; + //! key -> addr + std::unordered_map m_key2master_ip; + std::unordered_map m_key2port; + std::unordered_map m_key2addr_size; + std::unordered_map m_key2addr_flag; + std::mutex m_key2addr_mtx; + std::condition_variable m_bcast_cv; //! barrier uint32_t m_barrier_size; @@ -133,8 +134,8 @@ class GroupClient { bool is_root, int rank, uint64_t comp_node_hash) = 0; - virtual std::vector gather_uid(const std::string& uid, - const std::string& key, uint32_t size, uint32_t rank) = 0; + virtual void bcast_addr(std::string& master_ip, int& port, + const std::string& key, uint32_t size, uint32_t rank, uint32_t root) = 0; virtual void set_output_shape(const std::string& key, const TensorShape& shape) = 0; diff --git a/src/opr-mm/include/megbrain/opr/mm_handler.h b/src/opr-mm/include/megbrain/opr/mm_handler.h index eaa33f90f05908cd06a4a478b2662971d9a3c3fa..c5e5c0f694d22104c56a7e82a5c0de1819a9b2e0 100644 --- a/src/opr-mm/include/megbrain/opr/mm_handler.h +++ b/src/opr-mm/include/megbrain/opr/mm_handler.h @@ -37,8 +37,8 @@ public: int rank, uint64_t comp_node_hash) override; - std::vector gather_uid(const std::string& uid, - const std::string& key, uint32_t size, uint32_t rank) override; + void bcast_addr(std::string& master_ip, int& port, const std::string& key, + uint32_t size, uint32_t rank, uint32_t root) override; void set_output_shape(const std::string& key, const TensorShape& shape) override; diff --git a/src/opr-mm/proto/mm_handler.proto b/src/opr-mm/proto/mm_handler.proto index 4102b870a2bd80d3e4dbe50a6e80f59662bf0196..00f1d662a5fff9f000941d0b6e3ce76c74592f3d 100644 --- a/src/opr-mm/proto/mm_handler.proto +++ b/src/opr-mm/proto/mm_handler.proto @@ -16,15 +16,18 @@ message OprRegisterResponse { int32 root_rank = 3; } -message GatherUidRequest { - bytes uid = 1; - string key = 2; - uint32 size = 3; - uint32 rank = 4; -} - -message GatherUidResponse { - repeated bytes uids = 1; +message BcastAddrRequest { + string master_ip = 1; + int32 port = 2; + string key = 3; + uint32 size = 4; + uint32 rank = 5; + uint32 root = 6; +} + +message BcastAddrResponse { + string master_ip = 1; + int32 port = 2; } message SetOutputShapeRequest { diff --git a/src/opr-mm/test/mock_client.h b/src/opr-mm/test/mock_client.h index 5ca014594978115c6baf393267be67a10c3ef295..fa02f3e6d74e2301d40fb9da947c119831c7d7b2 100644 --- a/src/opr-mm/test/mock_client.h +++ b/src/opr-mm/test/mock_client.h @@ -29,13 +29,14 @@ class MockGroupClient final : public opr::GroupClient { } RegisterInfo opr_register(const std::string& key, size_t nr_devices, - bool is_root, int rank, uint64_t comp_node_hash) { + bool is_root, int rank, uint64_t comp_node_hash) override { return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash); } - std::vector gather_uid(const std::string& uid, - const std::string& key, uint32_t size, uint32_t rank) { - return m_mgr.gather_uid(uid, key, size, rank); + void bcast_addr(std::string& master_ip, int& port, + const std::string& key, uint32_t size, + uint32_t rank, uint32_t root) override { + return m_mgr.bcast_addr(master_ip, port, key, size, rank, root); } void set_output_shape(const std::string& key, diff --git a/third_party/MegRay b/third_party/MegRay index d06c215dc1425fa932e20ecfaab7b07c0343a5bc..e14e4f84c1349598ba17c49923168db47a4e9642 160000 --- a/third_party/MegRay +++ b/third_party/MegRay @@ -1 +1 @@ -Subproject commit d06c215dc1425fa932e20ecfaab7b07c0343a5bc +Subproject commit e14e4f84c1349598ba17c49923168db47a4e9642