Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e4eb8d36
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e4eb8d36
编写于
10月 11, 2022
作者:
W
Wen Sun
提交者:
GitHub
10月 11, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Completes bfloat16 dtype for collective api in eager mode (#45844)
上级
f6a85db9
变更
27
隐藏空白更改
内联
并排
Showing
27 changed file
with
317 addition
and
329 deletion
+317
-329
paddle/fluid/distributed/collective/ProcessGroupGloo.cc
paddle/fluid/distributed/collective/ProcessGroupGloo.cc
+3
-0
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
+3
-0
paddle/fluid/platform/device/gpu/nccl_helper.h
paddle/fluid/platform/device/gpu/nccl_helper.h
+2
-2
python/paddle/distributed/collective.py
python/paddle/distributed/collective.py
+18
-17
python/paddle/fluid/tests/unittests/collective/CMakeLists.txt
...on/paddle/fluid/tests/unittests/collective/CMakeLists.txt
+9
-9
python/paddle/fluid/tests/unittests/collective/collective_allgather_api_dygraph.py
.../unittests/collective/collective_allgather_api_dygraph.py
+12
-3
python/paddle/fluid/tests/unittests/collective/collective_allreduce_api_dygraph.py
.../unittests/collective/collective_allreduce_api_dygraph.py
+10
-3
python/paddle/fluid/tests/unittests/collective/collective_alltoall_api_dygraph.py
...s/unittests/collective/collective_alltoall_api_dygraph.py
+15
-7
python/paddle/fluid/tests/unittests/collective/collective_alltoall_single_api_dygraph.py
...ests/collective/collective_alltoall_single_api_dygraph.py
+12
-4
python/paddle/fluid/tests/unittests/collective/collective_broadcast_api_dygraph.py
.../unittests/collective/collective_broadcast_api_dygraph.py
+10
-3
python/paddle/fluid/tests/unittests/collective/collective_isend_irecv_api_dygraph.py
...nittests/collective/collective_isend_irecv_api_dygraph.py
+17
-6
python/paddle/fluid/tests/unittests/collective/collective_reduce_api_dygraph.py
...sts/unittests/collective/collective_reduce_api_dygraph.py
+10
-3
python/paddle/fluid/tests/unittests/collective/collective_reduce_scatter_api_dygraph.py
...tests/collective/collective_reduce_scatter_api_dygraph.py
+12
-4
python/paddle/fluid/tests/unittests/collective/collective_scatter_api_dygraph.py
...ts/unittests/collective/collective_scatter_api_dygraph.py
+21
-8
python/paddle/fluid/tests/unittests/collective/collective_sendrecv_api_dygraph.py
...s/unittests/collective/collective_sendrecv_api_dygraph.py
+18
-8
python/paddle/fluid/tests/unittests/collective/test_collective_allgather_api.py
...sts/unittests/collective/test_collective_allgather_api.py
+43
-201
python/paddle/fluid/tests/unittests/collective/test_collective_allreduce_api.py
...sts/unittests/collective/test_collective_allreduce_api.py
+7
-5
python/paddle/fluid/tests/unittests/collective/test_collective_alltoall_api.py
...ests/unittests/collective/test_collective_alltoall_api.py
+5
-3
python/paddle/fluid/tests/unittests/collective/test_collective_alltoall_single_api.py
...ittests/collective/test_collective_alltoall_single_api.py
+5
-3
python/paddle/fluid/tests/unittests/collective/test_collective_broadcast_api.py
...sts/unittests/collective/test_collective_broadcast_api.py
+7
-5
python/paddle/fluid/tests/unittests/collective/test_collective_isend_irecv_api.py
...s/unittests/collective/test_collective_isend_irecv_api.py
+5
-3
python/paddle/fluid/tests/unittests/collective/test_collective_reduce_api.py
.../tests/unittests/collective/test_collective_reduce_api.py
+7
-5
python/paddle/fluid/tests/unittests/collective/test_collective_reduce_scatter_api.py
...nittests/collective/test_collective_reduce_scatter_api.py
+5
-3
python/paddle/fluid/tests/unittests/collective/test_collective_scatter_api.py
...tests/unittests/collective/test_collective_scatter_api.py
+7
-5
python/paddle/fluid/tests/unittests/collective/test_collective_sendrecv_api.py
...ests/unittests/collective/test_collective_sendrecv_api.py
+5
-3
python/paddle/fluid/tests/unittests/collective/testslist.csv
python/paddle/fluid/tests/unittests/collective/testslist.csv
+9
-9
python/paddle/fluid/tests/unittests/test_collective_api_base.py
.../paddle/fluid/tests/unittests/test_collective_api_base.py
+40
-7
未找到文件。
paddle/fluid/distributed/collective/ProcessGroupGloo.cc
浏览文件 @
e4eb8d36
...
@@ -88,6 +88,9 @@ namespace distributed {
...
@@ -88,6 +88,9 @@ namespace distributed {
case experimental::DataType::BOOL: \
case experimental::DataType::BOOL: \
func<bool>(args); \
func<bool>(args); \
break; \
break; \
case experimental::DataType::BFLOAT16: \
func<bfloat16>(args); \
break; \
default: \
default: \
VLOG(0) << "Error: Unknown DataType."; \
VLOG(0) << "Error: Unknown DataType."; \
exit(-1); \
exit(-1); \
...
...
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
浏览文件 @
e4eb8d36
...
@@ -996,6 +996,9 @@ void* GetPointerByOffset(void* raw_pointer,
...
@@ -996,6 +996,9 @@ void* GetPointerByOffset(void* raw_pointer,
}
else
if
(
type
==
experimental
::
DataType
::
BOOL
)
{
}
else
if
(
type
==
experimental
::
DataType
::
BOOL
)
{
return
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
bool
*>
(
raw_pointer
)
+
return
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
bool
*>
(
raw_pointer
)
+
offset
);
offset
);
}
else
if
(
type
==
experimental
::
DataType
::
BFLOAT16
)
{
return
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
uint16_t
*>
(
raw_pointer
)
+
offset
);
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"This datatype in nccl is not supported."
));
"This datatype in nccl is not supported."
));
...
...
paddle/fluid/platform/device/gpu/nccl_helper.h
浏览文件 @
e4eb8d36
...
@@ -59,7 +59,7 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
...
@@ -59,7 +59,7 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
return
ncclUint8
;
return
ncclUint8
;
}
else
if
(
type
==
framework
::
proto
::
VarType
::
BOOL
)
{
}
else
if
(
type
==
framework
::
proto
::
VarType
::
BOOL
)
{
return
ncclUint8
;
return
ncclUint8
;
#if
CUDNN_VERSION_MIN(8, 1, 0) &&
NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000
}
else
if
(
type
==
framework
::
proto
::
VarType
::
BF16
)
{
}
else
if
(
type
==
framework
::
proto
::
VarType
::
BF16
)
{
return
ncclBfloat16
;
return
ncclBfloat16
;
#endif
#endif
...
@@ -86,7 +86,7 @@ inline ncclDataType_t ToNCCLDataType(experimental::DataType type) {
...
@@ -86,7 +86,7 @@ inline ncclDataType_t ToNCCLDataType(experimental::DataType type) {
return
ncclInt8
;
return
ncclInt8
;
}
else
if
(
type
==
experimental
::
DataType
::
BOOL
)
{
}
else
if
(
type
==
experimental
::
DataType
::
BOOL
)
{
return
ncclUint8
;
return
ncclUint8
;
#if
CUDNN_VERSION_MIN(8, 1, 0) &&
NCCL_VERSION_CODE >= 21000
#if NCCL_VERSION_CODE >= 21000
}
else
if
(
type
==
experimental
::
DataType
::
BFLOAT16
)
{
}
else
if
(
type
==
experimental
::
DataType
::
BFLOAT16
)
{
return
ncclBfloat16
;
return
ncclBfloat16
;
#endif
#endif
...
...
python/paddle/distributed/collective.py
浏览文件 @
e4eb8d36
...
@@ -478,7 +478,8 @@ def is_initialized():
...
@@ -478,7 +478,8 @@ def is_initialized():
Check whether the distributed environment has been initialized
Check whether the distributed environment has been initialized
Returns (bool): `True` if distributed environment has been initialized, otherwise `False`.
Returns:
`True` if distributed environment has been initialized, otherwise `False`.
Examples:
Examples:
.. code-block:: python
.. code-block:: python
...
@@ -626,7 +627,7 @@ def broadcast(tensor, src, group=None, sync_op=True):
...
@@ -626,7 +627,7 @@ def broadcast(tensor, src, group=None, sync_op=True):
Args:
Args:
tensor (Tensor): The Tensor to send if current rank is the source, or the Tensor to receive otherwise. Its data type
tensor (Tensor): The Tensor to send if current rank is the source, or the Tensor to receive otherwise. Its data type
should be float16, float32, float64, int32, int64, int8, uint8
or bool
.
should be float16, float32, float64, int32, int64, int8, uint8
, bool or bfloat16
.
src (int): The source rank.
src (int): The source rank.
group (Group, optional): The group instance return by new_group or None for global default group.
group (Group, optional): The group instance return by new_group or None for global default group.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
...
@@ -709,7 +710,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, sync_op=True):
...
@@ -709,7 +710,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, sync_op=True):
Args:
Args:
tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
should be float16, float32, float64, int32, int64, int8, uint8
or bool
.
should be float16, float32, float64, int32, int64, int8, uint8
, bool or bfloat16
.
dst (int): The destination rank id.
dst (int): The destination rank id.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The operation used. Default value is ReduceOp.SUM.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The operation used. Default value is ReduceOp.SUM.
group (Group, optional): The group instance return by new_group or None for global default group.
group (Group, optional): The group instance return by new_group or None for global default group.
...
@@ -817,7 +818,7 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True):
...
@@ -817,7 +818,7 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True):
Args:
Args:
tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
should be float16, float32, float64, int32, int64, int8, uint8, bool,
bfloat16,
complex64 or complex128.
tensor (Tensor): The Tensor to send. Its data type
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
should be float16, float32, float64, int32, int64, int8, uint8, bool, complex64 or complex128.
group (Group, optional): The group instance return by new_group or None for global default group.
group (Group, optional): The group instance return by new_group or None for global default group.
...
@@ -999,9 +1000,9 @@ def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True):
...
@@ -999,9 +1000,9 @@ def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True):
Args:
Args:
tensor (Tensor): The output Tensor. Its data type
tensor (Tensor): The output Tensor. Its data type
should be float16, float32, float64, int32, int64, int8, uint8
or bool
.
should be float16, float32, float64, int32, int64, int8, uint8
, bool or bfloat16
.
tensor_list (list|tuple): A list/tuple of Tensors to scatter. Every element in the list must be a Tensor whose data type
tensor_list (list|tuple): A list/tuple of Tensors to scatter. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32, int64, int8, uint8
or bool
. Default value is None.
should be float16, float32, float64, int32, int64, int8, uint8
, bool or bfloat16
. Default value is None.
src (int): The source rank id. Default value is 0.
src (int): The source rank id. Default value is 0.
group (Group, optional): The group instance return by new_group or None for global default group.
group (Group, optional): The group instance return by new_group or None for global default group.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
...
@@ -1096,7 +1097,7 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, sync_op=True):
...
@@ -1096,7 +1097,7 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, sync_op=True):
Args:
Args:
in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type
in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32, int64, int8, uint8
or bool
.
should be float16, float32, float64, int32, int64, int8, uint8
, bool or bfloat16
.
out_tensor_list (list): A list of output Tensors. The data type of its elements should be the same as the
out_tensor_list (list): A list of output Tensors. The data type of its elements should be the same as the
data type of the input Tensors.
data type of the input Tensors.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
...
@@ -1197,7 +1198,7 @@ def alltoall_single(in_tensor,
...
@@ -1197,7 +1198,7 @@ def alltoall_single(in_tensor,
``alltoall_single`` is only supported in eager mode.
``alltoall_single`` is only supported in eager mode.
Args:
Args:
in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32, int64, int8, uint8
or bool
.
in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32, int64, int8, uint8
, bool or bfloat16
.
out_tensor (Tensor): Output Tensor. The data type should be the same as the data type of the input Tensor.
out_tensor (Tensor): Output Tensor. The data type should be the same as the data type of the input Tensor.
in_split_sizes (list[int], optional): Split sizes of ``in_tensor`` for dim[0]. If not given, dim[0] of ``in_tensor``
in_split_sizes (list[int], optional): Split sizes of ``in_tensor`` for dim[0]. If not given, dim[0] of ``in_tensor``
must be divisible by group size and ``in_tensor`` will be scattered averagely to all participators. Default: None.
must be divisible by group size and ``in_tensor`` will be scattered averagely to all participators. Default: None.
...
@@ -1286,7 +1287,7 @@ def send(tensor, dst=0, group=None, sync_op=True):
...
@@ -1286,7 +1287,7 @@ def send(tensor, dst=0, group=None, sync_op=True):
Args:
Args:
tensor (Tensor): The Tensor to send. Its data type
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32, int64, int8, uint8
or bool
.
should be float16, float32, float64, int32, int64, int8, uint8
, bool or bfloat16
.
dst (int): The destination rank id.
dst (int): The destination rank id.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
...
@@ -1352,7 +1353,7 @@ def recv(tensor, src=0, group=None, sync_op=True):
...
@@ -1352,7 +1353,7 @@ def recv(tensor, src=0, group=None, sync_op=True):
Args:
Args:
tensor (Tensor): The Tensor to receive. Its data type
tensor (Tensor): The Tensor to receive. Its data type
should be float16, float32, float64, int32, int64, int8, uint8
or bool
.
should be float16, float32, float64, int32, int64, int8, uint8
, bool or bfloat16
.
src (int): The source rank id.
src (int): The source rank id.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
...
@@ -1435,7 +1436,7 @@ def isend(tensor, dst, group=None):
...
@@ -1435,7 +1436,7 @@ def isend(tensor, dst, group=None):
Args:
Args:
tensor (Tensor): The Tensor to send. Its data type
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32, int64, int8, uint8
or bool
.
should be float16, float32, float64, int32, int64, int8, uint8
, bool or bfloat16
.
dst (int): The destination rank.
dst (int): The destination rank.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
...
@@ -1485,7 +1486,7 @@ def irecv(tensor, src=None, group=None):
...
@@ -1485,7 +1486,7 @@ def irecv(tensor, src=None, group=None):
Args:
Args:
tensor (Tensor): The Tensor to receive. Its data type
tensor (Tensor): The Tensor to receive. Its data type
should be float16, float32, float64, int32, int64, int8, uint8
or bool
.
should be float16, float32, float64, int32, int64, int8, uint8
, bool or bfloat16
.
src (int): The source rank id.
src (int): The source rank id.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
...
@@ -1594,7 +1595,7 @@ def batch_isend_irecv(p2p_op_list):
...
@@ -1594,7 +1595,7 @@ def batch_isend_irecv(p2p_op_list):
corresponding tasks. NCCL are currently supported.
corresponding tasks. NCCL are currently supported.
Args:
Args:
p2p_op_list: A list of point-to-point operations(type of each operator is
p2p_op_list
(List[P2POp])
: A list of point-to-point operations(type of each operator is
``paddle.distributed.P2POp``). The order of the isend/irecv in the list
``paddle.distributed.P2POp``). The order of the isend/irecv in the list
matters and it needs to match with corresponding isend/irecv on the
matters and it needs to match with corresponding isend/irecv on the
remote end.
remote end.
...
@@ -1668,9 +1669,9 @@ def reduce_scatter(tensor,
...
@@ -1668,9 +1669,9 @@ def reduce_scatter(tensor,
Reduces, then scatters a list of tensors to all processes in a group
Reduces, then scatters a list of tensors to all processes in a group
Args:
Args:
tensor (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8
or bool
.
tensor (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8
, bool or bfloat16
.
tensor_list (list[Tensor]): List of tensors to reduce and scatter. Every element in the list must be a Tensor whose data type
tensor_list (list[Tensor]): List of tensors to reduce and scatter. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32, int64, int8, uint8
or bool
.
should be float16, float32, float64, int32, int64, int8, uint8
, bool or bfloat16
.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
group (Group, optional): The group instance return by new_group or None for global
group (Group, optional): The group instance return by new_group or None for global
default group. Default: None.
default group. Default: None.
...
@@ -1736,9 +1737,9 @@ def _reduce_scatter_base(output,
...
@@ -1736,9 +1737,9 @@ def _reduce_scatter_base(output,
Reduces, then scatters a flattened tensor to all processes in a group.
Reduces, then scatters a flattened tensor to all processes in a group.
Args:
Args:
output (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8
or bool
.
output (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8
, bool or bfloat16
.
input (Tensor): Input tensor that is of size output tensor size times world size. Its data type
input (Tensor): Input tensor that is of size output tensor size times world size. Its data type
should be float16, float32, float64, int32, int64, int8, uint8
or bool
.
should be float16, float32, float64, int32, int64, int8, uint8
, bool or bfloat16
.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
group (ProcessGroup, optional): The process group to work on. If None,
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
the default process group will be used.
...
...
python/paddle/fluid/tests/unittests/collective/CMakeLists.txt
浏览文件 @
e4eb8d36
...
@@ -71,14 +71,14 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
...
@@ -71,14 +71,14 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_allreduce_api MODULES test_collective_allreduce_api ENVS
test_collective_allreduce_api MODULES test_collective_allreduce_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
set_tests_properties
(
test_collective_allreduce_api
set_tests_properties
(
test_collective_allreduce_api
PROPERTIES TIMEOUT
"1
2
0"
LABELS
"RUN_TYPE=DIST"
)
PROPERTIES TIMEOUT
"1
8
0"
LABELS
"RUN_TYPE=DIST"
)
endif
()
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
py_test_modules
(
py_test_modules
(
test_collective_alltoall_api MODULES test_collective_alltoall_api ENVS
test_collective_alltoall_api MODULES test_collective_alltoall_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
set_tests_properties
(
test_collective_alltoall_api
set_tests_properties
(
test_collective_alltoall_api
PROPERTIES TIMEOUT
"
30
0"
LABELS
"RUN_TYPE=DIST"
)
PROPERTIES TIMEOUT
"
12
0"
LABELS
"RUN_TYPE=DIST"
)
endif
()
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
bash_test_modules
(
bash_test_modules
(
...
@@ -98,7 +98,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
...
@@ -98,7 +98,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_alltoall_single_api ENVS
test_collective_alltoall_single_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
set_tests_properties
(
test_collective_alltoall_single_api
set_tests_properties
(
test_collective_alltoall_single_api
PROPERTIES TIMEOUT
"
30
0"
LABELS
"RUN_TYPE=DIST"
)
PROPERTIES TIMEOUT
"
12
0"
LABELS
"RUN_TYPE=DIST"
)
endif
()
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
py_test_modules
(
py_test_modules
(
...
@@ -125,7 +125,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
...
@@ -125,7 +125,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_broadcast_api MODULES test_collective_broadcast_api ENVS
test_collective_broadcast_api MODULES test_collective_broadcast_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
set_tests_properties
(
test_collective_broadcast_api
set_tests_properties
(
test_collective_broadcast_api
PROPERTIES TIMEOUT
"
30
0"
LABELS
"RUN_TYPE=DIST"
)
PROPERTIES TIMEOUT
"
18
0"
LABELS
"RUN_TYPE=DIST"
)
endif
()
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
py_test_modules
(
py_test_modules
(
...
@@ -154,7 +154,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
...
@@ -154,7 +154,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_isend_irecv_api MODULES test_collective_isend_irecv_api
test_collective_isend_irecv_api MODULES test_collective_isend_irecv_api
ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
set_tests_properties
(
test_collective_isend_irecv_api
set_tests_properties
(
test_collective_isend_irecv_api
PROPERTIES TIMEOUT
"
30
0"
LABELS
"RUN_TYPE=DIST"
)
PROPERTIES TIMEOUT
"
12
0"
LABELS
"RUN_TYPE=DIST"
)
endif
()
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
py_test_modules
(
py_test_modules
(
...
@@ -187,7 +187,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
...
@@ -187,7 +187,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_reduce_api MODULES test_collective_reduce_api ENVS
test_collective_reduce_api MODULES test_collective_reduce_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
set_tests_properties
(
test_collective_reduce_api
set_tests_properties
(
test_collective_reduce_api
PROPERTIES TIMEOUT
"
30
0"
LABELS
"RUN_TYPE=DIST"
)
PROPERTIES TIMEOUT
"
18
0"
LABELS
"RUN_TYPE=DIST"
)
endif
()
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
bash_test_modules
(
bash_test_modules
(
...
@@ -207,7 +207,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
...
@@ -207,7 +207,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_reduce_scatter_api ENVS
test_collective_reduce_scatter_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
set_tests_properties
(
test_collective_reduce_scatter_api
set_tests_properties
(
test_collective_reduce_scatter_api
PROPERTIES TIMEOUT
"
30
0"
LABELS
"RUN_TYPE=DIST"
)
PROPERTIES TIMEOUT
"
12
0"
LABELS
"RUN_TYPE=DIST"
)
endif
()
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
py_test_modules
(
py_test_modules
(
...
@@ -221,7 +221,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
...
@@ -221,7 +221,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_scatter_api MODULES test_collective_scatter_api ENVS
test_collective_scatter_api MODULES test_collective_scatter_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
set_tests_properties
(
test_collective_scatter_api
set_tests_properties
(
test_collective_scatter_api
PROPERTIES TIMEOUT
"
30
0"
LABELS
"RUN_TYPE=DIST"
)
PROPERTIES TIMEOUT
"
18
0"
LABELS
"RUN_TYPE=DIST"
)
endif
()
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
py_test_modules
(
py_test_modules
(
...
@@ -235,7 +235,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
...
@@ -235,7 +235,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_sendrecv_api MODULES test_collective_sendrecv_api ENVS
test_collective_sendrecv_api MODULES test_collective_sendrecv_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
set_tests_properties
(
test_collective_sendrecv_api
set_tests_properties
(
test_collective_sendrecv_api
PROPERTIES TIMEOUT
"
30
0"
LABELS
"RUN_TYPE=DIST"
)
PROPERTIES TIMEOUT
"
12
0"
LABELS
"RUN_TYPE=DIST"
)
endif
()
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
py_test_modules
(
py_test_modules
(
...
...
python/paddle/fluid/tests/unittests/collective/collective_allgather_api_dygraph.py
浏览文件 @
e4eb8d36
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
paddle
import
paddle
import
paddle.distributed
as
dist
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
test_collective_api_base
as
test_base
import
test_collective_api_base
as
test_base
...
@@ -24,10 +25,18 @@ class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase):
...
@@ -24,10 +25,18 @@ class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase):
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
paddle
.
to_tensor
(
indata
)
tensor_list
=
[]
tensor_list
=
[]
paddle
.
distributed
.
all_gather
(
tensor_list
,
tindata
)
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
return
[
tensor
.
numpy
()
for
tensor
in
tensor_list
]
if
indata
.
dtype
==
"bfloat16"
:
tindata
=
paddle
.
to_tensor
(
indata
,
"float32"
).
cast
(
"uint16"
)
dist
.
all_gather
(
tensor_list
,
tindata
)
return
[
tensor
.
cast
(
"float32"
).
numpy
()
for
tensor
in
tensor_list
]
else
:
tindata
=
paddle
.
to_tensor
(
indata
)
dist
.
all_gather
(
tensor_list
,
tindata
)
return
[
tensor
.
numpy
()
for
tensor
in
tensor_list
]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
python/paddle/fluid/tests/unittests/collective/collective_allreduce_api_dygraph.py
浏览文件 @
e4eb8d36
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
paddle
import
paddle
import
paddle.distributed
as
dist
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
test_collective_api_base
as
test_base
import
test_collective_api_base
as
test_base
...
@@ -24,9 +25,15 @@ class TestCollectiveAllreduceAPI(test_base.TestCollectiveAPIRunnerBase):
...
@@ -24,9 +25,15 @@ class TestCollectiveAllreduceAPI(test_base.TestCollectiveAPIRunnerBase):
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
paddle
.
to_tensor
(
indata
)
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
paddle
.
distributed
.
all_reduce
(
tindata
)
if
indata
.
dtype
==
"bfloat16"
:
return
[
tindata
.
numpy
()]
tindata
=
paddle
.
to_tensor
(
indata
,
"float32"
).
cast
(
"uint16"
)
dist
.
all_reduce
(
tindata
)
return
[
tindata
.
cast
(
"float32"
).
numpy
()]
else
:
tindata
=
paddle
.
to_tensor
(
indata
)
dist
.
all_reduce
(
tindata
)
return
[
tindata
.
numpy
()]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
python/paddle/fluid/tests/unittests/collective/collective_alltoall_api_dygraph.py
浏览文件 @
e4eb8d36
...
@@ -13,23 +13,31 @@
...
@@ -13,23 +13,31 @@
# limitations under the License.
# limitations under the License.
import
paddle
import
paddle
import
paddle.distributed
as
dist
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
test_collective_api_base
import
TestCollectiveAPIRunnerBase
,
runtime_main
import
test_collective_api_base
as
test_base
class
TestCollectiveAllToAllAPI
(
TestCollectiveAPIRunnerBase
):
class
TestCollectiveAllToAllAPI
(
test_base
.
TestCollectiveAPIRunnerBase
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
global_ring_id
=
0
self
.
global_ring_id
=
0
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
paddle
.
to_tensor
(
indata
)
tindata
=
paddle
.
split
(
tindata
,
2
,
axis
=
0
)
toutdata
=
[]
toutdata
=
[]
paddle
.
distributed
.
alltoall
(
tindata
,
toutdata
)
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
return
[
data
.
numpy
()
for
data
in
toutdata
]
if
indata
.
dtype
==
"bfloat16"
:
tindata
=
paddle
.
to_tensor
(
indata
,
"float32"
).
cast
(
"uint16"
)
tindata
=
paddle
.
split
(
tindata
,
2
,
axis
=
0
)
dist
.
alltoall
(
tindata
,
toutdata
)
return
[
data
.
cast
(
"float32"
).
numpy
()
for
data
in
toutdata
]
else
:
tindata
=
paddle
.
to_tensor
(
indata
)
tindata
=
paddle
.
split
(
tindata
,
2
,
axis
=
0
)
dist
.
alltoall
(
tindata
,
toutdata
)
return
[
data
.
numpy
()
for
data
in
toutdata
]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
runtime_main
(
TestCollectiveAllToAllAPI
,
"alltoall"
)
test_base
.
runtime_main
(
TestCollectiveAllToAllAPI
,
"alltoall"
)
python/paddle/fluid/tests/unittests/collective/collective_alltoall_single_api_dygraph.py
浏览文件 @
e4eb8d36
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
paddle
import
paddle
import
paddle.distributed
as
dist
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
test_collective_api_base
as
test_base
import
test_collective_api_base
as
test_base
...
@@ -24,10 +25,17 @@ class TestCollectiveAllToAllSingleAPI(test_base.TestCollectiveAPIRunnerBase):
...
@@ -24,10 +25,17 @@ class TestCollectiveAllToAllSingleAPI(test_base.TestCollectiveAPIRunnerBase):
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
paddle
.
to_tensor
(
indata
)
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
toutdata
=
paddle
.
to_tensor
(
indata
)
if
indata
.
dtype
==
"bfloat16"
:
paddle
.
distributed
.
alltoall_single
(
tindata
,
toutdata
)
tindata
=
paddle
.
to_tensor
(
indata
,
"float32"
).
cast
(
"uint16"
)
return
[
toutdata
.
numpy
()]
toutdata
=
paddle
.
to_tensor
(
tindata
,
"float32"
).
cast
(
"uint16"
)
dist
.
alltoall_single
(
tindata
,
toutdata
)
return
[
toutdata
.
cast
(
"float32"
).
numpy
()]
else
:
tindata
=
paddle
.
to_tensor
(
indata
)
toutdata
=
paddle
.
to_tensor
(
indata
)
dist
.
alltoall_single
(
tindata
,
toutdata
)
return
[
toutdata
.
numpy
()]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
python/paddle/fluid/tests/unittests/collective/collective_broadcast_api_dygraph.py
浏览文件 @
e4eb8d36
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
paddle
import
paddle
import
paddle.distributed
as
dist
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
test_collective_api_base
as
test_base
import
test_collective_api_base
as
test_base
...
@@ -24,9 +25,15 @@ class TestCollectiveBroadcastAPI(test_base.TestCollectiveAPIRunnerBase):
...
@@ -24,9 +25,15 @@ class TestCollectiveBroadcastAPI(test_base.TestCollectiveAPIRunnerBase):
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
paddle
.
to_tensor
(
indata
)
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
paddle
.
distributed
.
broadcast
(
tindata
,
src
=
1
)
if
indata
.
dtype
==
"bfloat16"
:
return
[
tindata
.
numpy
()]
tindata
=
paddle
.
to_tensor
(
indata
,
"float32"
).
cast
(
"uint16"
)
dist
.
broadcast
(
tindata
,
src
=
1
)
return
[
tindata
.
cast
(
"float32"
).
numpy
()]
else
:
tindata
=
paddle
.
to_tensor
(
indata
)
dist
.
broadcast
(
tindata
,
src
=
1
)
return
[
tindata
.
numpy
()]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
python/paddle/fluid/tests/unittests/collective/collective_isend_irecv_api_dygraph.py
浏览文件 @
e4eb8d36
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
paddle
import
paddle
import
paddle.distributed
as
dist
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
test_collective_api_base
as
test_base
import
test_collective_api_base
as
test_base
...
@@ -24,13 +25,23 @@ class TestCollectiveIsendIrecvAPI(test_base.TestCollectiveAPIRunnerBase):
...
@@ -24,13 +25,23 @@ class TestCollectiveIsendIrecvAPI(test_base.TestCollectiveAPIRunnerBase):
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
paddle
.
to_tensor
(
indata
)
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
if
rank
==
0
:
if
indata
.
dtype
==
"bfloat16"
:
task
=
paddle
.
distributed
.
isend
(
tindata
,
dst
=
1
)
tindata
=
paddle
.
to_tensor
(
indata
,
"float32"
).
cast
(
"uint16"
)
if
rank
==
0
:
task
=
dist
.
isend
(
tindata
,
dst
=
1
)
else
:
task
=
dist
.
irecv
(
tindata
,
src
=
0
)
task
.
wait
()
return
[
tindata
.
cast
(
"float32"
).
numpy
()]
else
:
else
:
task
=
paddle
.
distributed
.
irecv
(
tindata
,
src
=
0
)
tindata
=
paddle
.
to_tensor
(
indata
)
task
.
wait
()
if
rank
==
0
:
return
[
tindata
.
numpy
()]
task
=
dist
.
isend
(
tindata
,
dst
=
1
)
else
:
task
=
dist
.
irecv
(
tindata
,
src
=
0
)
task
.
wait
()
return
[
tindata
.
numpy
()]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
python/paddle/fluid/tests/unittests/collective/collective_reduce_api_dygraph.py
浏览文件 @
e4eb8d36
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
paddle
import
paddle
import
paddle.distributed
as
dist
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
test_collective_api_base
as
test_base
import
test_collective_api_base
as
test_base
...
@@ -24,9 +25,15 @@ class TestCollectiveReduceAPI(test_base.TestCollectiveAPIRunnerBase):
...
@@ -24,9 +25,15 @@ class TestCollectiveReduceAPI(test_base.TestCollectiveAPIRunnerBase):
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
paddle
.
to_tensor
(
indata
)
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
paddle
.
distributed
.
reduce
(
tindata
,
dst
=
0
)
if
indata
.
dtype
==
"bfloat16"
:
return
[
tindata
.
numpy
()]
tindata
=
paddle
.
to_tensor
(
indata
,
"float32"
).
cast
(
"uint16"
)
dist
.
reduce
(
tindata
,
dst
=
0
)
return
[
tindata
.
cast
(
"float32"
).
numpy
()]
else
:
tindata
=
paddle
.
to_tensor
(
indata
)
dist
.
reduce
(
tindata
,
dst
=
0
)
return
[
tindata
.
numpy
()]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
python/paddle/fluid/tests/unittests/collective/collective_reduce_scatter_api_dygraph.py
浏览文件 @
e4eb8d36
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
paddle
import
paddle
import
paddle.distributed
as
dist
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
test_collective_api_base
as
test_base
import
test_collective_api_base
as
test_base
...
@@ -24,10 +25,17 @@ class TestCollectiveReduceScatterAPI(test_base.TestCollectiveAPIRunnerBase):
...
@@ -24,10 +25,17 @@ class TestCollectiveReduceScatterAPI(test_base.TestCollectiveAPIRunnerBase):
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
paddle
.
to_tensor
(
indata
)
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
subdata1
,
subdata2
=
paddle
.
split
(
tindata
,
2
,
axis
=
0
)
if
indata
.
dtype
==
"bfloat16"
:
paddle
.
distributed
.
reduce_scatter
(
subdata1
,
[
subdata1
,
subdata2
])
tindata
=
paddle
.
to_tensor
(
indata
,
"float32"
).
cast
(
"uint16"
)
return
[
subdata1
.
numpy
()]
subdata1
,
subdata2
=
paddle
.
split
(
tindata
,
2
,
axis
=
0
)
dist
.
reduce_scatter
(
subdata1
,
[
subdata1
,
subdata2
])
return
[
subdata1
.
cast
(
"float32"
).
numpy
()]
else
:
tindata
=
paddle
.
to_tensor
(
indata
)
subdata1
,
subdata2
=
paddle
.
split
(
tindata
,
2
,
axis
=
0
)
dist
.
reduce_scatter
(
subdata1
,
[
subdata1
,
subdata2
])
return
[
subdata1
.
numpy
()]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
python/paddle/fluid/tests/unittests/collective/collective_scatter_api_dygraph.py
浏览文件 @
e4eb8d36
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
paddle
import
paddle
import
paddle.distributed
as
dist
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
test_collective_api_base
as
test_base
import
test_collective_api_base
as
test_base
...
@@ -24,15 +25,27 @@ class TestCollectiveScatterAPI(test_base.TestCollectiveAPIRunnerBase):
...
@@ -24,15 +25,27 @@ class TestCollectiveScatterAPI(test_base.TestCollectiveAPIRunnerBase):
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
paddle
.
to_tensor
(
indata
)
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
subdata1
,
subdata2
=
paddle
.
split
(
tindata
,
2
,
axis
=
0
)
if
indata
.
dtype
==
"bfloat16"
:
if
rank
==
0
:
tindata
=
paddle
.
to_tensor
(
indata
,
"float32"
).
cast
(
"uint16"
)
paddle
.
distributed
.
scatter
(
subdata1
,
src
=
1
)
subdata1
,
subdata2
=
paddle
.
split
(
tindata
,
2
,
axis
=
0
)
if
rank
==
0
:
dist
.
scatter
(
subdata1
,
src
=
1
)
else
:
dist
.
scatter
(
subdata1
,
tensor_list
=
[
subdata1
,
subdata2
],
src
=
1
)
return
[
subdata1
.
cast
(
"float32"
).
numpy
()]
else
:
else
:
paddle
.
distributed
.
scatter
(
subdata1
,
tindata
=
paddle
.
to_tensor
(
indata
)
tensor_list
=
[
subdata1
,
subdata2
],
subdata1
,
subdata2
=
paddle
.
split
(
tindata
,
2
,
axis
=
0
)
src
=
1
)
if
rank
==
0
:
return
[
subdata1
.
numpy
()]
dist
.
scatter
(
subdata1
,
src
=
1
)
else
:
dist
.
scatter
(
subdata1
,
tensor_list
=
[
subdata1
,
subdata2
],
src
=
1
)
return
[
subdata1
.
numpy
()]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
python/paddle/fluid/tests/unittests/collective/collective_sendrecv_api_dygraph.py
浏览文件 @
e4eb8d36
...
@@ -13,24 +13,34 @@
...
@@ -13,24 +13,34 @@
# limitations under the License.
# limitations under the License.
import
paddle
import
paddle
import
paddle.distributed
as
dist
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
test_collective_api_base
import
TestCollectiveAPIRunnerBase
,
runtime_main
import
test_collective_api_base
as
test_base
class
TestCollectiveSendRecvAPI
(
TestCollectiveAPIRunnerBase
):
class
TestCollectiveSendRecvAPI
(
test_base
.
TestCollectiveAPIRunnerBase
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
global_ring_id
=
0
self
.
global_ring_id
=
0
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
paddle
.
to_tensor
(
indata
)
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
if
rank
==
0
:
if
indata
.
dtype
==
"bfloat16"
:
paddle
.
distributed
.
send
(
tindata
,
dst
=
1
)
tindata
=
paddle
.
to_tensor
(
indata
,
"float32"
).
cast
(
"uint16"
)
if
rank
==
0
:
dist
.
send
(
tindata
,
dst
=
1
)
else
:
dist
.
recv
(
tindata
,
src
=
0
)
return
[
tindata
.
cast
(
"float32"
).
numpy
()]
else
:
else
:
paddle
.
distributed
.
recv
(
tindata
,
src
=
0
)
tindata
=
paddle
.
to_tensor
(
indata
)
return
[
tindata
.
numpy
()]
if
rank
==
0
:
dist
.
send
(
tindata
,
dst
=
1
)
else
:
dist
.
recv
(
tindata
,
src
=
0
)
return
[
tindata
.
numpy
()]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
runtime_main
(
TestCollectiveSendRecvAPI
,
"sendrecv"
)
test_base
.
runtime_main
(
TestCollectiveSendRecvAPI
,
"sendrecv"
)
python/paddle/fluid/tests/unittests/collective/test_collective_allgather_api.py
浏览文件 @
e4eb8d36
...
@@ -26,213 +26,55 @@ class TestCollectiveAllgatherAPI(TestDistBase):
...
@@ -26,213 +26,55 @@ class TestCollectiveAllgatherAPI(TestDistBase):
pass
pass
def
test_allgather_nccl
(
self
):
def
test_allgather_nccl
(
self
):
self
.
check_with_place
(
"collective_allgather_api.py"
,
dtypes_to_test
=
[
"allgather"
,
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int64"
,
"int8"
,
"uint8"
,
"nccl"
,
"bool"
,
"complex64"
,
"complex128"
dtype
=
"float16"
)
]
self
.
check_with_place
(
"collective_allgather_api.py"
,
for
dtype
in
dtypes_to_test
:
"allgather"
,
self
.
check_with_place
(
"collective_allgather_api.py"
,
"nccl"
,
"allgather"
,
dtype
=
"float32"
)
"nccl"
,
self
.
check_with_place
(
"collective_allgather_api.py"
,
dtype
=
dtype
)
"allgather"
,
"nccl"
,
dtype
=
"float64"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"bool"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"uint8"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"int8"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"int32"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"int64"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"complex64"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"complex128"
)
def
test_allgather_gloo
(
self
):
def
test_allgather_gloo
(
self
):
self
.
check_with_place
(
"collective_allgather_api.py"
,
dtypes_to_test
=
[
"allgather"
,
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int64"
,
"int8"
,
"uint8"
,
"gloo"
,
"bool"
,
"complex64"
,
"complex128"
"3"
,
]
dtype
=
"float16"
)
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_allgather_api.py"
,
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"allgather"
,
"gloo"
,
"gloo"
,
"3"
,
"3"
,
dtype
=
"float32"
)
dtype
=
dtype
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"float64"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"bool"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"uint8"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"int8"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"int32"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"int64"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"complex64"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"complex128"
)
def
test_allgatther_nccl_dygraph
(
self
):
def
test_allgatther_nccl_dygraph
(
self
):
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
dtypes_to_test
=
[
"allgather"
,
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int64"
,
"int8"
,
"uint8"
,
"nccl"
,
"bool"
,
"complex64"
,
"complex128"
static_mode
=
"0"
,
]
dtype
=
"float16"
)
if
self
.
_nccl_version
>=
2100
:
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
dtypes_to_test
.
append
(
"bfloat16"
)
"allgather"
,
for
dtype
in
dtypes_to_test
:
"nccl"
,
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
static_mode
=
"0"
,
"allgather"
,
dtype
=
"float32"
)
"nccl"
,
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
static_mode
=
"0"
,
"allgather"
,
dtype
=
dtype
)
"nccl"
,
static_mode
=
"0"
,
dtype
=
"float64"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"bool"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"uint8"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"int8"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"int32"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"int64"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"complex64"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"complex128"
)
def
test_allgather_gloo_dygraph
(
self
):
def
test_allgather_gloo_dygraph
(
self
):
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
dtypes_to_test
=
[
"allgather"
,
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int64"
,
"int8"
,
"uint8"
,
"gloo"
,
"bool"
,
"bfloat16"
,
"complex64"
,
"complex128"
"3"
,
]
static_mode
=
"0"
,
for
dtype
in
dtypes_to_test
:
dtype
=
"float16"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"allgather"
,
"gloo"
,
"gloo"
,
"3"
,
"3"
,
static_mode
=
"0"
,
static_mode
=
"0"
,
dtype
=
dtype
)
dtype
=
"float32"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"float64"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"bool"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"uint8"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"int8"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"int32"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"int64"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"complex64"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"complex128"
)
if
__name__
==
'__main__'
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/test_collective_allreduce_api.py
浏览文件 @
e4eb8d36
...
@@ -41,9 +41,11 @@ class TestCollectiveAllreduceAPI(TestDistBase):
...
@@ -41,9 +41,11 @@ class TestCollectiveAllreduceAPI(TestDistBase):
def
test_allreduce_nccl_dygraph
(
self
):
def
test_allreduce_nccl_dygraph
(
self
):
dtypes_to_test
=
[
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int64"
,
"int8"
,
"uint8"
,
'bool'
"bool"
]
]
if
self
.
_nccl_version
>=
2100
:
dtypes_to_test
.
append
(
"bfloat16"
)
for
dtype
in
dtypes_to_test
:
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_allreduce_api_dygraph.py"
,
self
.
check_with_place
(
"collective_allreduce_api_dygraph.py"
,
"allreduce"
,
"allreduce"
,
...
@@ -53,8 +55,8 @@ class TestCollectiveAllreduceAPI(TestDistBase):
...
@@ -53,8 +55,8 @@ class TestCollectiveAllreduceAPI(TestDistBase):
def
test_allreduce_gloo_dygraph
(
self
):
def
test_allreduce_gloo_dygraph
(
self
):
dtypes_to_test
=
[
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int64"
,
"int8"
,
"uint8"
,
'bool'
"bool"
,
"bfloat16"
]
]
for
dtype
in
dtypes_to_test
:
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_allreduce_api_dygraph.py"
,
self
.
check_with_place
(
"collective_allreduce_api_dygraph.py"
,
...
@@ -65,5 +67,5 @@ class TestCollectiveAllreduceAPI(TestDistBase):
...
@@ -65,5 +67,5 @@ class TestCollectiveAllreduceAPI(TestDistBase):
dtype
=
dtype
)
dtype
=
dtype
)
if
__name__
==
'__main__'
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/test_collective_alltoall_api.py
浏览文件 @
e4eb8d36
...
@@ -30,9 +30,11 @@ class TestCollectiveAllToAllAPI(TestDistBase):
...
@@ -30,9 +30,11 @@ class TestCollectiveAllToAllAPI(TestDistBase):
def
test_alltoall_nccl_dygraph
(
self
):
def
test_alltoall_nccl_dygraph
(
self
):
dtypes_to_test
=
[
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int64"
,
"int8"
,
"uint8"
,
'bool'
"bool"
]
]
if
self
.
_nccl_version
>=
2100
:
dtypes_to_test
.
append
(
"bfloat16"
)
for
dtype
in
dtypes_to_test
:
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_alltoall_api_dygraph.py"
,
self
.
check_with_place
(
"collective_alltoall_api_dygraph.py"
,
"alltoall"
,
"alltoall"
,
...
@@ -41,5 +43,5 @@ class TestCollectiveAllToAllAPI(TestDistBase):
...
@@ -41,5 +43,5 @@ class TestCollectiveAllToAllAPI(TestDistBase):
dtype
=
dtype
)
dtype
=
dtype
)
if
__name__
==
'__main__'
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/test_collective_alltoall_single_api.py
浏览文件 @
e4eb8d36
...
@@ -23,9 +23,11 @@ class TestCollectiveAllToAllSingleAPI(test_base.TestDistBase):
...
@@ -23,9 +23,11 @@ class TestCollectiveAllToAllSingleAPI(test_base.TestDistBase):
def
test_alltooall_single_nccl_dygraph
(
self
):
def
test_alltooall_single_nccl_dygraph
(
self
):
dtypes_to_test
=
[
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int64"
,
"int8"
,
"uint8"
,
'bool'
"bool"
]
]
if
self
.
_nccl_version
>=
2100
:
dtypes_to_test
.
append
(
"bfloat16"
)
for
dtype
in
dtypes_to_test
:
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_alltoall_single_api_dygraph.py"
,
self
.
check_with_place
(
"collective_alltoall_single_api_dygraph.py"
,
"alltoall"
,
"alltoall"
,
...
@@ -34,5 +36,5 @@ class TestCollectiveAllToAllSingleAPI(test_base.TestDistBase):
...
@@ -34,5 +36,5 @@ class TestCollectiveAllToAllSingleAPI(test_base.TestDistBase):
dtype
=
dtype
)
dtype
=
dtype
)
if
__name__
==
'__main__'
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/test_collective_broadcast_api.py
浏览文件 @
e4eb8d36
...
@@ -35,9 +35,11 @@ class TestCollectiveBroadcastAPI(TestDistBase):
...
@@ -35,9 +35,11 @@ class TestCollectiveBroadcastAPI(TestDistBase):
def
test_broadcast_nccl_dygraph
(
self
):
def
test_broadcast_nccl_dygraph
(
self
):
dtypes_to_test
=
[
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int64"
,
"int8"
,
"uint8"
,
'bool'
"bool"
]
]
if
self
.
_nccl_version
>=
2100
:
dtypes_to_test
.
append
(
"bfloat16"
)
for
dtype
in
dtypes_to_test
:
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_broadcast_api_dygraph.py"
,
self
.
check_with_place
(
"collective_broadcast_api_dygraph.py"
,
"broadcast"
,
"broadcast"
,
...
@@ -47,8 +49,8 @@ class TestCollectiveBroadcastAPI(TestDistBase):
...
@@ -47,8 +49,8 @@ class TestCollectiveBroadcastAPI(TestDistBase):
def
test_broadcast_gloo_dygraph
(
self
):
def
test_broadcast_gloo_dygraph
(
self
):
dtypes_to_test
=
[
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int64"
,
"int8"
,
"uint8"
,
'bool'
"bool"
,
"bfloat16"
]
]
for
dtype
in
dtypes_to_test
:
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_broadcast_api_dygraph.py"
,
self
.
check_with_place
(
"collective_broadcast_api_dygraph.py"
,
...
@@ -59,5 +61,5 @@ class TestCollectiveBroadcastAPI(TestDistBase):
...
@@ -59,5 +61,5 @@ class TestCollectiveBroadcastAPI(TestDistBase):
dtype
=
dtype
)
dtype
=
dtype
)
if
__name__
==
'__main__'
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/test_collective_isend_irecv_api.py
浏览文件 @
e4eb8d36
...
@@ -23,9 +23,11 @@ class TestCollectiveIsendIrecvAPI(test_base.TestDistBase):
...
@@ -23,9 +23,11 @@ class TestCollectiveIsendIrecvAPI(test_base.TestDistBase):
def
test_isend_irecv_nccl_dygraph
(
self
):
def
test_isend_irecv_nccl_dygraph
(
self
):
dtypes_to_test
=
[
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int64"
,
"int8"
,
"uint8"
,
'bool'
"bool"
]
]
if
self
.
_nccl_version
>=
2100
:
dtypes_to_test
.
append
(
"bfloat16"
)
for
dtype
in
dtypes_to_test
:
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_isend_irecv_api_dygraph.py"
,
self
.
check_with_place
(
"collective_isend_irecv_api_dygraph.py"
,
"sendrecv"
,
"sendrecv"
,
...
@@ -34,5 +36,5 @@ class TestCollectiveIsendIrecvAPI(test_base.TestDistBase):
...
@@ -34,5 +36,5 @@ class TestCollectiveIsendIrecvAPI(test_base.TestDistBase):
dtype
=
dtype
)
dtype
=
dtype
)
if
__name__
==
'__main__'
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/test_collective_reduce_api.py
浏览文件 @
e4eb8d36
...
@@ -38,9 +38,11 @@ class TestCollectiveReduceAPI(TestDistBase):
...
@@ -38,9 +38,11 @@ class TestCollectiveReduceAPI(TestDistBase):
def
test_reduce_nccl_dygraph
(
self
):
def
test_reduce_nccl_dygraph
(
self
):
dtypes_to_test
=
[
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int64"
,
"int8"
,
"uint8"
,
'bool'
"bool"
]
]
if
self
.
_nccl_version
>=
2100
:
dtypes_to_test
.
append
(
"bfloat16"
)
for
dtype
in
dtypes_to_test
:
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_reduce_api_dygraph.py"
,
self
.
check_with_place
(
"collective_reduce_api_dygraph.py"
,
"reduce"
,
"reduce"
,
...
@@ -50,8 +52,8 @@ class TestCollectiveReduceAPI(TestDistBase):
...
@@ -50,8 +52,8 @@ class TestCollectiveReduceAPI(TestDistBase):
def
test_reduce_gloo_dygraph
(
self
):
def
test_reduce_gloo_dygraph
(
self
):
dtypes_to_test
=
[
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int64"
,
"int8"
,
"uint8"
,
'bool'
"bool"
,
"bfloat16"
]
]
for
dtype
in
dtypes_to_test
:
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_reduce_api_dygraph.py"
,
self
.
check_with_place
(
"collective_reduce_api_dygraph.py"
,
...
@@ -62,5 +64,5 @@ class TestCollectiveReduceAPI(TestDistBase):
...
@@ -62,5 +64,5 @@ class TestCollectiveReduceAPI(TestDistBase):
dtype
=
dtype
)
dtype
=
dtype
)
if
__name__
==
'__main__'
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/test_collective_reduce_scatter_api.py
浏览文件 @
e4eb8d36
...
@@ -23,9 +23,11 @@ class TestCollectiveReduceScatterAPI(test_base.TestDistBase):
...
@@ -23,9 +23,11 @@ class TestCollectiveReduceScatterAPI(test_base.TestDistBase):
def
test_reduce_scatter_nccl_dygraph
(
self
):
def
test_reduce_scatter_nccl_dygraph
(
self
):
dtypes_to_test
=
[
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int64"
,
"int8"
,
"uint8"
,
'bool'
"bool"
]
]
if
self
.
_nccl_version
>=
2100
:
dtypes_to_test
.
append
(
"bfloat16"
)
for
dtype
in
dtypes_to_test
:
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_reduce_scatter_api_dygraph.py"
,
self
.
check_with_place
(
"collective_reduce_scatter_api_dygraph.py"
,
"reduce_scatter"
,
"reduce_scatter"
,
...
@@ -34,5 +36,5 @@ class TestCollectiveReduceScatterAPI(test_base.TestDistBase):
...
@@ -34,5 +36,5 @@ class TestCollectiveReduceScatterAPI(test_base.TestDistBase):
dtype
=
dtype
)
dtype
=
dtype
)
if
__name__
==
'__main__'
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/test_collective_scatter_api.py
浏览文件 @
e4eb8d36
...
@@ -34,9 +34,11 @@ class TestCollectiveScatterAPI(TestDistBase):
...
@@ -34,9 +34,11 @@ class TestCollectiveScatterAPI(TestDistBase):
def
test_scatter_nccl_dygraph
(
self
):
def
test_scatter_nccl_dygraph
(
self
):
dtypes_to_test
=
[
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int64"
,
"int8"
,
"uint8"
,
'bool'
"bool"
]
]
if
self
.
_nccl_version
>=
2100
:
dtypes_to_test
.
append
(
"bfloat16"
)
for
dtype
in
dtypes_to_test
:
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_scatter_api_dygraph.py"
,
self
.
check_with_place
(
"collective_scatter_api_dygraph.py"
,
"scatter"
,
"scatter"
,
...
@@ -46,8 +48,8 @@ class TestCollectiveScatterAPI(TestDistBase):
...
@@ -46,8 +48,8 @@ class TestCollectiveScatterAPI(TestDistBase):
def
test_scatter_gloo_dygraph
(
self
):
def
test_scatter_gloo_dygraph
(
self
):
dtypes_to_test
=
[
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int64"
,
"int8"
,
"uint8"
,
'bool'
"bool"
,
"bfloat16"
]
]
for
dtype
in
dtypes_to_test
:
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_scatter_api_dygraph.py"
,
self
.
check_with_place
(
"collective_scatter_api_dygraph.py"
,
...
@@ -58,5 +60,5 @@ class TestCollectiveScatterAPI(TestDistBase):
...
@@ -58,5 +60,5 @@ class TestCollectiveScatterAPI(TestDistBase):
dtype
=
dtype
)
dtype
=
dtype
)
if
__name__
==
'__main__'
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/test_collective_sendrecv_api.py
浏览文件 @
e4eb8d36
...
@@ -32,9 +32,11 @@ class TestCollectiveSendRecvAPI(TestDistBase):
...
@@ -32,9 +32,11 @@ class TestCollectiveSendRecvAPI(TestDistBase):
def
test_sendrecv_nccl_dygraph
(
self
):
def
test_sendrecv_nccl_dygraph
(
self
):
dtypes_to_test
=
[
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int64"
,
"int8"
,
"uint8"
,
'bool'
"bool"
]
]
if
self
.
_nccl_version
>=
2100
:
dtypes_to_test
.
append
(
"bfloat16"
)
for
dtype
in
dtypes_to_test
:
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_sendrecv_api_dygraph.py"
,
self
.
check_with_place
(
"collective_sendrecv_api_dygraph.py"
,
"sendrecv"
,
"sendrecv"
,
...
@@ -43,5 +45,5 @@ class TestCollectiveSendRecvAPI(TestDistBase):
...
@@ -43,5 +45,5 @@ class TestCollectiveSendRecvAPI(TestDistBase):
dtype
=
dtype
)
dtype
=
dtype
)
if
__name__
==
'__main__'
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/testslist.csv
浏览文件 @
e4eb8d36
...
@@ -7,27 +7,27 @@ test_c_split,linux,gpu;rocm,120,DIST,test_runner.py,2,,PYTHONPATH=..;http_proxy=
...
@@ -7,27 +7,27 @@ test_c_split,linux,gpu;rocm,120,DIST,test_runner.py,2,,PYTHONPATH=..;http_proxy=
test_collective_split_embedding,linux,rocm;gpu,300,DIST,../dist_test.sh,2,,PYTHONPATH=..;http_proxy=;https_proxy=,
test_collective_split_embedding,linux,rocm;gpu,300,DIST,../dist_test.sh,2,,PYTHONPATH=..;http_proxy=;https_proxy=,
test_collective_allgather_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_allgather_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_allgather_object_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_allgather_object_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_allreduce_api,linux,gpu;rocm,1
2
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_allreduce_api,linux,gpu;rocm,1
8
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_alltoall_api,linux,gpu;rocm,
30
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_alltoall_api,linux,gpu;rocm,
12
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_alltoall_single,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_alltoall_single,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_alltoall_single_api,linux,gpu;rocm,
30
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_alltoall_single_api,linux,gpu;rocm,
12
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_barrier_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_barrier_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_batch_isend_irecv,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_batch_isend_irecv,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_broadcast_api,linux,gpu;rocm,
30
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_broadcast_api,linux,gpu;rocm,
18
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_cpu_barrier_with_gloo,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_cpu_barrier_with_gloo,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_global_gather,linux,gpu;rocm,200,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_global_gather,linux,gpu;rocm,200,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_global_scatter,linux,gpu;rocm,200,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_global_scatter,linux,gpu;rocm,200,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_isend_irecv_api,linux,gpu;rocm,
30
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_isend_irecv_api,linux,gpu;rocm,
12
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_optimizer,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_optimizer,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_process_group,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_process_group,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce_api,linux,gpu;rocm,
30
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce_api,linux,gpu;rocm,
18
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce_scatter,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce_scatter,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce_scatter_api,linux,gpu;rocm,
30
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce_scatter_api,linux,gpu;rocm,
12
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_scatter,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_scatter,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_scatter_api,linux,gpu;rocm,
30
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_scatter_api,linux,gpu;rocm,
18
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_sendrecv,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_sendrecv,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_sendrecv_api,linux,gpu;rocm,
30
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_sendrecv_api,linux,gpu;rocm,
12
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_split_col_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_split_col_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_split_embedding_none_divisible,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_split_embedding_none_divisible,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_split_row_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_split_row_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
...
...
python/paddle/fluid/tests/unittests/test_collective_api_base.py
浏览文件 @
e4eb8d36
...
@@ -23,6 +23,7 @@ from contextlib import closing
...
@@ -23,6 +23,7 @@ from contextlib import closing
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid
import
core
from
paddle.fluid
import
core
from
paddle_bfloat
import
bfloat16
def
create_bool_test_data
(
shape
=
None
,
seed
=
None
):
def
create_bool_test_data
(
shape
=
None
,
seed
=
None
):
...
@@ -76,6 +77,9 @@ def create_test_data(shape=None, dtype=None, seed=None):
...
@@ -76,6 +77,9 @@ def create_test_data(shape=None, dtype=None, seed=None):
assert
shape
,
"Shape should be specified"
assert
shape
,
"Shape should be specified"
if
dtype
==
"float32"
or
dtype
==
"float16"
or
dtype
==
"float64"
:
if
dtype
==
"float32"
or
dtype
==
"float16"
or
dtype
==
"float64"
:
return
create_float_test_data
(
shape
=
shape
,
dtype
=
dtype
,
seed
=
seed
)
return
create_float_test_data
(
shape
=
shape
,
dtype
=
dtype
,
seed
=
seed
)
elif
dtype
==
"bfloat16"
:
# since numpy does not support bfloat16 yet, use `paddle_bfloat` to replace
return
create_float_test_data
(
shape
=
shape
,
dtype
=
bfloat16
,
seed
=
seed
)
elif
dtype
==
"bool"
:
elif
dtype
==
"bool"
:
return
create_bool_test_data
(
shape
=
shape
,
seed
=
seed
)
return
create_bool_test_data
(
shape
=
shape
,
seed
=
seed
)
elif
dtype
==
"int32"
or
dtype
==
"int64"
or
dtype
==
"int8"
or
dtype
==
"uint8"
:
elif
dtype
==
"int32"
or
dtype
==
"int64"
or
dtype
==
"int8"
or
dtype
==
"uint8"
:
...
@@ -167,6 +171,15 @@ class TestDistBase(unittest.TestCase):
...
@@ -167,6 +171,15 @@ class TestDistBase(unittest.TestCase):
self
.
temp_dir
=
tempfile
.
TemporaryDirectory
()
self
.
temp_dir
=
tempfile
.
TemporaryDirectory
()
# NOTE: this is a hack to get int format nccl version, like 2134
# if current platform is not linux, version number will be 0
nccl_version_str
=
subprocess
.
check_output
(
r
"ldconfig -v | grep 'libnccl.so' | tail -n1 | sed -r 's/^.*\.so\.//'"
,
stderr
=
subprocess
.
DEVNULL
,
shell
=
True
).
decode
(
'utf-8'
)
self
.
_nccl_version
=
int
(
""
.
join
(
nccl_version_str
.
split
(
"."
)))
if
nccl_version_str
else
0
def
tearDown
(
self
):
def
tearDown
(
self
):
self
.
temp_dir
.
cleanup
()
self
.
temp_dir
.
cleanup
()
...
@@ -305,6 +318,10 @@ class TestDistBase(unittest.TestCase):
...
@@ -305,6 +318,10 @@ class TestDistBase(unittest.TestCase):
model_file
,
required_envs
)
model_file
,
required_envs
)
input1
=
create_test_data
(
shape
=
(
10
,
1000
),
dtype
=
dtype
,
seed
=
pid0
)
input1
=
create_test_data
(
shape
=
(
10
,
1000
),
dtype
=
dtype
,
seed
=
pid0
)
input2
=
create_test_data
(
shape
=
(
10
,
1000
),
dtype
=
dtype
,
seed
=
pid1
)
input2
=
create_test_data
(
shape
=
(
10
,
1000
),
dtype
=
dtype
,
seed
=
pid1
)
# cast bfloat16 to float32 for numeric comparison
if
dtype
==
"bfloat16"
:
input1
=
input1
.
astype
(
"float32"
)
input2
=
input2
.
astype
(
"float32"
)
if
col_type
==
"allgather"
:
if
col_type
==
"allgather"
:
need_result
=
np
.
vstack
((
input1
,
input2
))
need_result
=
np
.
vstack
((
input1
,
input2
))
tr_out0
=
np
.
vstack
((
tr0_out
[
0
],
tr0_out
[
1
]))
tr_out0
=
np
.
vstack
((
tr0_out
[
0
],
tr0_out
[
1
]))
...
@@ -321,7 +338,13 @@ class TestDistBase(unittest.TestCase):
...
@@ -321,7 +338,13 @@ class TestDistBase(unittest.TestCase):
np
.
testing
.
assert_allclose
(
tr1_out
[
0
],
need_result
,
rtol
=
1e-05
)
np
.
testing
.
assert_allclose
(
tr1_out
[
0
],
need_result
,
rtol
=
1e-05
)
elif
col_type
==
"reduce"
:
elif
col_type
==
"reduce"
:
need_result
=
input1
+
input2
need_result
=
input1
+
input2
np
.
testing
.
assert_allclose
(
tr0_out
[
0
],
need_result
,
rtol
=
1e-05
)
# bfloat16 precision loss comes from truncating the last 16 bits of float32,
# which sums (\sum_{i=-23}^{-8}2^{i}) to about 0.0078
if
dtype
==
"bfloat16"
:
rtol
=
8e-03
else
:
rtol
=
1e-05
np
.
testing
.
assert_allclose
(
tr0_out
[
0
],
need_result
,
rtol
=
rtol
)
elif
col_type
==
"scatter"
:
elif
col_type
==
"scatter"
:
need_result
=
input2
need_result
=
input2
need_result1
=
need_result
[
0
:
need_result
.
shape
[
0
]
//
2
]
need_result1
=
need_result
[
0
:
need_result
.
shape
[
0
]
//
2
]
...
@@ -332,18 +355,28 @@ class TestDistBase(unittest.TestCase):
...
@@ -332,18 +355,28 @@ class TestDistBase(unittest.TestCase):
need_result
=
input1
+
input2
need_result
=
input1
+
input2
need_result1
=
need_result
[
0
:
need_result
.
shape
[
0
]
//
2
]
need_result1
=
need_result
[
0
:
need_result
.
shape
[
0
]
//
2
]
need_result2
=
need_result
[
need_result
.
shape
[
0
]
//
2
:]
need_result2
=
need_result
[
need_result
.
shape
[
0
]
//
2
:]
np
.
testing
.
assert_allclose
(
tr0_out
[
0
],
need_result1
,
rtol
=
1e-05
)
if
dtype
==
"bfloat16"
:
np
.
testing
.
assert_allclose
(
tr1_out
[
0
],
need_result2
,
rtol
=
1e-05
)
rtol
=
8e-03
else
:
rtol
=
1e-05
np
.
testing
.
assert_allclose
(
tr0_out
[
0
],
need_result1
,
rtol
=
rtol
)
np
.
testing
.
assert_allclose
(
tr1_out
[
0
],
need_result2
,
rtol
=
rtol
)
elif
col_type
==
"allreduce"
:
elif
col_type
==
"allreduce"
:
need_result
=
input1
+
input2
need_result
=
input1
+
input2
if
dtype
==
"bfloat16"
:
rtol
=
8e-03
atol
=
8e-03
else
:
rtol
=
1e-05
atol
=
1e-05
np
.
testing
.
assert_allclose
(
tr0_out
[
0
],
np
.
testing
.
assert_allclose
(
tr0_out
[
0
],
need_result
,
need_result
,
rtol
=
1e-05
,
rtol
=
rtol
,
atol
=
1e-05
)
atol
=
atol
)
np
.
testing
.
assert_allclose
(
tr1_out
[
0
],
np
.
testing
.
assert_allclose
(
tr1_out
[
0
],
need_result
,
need_result
,
rtol
=
1e-05
,
rtol
=
rtol
,
atol
=
1e-05
)
atol
=
atol
)
elif
col_type
==
"parallel_embedding"
:
elif
col_type
==
"parallel_embedding"
:
result_data
=
tr0_out
[
0
]
result_data
=
tr0_out
[
0
]
np
.
random
.
seed
(
2020
)
np
.
random
.
seed
(
2020
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录