未验证 提交 1b6c1d39 编写于 作者: K kuizhiqing 提交者: GitHub

fix doc preblem (#32010)

上级 8460698b
...@@ -142,7 +142,7 @@ def get_group(id=0): ...@@ -142,7 +142,7 @@ def get_group(id=0):
Get group instance by group id. Get group instance by group id.
Args: Args:
id (int): the group id id (int): the group id. Default value is 0.
Returns: Returns:
Group: the group instance. Group: the group instance.
...@@ -163,26 +163,24 @@ def get_group(id=0): ...@@ -163,26 +163,24 @@ def get_group(id=0):
def new_group(ranks=None, backend=None): def new_group(ranks=None, backend=None):
""" """
Creates a new distributed comminication group. Creates a new distributed communication group.
Args: Args:
ranks (list): The global ranks of group members, list as sorted. ranks (list): The global ranks of group members.
backend (str): The backend used to create group, only nccl is supported now. backend (str): The backend used to create group, only nccl is supported now.
Returns: Returns:
Group: The group instance. Nerver return None. Group: The group instance.
Examples: Examples:
.. code-block:: python .. code-block:: python
import numpy as np
import paddle import paddle
paddle.distributed.init_parallel_env() paddle.distributed.init_parallel_env()
tindata = np.random.random([10, 1000]).astype('float32') tindata = paddle.randn(shape=[2, 3])
tindata = paddle.to_tensor(tindata) gp = paddle.distributed.new_group([2,4,6])
gid = paddle.distributed.new_group([2,4,6]) paddle.distributed.all_reduce(tindata, group=gp, use_calc_stream=False)
paddle.distributed.all_reduce(tindata, group=gid, use_calc_stream=False)
""" """
...@@ -221,7 +219,7 @@ def new_group(ranks=None, backend=None): ...@@ -221,7 +219,7 @@ def new_group(ranks=None, backend=None):
place = core.CUDAPlace(genv.device_id) place = core.CUDAPlace(genv.device_id)
core.NCCLParallelContext(strategy, place).init_with_ring_id(ring_id) core.NCCLParallelContext(strategy, place).init_with_ring_id(ring_id)
else: else:
assert False assert False, ("no cuda device found")
return gp return gp
...@@ -234,8 +232,8 @@ def wait(tensor, group=None, use_calc_stream=True): ...@@ -234,8 +232,8 @@ def wait(tensor, group=None, use_calc_stream=True):
Args: Args:
tensor (Tensor): The Tensor used before sync. tensor (Tensor): The Tensor used before sync.
group (Group): The Group instance to perform sync. group (Group): The Group instance to perform sync.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False), use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
default to False. Default to True.
Returns: Returns:
None. None.
...@@ -243,13 +241,10 @@ def wait(tensor, group=None, use_calc_stream=True): ...@@ -243,13 +241,10 @@ def wait(tensor, group=None, use_calc_stream=True):
Examples: Examples:
.. code-block:: python .. code-block:: python
import numpy as np
import paddle import paddle
paddle.distributed.init_parallel_env() paddle.distributed.init_parallel_env()
tindata = np.random.random([10, 1000]).astype('float32') tindata = paddle.randn(shape=[2, 3])
tindata = paddle.to_tensor(tindata)
paddle.distributed.all_reduce(tindata, use_calc_stream=True) paddle.distributed.all_reduce(tindata, use_calc_stream=True)
paddle.distributed.wait(tindata) paddle.distributed.wait(tindata)
...@@ -306,8 +301,8 @@ def broadcast(tensor, src, group=None, use_calc_stream=True): ...@@ -306,8 +301,8 @@ def broadcast(tensor, src, group=None, use_calc_stream=True):
should be float16, float32, float64, int32 or int64. should be float16, float32, float64, int32 or int64.
src (int): The source rank. src (int): The source rank.
group (Group): The group instance return by new_group or None for global default group. group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False), use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
default to True. Default to True.
Returns: Returns:
None. None.
...@@ -339,6 +334,7 @@ def broadcast(tensor, src, group=None, use_calc_stream=True): ...@@ -339,6 +334,7 @@ def broadcast(tensor, src, group=None, use_calc_stream=True):
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
gsrc = src if group is None else group.get_group_rank(src) gsrc = src if group is None else group.get_group_rank(src)
assert gsrc >= 0, ("src rank out of group, need global rank")
if in_dygraph_mode(): if in_dygraph_mode():
return core.ops.c_broadcast(tensor, tensor, 'root', gsrc, return core.ops.c_broadcast(tensor, tensor, 'root', gsrc,
...@@ -370,10 +366,10 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True): ...@@ -370,10 +366,10 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
Args: Args:
tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
should be float16, float32, float64, int32 or int64. should be float16, float32, float64, int32 or int64.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
group (Group): The group instance return by new_group or None for global default group. group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False), use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
default to True. Default to True.
Returns: Returns:
None. None.
...@@ -453,10 +449,10 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True): ...@@ -453,10 +449,10 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
should be float16, float32, float64, int32 or int64. should be float16, float32, float64, int32 or int64.
dst (int): The destination rank id. dst (int): The destination rank id.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
group (Group): The group instance return by new_group or None for global default group. group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False), use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
default to True. Default to True.
Returns: Returns:
None. None.
...@@ -487,6 +483,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True): ...@@ -487,6 +483,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
gdst = dst if group is None else group.get_group_rank(dst) gdst = dst if group is None else group.get_group_rank(dst)
assert gdst >= 0, ("dst rank out of group, need global rank")
if in_dygraph_mode(): if in_dygraph_mode():
if op == ReduceOp.SUM: if op == ReduceOp.SUM:
...@@ -548,8 +545,8 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True): ...@@ -548,8 +545,8 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
tensor (Tensor): The Tensor to send. Its data type tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32 or int64. should be float16, float32, float64, int32 or int64.
group (Group): The group instance return by new_group or None for global default group. group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False), use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
default to True. Default to True.
Returns: Returns:
None. None.
...@@ -624,11 +621,11 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True): ...@@ -624,11 +621,11 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
tensor (Tensor): The output Tensor. Its data type tensor (Tensor): The output Tensor. Its data type
should be float16, float32, float64, int32 or int64. should be float16, float32, float64, int32 or int64.
tensor_list (list): A list of Tensors to scatter. Every element in the list must be a Tensor whose data type tensor_list (list): A list of Tensors to scatter. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32 or int64. should be float16, float32, float64, int32 or int64. Default value is None.
src (int): The source rank id. src (int): The source rank id. Default value is 0.
group (Group): The group instance return by new_group or None for global default group. group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False), use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
default to True. Default to True.
Returns: Returns:
None. None.
...@@ -664,6 +661,7 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True): ...@@ -664,6 +661,7 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
gsrc = src if group is None else group.get_group_rank(src) gsrc = src if group is None else group.get_group_rank(src)
assert gsrc >= 0, ("src rank out of group, need global rank")
rank = _get_global_group().rank if group is None else group.rank rank = _get_global_group().rank if group is None else group.rank
nranks = _get_global_group().nranks if group is None else group.nranks nranks = _get_global_group().nranks if group is None else group.nranks
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册