未验证 提交 776aef79 编写于 作者: W Wen Sun 提交者: GitHub

Move collective communication `all_gather` from collective.py (#48339)

* refactor: move all_gather
上级 cfd7ff8f
......@@ -27,8 +27,6 @@ from paddle.distributed.fleet.dataset import InMemoryDataset # noqa: F401
from paddle.distributed.fleet.dataset import QueueDataset # noqa: F401
from paddle.distributed.fleet.base.topology import ParallelMode # noqa: F401
from .collective import all_gather # noqa: F401
from .collective import all_gather_object # noqa: F401
from .collective import barrier # noqa: F401
from .collective import split # noqa: F401
from .collective import new_group # noqa: F401
......@@ -37,6 +35,8 @@ from .collective import wait # noqa: F401
from .communication import (
stream,
ReduceOp,
all_gather,
all_gather_object,
all_reduce,
alltoall,
alltoall_single,
......@@ -111,5 +111,4 @@ __all__ = [ # noqa
"isend",
"irecv",
"reduce_scatter",
"rpc",
]
......@@ -12,14 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pickle
import io
import datetime
from ..fluid.layer_helper import LayerHelper
from ..fluid.framework import in_dygraph_mode
from ..fluid.framework import _non_static_mode
from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.layers.tensor import fill_constant
import paddle
import paddle.fluid.core as core
......@@ -435,225 +431,3 @@ def _sync_comm_stream(tensor, ring_id=0):
outputs={'Out': [tensor]},
attrs={'ring_id': ring_id},
)
def all_gather(tensor_list, tensor, group=None, sync_op=True):
"""
Gather tensors from all participators and all get the result. As shown
below, one process is started with a GPU and the data of this process is represented
by its group rank. Through the all_gather operator, each GPU will have data
from all GPUs.
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allgather.png
:width: 800
:alt: all_gather
:align: center
Args:
tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128.
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
group (Group, optional): The group instance return by new_group or None for global default group.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
Returns:
None.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
tensor_list = []
if dist.get_rank() == 0:
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
else:
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
dist.all_gather(tensor_list, data)
print(tensor_list)
# [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
"""
if group is not None and not group.is_member():
return
def convert_to_complex(list_of_tensor):
list_of_complex = []
for tensor in list_of_tensor:
list_of_complex.append(paddle.as_complex(tensor))
return list_of_complex
is_input_complex = (
tensor.dtype == paddle.complex64 or tensor.dtype == paddle.complex128
)
if is_input_complex:
tensor = paddle.as_real(tensor)
if in_dygraph_mode():
group = _get_default_group() if group is None else group
if len(tensor_list) == 0:
tensor_shape = list(tensor.shape)
tensor_shape[0] *= group.nranks
out = paddle.empty(tensor_shape, tensor.dtype)
else:
out = paddle.concat(tensor_list, axis=0)
task = group.process_group.all_gather_into_tensor(out, tensor, sync_op)
task.wait()
tensor_list.clear()
list_of_tensor = paddle.split(out, group.nranks, 0)
if is_input_complex:
tensor_list.extend(convert_to_complex(list_of_tensor))
else:
tensor_list.extend(list_of_tensor)
return
use_calc_stream = sync_op
ring_id = 0 if group is None else group.id
nranks = _get_global_group().nranks if group is None else group.nranks
if _non_static_mode():
out = _legacy_C_ops.c_allgather(
tensor,
'use_calc_stream',
use_calc_stream,
'ring_id',
ring_id,
'nranks',
nranks,
)
else:
op_type = 'c_allgather'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
if not isinstance(tensor_list, list):
raise ValueError(
"The type of 'tensor_list' for all_gather " "should be list."
)
for elem in tensor_list:
check_variable_and_dtype(
elem,
'tensor_list',
[
'float16',
'float32',
'float64',
'int32',
'int64',
'bool',
'int8',
'uint8',
'complex64',
'complex128',
],
'all_gather',
)
check_variable_and_dtype(
tensor,
'tensor',
[
'float16',
'float32',
'float64',
'int32',
'int64',
'bool',
'int8',
'uint8',
'complex64',
'complex128',
],
'all_gather',
)
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [out]},
attrs={
'ring_id': ring_id,
'use_calc_stream': use_calc_stream,
'nranks': nranks,
},
)
list_of_tensor = paddle.split(out, nranks, 0)
if is_input_complex:
tensor_list.extend(convert_to_complex(list_of_tensor))
else:
tensor_list.extend(list_of_tensor)
def _convert_object_to_tensor(obj):
_pickler = pickle.Pickler
f = io.BytesIO()
_pickler(f).dump(obj)
data = np.frombuffer(f.getvalue(), dtype=np.uint8)
tensor = paddle.to_tensor(data)
return tensor, tensor.numel()
def _convert_tensor_to_object(tensor, len_of_tensor):
_unpickler = pickle.Unpickler
return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load()
def all_gather_object(object_list, obj, group=None):
"""
Gather picklable objects from all participators and all get the result. Similiar to all_gather(), but python object can be passed in.
Args:
object_list (list): A list of output object. The datatype of every element in the list is same as the input obj.
obj (Any): The picklable object to send.
group (Group): The group instance return by new_group or None for global default group.
Returns:
None.
Warning:
This API only supports the dygraph mode.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
object_list = []
if dist.get_rank() == 0:
obj = {"foo": [1, 2, 3]}
else:
obj = {"bar": [4, 5, 6]}
dist.all_gather_object(object_list, obj)
print(object_list)
# [{'foo': [1, 2, 3]}, {'bar': [4, 5, 6]}] (2 GPUs)
"""
assert (
in_dygraph_mode()
), "all_gather_object doesn't support static graph mode."
tensor, len_of_tensor = _convert_object_to_tensor(obj)
# gather len_of_tensor from all ranks
list_len_of_tensor = []
all_gather(list_len_of_tensor, len_of_tensor, group)
# get the max length from list
max_len_of_tensor = int(max(list_len_of_tensor).item())
# resize the input tensor to max length avoid hang in all gather
# Note(liyurui): Maybe we should support various length all_gather?
# Now this operation is efficient for we don't support resize in python.
numpy_data = tensor.numpy()
numpy_data = np.resize(numpy_data, [max_len_of_tensor])
input_tensor = paddle.to_tensor(numpy_data)
tensor_list = []
all_gather(tensor_list, input_tensor, group)
for i, tensor in enumerate(tensor_list):
object_list.append(
_convert_tensor_to_object(tensor, list_len_of_tensor[i])
)
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .all_gather import all_gather, all_gather_object
from .all_reduce import all_reduce
from .broadcast import broadcast
from .reduce import reduce, ReduceOp
......@@ -24,6 +25,8 @@ from .group import is_initialized, destroy_process_group, get_group
__all__ = [
"ReduceOp",
"all_gather",
"all_gather_object",
"all_reduce",
"alltoall",
"alltoall_single",
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import pickle
import numpy as np
import paddle
import paddle.distributed as dist
import paddle.fluid.framework as framework
import paddle.distributed.communication.stream as stream
def all_gather(tensor_list, tensor, group=None, sync_op=True):
"""
Gather tensors from all participators and all get the result. As shown
below, one process is started with a GPU and the data of this process is represented
by its group rank. Through the all_gather operator, each GPU will have data
from all GPUs.
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allgather.png
:width: 800
:alt: all_gather
:align: center
Args:
tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
group (Group, optional): The group instance return by new_group or None for global default group.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
Returns:
None.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
tensor_list = []
if dist.get_rank() == 0:
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
else:
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
dist.all_gather(tensor_list, data)
print(tensor_list)
# [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
"""
if not framework._in_legacy_dygraph():
return stream.all_gather(tensor_list, tensor, group, sync_op)
# NOTE: uncomment code below when having fully complex support
# def convert_to_complex(list_of_tensor):
# list_of_complex = []
# for tensor in list_of_tensor:
# list_of_complex.append(paddle.as_complex(tensor))
# return list_of_complex
# is_input_complex = (tensor.dtype == paddle.complex64
# or tensor.dtype == paddle.complex128)
# if is_input_complex:
# tensor = paddle.as_real(tensor)
# code below will be removed after we remove the old dygraph
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
nranks = dist.get_world_size()
out = paddle._legacy_C_ops.c_allgather(
tensor,
'use_calc_stream',
sync_op,
'ring_id',
ring_id,
'nranks',
nranks,
)
tensor_list.clear()
tensor_list.extend(paddle.split(out, nranks, 0))
def _convert_object_to_tensor(obj):
_pickler = pickle.Pickler
f = io.BytesIO()
_pickler(f).dump(obj)
data = np.frombuffer(f.getvalue(), dtype=np.uint8)
tensor = paddle.to_tensor(data)
return tensor, tensor.numel()
def _convert_tensor_to_object(tensor, len_of_tensor):
_unpickler = pickle.Unpickler
return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load()
def all_gather_object(object_list, obj, group=None):
"""
Gather picklable objects from all participators and all get the result. Similiar to all_gather(), but python object can be passed in.
Args:
object_list (list): A list of output object. The datatype of every element in the list is same as the input obj.
obj (Any): The picklable object to send.
group (Group): The group instance return by new_group or None for global default group.
Returns:
None.
Warning:
This API only supports the dygraph mode.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
object_list = []
if dist.get_rank() == 0:
obj = {"foo": [1, 2, 3]}
else:
obj = {"bar": [4, 5, 6]}
dist.all_gather_object(object_list, obj)
print(object_list)
# [{'foo': [1, 2, 3]}, {'bar': [4, 5, 6]}] (2 GPUs)
"""
assert (
framework.in_dygraph_mode()
), "all_gather_object doesn't support static graph mode."
tensor, len_of_tensor = _convert_object_to_tensor(obj)
# gather len_of_tensor from all ranks
list_len_of_tensor = []
all_gather(list_len_of_tensor, len_of_tensor, group)
# get the max length from list
max_len_of_tensor = int(max(list_len_of_tensor).item())
# resize the input tensor to max length avoid hang in all gather
# Note(liyurui): Maybe we should support various length all_gather?
# Now this operation is efficient for we don't support resize in python.
numpy_data = tensor.numpy()
numpy_data = np.resize(numpy_data, [max_len_of_tensor])
input_tensor = paddle.to_tensor(numpy_data)
tensor_list = []
all_gather(tensor_list, input_tensor, group)
for i, tensor in enumerate(tensor_list):
object_list.append(
_convert_tensor_to_object(tensor, list_len_of_tensor[i])
)
......@@ -13,14 +13,17 @@
# limitations under the License.
import paddle
import paddle.distributed as dist
import paddle.fluid.framework as framework
from paddle.distributed import collective
import paddle.fluid.data_feeder as data_feeder
import paddle.fluid.layer_helper as layer_helper
from paddle.distributed.communication.group import _get_global_group
def _all_gather_into_tensor_in_dygraph(
out_tensor, in_tensor, group, sync_op, use_calc_stream
):
group = collective._get_default_group() if group is None else group
group = _get_global_group() if group is None else group
if use_calc_stream:
return group.process_group.all_gather_into_tensor_on_calc_stream(
......@@ -40,7 +43,7 @@ def _all_gather_into_tensor_in_dygraph(
def _all_gather_in_dygraph(
tensor_list, tensor, group, sync_op, use_calc_stream
):
group = collective._get_default_group() if group is None else group
group = _get_global_group() if group is None else group
if len(tensor_list) == 0:
tensor_list += [paddle.empty_like(tensor) for _ in range(group.nranks)]
......@@ -57,6 +60,58 @@ def _all_gather_in_dygraph(
return task
def _all_gather_in_static_mode(tensor_list, tensor, group, sync_op):
op_type = 'c_allgather'
helper = layer_helper.LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
for elem in tensor_list:
data_feeder.check_variable_and_dtype(
elem,
'tensor_list',
[
'float16',
'float32',
'float64',
'int32',
'int64',
'bool',
'int8',
'uint8',
],
'all_gather',
)
data_feeder.check_variable_and_dtype(
tensor,
'tensor',
[
'float16',
'float32',
'float64',
'int32',
'int64',
'bool',
'int8',
'uint8',
],
'all_gather',
)
ring_id = 0 if group is None else group.id
nranks = dist.get_world_size()
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [out]},
attrs={
'ring_id': ring_id,
'use_calc_stream': sync_op,
'nranks': nranks,
},
)
tensor_list.clear()
tensor_list.extend(paddle.split(out, nranks, 0))
def all_gather(
tensor_or_tensor_list,
tensor,
......@@ -122,7 +177,13 @@ def all_gather(
return _all_gather_in_dygraph(
tensor_or_tensor_list, tensor, group, sync_op, use_calc_stream
)
raise RuntimeError(
"paddle.distributed.stream.all_gather is only supported in dygraph mode now."
)
else:
assert group is None, "Group can not be used in static mode for now."
if paddle.is_tensor(tensor_or_tensor_list):
raise RuntimeError(
"Only support passing a tensor list to `all_gather` in static mode now."
)
else:
return _all_gather_in_static_mode(
tensor_or_tensor_list, tensor, group, sync_op
)
......@@ -35,8 +35,6 @@ class TestCollectiveAllgatherAPI(TestDistBase):
"int8",
"uint8",
"bool",
"complex64",
"complex128",
]
for dtype in dtypes_to_test:
self.check_with_place(
......@@ -53,8 +51,6 @@ class TestCollectiveAllgatherAPI(TestDistBase):
"int8",
"uint8",
"bool",
"complex64",
"complex128",
]
for dtype in dtypes_to_test:
self.check_with_place(
......@@ -65,7 +61,7 @@ class TestCollectiveAllgatherAPI(TestDistBase):
dtype=dtype,
)
def test_allgatther_nccl_dygraph(self):
def test_allgather_nccl_dygraph(self):
dtypes_to_test = [
"float16",
"float32",
......@@ -75,8 +71,6 @@ class TestCollectiveAllgatherAPI(TestDistBase):
"int8",
"uint8",
"bool",
"complex64",
"complex128",
]
if self._nccl_version >= 2100:
dtypes_to_test.append("bfloat16")
......@@ -100,8 +94,6 @@ class TestCollectiveAllgatherAPI(TestDistBase):
"uint8",
"bool",
"bfloat16",
"complex64",
"complex128",
]
for dtype in dtypes_to_test:
self.check_with_place(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册