未验证 提交 e4eb8d36 编写于 作者: W Wen Sun 提交者: GitHub

Completes bfloat16 dtype for collective api in eager mode (#45844)

上级 f6a85db9
......@@ -88,6 +88,9 @@ namespace distributed {
case experimental::DataType::BOOL: \
func<bool>(args); \
break; \
case experimental::DataType::BFLOAT16: \
func<bfloat16>(args); \
break; \
default: \
VLOG(0) << "Error: Unknown DataType."; \
exit(-1); \
......
......@@ -996,6 +996,9 @@ void* GetPointerByOffset(void* raw_pointer,
} else if (type == experimental::DataType::BOOL) {
return reinterpret_cast<void*>(reinterpret_cast<bool*>(raw_pointer) +
offset);
} else if (type == experimental::DataType::BFLOAT16) {
return reinterpret_cast<void*>(reinterpret_cast<uint16_t*>(raw_pointer) +
offset);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"This datatype in nccl is not supported."));
......
......@@ -59,7 +59,7 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
return ncclUint8;
} else if (type == framework::proto::VarType::BOOL) {
return ncclUint8;
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000
} else if (type == framework::proto::VarType::BF16) {
return ncclBfloat16;
#endif
......@@ -86,7 +86,7 @@ inline ncclDataType_t ToNCCLDataType(experimental::DataType type) {
return ncclInt8;
} else if (type == experimental::DataType::BOOL) {
return ncclUint8;
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000
} else if (type == experimental::DataType::BFLOAT16) {
return ncclBfloat16;
#endif
......
......@@ -478,7 +478,8 @@ def is_initialized():
Check whether the distributed environment has been initialized
Returns (bool): `True` if distributed environment has been initialized, otherwise `False`.
Returns:
`True` if distributed environment has been initialized, otherwise `False`.
Examples:
.. code-block:: python
......@@ -626,7 +627,7 @@ def broadcast(tensor, src, group=None, sync_op=True):
Args:
tensor (Tensor): The Tensor to send if current rank is the source, or the Tensor to receive otherwise. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
src (int): The source rank.
group (Group, optional): The group instance return by new_group or None for global default group.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
......@@ -709,7 +710,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, sync_op=True):
Args:
tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
dst (int): The destination rank id.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The operation used. Default value is ReduceOp.SUM.
group (Group, optional): The group instance return by new_group or None for global default group.
......@@ -817,7 +818,7 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True):
Args:
tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
should be float16, float32, float64, int32, int64, int8, uint8, bool, bfloat16, complex64 or complex128.
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
group (Group, optional): The group instance return by new_group or None for global default group.
......@@ -999,9 +1000,9 @@ def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True):
Args:
tensor (Tensor): The output Tensor. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
tensor_list (list|tuple): A list/tuple of Tensors to scatter. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool. Default value is None.
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. Default value is None.
src (int): The source rank id. Default value is 0.
group (Group, optional): The group instance return by new_group or None for global default group.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
......@@ -1096,7 +1097,7 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, sync_op=True):
Args:
in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
out_tensor_list (list): A list of output Tensors. The data type of its elements should be the same as the
data type of the input Tensors.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
......@@ -1197,7 +1198,7 @@ def alltoall_single(in_tensor,
``alltoall_single`` is only supported in eager mode.
Args:
in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32, int64, int8, uint8 or bool.
in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
out_tensor (Tensor): Output Tensor. The data type should be the same as the data type of the input Tensor.
in_split_sizes (list[int], optional): Split sizes of ``in_tensor`` for dim[0]. If not given, dim[0] of ``in_tensor``
must be divisible by group size and ``in_tensor`` will be scattered averagely to all participators. Default: None.
......@@ -1286,7 +1287,7 @@ def send(tensor, dst=0, group=None, sync_op=True):
Args:
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
dst (int): The destination rank id.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
......@@ -1352,7 +1353,7 @@ def recv(tensor, src=0, group=None, sync_op=True):
Args:
tensor (Tensor): The Tensor to receive. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
src (int): The source rank id.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
......@@ -1435,7 +1436,7 @@ def isend(tensor, dst, group=None):
Args:
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
dst (int): The destination rank.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
......@@ -1485,7 +1486,7 @@ def irecv(tensor, src=None, group=None):
Args:
tensor (Tensor): The Tensor to receive. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
src (int): The source rank id.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
......@@ -1594,7 +1595,7 @@ def batch_isend_irecv(p2p_op_list):
corresponding tasks. NCCL are currently supported.
Args:
p2p_op_list: A list of point-to-point operations(type of each operator is
p2p_op_list (List[P2POp]): A list of point-to-point operations(type of each operator is
``paddle.distributed.P2POp``). The order of the isend/irecv in the list
matters and it needs to match with corresponding isend/irecv on the
remote end.
......@@ -1668,9 +1669,9 @@ def reduce_scatter(tensor,
Reduces, then scatters a list of tensors to all processes in a group
Args:
tensor (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8 or bool.
tensor (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
tensor_list (list[Tensor]): List of tensors to reduce and scatter. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
group (Group, optional): The group instance return by new_group or None for global
default group. Default: None.
......@@ -1736,9 +1737,9 @@ def _reduce_scatter_base(output,
Reduces, then scatters a flattened tensor to all processes in a group.
Args:
output (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8 or bool.
output (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
input (Tensor): Input tensor that is of size output tensor size times world size. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
......
......@@ -71,14 +71,14 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_allreduce_api MODULES test_collective_allreduce_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_allreduce_api
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_collective_alltoall_api MODULES test_collective_alltoall_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_alltoall_api
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
bash_test_modules(
......@@ -98,7 +98,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_alltoall_single_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_alltoall_single_api
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
......@@ -125,7 +125,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_broadcast_api MODULES test_collective_broadcast_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_broadcast_api
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
......@@ -154,7 +154,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_isend_irecv_api MODULES test_collective_isend_irecv_api
ENVS "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_isend_irecv_api
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
......@@ -187,7 +187,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_reduce_api MODULES test_collective_reduce_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_reduce_api
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
bash_test_modules(
......@@ -207,7 +207,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_reduce_scatter_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_reduce_scatter_api
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
......@@ -221,7 +221,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_scatter_api MODULES test_collective_scatter_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_scatter_api
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
......@@ -235,7 +235,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_sendrecv_api MODULES test_collective_sendrecv_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_sendrecv_api
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import paddle
import paddle.distributed as dist
import paddle.fluid as fluid
import test_collective_api_base as test_base
......@@ -24,9 +25,17 @@ class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase):
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.to_tensor(indata)
tensor_list = []
paddle.distributed.all_gather(tensor_list, tindata)
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
if indata.dtype == "bfloat16":
tindata = paddle.to_tensor(indata, "float32").cast("uint16")
dist.all_gather(tensor_list, tindata)
return [
tensor.cast("float32").numpy() for tensor in tensor_list
]
else:
tindata = paddle.to_tensor(indata)
dist.all_gather(tensor_list, tindata)
return [tensor.numpy() for tensor in tensor_list]
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import paddle
import paddle.distributed as dist
import paddle.fluid as fluid
import test_collective_api_base as test_base
......@@ -24,8 +25,14 @@ class TestCollectiveAllreduceAPI(test_base.TestCollectiveAPIRunnerBase):
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
if indata.dtype == "bfloat16":
tindata = paddle.to_tensor(indata, "float32").cast("uint16")
dist.all_reduce(tindata)
return [tindata.cast("float32").numpy()]
else:
tindata = paddle.to_tensor(indata)
paddle.distributed.all_reduce(tindata)
dist.all_reduce(tindata)
return [tindata.numpy()]
......
......@@ -13,23 +13,31 @@
# limitations under the License.
import paddle
import paddle.distributed as dist
import paddle.fluid as fluid
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
import test_collective_api_base as test_base
class TestCollectiveAllToAllAPI(TestCollectiveAPIRunnerBase):
class TestCollectiveAllToAllAPI(test_base.TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
toutdata = []
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
if indata.dtype == "bfloat16":
tindata = paddle.to_tensor(indata, "float32").cast("uint16")
tindata = paddle.split(tindata, 2, axis=0)
dist.alltoall(tindata, toutdata)
return [data.cast("float32").numpy() for data in toutdata]
else:
tindata = paddle.to_tensor(indata)
tindata = paddle.split(tindata, 2, axis=0)
toutdata = []
paddle.distributed.alltoall(tindata, toutdata)
dist.alltoall(tindata, toutdata)
return [data.numpy() for data in toutdata]
if __name__ == "__main__":
runtime_main(TestCollectiveAllToAllAPI, "alltoall")
test_base.runtime_main(TestCollectiveAllToAllAPI, "alltoall")
......@@ -13,6 +13,7 @@
# limitations under the License.
import paddle
import paddle.distributed as dist
import paddle.fluid as fluid
import test_collective_api_base as test_base
......@@ -24,9 +25,16 @@ class TestCollectiveAllToAllSingleAPI(test_base.TestCollectiveAPIRunnerBase):
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
if indata.dtype == "bfloat16":
tindata = paddle.to_tensor(indata, "float32").cast("uint16")
toutdata = paddle.to_tensor(tindata, "float32").cast("uint16")
dist.alltoall_single(tindata, toutdata)
return [toutdata.cast("float32").numpy()]
else:
tindata = paddle.to_tensor(indata)
toutdata = paddle.to_tensor(indata)
paddle.distributed.alltoall_single(tindata, toutdata)
dist.alltoall_single(tindata, toutdata)
return [toutdata.numpy()]
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import paddle
import paddle.distributed as dist
import paddle.fluid as fluid
import test_collective_api_base as test_base
......@@ -24,8 +25,14 @@ class TestCollectiveBroadcastAPI(test_base.TestCollectiveAPIRunnerBase):
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
if indata.dtype == "bfloat16":
tindata = paddle.to_tensor(indata, "float32").cast("uint16")
dist.broadcast(tindata, src=1)
return [tindata.cast("float32").numpy()]
else:
tindata = paddle.to_tensor(indata)
paddle.distributed.broadcast(tindata, src=1)
dist.broadcast(tindata, src=1)
return [tindata.numpy()]
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import paddle
import paddle.distributed as dist
import paddle.fluid as fluid
import test_collective_api_base as test_base
......@@ -24,11 +25,21 @@ class TestCollectiveIsendIrecvAPI(test_base.TestCollectiveAPIRunnerBase):
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
if indata.dtype == "bfloat16":
tindata = paddle.to_tensor(indata, "float32").cast("uint16")
if rank == 0:
task = dist.isend(tindata, dst=1)
else:
task = dist.irecv(tindata, src=0)
task.wait()
return [tindata.cast("float32").numpy()]
else:
tindata = paddle.to_tensor(indata)
if rank == 0:
task = paddle.distributed.isend(tindata, dst=1)
task = dist.isend(tindata, dst=1)
else:
task = paddle.distributed.irecv(tindata, src=0)
task = dist.irecv(tindata, src=0)
task.wait()
return [tindata.numpy()]
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import paddle
import paddle.distributed as dist
import paddle.fluid as fluid
import test_collective_api_base as test_base
......@@ -24,8 +25,14 @@ class TestCollectiveReduceAPI(test_base.TestCollectiveAPIRunnerBase):
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
if indata.dtype == "bfloat16":
tindata = paddle.to_tensor(indata, "float32").cast("uint16")
dist.reduce(tindata, dst=0)
return [tindata.cast("float32").numpy()]
else:
tindata = paddle.to_tensor(indata)
paddle.distributed.reduce(tindata, dst=0)
dist.reduce(tindata, dst=0)
return [tindata.numpy()]
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import paddle
import paddle.distributed as dist
import paddle.fluid as fluid
import test_collective_api_base as test_base
......@@ -24,9 +25,16 @@ class TestCollectiveReduceScatterAPI(test_base.TestCollectiveAPIRunnerBase):
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
if indata.dtype == "bfloat16":
tindata = paddle.to_tensor(indata, "float32").cast("uint16")
subdata1, subdata2 = paddle.split(tindata, 2, axis=0)
dist.reduce_scatter(subdata1, [subdata1, subdata2])
return [subdata1.cast("float32").numpy()]
else:
tindata = paddle.to_tensor(indata)
subdata1, subdata2 = paddle.split(tindata, 2, axis=0)
paddle.distributed.reduce_scatter(subdata1, [subdata1, subdata2])
dist.reduce_scatter(subdata1, [subdata1, subdata2])
return [subdata1.numpy()]
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import paddle
import paddle.distributed as dist
import paddle.fluid as fluid
import test_collective_api_base as test_base
......@@ -24,12 +25,24 @@ class TestCollectiveScatterAPI(test_base.TestCollectiveAPIRunnerBase):
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
if indata.dtype == "bfloat16":
tindata = paddle.to_tensor(indata, "float32").cast("uint16")
subdata1, subdata2 = paddle.split(tindata, 2, axis=0)
if rank == 0:
dist.scatter(subdata1, src=1)
else:
dist.scatter(subdata1,
tensor_list=[subdata1, subdata2],
src=1)
return [subdata1.cast("float32").numpy()]
else:
tindata = paddle.to_tensor(indata)
subdata1, subdata2 = paddle.split(tindata, 2, axis=0)
if rank == 0:
paddle.distributed.scatter(subdata1, src=1)
dist.scatter(subdata1, src=1)
else:
paddle.distributed.scatter(subdata1,
dist.scatter(subdata1,
tensor_list=[subdata1, subdata2],
src=1)
return [subdata1.numpy()]
......
......@@ -13,24 +13,34 @@
# limitations under the License.
import paddle
import paddle.distributed as dist
import paddle.fluid as fluid
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
import test_collective_api_base as test_base
class TestCollectiveSendRecvAPI(TestCollectiveAPIRunnerBase):
class TestCollectiveSendRecvAPI(test_base.TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
if indata.dtype == "bfloat16":
tindata = paddle.to_tensor(indata, "float32").cast("uint16")
if rank == 0:
dist.send(tindata, dst=1)
else:
dist.recv(tindata, src=0)
return [tindata.cast("float32").numpy()]
else:
tindata = paddle.to_tensor(indata)
if rank == 0:
paddle.distributed.send(tindata, dst=1)
dist.send(tindata, dst=1)
else:
paddle.distributed.recv(tindata, src=0)
dist.recv(tindata, src=0)
return [tindata.numpy()]
if __name__ == "__main__":
runtime_main(TestCollectiveSendRecvAPI, "sendrecv")
test_base.runtime_main(TestCollectiveSendRecvAPI, "sendrecv")
......@@ -26,213 +26,55 @@ class TestCollectiveAllgatherAPI(TestDistBase):
pass
def test_allgather_nccl(self):
dtypes_to_test = [
"float16", "float32", "float64", "int32", "int64", "int8", "uint8",
"bool", "complex64", "complex128"
]
for dtype in dtypes_to_test:
self.check_with_place("collective_allgather_api.py",
"allgather",
"nccl",
dtype="float16")
self.check_with_place("collective_allgather_api.py",
"allgather",
"nccl",
dtype="float32")
self.check_with_place("collective_allgather_api.py",
"allgather",
"nccl",
dtype="float64")
self.check_with_place("collective_allgather_api.py",
"allgather",
"nccl",
dtype="bool")
self.check_with_place("collective_allgather_api.py",
"allgather",
"nccl",
dtype="uint8")
self.check_with_place("collective_allgather_api.py",
"allgather",
"nccl",
dtype="int8")
self.check_with_place("collective_allgather_api.py",
"allgather",
"nccl",
dtype="int32")
self.check_with_place("collective_allgather_api.py",
"allgather",
"nccl",
dtype="int64")
self.check_with_place("collective_allgather_api.py",
"allgather",
"nccl",
dtype="complex64")
self.check_with_place("collective_allgather_api.py",
"allgather",
"nccl",
dtype="complex128")
dtype=dtype)
def test_allgather_gloo(self):
dtypes_to_test = [
"float16", "float32", "float64", "int32", "int64", "int8", "uint8",
"bool", "complex64", "complex128"
]
for dtype in dtypes_to_test:
self.check_with_place("collective_allgather_api.py",
"allgather",
"gloo",
"3",
dtype="float16")
self.check_with_place("collective_allgather_api.py",
"allgather",
"gloo",
"3",
dtype="float32")
self.check_with_place("collective_allgather_api.py",
"allgather",
"gloo",
"3",
dtype="float64")
self.check_with_place("collective_allgather_api.py",
"allgather",
"gloo",
"3",
dtype="bool")
self.check_with_place("collective_allgather_api.py",
"allgather",
"gloo",
"3",
dtype="uint8")
self.check_with_place("collective_allgather_api.py",
"allgather",
"gloo",
"3",
dtype="int8")
self.check_with_place("collective_allgather_api.py",
"allgather",
"gloo",
"3",
dtype="int32")
self.check_with_place("collective_allgather_api.py",
"allgather",
"gloo",
"3",
dtype="int64")
self.check_with_place("collective_allgather_api.py",
"allgather",
"gloo",
"3",
dtype="complex64")
self.check_with_place("collective_allgather_api.py",
"allgather",
"gloo",
"3",
dtype="complex128")
dtype=dtype)
def test_allgatther_nccl_dygraph(self):
dtypes_to_test = [
"float16", "float32", "float64", "int32", "int64", "int8", "uint8",
"bool", "complex64", "complex128"
]
if self._nccl_version >= 2100:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"nccl",
static_mode="0",
dtype="float16")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"nccl",
static_mode="0",
dtype="float32")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"nccl",
static_mode="0",
dtype="float64")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"nccl",
static_mode="0",
dtype="bool")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"nccl",
static_mode="0",
dtype="uint8")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"nccl",
static_mode="0",
dtype="int8")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"nccl",
static_mode="0",
dtype="int32")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"nccl",
static_mode="0",
dtype="int64")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"nccl",
static_mode="0",
dtype="complex64")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"nccl",
static_mode="0",
dtype="complex128")
dtype=dtype)
def test_allgather_gloo_dygraph(self):
dtypes_to_test = [
"float16", "float32", "float64", "int32", "int64", "int8", "uint8",
"bool", "bfloat16", "complex64", "complex128"
]
for dtype in dtypes_to_test:
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"gloo",
"3",
static_mode="0",
dtype="float16")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"gloo",
"3",
static_mode="0",
dtype="float32")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"gloo",
"3",
static_mode="0",
dtype="float64")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"gloo",
"3",
static_mode="0",
dtype="bool")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"gloo",
"3",
static_mode="0",
dtype="uint8")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"gloo",
"3",
static_mode="0",
dtype="int8")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"gloo",
"3",
static_mode="0",
dtype="int32")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"gloo",
"3",
static_mode="0",
dtype="int64")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"gloo",
"3",
static_mode="0",
dtype="complex64")
self.check_with_place("collective_allgather_api_dygraph.py",
"allgather",
"gloo",
"3",
static_mode="0",
dtype="complex128")
dtype=dtype)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
......@@ -41,9 +41,11 @@ class TestCollectiveAllreduceAPI(TestDistBase):
def test_allreduce_nccl_dygraph(self):
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
"float16", "float32", "float64", "int32", "int64", "int8", "uint8",
"bool"
]
if self._nccl_version >= 2100:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place("collective_allreduce_api_dygraph.py",
"allreduce",
......@@ -53,8 +55,8 @@ class TestCollectiveAllreduceAPI(TestDistBase):
def test_allreduce_gloo_dygraph(self):
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
"float16", "float32", "float64", "int32", "int64", "int8", "uint8",
"bool", "bfloat16"
]
for dtype in dtypes_to_test:
self.check_with_place("collective_allreduce_api_dygraph.py",
......@@ -65,5 +67,5 @@ class TestCollectiveAllreduceAPI(TestDistBase):
dtype=dtype)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
......@@ -30,9 +30,11 @@ class TestCollectiveAllToAllAPI(TestDistBase):
def test_alltoall_nccl_dygraph(self):
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
"float16", "float32", "float64", "int32", "int64", "int8", "uint8",
"bool"
]
if self._nccl_version >= 2100:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place("collective_alltoall_api_dygraph.py",
"alltoall",
......@@ -41,5 +43,5 @@ class TestCollectiveAllToAllAPI(TestDistBase):
dtype=dtype)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
......@@ -23,9 +23,11 @@ class TestCollectiveAllToAllSingleAPI(test_base.TestDistBase):
def test_alltooall_single_nccl_dygraph(self):
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
"float16", "float32", "float64", "int32", "int64", "int8", "uint8",
"bool"
]
if self._nccl_version >= 2100:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place("collective_alltoall_single_api_dygraph.py",
"alltoall",
......@@ -34,5 +36,5 @@ class TestCollectiveAllToAllSingleAPI(test_base.TestDistBase):
dtype=dtype)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
......@@ -35,9 +35,11 @@ class TestCollectiveBroadcastAPI(TestDistBase):
def test_broadcast_nccl_dygraph(self):
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
"float16", "float32", "float64", "int32", "int64", "int8", "uint8",
"bool"
]
if self._nccl_version >= 2100:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place("collective_broadcast_api_dygraph.py",
"broadcast",
......@@ -47,8 +49,8 @@ class TestCollectiveBroadcastAPI(TestDistBase):
def test_broadcast_gloo_dygraph(self):
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
"float16", "float32", "float64", "int32", "int64", "int8", "uint8",
"bool", "bfloat16"
]
for dtype in dtypes_to_test:
self.check_with_place("collective_broadcast_api_dygraph.py",
......@@ -59,5 +61,5 @@ class TestCollectiveBroadcastAPI(TestDistBase):
dtype=dtype)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
......@@ -23,9 +23,11 @@ class TestCollectiveIsendIrecvAPI(test_base.TestDistBase):
def test_isend_irecv_nccl_dygraph(self):
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
"float16", "float32", "float64", "int32", "int64", "int8", "uint8",
"bool"
]
if self._nccl_version >= 2100:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place("collective_isend_irecv_api_dygraph.py",
"sendrecv",
......@@ -34,5 +36,5 @@ class TestCollectiveIsendIrecvAPI(test_base.TestDistBase):
dtype=dtype)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
......@@ -38,9 +38,11 @@ class TestCollectiveReduceAPI(TestDistBase):
def test_reduce_nccl_dygraph(self):
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
"float16", "float32", "float64", "int32", "int64", "int8", "uint8",
"bool"
]
if self._nccl_version >= 2100:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place("collective_reduce_api_dygraph.py",
"reduce",
......@@ -50,8 +52,8 @@ class TestCollectiveReduceAPI(TestDistBase):
def test_reduce_gloo_dygraph(self):
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
"float16", "float32", "float64", "int32", "int64", "int8", "uint8",
"bool", "bfloat16"
]
for dtype in dtypes_to_test:
self.check_with_place("collective_reduce_api_dygraph.py",
......@@ -62,5 +64,5 @@ class TestCollectiveReduceAPI(TestDistBase):
dtype=dtype)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
......@@ -23,9 +23,11 @@ class TestCollectiveReduceScatterAPI(test_base.TestDistBase):
def test_reduce_scatter_nccl_dygraph(self):
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
"float16", "float32", "float64", "int32", "int64", "int8", "uint8",
"bool"
]
if self._nccl_version >= 2100:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place("collective_reduce_scatter_api_dygraph.py",
"reduce_scatter",
......@@ -34,5 +36,5 @@ class TestCollectiveReduceScatterAPI(test_base.TestDistBase):
dtype=dtype)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
......@@ -34,9 +34,11 @@ class TestCollectiveScatterAPI(TestDistBase):
def test_scatter_nccl_dygraph(self):
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
"float16", "float32", "float64", "int32", "int64", "int8", "uint8",
"bool"
]
if self._nccl_version >= 2100:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place("collective_scatter_api_dygraph.py",
"scatter",
......@@ -46,8 +48,8 @@ class TestCollectiveScatterAPI(TestDistBase):
def test_scatter_gloo_dygraph(self):
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
"float16", "float32", "float64", "int32", "int64", "int8", "uint8",
"bool", "bfloat16"
]
for dtype in dtypes_to_test:
self.check_with_place("collective_scatter_api_dygraph.py",
......@@ -58,5 +60,5 @@ class TestCollectiveScatterAPI(TestDistBase):
dtype=dtype)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
......@@ -32,9 +32,11 @@ class TestCollectiveSendRecvAPI(TestDistBase):
def test_sendrecv_nccl_dygraph(self):
dtypes_to_test = [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
"float16", "float32", "float64", "int32", "int64", "int8", "uint8",
"bool"
]
if self._nccl_version >= 2100:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place("collective_sendrecv_api_dygraph.py",
"sendrecv",
......@@ -43,5 +45,5 @@ class TestCollectiveSendRecvAPI(TestDistBase):
dtype=dtype)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
......@@ -7,27 +7,27 @@ test_c_split,linux,gpu;rocm,120,DIST,test_runner.py,2,,PYTHONPATH=..;http_proxy=
test_collective_split_embedding,linux,rocm;gpu,300,DIST,../dist_test.sh,2,,PYTHONPATH=..;http_proxy=;https_proxy=,
test_collective_allgather_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_allgather_object_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_allreduce_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_alltoall_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_allreduce_api,linux,gpu;rocm,180,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_alltoall_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_alltoall_single,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_alltoall_single_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_alltoall_single_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_barrier_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_batch_isend_irecv,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_broadcast_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_broadcast_api,linux,gpu;rocm,180,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_cpu_barrier_with_gloo,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_global_gather,linux,gpu;rocm,200,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_global_scatter,linux,gpu;rocm,200,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_isend_irecv_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_isend_irecv_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_optimizer,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_process_group,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce_api,linux,gpu;rocm,180,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce_scatter,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce_scatter_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce_scatter_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_scatter,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_scatter_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_scatter_api,linux,gpu;rocm,180,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_sendrecv,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_sendrecv_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_sendrecv_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_split_col_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_split_embedding_none_divisible,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_split_row_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
......
......@@ -23,6 +23,7 @@ from contextlib import closing
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle_bfloat import bfloat16
def create_bool_test_data(shape=None, seed=None):
......@@ -76,6 +77,9 @@ def create_test_data(shape=None, dtype=None, seed=None):
assert shape, "Shape should be specified"
if dtype == "float32" or dtype == "float16" or dtype == "float64":
return create_float_test_data(shape=shape, dtype=dtype, seed=seed)
elif dtype == "bfloat16":
# since numpy does not support bfloat16 yet, use `paddle_bfloat` to replace
return create_float_test_data(shape=shape, dtype=bfloat16, seed=seed)
elif dtype == "bool":
return create_bool_test_data(shape=shape, seed=seed)
elif dtype == "int32" or dtype == "int64" or dtype == "int8" or dtype == "uint8":
......@@ -167,6 +171,15 @@ class TestDistBase(unittest.TestCase):
self.temp_dir = tempfile.TemporaryDirectory()
# NOTE: this is a hack to get int format nccl version, like 2134
# if current platform is not linux, version number will be 0
nccl_version_str = subprocess.check_output(
r"ldconfig -v | grep 'libnccl.so' | tail -n1 | sed -r 's/^.*\.so\.//'",
stderr=subprocess.DEVNULL,
shell=True).decode('utf-8')
self._nccl_version = int("".join(
nccl_version_str.split("."))) if nccl_version_str else 0
def tearDown(self):
self.temp_dir.cleanup()
......@@ -305,6 +318,10 @@ class TestDistBase(unittest.TestCase):
model_file, required_envs)
input1 = create_test_data(shape=(10, 1000), dtype=dtype, seed=pid0)
input2 = create_test_data(shape=(10, 1000), dtype=dtype, seed=pid1)
# cast bfloat16 to float32 for numeric comparison
if dtype == "bfloat16":
input1 = input1.astype("float32")
input2 = input2.astype("float32")
if col_type == "allgather":
need_result = np.vstack((input1, input2))
tr_out0 = np.vstack((tr0_out[0], tr0_out[1]))
......@@ -321,7 +338,13 @@ class TestDistBase(unittest.TestCase):
np.testing.assert_allclose(tr1_out[0], need_result, rtol=1e-05)
elif col_type == "reduce":
need_result = input1 + input2
np.testing.assert_allclose(tr0_out[0], need_result, rtol=1e-05)
# bfloat16 precision loss comes from truncating the last 16 bits of float32,
# which sums (\sum_{i=-23}^{-8}2^{i}) to about 0.0078
if dtype == "bfloat16":
rtol = 8e-03
else:
rtol = 1e-05
np.testing.assert_allclose(tr0_out[0], need_result, rtol=rtol)
elif col_type == "scatter":
need_result = input2
need_result1 = need_result[0:need_result.shape[0] // 2]
......@@ -332,18 +355,28 @@ class TestDistBase(unittest.TestCase):
need_result = input1 + input2
need_result1 = need_result[0:need_result.shape[0] // 2]
need_result2 = need_result[need_result.shape[0] // 2:]
np.testing.assert_allclose(tr0_out[0], need_result1, rtol=1e-05)
np.testing.assert_allclose(tr1_out[0], need_result2, rtol=1e-05)
if dtype == "bfloat16":
rtol = 8e-03
else:
rtol = 1e-05
np.testing.assert_allclose(tr0_out[0], need_result1, rtol=rtol)
np.testing.assert_allclose(tr1_out[0], need_result2, rtol=rtol)
elif col_type == "allreduce":
need_result = input1 + input2
if dtype == "bfloat16":
rtol = 8e-03
atol = 8e-03
else:
rtol = 1e-05
atol = 1e-05
np.testing.assert_allclose(tr0_out[0],
need_result,
rtol=1e-05,
atol=1e-05)
rtol=rtol,
atol=atol)
np.testing.assert_allclose(tr1_out[0],
need_result,
rtol=1e-05,
atol=1e-05)
rtol=rtol,
atol=atol)
elif col_type == "parallel_embedding":
result_data = tr0_out[0]
np.random.seed(2020)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册