diff --git a/imperative/python/megengine/distributed/functional.py b/imperative/python/megengine/distributed/functional.py index f0650b6fa8747dacaaaf34334214ed5e47a98f16..6c50fec2b80a3188041f04e3c63e2a752b01047b 100644 --- a/imperative/python/megengine/distributed/functional.py +++ b/imperative/python/megengine/distributed/functional.py @@ -251,12 +251,52 @@ def reduce_scatter_sum( def all_reduce_sum( inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, ) -> Tensor: - """ + r""" Create all_reduce_sum operator for collective communication. - :param inp: input tensor. - :param group: communication group. - :param device: execution device. + This operator sums the tensor data by coordinates across the specified group and returns a tensor with the shape of the input tensor. + + Args: + inp: The tensor data to apply this operator on. + group: The communication node list instance of :class:'Group' to apply this operator across. The default group is WORLD which means all processes available. + Specify a list of process ranks to apply this operator on specific processes, e.g. [1, 3, 5]. + device: The specific device type of :class:'str' to execute this operator. The default device is None which mean the device of inp will be used. + Specify "cpu" or "gpu" to execute this operator on specific devices. + + Returns: + opt: The reduce sum tensor of the input tensor data across the specified group. + + Examples: + + .. code-block:: + + import megengine as mge + import megengine.distributed as dist + import numpy as np + from warnings import warn + + + def func(sum_value): + # get the rank of this process, the ranks shold be 0, 1, 2, 3 for a 4 gpu task + rank = dist.get_rank() + data = mge.tensor(rank) + # the result should be n * (n - 1) / 2 for all processes + result = mge.functional.distributed.all_reduce_sum(data).item() + assert result == sum_value + + + def main(): + p_num = mge.device.get_device_count("gpu") + if p_num < 2: + warn('This opr only works on group with more than one gpu') + return + method = dist.launcher(func) + method(p_num * (p_num - 1) // 2) + + + if __name__ == '__main__': + main() + """ mode = CollectiveComm.Mode.ALL_REDUCE_SUM return collective_comm(inp, mode, group, device)