From 6d367454cf1859fd2644517ec44ee28d642c162d Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 18 Jun 2020 18:03:36 +0800 Subject: [PATCH] feat(mge/opr-mm): add param local_grad for collective_comm opr GitOrigin-RevId: cc120cfb55d67a48dc126d1fd8773fa08a860d32 --- .../megengine/distributed/__init__.py | 3 + .../megengine/distributed/functional.py | 136 ++- python_module/megengine/distributed/helper.py | 21 +- .../megengine/optimizer/optimizer.py | 6 +- python_module/src/cpp/opr_defs.cpp | 28 +- python_module/src/cpp/opr_defs.h | 16 +- .../test/unit/distributed/test_functional.py | 22 +- src/gopt/impl/misc.cpp | 3 +- src/gopt/test/misc.cpp | 53 +- src/opr-mm/impl/collective_comm.cpp | 249 +++-- src/opr-mm/impl/collective_comm.oprdecl | 5 +- .../include/megbrain/opr/collective_comm.h | 13 +- src/opr-mm/test/collective_comm.cpp | 849 +++++++++++++----- tools/param_defs/mgb_opr_param_defs.py | 5 +- 14 files changed, 999 insertions(+), 410 deletions(-) diff --git a/python_module/megengine/distributed/__init__.py b/python_module/megengine/distributed/__init__.py index 63974cd37..1416e82cc 100644 --- a/python_module/megengine/distributed/__init__.py +++ b/python_module/megengine/distributed/__init__.py @@ -11,10 +11,13 @@ from .functional import ( all_reduce_max, all_reduce_min, all_reduce_sum, + all_to_all, bcast_param, broadcast, + gather, reduce_scatter_sum, reduce_sum, + scatter, ) from .util import ( get_backend, diff --git a/python_module/megengine/distributed/functional.py b/python_module/megengine/distributed/functional.py index 5a2e85f66..56ad089fe 100644 --- a/python_module/megengine/distributed/functional.py +++ b/python_module/megengine/distributed/functional.py @@ -9,7 +9,7 @@ from typing import Optional, Union import megengine._internal as mgb -from megengine._internal.opr_param_defs import CollectiveComm as CollParam +from megengine._internal.opr_param_defs import CollectiveComm as Param from ..core import Buffer, Parameter, Tensor, wrap_io_tensor from ..functional import add_update @@ -22,9 +22,16 @@ def _collective_comm(*args, **kargs): return collective_comm_symvar(*args, **kargs) +def _group_check(*args): + """Return True when arguments are all None or all not None + """ + l = [val is None for val in args] + return len(set(l)) <= 1 + + def reduce_sum( tensor: Tensor, - key: str, + key: Optional[str] = None, nr_ranks: Optional[int] = None, is_root: Optional[bool] = None, ) -> Tensor: @@ -35,14 +42,17 @@ def reduce_sum( :param nr_ranks: number of ranks, use util.get_world_size() as default :param is_root: whether this is a root node """ + assert _group_check( + key, nr_ranks, is_root + ), "key, nr_ranks, is_root should be set at the same time" return _collective_comm( - tensor, key, CollParam.Mode.REDUCE_SUM, nr_ranks, is_root, device=tensor.device, + tensor, key, Param.Mode.REDUCE_SUM, nr_ranks, is_root, device=tensor.device, ) def gather( tensor: Tensor, - key: str, + key: Optional[str] = None, nr_ranks: Optional[int] = None, is_root: Optional[bool] = None, rank: Optional[int] = None, @@ -55,20 +65,17 @@ def gather( :param is_root: whether this is a root node :param rank: rank of this node """ + assert _group_check( + key, nr_ranks, is_root, rank + ), "key, nr_ranks, is_root, rank should be set at the same time" return _collective_comm( - tensor, - key, - CollParam.Mode.GATHER, - nr_ranks, - is_root, - rank, - device=tensor.device, + tensor, key, Param.Mode.GATHER, nr_ranks, is_root, rank, device=tensor.device, ) def broadcast( tensor: Tensor, - key: str, + key: Optional[str] = None, nr_ranks: Optional[int] = None, is_root: Optional[bool] = None, ) -> Tensor: @@ -79,11 +86,12 @@ def broadcast( :param nr_ranks: number of ranks, use util.get_world_size() as default :param is_root: whether this is a root node """ - if key is None: - key = tensor._symvar.name + assert _group_check( + key, nr_ranks, is_root + ), "key, nr_ranks, is_root should be set at the same time" + if is_root is None: is_root = get_rank() == 0 - if is_root: inp = tensor else: @@ -92,7 +100,7 @@ def broadcast( return _collective_comm( inp, key, - CollParam.Mode.BROADCAST, + Param.Mode.BROADCAST, nr_ranks, is_root, dtype=tensor.dtype, @@ -102,7 +110,7 @@ def broadcast( def scatter( tensor: Tensor, - key: str, + key: Optional[str] = None, nr_ranks: Optional[int] = None, is_root: Optional[bool] = None, rank: Optional[int] = None, @@ -115,6 +123,9 @@ def scatter( :param is_root: whether this is a root node :param rank: rank of this node """ + assert _group_check( + key, nr_ranks, is_root, rank + ), "key, nr_ranks, is_root, rank should be set at the same time" if key is None: key = tensor._symvar.name if is_root is None: @@ -128,7 +139,7 @@ def scatter( return _collective_comm( inp, key, - CollParam.Mode.SCATTER, + Param.Mode.SCATTER, nr_ranks, is_root, rank, @@ -138,7 +149,11 @@ def scatter( def all_to_all( - tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None + tensor: Tensor, + key: Optional[str] = None, + nr_ranks: Optional[int] = None, + rank: Optional[int] = None, + local_grad: Optional[bool] = False, ) -> Tensor: """Create all_to_all operator for collective communication @@ -146,12 +161,22 @@ def all_to_all( :param key: unique identifier for collective communication :param nr_ranks: number of ranks, use util.get_world_size() as default :param rank: rank of this node + :param local_grad: whether use local grad """ - return _collective_comm(tensor, key, CollParam.Mode.ALL_TO_ALL, nr_ranks, rank=rank) + assert _group_check( + key, nr_ranks, rank + ), "key, nr_ranks, rank should be set at the same time" + return _collective_comm( + tensor, key, Param.Mode.ALL_TO_ALL, nr_ranks, rank=rank, local_grad=local_grad, + ) def all_gather( - tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None + tensor: Tensor, + key: Optional[str] = None, + nr_ranks: Optional[int] = None, + rank: Optional[int] = None, + local_grad: Optional[bool] = False, ) -> Tensor: """Create all_gather operator for collective communication @@ -159,12 +184,22 @@ def all_gather( :param key: unique identifier for collective communication :param nr_ranks: number of ranks, use util.get_world_size() as default :param rank: rank of this node + :param local_grad: whether use local grad """ - return _collective_comm(tensor, key, CollParam.Mode.ALL_GATHER, nr_ranks, rank=rank) + assert _group_check( + key, nr_ranks, rank + ), "key, nr_ranks, rank should be set at the same time" + return _collective_comm( + tensor, key, Param.Mode.ALL_GATHER, nr_ranks, rank=rank, local_grad=local_grad + ) def reduce_scatter_sum( - tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None + tensor: Tensor, + key: Optional[str] = None, + nr_ranks: Optional[int] = None, + rank: Optional[int] = None, + local_grad: Optional[bool] = False, ) -> Tensor: """Create reduce_scatter_sum operator for collective communication @@ -172,45 +207,81 @@ def reduce_scatter_sum( :param key: unique identifier for collective communication :param nr_ranks: number of ranks, use util.get_world_size() as default :param rank: rank of this node + :param local_grad: whether use local grad """ + assert _group_check( + key, nr_ranks, rank + ), "key, nr_ranks, rank should be set at the same time" return _collective_comm( - tensor, key, CollParam.Mode.REDUCE_SCATTER_SUM, nr_ranks, rank=rank, + tensor, + key, + Param.Mode.REDUCE_SCATTER_SUM, + nr_ranks, + rank=rank, + local_grad=local_grad, ) -def all_reduce_sum(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor: +def all_reduce_sum( + tensor: Tensor, + key: Optional[str] = None, + nr_ranks: Optional[int] = None, + local_grad: Optional[bool] = False, +) -> Tensor: """Create all_reduce_sum operator for collective communication :param tensor: input tensor :param key: unique identifier for collective communication :param nr_ranks: number of ranks, use util.get_world_size() as default + :param local_grad: whether use local grad """ - return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_SUM, nr_ranks) + assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time" + return _collective_comm( + tensor, key, Param.Mode.ALL_REDUCE_SUM, nr_ranks, local_grad=local_grad + ) -def all_reduce_max(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor: +def all_reduce_max( + tensor: Tensor, + key: Optional[str] = None, + nr_ranks: Optional[int] = None, + local_grad: Optional[bool] = False, +) -> Tensor: """Create all_reduce_max operator for collective communication :param tensor: input tensor :param key: unique identifier for collective communication :param nr_ranks: number of ranks, use util.get_world_size() as default + :param local_grad: whether use local grad """ - return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MAX, nr_ranks) + assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time" + return _collective_comm( + tensor, key, Param.Mode.ALL_REDUCE_MAX, nr_ranks, local_grad=local_grad + ) -def all_reduce_min(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor: +def all_reduce_min( + tensor: Tensor, + key: Optional[str] = None, + nr_ranks: Optional[int] = None, + local_grad: Optional[bool] = False, +) -> Tensor: """Create all_reduce_min operator for collective communication :param tensor: input tensor :param key: unique identifier for collective communication :param nr_ranks: number of ranks, use util.get_world_size() as default + :param local_grad: whether use local grad """ - return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MIN, nr_ranks) + assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time" + return _collective_comm( + tensor, key, Param.Mode.ALL_REDUCE_MIN, nr_ranks, local_grad=local_grad + ) def bcast_param( inp: Union[Buffer, Parameter], - key: str, + key: Optional[str] = None, nr_ranks: Optional[int] = None, is_root: Optional[bool] = None, ) -> None: @@ -223,6 +294,9 @@ def bcast_param( """ if not is_distributed(): return + assert _group_check( + key, nr_ranks, is_root + ), "key, nr_ranks, is_root should be set at the same time" assert isinstance(inp, (Buffer, Parameter)) bcast_res = broadcast(inp, key, nr_ranks, is_root) add_update(inp, bcast_res, alpha=0) diff --git a/python_module/megengine/distributed/helper.py b/python_module/megengine/distributed/helper.py index 0010b3735..7d2d84bda 100644 --- a/python_module/megengine/distributed/helper.py +++ b/python_module/megengine/distributed/helper.py @@ -11,16 +11,24 @@ from typing import Optional, Union import megengine._internal as mgb from megengine._internal.opr_param_defs import CollectiveComm as CollParam -from .util import get_backend, get_master_ip, get_master_port, get_rank, get_world_size +from .util import ( + get_backend, + get_group_id, + get_master_ip, + get_master_port, + get_rank, + get_world_size, +) def collective_comm_symvar( inp: Union[mgb.SymbolVar, mgb.CompGraph], - key: str, - op: CollParam.Mode, + key: Optional[str] = None, + op: CollParam.Mode = None, nr_ranks: Optional[int] = None, is_root: Optional[bool] = None, rank: Optional[int] = None, + local_grad: Optional[bool] = False, dtype: Optional[type] = None, device: Optional[mgb.CompNode] = None, comp_graph: Optional[mgb.CompGraph] = None, @@ -32,16 +40,19 @@ def collective_comm_symvar( :param op: mode of collective communication :param nr_ranks: number of ranks, use util.get_world_size() as default :param is_root: whether this node is root node + :param rank: rank of this node + :param local_grad: whether use local grad :param dtype: output data type, use dtype of inp as default :param device: output comp node, use comp node of inp as default :param comp_graph: output comp graph, use comp graph of inp as default """ return mgb.opr.collective_comm( inp, - key=str(key), + key=key if key is not None else ("collective_comm_" + str(get_group_id())), nr_devices=nr_ranks if nr_ranks is not None else get_world_size(), is_root=is_root if is_root is not None else (get_rank() == 0), - rank=rank if rank is not None else -1, + rank=rank if rank is not None else get_rank(), + local_grad=local_grad, server_addr=get_master_ip(), port=get_master_port(), param=CollParam(mode=op), diff --git a/python_module/megengine/optimizer/optimizer.py b/python_module/megengine/optimizer/optimizer.py index b89eb1d1d..64a53ad2f 100644 --- a/python_module/megengine/optimizer/optimizer.py +++ b/python_module/megengine/optimizer/optimizer.py @@ -182,7 +182,9 @@ class Optimizer(metaclass=ABCMeta): with opr_priority_scope(cg, -(2 ** 30)): # always run all_reduce_mean first except add_update grad = ( - all_reduce_sum(grad, "grad_" + str(get_group_id())) + all_reduce_sum( + grad, "grad_" + str(get_group_id()), get_world_size() + ) / get_world_size() ) with opr_priority_scope(cg, -(2 ** 31)): @@ -229,7 +231,7 @@ class Optimizer(metaclass=ABCMeta): for group in self.param_groups: for param in group["params"]: bcast_param( - param, "bcast_param_" + str(key), is_root=(get_rank() == 0), + param, "bcast_param_" + str(key), get_world_size(), get_rank() == 0, ) key += 1 diff --git a/python_module/src/cpp/opr_defs.cpp b/python_module/src/cpp/opr_defs.cpp index b2a7bc43b..d17ef9726 100644 --- a/python_module/src/cpp/opr_defs.cpp +++ b/python_module/src/cpp/opr_defs.cpp @@ -94,9 +94,9 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, SymbolVar _Opr::collective_comm_with_input( SymbolVar inpvar, const std::string& key, const size_t nr_devices, - const bool is_root, const int rank, const std::string& server_addr, - const int port, PyObject* params, PyObject* dtype, - const std::string& backend, SharedND* output_buf, + const bool is_root, const int rank, const bool local_grad, + const std::string& server_addr, const int port, PyObject* params, + PyObject* dtype, const std::string& backend, SharedND* output_buf, const OperatorNodeConfig& config, const SharedScalar& disable) { SymbolVarArray inputs(1, inpvar); ComputingGraph* graph = inpvar.node()->owner_graph(); @@ -111,15 +111,15 @@ SymbolVar _Opr::collective_comm_with_input( _dtype = npy::dtype_np2mgb(dtype); } return CollectiveComm::make(inputs, graph, key, nr_devices, is_root, rank, - group_mgr, dev_buffer_arr, param, _dtype, - backend, config, disable.get_val())[0]; + local_grad, group_mgr, dev_buffer_arr, param, + _dtype, backend, config, disable.get_val())[0]; } SymbolVar _Opr::collective_comm_without_input( CompGraph& cg, const std::string& key, const size_t nr_devices, - const bool is_root, const int rank, const std::string& server_addr, - const int port, PyObject* params, PyObject* dtype, - const std::string& backend, SharedND* output_buf, + const bool is_root, const int rank, const bool local_grad, + const std::string& server_addr, const int port, PyObject* params, + PyObject* dtype, const std::string& backend, SharedND* output_buf, const OperatorNodeConfig& config, const SharedScalar& disable) { SymbolVarArray inputs; auto& graph = cg.get(); @@ -134,8 +134,8 @@ SymbolVar _Opr::collective_comm_without_input( _dtype = npy::dtype_np2mgb(dtype); } return CollectiveComm::make(inputs, &graph, key, nr_devices, is_root, rank, - group_mgr, dev_buffer_arr, param, _dtype, - backend, config, disable.get_val())[0]; + local_grad, group_mgr, dev_buffer_arr, param, + _dtype, backend, config, disable.get_val())[0]; } #else @@ -171,8 +171,8 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, } SymbolVar _Opr::collective_comm_with_input( - SymbolVar inpvar, const std::string& key, - const size_t nr_devices, const bool is_root, const int rank, + SymbolVar inpvar, const std::string& key, const size_t nr_devices, + const bool is_root, const int rank, const bool local_grad, const std::string& server_addr, const int port, PyObject* params, PyObject* dtype, const std::string& backend, SharedND* output_buf, const OperatorNodeConfig& config, const SharedScalar& disable) { @@ -180,8 +180,8 @@ SymbolVar _Opr::collective_comm_with_input( } SymbolVar _Opr::collective_comm_without_input( - CompGraph& cg, const std::string& key, - const size_t nr_devices, const bool is_root, const int rank, + CompGraph& cg, const std::string& key, const size_t nr_devices, + const bool is_root, const int rank, const bool local_grad, const std::string& server_addr, const int port, PyObject* params, PyObject* dtype, const std::string& backend, SharedND* output_buf, const OperatorNodeConfig& config, const SharedScalar& disable) { diff --git a/python_module/src/cpp/opr_defs.h b/python_module/src/cpp/opr_defs.h index 920f2d032..2998d545e 100644 --- a/python_module/src/cpp/opr_defs.h +++ b/python_module/src/cpp/opr_defs.h @@ -94,17 +94,17 @@ static SymbolVar remote_recv(const std::string& server_addr, const int port, static SymbolVar collective_comm_with_input( SymbolVar inpvar, const std::string& key, const size_t nr_devices, - const bool is_root, const int rank, const std::string& server_addr, const int port, - PyObject* params, PyObject* dtype, const std::string& backend, - SharedND* output_buf, const OperatorNodeConfig& config, - const SharedScalar& disable); + const bool is_root, const int rank, const bool local_grad, + const std::string& server_addr, const int port, PyObject* params, + PyObject* dtype, const std::string& backend, SharedND* output_buf, + const OperatorNodeConfig& config, const SharedScalar& disable); static SymbolVar collective_comm_without_input( CompGraph& graph, const std::string& key, const size_t nr_devices, - const bool is_root, const int rank, const std::string& server_addr, const int port, - PyObject* params, PyObject* dtype, const std::string& backend, - SharedND* output_buf, const OperatorNodeConfig& config, - const SharedScalar& disable); + const bool is_root, const int rank, const bool local_grad, + const std::string& server_addr, const int port, PyObject* params, + PyObject* dtype, const std::string& backend, SharedND* output_buf, + const OperatorNodeConfig& config, const SharedScalar& disable); // misc static SymbolVarArray extern_c_opr_placeholder( diff --git a/python_module/test/unit/distributed/test_functional.py b/python_module/test/unit/distributed/test_functional.py index f34cb5e62..2b901d660 100644 --- a/python_module/test/unit/distributed/test_functional.py +++ b/python_module/test/unit/distributed/test_functional.py @@ -34,7 +34,7 @@ def test_reduce_sum(): return _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) inp = tensor(data) - output = dist.functional.reduce_sum(inp, "x") + output = dist.functional.reduce_sum(inp) if rank == 0: assert np.allclose(output.numpy(), expect) else: @@ -70,7 +70,7 @@ def test_gather(): return _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) inp = tensor(data) - output = dist.functional.gather(inp, "x", is_root=(rank == 0), rank=rank) + output = dist.functional.gather(inp) if rank == 0: assert np.allclose(output.numpy(), expect) else: @@ -106,7 +106,7 @@ def test_broadcast(): return _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) inp = tensor(data) - output = dist.functional.broadcast(inp, "x") + output = dist.functional.broadcast(inp) assert np.allclose(output.numpy(), expect) def check(shape, backend): @@ -138,7 +138,7 @@ def test_scatter(): return _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) inp = tensor(data) - output = dist.functional.scatter(inp, "x", is_root=(rank == 0), rank=rank) + output = dist.functional.scatter(inp) assert np.allclose(output.numpy(), expect) def check(shape, backend): @@ -174,7 +174,7 @@ def test_all_to_all(): return _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) inp = tensor(data) - output = dist.functional.all_to_all(inp, "x", rank=rank) + output = dist.functional.all_to_all(inp) assert np.allclose(output.numpy(), expect) def check(shape, backend): @@ -208,7 +208,7 @@ def test_all_gather(): return _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) inp = tensor(data) - output = dist.functional.all_gather(inp, "x", rank=rank) + output = dist.functional.all_gather(inp) assert np.allclose(output.numpy(), expect) def check(shape, backend): @@ -241,7 +241,7 @@ def test_reduce_scatter_sum(): return _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) inp = tensor(data) - output = dist.functional.reduce_scatter_sum(inp, "x", rank=rank) + output = dist.functional.reduce_scatter_sum(inp) assert np.allclose(output.numpy(), expect) def check(shape, backend): @@ -278,7 +278,7 @@ def test_all_reduce_sum(): return _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) inp = tensor(data) - output = dist.functional.all_reduce_sum(inp, "x") + output = dist.functional.all_reduce_sum(inp) assert np.allclose(output.numpy(), expect) def check(shape, backend): @@ -311,7 +311,7 @@ def test_all_reduce_max(): return _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) inp = tensor(data) - output = dist.functional.all_reduce_max(inp, "x") + output = dist.functional.all_reduce_max(inp) assert np.allclose(output.numpy(), expect) def check(shape, backend): @@ -344,7 +344,7 @@ def test_all_reduce_min(): return _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) inp = tensor(data) - output = dist.functional.all_reduce_min(inp, "x") + output = dist.functional.all_reduce_min(inp) assert np.allclose(output.numpy(), expect) def check(shape, backend): @@ -377,7 +377,7 @@ def test_bcast_param(): return _init_process_group_wrapper(world_size, rank, rank, backend, port_queue) inp = Parameter(data) - dist.functional.bcast_param(inp, "x") + dist.functional.bcast_param(inp) assert np.allclose(inp.numpy(), expect) def check(shape, backend): diff --git a/src/gopt/impl/misc.cpp b/src/gopt/impl/misc.cpp index 136b0227d..af47b9246 100644 --- a/src/gopt/impl/misc.cpp +++ b/src/gopt/impl/misc.cpp @@ -688,6 +688,7 @@ bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) { if (!opr->same_type()) return false; auto& comm = opr->cast_final_safe(); if (comm.param().mode != opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM) return false; + if (comm.local_grad()) return false; if (comm.input().size() != 1) return false; auto grad = comm.input(0)->owner_opr(); @@ -839,7 +840,7 @@ void PackAllReduceReplacePass::insert_packed_oprs( std::string key = ssprintf("grad_pack_%zu", pack_id); auto param = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; SymbolVar allreduce = opr::CollectiveComm::make({concat}, graph, - key, info->nr_devices, info->is_root, info->rank, + key, info->nr_devices, info->is_root, info->rank, false, info->group_client, param, info->dtype, info->backend)[0]; // split according to recorded partition diff --git a/src/gopt/test/misc.cpp b/src/gopt/test/misc.cpp index 9f7ef0884..9cdeaf38e 100644 --- a/src/gopt/test/misc.cpp +++ b/src/gopt/test/misc.cpp @@ -438,14 +438,14 @@ TEST_PASS(PackAllReduceScanPass, Basic) { auto grad3 = opr::VirtualGrad::make(y1, x1); auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; - auto comm0 = opr::CollectiveComm::make({grad0}, graph.get(), - "grad0", 2, 0, 0, client, mode)[0]; - auto comm1 = opr::CollectiveComm::make({grad1}, graph.get(), - "grad1", 2, 0, 0, client, mode)[0]; - auto comm2 = opr::CollectiveComm::make({grad2}, graph.get(), - "grad2", 2, 0, 0, client, mode)[0]; - auto comm3 = opr::CollectiveComm::make({grad3}, graph.get(), - "grad3", 2, 0, 0, client, mode)[0]; + auto comm0 = opr::CollectiveComm::make({grad0}, graph.get(), "grad0", 2, + false, 0, false, client, mode)[0]; + auto comm1 = opr::CollectiveComm::make({grad1}, graph.get(), "grad1", 2, + false, 0, false, client, mode)[0]; + auto comm2 = opr::CollectiveComm::make({grad2}, graph.get(), "grad2", 2, + false, 0, false, client, mode)[0]; + auto comm3 = opr::CollectiveComm::make({grad3}, graph.get(), "grad3", 2, + false, 0, false, client, mode)[0]; gopt::GraphOptimizer() .add_pass() @@ -488,10 +488,12 @@ TEST_PASS(PackAllReduceReplacePass, CollectGroups) { auto grad = opr::VirtualGrad::make(target, wrt); - auto comm = opr::CollectiveComm::make( - {grad}, graph.get(), "key", 2, 0, 0, client, - opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM)[0] - .node()->owner_opr(); + auto comm = + opr::CollectiveComm::make( + {grad}, graph.get(), "key", 2, false, 0, false, client, + opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM)[0] + .node() + ->owner_opr(); comm->cast_final_safe().set_pack_hash(extra_hash); @@ -543,8 +545,8 @@ TEST_PASS(PackAllReduceReplacePass, DividePacks) { auto insert_opr = [&] (size_t size) { auto dev = std::make_shared(cn, TensorShape{size / sizeof(float)}); auto sd = opr::SharedDeviceTensor::make(*graph, dev); - auto symvar = opr::CollectiveComm::make({sd}, graph.get(), - "key", 2, 0, 0, client, mode)[0]; + auto symvar = opr::CollectiveComm::make( + {sd}, graph.get(), "key", 2, false, 0, false, client, mode)[0]; auto opr = symvar.node()->owner_opr(); auto& comm = opr->cast_final_safe(); comm.set_pack_hash(1); @@ -596,7 +598,6 @@ TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) { size_t nr_devices = 2; uint32_t rank = 0; - uint32_t root = 0; using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo; ThinHashMap> group_info; @@ -605,8 +606,9 @@ TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) { auto insert_opr = [&] (const TensorShape& shape) { auto dev = std::make_shared(cn, shape); auto sd = opr::SharedDeviceTensor::make(*graph, dev); - auto symvar = opr::CollectiveComm::make({sd}, graph.get(), - "key", nr_devices, rank, root, client, mode)[0]; + auto symvar = + opr::CollectiveComm::make({sd}, graph.get(), "key", nr_devices, + false, rank, false, client, mode)[0]; auto opr = symvar.node()->owner_opr(); auto& comm = opr->cast_final_safe(); comm.set_pack_hash(1); @@ -634,8 +636,9 @@ TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) { auto concat = opr::Concat::make({grad_x.flatten(), grad_y.flatten()}, 0); std::string key = ssprintf("grad_pack_%zu", pack_id); - auto allreduce = opr::CollectiveComm::make({concat}, graph.get(), - key, nr_devices, rank, root, client, mode)[0]; + auto allreduce = + opr::CollectiveComm::make({concat}, graph.get(), key, nr_devices, + false, rank, false, client, mode)[0]; std::vector partition; partition.push_back(shape_x.total_nr_elems()); @@ -683,10 +686,14 @@ TEST_PASS(PackAllReduceReplacePass, Equivalence) { using Mode = opr::CollectiveComm::Param::Mode; bool is_root = (rank == 0); - auto reduced_x = opr::CollectiveComm::make({grad_x}, graph.get(), - "x", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2; - auto reduced_y = opr::CollectiveComm::make({grad_y}, graph.get(), - "y", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2; + auto reduced_x = opr::CollectiveComm::make( + {grad_x}, graph.get(), "x", 2, is_root, rank, + false, client, Mode::ALL_REDUCE_SUM)[0] / + 2; + auto reduced_y = opr::CollectiveComm::make( + {grad_y}, graph.get(), "y", 2, is_root, rank, + false, client, Mode::ALL_REDUCE_SUM)[0] / + 2; graph->options().allreduce_pack_max_size = 5000; graph->options().allreduce_pack_ignore_first = 0; diff --git a/src/opr-mm/impl/collective_comm.cpp b/src/opr-mm/impl/collective_comm.cpp index c04c813ab..46e2b5417 100644 --- a/src/opr-mm/impl/collective_comm.cpp +++ b/src/opr-mm/impl/collective_comm.cpp @@ -14,6 +14,8 @@ #include "megbrain/comp_node_env.h" #include "megbrain/graph/event.h" #include "megbrain/graph/grad_impl.h" +#include "megbrain/opr/io.h" +#include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/basic_arith.h" #include "megbrain/opr/megray_helper.h" #include "megbrain/opr/group_manager.h" @@ -77,6 +79,8 @@ cudaStream_t get_stream(VarNode* var) { } } // anonymous namespace +/* ================= ModeTrait ================= */ + class CollectiveComm::ModeTrait { class BROADCAST; class REDUCE_SUM; @@ -132,6 +136,42 @@ public: return None; } + VarNode* full_grad(VarNode* out_grad, const CollectiveComm* opr) const { + auto mode = ModeTrait::from_mode(opr->param().mode).grad_mode(); + SymbolVarArray og_syms; + og_syms.push_back(out_grad); + + auto&& cn = opr->output(0)->comp_node(); + + auto gvar = CollectiveComm::make( + og_syms, opr->owner_graph(), opr->key() + ":grad", + opr->nr_devices(), opr->is_root(), opr->rank(), false, + opr->group_client(), mode, opr->dtype(), opr->backend(), {cn}); + + return gvar[0].node(); + } + + virtual VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const { + mgb_throw(MegBrainError, + "only all_reduce all_to_all all_gather reduce_scatter " + "support local_grad"); + } + + virtual VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const { + if (opr->local_grad()){ + return local_grad(out_grad, opr); + } else { + return full_grad(out_grad, opr); + } + } + + VarNode* zeros(mgb::cg::ComputingGraph &graph, CompNode node, const SymbolVar& shape, + DType dtype) const { + auto zero = SymbolVar::make_scalar(0, graph, node); + auto zero_tensor = opr::TypeCvt::make(zero, dtype).broadcast(shape); + return zero_tensor.node(); + } + virtual void get_output_var_shape(const CollectiveComm* opr, const TensorShapeArray& ishp, TensorShapeArray& oshp) = 0; @@ -174,6 +214,17 @@ class CollectiveComm::ModeTrait::ALL_GATHER : public ModeTrait { } Mode grad_mode() override { return Mode::REDUCE_SCATTER_SUM; } + + VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const override { + auto nr_devices = opr->nr_devices(); + auto rank = opr->rank(); + opr::Subtensor::IndexDesc axis; + auto shape0 = opr::GetVarShape::make(out_grad, 0); + axis.push_back({0, shape0 * rank / (int)nr_devices, + shape0 * (rank + 1) / (int)nr_devices}); + auto grad = opr::Subtensor::make(out_grad, axis); + return grad.node(); + } }; class CollectiveComm::ModeTrait::REDUCE_SCATTER_SUM : public ModeTrait { @@ -211,9 +262,23 @@ class CollectiveComm::ModeTrait::REDUCE_SCATTER_SUM : public ModeTrait { } Mode grad_mode() override { return Mode::ALL_GATHER; } -}; -/* ================= ModeTrait impls ================= */ + VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const override { + VarNodeArray grads; + auto zeros_tensor = + zeros(*out_grad->owner_graph(), out_grad->comp_node(), + opr::GetVarShape::make(out_grad), out_grad->dtype()); + for (size_t i = 0;i < opr->nr_devices();i++) { + if (i == opr->rank()) { + grads.push_back(out_grad); + } else { + grads.push_back(zeros_tensor); + } + } + auto grad = opr::Concat::make(grads, 0); + return grad.node(); + } +}; class CollectiveComm::ModeTrait::ReducedBasedTrait { protected: @@ -250,6 +315,12 @@ class CollectiveComm::ModeTrait::AllReduceBase : public ReducedBasedTrait, } Mode grad_mode() override { return Mode::ALL_REDUCE_SUM; } + +public: + VarNode* local_grad(VarNode* out_grad, + const CollectiveComm* opr) const override { + return out_grad; + } }; class CollectiveComm::ModeTrait::ALL_REDUCE_SUM final : public AllReduceBase { @@ -258,10 +329,38 @@ class CollectiveComm::ModeTrait::ALL_REDUCE_SUM final : public AllReduceBase { class CollectiveComm::ModeTrait::ALL_REDUCE_MAX final : public AllReduceBase { MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_MAX; } + + VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override { + VarNode* grad; + if (opr->local_grad()) { + grad = local_grad(out_grad, opr); + } else { + grad = full_grad(out_grad, opr); + } + + grad = opr::Elemwise::make({opr->output(0), opr->input(0), grad}, + Elemwise::Mode::COND_LEQ_MOV) + .node(); + return grad; + } }; class CollectiveComm::ModeTrait::ALL_REDUCE_MIN final : public AllReduceBase { MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_MIN; } + + VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override { + VarNode* grad; + if (opr->local_grad()) { + grad = local_grad(out_grad, opr); + } else { + grad = full_grad(out_grad, opr); + } + + grad = opr::Elemwise::make({opr->input(0), opr->output(0), grad}, + Elemwise::Mode::COND_LEQ_MOV) + .node(); + return grad; + } }; class CollectiveComm::ModeTrait::ReduceBase : public ReducedBasedTrait, @@ -448,6 +547,24 @@ class CollectiveComm::ModeTrait::ALL_TO_ALL : public ModeTrait { } Mode grad_mode() override { return Mode::ALL_TO_ALL; } + + VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const override { + VarNodeArray grads; + auto grad_shape = opr::GetVarShape::make(out_grad); + auto zeros_tensor = + zeros(*out_grad->owner_graph(), out_grad->comp_node(), + grad_shape, out_grad->dtype()); + + auto nr_devices = opr->nr_devices(); + auto rank = opr->rank(); + opr::Subtensor::IndexDesc axis; + auto shape0 = opr::GetVarShape::make(out_grad, 0); + axis.push_back({0, shape0 * rank / (int)nr_devices, + shape0 * (rank + 1) / (int)nr_devices}); + auto sub_grad = opr::Subtensor::make(out_grad, axis); + + return opr::SetSubtensor::make(zeros_tensor, sub_grad, axis).node(); + } }; CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) { @@ -469,8 +586,9 @@ CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) { CollectiveComm::CollectiveComm( VarNodeArray inputs, ComputingGraph* const graph, const std::string& key, const size_t nr_devices, const bool is_root, - const int rank, std::shared_ptr group_client, - const Param& param, const DType& dtype, const std::string& backend, + const int rank, const bool local_grad, + std::shared_ptr group_client, const Param& param, + const DType& dtype, const std::string& backend, const SmallVector>& dev_buffer_arr, const OperatorNodeConfig& config, const std::shared_ptr& disable) @@ -482,6 +600,7 @@ CollectiveComm::CollectiveComm( m_nr_devices(nr_devices), m_is_root(is_root), m_rank(rank), + m_local_grad(local_grad), m_key(key), m_dev_buffers(dev_buffer_arr), m_disable{disable} { @@ -523,28 +642,31 @@ CollectiveComm::CollectiveComm( SymbolVarArray CollectiveComm::make( const SymbolVarArray& inputs, ComputingGraph* const graph, const std::string& key, const size_t nr_devices, const bool is_root, - const int rank, std::shared_ptr group_client, - const Param& param, const DType& dtype, const std::string& backend, + const int rank, const bool local_grad, + std::shared_ptr group_client, const Param& param, + const DType& dtype, const std::string& backend, const OperatorNodeConfig& config, const std::shared_ptr& disable) { SmallVector> dev_buffer_arr(nr_devices, nullptr); - return make(inputs, graph, key, nr_devices, is_root, rank, group_client, - dev_buffer_arr, param, dtype, backend, config); + return make(inputs, graph, key, nr_devices, is_root, rank, local_grad, + group_client, dev_buffer_arr, param, dtype, backend, config); } SymbolVarArray CollectiveComm::make( const SymbolVarArray& inputs, ComputingGraph* const graph, const std::string& key, const size_t nr_devices, const bool is_root, - const int rank, std::shared_ptr group_client, + const int rank, const bool local_grad, + std::shared_ptr group_client, const SmallVector>& dev_buffer_arr, const Param& param, const DType& dtype, const std::string& backend, const OperatorNodeConfig& config, const std::shared_ptr& disable) { auto inpvars = cg::to_var_node_array(inputs); auto opr = graph->insert_opr(std::make_unique( - inpvars, graph, key, nr_devices, is_root, rank, std::move(group_client), - param, dtype, backend, dev_buffer_arr, config, disable)); + inpvars, graph, key, nr_devices, is_root, rank, local_grad, + std::move(group_client), param, dtype, backend, dev_buffer_arr, + config, disable)); mgb_assert(!opr->output().empty()); return cg::to_symbol_var_array(opr->output()); } @@ -647,93 +769,12 @@ void CollectiveComm::do_execute(ExecEnv& env) { owner_graph()->event().signal_inplace(this, cn); trait.exec(this); owner_graph()->event().signal_inplace(this, cn); - -#if CUDART_VERSION < 9000 -#pragma message "legacy CUDA; use sync to avoid blocking" - // nccl hangs occasionally without this sync() - cn.sync(); -#endif }; env.dispatch_on_comp_node(cn, runner); } void CollectiveComm::on_output_comp_node_stream_changed() {} -VarNodeArray CollectiveComm::grad(const VarNodeArray& out_grads) const { - auto mode = ModeTrait::from_mode(m_param.mode).grad_mode(); - SymbolVarArray og_syms; - if (m_param.mode == Param::Mode::REDUCE_SUM) { - for (size_t i = 0; i < output().size(); i++) { - if (out_grads[i]) - og_syms.push_back(out_grads[i]); - } - mgb_assert(og_syms.size() == 1); - } else { - for (size_t i = 0; i < output().size(); i++) { - if (!out_grads[i]) { - mgb_assert(m_param.mode != Param::Mode::REDUCE_SCATTER_SUM, - "null out grad in CollctiveCommMM currently " - "unsupported when the forward mode is " - "Reduce_Scatter_Sum."); - DTypeScalar dval{output(i)->dtype()}; - dval.set_retain_dtype(0); - auto zeros = - SymbolVar::make_scalar(dval, *output(i)->owner_graph(), - output(i)->comp_node()) - .broadcast(SymbolVar(output(i)).symshape()); - og_syms.push_back(zeros); - } else { - og_syms.push_back(out_grads[i]); - } - } - } - - OperatorNodeConfig::CompNodeArray cn_arr; - if (m_param.mode == Param::Mode::REDUCE_SUM) { - for (auto i : input()) { - cn_arr.push_back(i->comp_node()); - } - } else if (m_param.mode == Param::Mode::BROADCAST) { - if (!input().empty()) { - cn_arr.push_back(input(0)->comp_node()); - } - } - - auto gvar = CollectiveComm::make( - og_syms, owner_graph(), m_key + ":grad", m_nr_devices, m_is_root, - m_rank, m_group_client, mode, m_dtype, m_backend, - OperatorNodeConfig{}.comp_node_arr(cn_arr)); - - if (m_param.mode == Param::Mode::ALL_REDUCE_MAX) { - for (size_t i = 0; i < input().size(); ++i) { - gvar[i] = Elemwise::make({output(i), input(i), gvar[i]}, - Elemwise::Mode::COND_LEQ_MOV); - } - } else if (m_param.mode == Param::Mode::ALL_REDUCE_MIN) { - for (size_t i = 0; i < input().size(); ++i) { - gvar[i] = Elemwise::make({input(i), output(i), gvar[i]}, - Elemwise::Mode::COND_LEQ_MOV); - } - } else if (m_param.mode == Param::Mode::BROADCAST) { - if (!input().empty()) { - CompNode&& master_out_cn = input(0)->comp_node(); - SymbolVarArray rst; - for (auto i : gvar) { - if (i.node()->comp_node() == master_out_cn) { - mgb_assert(rst.empty()); - rst.push_back(i); - } - } - gvar = rst; - } - } - return cg::to_var_node_array(gvar); -} - -MGB_IMPL_OPR_GRAD(CollectiveComm) { - return opr.grad(out_grad); -} - void CollectiveComm::init_output_dtype() { if (m_dtype.valid()) { for (size_t i = 0; i < input().size(); ++i) { @@ -797,6 +838,15 @@ void CollectiveComm::init_output_static_infer_desc() { } } +VarNode* CollectiveComm::grad(VarNode* out_grad) const { + return ModeTrait::from_mode(m_param.mode).grad(out_grad, this); +} + +MGB_IMPL_OPR_GRAD(CollectiveComm) { + mgb_assert(out_grad.size() == 1, "CollectiveComm should only have one grad"); + return opr.grad(out_grad[0]); +} + /* ===================== shallow copy ===================== */ namespace mgb { @@ -807,13 +857,14 @@ cg::OperatorNodeBase* opr_shallow_copy_collective_mm( const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, const OperatorNodeConfig& config) { auto&& opr = opr_.cast_final_safe(); - auto new_opr = CollectiveComm::make( - to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs), - opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(), - opr.group_client(), opr.dev_buffers(), opr.param(), - opr.dtype(), opr.backend(), config)[0] - .node() - ->owner_opr(); + auto new_opr = + CollectiveComm::make( + to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs), + opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(), + opr.local_grad(), opr.group_client(), opr.dev_buffers(), + opr.param(), opr.dtype(), opr.backend(), config)[0] + .node() + ->owner_opr(); new_opr->cast_final_safe().set_pack_hash(opr.pack_hash()); return new_opr; } diff --git a/src/opr-mm/impl/collective_comm.oprdecl b/src/opr-mm/impl/collective_comm.oprdecl index 996b6e150..a5fc48b8e 100644 --- a/src/opr-mm/impl/collective_comm.oprdecl +++ b/src/opr-mm/impl/collective_comm.oprdecl @@ -8,6 +8,7 @@ decl_raw_opr( 'operation to which this operator belongs.', 'int'), Doc('is_root', 'whether this node is root node', 'bool'), Doc('rank', 'rank of this node, if is -1, generate one', 'int'), + Doc('local_grad', 'whether use local grad', 'bool'), Doc('server_addr', 'rpc server ip address'), Doc('port', 'server rpc listening port'), Doc('param', 'The only component of *param* is *mode*, which refers to ' @@ -28,12 +29,12 @@ decl_raw_opr( body = [ 'if isinstance(input, _mgb.SymbolVar):', (' output = _mgb._Opr.collective_comm_with_input(input, key, ' - 'nr_devices, is_root, rank, server_addr, port, ' + 'nr_devices, is_root, rank, local_grad, server_addr, port, ' '[param.serialize()], dtype, backend, output_buffer, config, disable)'), 'else:', ' assert isinstance(input, _mgb.CompGraph)', (' output = _mgb._Opr.collective_comm_without_input(input, key, ' - 'nr_devices, is_root, rank, server_addr, port, ' + 'nr_devices, is_root, rank, local_grad, server_addr, port, ' '[param.serialize()], dtype, backend, output_buffer, config, disable)') ], desc = ('collective communication between multiple CompNodes on multiple ' diff --git a/src/opr-mm/include/megbrain/opr/collective_comm.h b/src/opr-mm/include/megbrain/opr/collective_comm.h index 791691ca2..677dab77d 100644 --- a/src/opr-mm/include/megbrain/opr/collective_comm.h +++ b/src/opr-mm/include/megbrain/opr/collective_comm.h @@ -29,8 +29,9 @@ public: CollectiveComm( VarNodeArray inputs, ComputingGraph* const graph, const std::string& key, const size_t nr_devices, const bool is_root, - const int rank, std::shared_ptr group_client, - const Param& param, const DType& dtype, const std::string& backend, + const int rank, const bool local_grad, + std::shared_ptr group_client, const Param& param, + const DType& dtype, const std::string& backend, const SmallVector>& dev_buffer_arr, const OperatorNodeConfig& config, const std::shared_ptr& disable); @@ -38,7 +39,8 @@ public: static SymbolVarArray make( const SymbolVarArray& inputs, ComputingGraph* const graph, const std::string& key, const size_t nr_devices, const bool is_root, - const int rank, std::shared_ptr group_client, + const int rank, const bool local_grad, + std::shared_ptr group_client, const SmallVector>& dev_buffer_arr, const Param& param, const DType& dtype = {}, const std::string& backend = "nccl", @@ -50,6 +52,7 @@ public: ComputingGraph* const graph, const std::string& key, const size_t nr_devices, const bool is_root, const int rank, + const bool local_grad, std::shared_ptr group_client, const Param& param, const DType& dtype = {}, const std::string& backend = "nccl", @@ -72,6 +75,7 @@ public: int rank() const { return m_rank; } int root() const { return m_root; } bool is_root() const { return m_is_root; } + bool local_grad() const { return m_local_grad; } //! The key that identifies an NCCL clique. //! Operators with same keys belong to the same clique. @@ -89,7 +93,7 @@ public: return m_megray_ctx; } - VarNodeArray grad(const VarNodeArray& out_grad) const; + VarNode* grad(VarNode* out_grad) const; private: Barrier m_exec_barrier; @@ -116,6 +120,7 @@ private: size_t m_nr_devices = 0; bool m_is_root; int m_rank; + bool m_local_grad; std::string m_key; //! XXHash generated from m_key size_t m_hash; diff --git a/src/opr-mm/test/collective_comm.cpp b/src/opr-mm/test/collective_comm.cpp index 41747d03d..5b9755e8c 100644 --- a/src/opr-mm/test/collective_comm.cpp +++ b/src/opr-mm/test/collective_comm.cpp @@ -10,13 +10,13 @@ */ #include "megbrain/opr/collective_comm.h" +#include "megbrain/graph.h" #include "megbrain/opr/basic_arith.h" #include "megbrain/opr/blas.h" #include "megbrain/opr/io.h" #include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/utility.h" #include "megbrain/test/helper.h" -#include "megbrain/graph.h" #include "mock_client.h" using namespace mgb; @@ -46,30 +46,33 @@ TEST(TestOprCollectiveComm, AllReduce) { auto run_mode = [](const Mode mode) { auto cn0 = CompNode::load("gpu0"); auto cn1 = CompNode::load("gpu1"); - + HostTensorGenerator<> gen; auto host_x0 = gen({28, 28}); auto host_x1 = gen({28, 28}); HostTensorND host_y0, host_y1, host_y_expect; - + auto client = std::make_shared(); auto graph = ComputingGraph::make(); - + auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); auto x1 = opr::Host2DeviceCopy::make(*graph, host_x1, cn0); auto x1c = opr::Copy::make(x1, cn1); - - auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "all_reduce", - 2, false, 0, client, {mode}, dtype::Float32(), "nccl")[0]; - auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "all_reduce", - 2, false, 1, client, {mode}, dtype::Float32(), "nccl")[0]; + + auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "all_reduce", 2, + false, 0, false, client, {mode}, + dtype::Float32(), "nccl")[0]; + auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "all_reduce", 2, + false, 1, false, client, {mode}, + dtype::Float32(), "nccl")[0]; auto y_expect = make_all_reduce_output(mode, {x0, x1}); - - auto func = graph->compile({make_callback_copy(y0, host_y0), - make_callback_copy(y1, host_y1), - make_callback_copy(y_expect, host_y_expect)}); + + auto func = + graph->compile({make_callback_copy(y0, host_y0), + make_callback_copy(y1, host_y1), + make_callback_copy(y_expect, host_y_expect)}); func->execute(); - + MGB_ASSERT_TENSOR_EQ(host_y_expect, host_y0); MGB_ASSERT_TENSOR_EQ(host_y_expect, host_y1); }; @@ -95,8 +98,9 @@ TEST(TestOprCollectiveComm, AllReduceMultiThread) { auto run_0 = [&]() { auto graph0 = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0); - auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_reduce", - 2, false, 0, client, {mode}, dtype::Float32(), "nccl")[0]; + auto y0 = opr::CollectiveComm::make( + {x0}, graph0.get(), "all_reduce", 2, false, 0, false, + client, {mode}, dtype::Float32(), "nccl")[0]; auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); func0->execute(); }; @@ -104,8 +108,9 @@ TEST(TestOprCollectiveComm, AllReduceMultiThread) { auto run_1 = [&]() { auto graph1 = ComputingGraph::make(); auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); - auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_reduce", - 2, false, 1, client, {mode}, dtype::Float32(), "nccl")[0]; + auto y1 = opr::CollectiveComm::make( + {x1}, graph1.get(), "all_reduce", 2, false, 1, false, + client, {mode}, dtype::Float32(), "nccl")[0]; auto func1 = graph1->compile({make_callback_copy(y1, host_y1)}); func1->execute(); }; @@ -115,7 +120,8 @@ TEST(TestOprCollectiveComm, AllReduceMultiThread) { auto x0 = opr::Host2DeviceCopy::make(*graph2, host_x0, cn0); auto x1 = opr::Host2DeviceCopy::make(*graph2, host_x1, cn0); auto y_expect = make_all_reduce_output(mode, {x0, x1}); - auto func2 = graph2->compile({make_callback_copy(y_expect, host_y_expect)}); + auto func2 = graph2->compile( + {make_callback_copy(y_expect, host_y_expect)}); func2->execute(); }; @@ -153,12 +159,13 @@ TEST(TestOprCollectiveComm, AllReduceWithGrad) { auto client = std::make_shared(); - auto run_0 = [&]() { // rank 0 + auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); graph0->options().graph_opt_level = 0; auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); - auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_reduce", 2, false, 0, client, + auto y0 = opr::CollectiveComm::make( + {x0}, graph0.get(), "all_reduce", 2, false, 0, false, client, {Mode::ALL_REDUCE_SUM}, dtype::Float32(), "nccl")[0]; y0.node()->owner_opr()->node_prop().attribute().priority = -1; @@ -166,18 +173,18 @@ TEST(TestOprCollectiveComm, AllReduceWithGrad) { auto loss = opr::Dot::make(y0, grad0); auto g = opr::VirtualGrad::make(loss, x0); - auto func0 = graph0->compile( - {make_callback_copy(y0, host_y0), - make_callback_copy(g, host_out_grad0)}); + auto func0 = graph0->compile({make_callback_copy(y0, host_y0), + make_callback_copy(g, host_out_grad0)}); func0->execute(); }; - auto run_1 = [&]() { // rank 1 + auto run_1 = [&]() { // rank 1 auto graph1 = ComputingGraph::make(); graph1->options().graph_opt_level = 0; auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); - auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_reduce", 2, false, 1, client, + auto y1 = opr::CollectiveComm::make( + {x1}, graph1.get(), "all_reduce", 2, false, 1, false, client, {Mode::ALL_REDUCE_SUM}, dtype::Float32(), "nccl")[0]; y1.node()->owner_opr()->node_prop().attribute().priority = -1; @@ -185,13 +192,12 @@ TEST(TestOprCollectiveComm, AllReduceWithGrad) { auto loss = opr::Dot::make(y1, grad1); auto g = opr::VirtualGrad::make(loss, x1); - auto func1 = graph1->compile( - {make_callback_copy(y1, host_y1), - make_callback_copy(g, host_out_grad1)}); + auto func1 = graph1->compile({make_callback_copy(y1, host_y1), + make_callback_copy(g, host_out_grad1)}); func1->execute(); }; - auto run_2 = [&]() { // check + auto run_2 = [&]() { // check auto graph2 = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph2, host_x0, cn0); @@ -200,11 +206,12 @@ TEST(TestOprCollectiveComm, AllReduceWithGrad) { auto grad0 = opr::Host2DeviceCopy::make(*graph2, host_grad0, cn0); auto grad1 = opr::Host2DeviceCopy::make(*graph2, host_grad1, cn0); - auto out_grad_expect = make_all_reduce_output(Mode::ALL_REDUCE_SUM, {grad0, grad1}); + auto out_grad_expect = + make_all_reduce_output(Mode::ALL_REDUCE_SUM, {grad0, grad1}); auto func2 = graph2->compile( - {make_callback_copy(y_expect, host_y_expect), - make_callback_copy(out_grad_expect, host_out_grad_expect)}); + {make_callback_copy(y_expect, host_y_expect), + make_callback_copy(out_grad_expect, host_out_grad_expect)}); func2->execute(); }; @@ -222,6 +229,87 @@ TEST(TestOprCollectiveComm, AllReduceWithGrad) { MGB_ASSERT_TENSOR_EQ(host_out_grad_expect, host_out_grad1); } +TEST(TestOprCollectiveComm, AllReduceWithGradThisNodeOnly) { + REQUIRE_GPU(2); + auto cn0 = CompNode::load("gpu0"); + auto cn1 = CompNode::load("gpu1"); + + HostTensorGenerator<> gen; + TensorShape shape({10}); + auto host_x0 = gen(shape); + auto host_x1 = gen(shape); + auto host_grad0 = gen(shape); + auto host_grad1 = gen(shape); + + HostTensorND host_y0, host_y1, host_y_expect; + HostTensorND host_out_grad0, host_out_grad1, host_out_grad_expect; + + auto client = std::make_shared(); + + auto run_0 = [&]() { // rank 0 + auto graph0 = ComputingGraph::make(); + graph0->options().graph_opt_level = 0; + + auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); + auto y0 = opr::CollectiveComm::make( + {x0}, graph0.get(), "all_reduce", 2, false, 0, true, client, + {Mode::ALL_REDUCE_SUM}, dtype::Float32(), "nccl")[0]; + y0.node()->owner_opr()->node_prop().attribute().priority = -1; + + auto grad0 = opr::Host2DeviceCopy::make(*graph0, host_grad0, cn0); + auto loss = opr::Dot::make(y0, grad0); + auto g = opr::VirtualGrad::make(loss, x0); + + auto func0 = graph0->compile({make_callback_copy(y0, host_y0), + make_callback_copy(g, host_out_grad0)}); + func0->execute(); + }; + + auto run_1 = [&]() { // rank 1 + auto graph1 = ComputingGraph::make(); + graph1->options().graph_opt_level = 0; + + auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); + auto y1 = opr::CollectiveComm::make( + {x1}, graph1.get(), "all_reduce", 2, false, 1, true, client, + {Mode::ALL_REDUCE_SUM}, dtype::Float32(), "nccl")[0]; + y1.node()->owner_opr()->node_prop().attribute().priority = -1; + + auto grad1 = opr::Host2DeviceCopy::make(*graph1, host_grad1, cn1); + auto loss = opr::Dot::make(y1, grad1); + auto g = opr::VirtualGrad::make(loss, x1); + + auto func1 = graph1->compile({make_callback_copy(y1, host_y1), + make_callback_copy(g, host_out_grad1)}); + func1->execute(); + }; + + auto run_2 = [&]() { // check + auto graph2 = ComputingGraph::make(); + + auto x0 = opr::Host2DeviceCopy::make(*graph2, host_x0, cn0); + auto x1 = opr::Host2DeviceCopy::make(*graph2, host_x1, cn0); + auto y_expect = make_all_reduce_output(Mode::ALL_REDUCE_SUM, {x0, x1}); + + auto func2 = + graph2->compile({make_callback_copy(y_expect, host_y_expect)}); + func2->execute(); + }; + + std::thread t0(run_0); + std::thread t1(run_1); + std::thread t2(run_2); + + t0.join(); + t1.join(); + t2.join(); + + MGB_ASSERT_TENSOR_EQ(host_y_expect, host_y0); + MGB_ASSERT_TENSOR_EQ(host_y_expect, host_y1); + MGB_ASSERT_TENSOR_EQ(*host_grad0, host_out_grad0); + MGB_ASSERT_TENSOR_EQ(*host_grad1, host_out_grad1); +} + TEST(TestOprCollectiveComm, AllGather) { REQUIRE_GPU(2); auto cn0 = CompNode::load("gpu0"); @@ -239,10 +327,12 @@ TEST(TestOprCollectiveComm, AllGather) { auto x1 = opr::Host2DeviceCopy::make(*graph, host_x1, cn0); auto x1c = opr::Copy::make(x1, cn1); - auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "all_gather", - 2, false, 0, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; - auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "all_gather", - 2, false, 1, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; + auto y0 = opr::CollectiveComm::make( + {x0}, graph.get(), "all_gather", 2, false, 0, false, client, + {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; + auto y1 = opr::CollectiveComm::make( + {x1c}, graph.get(), "all_gather", 2, false, 1, false, client, + {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; auto y_expect = opr::Concat::make({x0, x1}, 0); auto func = graph->compile({make_callback_copy(y0, host_y0), @@ -266,30 +356,33 @@ TEST(TestOprCollectiveComm, AllGatherMultiThread) { auto client = std::make_shared(); - auto run_0 = [&]() { // rank 0 + auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); - auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_gather", 2, false, 0, client, + auto y0 = opr::CollectiveComm::make( + {x0}, graph0.get(), "all_gather", 2, false, 0, false, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); func0->execute(); }; - auto run_1 = [&]() { // rank 1 + auto run_1 = [&]() { // rank 1 auto graph1 = ComputingGraph::make(); auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); - auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_gather", 2, false, 1, client, + auto y1 = opr::CollectiveComm::make( + {x1}, graph1.get(), "all_gather", 2, false, 1, false, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; auto func1 = graph1->compile({make_callback_copy(y1, host_y1)}); func1->execute(); }; - auto run_2 = [&]() { // check + auto run_2 = [&]() { // check auto graph2 = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph2, host_x0, cn0); auto x1 = opr::Host2DeviceCopy::make(*graph2, host_x1, cn0); auto y_expect = opr::Concat::make({x0, x1}, 0); - auto func2 = graph2->compile({make_callback_copy(y_expect, host_y_expect)}); + auto func2 = + graph2->compile({make_callback_copy(y_expect, host_y_expect)}); func2->execute(); }; @@ -322,12 +415,13 @@ TEST(TestOprCollectiveComm, AllGatherWithGrad) { auto client = std::make_shared(); - auto run_0 = [&]() { // rank 0 + auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); graph0->options().graph_opt_level = 0; auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); - auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "all_gather", 2, false, 0, client, + auto y0 = opr::CollectiveComm::make( + {x0}, graph0.get(), "all_gather", 2, false, 0, false, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; y0.node()->owner_opr()->node_prop().attribute().priority = -1; @@ -335,18 +429,18 @@ TEST(TestOprCollectiveComm, AllGatherWithGrad) { auto loss = opr::Dot::make(y0, grad0); auto g = opr::VirtualGrad::make(loss, x0); - auto func0 = graph0->compile( - {make_callback_copy(y0, host_y0), - make_callback_copy(g, host_out_grad0)}); + auto func0 = graph0->compile({make_callback_copy(y0, host_y0), + make_callback_copy(g, host_out_grad0)}); func0->execute(); }; - auto run_1 = [&]() { // rank 1 + auto run_1 = [&]() { // rank 1 auto graph1 = ComputingGraph::make(); graph1->options().graph_opt_level = 0; auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); - auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "all_gather", 2, false, 1, client, + auto y1 = opr::CollectiveComm::make( + {x1}, graph1.get(), "all_gather", 2, false, 1, false, client, {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; y1.node()->owner_opr()->node_prop().attribute().priority = -1; @@ -354,13 +448,12 @@ TEST(TestOprCollectiveComm, AllGatherWithGrad) { auto loss = opr::Dot::make(y1, grad1); auto g = opr::VirtualGrad::make(loss, x1); - auto func1 = graph1->compile( - {make_callback_copy(y1, host_y1), - make_callback_copy(g, host_out_grad1)}); + auto func1 = graph1->compile({make_callback_copy(y1, host_y1), + make_callback_copy(g, host_out_grad1)}); func1->execute(); }; - auto run_2 = [&]() { // check + auto run_2 = [&]() { // check auto graph2 = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph2, host_x0, cn0); @@ -372,9 +465,105 @@ TEST(TestOprCollectiveComm, AllGatherWithGrad) { auto out_grad_expect = make_reduce_scatter_sum_output({grad0, grad1}); auto func2 = graph2->compile( - {make_callback_copy(y_expect, host_y_expect), - make_callback_copy(out_grad_expect[0], host_out_grad0_expect), - make_callback_copy(out_grad_expect[1], host_out_grad1_expect)}); + {make_callback_copy(y_expect, host_y_expect), + make_callback_copy(out_grad_expect[0], host_out_grad0_expect), + make_callback_copy(out_grad_expect[1], + host_out_grad1_expect)}); + func2->execute(); + }; + + std::thread t0(run_0); + std::thread t1(run_1); + std::thread t2(run_2); + + t0.join(); + t1.join(); + t2.join(); + + MGB_ASSERT_TENSOR_EQ(host_y_expect, host_y0); + MGB_ASSERT_TENSOR_EQ(host_y_expect, host_y1); + MGB_ASSERT_TENSOR_EQ(host_out_grad0_expect, host_out_grad0); + MGB_ASSERT_TENSOR_EQ(host_out_grad1_expect, host_out_grad1); +} + +TEST(TestOprCollectiveComm, AllGatherWithGradThisNodeOnly) { + REQUIRE_GPU(2); + auto cn0 = CompNode::load("gpu0"); + auto cn1 = CompNode::load("gpu1"); + + HostTensorGenerator<> gen; + auto host_x0 = gen({10}); + auto host_x1 = gen({10}); + auto host_grad0 = gen({20}); + auto host_grad1 = gen({20}); + + HostTensorND host_y0, host_y1, host_y_expect; + HostTensorND host_out_grad0, host_out_grad1; + HostTensorND host_out_grad0_expect, host_out_grad1_expect; + + auto client = std::make_shared(); + + auto run_0 = [&]() { // rank 0 + auto graph0 = ComputingGraph::make(); + graph0->options().graph_opt_level = 0; + + auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); + auto y0 = opr::CollectiveComm::make( + {x0}, graph0.get(), "all_gather", 2, false, 0, true, client, + {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; + y0.node()->owner_opr()->node_prop().attribute().priority = -1; + + auto grad0 = opr::Host2DeviceCopy::make(*graph0, host_grad0, cn0); + auto loss = opr::Dot::make(y0, grad0); + auto g = opr::VirtualGrad::make(loss, x0); + + auto func0 = graph0->compile({make_callback_copy(y0, host_y0), + make_callback_copy(g, host_out_grad0)}); + func0->execute(); + }; + + auto run_1 = [&]() { // rank 1 + auto graph1 = ComputingGraph::make(); + graph1->options().graph_opt_level = 0; + + auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); + auto y1 = opr::CollectiveComm::make( + {x1}, graph1.get(), "all_gather", 2, false, 1, true, client, + {Mode::ALL_GATHER}, dtype::Float32(), "nccl")[0]; + y1.node()->owner_opr()->node_prop().attribute().priority = -1; + + auto grad1 = opr::Host2DeviceCopy::make(*graph1, host_grad1, cn1); + auto loss = opr::Dot::make(y1, grad1); + auto g = opr::VirtualGrad::make(loss, x1); + + auto func1 = graph1->compile({make_callback_copy(y1, host_y1), + make_callback_copy(g, host_out_grad1)}); + func1->execute(); + }; + + auto run_2 = [&]() { // check + auto graph2 = ComputingGraph::make(); + + auto x0 = opr::Host2DeviceCopy::make(*graph2, host_x0, cn0); + auto x1 = opr::Host2DeviceCopy::make(*graph2, host_x1, cn0); + auto y_expect = opr::Concat::make({x0, x1}, 0); + + auto grad0 = opr::Host2DeviceCopy::make(*graph2, host_grad0, cn0); + auto grad1 = opr::Host2DeviceCopy::make(*graph2, host_grad1, cn0); + + opr::Subtensor::IndexDesc axis0; + auto shape0 = opr::GetVarShape::make(grad0, 0); + axis0.push_back({0, 0, shape0 / 2}); + auto out_grad0_expect = opr::Subtensor::make(grad0, axis0); + + opr::Subtensor::IndexDesc axis1; + axis1.push_back({0, shape0 / 2}); + auto out_grad1_expect = opr::Subtensor::make(grad1, axis1); + + auto func2 = graph2->compile( + {make_callback_copy(y_expect, host_y_expect), + make_callback_copy(out_grad0_expect, host_out_grad0_expect), + make_callback_copy(out_grad1_expect, host_out_grad1_expect)}); func2->execute(); }; @@ -409,16 +598,18 @@ TEST(TestOprCollectiveComm, ReduceScatterSum) { auto x1 = opr::Host2DeviceCopy::make(*graph, host_x1, cn0); auto x1c = opr::Copy::make(x1, cn1); - auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "reduce_scatter_sum", - 2, false, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; - auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "reduce_scatter_sum", - 2, false, 1, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; + auto y0 = opr::CollectiveComm::make( + {x0}, graph.get(), "reduce_scatter_sum", 2, false, 0, false, client, + {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; + auto y1 = opr::CollectiveComm::make( + {x1c}, graph.get(), "reduce_scatter_sum", 2, false, 1, false, + client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; auto y_expect = make_reduce_scatter_sum_output({x0, x1}); - auto func = graph->compile({make_callback_copy(y0, host_y0), - make_callback_copy(y1, host_y1), - make_callback_copy(y_expect[0], host_y0_expect), - make_callback_copy(y_expect[1], host_y1_expect)}); + auto func = graph->compile( + {make_callback_copy(y0, host_y0), make_callback_copy(y1, host_y1), + make_callback_copy(y_expect[0], host_y0_expect), + make_callback_copy(y_expect[1], host_y1_expect)}); func->execute(); MGB_ASSERT_TENSOR_EQ(host_y0_expect, host_y0); @@ -437,32 +628,36 @@ TEST(TestOprCollectiveComm, ReduceScatterSumMultiThread) { auto client = std::make_shared(); - auto run_0 = [&]() { // rank 0 + auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); - auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce_scatter_sum", - 2, false, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; + auto y0 = opr::CollectiveComm::make( + {x0}, graph0.get(), "reduce_scatter_sum", 2, false, 0, false, + client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), + "nccl")[0]; auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); func0->execute(); }; - auto run_1 = [&]() { // rank 1 + auto run_1 = [&]() { // rank 1 auto graph1 = ComputingGraph::make(); auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); - auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce_scatter_sum", - 2, false, 1, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; + auto y1 = opr::CollectiveComm::make( + {x1}, graph1.get(), "reduce_scatter_sum", 2, false, 1, false, + client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), + "nccl")[0]; auto func1 = graph1->compile({make_callback_copy(y1, host_y1)}); func1->execute(); }; - auto run_2 = [&]() { // check + auto run_2 = [&]() { // check auto graph2 = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph2, host_x0, cn0); auto x1 = opr::Host2DeviceCopy::make(*graph2, host_x1, cn0); auto y_expect = make_reduce_scatter_sum_output({x0, x1}); auto func = graph2->compile( - {make_callback_copy(y_expect[0], host_y0_expect), - make_callback_copy(y_expect[1], host_y1_expect)}); + {make_callback_copy(y_expect[0], host_y0_expect), + make_callback_copy(y_expect[1], host_y1_expect)}); func->execute(); }; @@ -494,45 +689,47 @@ TEST(TestOprCollectiveComm, ReduceScatterSumWithGrad) { auto client = std::make_shared(); - auto run_0 = [&]() { // rank 0 + auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); graph0->options().graph_opt_level = 0; auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); - auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce_scatter_sum", - 2, false, 0, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; + auto y0 = opr::CollectiveComm::make( + {x0}, graph0.get(), "reduce_scatter_sum", 2, false, 0, false, + client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), + "nccl")[0]; y0.node()->owner_opr()->node_prop().attribute().priority = -1; auto grad0 = opr::Host2DeviceCopy::make(*graph0, host_grad0, cn0); auto loss = opr::Dot::make(y0, grad0); auto g = opr::VirtualGrad::make(loss, x0); - auto func0 = graph0->compile( - {make_callback_copy(y0, host_y0), - make_callback_copy(g, host_out_grad0)}); + auto func0 = graph0->compile({make_callback_copy(y0, host_y0), + make_callback_copy(g, host_out_grad0)}); func0->execute(); }; - auto run_1 = [&]() { // rank 1 + auto run_1 = [&]() { // rank 1 auto graph1 = ComputingGraph::make(); graph1->options().graph_opt_level = 0; auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); - auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce_scatter_sum", - 2, false, 1, client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), "nccl")[0]; + auto y1 = opr::CollectiveComm::make( + {x1}, graph1.get(), "reduce_scatter_sum", 2, false, 1, false, + client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), + "nccl")[0]; y1.node()->owner_opr()->node_prop().attribute().priority = -1; auto grad1 = opr::Host2DeviceCopy::make(*graph1, host_grad1, cn1); auto loss = opr::Dot::make(y1, grad1); auto g = opr::VirtualGrad::make(loss, x1); - auto func1 = graph1->compile( - {make_callback_copy(y1, host_y1), - make_callback_copy(g, host_out_grad1)}); + auto func1 = graph1->compile({make_callback_copy(y1, host_y1), + make_callback_copy(g, host_out_grad1)}); func1->execute(); }; - auto run_2 = [&]() { // check + auto run_2 = [&]() { // check auto graph2 = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph2, host_x0, cn0); @@ -544,9 +741,9 @@ TEST(TestOprCollectiveComm, ReduceScatterSumWithGrad) { auto out_grad_expect = opr::Concat::make({grad0, grad1}, 0); auto func2 = graph2->compile( - {make_callback_copy(y_expect[0], host_y0_expect), - make_callback_copy(y_expect[1], host_y1_expect), - make_callback_copy(out_grad_expect, host_out_grad_expect)}); + {make_callback_copy(y_expect[0], host_y0_expect), + make_callback_copy(y_expect[1], host_y1_expect), + make_callback_copy(out_grad_expect, host_out_grad_expect)}); func2->execute(); }; @@ -564,6 +761,101 @@ TEST(TestOprCollectiveComm, ReduceScatterSumWithGrad) { MGB_ASSERT_TENSOR_EQ(host_out_grad_expect, host_out_grad1); } +TEST(TestOprCollectiveComm, ReduceScatterSumWithGradThisNodeOnly) { + REQUIRE_GPU(2); + auto cn0 = CompNode::load("gpu0"); + auto cn1 = CompNode::load("gpu1"); + + HostTensorGenerator<> gen; + HostTensorGenerator<> zeros(0, 0); + auto host_x0 = gen({20}); + auto host_x1 = gen({20}); + auto host_grad0 = gen({10}); + auto host_grad1 = gen({10}); + auto host_zero_grad = zeros({10}); + + HostTensorND host_y0, host_y1, host_y0_expect, host_y1_expect; + HostTensorND host_out_grad0, host_out_grad1, host_out_grad_expect0, + host_out_grad_expect1; + + auto client = std::make_shared(); + + auto run_0 = [&]() { // rank 0 + auto graph0 = ComputingGraph::make(); + graph0->options().graph_opt_level = 0; + + auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); + auto y0 = opr::CollectiveComm::make( + {x0}, graph0.get(), "reduce_scatter_sum", 2, false, 0, true, + client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), + "nccl")[0]; + y0.node()->owner_opr()->node_prop().attribute().priority = -1; + + auto grad0 = opr::Host2DeviceCopy::make(*graph0, host_grad0, cn0); + auto loss = opr::Dot::make(y0, grad0); + auto g = opr::VirtualGrad::make(loss, x0); + + auto func0 = graph0->compile({make_callback_copy(y0, host_y0), + make_callback_copy(g, host_out_grad0)}); + func0->execute(); + }; + + auto run_1 = [&]() { // rank 1 + auto graph1 = ComputingGraph::make(); + graph1->options().graph_opt_level = 0; + + auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); + auto y1 = opr::CollectiveComm::make( + {x1}, graph1.get(), "reduce_scatter_sum", 2, false, 1, true, + client, {Mode::REDUCE_SCATTER_SUM}, dtype::Float32(), + "nccl")[0]; + y1.node()->owner_opr()->node_prop().attribute().priority = -1; + + auto grad1 = opr::Host2DeviceCopy::make(*graph1, host_grad1, cn1); + auto loss = opr::Dot::make(y1, grad1); + auto g = opr::VirtualGrad::make(loss, x1); + + auto func1 = graph1->compile({make_callback_copy(y1, host_y1), + make_callback_copy(g, host_out_grad1)}); + func1->execute(); + }; + + auto run_2 = [&]() { // check + auto graph2 = ComputingGraph::make(); + + auto x0 = opr::Host2DeviceCopy::make(*graph2, host_x0, cn0); + auto x1 = opr::Host2DeviceCopy::make(*graph2, host_x1, cn0); + auto y_expect = make_reduce_scatter_sum_output({x0, x1}); + + auto grad0 = opr::Host2DeviceCopy::make(*graph2, host_grad0, cn0); + auto grad1 = opr::Host2DeviceCopy::make(*graph2, host_grad1, cn0); + auto zero_grad = + opr::Host2DeviceCopy::make(*graph2, host_zero_grad, cn0); + auto out_grad_expect0 = opr::Concat::make({grad0, zero_grad}, 0); + auto out_grad_expect1 = opr::Concat::make({zero_grad, grad1}, 0); + + auto func2 = graph2->compile( + {make_callback_copy(y_expect[0], host_y0_expect), + make_callback_copy(y_expect[1], host_y1_expect), + make_callback_copy(out_grad_expect0, host_out_grad_expect0), + make_callback_copy(out_grad_expect1, host_out_grad_expect1)}); + func2->execute(); + }; + + std::thread t0(run_0); + std::thread t1(run_1); + std::thread t2(run_2); + + t0.join(); + t1.join(); + t2.join(); + + MGB_ASSERT_TENSOR_EQ(host_y0_expect, host_y0); + MGB_ASSERT_TENSOR_EQ(host_y1_expect, host_y1); + MGB_ASSERT_TENSOR_EQ(host_out_grad_expect0, host_out_grad0); + MGB_ASSERT_TENSOR_EQ(host_out_grad_expect1, host_out_grad1); +} + TEST(TestOprCollectiveComm, ReduceSum) { REQUIRE_GPU(2); auto cn0 = CompNode::load("gpu0"); @@ -581,10 +873,12 @@ TEST(TestOprCollectiveComm, ReduceSum) { auto x1 = opr::Host2DeviceCopy::make(*graph, host_x1, cn0); auto x1c = opr::Copy::make(x1, cn1); - auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "reduce_sum", - 2, true, 0, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; - auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "reduce_sum", - 2, false, 1, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; + auto y0 = opr::CollectiveComm::make( + {x0}, graph.get(), "reduce_sum", 2, true, 0, false, client, + {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; + auto y1 = opr::CollectiveComm::make( + {x1c}, graph.get(), "reduce_sum", 2, false, 1, false, client, + {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; auto y_expect = x0 + x1; auto func = graph->compile({make_callback_copy(y0, host_y0), @@ -607,30 +901,33 @@ TEST(TestOprCollectiveComm, ReduceSumMultiThread) { auto client = std::make_shared(); - auto run_0 = [&]() { // rank 0 + auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); - auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce", 2, true, 0, client, + auto y0 = opr::CollectiveComm::make( + {x0}, graph0.get(), "reduce", 2, true, 0, false, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); func0->execute(); }; - auto run_1 = [&]() { // rank 1 + auto run_1 = [&]() { // rank 1 auto graph1 = ComputingGraph::make(); auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); - auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce", 2, false, 1, client, + auto y1 = opr::CollectiveComm::make( + {x1}, graph1.get(), "reduce", 2, false, 1, false, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; auto func1 = graph1->compile({{y1, nullptr}}); func1->execute(); }; - auto run_2 = [&]() { // check + auto run_2 = [&]() { // check auto graph2 = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph2, host_x0, cn0); auto x1 = opr::Host2DeviceCopy::make(*graph2, host_x1, cn0); auto y_expect = x0 + x1; - auto func2 = graph2->compile({make_callback_copy(y_expect, host_y_expect)}); + auto func2 = + graph2->compile({make_callback_copy(y_expect, host_y_expect)}); func2->execute(); }; @@ -660,12 +957,13 @@ TEST(TestOprCollectiveComm, ReduceSumWithGrad) { auto client = std::make_shared(); - auto run_0 = [&]() { // rank 0 + auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); graph0->options().graph_opt_level = 0; auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); - auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "reduce", 2, true, 0, client, + auto y0 = opr::CollectiveComm::make( + {x0}, graph0.get(), "reduce", 2, true, 0, false, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; y0.node()->owner_opr()->node_prop().attribute().priority = -1; @@ -673,18 +971,18 @@ TEST(TestOprCollectiveComm, ReduceSumWithGrad) { auto loss = opr::Dot::make(y0, grad); auto g = opr::VirtualGrad::make(loss, x0); - auto func0 = graph0->compile( - {make_callback_copy(y0, host_y0), - make_callback_copy(g, host_out_grad0)}); + auto func0 = graph0->compile({make_callback_copy(y0, host_y0), + make_callback_copy(g, host_out_grad0)}); func0->execute(); }; - auto run_1 = [&]() { // rank 1 + auto run_1 = [&]() { // rank 1 auto graph1 = ComputingGraph::make(); graph1->options().graph_opt_level = 0; auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); - auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "reduce", 2, false, 1, client, + auto y1 = opr::CollectiveComm::make( + {x1}, graph1.get(), "reduce", 2, false, 1, false, client, {Mode::REDUCE_SUM}, dtype::Float32(), "nccl")[0]; y1.node()->owner_opr()->node_prop().attribute().priority = -1; @@ -692,17 +990,18 @@ TEST(TestOprCollectiveComm, ReduceSumWithGrad) { auto loss = opr::Dot::make(y1, grad); auto g = opr::VirtualGrad::make(loss, x1); - auto func1 = graph1->compile({{y1, nullptr}, make_callback_copy(g, host_out_grad1)}); + auto func1 = graph1->compile( + {{y1, nullptr}, make_callback_copy(g, host_out_grad1)}); func1->execute(); }; - auto run_2 = [&]() { // check + auto run_2 = [&]() { // check auto graph2 = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph2, host_x0, cn0); auto x1 = opr::Host2DeviceCopy::make(*graph2, host_x1, cn0); auto y0_expect = x0 + x1; - auto func2 = graph2->compile({ - make_callback_copy(y0_expect, host_y0_expect)}); + auto func2 = graph2->compile( + {make_callback_copy(y0_expect, host_y0_expect)}); func2->execute(); }; @@ -736,10 +1035,12 @@ TEST(TestOprCollectiveComm, Gather) { auto x1 = opr::Host2DeviceCopy::make(*graph, host_x1, cn0); auto x1c = opr::Copy::make(x1, cn1); - auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "gather", - 2, true, 0, client, {Mode::GATHER}, dtype::Float32(), "nccl")[0]; - auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "gather", - 2, false, 1, client, {Mode::GATHER}, dtype::Float32(), "nccl")[0]; + auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "gather", 2, true, 0, + false, client, {Mode::GATHER}, + dtype::Float32(), "nccl")[0]; + auto y1 = opr::CollectiveComm::make({x1c}, graph.get(), "gather", 2, false, + 1, false, client, {Mode::GATHER}, + dtype::Float32(), "nccl")[0]; auto y_expect = opr::Concat::make({x0, x1}, 0); auto func = graph->compile({make_callback_copy(y0, host_y0), @@ -762,30 +1063,33 @@ TEST(TestOprCollectiveComm, GatherMultiThread) { auto client = std::make_shared(); - auto run_0 = [&]() { // rank 0 + auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); - auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "gather", 2, true, 0, client, + auto y0 = opr::CollectiveComm::make( + {x0}, graph0.get(), "gather", 2, true, 0, false, client, {Mode::GATHER}, dtype::Float32(), "nccl")[0]; auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); func0->execute(); }; - auto run_1 = [&]() { // rank 1 + auto run_1 = [&]() { // rank 1 auto graph1 = ComputingGraph::make(); auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); - auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "gather", 2, false, 1, client, + auto y1 = opr::CollectiveComm::make( + {x1}, graph1.get(), "gather", 2, false, 1, false, client, {Mode::GATHER}, dtype::Float32(), "nccl")[0]; auto func1 = graph1->compile({{y1, nullptr}}); func1->execute(); }; - auto run_2 = [&]() { // check + auto run_2 = [&]() { // check auto graph2 = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph2, host_x0, cn0); auto x1 = opr::Host2DeviceCopy::make(*graph2, host_x1, cn0); auto y_expect = opr::Concat::make({x0, x1}, 0); - auto func2 = graph2->compile({make_callback_copy(y_expect, host_y_expect)}); + auto func2 = + graph2->compile({make_callback_copy(y_expect, host_y_expect)}); func2->execute(); }; @@ -816,12 +1120,13 @@ TEST(TestOprCollectiveComm, GatherWithGrad) { auto client = std::make_shared(); - auto run_0 = [&]() { // rank 0 + auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); graph0->options().graph_opt_level = 0; auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); - auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "gather", 2, true, 0, client, + auto y0 = opr::CollectiveComm::make( + {x0}, graph0.get(), "gather", 2, true, 0, false, client, {Mode::GATHER}, dtype::Float32(), "nccl")[0]; y0.node()->owner_opr()->node_prop().attribute().priority = -1; @@ -831,18 +1136,18 @@ TEST(TestOprCollectiveComm, GatherWithGrad) { auto loss = opr::Dot::make(y0, grad); auto g = opr::VirtualGrad::make(loss, x0); - auto func0 = graph0->compile( - {make_callback_copy(y0, host_y0), - make_callback_copy(g, host_out_grad0)}); + auto func0 = graph0->compile({make_callback_copy(y0, host_y0), + make_callback_copy(g, host_out_grad0)}); func0->execute(); }; - auto run_1 = [&]() { // rank 1 + auto run_1 = [&]() { // rank 1 auto graph1 = ComputingGraph::make(); graph1->options().graph_opt_level = 0; auto x1 = opr::Host2DeviceCopy::make(*graph1, host_x1, cn1); - auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "gather", 2, false, 1, client, + auto y1 = opr::CollectiveComm::make( + {x1}, graph1.get(), "gather", 2, false, 1, false, client, {Mode::GATHER}, dtype::Float32(), "nccl")[0]; y1.node()->owner_opr()->node_prop().attribute().priority = -1; @@ -850,17 +1155,18 @@ TEST(TestOprCollectiveComm, GatherWithGrad) { auto loss = opr::Dot::make(y1, grad); auto g = opr::VirtualGrad::make(loss, x1); - auto func1 = graph1->compile({{y1, nullptr}, make_callback_copy(g, host_out_grad1)}); + auto func1 = graph1->compile( + {{y1, nullptr}, make_callback_copy(g, host_out_grad1)}); func1->execute(); }; - auto run_2 = [&]() { // check + auto run_2 = [&]() { // check auto graph2 = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph2, host_x0, cn0); auto x1 = opr::Host2DeviceCopy::make(*graph2, host_x1, cn0); auto y0_expect = opr::Concat::make({x0, x1}, 0); - auto func2 = graph2->compile({ - make_callback_copy(y0_expect, host_y0_expect)}); + auto func2 = graph2->compile( + {make_callback_copy(y0_expect, host_y0_expect)}); func2->execute(); }; @@ -890,17 +1196,20 @@ TEST(TestOprCollectiveComm, Broadcast) { auto graph = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); - auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "broadcast", - 2, true, 0, client, {Mode::BROADCAST}, dtype::Float32(), "nccl")[0]; - auto y_dev = std::make_shared(DeviceTensorND() - .comp_node(cn1) - .dtype(dtype::Float32()) - .resize(host_x0->shape())); - auto y1 = opr::CollectiveComm::make({}, graph.get(), "broadcast", 2, false, 1, - client, {y_dev}, {Mode::BROADCAST}, dtype::Float32(), "nccl", {cn1})[0]; - - auto func = graph->compile({make_callback_copy(y0, host_y0), - make_callback_copy(y1, host_y1)}); + auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "broadcast", 2, true, + 0, false, client, {Mode::BROADCAST}, + dtype::Float32(), "nccl")[0]; + auto y_dev = + std::make_shared(DeviceTensorND() + .comp_node(cn1) + .dtype(dtype::Float32()) + .resize(host_x0->shape())); + auto y1 = opr::CollectiveComm::make( + {}, graph.get(), "broadcast", 2, false, 1, false, client, {y_dev}, + {Mode::BROADCAST}, dtype::Float32(), "nccl", {cn1})[0]; + + auto func = graph->compile( + {make_callback_copy(y0, host_y0), make_callback_copy(y1, host_y1)}); func->execute(); MGB_ASSERT_TENSOR_EQ(*host_x0, host_y0); @@ -918,22 +1227,25 @@ TEST(TestOprCollectiveComm, BroadcastMultiThread) { auto client = std::make_shared(); - auto run_0 = [&]() { // rank 0 + auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); - auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "broadcast", 2, true, 0, client, + auto y0 = opr::CollectiveComm::make( + {x0}, graph0.get(), "broadcast", 2, true, 0, false, client, {Mode::BROADCAST}, dtype::Float32(), "nccl")[0]; auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); func0->execute(); }; - auto run_1 = [&]() { // rank 1 + auto run_1 = [&]() { // rank 1 auto graph1 = ComputingGraph::make(); - auto y_dev = std::make_shared(DeviceTensorND() - .comp_node(cn1) - .dtype(dtype::Float32()) - .resize(host_x0->shape())); - auto y1 = opr::CollectiveComm::make({}, graph1.get(), "broadcast", 2, false, 1, client, + auto y_dev = std::make_shared( + DeviceTensorND() + .comp_node(cn1) + .dtype(dtype::Float32()) + .resize(host_x0->shape())); + auto y1 = opr::CollectiveComm::make( + {}, graph1.get(), "broadcast", 2, false, 1, false, client, {y_dev}, {Mode::BROADCAST}, dtype::Float32(), "nccl", {cn1})[0]; auto func1 = graph1->compile({make_callback_copy(y1, host_y1)}); func1->execute(); @@ -964,12 +1276,13 @@ TEST(TestOprCollectiveComm, BroadcastWithGrad) { auto client = std::make_shared(); - auto run_0 = [&]() { // rank 0 + auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); graph0->options().graph_opt_level = 0; auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); - auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "broadcast", 2, true, 0, client, + auto y0 = opr::CollectiveComm::make( + {x0}, graph0.get(), "broadcast", 2, true, 0, false, client, {Mode::BROADCAST}, dtype::Float32(), "nccl")[0]; y0.node()->owner_opr()->node_prop().attribute().priority = -1; @@ -977,35 +1290,37 @@ TEST(TestOprCollectiveComm, BroadcastWithGrad) { auto loss = opr::Dot::make(y0, grad0); auto g = opr::VirtualGrad::make(loss, x0); - auto func0 = graph0->compile( - {make_callback_copy(y0, host_y0), - make_callback_copy(g, host_out_grad)}); + auto func0 = graph0->compile({make_callback_copy(y0, host_y0), + make_callback_copy(g, host_out_grad)}); func0->execute(); }; - auto run_1 = [&]() { // rank 1 + auto run_1 = [&]() { // rank 1 auto graph1 = ComputingGraph::make(); graph1->options().graph_opt_level = 0; - auto y1 = opr::CollectiveComm::make({}, graph1.get(), "broadcast", 2, false, 1, client, + auto y1 = opr::CollectiveComm::make( + {}, graph1.get(), "broadcast", 2, false, 1, false, client, {Mode::BROADCAST}, dtype::Float32(), "nccl", {cn1})[0]; auto grad1 = opr::Host2DeviceCopy::make(*graph1, host_grad1, cn1); - auto g = opr::CollectiveComm::make({grad1}, graph1.get(), "broadcast:grad", 2, false, 1, client, - Mode::REDUCE_SUM, dtype::Float32(), "nccl")[0]; + auto g = opr::CollectiveComm::make( + {grad1}, graph1.get(), "broadcast:grad", 2, false, 1, false, + client, Mode::REDUCE_SUM, dtype::Float32(), "nccl")[0]; g.node()->owner_opr()->node_prop().attribute().priority = 1; - auto func1 = graph1->compile({make_callback_copy(y1, host_y1), {g, nullptr}}); + auto func1 = graph1->compile( + {make_callback_copy(y1, host_y1), {g, nullptr}}); func1->execute(); }; - auto run_2 = [&]() { // check + auto run_2 = [&]() { // check auto graph2 = ComputingGraph::make(); auto grad0 = opr::Host2DeviceCopy::make(*graph2, host_grad0, cn0); auto grad1 = opr::Host2DeviceCopy::make(*graph2, host_grad1, cn0); auto out_grad_expect = grad0 + grad1; - auto func2 = graph2->compile({ - make_callback_copy(out_grad_expect, host_out_grad_expect)}); + auto func2 = graph2->compile( + {make_callback_copy(out_grad_expect, host_out_grad_expect)}); func2->execute(); }; @@ -1038,13 +1353,15 @@ TEST(TestOprCollectiveComm, Scatter) { auto x0 = opr::Host2DeviceCopy::make(*graph, host_x0, cn0); auto x1 = opr::Host2DeviceCopy::make(*graph, host_x1, cn0); auto x = opr::Concat::make({x0, x1}, 0); - auto y0 = opr::CollectiveComm::make({x}, graph.get(), "scatter", - 2, true, 0, client, {Mode::SCATTER}, dtype::Float32(), "nccl")[0]; + auto y0 = opr::CollectiveComm::make({x}, graph.get(), "scatter", 2, true, 0, + false, client, {Mode::SCATTER}, + dtype::Float32(), "nccl")[0]; auto y1 = opr::CollectiveComm::make({}, graph.get(), "scatter", 2, false, 1, - client, {Mode::SCATTER}, dtype::Float32(), "nccl", {cn1})[0]; + false, client, {Mode::SCATTER}, + dtype::Float32(), "nccl", {cn1})[0]; - auto func = graph->compile({make_callback_copy(y0, host_y0), - make_callback_copy(y1, host_y1)}); + auto func = graph->compile( + {make_callback_copy(y0, host_y0), make_callback_copy(y1, host_y1)}); func->execute(); MGB_ASSERT_TENSOR_EQ(*host_x0, host_y0); @@ -1063,20 +1380,22 @@ TEST(TestOprCollectiveComm, ScatterMultiThread) { auto client = std::make_shared(); - auto run_0 = [&]() { // rank 0 + auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); auto x1 = opr::Host2DeviceCopy::make(*graph0, host_x1, cn0); auto x = opr::Concat::make({x0, x1}, 0); - auto y0 = opr::CollectiveComm::make({x}, graph0.get(), "scatter", 2, true, 0, client, + auto y0 = opr::CollectiveComm::make( + {x}, graph0.get(), "scatter", 2, true, 0, false, client, {Mode::SCATTER}, dtype::Float32(), "nccl")[0]; auto func0 = graph0->compile({make_callback_copy(y0, host_y0)}); func0->execute(); }; - auto run_1 = [&]() { // rank 1 + auto run_1 = [&]() { // rank 1 auto graph1 = ComputingGraph::make(); - auto y1 = opr::CollectiveComm::make({}, graph1.get(), "scatter", 2, false, 1, client, + auto y1 = opr::CollectiveComm::make( + {}, graph1.get(), "scatter", 2, false, 1, false, client, {Mode::SCATTER}, dtype::Float32(), "nccl", {cn1})[0]; auto func1 = graph1->compile({make_callback_copy(y1, host_y1)}); func1->execute(); @@ -1108,14 +1427,15 @@ TEST(TestOprCollectiveComm, ScatterWithGrad) { auto client = std::make_shared(); - auto run_0 = [&]() { // rank 0 + auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); graph0->options().graph_opt_level = 0; auto x0 = opr::Host2DeviceCopy::make(*graph0, host_x0, cn0); auto x1 = opr::Host2DeviceCopy::make(*graph0, host_x1, cn0); auto x = opr::Concat::make({x0, x1}, 0); - auto y0 = opr::CollectiveComm::make({x}, graph0.get(), "scatter", 2, true, 0, client, + auto y0 = opr::CollectiveComm::make( + {x}, graph0.get(), "scatter", 2, true, 0, false, client, {Mode::SCATTER}, dtype::Float32(), "nccl")[0]; y0.node()->owner_opr()->node_prop().attribute().priority = -1; @@ -1123,35 +1443,37 @@ TEST(TestOprCollectiveComm, ScatterWithGrad) { auto loss = opr::Dot::make(y0, grad0); auto g = opr::VirtualGrad::make(loss, x); - auto func0 = graph0->compile( - {make_callback_copy(y0, host_y0), - make_callback_copy(g, host_out_grad)}); + auto func0 = graph0->compile({make_callback_copy(y0, host_y0), + make_callback_copy(g, host_out_grad)}); func0->execute(); }; - auto run_1 = [&]() { // rank 1 + auto run_1 = [&]() { // rank 1 auto graph1 = ComputingGraph::make(); graph1->options().graph_opt_level = 0; - auto y1 = opr::CollectiveComm::make({}, graph1.get(), "scatter", 2, false, 1, client, + auto y1 = opr::CollectiveComm::make( + {}, graph1.get(), "scatter", 2, false, 1, false, client, {Mode::SCATTER}, dtype::Float32(), "nccl", {cn1})[0]; auto grad1 = opr::Host2DeviceCopy::make(*graph1, host_grad1, cn1); - auto g = opr::CollectiveComm::make({grad1}, graph1.get(), "scatter:grad", 2, false, 1, client, - Mode::GATHER, dtype::Float32(), "nccl")[0]; + auto g = opr::CollectiveComm::make( + {grad1}, graph1.get(), "scatter:grad", 2, false, 1, false, + client, Mode::GATHER, dtype::Float32(), "nccl")[0]; g.node()->owner_opr()->node_prop().attribute().priority = 1; - auto func1 = graph1->compile({make_callback_copy(y1, host_y1), {g, nullptr}}); + auto func1 = graph1->compile( + {make_callback_copy(y1, host_y1), {g, nullptr}}); func1->execute(); }; - auto run_2 = [&]() { // check + auto run_2 = [&]() { // check auto graph2 = ComputingGraph::make(); auto grad0 = opr::Host2DeviceCopy::make(*graph2, host_grad0, cn0); auto grad1 = opr::Host2DeviceCopy::make(*graph2, host_grad1, cn0); auto out_grad_expect = opr::Concat::make({grad0, grad1}, 0); - auto func2 = graph2->compile({ - make_callback_copy(out_grad_expect, host_out_grad_expect)}); + auto func2 = graph2->compile( + {make_callback_copy(out_grad_expect, host_out_grad_expect)}); func2->execute(); }; @@ -1197,10 +1519,12 @@ TEST(TestOprCollectiveComm, AllToAll) { auto expect_y0 = opr::Concat::make({x00, x10c}, 0); auto expect_y1 = opr::Concat::make({x01c, x11}, 0); - auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "alltoall", - 2, false, 0, client, {Mode::ALL_TO_ALL}, dtype::Float32(), "nccl")[0]; - auto y1 = opr::CollectiveComm::make({x1}, graph.get(), "alltoall", 2, false, 1, - client, {Mode::ALL_TO_ALL}, dtype::Float32(), "nccl")[0]; + auto y0 = opr::CollectiveComm::make({x0}, graph.get(), "alltoall", 2, false, + 0, false, client, {Mode::ALL_TO_ALL}, + dtype::Float32(), "nccl")[0]; + auto y1 = opr::CollectiveComm::make({x1}, graph.get(), "alltoall", 2, false, + 1, false, client, {Mode::ALL_TO_ALL}, + dtype::Float32(), "nccl")[0]; auto func = graph->compile({make_callback_copy(y0, host_y0), make_callback_copy(y1, host_y1), @@ -1227,14 +1551,15 @@ TEST(TestOprCollectiveComm, AllToAllMultiThread) { auto client = std::make_shared(); - auto run_0 = [&]() { // rank 0 + auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); auto x00 = opr::Host2DeviceCopy::make(*graph0, host_x00, cn0); auto x01 = opr::Host2DeviceCopy::make(*graph0, host_x01, cn0); auto x10 = opr::Host2DeviceCopy::make(*graph0, host_x10, cn0); auto x0 = opr::Concat::make({x00, x01}, 0); auto expect_y0 = opr::Concat::make({x00, x10}, 0); - auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "alltoall", 2, false, 0, client, + auto y0 = opr::CollectiveComm::make( + {x0}, graph0.get(), "alltoall", 2, false, 0, false, client, {Mode::ALL_TO_ALL}, dtype::Float32(), "nccl")[0]; auto func0 = graph0->compile( {make_callback_copy(y0, host_y0), @@ -1242,14 +1567,15 @@ TEST(TestOprCollectiveComm, AllToAllMultiThread) { func0->execute(); }; - auto run_1 = [&]() { // rank 1 + auto run_1 = [&]() { // rank 1 auto graph1 = ComputingGraph::make(); auto x10 = opr::Host2DeviceCopy::make(*graph1, host_x10, cn1); auto x11 = opr::Host2DeviceCopy::make(*graph1, host_x11, cn1); auto x01 = opr::Host2DeviceCopy::make(*graph1, host_x01, cn1); auto x1 = opr::Concat::make({x10, x11}, 0); auto expect_y1 = opr::Concat::make({x01, x11}, 0); - auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "alltoall", 2, false, 1, client, + auto y1 = opr::CollectiveComm::make( + {x1}, graph1.get(), "alltoall", 2, false, 1, false, client, {Mode::ALL_TO_ALL}, dtype::Float32(), "nccl")[0]; auto func1 = graph1->compile( {make_callback_copy(y1, host_y1), @@ -1288,7 +1614,7 @@ TEST(TestOprCollectiveComm, AllToAllWithGrad) { auto client = std::make_shared(); - auto run_0 = [&]() { // rank 0 + auto run_0 = [&]() { // rank 0 auto graph0 = ComputingGraph::make(); graph0->options().graph_opt_level = 0; @@ -1297,7 +1623,8 @@ TEST(TestOprCollectiveComm, AllToAllWithGrad) { auto x10 = opr::Host2DeviceCopy::make(*graph0, host_x10, cn0); auto x0 = opr::Concat::make({x00, x01}, 0); auto expect_y0 = opr::Concat::make({x00, x10}, 0); - auto y0 = opr::CollectiveComm::make({x0}, graph0.get(), "alltoall", 2, false, 0, client, + auto y0 = opr::CollectiveComm::make( + {x0}, graph0.get(), "alltoall", 2, false, 0, false, client, {Mode::ALL_TO_ALL}, dtype::Float32(), "nccl")[0]; y0.node()->owner_opr()->node_prop().attribute().priority = -1; @@ -1308,13 +1635,13 @@ TEST(TestOprCollectiveComm, AllToAllWithGrad) { auto g = opr::VirtualGrad::make(loss, x0); auto func0 = graph0->compile( - {make_callback_copy(y0, host_y0), - make_callback_copy(g, host_grad0), - make_callback_copy(expect_y0, host_expect_y0)}); + {make_callback_copy(y0, host_y0), + make_callback_copy(g, host_grad0), + make_callback_copy(expect_y0, host_expect_y0)}); func0->execute(); }; - auto run_1 = [&]() { // rank 1 + auto run_1 = [&]() { // rank 1 auto graph1 = ComputingGraph::make(); graph1->options().graph_opt_level = 0; @@ -1323,7 +1650,8 @@ TEST(TestOprCollectiveComm, AllToAllWithGrad) { auto x01 = opr::Host2DeviceCopy::make(*graph1, host_x01, cn1); auto x1 = opr::Concat::make({x10, x11}, 0); auto expect_y1 = opr::Concat::make({x01, x11}, 0); - auto y1 = opr::CollectiveComm::make({x1}, graph1.get(), "alltoall", 2, false, 1, client, + auto y1 = opr::CollectiveComm::make( + {x1}, graph1.get(), "alltoall", 2, false, 1, false, client, {Mode::ALL_TO_ALL}, dtype::Float32(), "nccl")[0]; y1.node()->owner_opr()->node_prop().attribute().priority = -1; @@ -1334,13 +1662,13 @@ TEST(TestOprCollectiveComm, AllToAllWithGrad) { auto g = opr::VirtualGrad::make(loss, x1); auto func0 = graph1->compile( - {make_callback_copy(y1, host_y1), - make_callback_copy(g, host_grad1), - make_callback_copy(expect_y1, host_expect_y1)}); + {make_callback_copy(y1, host_y1), + make_callback_copy(g, host_grad1), + make_callback_copy(expect_y1, host_expect_y1)}); func0->execute(); }; - auto run_2 = [&]() { // check + auto run_2 = [&]() { // check auto graph2 = ComputingGraph::make(); auto grad00 = opr::Host2DeviceCopy::make(*graph2, host_grad00, cn0); auto grad01 = opr::Host2DeviceCopy::make(*graph2, host_grad01, cn0); @@ -1348,9 +1676,114 @@ TEST(TestOprCollectiveComm, AllToAllWithGrad) { auto grad11 = opr::Host2DeviceCopy::make(*graph2, host_grad11, cn0); auto out_grad0_expect = opr::Concat::make({grad00, grad01}, 0); auto out_grad1_expect = opr::Concat::make({grad10, grad11}, 0); - auto func2 = graph2->compile({ - make_callback_copy(out_grad0_expect, host_expect_grad0), - make_callback_copy(out_grad1_expect, host_expect_grad1)}); + auto func2 = graph2->compile( + {make_callback_copy(out_grad0_expect, host_expect_grad0), + make_callback_copy(out_grad1_expect, host_expect_grad1)}); + func2->execute(); + }; + + std::thread t0(run_0); + std::thread t1(run_1); + std::thread t2(run_2); + + t0.join(); + t1.join(); + t2.join(); + + MGB_ASSERT_TENSOR_EQ(host_expect_y0, host_y0); + MGB_ASSERT_TENSOR_EQ(host_expect_y1, host_y1); + MGB_ASSERT_TENSOR_EQ(host_expect_grad0, host_grad0); + MGB_ASSERT_TENSOR_EQ(host_expect_grad1, host_grad1); +} + +TEST(TestOprCollectiveComm, AllToAllWithGradThisNodeOnly) { + REQUIRE_GPU(2); + auto cn0 = CompNode::load("gpu0"); + auto cn1 = CompNode::load("gpu1"); + + HostTensorGenerator<> gen; + HostTensorGenerator<> zeros(0, 0); + TensorShape shape({10}); + auto host_x00 = gen(shape); + auto host_x01 = gen(shape); + auto host_x10 = gen(shape); + auto host_x11 = gen(shape); + auto host_grad00 = gen(shape); + auto host_grad01 = gen(shape); + auto host_grad10 = gen(shape); + auto host_grad11 = gen(shape); + auto host_zero_grad = zeros(shape); + + HostTensorND host_y0, host_y1, host_expect_y0, host_expect_y1, host_grad0, + host_grad1, host_expect_grad0, host_expect_grad1; + + auto client = std::make_shared(); + + auto run_0 = [&]() { // rank 0 + auto graph0 = ComputingGraph::make(); + graph0->options().graph_opt_level = 0; + + auto x00 = opr::Host2DeviceCopy::make(*graph0, host_x00, cn0); + auto x01 = opr::Host2DeviceCopy::make(*graph0, host_x01, cn0); + auto x10 = opr::Host2DeviceCopy::make(*graph0, host_x10, cn0); + auto x0 = opr::Concat::make({x00, x01}, 0); + auto expect_y0 = opr::Concat::make({x00, x10}, 0); + auto y0 = opr::CollectiveComm::make( + {x0}, graph0.get(), "alltoall", 2, false, 0, true, client, + {Mode::ALL_TO_ALL}, dtype::Float32(), "nccl")[0]; + y0.node()->owner_opr()->node_prop().attribute().priority = -1; + + auto grad00 = opr::Host2DeviceCopy::make(*graph0, host_grad00, cn0); + auto grad10 = opr::Host2DeviceCopy::make(*graph0, host_grad10, cn0); + auto grad_y0 = opr::Concat::make({grad00, grad10}, 0); + auto loss = opr::Dot::make(y0, grad_y0); + auto g = opr::VirtualGrad::make(loss, x0); + + auto func0 = graph0->compile( + {make_callback_copy(y0, host_y0), + make_callback_copy(g, host_grad0), + make_callback_copy(expect_y0, host_expect_y0)}); + func0->execute(); + }; + + auto run_1 = [&]() { // rank 1 + auto graph1 = ComputingGraph::make(); + graph1->options().graph_opt_level = 0; + + auto x10 = opr::Host2DeviceCopy::make(*graph1, host_x10, cn1); + auto x11 = opr::Host2DeviceCopy::make(*graph1, host_x11, cn1); + auto x01 = opr::Host2DeviceCopy::make(*graph1, host_x01, cn1); + auto x1 = opr::Concat::make({x10, x11}, 0); + auto expect_y1 = opr::Concat::make({x01, x11}, 0); + auto y1 = opr::CollectiveComm::make( + {x1}, graph1.get(), "alltoall", 2, false, 1, true, client, + {Mode::ALL_TO_ALL}, dtype::Float32(), "nccl")[0]; + y1.node()->owner_opr()->node_prop().attribute().priority = -1; + + auto grad01 = opr::Host2DeviceCopy::make(*graph1, host_grad01, cn1); + auto grad11 = opr::Host2DeviceCopy::make(*graph1, host_grad11, cn1); + auto grad_y1 = opr::Concat::make({grad01, grad11}, 0); + auto loss = opr::Dot::make(y1, grad_y1); + auto g = opr::VirtualGrad::make(loss, x1); + + auto func0 = graph1->compile( + {make_callback_copy(y1, host_y1), + make_callback_copy(g, host_grad1), + make_callback_copy(expect_y1, host_expect_y1)}); + func0->execute(); + }; + + auto run_2 = [&]() { // check + auto graph2 = ComputingGraph::make(); + auto grad00 = opr::Host2DeviceCopy::make(*graph2, host_grad00, cn0); + auto grad11 = opr::Host2DeviceCopy::make(*graph2, host_grad11, cn0); + auto zero_grad = + opr::Host2DeviceCopy::make(*graph2, host_zero_grad, cn0); + auto out_grad0_expect = opr::Concat::make({grad00, zero_grad}, 0); + auto out_grad1_expect = opr::Concat::make({zero_grad, grad11}, 0); + auto func2 = graph2->compile( + {make_callback_copy(out_grad0_expect, host_expect_grad0), + make_callback_copy(out_grad1_expect, host_expect_grad1)}); func2->execute(); }; diff --git a/tools/param_defs/mgb_opr_param_defs.py b/tools/param_defs/mgb_opr_param_defs.py index 2399178b2..c1ad2fe65 100644 --- a/tools/param_defs/mgb_opr_param_defs.py +++ b/tools/param_defs/mgb_opr_param_defs.py @@ -46,7 +46,7 @@ pdef('PersistentOutputStorage').add_fields( (pdef('CollectiveComm', 'collective communication between multiple computing ' 'nodes on localhost') - .add_enum('Mode', + .add_enum(Doc('Mode', 'mode of collective communication'), Doc('REDUCE_SUM', 'reduce by sum to output computing node'), Doc('BROADCAST', 'copy input value to each output computing node'), Doc('ALL_GATHER', 'each output comp node gets the concatenated ' @@ -59,7 +59,8 @@ pdef('PersistentOutputStorage').add_fields( Doc('ALL_REDUCE_PROD', 'every output gets the prod of all inputs'), Doc('GATHER', 'concat inputs to one node'), Doc('SCATTER', 'scatter input to each output computing node'), - Doc('ALL_TO_ALL', 'scatter inputs and gather them on each computing node'))) + Doc('ALL_TO_ALL', 'scatter inputs and gather them on each computing node'), + name_field='mode')) (pdef('FakeSerializedDType', 'HACK: The tag of this param def is actually used for another ' -- GitLab