未验证 提交 dec67d6d 编写于 作者: W Wen Sun 提交者: GitHub

Add collective communication APIs to improve completeness (#49252)

* feat: broadcast_object_list & scatter_object_list

* chore: update ut conf

* get_backend & is_available

* docs: update requirements

* fix: resolve conflicts
Co-authored-by: NLiYuRio <liyuruijx@163.com>
上级 a72a0da0
...@@ -29,6 +29,7 @@ from paddle.distributed.fleet.base.topology import ParallelMode # noqa: F401 ...@@ -29,6 +29,7 @@ from paddle.distributed.fleet.base.topology import ParallelMode # noqa: F401
from .collective import split # noqa: F401 from .collective import split # noqa: F401
from .collective import new_group # noqa: F401 from .collective import new_group # noqa: F401
from .collective import is_available # noqa: F401
from .communication import ( from .communication import (
stream, stream,
...@@ -39,9 +40,11 @@ from .communication import ( ...@@ -39,9 +40,11 @@ from .communication import (
alltoall, alltoall,
alltoall_single, alltoall_single,
broadcast, broadcast,
broadcast_object_list,
reduce, reduce,
send, send,
scatter, scatter,
scatter_object_list,
isend, isend,
recv, recv,
irecv, irecv,
...@@ -53,6 +56,7 @@ from .communication import ( ...@@ -53,6 +56,7 @@ from .communication import (
get_group, get_group,
wait, wait,
barrier, barrier,
get_backend,
) # noqa: F401 ) # noqa: F401
from .auto_parallel import shard_op # noqa: F401 from .auto_parallel import shard_op # noqa: F401
...@@ -81,7 +85,9 @@ __all__ = [ # noqa ...@@ -81,7 +85,9 @@ __all__ = [ # noqa
"spawn", "spawn",
"launch", "launch",
"scatter", "scatter",
"scatter_object_list",
"broadcast", "broadcast",
"broadcast_object_list",
"ParallelEnv", "ParallelEnv",
"new_group", "new_group",
"init_parallel_env", "init_parallel_env",
...@@ -114,4 +120,6 @@ __all__ = [ # noqa ...@@ -114,4 +120,6 @@ __all__ = [ # noqa
"isend", "isend",
"irecv", "irecv",
"reduce_scatter", "reduce_scatter",
"is_available",
"get_backend",
] ]
...@@ -307,3 +307,21 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout): ...@@ -307,3 +307,21 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout):
paddle.distributed.all_reduce(tmp, sync_op=True) paddle.distributed.all_reduce(tmp, sync_op=True)
paddle.distributed.wait(tmp) paddle.distributed.wait(tmp)
return gp return gp
def is_available():
"""
Check whether the distributed package is available.
Returns:
Returns True if the distributed package is available, otherwise False.
Examples:
.. code-block:: python
import paddle
print(paddle.distributed.is_available())
"""
return core.is_compiled_with_dist()
...@@ -13,11 +13,11 @@ ...@@ -13,11 +13,11 @@
# limitations under the License. # limitations under the License.
from .all_gather import all_gather, all_gather_object from .all_gather import all_gather, all_gather_object
from .all_reduce import all_reduce from .all_reduce import all_reduce
from .broadcast import broadcast from .broadcast import broadcast, broadcast_object_list
from .reduce import reduce, ReduceOp from .reduce import reduce, ReduceOp
from .send import send, isend from .send import send, isend
from .recv import recv, irecv from .recv import recv, irecv
from .scatter import scatter from .scatter import scatter, scatter_object_list
from .batch_isend_irecv import batch_isend_irecv, P2POp from .batch_isend_irecv import batch_isend_irecv, P2POp
from .reduce_scatter import reduce_scatter from .reduce_scatter import reduce_scatter
from .all_to_all import alltoall, alltoall_single from .all_to_all import alltoall, alltoall_single
...@@ -27,4 +27,5 @@ from .group import ( ...@@ -27,4 +27,5 @@ from .group import (
get_group, get_group,
wait, wait,
barrier, barrier,
get_backend,
) )
...@@ -12,15 +12,17 @@ ...@@ -12,15 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import io
import pickle
import numpy as np import numpy as np
import paddle import paddle
import paddle.distributed.communication.stream as stream import paddle.distributed.communication.stream as stream
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
from .serialization_utils import (
convert_object_to_tensor,
convert_tensor_to_object,
)
def all_gather(tensor_list, tensor, group=None, sync_op=True): def all_gather(tensor_list, tensor, group=None, sync_op=True):
""" """
...@@ -66,20 +68,6 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True): ...@@ -66,20 +68,6 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True):
return stream.all_gather(tensor_list, tensor, group, sync_op) return stream.all_gather(tensor_list, tensor, group, sync_op)
def _convert_object_to_tensor(obj):
_pickler = pickle.Pickler
f = io.BytesIO()
_pickler(f).dump(obj)
data = np.frombuffer(f.getvalue(), dtype=np.uint8)
tensor = paddle.to_tensor(data)
return tensor, tensor.numel()
def _convert_tensor_to_object(tensor, len_of_tensor):
_unpickler = pickle.Unpickler
return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load()
def all_gather_object(object_list, obj, group=None): def all_gather_object(object_list, obj, group=None):
""" """
...@@ -117,7 +105,7 @@ def all_gather_object(object_list, obj, group=None): ...@@ -117,7 +105,7 @@ def all_gather_object(object_list, obj, group=None):
framework.in_dygraph_mode() framework.in_dygraph_mode()
), "all_gather_object doesn't support static graph mode." ), "all_gather_object doesn't support static graph mode."
tensor, len_of_tensor = _convert_object_to_tensor(obj) tensor, len_of_tensor = convert_object_to_tensor(obj)
# gather len_of_tensor from all ranks # gather len_of_tensor from all ranks
list_len_of_tensor = [] list_len_of_tensor = []
...@@ -135,5 +123,5 @@ def all_gather_object(object_list, obj, group=None): ...@@ -135,5 +123,5 @@ def all_gather_object(object_list, obj, group=None):
all_gather(tensor_list, input_tensor, group) all_gather(tensor_list, input_tensor, group)
for i, tensor in enumerate(tensor_list): for i, tensor in enumerate(tensor_list):
object_list.append( object_list.append(
_convert_tensor_to_object(tensor, list_len_of_tensor[i]) convert_tensor_to_object(tensor, list_len_of_tensor[i])
) )
...@@ -12,7 +12,15 @@ ...@@ -12,7 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle
import paddle.distributed as dist
import paddle.distributed.communication.stream as stream import paddle.distributed.communication.stream as stream
import paddle.fluid.framework as framework
from .serialization_utils import (
convert_object_to_tensor,
convert_tensor_to_object,
)
def broadcast(tensor, src, group=None, sync_op=True): def broadcast(tensor, src, group=None, sync_op=True):
...@@ -60,3 +68,70 @@ def broadcast(tensor, src, group=None, sync_op=True): ...@@ -60,3 +68,70 @@ def broadcast(tensor, src, group=None, sync_op=True):
sync_op=sync_op, sync_op=sync_op,
use_calc_stream=False, use_calc_stream=False,
) )
def broadcast_object_list(object_list, src, group=None):
"""
Broadcast picklable objects from the source to all others. Similiar to broadcast(), but python object can be passed in.
Args:
object_list (list): The list of objects to send if current rank is the source, or the list of objects to receive otherwise.
src (int): The source rank in global view.
group (Group): The group instance return by new_group or None for global default group.
Returns:
None.
Warning:
This API only supports the dygraph mode.
Examples:
.. code-block:: python
# required: distributed
import paddle.distributed as dist
dist.init_parallel_env()
if dist.get_rank() == 0:
object_list = [{"foo": [1, 2, 3]}]
else:
object_list = [{"bar": [4, 5, 6]}]
dist.broadcast_object_list(object_list, src=1)
print(object_list)
# [{"bar": [4, 5, 6]}] (2 GPUs)
"""
assert (
framework.in_dygraph_mode()
), "broadcast_object_list doesn't support static graph mode."
rank = dist.get_rank()
obj_tensors = []
obj_nums = len(object_list)
if rank == src:
obj_sizes = []
for obj in object_list:
obj_tensor, obj_size = convert_object_to_tensor(obj)
obj_tensors.append(obj_tensor)
obj_sizes.append(obj_size)
obj_size_tensor = paddle.concat(obj_sizes)
else:
obj_size_tensor = paddle.empty([obj_nums], dtype="int64")
broadcast(obj_size_tensor, src)
if rank == src:
# cast to uint8 to keep the same dtype
obj_data_tensor = paddle.concat(obj_tensors).cast("uint8")
else:
data_len = paddle.sum(obj_size_tensor).item()
obj_data_tensor = paddle.empty([data_len], dtype="uint8")
broadcast(obj_data_tensor, src)
offset = 0
for i in range(obj_nums):
data_len = obj_size_tensor[i]
object_list[i] = convert_tensor_to_object(
obj_data_tensor[offset : offset + data_len], data_len
)
offset += data_len
...@@ -19,7 +19,6 @@ import paddle.distributed as dist ...@@ -19,7 +19,6 @@ import paddle.distributed as dist
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
import paddle.fluid.layer_helper as layer_helper import paddle.fluid.layer_helper as layer_helper
from paddle.fluid.framework import in_dygraph_mode
class Group: class Group:
...@@ -236,7 +235,7 @@ def get_group(id=0): ...@@ -236,7 +235,7 @@ def get_group(id=0):
def _sync_calc_stream(tensor): def _sync_calc_stream(tensor):
if in_dygraph_mode(): if framework.in_dygraph_mode():
return paddle._legacy_C_ops.c_sync_calc_stream(tensor, tensor) return paddle._legacy_C_ops.c_sync_calc_stream(tensor, tensor)
else: else:
op_type = 'c_sync_calc_stream' op_type = 'c_sync_calc_stream'
...@@ -249,7 +248,7 @@ def _sync_calc_stream(tensor): ...@@ -249,7 +248,7 @@ def _sync_calc_stream(tensor):
def _sync_comm_stream(tensor, ring_id=0): def _sync_comm_stream(tensor, ring_id=0):
if in_dygraph_mode(): if framework.in_dygraph_mode():
return paddle._legacy_C_ops.c_sync_comm_stream( return paddle._legacy_C_ops.c_sync_comm_stream(
[tensor], [tensor], 'ring_id', ring_id [tensor], [tensor], 'ring_id', ring_id
) )
...@@ -337,7 +336,7 @@ def barrier(group=None): ...@@ -337,7 +336,7 @@ def barrier(group=None):
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
barrier_tensor = paddle.full([1], 1, dtype="int32") barrier_tensor = paddle.full([1], 1, dtype="int32")
if in_dygraph_mode(): if framework.in_dygraph_mode():
return paddle._legacy_C_ops.barrier( return paddle._legacy_C_ops.barrier(
barrier_tensor, barrier_tensor, 'ring_id', ring_id barrier_tensor, barrier_tensor, 'ring_id', ring_id
) )
...@@ -352,3 +351,29 @@ def barrier(group=None): ...@@ -352,3 +351,29 @@ def barrier(group=None):
outputs={'Out': [barrier_tensor]}, outputs={'Out': [barrier_tensor]},
attrs={'ring_id': ring_id}, attrs={'ring_id': ring_id},
) )
def get_backend(group=None):
"""
Get the backend of given group.
Args:
group (Group): The group to work on. Use the global group as default.
Returns:
Returns the name of the given group backend.
Examples:
.. code-block:: python
# required: distributed
import paddle
paddle.distributed.init_parallel_env()
paddle.distributed.get_backend() # NCCL
"""
if _warn_cur_rank_not_in_group(group):
raise RuntimeError("Invalid group specified")
group = _get_global_group() if group is None else group
return group.backend
...@@ -12,7 +12,17 @@ ...@@ -12,7 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
import paddle
import paddle.distributed as dist
import paddle.distributed.communication.stream as stream import paddle.distributed.communication.stream as stream
import paddle.fluid.framework as framework
from .serialization_utils import (
convert_object_to_tensor,
convert_tensor_to_object,
)
def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True): def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True):
...@@ -59,3 +69,79 @@ def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True): ...@@ -59,3 +69,79 @@ def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True):
# [4, 5, 6] [4, 5, 6] (2 GPUs, out for rank 1) # [4, 5, 6] [4, 5, 6] (2 GPUs, out for rank 1)
""" """
return stream.scatter(tensor, tensor_list, src, group, sync_op) return stream.scatter(tensor, tensor_list, src, group, sync_op)
def scatter_object_list(
out_object_list, in_object_list=None, src=0, group=None
):
"""
Scatter picklable objects from the source to all others. Similiar to scatter(), but python object can be passed in.
Args:
out_object_list (list): The list of objects to store the scattered objects.
in_object_list (list): The list of objects to scatter. Only objects on the src rank will be scattered.
src (int): The source rank in global view.
group (Group): The group instance return by new_group or None for global default group.
Returns:
None.
Warning:
This API only supports the dygraph mode.
Examples:
.. code-block:: python
# required: distributed
import paddle.distributed as dist
dist.init_parallel_env()
out_object_list = []
if dist.get_rank() == 0:
in_object_list = [{'foo': [1, 2, 3]}, {'foo': [4, 5, 6]}]
else:
in_object_list = [{'bar': [1, 2, 3]}, {'bar': [4, 5, 6]}]
dist.scatter_object_list(out_object_list, in_object_list, src=1)
print(out_object_list)
# [{'bar': [1, 2, 3]}] (2 GPUs, out for rank 0)
# [{'bar': [4, 5, 6]}] (2 GPUs, out for rank 1)
"""
assert (
framework.in_dygraph_mode()
), "scatter_object_list doesn't support static graph mode."
rank = dist.get_rank()
in_obj_tensors = []
in_obj_sizes = []
if rank == src:
for obj in in_object_list:
obj_tensor, obj_size = convert_object_to_tensor(obj)
in_obj_tensors.append(obj_tensor)
in_obj_sizes.append(obj_size)
max_obj_size_tensor = max(in_obj_sizes)
else:
# NOTE: shape can be [] after 0D tensor support
max_obj_size_tensor = paddle.empty([1], dtype="int64")
stream.broadcast(max_obj_size_tensor, src)
max_obj_size = int(max_obj_size_tensor.item())
# resize to the same size
in_tensor_list = []
for tensor in in_obj_tensors:
numpy_data = tensor.numpy()
numpy_data = np.resize(numpy_data, [max_obj_size])
in_tensor = paddle.to_tensor(numpy_data)
in_tensor_list.append(in_tensor)
out_tensor = paddle.empty([max_obj_size], dtype="uint8")
scatter(out_tensor, in_tensor_list if rank == src else None, src)
# NOTE: shape can be [] after 0D tensor support
out_tensor_size = paddle.empty([1], dtype="int64")
scatter(out_tensor_size, in_obj_sizes if rank == src else None, src)
out_object_list.clear()
out_object_list.append(
convert_tensor_to_object(out_tensor, out_tensor_size.item())
)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import pickle
import numpy as np
import paddle
def convert_object_to_tensor(obj):
_pickler = pickle.Pickler
f = io.BytesIO()
_pickler(f).dump(obj)
data = np.frombuffer(f.getvalue(), dtype=np.uint8)
tensor = paddle.to_tensor(data)
return tensor, tensor.numel()
def convert_tensor_to_object(tensor, len_of_tensor):
_unpickler = pickle.Unpickler
return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load()
...@@ -127,6 +127,14 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) ...@@ -127,6 +127,14 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
set_tests_properties(test_collective_broadcast_api set_tests_properties(test_collective_broadcast_api
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST") PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST")
endif() endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_collective_broadcast_object_list_api MODULES
test_collective_broadcast_object_list_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_broadcast_object_list_api
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX)) if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules( py_test_modules(
test_collective_cpu_barrier_with_gloo MODULES test_collective_cpu_barrier_with_gloo MODULES
...@@ -223,6 +231,14 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) ...@@ -223,6 +231,14 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
set_tests_properties(test_collective_scatter_api set_tests_properties(test_collective_scatter_api
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST") PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST")
endif() endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_collective_scatter_object_list_api MODULES
test_collective_scatter_object_list_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_scatter_object_list_api
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX)) if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules( py_test_modules(
test_collective_sendrecv MODULES test_collective_sendrecv ENVS test_collective_sendrecv MODULES test_collective_sendrecv ENVS
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import test_collective_api_base as test_base
import paddle.distributed as dist
import paddle.fluid as fluid
class TestCollectiveBroadcastObjectListAPI(
test_base.TestCollectiveAPIRunnerBase
):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
object_list = [indata]
dist.broadcast_object_list(object_list, src=1)
return object_list
if __name__ == "__main__":
test_base.runtime_main(
TestCollectiveBroadcastObjectListAPI, "broadcast_object_list"
)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import test_collective_api_base as test_base
import paddle.distributed as dist
import paddle.fluid as fluid
class TestCollectiveScatterObjectListAPI(test_base.TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
data_len = len(indata) // 2
in_object_list = [indata[:data_len], indata[data_len:]]
out_object_list = []
dist.scatter_object_list(out_object_list, in_object_list, src=1)
return out_object_list
if __name__ == "__main__":
test_base.runtime_main(
TestCollectiveScatterObjectListAPI, "scatter_object_list"
)
...@@ -46,10 +46,14 @@ class TestProcessGroupFp32(unittest.TestCase): ...@@ -46,10 +46,14 @@ class TestProcessGroupFp32(unittest.TestCase):
device_id = paddle.distributed.ParallelEnv().dev_id device_id = paddle.distributed.ParallelEnv().dev_id
paddle.set_device('gpu:%d' % device_id) paddle.set_device('gpu:%d' % device_id)
assert paddle.distributed.is_available()
pg = init_process_group() pg = init_process_group()
print("rank:", pg.rank(), "size:", pg.size(), "name:", pg.name()) print("rank:", pg.rank(), "size:", pg.size(), "name:", pg.name())
print("test new group api ok") print("test new group api ok")
assert paddle.distributed.get_backend() == "NCCL"
# test allreduce sum # test allreduce sum
# rank 0 # rank 0
x = np.random.random(self.shape).astype(self.dtype) x = np.random.random(self.shape).astype(self.dtype)
......
...@@ -27,14 +27,7 @@ class TestCollectiveAllgatherObjectAPI(test_base.TestDistBase): ...@@ -27,14 +27,7 @@ class TestCollectiveAllgatherObjectAPI(test_base.TestDistBase):
"allgather_object", "allgather_object",
"nccl", "nccl",
static_mode="0", static_mode="0",
dtype="pylist", dtype="pyobject",
)
self.check_with_place(
"collective_allgather_object_api_dygraph.py",
"allgather_object",
"nccl",
static_mode="0",
dtype="pydict",
) )
def test_allgather_gloo_dygraph(self): def test_allgather_gloo_dygraph(self):
...@@ -44,15 +37,7 @@ class TestCollectiveAllgatherObjectAPI(test_base.TestDistBase): ...@@ -44,15 +37,7 @@ class TestCollectiveAllgatherObjectAPI(test_base.TestDistBase):
"gloo", "gloo",
"3", "3",
static_mode="0", static_mode="0",
dtype="pylist", dtype="pyobject",
)
self.check_with_place(
"collective_allgather_object_api_dygraph.py",
"allgather_object",
"gloo",
"3",
static_mode="0",
dtype="pydict",
) )
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import test_collective_api_base as test_base
class TestCollectiveBroadcastObjectListAPI(test_base.TestDistBase):
def _setup_config(self):
pass
def test_broadcast_nccl(self):
self.check_with_place(
"collective_broadcast_object_list_api_dygraph.py",
"broadcast_object_list",
"nccl",
static_mode="0",
dtype="pyobject",
)
def test_broadcast_gloo_dygraph(self):
self.check_with_place(
"collective_broadcast_object_list_api_dygraph.py",
"broadcast_object_list",
"gloo",
"3",
static_mode="0",
dtype="pyobject",
)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import test_collective_api_base as test_base
class TestCollectiveScatterObjectListAPI(test_base.TestDistBase):
def _setup_config(self):
pass
def test_scatter_nccl(self):
self.check_with_place(
"collective_scatter_object_list_api_dygraph.py",
"scatter_object_list",
"nccl",
static_mode="0",
dtype="pyobject",
)
def test_scatter_gloo_dygraph(self):
self.check_with_place(
"collective_scatter_object_list_api_dygraph.py",
"scatter_object_list",
"gloo",
"3",
static_mode="0",
dtype="pyobject",
)
if __name__ == '__main__':
unittest.main()
...@@ -14,6 +14,7 @@ test_collective_alltoall_single_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,ht ...@@ -14,6 +14,7 @@ test_collective_alltoall_single_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,ht
test_collective_barrier_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_barrier_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_batch_isend_irecv,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_batch_isend_irecv,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_broadcast_api,linux,gpu;rocm,180,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_broadcast_api,linux,gpu;rocm,180,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_broadcast_object_list_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_cpu_barrier_with_gloo,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_cpu_barrier_with_gloo,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_global_gather,linux,gpu;rocm,200,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_global_gather,linux,gpu;rocm,200,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_global_scatter,linux,gpu;rocm,200,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_global_scatter,linux,gpu;rocm,200,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
...@@ -26,6 +27,7 @@ test_collective_reduce_scatter,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_p ...@@ -26,6 +27,7 @@ test_collective_reduce_scatter,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_p
test_collective_reduce_scatter_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_reduce_scatter_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_scatter,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_scatter,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_scatter_api,linux,gpu;rocm,180,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_scatter_api,linux,gpu;rocm,180,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_scatter_object_list_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_sendrecv,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_sendrecv,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_sendrecv_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_sendrecv_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_split_col_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_collective_split_col_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
......
...@@ -58,22 +58,15 @@ def create_complex_test_data(shape=None, dtype=None, seed=None): ...@@ -58,22 +58,15 @@ def create_complex_test_data(shape=None, dtype=None, seed=None):
return data return data
def create_pylist_test_data(shape=None, seed=None): def create_pyobject_test_data(shape=None, seed=None):
if seed: if seed:
np.random.seed(seed) np.random.seed(seed)
# Generate random shape test case for xxx_object api list_shape = np.random.randint(0, high=100, size=(2)).tolist()
shape = np.random.randint(0, high=100, size=(2)).tolist() list_data = np.random.random(shape).tolist()
data = np.random.random(shape).tolist() dict_key = [i for i in range(0, shape[0])]
return data dict_val = np.random.random(shape).tolist()
dict_data = dict(zip(dict_key, dict_val))
return [list_data, dict_data]
def create_pydict_test_data(shape=None, seed=None):
if seed:
np.random.seed(seed)
key = [i for i in range(0, shape[0])]
value = np.random.random(shape).tolist()
data = dict(zip(key, value))
return data
def create_test_data(shape=None, dtype=None, seed=None): def create_test_data(shape=None, dtype=None, seed=None):
...@@ -94,10 +87,8 @@ def create_test_data(shape=None, dtype=None, seed=None): ...@@ -94,10 +87,8 @@ def create_test_data(shape=None, dtype=None, seed=None):
return create_int_test_data(shape=shape, dtype=dtype, seed=seed) return create_int_test_data(shape=shape, dtype=dtype, seed=seed)
elif dtype == "complex64" or dtype == "complex128": elif dtype == "complex64" or dtype == "complex128":
return create_complex_test_data(shape=shape, dtype=dtype, seed=seed) return create_complex_test_data(shape=shape, dtype=dtype, seed=seed)
elif dtype == "pylist": elif dtype == "pyobject":
return create_pylist_test_data(shape=shape, seed=seed) return create_pyobject_test_data(shape=shape, seed=seed)
elif dtype == "pydict":
return create_pydict_test_data(shape=shape, seed=seed)
else: else:
raise NotImplementedError("Unsupported dtype for creating test data.") raise NotImplementedError("Unsupported dtype for creating test data.")
...@@ -342,7 +333,7 @@ class TestDistBase(unittest.TestCase): ...@@ -342,7 +333,7 @@ class TestDistBase(unittest.TestCase):
tr_out1 = np.vstack((tr1_out[0], tr1_out[1])) tr_out1 = np.vstack((tr1_out[0], tr1_out[1]))
np.testing.assert_allclose(tr_out0, need_result, rtol=1e-05) np.testing.assert_allclose(tr_out0, need_result, rtol=1e-05)
np.testing.assert_allclose(tr_out1, need_result, rtol=1e-05) np.testing.assert_allclose(tr_out1, need_result, rtol=1e-05)
if col_type == "allgather_object": elif col_type == "allgather_object":
need_result = [input1, input2] need_result = [input1, input2]
self.assertEqual(need_result, tr0_out) self.assertEqual(need_result, tr0_out)
self.assertEqual(need_result, tr1_out) self.assertEqual(need_result, tr1_out)
...@@ -350,6 +341,10 @@ class TestDistBase(unittest.TestCase): ...@@ -350,6 +341,10 @@ class TestDistBase(unittest.TestCase):
need_result = input2 need_result = input2
np.testing.assert_allclose(tr0_out[0], need_result, rtol=1e-05) np.testing.assert_allclose(tr0_out[0], need_result, rtol=1e-05)
np.testing.assert_allclose(tr1_out[0], need_result, rtol=1e-05) np.testing.assert_allclose(tr1_out[0], need_result, rtol=1e-05)
elif col_type == "broadcast_object_list":
need_result = [input2]
self.assertEqual(need_result, tr0_out)
self.assertEqual(need_result, tr1_out)
elif col_type == "reduce": elif col_type == "reduce":
need_result = input1 + input2 need_result = input1 + input2
# bfloat16 precision loss comes from truncating the last 16 bits of float32, # bfloat16 precision loss comes from truncating the last 16 bits of float32,
...@@ -365,6 +360,12 @@ class TestDistBase(unittest.TestCase): ...@@ -365,6 +360,12 @@ class TestDistBase(unittest.TestCase):
need_result2 = need_result[need_result.shape[0] // 2 :] need_result2 = need_result[need_result.shape[0] // 2 :]
np.testing.assert_allclose(tr0_out[0], need_result1, rtol=1e-05) np.testing.assert_allclose(tr0_out[0], need_result1, rtol=1e-05)
np.testing.assert_allclose(tr1_out[0], need_result2, rtol=1e-05) np.testing.assert_allclose(tr1_out[0], need_result2, rtol=1e-05)
elif col_type == "scatter_object_list":
need_result = input2
need_result1 = [need_result[0 : len(need_result) // 2]]
need_result2 = [need_result[len(need_result) // 2 :]]
self.assertEqual(need_result1, tr0_out)
self.assertEqual(need_result2, tr1_out)
elif col_type == "reduce_scatter": elif col_type == "reduce_scatter":
need_result = input1 + input2 need_result = input1 + input2
need_result1 = need_result[0 : need_result.shape[0] // 2] need_result1 = need_result[0 : need_result.shape[0] // 2]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册