From c7e6c658fdf2305d5a60317317233575ade5b1c8 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 29 May 2020 19:39:41 +0800 Subject: [PATCH] refactor(mge/distribute): use is_root (and rank) in stead of rank and root at collective comm GitOrigin-RevId: dccdb715533576db97bbc21fe61e640141c1e8b6 --- .../megengine/distributed/functional.py | 67 +++++---------- python_module/megengine/distributed/helper.py | 9 +- .../megengine/optimizer/optimizer.py | 14 ++- python_module/src/cpp/opr_defs.cpp | 28 +++--- python_module/src/cpp/opr_defs.h | 16 ++-- .../test/unit/distributed/test_functional.py | 4 +- src/opr-mm/impl/collective_comm.cpp | 49 ++++++----- src/opr-mm/impl/collective_comm.oprdecl | 8 +- src/opr-mm/impl/group_manager.cpp | 86 +++++++++++++------ src/opr-mm/impl/io_remote.cpp | 19 ++-- src/opr-mm/impl/mm_handler.cpp | 25 +++--- .../include/megbrain/opr/collective_comm.h | 49 ++++++----- .../include/megbrain/opr/group_manager.h | 42 ++++++--- src/opr-mm/include/megbrain/opr/mm_handler.h | 7 +- src/opr-mm/proto/mm_handler.proto | 11 ++- src/opr-mm/test/collective_comm.cpp | 70 +++++++-------- src/opr-mm/test/io_remote.cpp | 10 ++- 17 files changed, 286 insertions(+), 228 deletions(-) diff --git a/python_module/megengine/distributed/functional.py b/python_module/megengine/distributed/functional.py index ec8d1c558..ee404d126 100644 --- a/python_module/megengine/distributed/functional.py +++ b/python_module/megengine/distributed/functional.py @@ -26,25 +26,17 @@ def reduce_sum( tensor: Tensor, key: str, nr_ranks: Optional[int] = None, - rank: Optional[int] = None, - root: Optional[int] = 0, + is_root: Optional[bool] = None, ) -> Tensor: """Create reduce_sum operator for collective communication :param tensor: input tensor :param key: unique identifier for collective communication :param nr_ranks: number of ranks, use util.get_world_size() as default - :param rank: rank of the current process, use util.get_rank() as default - :param root: rank of root node, use 0 as default + :param is_root: whether this is a root node """ return _collective_comm( - tensor, - key, - CollParam.Mode.REDUCE_SUM, - nr_ranks, - rank, - root, - device=tensor.device, + tensor, key, CollParam.Mode.REDUCE_SUM, nr_ranks, is_root, device=tensor.device, ) @@ -52,24 +44,21 @@ def broadcast( tensor: Tensor, key: str, nr_ranks: Optional[int] = None, - rank: Optional[int] = None, - root: Optional[int] = 0, + is_root: Optional[bool] = None, ) -> Tensor: """Create broadcast operator for collective communication :param tensor: input tensor :param key: unique identifier for collective communication :param nr_ranks: number of ranks, use util.get_world_size() as default - :param rank: rank of the current process, use util.get_rank() as default - :param root: rank of root node, use 0 as default + :param is_root: whether this is a root node """ if key is None: key = tensor._symvar.name + if is_root is None: + is_root = get_rank() == 0 - if rank is None: - rank = get_rank() - - if rank == root: + if is_root: inp = tensor else: inp = tensor._symvar.owner_graph @@ -79,8 +68,7 @@ def broadcast( key, CollParam.Mode.BROADCAST, nr_ranks, - rank, - root, + is_root, dtype=tensor.dtype, device=tensor.device, ) @@ -94,9 +82,9 @@ def all_gather( :param tensor: input tensor :param key: unique identifier for collective communication :param nr_ranks: number of ranks, use util.get_world_size() as default - :param rank: rank of the current process, use util.get_rank() as default + :param rank: rank of this node """ - return _collective_comm(tensor, key, CollParam.Mode.ALL_GATHER, nr_ranks, rank, 0) + return _collective_comm(tensor, key, CollParam.Mode.ALL_GATHER, nr_ranks, rank=rank) def reduce_scatter_sum( @@ -107,69 +95,58 @@ def reduce_scatter_sum( :param tensor: input tensor :param key: unique identifier for collective communication :param nr_ranks: number of ranks, use util.get_world_size() as default - :param rank: rank of the current process, use util.get_rank() as default + :param rank: rank of this node """ return _collective_comm( - tensor, key, CollParam.Mode.REDUCE_SCATTER_SUM, nr_ranks, rank + tensor, key, CollParam.Mode.REDUCE_SCATTER_SUM, nr_ranks, rank=rank, ) -def all_reduce_sum( - tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None -) -> Tensor: +def all_reduce_sum(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor: """Create all_reduce_sum operator for collective communication :param tensor: input tensor :param key: unique identifier for collective communication :param nr_ranks: number of ranks, use util.get_world_size() as default - :param rank: rank of the current process, use util.get_rank() as default """ - return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_SUM, nr_ranks, rank) + return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_SUM, nr_ranks) -def all_reduce_max( - tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None -) -> Tensor: +def all_reduce_max(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor: """Create all_reduce_max operator for collective communication :param tensor: input tensor :param key: unique identifier for collective communication :param nr_ranks: number of ranks, use util.get_world_size() as default - :param rank: rank of the current process, use util.get_rank() as default """ - return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MAX, nr_ranks, rank) + return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MAX, nr_ranks) -def all_reduce_min( - tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None -) -> Tensor: +def all_reduce_min(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor: """Create all_reduce_min operator for collective communication :param tensor: input tensor :param key: unique identifier for collective communication :param nr_ranks: number of ranks, use util.get_world_size() as default - :param rank: rank of the current process, use util.get_rank() as default """ - return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MIN, nr_ranks, rank) + return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MIN, nr_ranks) def bcast_param( inp: Union[Buffer, Parameter], key: str, nr_ranks: Optional[int] = None, - rank: Optional[int] = None, - root: Optional[int] = 0, + is_root: Optional[bool] = None, ) -> None: """Broadcast parameters among devices :param inp: input Buffer or Parameter to be synchronized :param key: unique identifier for collective communication :param nr_ranks: number of ranks, use util.get_world_size() as default - :param rank: rank of the current process, use util.get_rank() as default - :param root: rank of root node, use 0 as default + :param is_root: whether this is a root node """ if not is_distributed(): return assert isinstance(inp, (Buffer, Parameter)) - bcast_res = broadcast(inp, key, nr_ranks, rank, root) + bcast_res = broadcast(inp, key, nr_ranks, is_root) add_update(inp, bcast_res, alpha=0) diff --git a/python_module/megengine/distributed/helper.py b/python_module/megengine/distributed/helper.py index 4514228eb..0010b3735 100644 --- a/python_module/megengine/distributed/helper.py +++ b/python_module/megengine/distributed/helper.py @@ -19,8 +19,8 @@ def collective_comm_symvar( key: str, op: CollParam.Mode, nr_ranks: Optional[int] = None, + is_root: Optional[bool] = None, rank: Optional[int] = None, - root: Optional[int] = 0, dtype: Optional[type] = None, device: Optional[mgb.CompNode] = None, comp_graph: Optional[mgb.CompGraph] = None, @@ -31,8 +31,7 @@ def collective_comm_symvar( :param key: unique identifier for collective communication :param op: mode of collective communication :param nr_ranks: number of ranks, use util.get_world_size() as default - :param rank: rank of the current process, use util.get_rank() as default - :param root: rank of root node, use 0 as default + :param is_root: whether this node is root node :param dtype: output data type, use dtype of inp as default :param device: output comp node, use comp node of inp as default :param comp_graph: output comp graph, use comp graph of inp as default @@ -41,8 +40,8 @@ def collective_comm_symvar( inp, key=str(key), nr_devices=nr_ranks if nr_ranks is not None else get_world_size(), - rank=rank if rank is not None else get_rank(), - root=root, + is_root=is_root if is_root is not None else (get_rank() == 0), + rank=rank if rank is not None else -1, server_addr=get_master_ip(), port=get_master_port(), param=CollParam(mode=op), diff --git a/python_module/megengine/optimizer/optimizer.py b/python_module/megengine/optimizer/optimizer.py index 295d1b6a0..a6c9a5e2f 100644 --- a/python_module/megengine/optimizer/optimizer.py +++ b/python_module/megengine/optimizer/optimizer.py @@ -17,7 +17,13 @@ import numpy as np from .._internal.config import opr_priority_scope from ..core import Buffer, Parameter, Tensor, TensorDict from ..core.graph import get_default_graph -from ..distributed import all_reduce_sum, bcast_param, get_world_size, is_distributed +from ..distributed import ( + all_reduce_sum, + bcast_param, + get_rank, + get_world_size, + is_distributed, +) from ..distributed.util import get_group_id from ..functional import add_update from ..functional import grad as grad_func @@ -222,7 +228,11 @@ class Optimizer(metaclass=ABCMeta): def bcast_param(self): for group in self.param_groups: for param in group["params"]: - bcast_param(param, "bcast_param_" + str(get_group_id())) + bcast_param( + param, + "bcast_param_" + str(get_group_id()), + is_root=(get_rank() == 0), + ) def state_dict(self) -> Dict: r"""Export the optimizer state. diff --git a/python_module/src/cpp/opr_defs.cpp b/python_module/src/cpp/opr_defs.cpp index be74e8c93..b2a7bc43b 100644 --- a/python_module/src/cpp/opr_defs.cpp +++ b/python_module/src/cpp/opr_defs.cpp @@ -93,10 +93,9 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, } SymbolVar _Opr::collective_comm_with_input( - SymbolVar inpvar, const std::string& key, - const size_t nr_devices, const uint32_t rank, const uint32_t root, - const std::string& server_addr, const int port, - PyObject* params, PyObject* dtype, + SymbolVar inpvar, const std::string& key, const size_t nr_devices, + const bool is_root, const int rank, const std::string& server_addr, + const int port, PyObject* params, PyObject* dtype, const std::string& backend, SharedND* output_buf, const OperatorNodeConfig& config, const SharedScalar& disable) { SymbolVarArray inputs(1, inpvar); @@ -111,15 +110,15 @@ SymbolVar _Opr::collective_comm_with_input( if (dtype != Py_None) { _dtype = npy::dtype_np2mgb(dtype); } - return CollectiveComm::make(inputs, graph, key, nr_devices, rank, root, group_mgr, - dev_buffer_arr, param, _dtype, backend, config, disable.get_val())[0]; + return CollectiveComm::make(inputs, graph, key, nr_devices, is_root, rank, + group_mgr, dev_buffer_arr, param, _dtype, + backend, config, disable.get_val())[0]; } SymbolVar _Opr::collective_comm_without_input( - CompGraph& cg, const std::string& key, - const size_t nr_devices, const uint32_t rank, const uint32_t root, - const std::string& server_addr, const int port, - PyObject* params, PyObject* dtype, + CompGraph& cg, const std::string& key, const size_t nr_devices, + const bool is_root, const int rank, const std::string& server_addr, + const int port, PyObject* params, PyObject* dtype, const std::string& backend, SharedND* output_buf, const OperatorNodeConfig& config, const SharedScalar& disable) { SymbolVarArray inputs; @@ -134,8 +133,9 @@ SymbolVar _Opr::collective_comm_without_input( if (dtype != Py_None) { _dtype = npy::dtype_np2mgb(dtype); } - return CollectiveComm::make(inputs, &graph, key, nr_devices, rank, root, group_mgr, - dev_buffer_arr, param, _dtype, backend, config, disable.get_val())[0]; + return CollectiveComm::make(inputs, &graph, key, nr_devices, is_root, rank, + group_mgr, dev_buffer_arr, param, _dtype, + backend, config, disable.get_val())[0]; } #else @@ -172,7 +172,7 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, SymbolVar _Opr::collective_comm_with_input( SymbolVar inpvar, const std::string& key, - const size_t nr_devices, const uint32_t rank, const uint32_t root, + const size_t nr_devices, const bool is_root, const int rank, const std::string& server_addr, const int port, PyObject* params, PyObject* dtype, const std::string& backend, SharedND* output_buf, const OperatorNodeConfig& config, const SharedScalar& disable) { @@ -181,7 +181,7 @@ SymbolVar _Opr::collective_comm_with_input( SymbolVar _Opr::collective_comm_without_input( CompGraph& cg, const std::string& key, - const size_t nr_devices, const uint32_t rank, const uint32_t root, + const size_t nr_devices, const bool is_root, const int rank, const std::string& server_addr, const int port, PyObject* params, PyObject* dtype, const std::string& backend, SharedND* output_buf, const OperatorNodeConfig& config, const SharedScalar& disable) { diff --git a/python_module/src/cpp/opr_defs.h b/python_module/src/cpp/opr_defs.h index 5c3583557..920f2d032 100644 --- a/python_module/src/cpp/opr_defs.h +++ b/python_module/src/cpp/opr_defs.h @@ -94,17 +94,17 @@ static SymbolVar remote_recv(const std::string& server_addr, const int port, static SymbolVar collective_comm_with_input( SymbolVar inpvar, const std::string& key, const size_t nr_devices, - const uint32_t rank, const uint32_t root, const std::string& server_addr, - const int port, PyObject* params, PyObject* dtype, - const std::string& backend, SharedND* output_buf, - const OperatorNodeConfig& config, const SharedScalar& disable); + const bool is_root, const int rank, const std::string& server_addr, const int port, + PyObject* params, PyObject* dtype, const std::string& backend, + SharedND* output_buf, const OperatorNodeConfig& config, + const SharedScalar& disable); static SymbolVar collective_comm_without_input( CompGraph& graph, const std::string& key, const size_t nr_devices, - const uint32_t rank, const uint32_t root, const std::string& server_addr, - const int port, PyObject* params, PyObject* dtype, - const std::string& backend, SharedND* output_buf, - const OperatorNodeConfig& config, const SharedScalar& disable); + const bool is_root, const int rank, const std::string& server_addr, const int port, + PyObject* params, PyObject* dtype, const std::string& backend, + SharedND* output_buf, const OperatorNodeConfig& config, + const SharedScalar& disable); // misc static SymbolVarArray extern_c_opr_placeholder( diff --git a/python_module/test/unit/distributed/test_functional.py b/python_module/test/unit/distributed/test_functional.py index f3c981b2f..d2bee6bdf 100644 --- a/python_module/test/unit/distributed/test_functional.py +++ b/python_module/test/unit/distributed/test_functional.py @@ -102,7 +102,7 @@ def test_all_gather(): return _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) inp = tensor(data) - output = dist.functional.all_gather(inp, "x") + output = dist.functional.all_gather(inp, "x", rank=rank) assert np.allclose(output.numpy(), expect) def check(shape, backend): @@ -135,7 +135,7 @@ def test_reduce_scatter_sum(): return _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) inp = tensor(data) - output = dist.functional.reduce_scatter_sum(inp, "x") + output = dist.functional.reduce_scatter_sum(inp, "x", rank=rank) assert np.allclose(output.numpy(), expect) def check(shape, backend): diff --git a/src/opr-mm/impl/collective_comm.cpp b/src/opr-mm/impl/collective_comm.cpp index d80df9b76..d735bb1cb 100644 --- a/src/opr-mm/impl/collective_comm.cpp +++ b/src/opr-mm/impl/collective_comm.cpp @@ -368,8 +368,8 @@ CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) { CollectiveComm::CollectiveComm( VarNodeArray inputs, ComputingGraph* const graph, - const std::string& key, const size_t nr_devices, const uint32_t rank, - const uint32_t root, std::shared_ptr group_client, + const std::string& key, const size_t nr_devices, const bool is_root, + const int rank, std::shared_ptr group_client, const Param& param, const DType& dtype, const std::string& backend, const SmallVector>& dev_buffer_arr, const OperatorNodeConfig& config, @@ -380,9 +380,9 @@ CollectiveComm::CollectiveComm( m_backend(backend), m_group_client{std::move(group_client)}, m_nr_devices(nr_devices), + m_is_root(is_root), m_rank(rank), m_key(key), - m_root(root), m_dev_buffers(dev_buffer_arr), m_disable{disable} { for (auto i : inputs) { @@ -422,28 +422,28 @@ CollectiveComm::CollectiveComm( SymbolVarArray CollectiveComm::make( const SymbolVarArray& inputs, ComputingGraph* const graph, - const std::string& key, const size_t nr_devices, const uint32_t rank, - const uint32_t root, std::shared_ptr group_client, - const Param& param, const DType& dtype, const std::string& backend, - const OperatorNodeConfig& config, + const std::string& key, const size_t nr_devices, const bool is_root, + const int rank, std::shared_ptr group_client, + const Param& param, const DType& dtype, const std::string& backend, + const OperatorNodeConfig& config, const std::shared_ptr& disable) { SmallVector> dev_buffer_arr(nr_devices, nullptr); - return make(inputs, graph, key, nr_devices, rank, root, group_client, + return make(inputs, graph, key, nr_devices, is_root, rank, group_client, dev_buffer_arr, param, dtype, backend, config); } SymbolVarArray CollectiveComm::make( const SymbolVarArray& inputs, ComputingGraph* const graph, - const std::string& key, const size_t nr_devices, const uint32_t rank, - const uint32_t root, std::shared_ptr group_client, + const std::string& key, const size_t nr_devices, const bool is_root, + const int rank, std::shared_ptr group_client, const SmallVector>& dev_buffer_arr, const Param& param, const DType& dtype, const std::string& backend, const OperatorNodeConfig& config, const std::shared_ptr& disable) { auto inpvars = cg::to_var_node_array(inputs); auto opr = graph->insert_opr(std::make_unique( - inpvars, graph, key, nr_devices, rank, root, std::move(group_client), + inpvars, graph, key, nr_devices, is_root, rank, std::move(group_client), param, dtype, backend, dev_buffer_arr, config, disable)); mgb_assert(!opr->output().empty()); return cg::to_symbol_var_array(opr->output()); @@ -452,11 +452,14 @@ SymbolVarArray CollectiveComm::make( void CollectiveComm::opr_register() { if (m_init) return; - auto&& cuda_env = CompNodeEnv::from_comp_node(output(0)->comp_node()) - .cuda_env(); + auto&& comp_node = output(0)->comp_node(); + + auto reg_info = m_group_client->opr_register( + m_key, m_nr_devices, m_is_root, m_rank, + comp_node.get_uid()); - auto hash = m_group_client->opr_register(m_key, m_nr_devices, m_rank, - reinterpret_cast(cuda_env.stream)); + m_rank = reg_info.rank; + m_root = reg_info.root_rank; MegRayCommunicatorBuilder* builder; @@ -468,7 +471,7 @@ void CollectiveComm::opr_register() { } m_megray_comm = builder->get_megray_comm( - hash, m_key, m_nr_devices, m_rank, + reg_info.hash, m_key, m_nr_devices, m_rank, get_megray_backend(m_backend), m_group_client); m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0))); @@ -606,8 +609,8 @@ VarNodeArray CollectiveComm::grad(const VarNodeArray& out_grads) const { } auto gvar = CollectiveComm::make( - og_syms, owner_graph(), m_key + ":grad", m_nr_devices, m_rank, m_root, - m_group_client, mode, m_dtype, m_backend, + og_syms, owner_graph(), m_key + ":grad", m_nr_devices, m_is_root, + m_rank, m_group_client, mode, m_dtype, m_backend, OperatorNodeConfig{}.comp_node_arr(cn_arr)); if (m_param.mode == Param::Mode::ALL_REDUCE_MAX) { @@ -733,11 +736,11 @@ cg::OperatorNodeBase* opr_shallow_copy_collective_mm( const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, const OperatorNodeConfig& config) { auto&& opr = opr_.cast_final_safe(); - return opr::CollectiveComm::make(to_symbol_var_array(inputs), - ctx.owner_graph(opr_, inputs), opr.key(), - opr.nr_devices(), opr.rank(), opr.root(), - opr.group_client(), opr.dev_buffers(), - opr.param(), opr.dtype(), opr.backend(), config)[0] + return opr::CollectiveComm::make( + to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs), + opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(), + opr.group_client(), opr.dev_buffers(), opr.param(), + opr.dtype(), opr.backend(), config)[0] .node() ->owner_opr(); } diff --git a/src/opr-mm/impl/collective_comm.oprdecl b/src/opr-mm/impl/collective_comm.oprdecl index 5ff1489fc..996b6e150 100644 --- a/src/opr-mm/impl/collective_comm.oprdecl +++ b/src/opr-mm/impl/collective_comm.oprdecl @@ -6,8 +6,8 @@ decl_raw_opr( 'to the same NCCL operation.', 'str'), Doc('nr_devices', 'Total number of devices involved in the NCCL ' 'operation to which this operator belongs.', 'int'), - Doc('rank', 'Rank of this operator', 'int'), - Doc('root', 'root rank of broadcast or reduce operation'), + Doc('is_root', 'whether this node is root node', 'bool'), + Doc('rank', 'rank of this node, if is -1, generate one', 'int'), Doc('server_addr', 'rpc server ip address'), Doc('port', 'server rpc listening port'), Doc('param', 'The only component of *param* is *mode*, which refers to ' @@ -28,12 +28,12 @@ decl_raw_opr( body = [ 'if isinstance(input, _mgb.SymbolVar):', (' output = _mgb._Opr.collective_comm_with_input(input, key, ' - 'nr_devices, rank, root, server_addr, port, ' + 'nr_devices, is_root, rank, server_addr, port, ' '[param.serialize()], dtype, backend, output_buffer, config, disable)'), 'else:', ' assert isinstance(input, _mgb.CompGraph)', (' output = _mgb._Opr.collective_comm_without_input(input, key, ' - 'nr_devices, rank, root, server_addr, port, ' + 'nr_devices, is_root, rank, server_addr, port, ' '[param.serialize()], dtype, backend, output_buffer, config, disable)') ], desc = ('collective communication between multiple CompNodes on multiple ' diff --git a/src/opr-mm/impl/group_manager.cpp b/src/opr-mm/impl/group_manager.cpp index 9e2820d75..ce0a7528c 100644 --- a/src/opr-mm/impl/group_manager.cpp +++ b/src/opr-mm/impl/group_manager.cpp @@ -16,16 +16,60 @@ using namespace opr; /* ================= GroupInfo ================= */ +void GroupInfo::sort_opr_infos() { + auto cmp = [](const GroupInfo::OprInfo& a, const GroupInfo::OprInfo& b) { + return a.comp_node_hash < b.comp_node_hash; + }; + std::sort(m_opr_infos.begin(), m_opr_infos.end(), cmp); +} + +void GroupInfo::gen_infos_from_opr_infos() { + // generate rank + bool rank_assgined = true; + for (auto& opr_info:m_opr_infos) { + if(opr_info.rank < 0) { + rank_assgined = false; + break; + } + } + if (!rank_assgined) { + for (size_t i = 0; i < m_opr_infos.size(); i++) { + m_opr_infos[i].rank = i; + m_rank_map.insert({m_opr_infos[i].comp_node_hash, i}); + } + } else { + for (size_t i = 0; i < m_opr_infos.size(); i++) { + m_rank_map.insert( + {m_opr_infos[i].comp_node_hash, m_opr_infos[i].rank}); + } + } + + // generate root rank + for (auto& opr_info:m_opr_infos) { + if (opr_info.is_root) { + m_root_rank = opr_info.rank; + break; + } + } + + // generate group hash + auto xxhash = XXHash{}; + for (auto&& opr_info : m_opr_infos) { + xxhash.update(&opr_info.comp_node_hash, sizeof(uint64_t)) + .update(&opr_info.rank, sizeof(int)); + } + m_hash = xxhash.digest(); +} + void GroupInfo::add_opr(const std::string& key, size_t nr_expected_devices, - uint32_t rank, uintptr_t stream) { + bool is_root, int rank, uint64_t comp_node_hash) { std::unique_lock lk{m_group_mtx}; if (m_nr_expected_devs == 0) { m_nr_expected_devs = nr_expected_devices; } else { mgb_assert(m_nr_expected_devs == nr_expected_devices); } - OprInfo opr_info = {rank, stream}; - m_opr_infos.push_back(std::move(opr_info)); + m_opr_infos.push_back({comp_node_hash, is_root, rank}); m_nr_registered_devs++; m_count++; if (m_nr_registered_devs > nr_expected_devices) { @@ -38,6 +82,8 @@ void GroupInfo::add_opr(const std::string& key, size_t nr_expected_devices, key.c_str(), nr_expected_devices, m_nr_registered_devs); } if (m_nr_expected_devs == m_nr_registered_devs) { + sort_opr_infos(); + gen_infos_from_opr_infos(); m_register_cv.notify_all(); } else { m_register_cv.wait(lk, @@ -66,6 +112,8 @@ void GroupInfo::clear() { m_count--; if (m_count == 0) { m_opr_infos.clear(); + m_rank_map.clear(); + m_root_rank = -1; m_nr_expected_devs = 0; m_nr_registered_devs = 0; m_output_shape.invalidate(); @@ -77,14 +125,18 @@ void GroupInfo::clear() { /* ================= GroupManager ================= */ -uint64_t GroupManager::opr_register(const std::string& key, size_t nr_devices, - uint32_t rank, uintptr_t stream) { +GroupManager::RegisterInfo GroupManager::opr_register(const std::string& key, + size_t nr_devices, + bool is_root, int rank, + uint64_t comp_node_hash) { + GroupManager::RegisterInfo ret{0, 0, 0}; auto&& group = get_group(key); - group.add_opr(key, nr_devices, rank, stream); - auto&& opr_infos = group.opr_infos(); - uint64_t hash = get_hash_key(opr_infos, rank); + group.add_opr(key, nr_devices, is_root, rank, comp_node_hash); + ret.rank = group.get_rank(comp_node_hash); + ret.root_rank = group.get_root_rank(); + ret.hash = group.get_group_hash() + ret.rank; group.clear(); - return hash; + return ret; } std::vector GroupManager::gather_uid(const std::string& uid, @@ -126,22 +178,6 @@ GroupInfo& GroupManager::get_group(const std::string& key) { return m_key2group_info[key]; } -uint64_t GroupManager::get_hash_key(const std::vector& _infos, - uint32_t rank) { - auto cmp = [](const GroupInfo::OprInfo& lhs, const GroupInfo::OprInfo& rhs) { - return lhs.rank < rhs.rank; - }; - auto infos = _infos; - std::sort(infos.begin(), infos.end(), cmp); - auto xxhash = XXHash{}; - for (auto&& opr_info : infos) { - xxhash.update(&opr_info.rank, sizeof(uint32_t)) - .update(&opr_info.stream, sizeof(uintptr_t)); - } - xxhash.update(&rank, sizeof(uint32_t)); - return xxhash.digest(); -}; - uint32_t GroupManager::group_barrier(uint32_t size, uint32_t rank) { std::unique_lock lk{m_barrier_mtx}; if (m_barrier_set.empty()) { diff --git a/src/opr-mm/impl/io_remote.cpp b/src/opr-mm/impl/io_remote.cpp index 6f1bd9d99..8fea82254 100644 --- a/src/opr-mm/impl/io_remote.cpp +++ b/src/opr-mm/impl/io_remote.cpp @@ -48,12 +48,11 @@ SymbolVar RemoteSend::make(const PeerDesc& peer, SymbolVar var, void RemoteSend::scn_do_execute() { if (!m_init) { - auto&& cuda_env = CompNodeEnv::from_comp_node(output(0)->comp_node()) - .cuda_env(); + auto&& comp_node = output(0)->comp_node(); // rank 0 for RemoteSend - auto hash = m_group_client->opr_register(m_peer.key, 2, 0, - reinterpret_cast(cuda_env.stream)); + auto reg_info = m_group_client->opr_register(m_peer.key, 2, 0, false, + comp_node.get_uid()); auto megray_comm_builder = owner_graph() @@ -62,7 +61,7 @@ void RemoteSend::scn_do_execute() { .get_user_data_or_create(); m_megray_comm = megray_comm_builder->get_megray_comm( - hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client); + reg_info.hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client); m_init = true; } @@ -152,12 +151,12 @@ SymbolVar RemoteRecv::make(const PeerDesc& peer, cg::ComputingGraph& graph, void RemoteRecv::scn_do_execute() { if (!m_init) { - auto&& cuda_env = CompNodeEnv::from_comp_node(output(0)->comp_node()) - .cuda_env(); + auto&& comp_node = output(0)->comp_node(); // rank 1 for RemoteRecv - auto hash = m_group_client->opr_register(m_peer.key, 2, 1, - reinterpret_cast(cuda_env.stream)); + auto reg_info = m_group_client->opr_register( + m_peer.key, 2, false, 1, + comp_node.get_uid()); auto megray_comm_builder = owner_graph() @@ -166,7 +165,7 @@ void RemoteRecv::scn_do_execute() { .get_user_data_or_create(); m_megray_comm = megray_comm_builder->get_megray_comm( - hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client); + reg_info.hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client); m_init = true; } diff --git a/src/opr-mm/impl/mm_handler.cpp b/src/opr-mm/impl/mm_handler.cpp index 4f8e3bbe8..f60248fa4 100644 --- a/src/opr-mm/impl/mm_handler.cpp +++ b/src/opr-mm/impl/mm_handler.cpp @@ -68,9 +68,11 @@ private: void GroupServerProxy::opr_register(void* input_ptr, size_t input_len, std::string *output) { INFO_INIT(mm_handler, OprRegister); - uint64_t hash = m_mgr.opr_register(req.key(), req.nr_expected_devices(), - req.rank(), req.stream()); - rsp.set_hash(hash); + auto ret = m_mgr.opr_register(req.key(), req.nr_expected_devices(), + req.is_root(), req.rank(), req.comp_node_hash()); + rsp.set_hash(ret.hash); + rsp.set_rank(ret.rank); + rsp.set_root_rank(ret.root_rank); rsp.SerializeToString(output); } @@ -122,11 +124,11 @@ void GroupServerProxy::group_barrier(void* input_ptr, size_t input_len, /* ======================== GroupClientProxy ========================== */ -#define INFO_INIT(space, f_name, name) \ +#define INFO_INIT(space, f_name, name) \ using Request = space::name##Request; \ using Response = space::name##Response; \ - std::string func_name = #f_name; \ - Request req; \ + std::string func_name = #f_name; \ + Request req; \ Response rsp; #define SOLVE_REQUEST(name, req, rsp) \ @@ -145,15 +147,18 @@ GroupClientProxy::GroupClientProxy(const std::string& server_addr) m_stub{ZmqRpc::ZmqRpcClient::get_client("tcp://" + server_addr)} { } -uint64_t GroupClientProxy::opr_register(const std::string& key, size_t nr_devices, - uint32_t rank, uintptr_t stream) { +GroupManager::RegisterInfo GroupClientProxy::opr_register( + const std::string& key, size_t nr_devices, bool is_root, int rank, + uint64_t comp_node_hash) { INFO_INIT(mm_handler, opr_register, OprRegister) req.set_key(key); + req.set_is_root(is_root); req.set_rank(rank); - req.set_stream(stream); + req.set_comp_node_hash(comp_node_hash); req.set_nr_expected_devices(nr_devices); SOLVE_REQUEST(func_name, req, rsp); - return rsp.hash(); + GroupManager::RegisterInfo ret{rsp.hash(), rsp.rank(), rsp.root_rank()}; + return ret; } void GroupClientProxy::set_output_shape(const std::string& key, diff --git a/src/opr-mm/include/megbrain/opr/collective_comm.h b/src/opr-mm/include/megbrain/opr/collective_comm.h index 9edc37d8f..2c7a1dfd0 100644 --- a/src/opr-mm/include/megbrain/opr/collective_comm.h +++ b/src/opr-mm/include/megbrain/opr/collective_comm.h @@ -26,18 +26,19 @@ public: using Param = megdnn::param::CollectiveComm; - CollectiveComm(VarNodeArray inputs, ComputingGraph* const graph, - const std::string& key, const size_t nr_devices, const uint32_t rank, - const uint32_t root, std::shared_ptr group_client, - const Param& param, const DType& dtype, const std::string& backend, - const SmallVector>& dev_buffer_arr, - const OperatorNodeConfig& config, - const std::shared_ptr& disable); + CollectiveComm( + VarNodeArray inputs, ComputingGraph* const graph, + const std::string& key, const size_t nr_devices, const bool is_root, + const int rank, std::shared_ptr group_client, + const Param& param, const DType& dtype, const std::string& backend, + const SmallVector>& dev_buffer_arr, + const OperatorNodeConfig& config, + const std::shared_ptr& disable); static SymbolVarArray make( const SymbolVarArray& inputs, ComputingGraph* const graph, - const std::string& key, const size_t nr_devices, const uint32_t rank, - const uint32_t root, std::shared_ptr group_client, + const std::string& key, const size_t nr_devices, const bool is_root, + const int rank, std::shared_ptr group_client, const SmallVector>& dev_buffer_arr, const Param& param, const DType& dtype = {}, const std::string& backend = "nccl", @@ -45,15 +46,16 @@ public: const std::shared_ptr& disable = std::make_shared(0)); - static SymbolVarArray make( - const SymbolVarArray& inputs, ComputingGraph* const graph, - const std::string& key, const size_t nr_devices, const uint32_t rank, - const uint32_t root, std::shared_ptr group_client, - const Param& param, const DType& dtype = {}, - const std::string& backend = "nccl", - const OperatorNodeConfig& config = {}, - const std::shared_ptr& disable = - std::make_shared(0)); + static SymbolVarArray make(const SymbolVarArray& inputs, + ComputingGraph* const graph, + const std::string& key, const size_t nr_devices, + const bool is_root, const int rank, + std::shared_ptr group_client, + const Param& param, const DType& dtype = {}, + const std::string& backend = "nccl", + const OperatorNodeConfig& config = {}, + const std::shared_ptr& disable = + std::make_shared(0)); const Param& param() const { return m_param; } const DType& dtype() const { return m_dtype; } @@ -67,9 +69,9 @@ public: return m_dev_buffers; } - uint32_t rank() const { return m_rank; } - uint32_t root() const { return m_root; } - bool is_root() const { return m_rank == m_root; } + int rank() const { return m_rank; } + int root() const { return m_root; } + bool is_root() const { return m_is_root; } //! The key that identifies an NCCL clique. //! Operators with same keys belong to the same clique. @@ -108,12 +110,13 @@ private: std::shared_ptr m_group_client; size_t m_nr_devices = 0; - uint32_t m_rank; + bool m_is_root; + int m_rank; std::string m_key; //! XXHash generated from m_key size_t m_hash; //! root of BROADCAST and REDUCE operation - uint32_t m_root; + int m_root; //! rank of root of BROADCAST and REDUCE operation Maybe m_broadcast_output_shape = None; // Whether shape infer is enabled. This is only used by BROADCAST operation, diff --git a/src/opr-mm/include/megbrain/opr/group_manager.h b/src/opr-mm/include/megbrain/opr/group_manager.h index 46ec7d7ed..9e87a89d5 100644 --- a/src/opr-mm/include/megbrain/opr/group_manager.h +++ b/src/opr-mm/include/megbrain/opr/group_manager.h @@ -24,12 +24,13 @@ namespace opr { class GroupInfo { public: struct OprInfo { - uint32_t rank; - uintptr_t stream; + uint64_t comp_node_hash; + bool is_root; + int rank; }; void add_opr(const std::string& key, size_t nr_expected_devices, - uint32_t graph_id, uintptr_t stream); + bool is_root, int rank, uint64_t comp_node_hash); void set_output_shape(const std::string& key, const TensorShape& shape); @@ -37,15 +38,25 @@ class GroupInfo { void clear(); - const std::vector& opr_infos() const {return m_opr_infos; } + const std::vector& opr_infos() const { return m_opr_infos; } + + int get_root_rank() const { return m_root_rank; } + int get_rank(uint64_t hash) const { return m_rank_map.at(hash); } + uint64_t get_group_hash() const { return m_hash; } private: + void sort_opr_infos(); + void gen_infos_from_opr_infos(); + std::vector m_opr_infos; + std::unordered_map m_rank_map; + uint64_t m_hash; uint32_t m_nr_registered_devs; uint32_t m_nr_expected_devs; Maybe m_output_shape; uint32_t m_count = 0; + int m_root_rank = -1; std::mutex m_group_mtx; std::condition_variable m_register_cv; std::condition_variable m_clear_cv; @@ -61,10 +72,16 @@ class GroupManager { public: ~GroupManager() = default; + struct RegisterInfo + { + uint64_t hash; + int rank, root_rank; + }; + //! register oprs' info to server, return deduplicated hash - uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank, - uintptr_t stream); - + RegisterInfo opr_register(const std::string& key, size_t nr_devices, + bool is_root, int rank, uint64_t comp_node_hash); + //! gather uids from all ranks std::vector gather_uid(const std::string& uid, const std::string& key, uint32_t size, uint32_t rank); @@ -80,9 +97,6 @@ class GroupManager { private: GroupInfo& get_group(const std::string& key); - - uint64_t get_hash_key(const std::vector& _infos, - uint32_t rank); //! key -> group info. std::unordered_map m_key2group_info; @@ -112,9 +126,11 @@ class GroupClient { virtual ~GroupClient() = default; public: - virtual uint64_t opr_register(const std::string& key, size_t nr_devices, - uint32_t rank, uintptr_t stream) = 0; - + virtual GroupManager::RegisterInfo opr_register(const std::string& key, + size_t nr_devices, + bool is_root, int rank, + uint64_t comp_node_hash) = 0; + virtual std::vector gather_uid(const std::string& uid, const std::string& key, uint32_t size, uint32_t rank) = 0; diff --git a/src/opr-mm/include/megbrain/opr/mm_handler.h b/src/opr-mm/include/megbrain/opr/mm_handler.h index fe80fb81f..ccc567d6b 100644 --- a/src/opr-mm/include/megbrain/opr/mm_handler.h +++ b/src/opr-mm/include/megbrain/opr/mm_handler.h @@ -14,6 +14,7 @@ #if MGB_ENABLE_OPR_MM #include "megbrain/opr/collective_comm.h" +#include "megbrain/opr/group_manager.h" using namespace mgb; using namespace opr; @@ -31,8 +32,10 @@ public: GroupClientProxy(const std::string& server_addr); //! graph registration, assign graph_id to worker. - uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank, - uintptr_t stream) override; + GroupManager::RegisterInfo opr_register(const std::string& key, + size_t nr_devices, bool is_root, + int rank, + uint64_t comp_node_hash) override; std::vector gather_uid(const std::string& uid, const std::string& key, uint32_t size, uint32_t rank) override; diff --git a/src/opr-mm/proto/mm_handler.proto b/src/opr-mm/proto/mm_handler.proto index e52b016b2..4102b870a 100644 --- a/src/opr-mm/proto/mm_handler.proto +++ b/src/opr-mm/proto/mm_handler.proto @@ -4,13 +4,16 @@ package mm_handler; message OprRegisterRequest { string key = 1; - uint32 rank = 2; - uint64 stream = 3; - uint32 nr_expected_devices = 4; + bool is_root = 2; + int32 rank = 3; + uint64 comp_node_hash = 4; + uint32 nr_expected_devices = 5; } message OprRegisterResponse { - uint64 hash = 1; + uint64 hash = 1; + int32 rank = 2; + int32 root_rank = 3; } message GatherUidRequest { diff --git a/src/opr-mm/test/collective_comm.cpp b/src/opr-mm/test/collective_comm.cpp index 0c29ddadb..70910a576 100644 --- a/src/opr-mm/test/collective_comm.cpp +++ b/src/opr-mm/test/collective_comm.cpp @@ -45,9 +45,11 @@ class MockGroupClient final : public opr::GroupClient { public: ~MockGroupClient() override = default; - uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank, - uintptr_t stream) { - return m_mgr.opr_register(key, nr_devices, rank, stream); + opr::GroupManager::RegisterInfo opr_register(const std::string& key, + size_t nr_devices, + bool is_root, int rank, + uintptr_t stream) { + return m_mgr.opr_register(key, nr_devices, is_root, rank, stream); } std::vector gather_uid(const std::string& uid, @@ -94,9 +96,9 @@ TEST(TestOprCollectiveComm, AllReduce) { auto x1c = opr::Copy::make(x1, cn1); auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "all_reduce", - 2, 0, 0, client, {mode}, dtype::Float32(), "nccl")[0]; + 2, false, 0, client, {mode}, dtype::Float32(), "nccl")[0]; auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "all_reduce", - 2, 1, 0, client, {mode}, dtype::Float32(), "nccl")[0]; + 2, false, 1, client, {mode}, dtype::Float32(), "nccl")[0]; auto y_expect = make_all_reduce_output(mode, {x0, x1}); auto func = graph->compile({make_callback_copy(y0, host_y0), @@ -130,7 +132,7 @@ TEST(TestOprCollectiveComm, AllReduceMultiThread) { auto graph0 = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0); auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_reduce", - 2, 0, 0, client, {mode}, dtype::Float32(), "nccl")[0]; + 2, false, 0, client, {mode}, dtype::Float32(), "nccl")[0]; auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); func0->execute(); }; @@ -139,7 +141,7 @@ TEST(TestOprCollectiveComm, AllReduceMultiThread) { auto graph1 = ComputingGraph::make(); auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_reduce", - 2, 1, 0, client, {mode}, dtype::Float32(), "nccl")[0]; + 2, false, 1, client, {mode}, dtype::Float32(), "nccl")[0]; auto func1 = graph1->compile({make_callback_copy(y1, host_y1)}); func1->execute(); }; @@ -192,7 +194,7 @@ TEST(TestOprCollectiveComm, AllReduceWithGrad) { graph0->options().graph_opt_level = 0; auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); - auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_reduce", 2, 0, 0, client, + auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_reduce", 2, false, 0, client, {Mode::ALL_REDUCE_SUM}, dtype::Float32(), "nccl")[0]; y0.node()->owner_opr()->node_prop().attribute().priority = -1; @@ -211,7 +213,7 @@ TEST(TestOprCollectiveComm, AllReduceWithGrad) { graph1->options().graph_opt_level = 0; auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); - auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_reduce", 2, 1, 0, client, + auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_reduce", 2, false, 1, client, {Mode::ALL_REDUCE_SUM}, dtype::Float32(), "nccl")[0]; y1.node()->owner_opr()->node_prop().attribute().priority = -1; @@ -274,9 +276,9 @@ TEST(TestOprCollectiveComm, AllGather) { auto x1c = opr::Copy::make(x1, cn1); auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "all_gather", - 2, 0, 0, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; + 2, false, 0, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "all_gather", - 2, 1, 0, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; + 2, false, 1, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; auto y_expect = opr::Concat::make({x0, x1}, 0); auto func = graph->compile({make_callback_copy(y0, host_y0), @@ -303,7 +305,7 @@ TEST(TestOprCollectiveComm, AllGatherMultiThread) { auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); - auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_gather", 2, 0, 0, client, + auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_gather", 2, false, 0, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); func0->execute(); @@ -312,7 +314,7 @@ TEST(TestOprCollectiveComm, AllGatherMultiThread) { auto run_1 = [&]() { // rank 1 auto graph1 = ComputingGraph::make(); auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); - auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_gather", 2, 1, 0, client, + auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_gather", 2, false, 1, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; auto func1 = graph1->compile({make_callback_copy(y1, host_y1)}); func1->execute(); @@ -361,7 +363,7 @@ TEST(TestOprCollectiveComm, AllGatherWithGrad) { graph0->options().graph_opt_level = 0; auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); - auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_gather", 2, 0, 0, client, + auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_gather", 2, false, 0, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; y0.node()->owner_opr()->node_prop().attribute().priority = -1; @@ -380,7 +382,7 @@ TEST(TestOprCollectiveComm, AllGatherWithGrad) { graph1->options().graph_opt_level = 0; auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); - auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_gather", 2, 1, 0, client, + auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_gather", 2, false, 1, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; y1.node()->owner_opr()->node_prop().attribute().priority = -1; @@ -444,9 +446,9 @@ TEST(TestOprCollectiveComm, ReduceScatterSum) { auto x1c = opr::Copy::make(x1, cn1); auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "reduce_scatter_sum", - 2, 0, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; + 2, false, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "reduce_scatter_sum", - 2, 1, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; + 2, false, 1, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; auto y_expect = make_reduce_scatter_sum_output({x0, x1}); auto func = graph->compile({make_callback_copy(y0, host_y0), @@ -475,7 +477,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumMultiThread) { auto graph0 = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce_scatter_sum", - 2, 0, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; + 2, false, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); func0->execute(); }; @@ -484,7 +486,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumMultiThread) { auto graph1 = ComputingGraph::make(); auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce_scatter_sum", - 2, 1, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; + 2, false, 1, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; auto func1 = graph1->compile({make_callback_copy(y1, host_y1)}); func1->execute(); }; @@ -534,7 +536,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumWithGrad) { auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce_scatter_sum", - 2, 0, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; + 2, false, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; y0.node()->owner_opr()->node_prop().attribute().priority = -1; auto grad0 = opr::Host2DeviceCopy::make(*graph0, host_grad0, cn0); @@ -553,7 +555,7 @@ TEST(TestOprCollectiveComm, ReduceScatterSumWithGrad) { auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce_scatter_sum", - 2, 1, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; + 2, false, 1, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; y1.node()->owner_opr()->node_prop().attribute().priority = -1; auto grad1 = opr::Host2DeviceCopy::make(*graph1, host_grad1, cn1); @@ -616,9 +618,9 @@ TEST(TestOprCollectiveComm, ReduceSum) { auto x1c = opr::Copy::make(x1, cn1); auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "reduce_sum", - 2, 0, 0, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; + 2, true, 0, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "reduce_sum", - 2, 1, 0, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; + 2, false, 1, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; auto y_expect = x0 + x1; auto func = graph->compile({make_callback_copy(y0, host_y0), @@ -644,7 +646,7 @@ TEST(TestOprCollectiveComm, ReduceSumMultiThread) { auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); - auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce", 2, 0, 0, client, + auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce", 2, true, 0, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); func0->execute(); @@ -653,7 +655,7 @@ TEST(TestOprCollectiveComm, ReduceSumMultiThread) { auto run_1 = [&]() { // rank 1 auto graph1 = ComputingGraph::make(); auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); - auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce", 2, 1, 0, client, + auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce", 2, false, 1, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; auto func1 = graph1->compile({{y1, nullptr}}); func1->execute(); @@ -699,7 +701,7 @@ TEST(TestOprCollectiveComm, ReduceSumWithGrad) { graph0->options().graph_opt_level = 0; auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); - auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce", 2, 0, 0, client, + auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce", 2, true, 0, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; y0.node()->owner_opr()->node_prop().attribute().priority = -1; @@ -718,7 +720,7 @@ TEST(TestOprCollectiveComm, ReduceSumWithGrad) { graph1->options().graph_opt_level = 0; auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); - auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce", 2, 1, 0, client, + auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce", 2, false, 1, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; y1.node()->owner_opr()->node_prop().attribute().priority = -1; @@ -767,12 +769,12 @@ TEST(TestOprCollectiveComm, Broadcast) { auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "broadcast", - 2, 0, 0, client, {Mode::BROADCAST}, dtype::Float32(), "nccl")[0]; + 2, true, 0, client, {Mode::BROADCAST}, dtype::Float32(), "nccl")[0]; auto y_dev = std::make_shared(DeviceTensorND() .comp_node(cn1) .dtype(dtype::Float32()) .resize(host_x0->shape())); - auto y1 = opr::CollectiveComm::make({}, graph.get(), "broadcast", 2, 1, 0, + auto y1 = opr::CollectiveComm::make({}, graph.get(), "broadcast", 2, false, 1, client, {y_dev}, {Mode::BROADCAST}, dtype::Float32(), "nccl", {cn1})[0]; auto func = graph->compile({make_callback_copy(y0, host_y0), @@ -797,7 +799,7 @@ TEST(TestOprCollectiveComm, BroadcastMultiThread) { auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); - auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "broadcast", 2, 0, 0, client, + auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "broadcast", 2, true, 0, client, {Mode::BROADCAST}, dtype::Float32(), "nccl")[0]; auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); func0->execute(); @@ -809,7 +811,7 @@ TEST(TestOprCollectiveComm, BroadcastMultiThread) { .comp_node(cn1) .dtype(dtype::Float32()) .resize(host_x0->shape())); - auto y1 = opr::CollectiveComm::make({}, graph1.get(), "broadcast", 2, 1, 0, client, + auto y1 = opr::CollectiveComm::make({}, graph1.get(), "broadcast", 2, false, 1, client, {y_dev}, {Mode::BROADCAST}, dtype::Float32(), "nccl", {cn1})[0]; auto func1 = graph1->compile({make_callback_copy(y1, host_y1)}); func1->execute(); @@ -845,7 +847,7 @@ TEST(TestOprCollectiveComm, BroadcastWithGrad) { graph0->options().graph_opt_level = 0; auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); - auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "broadcast", 2, 0, 0, client, + auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "broadcast", 2, true, 0, client, {Mode::BROADCAST}, dtype::Float32(), "nccl")[0]; y0.node()->owner_opr()->node_prop().attribute().priority = -1; @@ -863,11 +865,11 @@ TEST(TestOprCollectiveComm, BroadcastWithGrad) { auto graph1 = ComputingGraph::make(); graph1->options().graph_opt_level = 0; - auto y1 = opr::CollectiveComm::make({}, graph1.get(), "broadcast", 2, 1, 0, client, + auto y1 = opr::CollectiveComm::make({}, graph1.get(), "broadcast", 2, false, 1, client, {Mode::BROADCAST}, dtype::Float32(), "nccl", {cn1})[0]; auto grad1 = opr::Host2DeviceCopy::make(*graph1, host_grad1, cn1); - auto g = opr::CollectiveComm::make({grad1}, graph1.get(), "broadcast:grad", 2, 1, 0, client, + auto g = opr::CollectiveComm::make({grad1}, graph1.get(), "broadcast:grad", 2, false, 1, client, Mode::REDUCE_SUM, dtype::Float32(), "nccl")[0]; g.node()->owner_opr()->node_prop().attribute().priority = 1; diff --git a/src/opr-mm/test/io_remote.cpp b/src/opr-mm/test/io_remote.cpp index a25384fd2..b79178095 100644 --- a/src/opr-mm/test/io_remote.cpp +++ b/src/opr-mm/test/io_remote.cpp @@ -26,11 +26,13 @@ class MockGroupClient final : public opr::GroupClient { public: ~MockGroupClient() override = default; - uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank, - uintptr_t stream) { - return m_mgr.opr_register(key, nr_devices, rank, stream); + opr::GroupManager::RegisterInfo opr_register(const std::string& key, + size_t nr_devices, + bool is_root, int rank, + uint64_t comp_node_hash) { + return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash); } - + std::vector gather_uid(const std::string& uid, const std::string& key, uint32_t size, uint32_t rank) { return m_mgr.gather_uid(uid, key, size, rank); -- GitLab