提交 6d367454 编写于 作者: M Megvii Engine Team

feat(mge/opr-mm): add param local_grad for collective_comm opr

GitOrigin-RevId: cc120cfb55d67a48dc126d1fd8773fa08a860d32
上级 0ccb965c
...@@ -11,10 +11,13 @@ from .functional import ( ...@@ -11,10 +11,13 @@ from .functional import (
all_reduce_max, all_reduce_max,
all_reduce_min, all_reduce_min,
all_reduce_sum, all_reduce_sum,
all_to_all,
bcast_param, bcast_param,
broadcast, broadcast,
gather,
reduce_scatter_sum, reduce_scatter_sum,
reduce_sum, reduce_sum,
scatter,
) )
from .util import ( from .util import (
get_backend, get_backend,
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
from typing import Optional, Union from typing import Optional, Union
import megengine._internal as mgb import megengine._internal as mgb
from megengine._internal.opr_param_defs import CollectiveComm as CollParam from megengine._internal.opr_param_defs import CollectiveComm as Param
from ..core import Buffer, Parameter, Tensor, wrap_io_tensor from ..core import Buffer, Parameter, Tensor, wrap_io_tensor
from ..functional import add_update from ..functional import add_update
...@@ -22,9 +22,16 @@ def _collective_comm(*args, **kargs): ...@@ -22,9 +22,16 @@ def _collective_comm(*args, **kargs):
return collective_comm_symvar(*args, **kargs) return collective_comm_symvar(*args, **kargs)
def _group_check(*args):
"""Return True when arguments are all None or all not None
"""
l = [val is None for val in args]
return len(set(l)) <= 1
def reduce_sum( def reduce_sum(
tensor: Tensor, tensor: Tensor,
key: str, key: Optional[str] = None,
nr_ranks: Optional[int] = None, nr_ranks: Optional[int] = None,
is_root: Optional[bool] = None, is_root: Optional[bool] = None,
) -> Tensor: ) -> Tensor:
...@@ -35,14 +42,17 @@ def reduce_sum( ...@@ -35,14 +42,17 @@ def reduce_sum(
:param nr_ranks: number of ranks, use util.get_world_size() as default :param nr_ranks: number of ranks, use util.get_world_size() as default
:param is_root: whether this is a root node :param is_root: whether this is a root node
""" """
assert _group_check(
key, nr_ranks, is_root
), "key, nr_ranks, is_root should be set at the same time"
return _collective_comm( return _collective_comm(
tensor, key, CollParam.Mode.REDUCE_SUM, nr_ranks, is_root, device=tensor.device, tensor, key, Param.Mode.REDUCE_SUM, nr_ranks, is_root, device=tensor.device,
) )
def gather( def gather(
tensor: Tensor, tensor: Tensor,
key: str, key: Optional[str] = None,
nr_ranks: Optional[int] = None, nr_ranks: Optional[int] = None,
is_root: Optional[bool] = None, is_root: Optional[bool] = None,
rank: Optional[int] = None, rank: Optional[int] = None,
...@@ -55,20 +65,17 @@ def gather( ...@@ -55,20 +65,17 @@ def gather(
:param is_root: whether this is a root node :param is_root: whether this is a root node
:param rank: rank of this node :param rank: rank of this node
""" """
assert _group_check(
key, nr_ranks, is_root, rank
), "key, nr_ranks, is_root, rank should be set at the same time"
return _collective_comm( return _collective_comm(
tensor, tensor, key, Param.Mode.GATHER, nr_ranks, is_root, rank, device=tensor.device,
key,
CollParam.Mode.GATHER,
nr_ranks,
is_root,
rank,
device=tensor.device,
) )
def broadcast( def broadcast(
tensor: Tensor, tensor: Tensor,
key: str, key: Optional[str] = None,
nr_ranks: Optional[int] = None, nr_ranks: Optional[int] = None,
is_root: Optional[bool] = None, is_root: Optional[bool] = None,
) -> Tensor: ) -> Tensor:
...@@ -79,11 +86,12 @@ def broadcast( ...@@ -79,11 +86,12 @@ def broadcast(
:param nr_ranks: number of ranks, use util.get_world_size() as default :param nr_ranks: number of ranks, use util.get_world_size() as default
:param is_root: whether this is a root node :param is_root: whether this is a root node
""" """
if key is None: assert _group_check(
key = tensor._symvar.name key, nr_ranks, is_root
), "key, nr_ranks, is_root should be set at the same time"
if is_root is None: if is_root is None:
is_root = get_rank() == 0 is_root = get_rank() == 0
if is_root: if is_root:
inp = tensor inp = tensor
else: else:
...@@ -92,7 +100,7 @@ def broadcast( ...@@ -92,7 +100,7 @@ def broadcast(
return _collective_comm( return _collective_comm(
inp, inp,
key, key,
CollParam.Mode.BROADCAST, Param.Mode.BROADCAST,
nr_ranks, nr_ranks,
is_root, is_root,
dtype=tensor.dtype, dtype=tensor.dtype,
...@@ -102,7 +110,7 @@ def broadcast( ...@@ -102,7 +110,7 @@ def broadcast(
def scatter( def scatter(
tensor: Tensor, tensor: Tensor,
key: str, key: Optional[str] = None,
nr_ranks: Optional[int] = None, nr_ranks: Optional[int] = None,
is_root: Optional[bool] = None, is_root: Optional[bool] = None,
rank: Optional[int] = None, rank: Optional[int] = None,
...@@ -115,6 +123,9 @@ def scatter( ...@@ -115,6 +123,9 @@ def scatter(
:param is_root: whether this is a root node :param is_root: whether this is a root node
:param rank: rank of this node :param rank: rank of this node
""" """
assert _group_check(
key, nr_ranks, is_root, rank
), "key, nr_ranks, is_root, rank should be set at the same time"
if key is None: if key is None:
key = tensor._symvar.name key = tensor._symvar.name
if is_root is None: if is_root is None:
...@@ -128,7 +139,7 @@ def scatter( ...@@ -128,7 +139,7 @@ def scatter(
return _collective_comm( return _collective_comm(
inp, inp,
key, key,
CollParam.Mode.SCATTER, Param.Mode.SCATTER,
nr_ranks, nr_ranks,
is_root, is_root,
rank, rank,
...@@ -138,7 +149,11 @@ def scatter( ...@@ -138,7 +149,11 @@ def scatter(
def all_to_all( def all_to_all(
tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None tensor: Tensor,
key: Optional[str] = None,
nr_ranks: Optional[int] = None,
rank: Optional[int] = None,
local_grad: Optional[bool] = False,
) -> Tensor: ) -> Tensor:
"""Create all_to_all operator for collective communication """Create all_to_all operator for collective communication
...@@ -146,12 +161,22 @@ def all_to_all( ...@@ -146,12 +161,22 @@ def all_to_all(
:param key: unique identifier for collective communication :param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default :param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of this node :param rank: rank of this node
:param local_grad: whether use local grad
""" """
return _collective_comm(tensor, key, CollParam.Mode.ALL_TO_ALL, nr_ranks, rank=rank) assert _group_check(
key, nr_ranks, rank
), "key, nr_ranks, rank should be set at the same time"
return _collective_comm(
tensor, key, Param.Mode.ALL_TO_ALL, nr_ranks, rank=rank, local_grad=local_grad,
)
def all_gather( def all_gather(
tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None tensor: Tensor,
key: Optional[str] = None,
nr_ranks: Optional[int] = None,
rank: Optional[int] = None,
local_grad: Optional[bool] = False,
) -> Tensor: ) -> Tensor:
"""Create all_gather operator for collective communication """Create all_gather operator for collective communication
...@@ -159,12 +184,22 @@ def all_gather( ...@@ -159,12 +184,22 @@ def all_gather(
:param key: unique identifier for collective communication :param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default :param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of this node :param rank: rank of this node
:param local_grad: whether use local grad
""" """
return _collective_comm(tensor, key, CollParam.Mode.ALL_GATHER, nr_ranks, rank=rank) assert _group_check(
key, nr_ranks, rank
), "key, nr_ranks, rank should be set at the same time"
return _collective_comm(
tensor, key, Param.Mode.ALL_GATHER, nr_ranks, rank=rank, local_grad=local_grad
)
def reduce_scatter_sum( def reduce_scatter_sum(
tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None tensor: Tensor,
key: Optional[str] = None,
nr_ranks: Optional[int] = None,
rank: Optional[int] = None,
local_grad: Optional[bool] = False,
) -> Tensor: ) -> Tensor:
"""Create reduce_scatter_sum operator for collective communication """Create reduce_scatter_sum operator for collective communication
...@@ -172,45 +207,81 @@ def reduce_scatter_sum( ...@@ -172,45 +207,81 @@ def reduce_scatter_sum(
:param key: unique identifier for collective communication :param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default :param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of this node :param rank: rank of this node
:param local_grad: whether use local grad
""" """
assert _group_check(
key, nr_ranks, rank
), "key, nr_ranks, rank should be set at the same time"
return _collective_comm( return _collective_comm(
tensor, key, CollParam.Mode.REDUCE_SCATTER_SUM, nr_ranks, rank=rank, tensor,
key,
Param.Mode.REDUCE_SCATTER_SUM,
nr_ranks,
rank=rank,
local_grad=local_grad,
) )
def all_reduce_sum(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor: def all_reduce_sum(
tensor: Tensor,
key: Optional[str] = None,
nr_ranks: Optional[int] = None,
local_grad: Optional[bool] = False,
) -> Tensor:
"""Create all_reduce_sum operator for collective communication """Create all_reduce_sum operator for collective communication
:param tensor: input tensor :param tensor: input tensor
:param key: unique identifier for collective communication :param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default :param nr_ranks: number of ranks, use util.get_world_size() as default
:param local_grad: whether use local grad
""" """
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_SUM, nr_ranks) assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time"
return _collective_comm(
tensor, key, Param.Mode.ALL_REDUCE_SUM, nr_ranks, local_grad=local_grad
)
def all_reduce_max(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor: def all_reduce_max(
tensor: Tensor,
key: Optional[str] = None,
nr_ranks: Optional[int] = None,
local_grad: Optional[bool] = False,
) -> Tensor:
"""Create all_reduce_max operator for collective communication """Create all_reduce_max operator for collective communication
:param tensor: input tensor :param tensor: input tensor
:param key: unique identifier for collective communication :param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default :param nr_ranks: number of ranks, use util.get_world_size() as default
:param local_grad: whether use local grad
""" """
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MAX, nr_ranks) assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time"
return _collective_comm(
tensor, key, Param.Mode.ALL_REDUCE_MAX, nr_ranks, local_grad=local_grad
)
def all_reduce_min(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor: def all_reduce_min(
tensor: Tensor,
key: Optional[str] = None,
nr_ranks: Optional[int] = None,
local_grad: Optional[bool] = False,
) -> Tensor:
"""Create all_reduce_min operator for collective communication """Create all_reduce_min operator for collective communication
:param tensor: input tensor :param tensor: input tensor
:param key: unique identifier for collective communication :param key: unique identifier for collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default :param nr_ranks: number of ranks, use util.get_world_size() as default
:param local_grad: whether use local grad
""" """
return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MIN, nr_ranks) assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time"
return _collective_comm(
tensor, key, Param.Mode.ALL_REDUCE_MIN, nr_ranks, local_grad=local_grad
)
def bcast_param( def bcast_param(
inp: Union[Buffer, Parameter], inp: Union[Buffer, Parameter],
key: str, key: Optional[str] = None,
nr_ranks: Optional[int] = None, nr_ranks: Optional[int] = None,
is_root: Optional[bool] = None, is_root: Optional[bool] = None,
) -> None: ) -> None:
...@@ -223,6 +294,9 @@ def bcast_param( ...@@ -223,6 +294,9 @@ def bcast_param(
""" """
if not is_distributed(): if not is_distributed():
return return
assert _group_check(
key, nr_ranks, is_root
), "key, nr_ranks, is_root should be set at the same time"
assert isinstance(inp, (Buffer, Parameter)) assert isinstance(inp, (Buffer, Parameter))
bcast_res = broadcast(inp, key, nr_ranks, is_root) bcast_res = broadcast(inp, key, nr_ranks, is_root)
add_update(inp, bcast_res, alpha=0) add_update(inp, bcast_res, alpha=0)
...@@ -11,16 +11,24 @@ from typing import Optional, Union ...@@ -11,16 +11,24 @@ from typing import Optional, Union
import megengine._internal as mgb import megengine._internal as mgb
from megengine._internal.opr_param_defs import CollectiveComm as CollParam from megengine._internal.opr_param_defs import CollectiveComm as CollParam
from .util import get_backend, get_master_ip, get_master_port, get_rank, get_world_size from .util import (
get_backend,
get_group_id,
get_master_ip,
get_master_port,
get_rank,
get_world_size,
)
def collective_comm_symvar( def collective_comm_symvar(
inp: Union[mgb.SymbolVar, mgb.CompGraph], inp: Union[mgb.SymbolVar, mgb.CompGraph],
key: str, key: Optional[str] = None,
op: CollParam.Mode, op: CollParam.Mode = None,
nr_ranks: Optional[int] = None, nr_ranks: Optional[int] = None,
is_root: Optional[bool] = None, is_root: Optional[bool] = None,
rank: Optional[int] = None, rank: Optional[int] = None,
local_grad: Optional[bool] = False,
dtype: Optional[type] = None, dtype: Optional[type] = None,
device: Optional[mgb.CompNode] = None, device: Optional[mgb.CompNode] = None,
comp_graph: Optional[mgb.CompGraph] = None, comp_graph: Optional[mgb.CompGraph] = None,
...@@ -32,16 +40,19 @@ def collective_comm_symvar( ...@@ -32,16 +40,19 @@ def collective_comm_symvar(
:param op: mode of collective communication :param op: mode of collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default :param nr_ranks: number of ranks, use util.get_world_size() as default
:param is_root: whether this node is root node :param is_root: whether this node is root node
:param rank: rank of this node
:param local_grad: whether use local grad
:param dtype: output data type, use dtype of inp as default :param dtype: output data type, use dtype of inp as default
:param device: output comp node, use comp node of inp as default :param device: output comp node, use comp node of inp as default
:param comp_graph: output comp graph, use comp graph of inp as default :param comp_graph: output comp graph, use comp graph of inp as default
""" """
return mgb.opr.collective_comm( return mgb.opr.collective_comm(
inp, inp,
key=str(key), key=key if key is not None else ("collective_comm_" + str(get_group_id())),
nr_devices=nr_ranks if nr_ranks is not None else get_world_size(), nr_devices=nr_ranks if nr_ranks is not None else get_world_size(),
is_root=is_root if is_root is not None else (get_rank() == 0), is_root=is_root if is_root is not None else (get_rank() == 0),
rank=rank if rank is not None else -1, rank=rank if rank is not None else get_rank(),
local_grad=local_grad,
server_addr=get_master_ip(), server_addr=get_master_ip(),
port=get_master_port(), port=get_master_port(),
param=CollParam(mode=op), param=CollParam(mode=op),
......
...@@ -182,7 +182,9 @@ class Optimizer(metaclass=ABCMeta): ...@@ -182,7 +182,9 @@ class Optimizer(metaclass=ABCMeta):
with opr_priority_scope(cg, -(2 ** 30)): with opr_priority_scope(cg, -(2 ** 30)):
# always run all_reduce_mean first except add_update # always run all_reduce_mean first except add_update
grad = ( grad = (
all_reduce_sum(grad, "grad_" + str(get_group_id())) all_reduce_sum(
grad, "grad_" + str(get_group_id()), get_world_size()
)
/ get_world_size() / get_world_size()
) )
with opr_priority_scope(cg, -(2 ** 31)): with opr_priority_scope(cg, -(2 ** 31)):
...@@ -229,7 +231,7 @@ class Optimizer(metaclass=ABCMeta): ...@@ -229,7 +231,7 @@ class Optimizer(metaclass=ABCMeta):
for group in self.param_groups: for group in self.param_groups:
for param in group["params"]: for param in group["params"]:
bcast_param( bcast_param(
param, "bcast_param_" + str(key), is_root=(get_rank() == 0), param, "bcast_param_" + str(key), get_world_size(), get_rank() == 0,
) )
key += 1 key += 1
......
...@@ -94,9 +94,9 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, ...@@ -94,9 +94,9 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port,
SymbolVar _Opr::collective_comm_with_input( SymbolVar _Opr::collective_comm_with_input(
SymbolVar inpvar, const std::string& key, const size_t nr_devices, SymbolVar inpvar, const std::string& key, const size_t nr_devices,
const bool is_root, const int rank, const std::string& server_addr, const bool is_root, const int rank, const bool local_grad,
const int port, PyObject* params, PyObject* dtype, const std::string& server_addr, const int port, PyObject* params,
const std::string& backend, SharedND* output_buf, PyObject* dtype, const std::string& backend, SharedND* output_buf,
const OperatorNodeConfig& config, const SharedScalar& disable) { const OperatorNodeConfig& config, const SharedScalar& disable) {
SymbolVarArray inputs(1, inpvar); SymbolVarArray inputs(1, inpvar);
ComputingGraph* graph = inpvar.node()->owner_graph(); ComputingGraph* graph = inpvar.node()->owner_graph();
...@@ -111,15 +111,15 @@ SymbolVar _Opr::collective_comm_with_input( ...@@ -111,15 +111,15 @@ SymbolVar _Opr::collective_comm_with_input(
_dtype = npy::dtype_np2mgb(dtype); _dtype = npy::dtype_np2mgb(dtype);
} }
return CollectiveComm::make(inputs, graph, key, nr_devices, is_root, rank, return CollectiveComm::make(inputs, graph, key, nr_devices, is_root, rank,
group_mgr, dev_buffer_arr, param, _dtype, local_grad, group_mgr, dev_buffer_arr, param,
backend, config, disable.get_val())[0]; _dtype, backend, config, disable.get_val())[0];
} }
SymbolVar _Opr::collective_comm_without_input( SymbolVar _Opr::collective_comm_without_input(
CompGraph& cg, const std::string& key, const size_t nr_devices, CompGraph& cg, const std::string& key, const size_t nr_devices,
const bool is_root, const int rank, const std::string& server_addr, const bool is_root, const int rank, const bool local_grad,
const int port, PyObject* params, PyObject* dtype, const std::string& server_addr, const int port, PyObject* params,
const std::string& backend, SharedND* output_buf, PyObject* dtype, const std::string& backend, SharedND* output_buf,
const OperatorNodeConfig& config, const SharedScalar& disable) { const OperatorNodeConfig& config, const SharedScalar& disable) {
SymbolVarArray inputs; SymbolVarArray inputs;
auto& graph = cg.get(); auto& graph = cg.get();
...@@ -134,8 +134,8 @@ SymbolVar _Opr::collective_comm_without_input( ...@@ -134,8 +134,8 @@ SymbolVar _Opr::collective_comm_without_input(
_dtype = npy::dtype_np2mgb(dtype); _dtype = npy::dtype_np2mgb(dtype);
} }
return CollectiveComm::make(inputs, &graph, key, nr_devices, is_root, rank, return CollectiveComm::make(inputs, &graph, key, nr_devices, is_root, rank,
group_mgr, dev_buffer_arr, param, _dtype, local_grad, group_mgr, dev_buffer_arr, param,
backend, config, disable.get_val())[0]; _dtype, backend, config, disable.get_val())[0];
} }
#else #else
...@@ -171,8 +171,8 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port, ...@@ -171,8 +171,8 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port,
} }
SymbolVar _Opr::collective_comm_with_input( SymbolVar _Opr::collective_comm_with_input(
SymbolVar inpvar, const std::string& key, SymbolVar inpvar, const std::string& key, const size_t nr_devices,
const size_t nr_devices, const bool is_root, const int rank, const bool is_root, const int rank, const bool local_grad,
const std::string& server_addr, const int port, PyObject* params, const std::string& server_addr, const int port, PyObject* params,
PyObject* dtype, const std::string& backend, SharedND* output_buf, PyObject* dtype, const std::string& backend, SharedND* output_buf,
const OperatorNodeConfig& config, const SharedScalar& disable) { const OperatorNodeConfig& config, const SharedScalar& disable) {
...@@ -180,8 +180,8 @@ SymbolVar _Opr::collective_comm_with_input( ...@@ -180,8 +180,8 @@ SymbolVar _Opr::collective_comm_with_input(
} }
SymbolVar _Opr::collective_comm_without_input( SymbolVar _Opr::collective_comm_without_input(
CompGraph& cg, const std::string& key, CompGraph& cg, const std::string& key, const size_t nr_devices,
const size_t nr_devices, const bool is_root, const int rank, const bool is_root, const int rank, const bool local_grad,
const std::string& server_addr, const int port, PyObject* params, const std::string& server_addr, const int port, PyObject* params,
PyObject* dtype, const std::string& backend, SharedND* output_buf, PyObject* dtype, const std::string& backend, SharedND* output_buf,
const OperatorNodeConfig& config, const SharedScalar& disable) { const OperatorNodeConfig& config, const SharedScalar& disable) {
......
...@@ -94,17 +94,17 @@ static SymbolVar remote_recv(const std::string& server_addr, const int port, ...@@ -94,17 +94,17 @@ static SymbolVar remote_recv(const std::string& server_addr, const int port,
static SymbolVar collective_comm_with_input( static SymbolVar collective_comm_with_input(
SymbolVar inpvar, const std::string& key, const size_t nr_devices, SymbolVar inpvar, const std::string& key, const size_t nr_devices,
const bool is_root, const int rank, const std::string& server_addr, const int port, const bool is_root, const int rank, const bool local_grad,
PyObject* params, PyObject* dtype, const std::string& backend, const std::string& server_addr, const int port, PyObject* params,
SharedND* output_buf, const OperatorNodeConfig& config, PyObject* dtype, const std::string& backend, SharedND* output_buf,
const SharedScalar& disable); const OperatorNodeConfig& config, const SharedScalar& disable);
static SymbolVar collective_comm_without_input( static SymbolVar collective_comm_without_input(
CompGraph& graph, const std::string& key, const size_t nr_devices, CompGraph& graph, const std::string& key, const size_t nr_devices,
const bool is_root, const int rank, const std::string& server_addr, const int port, const bool is_root, const int rank, const bool local_grad,
PyObject* params, PyObject* dtype, const std::string& backend, const std::string& server_addr, const int port, PyObject* params,
SharedND* output_buf, const OperatorNodeConfig& config, PyObject* dtype, const std::string& backend, SharedND* output_buf,
const SharedScalar& disable); const OperatorNodeConfig& config, const SharedScalar& disable);
// misc // misc
static SymbolVarArray extern_c_opr_placeholder( static SymbolVarArray extern_c_opr_placeholder(
......
...@@ -34,7 +34,7 @@ def test_reduce_sum(): ...@@ -34,7 +34,7 @@ def test_reduce_sum():
return return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue) _init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data) inp = tensor(data)
output = dist.functional.reduce_sum(inp, "x") output = dist.functional.reduce_sum(inp)
if rank == 0: if rank == 0:
assert np.allclose(output.numpy(), expect) assert np.allclose(output.numpy(), expect)
else: else:
...@@ -70,7 +70,7 @@ def test_gather(): ...@@ -70,7 +70,7 @@ def test_gather():
return return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue) _init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data) inp = tensor(data)
output = dist.functional.gather(inp, "x", is_root=(rank == 0), rank=rank) output = dist.functional.gather(inp)
if rank == 0: if rank == 0:
assert np.allclose(output.numpy(), expect) assert np.allclose(output.numpy(), expect)
else: else:
...@@ -106,7 +106,7 @@ def test_broadcast(): ...@@ -106,7 +106,7 @@ def test_broadcast():
return return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue) _init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data) inp = tensor(data)
output = dist.functional.broadcast(inp, "x") output = dist.functional.broadcast(inp)
assert np.allclose(output.numpy(), expect) assert np.allclose(output.numpy(), expect)
def check(shape, backend): def check(shape, backend):
...@@ -138,7 +138,7 @@ def test_scatter(): ...@@ -138,7 +138,7 @@ def test_scatter():
return return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue) _init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data) inp = tensor(data)
output = dist.functional.scatter(inp, "x", is_root=(rank == 0), rank=rank) output = dist.functional.scatter(inp)
assert np.allclose(output.numpy(), expect) assert np.allclose(output.numpy(), expect)
def check(shape, backend): def check(shape, backend):
...@@ -174,7 +174,7 @@ def test_all_to_all(): ...@@ -174,7 +174,7 @@ def test_all_to_all():
return return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue) _init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data) inp = tensor(data)
output = dist.functional.all_to_all(inp, "x", rank=rank) output = dist.functional.all_to_all(inp)
assert np.allclose(output.numpy(), expect) assert np.allclose(output.numpy(), expect)
def check(shape, backend): def check(shape, backend):
...@@ -208,7 +208,7 @@ def test_all_gather(): ...@@ -208,7 +208,7 @@ def test_all_gather():
return return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue) _init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data) inp = tensor(data)
output = dist.functional.all_gather(inp, "x", rank=rank) output = dist.functional.all_gather(inp)
assert np.allclose(output.numpy(), expect) assert np.allclose(output.numpy(), expect)
def check(shape, backend): def check(shape, backend):
...@@ -241,7 +241,7 @@ def test_reduce_scatter_sum(): ...@@ -241,7 +241,7 @@ def test_reduce_scatter_sum():
return return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue) _init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data) inp = tensor(data)
output = dist.functional.reduce_scatter_sum(inp, "x", rank=rank) output = dist.functional.reduce_scatter_sum(inp)
assert np.allclose(output.numpy(), expect) assert np.allclose(output.numpy(), expect)
def check(shape, backend): def check(shape, backend):
...@@ -278,7 +278,7 @@ def test_all_reduce_sum(): ...@@ -278,7 +278,7 @@ def test_all_reduce_sum():
return return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue) _init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data) inp = tensor(data)
output = dist.functional.all_reduce_sum(inp, "x") output = dist.functional.all_reduce_sum(inp)
assert np.allclose(output.numpy(), expect) assert np.allclose(output.numpy(), expect)
def check(shape, backend): def check(shape, backend):
...@@ -311,7 +311,7 @@ def test_all_reduce_max(): ...@@ -311,7 +311,7 @@ def test_all_reduce_max():
return return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue) _init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data) inp = tensor(data)
output = dist.functional.all_reduce_max(inp, "x") output = dist.functional.all_reduce_max(inp)
assert np.allclose(output.numpy(), expect) assert np.allclose(output.numpy(), expect)
def check(shape, backend): def check(shape, backend):
...@@ -344,7 +344,7 @@ def test_all_reduce_min(): ...@@ -344,7 +344,7 @@ def test_all_reduce_min():
return return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue) _init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = tensor(data) inp = tensor(data)
output = dist.functional.all_reduce_min(inp, "x") output = dist.functional.all_reduce_min(inp)
assert np.allclose(output.numpy(), expect) assert np.allclose(output.numpy(), expect)
def check(shape, backend): def check(shape, backend):
...@@ -377,7 +377,7 @@ def test_bcast_param(): ...@@ -377,7 +377,7 @@ def test_bcast_param():
return return
_init_process_group_wrapper(world_size, rank, rank, backend, port_queue) _init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
inp = Parameter(data) inp = Parameter(data)
dist.functional.bcast_param(inp, "x") dist.functional.bcast_param(inp)
assert np.allclose(inp.numpy(), expect) assert np.allclose(inp.numpy(), expect)
def check(shape, backend): def check(shape, backend):
......
...@@ -688,6 +688,7 @@ bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) { ...@@ -688,6 +688,7 @@ bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) {
if (!opr->same_type<opr::CollectiveComm>()) return false; if (!opr->same_type<opr::CollectiveComm>()) return false;
auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); auto& comm = opr->cast_final_safe<opr::CollectiveComm>();
if (comm.param().mode != opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM) return false; if (comm.param().mode != opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM) return false;
if (comm.local_grad()) return false;
if (comm.input().size() != 1) return false; if (comm.input().size() != 1) return false;
auto grad = comm.input(0)->owner_opr(); auto grad = comm.input(0)->owner_opr();
...@@ -839,7 +840,7 @@ void PackAllReduceReplacePass::insert_packed_oprs( ...@@ -839,7 +840,7 @@ void PackAllReduceReplacePass::insert_packed_oprs(
std::string key = ssprintf("grad_pack_%zu", pack_id); std::string key = ssprintf("grad_pack_%zu", pack_id);
auto param = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; auto param = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM;
SymbolVar allreduce = opr::CollectiveComm::make({concat}, graph, SymbolVar allreduce = opr::CollectiveComm::make({concat}, graph,
key, info->nr_devices, info->is_root, info->rank, key, info->nr_devices, info->is_root, info->rank, false,
info->group_client, param, info->dtype, info->backend)[0]; info->group_client, param, info->dtype, info->backend)[0];
// split according to recorded partition // split according to recorded partition
......
...@@ -438,14 +438,14 @@ TEST_PASS(PackAllReduceScanPass, Basic) { ...@@ -438,14 +438,14 @@ TEST_PASS(PackAllReduceScanPass, Basic) {
auto grad3 = opr::VirtualGrad::make(y1, x1); auto grad3 = opr::VirtualGrad::make(y1, x1);
auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM; auto mode = opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM;
auto comm0 = opr::CollectiveComm::make({grad0}, graph.get(), auto comm0 = opr::CollectiveComm::make({grad0}, graph.get(), "grad0", 2,
"grad0", 2, 0, 0, client, mode)[0]; false, 0, false, client, mode)[0];
auto comm1 = opr::CollectiveComm::make({grad1}, graph.get(), auto comm1 = opr::CollectiveComm::make({grad1}, graph.get(), "grad1", 2,
"grad1", 2, 0, 0, client, mode)[0]; false, 0, false, client, mode)[0];
auto comm2 = opr::CollectiveComm::make({grad2}, graph.get(), auto comm2 = opr::CollectiveComm::make({grad2}, graph.get(), "grad2", 2,
"grad2", 2, 0, 0, client, mode)[0]; false, 0, false, client, mode)[0];
auto comm3 = opr::CollectiveComm::make({grad3}, graph.get(), auto comm3 = opr::CollectiveComm::make({grad3}, graph.get(), "grad3", 2,
"grad3", 2, 0, 0, client, mode)[0]; false, 0, false, client, mode)[0];
gopt::GraphOptimizer() gopt::GraphOptimizer()
.add_pass<gopt::PackAllReduceScanPass>() .add_pass<gopt::PackAllReduceScanPass>()
...@@ -488,10 +488,12 @@ TEST_PASS(PackAllReduceReplacePass, CollectGroups) { ...@@ -488,10 +488,12 @@ TEST_PASS(PackAllReduceReplacePass, CollectGroups) {
auto grad = opr::VirtualGrad::make(target, wrt); auto grad = opr::VirtualGrad::make(target, wrt);
auto comm = opr::CollectiveComm::make( auto comm =
{grad}, graph.get(), "key", 2, 0, 0, client, opr::CollectiveComm::make(
opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM)[0] {grad}, graph.get(), "key", 2, false, 0, false, client,
.node()->owner_opr(); opr::CollectiveComm::Param::Mode::ALL_REDUCE_SUM)[0]
.node()
->owner_opr();
comm->cast_final_safe<opr::CollectiveComm>().set_pack_hash(extra_hash); comm->cast_final_safe<opr::CollectiveComm>().set_pack_hash(extra_hash);
...@@ -543,8 +545,8 @@ TEST_PASS(PackAllReduceReplacePass, DividePacks) { ...@@ -543,8 +545,8 @@ TEST_PASS(PackAllReduceReplacePass, DividePacks) {
auto insert_opr = [&] (size_t size) { auto insert_opr = [&] (size_t size) {
auto dev = std::make_shared<DeviceTensorND>(cn, TensorShape{size / sizeof(float)}); auto dev = std::make_shared<DeviceTensorND>(cn, TensorShape{size / sizeof(float)});
auto sd = opr::SharedDeviceTensor::make(*graph, dev); auto sd = opr::SharedDeviceTensor::make(*graph, dev);
auto symvar = opr::CollectiveComm::make({sd}, graph.get(), auto symvar = opr::CollectiveComm::make(
"key", 2, 0, 0, client, mode)[0]; {sd}, graph.get(), "key", 2, false, 0, false, client, mode)[0];
auto opr = symvar.node()->owner_opr(); auto opr = symvar.node()->owner_opr();
auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); auto& comm = opr->cast_final_safe<opr::CollectiveComm>();
comm.set_pack_hash(1); comm.set_pack_hash(1);
...@@ -596,7 +598,6 @@ TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) { ...@@ -596,7 +598,6 @@ TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) {
size_t nr_devices = 2; size_t nr_devices = 2;
uint32_t rank = 0; uint32_t rank = 0;
uint32_t root = 0;
using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo; using GroupInfo = gopt::PackAllReduceReplacePass::GroupInfo;
ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info; ThinHashMap<uint64_t, std::shared_ptr<GroupInfo>> group_info;
...@@ -605,8 +606,9 @@ TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) { ...@@ -605,8 +606,9 @@ TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) {
auto insert_opr = [&] (const TensorShape& shape) { auto insert_opr = [&] (const TensorShape& shape) {
auto dev = std::make_shared<DeviceTensorND>(cn, shape); auto dev = std::make_shared<DeviceTensorND>(cn, shape);
auto sd = opr::SharedDeviceTensor::make(*graph, dev); auto sd = opr::SharedDeviceTensor::make(*graph, dev);
auto symvar = opr::CollectiveComm::make({sd}, graph.get(), auto symvar =
"key", nr_devices, rank, root, client, mode)[0]; opr::CollectiveComm::make({sd}, graph.get(), "key", nr_devices,
false, rank, false, client, mode)[0];
auto opr = symvar.node()->owner_opr(); auto opr = symvar.node()->owner_opr();
auto& comm = opr->cast_final_safe<opr::CollectiveComm>(); auto& comm = opr->cast_final_safe<opr::CollectiveComm>();
comm.set_pack_hash(1); comm.set_pack_hash(1);
...@@ -634,8 +636,9 @@ TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) { ...@@ -634,8 +636,9 @@ TEST_PASS(PackAllReduceReplacePass, InsertPackedOprs) {
auto concat = opr::Concat::make({grad_x.flatten(), grad_y.flatten()}, 0); auto concat = opr::Concat::make({grad_x.flatten(), grad_y.flatten()}, 0);
std::string key = ssprintf("grad_pack_%zu", pack_id); std::string key = ssprintf("grad_pack_%zu", pack_id);
auto allreduce = opr::CollectiveComm::make({concat}, graph.get(), auto allreduce =
key, nr_devices, rank, root, client, mode)[0]; opr::CollectiveComm::make({concat}, graph.get(), key, nr_devices,
false, rank, false, client, mode)[0];
std::vector<size_t> partition; std::vector<size_t> partition;
partition.push_back(shape_x.total_nr_elems()); partition.push_back(shape_x.total_nr_elems());
...@@ -683,10 +686,14 @@ TEST_PASS(PackAllReduceReplacePass, Equivalence) { ...@@ -683,10 +686,14 @@ TEST_PASS(PackAllReduceReplacePass, Equivalence) {
using Mode = opr::CollectiveComm::Param::Mode; using Mode = opr::CollectiveComm::Param::Mode;
bool is_root = (rank == 0); bool is_root = (rank == 0);
auto reduced_x = opr::CollectiveComm::make({grad_x}, graph.get(), auto reduced_x = opr::CollectiveComm::make(
"x", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2; {grad_x}, graph.get(), "x", 2, is_root, rank,
auto reduced_y = opr::CollectiveComm::make({grad_y}, graph.get(), false, client, Mode::ALL_REDUCE_SUM)[0] /
"y", 2, is_root, rank, client, Mode::ALL_REDUCE_SUM)[0] / 2; 2;
auto reduced_y = opr::CollectiveComm::make(
{grad_y}, graph.get(), "y", 2, is_root, rank,
false, client, Mode::ALL_REDUCE_SUM)[0] /
2;
graph->options().allreduce_pack_max_size = 5000; graph->options().allreduce_pack_max_size = 5000;
graph->options().allreduce_pack_ignore_first = 0; graph->options().allreduce_pack_ignore_first = 0;
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include "megbrain/comp_node_env.h" #include "megbrain/comp_node_env.h"
#include "megbrain/graph/event.h" #include "megbrain/graph/event.h"
#include "megbrain/graph/grad_impl.h" #include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/megray_helper.h" #include "megbrain/opr/megray_helper.h"
#include "megbrain/opr/group_manager.h" #include "megbrain/opr/group_manager.h"
...@@ -77,6 +79,8 @@ cudaStream_t get_stream(VarNode* var) { ...@@ -77,6 +79,8 @@ cudaStream_t get_stream(VarNode* var) {
} }
} // anonymous namespace } // anonymous namespace
/* ================= ModeTrait ================= */
class CollectiveComm::ModeTrait { class CollectiveComm::ModeTrait {
class BROADCAST; class BROADCAST;
class REDUCE_SUM; class REDUCE_SUM;
...@@ -132,6 +136,42 @@ public: ...@@ -132,6 +136,42 @@ public:
return None; return None;
} }
VarNode* full_grad(VarNode* out_grad, const CollectiveComm* opr) const {
auto mode = ModeTrait::from_mode(opr->param().mode).grad_mode();
SymbolVarArray og_syms;
og_syms.push_back(out_grad);
auto&& cn = opr->output(0)->comp_node();
auto gvar = CollectiveComm::make(
og_syms, opr->owner_graph(), opr->key() + ":grad",
opr->nr_devices(), opr->is_root(), opr->rank(), false,
opr->group_client(), mode, opr->dtype(), opr->backend(), {cn});
return gvar[0].node();
}
virtual VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const {
mgb_throw(MegBrainError,
"only all_reduce all_to_all all_gather reduce_scatter "
"support local_grad");
}
virtual VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const {
if (opr->local_grad()){
return local_grad(out_grad, opr);
} else {
return full_grad(out_grad, opr);
}
}
VarNode* zeros(mgb::cg::ComputingGraph &graph, CompNode node, const SymbolVar& shape,
DType dtype) const {
auto zero = SymbolVar::make_scalar(0, graph, node);
auto zero_tensor = opr::TypeCvt::make(zero, dtype).broadcast(shape);
return zero_tensor.node();
}
virtual void get_output_var_shape(const CollectiveComm* opr, virtual void get_output_var_shape(const CollectiveComm* opr,
const TensorShapeArray& ishp, const TensorShapeArray& ishp,
TensorShapeArray& oshp) = 0; TensorShapeArray& oshp) = 0;
...@@ -174,6 +214,17 @@ class CollectiveComm::ModeTrait::ALL_GATHER : public ModeTrait { ...@@ -174,6 +214,17 @@ class CollectiveComm::ModeTrait::ALL_GATHER : public ModeTrait {
} }
Mode grad_mode() override { return Mode::REDUCE_SCATTER_SUM; } Mode grad_mode() override { return Mode::REDUCE_SCATTER_SUM; }
VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const override {
auto nr_devices = opr->nr_devices();
auto rank = opr->rank();
opr::Subtensor::IndexDesc axis;
auto shape0 = opr::GetVarShape::make(out_grad, 0);
axis.push_back({0, shape0 * rank / (int)nr_devices,
shape0 * (rank + 1) / (int)nr_devices});
auto grad = opr::Subtensor::make(out_grad, axis);
return grad.node();
}
}; };
class CollectiveComm::ModeTrait::REDUCE_SCATTER_SUM : public ModeTrait { class CollectiveComm::ModeTrait::REDUCE_SCATTER_SUM : public ModeTrait {
...@@ -211,9 +262,23 @@ class CollectiveComm::ModeTrait::REDUCE_SCATTER_SUM : public ModeTrait { ...@@ -211,9 +262,23 @@ class CollectiveComm::ModeTrait::REDUCE_SCATTER_SUM : public ModeTrait {
} }
Mode grad_mode() override { return Mode::ALL_GATHER; } Mode grad_mode() override { return Mode::ALL_GATHER; }
};
/* ================= ModeTrait impls ================= */ VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const override {
VarNodeArray grads;
auto zeros_tensor =
zeros(*out_grad->owner_graph(), out_grad->comp_node(),
opr::GetVarShape::make(out_grad), out_grad->dtype());
for (size_t i = 0;i < opr->nr_devices();i++) {
if (i == opr->rank()) {
grads.push_back(out_grad);
} else {
grads.push_back(zeros_tensor);
}
}
auto grad = opr::Concat::make(grads, 0);
return grad.node();
}
};
class CollectiveComm::ModeTrait::ReducedBasedTrait { class CollectiveComm::ModeTrait::ReducedBasedTrait {
protected: protected:
...@@ -250,6 +315,12 @@ class CollectiveComm::ModeTrait::AllReduceBase : public ReducedBasedTrait, ...@@ -250,6 +315,12 @@ class CollectiveComm::ModeTrait::AllReduceBase : public ReducedBasedTrait,
} }
Mode grad_mode() override { return Mode::ALL_REDUCE_SUM; } Mode grad_mode() override { return Mode::ALL_REDUCE_SUM; }
public:
VarNode* local_grad(VarNode* out_grad,
const CollectiveComm* opr) const override {
return out_grad;
}
}; };
class CollectiveComm::ModeTrait::ALL_REDUCE_SUM final : public AllReduceBase { class CollectiveComm::ModeTrait::ALL_REDUCE_SUM final : public AllReduceBase {
...@@ -258,10 +329,38 @@ class CollectiveComm::ModeTrait::ALL_REDUCE_SUM final : public AllReduceBase { ...@@ -258,10 +329,38 @@ class CollectiveComm::ModeTrait::ALL_REDUCE_SUM final : public AllReduceBase {
class CollectiveComm::ModeTrait::ALL_REDUCE_MAX final : public AllReduceBase { class CollectiveComm::ModeTrait::ALL_REDUCE_MAX final : public AllReduceBase {
MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_MAX; } MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_MAX; }
VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override {
VarNode* grad;
if (opr->local_grad()) {
grad = local_grad(out_grad, opr);
} else {
grad = full_grad(out_grad, opr);
}
grad = opr::Elemwise::make({opr->output(0), opr->input(0), grad},
Elemwise::Mode::COND_LEQ_MOV)
.node();
return grad;
}
}; };
class CollectiveComm::ModeTrait::ALL_REDUCE_MIN final : public AllReduceBase { class CollectiveComm::ModeTrait::ALL_REDUCE_MIN final : public AllReduceBase {
MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_MIN; } MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_MIN; }
VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override {
VarNode* grad;
if (opr->local_grad()) {
grad = local_grad(out_grad, opr);
} else {
grad = full_grad(out_grad, opr);
}
grad = opr::Elemwise::make({opr->input(0), opr->output(0), grad},
Elemwise::Mode::COND_LEQ_MOV)
.node();
return grad;
}
}; };
class CollectiveComm::ModeTrait::ReduceBase : public ReducedBasedTrait, class CollectiveComm::ModeTrait::ReduceBase : public ReducedBasedTrait,
...@@ -448,6 +547,24 @@ class CollectiveComm::ModeTrait::ALL_TO_ALL : public ModeTrait { ...@@ -448,6 +547,24 @@ class CollectiveComm::ModeTrait::ALL_TO_ALL : public ModeTrait {
} }
Mode grad_mode() override { return Mode::ALL_TO_ALL; } Mode grad_mode() override { return Mode::ALL_TO_ALL; }
VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const override {
VarNodeArray grads;
auto grad_shape = opr::GetVarShape::make(out_grad);
auto zeros_tensor =
zeros(*out_grad->owner_graph(), out_grad->comp_node(),
grad_shape, out_grad->dtype());
auto nr_devices = opr->nr_devices();
auto rank = opr->rank();
opr::Subtensor::IndexDesc axis;
auto shape0 = opr::GetVarShape::make(out_grad, 0);
axis.push_back({0, shape0 * rank / (int)nr_devices,
shape0 * (rank + 1) / (int)nr_devices});
auto sub_grad = opr::Subtensor::make(out_grad, axis);
return opr::SetSubtensor::make(zeros_tensor, sub_grad, axis).node();
}
}; };
CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) { CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) {
...@@ -469,8 +586,9 @@ CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) { ...@@ -469,8 +586,9 @@ CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) {
CollectiveComm::CollectiveComm( CollectiveComm::CollectiveComm(
VarNodeArray inputs, ComputingGraph* const graph, VarNodeArray inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const bool is_root, const std::string& key, const size_t nr_devices, const bool is_root,
const int rank, std::shared_ptr<GroupClient> group_client, const int rank, const bool local_grad,
const Param& param, const DType& dtype, const std::string& backend, std::shared_ptr<GroupClient> group_client, const Param& param,
const DType& dtype, const std::string& backend,
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
const OperatorNodeConfig& config, const OperatorNodeConfig& config,
const std::shared_ptr<DTypeScalar>& disable) const std::shared_ptr<DTypeScalar>& disable)
...@@ -482,6 +600,7 @@ CollectiveComm::CollectiveComm( ...@@ -482,6 +600,7 @@ CollectiveComm::CollectiveComm(
m_nr_devices(nr_devices), m_nr_devices(nr_devices),
m_is_root(is_root), m_is_root(is_root),
m_rank(rank), m_rank(rank),
m_local_grad(local_grad),
m_key(key), m_key(key),
m_dev_buffers(dev_buffer_arr), m_dev_buffers(dev_buffer_arr),
m_disable{disable} { m_disable{disable} {
...@@ -523,28 +642,31 @@ CollectiveComm::CollectiveComm( ...@@ -523,28 +642,31 @@ CollectiveComm::CollectiveComm(
SymbolVarArray CollectiveComm::make( SymbolVarArray CollectiveComm::make(
const SymbolVarArray& inputs, ComputingGraph* const graph, const SymbolVarArray& inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const bool is_root, const std::string& key, const size_t nr_devices, const bool is_root,
const int rank, std::shared_ptr<GroupClient> group_client, const int rank, const bool local_grad,
const Param& param, const DType& dtype, const std::string& backend, std::shared_ptr<GroupClient> group_client, const Param& param,
const DType& dtype, const std::string& backend,
const OperatorNodeConfig& config, const OperatorNodeConfig& config,
const std::shared_ptr<DTypeScalar>& disable) { const std::shared_ptr<DTypeScalar>& disable) {
SmallVector<std::shared_ptr<DeviceTensorND>> dev_buffer_arr(nr_devices, SmallVector<std::shared_ptr<DeviceTensorND>> dev_buffer_arr(nr_devices,
nullptr); nullptr);
return make(inputs, graph, key, nr_devices, is_root, rank, group_client, return make(inputs, graph, key, nr_devices, is_root, rank, local_grad,
dev_buffer_arr, param, dtype, backend, config); group_client, dev_buffer_arr, param, dtype, backend, config);
} }
SymbolVarArray CollectiveComm::make( SymbolVarArray CollectiveComm::make(
const SymbolVarArray& inputs, ComputingGraph* const graph, const SymbolVarArray& inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const bool is_root, const std::string& key, const size_t nr_devices, const bool is_root,
const int rank, std::shared_ptr<GroupClient> group_client, const int rank, const bool local_grad,
std::shared_ptr<GroupClient> group_client,
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
const Param& param, const DType& dtype, const std::string& backend, const Param& param, const DType& dtype, const std::string& backend,
const OperatorNodeConfig& config, const OperatorNodeConfig& config,
const std::shared_ptr<DTypeScalar>& disable) { const std::shared_ptr<DTypeScalar>& disable) {
auto inpvars = cg::to_var_node_array(inputs); auto inpvars = cg::to_var_node_array(inputs);
auto opr = graph->insert_opr(std::make_unique<CollectiveComm>( auto opr = graph->insert_opr(std::make_unique<CollectiveComm>(
inpvars, graph, key, nr_devices, is_root, rank, std::move(group_client), inpvars, graph, key, nr_devices, is_root, rank, local_grad,
param, dtype, backend, dev_buffer_arr, config, disable)); std::move(group_client), param, dtype, backend, dev_buffer_arr,
config, disable));
mgb_assert(!opr->output().empty()); mgb_assert(!opr->output().empty());
return cg::to_symbol_var_array(opr->output()); return cg::to_symbol_var_array(opr->output());
} }
...@@ -647,93 +769,12 @@ void CollectiveComm::do_execute(ExecEnv& env) { ...@@ -647,93 +769,12 @@ void CollectiveComm::do_execute(ExecEnv& env) {
owner_graph()->event().signal_inplace<cg::event::BeforeKernel>(this, cn); owner_graph()->event().signal_inplace<cg::event::BeforeKernel>(this, cn);
trait.exec(this); trait.exec(this);
owner_graph()->event().signal_inplace<cg::event::AfterKernel>(this, cn); owner_graph()->event().signal_inplace<cg::event::AfterKernel>(this, cn);
#if CUDART_VERSION < 9000
#pragma message "legacy CUDA; use sync to avoid blocking"
// nccl hangs occasionally without this sync()
cn.sync();
#endif
}; };
env.dispatch_on_comp_node(cn, runner); env.dispatch_on_comp_node(cn, runner);
} }
void CollectiveComm::on_output_comp_node_stream_changed() {} void CollectiveComm::on_output_comp_node_stream_changed() {}
VarNodeArray CollectiveComm::grad(const VarNodeArray& out_grads) const {
auto mode = ModeTrait::from_mode(m_param.mode).grad_mode();
SymbolVarArray og_syms;
if (m_param.mode == Param::Mode::REDUCE_SUM) {
for (size_t i = 0; i < output().size(); i++) {
if (out_grads[i])
og_syms.push_back(out_grads[i]);
}
mgb_assert(og_syms.size() == 1);
} else {
for (size_t i = 0; i < output().size(); i++) {
if (!out_grads[i]) {
mgb_assert(m_param.mode != Param::Mode::REDUCE_SCATTER_SUM,
"null out grad in CollctiveCommMM currently "
"unsupported when the forward mode is "
"Reduce_Scatter_Sum.");
DTypeScalar dval{output(i)->dtype()};
dval.set_retain_dtype(0);
auto zeros =
SymbolVar::make_scalar(dval, *output(i)->owner_graph(),
output(i)->comp_node())
.broadcast(SymbolVar(output(i)).symshape());
og_syms.push_back(zeros);
} else {
og_syms.push_back(out_grads[i]);
}
}
}
OperatorNodeConfig::CompNodeArray cn_arr;
if (m_param.mode == Param::Mode::REDUCE_SUM) {
for (auto i : input()) {
cn_arr.push_back(i->comp_node());
}
} else if (m_param.mode == Param::Mode::BROADCAST) {
if (!input().empty()) {
cn_arr.push_back(input(0)->comp_node());
}
}
auto gvar = CollectiveComm::make(
og_syms, owner_graph(), m_key + ":grad", m_nr_devices, m_is_root,
m_rank, m_group_client, mode, m_dtype, m_backend,
OperatorNodeConfig{}.comp_node_arr(cn_arr));
if (m_param.mode == Param::Mode::ALL_REDUCE_MAX) {
for (size_t i = 0; i < input().size(); ++i) {
gvar[i] = Elemwise::make({output(i), input(i), gvar[i]},
Elemwise::Mode::COND_LEQ_MOV);
}
} else if (m_param.mode == Param::Mode::ALL_REDUCE_MIN) {
for (size_t i = 0; i < input().size(); ++i) {
gvar[i] = Elemwise::make({input(i), output(i), gvar[i]},
Elemwise::Mode::COND_LEQ_MOV);
}
} else if (m_param.mode == Param::Mode::BROADCAST) {
if (!input().empty()) {
CompNode&& master_out_cn = input(0)->comp_node();
SymbolVarArray rst;
for (auto i : gvar) {
if (i.node()->comp_node() == master_out_cn) {
mgb_assert(rst.empty());
rst.push_back(i);
}
}
gvar = rst;
}
}
return cg::to_var_node_array(gvar);
}
MGB_IMPL_OPR_GRAD(CollectiveComm) {
return opr.grad(out_grad);
}
void CollectiveComm::init_output_dtype() { void CollectiveComm::init_output_dtype() {
if (m_dtype.valid()) { if (m_dtype.valid()) {
for (size_t i = 0; i < input().size(); ++i) { for (size_t i = 0; i < input().size(); ++i) {
...@@ -797,6 +838,15 @@ void CollectiveComm::init_output_static_infer_desc() { ...@@ -797,6 +838,15 @@ void CollectiveComm::init_output_static_infer_desc() {
} }
} }
VarNode* CollectiveComm::grad(VarNode* out_grad) const {
return ModeTrait::from_mode(m_param.mode).grad(out_grad, this);
}
MGB_IMPL_OPR_GRAD(CollectiveComm) {
mgb_assert(out_grad.size() == 1, "CollectiveComm should only have one grad");
return opr.grad(out_grad[0]);
}
/* ===================== shallow copy ===================== */ /* ===================== shallow copy ===================== */
namespace mgb { namespace mgb {
...@@ -807,13 +857,14 @@ cg::OperatorNodeBase* opr_shallow_copy_collective_mm( ...@@ -807,13 +857,14 @@ cg::OperatorNodeBase* opr_shallow_copy_collective_mm(
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) { const OperatorNodeConfig& config) {
auto&& opr = opr_.cast_final_safe<opr::CollectiveComm>(); auto&& opr = opr_.cast_final_safe<opr::CollectiveComm>();
auto new_opr = CollectiveComm::make( auto new_opr =
to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs), CollectiveComm::make(
opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(), to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs),
opr.group_client(), opr.dev_buffers(), opr.param(), opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(),
opr.dtype(), opr.backend(), config)[0] opr.local_grad(), opr.group_client(), opr.dev_buffers(),
.node() opr.param(), opr.dtype(), opr.backend(), config)[0]
->owner_opr(); .node()
->owner_opr();
new_opr->cast_final_safe<opr::CollectiveComm>().set_pack_hash(opr.pack_hash()); new_opr->cast_final_safe<opr::CollectiveComm>().set_pack_hash(opr.pack_hash());
return new_opr; return new_opr;
} }
......
...@@ -8,6 +8,7 @@ decl_raw_opr( ...@@ -8,6 +8,7 @@ decl_raw_opr(
'operation to which this operator belongs.', 'int'), 'operation to which this operator belongs.', 'int'),
Doc('is_root', 'whether this node is root node', 'bool'), Doc('is_root', 'whether this node is root node', 'bool'),
Doc('rank', 'rank of this node, if is -1, generate one', 'int'), Doc('rank', 'rank of this node, if is -1, generate one', 'int'),
Doc('local_grad', 'whether use local grad', 'bool'),
Doc('server_addr', 'rpc server ip address'), Doc('server_addr', 'rpc server ip address'),
Doc('port', 'server rpc listening port'), Doc('port', 'server rpc listening port'),
Doc('param', 'The only component of *param* is *mode*, which refers to ' Doc('param', 'The only component of *param* is *mode*, which refers to '
...@@ -28,12 +29,12 @@ decl_raw_opr( ...@@ -28,12 +29,12 @@ decl_raw_opr(
body = [ body = [
'if isinstance(input, _mgb.SymbolVar):', 'if isinstance(input, _mgb.SymbolVar):',
(' output = _mgb._Opr.collective_comm_with_input(input, key, ' (' output = _mgb._Opr.collective_comm_with_input(input, key, '
'nr_devices, is_root, rank, server_addr, port, ' 'nr_devices, is_root, rank, local_grad, server_addr, port, '
'[param.serialize()], dtype, backend, output_buffer, config, disable)'), '[param.serialize()], dtype, backend, output_buffer, config, disable)'),
'else:', 'else:',
' assert isinstance(input, _mgb.CompGraph)', ' assert isinstance(input, _mgb.CompGraph)',
(' output = _mgb._Opr.collective_comm_without_input(input, key, ' (' output = _mgb._Opr.collective_comm_without_input(input, key, '
'nr_devices, is_root, rank, server_addr, port, ' 'nr_devices, is_root, rank, local_grad, server_addr, port, '
'[param.serialize()], dtype, backend, output_buffer, config, disable)') '[param.serialize()], dtype, backend, output_buffer, config, disable)')
], ],
desc = ('collective communication between multiple CompNodes on multiple ' desc = ('collective communication between multiple CompNodes on multiple '
......
...@@ -29,8 +29,9 @@ public: ...@@ -29,8 +29,9 @@ public:
CollectiveComm( CollectiveComm(
VarNodeArray inputs, ComputingGraph* const graph, VarNodeArray inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const bool is_root, const std::string& key, const size_t nr_devices, const bool is_root,
const int rank, std::shared_ptr<GroupClient> group_client, const int rank, const bool local_grad,
const Param& param, const DType& dtype, const std::string& backend, std::shared_ptr<GroupClient> group_client, const Param& param,
const DType& dtype, const std::string& backend,
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
const OperatorNodeConfig& config, const OperatorNodeConfig& config,
const std::shared_ptr<DTypeScalar>& disable); const std::shared_ptr<DTypeScalar>& disable);
...@@ -38,7 +39,8 @@ public: ...@@ -38,7 +39,8 @@ public:
static SymbolVarArray make( static SymbolVarArray make(
const SymbolVarArray& inputs, ComputingGraph* const graph, const SymbolVarArray& inputs, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const bool is_root, const std::string& key, const size_t nr_devices, const bool is_root,
const int rank, std::shared_ptr<GroupClient> group_client, const int rank, const bool local_grad,
std::shared_ptr<GroupClient> group_client,
const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr, const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
const Param& param, const DType& dtype = {}, const Param& param, const DType& dtype = {},
const std::string& backend = "nccl", const std::string& backend = "nccl",
...@@ -50,6 +52,7 @@ public: ...@@ -50,6 +52,7 @@ public:
ComputingGraph* const graph, ComputingGraph* const graph,
const std::string& key, const size_t nr_devices, const std::string& key, const size_t nr_devices,
const bool is_root, const int rank, const bool is_root, const int rank,
const bool local_grad,
std::shared_ptr<GroupClient> group_client, std::shared_ptr<GroupClient> group_client,
const Param& param, const DType& dtype = {}, const Param& param, const DType& dtype = {},
const std::string& backend = "nccl", const std::string& backend = "nccl",
...@@ -72,6 +75,7 @@ public: ...@@ -72,6 +75,7 @@ public:
int rank() const { return m_rank; } int rank() const { return m_rank; }
int root() const { return m_root; } int root() const { return m_root; }
bool is_root() const { return m_is_root; } bool is_root() const { return m_is_root; }
bool local_grad() const { return m_local_grad; }
//! The key that identifies an NCCL clique. //! The key that identifies an NCCL clique.
//! Operators with same keys belong to the same clique. //! Operators with same keys belong to the same clique.
...@@ -89,7 +93,7 @@ public: ...@@ -89,7 +93,7 @@ public:
return m_megray_ctx; return m_megray_ctx;
} }
VarNodeArray grad(const VarNodeArray& out_grad) const; VarNode* grad(VarNode* out_grad) const;
private: private:
Barrier m_exec_barrier; Barrier m_exec_barrier;
...@@ -116,6 +120,7 @@ private: ...@@ -116,6 +120,7 @@ private:
size_t m_nr_devices = 0; size_t m_nr_devices = 0;
bool m_is_root; bool m_is_root;
int m_rank; int m_rank;
bool m_local_grad;
std::string m_key; std::string m_key;
//! XXHash generated from m_key //! XXHash generated from m_key
size_t m_hash; size_t m_hash;
......
此差异已折叠。
...@@ -46,7 +46,7 @@ pdef('PersistentOutputStorage').add_fields( ...@@ -46,7 +46,7 @@ pdef('PersistentOutputStorage').add_fields(
(pdef('CollectiveComm', 'collective communication between multiple computing ' (pdef('CollectiveComm', 'collective communication between multiple computing '
'nodes on localhost') 'nodes on localhost')
.add_enum('Mode', .add_enum(Doc('Mode', 'mode of collective communication'),
Doc('REDUCE_SUM', 'reduce by sum to output computing node'), Doc('REDUCE_SUM', 'reduce by sum to output computing node'),
Doc('BROADCAST', 'copy input value to each output computing node'), Doc('BROADCAST', 'copy input value to each output computing node'),
Doc('ALL_GATHER', 'each output comp node gets the concatenated ' Doc('ALL_GATHER', 'each output comp node gets the concatenated '
...@@ -59,7 +59,8 @@ pdef('PersistentOutputStorage').add_fields( ...@@ -59,7 +59,8 @@ pdef('PersistentOutputStorage').add_fields(
Doc('ALL_REDUCE_PROD', 'every output gets the prod of all inputs'), Doc('ALL_REDUCE_PROD', 'every output gets the prod of all inputs'),
Doc('GATHER', 'concat inputs to one node'), Doc('GATHER', 'concat inputs to one node'),
Doc('SCATTER', 'scatter input to each output computing node'), Doc('SCATTER', 'scatter input to each output computing node'),
Doc('ALL_TO_ALL', 'scatter inputs and gather them on each computing node'))) Doc('ALL_TO_ALL', 'scatter inputs and gather them on each computing node'),
name_field='mode'))
(pdef('FakeSerializedDType', (pdef('FakeSerializedDType',
'HACK: The tag of this param def is actually used for another ' 'HACK: The tag of this param def is actually used for another '
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册