提交 f2ac4c34 编写于 作者: M Megvii Engine Team

docs(distributed.functional.all_reduce_sum): googlestring and examples

GitOrigin-RevId: a456dfde24ffab921cb91035b187dbb8bec3b694
上级 186bacfb
...@@ -251,12 +251,52 @@ def reduce_scatter_sum( ...@@ -251,12 +251,52 @@ def reduce_scatter_sum(
def all_reduce_sum( def all_reduce_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor: ) -> Tensor:
""" r"""
Create all_reduce_sum operator for collective communication. Create all_reduce_sum operator for collective communication.
:param inp: input tensor. This operator sums the tensor data by coordinates across the specified group and returns a tensor with the shape of the input tensor.
:param group: communication group.
:param device: execution device. Args:
inp: The tensor data to apply this operator on.
group: The communication node list instance of :class:'Group' to apply this operator across. The default group is WORLD which means all processes available.
Specify a list of process ranks to apply this operator on specific processes, e.g. [1, 3, 5].
device: The specific device type of :class:'str' to execute this operator. The default device is None which mean the device of inp will be used.
Specify "cpu" or "gpu" to execute this operator on specific devices.
Returns:
opt: The reduce sum tensor of the input tensor data across the specified group.
Examples:
.. code-block::
import megengine as mge
import megengine.distributed as dist
import numpy as np
from warnings import warn
def func(sum_value):
# get the rank of this process, the ranks shold be 0, 1, 2, 3 for a 4 gpu task
rank = dist.get_rank()
data = mge.tensor(rank)
# the result should be n * (n - 1) / 2 for all processes
result = mge.functional.distributed.all_reduce_sum(data).item()
assert result == sum_value
def main():
p_num = mge.device.get_device_count("gpu")
if p_num < 2:
warn('This opr only works on group with more than one gpu')
return
method = dist.launcher(func)
method(p_num * (p_num - 1) // 2)
if __name__ == '__main__':
main()
""" """
mode = CollectiveComm.Mode.ALL_REDUCE_SUM mode = CollectiveComm.Mode.ALL_REDUCE_SUM
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册