提交 27205461 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mgb/opr-mm): add register info cache for multi-machine oprs

GitOrigin-RevId: d5ae3c5a7c6d8d1939cd4cf92c25045a5e7159e0
上级 a7ff580e
......@@ -687,11 +687,21 @@ SymbolVarArray CollectiveComm::make(
void CollectiveComm::opr_register() {
if (m_init)
return;
auto&& comp_node = output(0)->comp_node();
bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph;
struct GroupManager::RegisterInfo reg_info;
auto reg_info = m_group_client->opr_register(
m_key, m_nr_devices, m_is_root, m_rank,
comp_node.get_uid());
if (use_cache and RegInfoCache::has_info(m_key)) {
reg_info = RegInfoCache::get_info(m_key);
} else {
reg_info = m_group_client->opr_register(
m_key, m_nr_devices, m_is_root, m_rank,
comp_node.get_uid());
if (use_cache) {
RegInfoCache::set_info(m_key, reg_info);
}
}
m_rank = reg_info.rank;
m_root = reg_info.root_rank;
......
......@@ -205,4 +205,20 @@ uint32_t GroupManager::group_barrier(uint32_t size, uint32_t rank) {
return m_barrier_size;
}
void RegInfoCache::set_info(const std::string& key,
const GroupManager::RegisterInfo& info) {
std::unique_lock<std::mutex> lock(RegInfoCache::mtx);
RegInfoCache::key2info[key] = info;
}
bool RegInfoCache::has_info(const std::string& key) {
std::unique_lock<std::mutex> lock(RegInfoCache::mtx);
return RegInfoCache::key2info.find(key) != RegInfoCache::key2info.end();
}
GroupManager::RegisterInfo RegInfoCache::get_info(const std::string& key) {
std::unique_lock<std::mutex> lock(RegInfoCache::mtx);
return RegInfoCache::key2info[key];
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -53,10 +53,19 @@ SymbolVar RemoteSend::make(const std::string& key, SymbolVar var,
void RemoteSend::scn_do_execute() {
if (!m_init) {
auto&& comp_node = output(0)->comp_node();
bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph;
struct GroupManager::RegisterInfo reg_info;
// rank 0 for RemoteSend
auto reg_info = m_group_client->opr_register(m_key, 2, 0, false,
comp_node.get_uid());
if (use_cache and RegInfoCache::has_info(m_key)) {
reg_info = RegInfoCache::get_info(m_key);
} else {
// rank 0 for RemoteSend
reg_info = m_group_client->opr_register(m_key, 2, 0, false,
comp_node.get_uid());
if (use_cache) {
RegInfoCache::set_info(m_key, reg_info);
}
}
m_megray_comm = MegRayCommBuilder::get_megray_comm(
reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client);
......@@ -153,11 +162,20 @@ SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph,
void RemoteRecv::scn_do_execute() {
if (!m_init) {
auto&& comp_node = output(0)->comp_node();
bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph;
struct GroupManager::RegisterInfo reg_info;
// rank 1 for RemoteRecv
auto reg_info = m_group_client->opr_register(
m_key, 2, false, 1,
comp_node.get_uid());
if (use_cache and RegInfoCache::has_info(m_key)) {
reg_info = RegInfoCache::get_info(m_key);
} else {
// rank 1 for RemoteRecv
reg_info = m_group_client->opr_register(
m_key, 2, false, 1,
comp_node.get_uid());
if (use_cache) {
RegInfoCache::set_info(m_key, reg_info);
}
}
m_megray_comm = MegRayCommBuilder::get_megray_comm(
reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client);
......
......@@ -145,6 +145,22 @@ class GroupClient {
virtual uint32_t group_barrier(uint32_t size, uint32_t rank) = 0;
};
/*!
* Cache RegisterInfo returned from GroupManager. This feature is only enabled
* in imperative runtime mode, so that multi-machine operators do not have to
* call opr_register repeatedly in each iter
*/
namespace RegInfoCache {
static std::mutex mtx;
static std::unordered_map<std::string, GroupManager::RegisterInfo> key2info;
void set_info(const std::string& key, const GroupManager::RegisterInfo& info);
bool has_info(const std::string& key);
GroupManager::RegisterInfo get_info(const std::string& key);
} // namespace RegInfoCache
} // namespace opr
} // namespace mgb
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册