未验证 提交 dec67d6d 编写于 作者: W Wen Sun 提交者: GitHub

Add collective communication APIs to improve completeness (#49252)

* feat: broadcast_object_list & scatter_object_list

* chore: update ut conf

* get_backend & is_available

* docs: update requirements

* fix: resolve conflicts
Co-authored-by: NLiYuRio <liyuruijx@163.com>
上级 a72a0da0
......@@ -29,6 +29,7 @@ from paddle.distributed.fleet.base.topology import ParallelMode # noqa: F401
from .collective import split # noqa: F401
from .collective import new_group # noqa: F401
from .collective import is_available # noqa: F401
from .communication import (
stream,
......@@ -39,9 +40,11 @@ from .communication import (
alltoall,
alltoall_single,
broadcast,
broadcast_object_list,
reduce,
send,
scatter,
scatter_object_list,
isend,
recv,
irecv,
......@@ -53,6 +56,7 @@ from .communication import (
get_group,
wait,
barrier,
get_backend,
) # noqa: F401
from .auto_parallel import shard_op # noqa: F401
......@@ -81,7 +85,9 @@ __all__ = [ # noqa
"spawn",
"launch",
"scatter",
"scatter_object_list",
"broadcast",
"broadcast_object_list",
"ParallelEnv",
"new_group",
"init_parallel_env",
......@@ -114,4 +120,6 @@ __all__ = [ # noqa
"isend",
"irecv",
"reduce_scatter",
"is_available",
"get_backend",
]
......@@ -307,3 +307,21 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout):
paddle.distributed.all_reduce(tmp, sync_op=True)
paddle.distributed.wait(tmp)
return gp
def is_available():
"""
Check whether the distributed package is available.
Returns:
Returns True if the distributed package is available, otherwise False.
Examples:
.. code-block:: python
import paddle
print(paddle.distributed.is_available())
"""
return core.is_compiled_with_dist()
......@@ -13,11 +13,11 @@
# limitations under the License.
from .all_gather import all_gather, all_gather_object
from .all_reduce import all_reduce
from .broadcast import broadcast
from .broadcast import broadcast, broadcast_object_list
from .reduce import reduce, ReduceOp
from .send import send, isend
from .recv import recv, irecv
from .scatter import scatter
from .scatter import scatter, scatter_object_list
from .batch_isend_irecv import batch_isend_irecv, P2POp
from .reduce_scatter import reduce_scatter
from .all_to_all import alltoall, alltoall_single
......@@ -27,4 +27,5 @@ from .group import (
get_group,
wait,
barrier,
get_backend,
)
......@@ -12,15 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import pickle
import numpy as np
import paddle
import paddle.distributed.communication.stream as stream
import paddle.fluid.framework as framework
from .serialization_utils import (
convert_object_to_tensor,
convert_tensor_to_object,
)
def all_gather(tensor_list, tensor, group=None, sync_op=True):
"""
......@@ -66,20 +68,6 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True):
return stream.all_gather(tensor_list, tensor, group, sync_op)
def _convert_object_to_tensor(obj):
_pickler = pickle.Pickler
f = io.BytesIO()
_pickler(f).dump(obj)
data = np.frombuffer(f.getvalue(), dtype=np.uint8)
tensor = paddle.to_tensor(data)
return tensor, tensor.numel()
def _convert_tensor_to_object(tensor, len_of_tensor):
_unpickler = pickle.Unpickler
return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load()
def all_gather_object(object_list, obj, group=None):
"""
......@@ -117,7 +105,7 @@ def all_gather_object(object_list, obj, group=None):
framework.in_dygraph_mode()
), "all_gather_object doesn't support static graph mode."
tensor, len_of_tensor = _convert_object_to_tensor(obj)
tensor, len_of_tensor = convert_object_to_tensor(obj)
# gather len_of_tensor from all ranks
list_len_of_tensor = []
......@@ -135,5 +123,5 @@ def all_gather_object(object_list, obj, group=None):
all_gather(tensor_list, input_tensor, group)
for i, tensor in enumerate(tensor_list):
object_list.append(
_convert_tensor_to_object(tensor, list_len_of_tensor[i])
convert_tensor_to_object(tensor, list_len_of_tensor[i])
)
......@@ -12,7 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.distributed as dist
import paddle.distributed.communication.stream as stream
import paddle.fluid.framework as framework
from .serialization_utils import (
convert_object_to_tensor,
convert_tensor_to_object,
)
def broadcast(tensor, src, group=None, sync_op=True):
......@@ -60,3 +68,70 @@ def broadcast(tensor, src, group=None, sync_op=True):
sync_op=sync_op,
use_calc_stream=False,
)
def broadcast_object_list(object_list, src, group=None):
"""
Broadcast picklable objects from the source to all others. Similiar to broadcast(), but python object can be passed in.
Args:
object_list (list): The list of objects to send if current rank is the source, or the list of objects to receive otherwise.
src (int): The source rank in global view.
group (Group): The group instance return by new_group or None for global default group.
Returns:
None.
Warning:
This API only supports the dygraph mode.
Examples:
.. code-block:: python
# required: distributed
import paddle.distributed as dist
dist.init_parallel_env()
if dist.get_rank() == 0:
object_list = [{"foo": [1, 2, 3]}]
else:
object_list = [{"bar": [4, 5, 6]}]
dist.broadcast_object_list(object_list, src=1)
print(object_list)
# [{"bar": [4, 5, 6]}] (2 GPUs)
"""
assert (
framework.in_dygraph_mode()
), "broadcast_object_list doesn't support static graph mode."
rank = dist.get_rank()
obj_tensors = []
obj_nums = len(object_list)
if rank == src:
obj_sizes = []
for obj in object_list:
obj_tensor, obj_size = convert_object_to_tensor(obj)
obj_tensors.append(obj_tensor)
obj_sizes.append(obj_size)
obj_size_tensor = paddle.concat(obj_sizes)
else:
obj_size_tensor = paddle.empty([obj_nums], dtype="int64")
broadcast(obj_size_tensor, src)
if rank == src:
# cast to uint8 to keep the same dtype
obj_data_tensor = paddle.concat(obj_tensors).cast("uint8")
else:
data_len = paddle.sum(obj_size_tensor).item()
obj_data_tensor = paddle.empty([data_len], dtype="uint8")
broadcast(obj_data_tensor, src)
offset = 0
for i in range(obj_nums):
data_len = obj_size_tensor[i]
object_list[i] = convert_tensor_to_object(
obj_data_tensor[offset : offset + data_len], data_len
)
offset += data_len
......@@ -19,7 +19,6 @@ import paddle.distributed as dist
import paddle.fluid.core as core
import paddle.fluid.framework as framework
import paddle.fluid.layer_helper as layer_helper
from paddle.fluid.framework import in_dygraph_mode
class Group:
......@@ -236,7 +235,7 @@ def get_group(id=0):
def _sync_calc_stream(tensor):
if in_dygraph_mode():
if framework.in_dygraph_mode():
return paddle._legacy_C_ops.c_sync_calc_stream(tensor, tensor)
else:
op_type = 'c_sync_calc_stream'
......@@ -249,7 +248,7 @@ def _sync_calc_stream(tensor):
def _sync_comm_stream(tensor, ring_id=0):
if in_dygraph_mode():
if framework.in_dygraph_mode():
return paddle._legacy_C_ops.c_sync_comm_stream(
[tensor], [tensor], 'ring_id', ring_id
)
......@@ -337,7 +336,7 @@ def barrier(group=None):
ring_id = 0 if group is None else group.id
barrier_tensor = paddle.full([1], 1, dtype="int32")
if in_dygraph_mode():
if framework.in_dygraph_mode():
return paddle._legacy_C_ops.barrier(
barrier_tensor, barrier_tensor, 'ring_id', ring_id
)
......@@ -352,3 +351,29 @@ def barrier(group=None):
outputs={'Out': [barrier_tensor]},
attrs={'ring_id': ring_id},
)
def get_backend(group=None):
"""
Get the backend of given group.
Args:
group (Group): The group to work on. Use the global group as default.
Returns:
Returns the name of the given group backend.
Examples:
.. code-block:: python
# required: distributed
import paddle
paddle.distributed.init_parallel_env()
paddle.distributed.get_backend() # NCCL
"""
if _warn_cur_rank_not_in_group(group):
raise RuntimeError("Invalid group specified")
group = _get_global_group() if group is None else group
return group.backend
......@@ -12,7 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.distributed as dist
import paddle.distributed.communication.stream as stream
import paddle.fluid.framework as framework
from .serialization_utils import (
convert_object_to_tensor,
convert_tensor_to_object,
)
def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True):
......@@ -59,3 +69,79 @@ def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True):
# [4, 5, 6] [4, 5, 6] (2 GPUs, out for rank 1)
"""
return stream.scatter(tensor, tensor_list, src, group, sync_op)
def scatter_object_list(
out_object_list, in_object_list=None, src=0, group=None
):
"""
Scatter picklable objects from the source to all others. Similiar to scatter(), but python object can be passed in.
Args:
out_object_list (list): The list of objects to store the scattered objects.
in_object_list (list): The list of objects to scatter. Only objects on the src rank will be scattered.
src (int): The source rank in global view.
group (Group): The group instance return by new_group or None for global default group.
Returns:
None.
Warning:
This API only supports the dygraph mode.
Examples:
.. code-block:: python
# required: distributed
import paddle.distributed as dist
dist.init_parallel_env()
out_object_list = []
if dist.get_rank() == 0:
in_object_list = [{'foo': [1, 2, 3]}, {'foo': [4, 5, 6]}]
else:
in_object_list = [{'bar': [1, 2, 3]}, {'bar': [4, 5, 6]}]
dist.scatter_object_list(out_object_list, in_object_list, src=1)
print(out_object_list)
# [{'bar': [1, 2, 3]}] (2 GPUs, out for rank 0)
# [{'bar': [4, 5, 6]}] (2 GPUs, out for rank 1)
"""
assert (
framework.in_dygraph_mode()
), "scatter_object_list doesn't support static graph mode."
rank = dist.get_rank()
in_obj_tensors = []
in_obj_sizes = []
if rank == src:
for obj in in_object_list:
obj_tensor, obj_size = convert_object_to_tensor(obj)
in_obj_tensors.append(obj_tensor)
in_obj_sizes.append(obj_size)
max_obj_size_tensor = max(in_obj_sizes)
else:
# NOTE: shape can be [] after 0D tensor support
max_obj_size_tensor = paddle.empty([1], dtype="int64")
stream.broadcast(max_obj_size_tensor, src)
max_obj_size = int(max_obj_size_tensor.item())
# resize to the same size
in_tensor_list = []
for tensor in in_obj_tensors:
numpy_data = tensor.numpy()
numpy_data = np.resize(numpy_data, [max_obj_size])
in_tensor = paddle.to_tensor(numpy_data)
in_tensor_list.append(in_tensor)
out_tensor = paddle.empty([max_obj_size], dtype="uint8")
scatter(out_tensor, in_tensor_list if rank == src else None, src)
# NOTE: shape can be [] after 0D tensor support
out_tensor_size = paddle.empty([1], dtype="int64")
scatter(out_tensor_size, in_obj_sizes if rank == src else None, src)
out_object_list.clear()
out_object_list.append(
convert_tensor_to_object(out_tensor, out_tensor_size.item())
)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import pickle
import numpy as np
import paddle
def convert_object_to_tensor(obj):
_pickler = pickle.Pickler
f = io.BytesIO()
_pickler(f).dump(obj)
data = np.frombuffer(f.getvalue(), dtype=np.uint8)
tensor = paddle.to_tensor(data)
return tensor, tensor.numel()
def convert_tensor_to_object(tensor, len_of_tensor):
_unpickler = pickle.Unpickler
return _unpickler(io.BytesIO(tensor.numpy()[:len_of_tensor])).load()
......@@ -127,6 +127,14 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
set_tests_properties(test_collective_broadcast_api
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_collective_broadcast_object_list_api MODULES
test_collective_broadcast_object_list_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_broadcast_object_list_api
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_collective_cpu_barrier_with_gloo MODULES
......@@ -223,6 +231,14 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
set_tests_properties(test_collective_scatter_api
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_collective_scatter_object_list_api MODULES
test_collective_scatter_object_list_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_scatter_object_list_api
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_collective_sendrecv MODULES test_collective_sendrecv ENVS
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import test_collective_api_base as test_base
import paddle.distributed as dist
import paddle.fluid as fluid
class TestCollectiveBroadcastObjectListAPI(
test_base.TestCollectiveAPIRunnerBase
):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
object_list = [indata]
dist.broadcast_object_list(object_list, src=1)
return object_list
if __name__ == "__main__":
test_base.runtime_main(
TestCollectiveBroadcastObjectListAPI, "broadcast_object_list"
)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import test_collective_api_base as test_base
import paddle.distributed as dist
import paddle.fluid as fluid
class TestCollectiveScatterObjectListAPI(test_base.TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
data_len = len(indata) // 2
in_object_list = [indata[:data_len], indata[data_len:]]
out_object_list = []
dist.scatter_object_list(out_object_list, in_object_list, src=1)
return out_object_list
if __name__ == "__main__":
test_base.runtime_main(
TestCollectiveScatterObjectListAPI, "scatter_object_list"
)
......@@ -46,10 +46,14 @@ class TestProcessGroupFp32(unittest.TestCase):
device_id = paddle.distributed.ParallelEnv().dev_id
paddle.set_device('gpu:%d' % device_id)
assert paddle.distributed.is_available()
pg = init_process_group()
print("rank:", pg.rank(), "size:", pg.size(), "name:", pg.name())
print("test new group api ok")
assert paddle.distributed.get_backend() == "NCCL"
# test allreduce sum
# rank 0
x = np.random.random(self.shape).astype(self.dtype)
......
......@@ -27,14 +27,7 @@ class TestCollectiveAllgatherObjectAPI(test_base.TestDistBase):
"allgather_object",
"nccl",
static_mode="0",
dtype="pylist",
)
self.check_with_place(
"collective_allgather_object_api_dygraph.py",
"allgather_object",
"nccl",
static_mode="0",
dtype="pydict",
dtype="pyobject",
)
def test_allgather_gloo_dygraph(self):
......@@ -44,15 +37,7 @@ class TestCollectiveAllgatherObjectAPI(test_base.TestDistBase):
"gloo",
"3",
static_mode="0",
dtype="pylist",
)
self.check_with_place(
"collective_allgather_object_api_dygraph.py",
"allgather_object",
"gloo",
"3",
static_mode="0",
dtype="pydict",
dtype="pyobject",
)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import test_collective_api_base as test_base
class TestCollectiveBroadcastObjectListAPI(test_base.TestDistBase):
def _setup_config(self):
pass
def test_broadcast_nccl(self):
self.check_with_place(
"collective_broadcast_object_list_api_dygraph.py",
"broadcast_object_list",
"nccl",
static_mode="0",
dtype="pyobject",
)
def test_broadcast_gloo_dygraph(self):
self.check_with_place(
"collective_broadcast_object_list_api_dygraph.py",
"broadcast_object_list",
"gloo",
"3",
static_mode="0",
dtype="pyobject",
)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import test_collective_api_base as test_base
class TestCollectiveScatterObjectListAPI(test_base.TestDistBase):
def _setup_config(self):
pass
def test_scatter_nccl(self):
self.check_with_place(
"collective_scatter_object_list_api_dygraph.py",
"scatter_object_list",
"nccl",
static_mode="0",
dtype="pyobject",
)
def test_scatter_gloo_dygraph(self):
self.check_with_place(
"collective_scatter_object_list_api_dygraph.py",
"scatter_object_list",
"gloo",
"3",
static_mode="0",
dtype="pyobject",
)
if __name__ == '__main__':
unittest.main()
......@@ -14,6 +14,7 @@ test_collective_alltoall_single_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,ht
test_collective_barrier_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_batch_isend_irecv,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_broadcast_api,linux,gpu;rocm,180,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_broadcast_object_list_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_cpu_barrier_with_gloo,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_global_gather,linux,gpu;rocm,200,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_global_scatter,linux,gpu;rocm,200,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
......@@ -26,6 +27,7 @@ test_collective_reduce_scatter,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_p
test_collective_reduce_scatter_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_scatter,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_scatter_api,linux,gpu;rocm,180,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_scatter_object_list_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_sendrecv,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_sendrecv_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_split_col_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
......
......@@ -58,22 +58,15 @@ def create_complex_test_data(shape=None, dtype=None, seed=None):
return data
def create_pylist_test_data(shape=None, seed=None):
def create_pyobject_test_data(shape=None, seed=None):
if seed:
np.random.seed(seed)
# Generate random shape test case for xxx_object api
shape = np.random.randint(0, high=100, size=(2)).tolist()
data = np.random.random(shape).tolist()
return data
def create_pydict_test_data(shape=None, seed=None):
if seed:
np.random.seed(seed)
key = [i for i in range(0, shape[0])]
value = np.random.random(shape).tolist()
data = dict(zip(key, value))
return data
list_shape = np.random.randint(0, high=100, size=(2)).tolist()
list_data = np.random.random(shape).tolist()
dict_key = [i for i in range(0, shape[0])]
dict_val = np.random.random(shape).tolist()
dict_data = dict(zip(dict_key, dict_val))
return [list_data, dict_data]
def create_test_data(shape=None, dtype=None, seed=None):
......@@ -94,10 +87,8 @@ def create_test_data(shape=None, dtype=None, seed=None):
return create_int_test_data(shape=shape, dtype=dtype, seed=seed)
elif dtype == "complex64" or dtype == "complex128":
return create_complex_test_data(shape=shape, dtype=dtype, seed=seed)
elif dtype == "pylist":
return create_pylist_test_data(shape=shape, seed=seed)
elif dtype == "pydict":
return create_pydict_test_data(shape=shape, seed=seed)
elif dtype == "pyobject":
return create_pyobject_test_data(shape=shape, seed=seed)
else:
raise NotImplementedError("Unsupported dtype for creating test data.")
......@@ -342,7 +333,7 @@ class TestDistBase(unittest.TestCase):
tr_out1 = np.vstack((tr1_out[0], tr1_out[1]))
np.testing.assert_allclose(tr_out0, need_result, rtol=1e-05)
np.testing.assert_allclose(tr_out1, need_result, rtol=1e-05)
if col_type == "allgather_object":
elif col_type == "allgather_object":
need_result = [input1, input2]
self.assertEqual(need_result, tr0_out)
self.assertEqual(need_result, tr1_out)
......@@ -350,6 +341,10 @@ class TestDistBase(unittest.TestCase):
need_result = input2
np.testing.assert_allclose(tr0_out[0], need_result, rtol=1e-05)
np.testing.assert_allclose(tr1_out[0], need_result, rtol=1e-05)
elif col_type == "broadcast_object_list":
need_result = [input2]
self.assertEqual(need_result, tr0_out)
self.assertEqual(need_result, tr1_out)
elif col_type == "reduce":
need_result = input1 + input2
# bfloat16 precision loss comes from truncating the last 16 bits of float32,
......@@ -365,6 +360,12 @@ class TestDistBase(unittest.TestCase):
need_result2 = need_result[need_result.shape[0] // 2 :]
np.testing.assert_allclose(tr0_out[0], need_result1, rtol=1e-05)
np.testing.assert_allclose(tr1_out[0], need_result2, rtol=1e-05)
elif col_type == "scatter_object_list":
need_result = input2
need_result1 = [need_result[0 : len(need_result) // 2]]
need_result2 = [need_result[len(need_result) // 2 :]]
self.assertEqual(need_result1, tr0_out)
self.assertEqual(need_result2, tr1_out)
elif col_type == "reduce_scatter":
need_result = input1 + input2
need_result1 = need_result[0 : need_result.shape[0] // 2]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册