未验证 提交 08d726f5 编写于 作者: 张春乔 提交者: GitHub

[xdoctest] reformat example code with google style in 202 (#56171)

* input.py

* Update python/paddle/nn/functional/input.py

* Update input.py

* Update all_gather.py

* Update all_gather.py
上级 61b2bb57
......@@ -51,19 +51,19 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True):
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
tensor_list = []
if dist.get_rank() == 0:
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
else:
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
dist.all_gather(tensor_list, data)
print(tensor_list)
# [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
>>> # doctest: +REQUIRES(env: DISTRIBUTED)
>>> import paddle
>>> import paddle.distributed as dist
>>> dist.init_parallel_env()
>>> tensor_list = []
>>> if dist.get_rank() == 0:
... data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
>>> else:
... data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
>>> dist.all_gather(tensor_list, data)
>>> print(tensor_list)
[[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
"""
return stream.all_gather(tensor_list, tensor, group, sync_op)
......@@ -87,19 +87,19 @@ def all_gather_object(object_list, obj, group=None):
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
object_list = []
if dist.get_rank() == 0:
obj = {"foo": [1, 2, 3]}
else:
obj = {"bar": [4, 5, 6]}
dist.all_gather_object(object_list, obj)
print(object_list)
# [{'foo': [1, 2, 3]}, {'bar': [4, 5, 6]}] (2 GPUs)
>>> # doctest: +REQUIRES(env: DISTRIBUTED)
>>> import paddle
>>> import paddle.distributed as dist
>>> dist.init_parallel_env()
>>> object_list = []
>>> if dist.get_rank() == 0:
... obj = {"foo": [1, 2, 3]}
>>> else:
... obj = {"bar": [4, 5, 6]}
>>> dist.all_gather_object(object_list, obj)
>>> print(object_list)
[{'foo': [1, 2, 3]}, {'bar': [4, 5, 6]}] (2 GPUs)
"""
assert (
framework.in_dynamic_mode()
......
......@@ -145,21 +145,21 @@ def all_gather(
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
local_rank = dist.get_rank()
tensor_list = []
if local_rank == 0:
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
else:
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
task = dist.stream.all_gather(tensor_list, data, sync_op=False)
task.wait()
print(tensor_list)
# [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
>>> # doctest: +REQUIRES(env: DISTRIBUTED)
>>> import paddle
>>> import paddle.distributed as dist
>>> dist.init_parallel_env()
>>> local_rank = dist.get_rank()
>>> tensor_list = []
>>> if local_rank == 0:
... data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
>>> else:
... data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
>>> task = dist.stream.all_gather(tensor_list, data, sync_op=False)
>>> task.wait()
>>> print(tensor_list)
[[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
"""
if group is not None and not group.is_member():
raise RuntimeError(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册