Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d4cf02bc
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
d4cf02bc
编写于
7月 28, 2022
作者:
L
LiYuRio
提交者:
GitHub
7月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Complete the dtypes for all_gather, add all_gather_object api (#44417)
上级
768e50c9
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
594 addition
and
48 deletion
+594
-48
paddle/fluid/distributed/collective/ProcessGroupGloo.cc
paddle/fluid/distributed/collective/ProcessGroupGloo.cc
+9
-0
paddle/fluid/operators/collective/c_allgather_op.cc
paddle/fluid/operators/collective/c_allgather_op.cc
+3
-0
paddle/fluid/operators/collective/c_allgather_op.cu.cc
paddle/fluid/operators/collective/c_allgather_op.cu.cc
+3
-0
paddle/fluid/platform/device/gpu/nccl_helper.h
paddle/fluid/platform/device/gpu/nccl_helper.h
+10
-0
paddle/phi/kernels/cpu/split_kernel.cc
paddle/phi/kernels/cpu/split_kernel.cc
+2
-0
paddle/phi/kernels/gpu/split_kernel.cu
paddle/phi/kernels/gpu/split_kernel.cu
+2
-0
python/paddle/distributed/__init__.py
python/paddle/distributed/__init__.py
+5
-3
python/paddle/distributed/collective.py
python/paddle/distributed/collective.py
+93
-20
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+5
-1
python/paddle/fluid/tests/unittests/collective_allgather_api.py
.../paddle/fluid/tests/unittests/collective_allgather_api.py
+43
-7
python/paddle/fluid/tests/unittests/collective_allgather_api_dygraph.py
...fluid/tests/unittests/collective_allgather_api_dygraph.py
+37
-0
python/paddle/fluid/tests/unittests/collective_allgather_object_api_dygraph.py
...ests/unittests/collective_allgather_object_api_dygraph.py
+35
-0
python/paddle/fluid/tests/unittests/test_collective_allgather_api.py
...le/fluid/tests/unittests/test_collective_allgather_api.py
+204
-4
python/paddle/fluid/tests/unittests/test_collective_allgather_object_api.py
...d/tests/unittests/test_collective_allgather_object_api.py
+53
-0
python/paddle/fluid/tests/unittests/test_collective_api_base.py
.../paddle/fluid/tests/unittests/test_collective_api_base.py
+85
-9
python/paddle/tensor/manipulation.py
python/paddle/tensor/manipulation.py
+5
-4
未找到文件。
paddle/fluid/distributed/collective/ProcessGroupGloo.cc
浏览文件 @
d4cf02bc
...
@@ -79,6 +79,15 @@ namespace distributed {
...
@@ -79,6 +79,15 @@ namespace distributed {
case experimental::DataType::INT64: \
case experimental::DataType::INT64: \
func<int64_t>(args); \
func<int64_t>(args); \
break; \
break; \
case experimental::DataType::INT8: \
func<int8_t>(args); \
break; \
case experimental::DataType::UINT8: \
func<uint8_t>(args); \
break; \
case experimental::DataType::BOOL: \
func<bool>(args); \
break; \
default: \
default: \
VLOG(0) << "Error: Unknown DataType."; \
VLOG(0) << "Error: Unknown DataType."; \
exit(-1); \
exit(-1); \
...
...
paddle/fluid/operators/collective/c_allgather_op.cc
浏览文件 @
d4cf02bc
...
@@ -94,4 +94,7 @@ REGISTER_OP_CPU_KERNEL(c_allgather,
...
@@ -94,4 +94,7 @@ REGISTER_OP_CPU_KERNEL(c_allgather,
ops
::
CAllGatherOpCPUKernel
<
double
>
,
ops
::
CAllGatherOpCPUKernel
<
double
>
,
ops
::
CAllGatherOpCPUKernel
<
int
>
,
ops
::
CAllGatherOpCPUKernel
<
int
>
,
ops
::
CAllGatherOpCPUKernel
<
int64_t
>
,
ops
::
CAllGatherOpCPUKernel
<
int64_t
>
,
ops
::
CAllGatherOpCPUKernel
<
uint8_t
>
,
ops
::
CAllGatherOpCPUKernel
<
int8_t
>
,
ops
::
CAllGatherOpCPUKernel
<
bool
>
,
ops
::
CAllGatherOpCPUKernel
<
plat
::
float16
>
);
ops
::
CAllGatherOpCPUKernel
<
plat
::
float16
>
);
paddle/fluid/operators/collective/c_allgather_op.cu.cc
浏览文件 @
d4cf02bc
...
@@ -100,5 +100,8 @@ REGISTER_OP_CUDA_KERNEL(c_allgather,
...
@@ -100,5 +100,8 @@ REGISTER_OP_CUDA_KERNEL(c_allgather,
ops
::
CAllGatherOpCUDAKernel
<
plat
::
bfloat16
>
,
ops
::
CAllGatherOpCUDAKernel
<
plat
::
bfloat16
>
,
#endif
#endif
ops
::
CAllGatherOpCUDAKernel
<
int
>
,
ops
::
CAllGatherOpCUDAKernel
<
int
>
,
ops
::
CAllGatherOpCUDAKernel
<
uint8_t
>
,
ops
::
CAllGatherOpCUDAKernel
<
int8_t
>
,
ops
::
CAllGatherOpCUDAKernel
<
int64_t
>
,
ops
::
CAllGatherOpCUDAKernel
<
int64_t
>
,
ops
::
CAllGatherOpCUDAKernel
<
bool
>
,
ops
::
CAllGatherOpCUDAKernel
<
plat
::
float16
>
);
ops
::
CAllGatherOpCUDAKernel
<
plat
::
float16
>
);
paddle/fluid/platform/device/gpu/nccl_helper.h
浏览文件 @
d4cf02bc
...
@@ -55,6 +55,10 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
...
@@ -55,6 +55,10 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
return
ncclFloat16
;
return
ncclFloat16
;
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT8
)
{
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT8
)
{
return
ncclInt8
;
return
ncclInt8
;
}
else
if
(
type
==
framework
::
proto
::
VarType
::
UINT8
)
{
return
ncclUint8
;
}
else
if
(
type
==
framework
::
proto
::
VarType
::
BOOL
)
{
return
ncclUint8
;
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
}
else
if
(
type
==
framework
::
proto
::
VarType
::
BF16
)
{
}
else
if
(
type
==
framework
::
proto
::
VarType
::
BF16
)
{
return
ncclBfloat16
;
return
ncclBfloat16
;
...
@@ -76,6 +80,12 @@ inline ncclDataType_t ToNCCLDataType(experimental::DataType type) {
...
@@ -76,6 +80,12 @@ inline ncclDataType_t ToNCCLDataType(experimental::DataType type) {
return
ncclInt64
;
return
ncclInt64
;
}
else
if
(
type
==
experimental
::
DataType
::
FLOAT16
)
{
}
else
if
(
type
==
experimental
::
DataType
::
FLOAT16
)
{
return
ncclFloat16
;
return
ncclFloat16
;
}
else
if
(
type
==
experimental
::
DataType
::
UINT8
)
{
return
ncclUint8
;
}
else
if
(
type
==
experimental
::
DataType
::
INT8
)
{
return
ncclInt8
;
}
else
if
(
type
==
experimental
::
DataType
::
BOOL
)
{
return
ncclUint8
;
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
}
else
if
(
type
==
experimental
::
DataType
::
BFLOAT16
)
{
}
else
if
(
type
==
experimental
::
DataType
::
BFLOAT16
)
{
return
ncclBfloat16
;
return
ncclBfloat16
;
...
...
paddle/phi/kernels/cpu/split_kernel.cc
浏览文件 @
d4cf02bc
...
@@ -72,5 +72,7 @@ PD_REGISTER_KERNEL(split,
...
@@ -72,5 +72,7 @@ PD_REGISTER_KERNEL(split,
int64_t
,
int64_t
,
int
,
int
,
bool
,
bool
,
uint8_t
,
int8_t
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/split_kernel.cu
浏览文件 @
d4cf02bc
...
@@ -71,5 +71,7 @@ PD_REGISTER_KERNEL(split,
...
@@ -71,5 +71,7 @@ PD_REGISTER_KERNEL(split,
int64_t
,
int64_t
,
int
,
int
,
bool
,
bool
,
uint8_t
,
int8_t
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
phi
::
dtype
::
bfloat16
)
{}
python/paddle/distributed/__init__.py
浏览文件 @
d4cf02bc
...
@@ -31,6 +31,7 @@ from .collective import broadcast # noqa: F401
...
@@ -31,6 +31,7 @@ from .collective import broadcast # noqa: F401
from
.collective
import
all_reduce
# noqa: F401
from
.collective
import
all_reduce
# noqa: F401
from
.collective
import
reduce
# noqa: F401
from
.collective
import
reduce
# noqa: F401
from
.collective
import
all_gather
# noqa: F401
from
.collective
import
all_gather
# noqa: F401
from
.collective
import
all_gather_object
# noqa: F401
from
.collective
import
scatter
# noqa: F401
from
.collective
import
scatter
# noqa: F401
from
.collective
import
barrier
# noqa: F401
from
.collective
import
barrier
# noqa: F401
from
.collective
import
ReduceOp
# noqa: F401
from
.collective
import
ReduceOp
# noqa: F401
...
@@ -71,7 +72,8 @@ __all__ = [ # noqa
...
@@ -71,7 +72,8 @@ __all__ = [ # noqa
"init_parallel_env"
,
"gloo_init_parallel_env"
,
"gloo_barrier"
,
"init_parallel_env"
,
"gloo_init_parallel_env"
,
"gloo_barrier"
,
"gloo_release"
,
"QueueDataset"
,
"split"
,
"CountFilterEntry"
,
"gloo_release"
,
"QueueDataset"
,
"split"
,
"CountFilterEntry"
,
"ShowClickEntry"
,
"get_world_size"
,
"get_group"
,
"all_gather"
,
"ShowClickEntry"
,
"get_world_size"
,
"get_group"
,
"all_gather"
,
"InMemoryDataset"
,
"barrier"
,
"all_reduce"
,
"alltoall"
,
"send"
,
"reduce"
,
"all_gather_object"
,
"InMemoryDataset"
,
"barrier"
,
"all_reduce"
,
"alltoall"
,
"recv"
,
"ReduceOp"
,
"wait"
,
"get_rank"
,
"ProbabilityEntry"
,
"ParallelMode"
,
"send"
,
"reduce"
,
"recv"
,
"ReduceOp"
,
"wait"
,
"get_rank"
,
"is_initialized"
,
"isend"
,
"irecv"
,
"reduce_scatter"
"ProbabilityEntry"
,
"ParallelMode"
,
"is_initialized"
,
"isend"
,
"irecv"
,
"reduce_scatter"
]
]
python/paddle/distributed/collective.py
浏览文件 @
d4cf02bc
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
import
numpy
as
np
import
numpy
as
np
import
os
import
os
import
pickle
import
io
from
datetime
import
timedelta
from
datetime
import
timedelta
from
..fluid.layer_helper
import
LayerHelper
from
..fluid.layer_helper
import
LayerHelper
from
..fluid.framework
import
Variable
from
..fluid.framework
import
Variable
...
@@ -927,9 +929,9 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
...
@@ -927,9 +929,9 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
Args:
Args:
tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32
or int64
.
should be float16, float32, float64, int32
, int64, int8, uint8, bool, complex64 or complex128
.
tensor (Tensor): The Tensor to send. Its data type
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32
or int64
.
should be float16, float32, float64, int32
, int64, int8, uint8, bool, complex64 or complex128
.
group (Group): The group instance return by new_group or None for global default group.
group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
Default to True.
Default to True.
...
@@ -941,7 +943,6 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
...
@@ -941,7 +943,6 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
.. code-block:: python
.. code-block:: python
# required: distributed
# required: distributed
import numpy as np
import paddle
import paddle
from paddle.distributed import init_parallel_env
from paddle.distributed import init_parallel_env
...
@@ -949,21 +950,26 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
...
@@ -949,21 +950,26 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
init_parallel_env()
init_parallel_env()
tensor_list = []
tensor_list = []
if paddle.distributed.ParallelEnv().local_rank == 0:
if paddle.distributed.ParallelEnv().local_rank == 0:
np_data1 = np.array([[4, 5, 6], [4, 5, 6]])
data1 = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
np_data2 = np.array([[4, 5, 6], [4, 5, 6]])
data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2)
paddle.distributed.all_gather(tensor_list, data1)
paddle.distributed.all_gather(tensor_list, data1)
else:
else:
np_data1 = np.array([[1, 2, 3], [1, 2, 3]])
data2 = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
np_data2 = np.array([[1, 2, 3], [1, 2, 3]])
data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2)
paddle.distributed.all_gather(tensor_list, data2)
paddle.distributed.all_gather(tensor_list, data2)
"""
"""
if
group
is
not
None
and
not
group
.
is_member
():
if
group
is
not
None
and
not
group
.
is_member
():
return
return
def
convert_to_complex
(
list_of_tensor
):
list_of_complex
=
[]
for
tensor
in
list_of_tensor
:
list_of_complex
.
append
(
paddle
.
as_complex
(
tensor
))
return
list_of_complex
is_input_complex
=
(
tensor
.
dtype
==
paddle
.
complex64
or
tensor
.
dtype
==
paddle
.
complex128
)
if
is_input_complex
:
tensor
=
paddle
.
as_real
(
tensor
)
if
in_dygraph_mode
():
if
in_dygraph_mode
():
group
=
_get_default_group
()
if
group
is
None
else
group
group
=
_get_default_group
()
if
group
is
None
else
group
if
len
(
tensor_list
)
==
0
:
if
len
(
tensor_list
)
==
0
:
...
@@ -975,7 +981,11 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
...
@@ -975,7 +981,11 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
task
=
group
.
process_group
.
all_gather
(
tensor
,
out
)
task
=
group
.
process_group
.
all_gather
(
tensor
,
out
)
task
.
wait
()
task
.
wait
()
tensor_list
.
clear
()
tensor_list
.
clear
()
tensor_list
.
extend
(
paddle
.
split
(
out
,
group
.
nranks
,
0
))
list_of_tensor
=
paddle
.
split
(
out
,
group
.
nranks
,
0
)
if
is_input_complex
:
tensor_list
.
extend
(
convert_to_complex
(
list_of_tensor
))
else
:
tensor_list
.
extend
(
list_of_tensor
)
return
return
ring_id
=
0
if
group
is
None
else
group
.
id
ring_id
=
0
if
group
is
None
else
group
.
id
...
@@ -992,13 +1002,14 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
...
@@ -992,13 +1002,14 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
raise
ValueError
(
"The type of 'tensor_list' for all_gather "
raise
ValueError
(
"The type of 'tensor_list' for all_gather "
"should be list."
)
"should be list."
)
for
elem
in
tensor_list
:
for
elem
in
tensor_list
:
check_variable_and_dtype
(
check_variable_and_dtype
(
elem
,
'tensor_list'
,
[
elem
,
'tensor_list'
,
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'bool'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'int8'
,
'uint8'
,
'complex64'
,
'complex128'
'all_gather'
)
],
'all_gather'
)
check_variable_and_dtype
(
check_variable_and_dtype
(
tensor
,
'tensor'
,
[
tensor
,
'tensor'
,
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'bool'
,
'int8'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'all_gather'
)
'uint8'
,
'complex64'
,
'complex128'
],
'all_gather'
)
helper
.
append_op
(
type
=
op_type
,
helper
.
append_op
(
type
=
op_type
,
inputs
=
{
'X'
:
[
tensor
]},
inputs
=
{
'X'
:
[
tensor
]},
outputs
=
{
'Out'
:
[
out
]},
outputs
=
{
'Out'
:
[
out
]},
...
@@ -1008,7 +1019,69 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
...
@@ -1008,7 +1019,69 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
'nranks'
:
nranks
'nranks'
:
nranks
})
})
tensor_list
.
extend
(
paddle
.
split
(
out
,
nranks
,
0
))
list_of_tensor
=
paddle
.
split
(
out
,
nranks
,
0
)
if
is_input_complex
:
tensor_list
.
extend
(
convert_to_complex
(
list_of_tensor
))
else
:
tensor_list
.
extend
(
list_of_tensor
)
def
_convert_object_to_tensor
(
obj
):
_pickler
=
pickle
.
Pickler
f
=
io
.
BytesIO
()
_pickler
(
f
).
dump
(
obj
)
data
=
np
.
frombuffer
(
f
.
getvalue
(),
dtype
=
np
.
uint8
)
tensor
=
paddle
.
to_tensor
(
data
)
return
tensor
def
_convert_tensor_to_object
(
tensor
):
_unpickler
=
pickle
.
Unpickler
return
_unpickler
(
io
.
BytesIO
(
tensor
.
numpy
())).
load
()
def
all_gather_object
(
object_list
,
obj
,
group
=
None
):
"""
Gather picklable objects from all participators and all get the result. Similiar to all_gather(), but python object can be passed in.
Args:
object_list (list): A list of output object. The datatype of every element in the list is same as the input obj.
obj (Any): The picklable object to send.
group (Group): The group instance return by new_group or None for global default group.
Returns:
None.
Warning:
This API only supports the dygraph mode.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
dist.init_parallel_env()
object_list = []
if paddle.distributed.ParallelEnv().local_rank == 0:
obj = {"foo": [1, 2, 3]}
paddle.distributed.all_gather_object(object_list, obj)
else:
obj = {"bar": [4, 5, 6]}
paddle.distributed.all_gather_object(object_list, obj)
"""
assert
in_dygraph_mode
(
),
"all_gather_object doesn't support static graph mode."
tensor
=
_convert_object_to_tensor
(
obj
)
tensor_list
=
[]
all_gather
(
tensor_list
,
tensor
,
group
)
for
tensor
in
tensor_list
:
object_list
.
append
(
_convert_tensor_to_object
(
tensor
))
def
scatter
(
tensor
,
tensor_list
=
None
,
src
=
0
,
group
=
None
,
use_calc_stream
=
True
):
def
scatter
(
tensor
,
tensor_list
=
None
,
src
=
0
,
group
=
None
,
use_calc_stream
=
True
):
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
d4cf02bc
...
@@ -183,6 +183,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
...
@@ -183,6 +183,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
list
(
REMOVE_ITEM TEST_OPS test_new_group_api
)
list
(
REMOVE_ITEM TEST_OPS test_new_group_api
)
list
(
REMOVE_ITEM TEST_OPS test_collective_broadcast_api
)
list
(
REMOVE_ITEM TEST_OPS test_collective_broadcast_api
)
list
(
REMOVE_ITEM TEST_OPS test_collective_allgather_api
)
list
(
REMOVE_ITEM TEST_OPS test_collective_allgather_api
)
list
(
REMOVE_ITEM TEST_OPS test_collective_allgather_object_api
)
list
(
REMOVE_ITEM TEST_OPS test_collective_alltoall_api
)
list
(
REMOVE_ITEM TEST_OPS test_collective_alltoall_api
)
list
(
REMOVE_ITEM TEST_OPS test_collective_global_gather
)
list
(
REMOVE_ITEM TEST_OPS test_collective_global_gather
)
list
(
REMOVE_ITEM TEST_OPS test_collective_global_scatter
)
list
(
REMOVE_ITEM TEST_OPS test_collective_global_scatter
)
...
@@ -1598,7 +1599,9 @@ if(APPLE)
...
@@ -1598,7 +1599,9 @@ if(APPLE)
endif
()
endif
()
if
((
WITH_ROCM OR WITH_GPU
)
AND NOT WIN32
)
if
((
WITH_ROCM OR WITH_GPU
)
AND NOT WIN32
)
set_tests_properties
(
test_collective_allgather_api PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_collective_allgather_api PROPERTIES TIMEOUT 300
)
set_tests_properties
(
test_collective_allgather_object_api PROPERTIES TIMEOUT
120
)
set_tests_properties
(
test_collective_alltoall_api PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_collective_alltoall_api PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_collective_global_gather PROPERTIES TIMEOUT 200
)
set_tests_properties
(
test_collective_global_gather PROPERTIES TIMEOUT 200
)
set_tests_properties
(
test_collective_global_scatter PROPERTIES TIMEOUT 200
)
set_tests_properties
(
test_collective_global_scatter PROPERTIES TIMEOUT 200
)
...
@@ -1629,6 +1632,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
...
@@ -1629,6 +1632,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
test_new_group_api
test_new_group_api
test_collective_broadcast_api
test_collective_broadcast_api
test_collective_allgather_api
test_collective_allgather_api
test_collective_allgather_object_api
test_collective_alltoall_api
test_collective_alltoall_api
test_collective_global_gather
test_collective_global_gather
test_collective_global_scatter
test_collective_global_scatter
...
...
python/paddle/fluid/tests/unittests/collective_allgather_api.py
浏览文件 @
d4cf02bc
...
@@ -30,28 +30,64 @@ import paddle.fluid.profiler as profiler
...
@@ -30,28 +30,64 @@ import paddle.fluid.profiler as profiler
import
paddle.fluid.unique_name
as
nameGen
import
paddle.fluid.unique_name
as
nameGen
from
paddle.fluid
import
core
from
paddle.fluid
import
core
import
unittest
import
unittest
import
pickle
from
multiprocessing
import
Process
from
multiprocessing
import
Process
import
paddle.fluid.layers
as
layers
import
paddle.fluid.layers
as
layers
from
functools
import
reduce
from
functools
import
reduce
from
test_collective_api_base
import
TestCollectiveAPIRunnerBase
,
runtime_main
import
test_collective_api_base
as
test_base
paddle
.
enable_static
()
paddle
.
enable_static
()
class
TestCollectiveAllgatherAPI
(
TestCollectiveAPIRunnerBase
):
class
TestCollectiveAllgatherAPI
(
test_base
.
TestCollectiveAPIRunnerBase
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
global_ring_id
=
0
self
.
global_ring_id
=
0
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
):
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
dtype
=
None
):
dtype
=
"float32"
if
dtype
is
None
else
dtype
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tensor_list
=
[]
tensor_list
=
[]
tindata
=
layers
.
data
(
name
=
"tindata"
,
tindata
=
layers
.
data
(
name
=
"tindata"
,
shape
=
[
10
,
1000
],
dtype
=
dtype
)
shape
=
[
10
,
1000
],
dtype
=
'float32'
)
paddle
.
distributed
.
all_gather
(
tensor_list
,
tindata
)
paddle
.
distributed
.
all_gather
(
tensor_list
,
tindata
)
return
tensor_list
return
tensor_list
def
run_trainer
(
self
,
args
):
train_prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
endpoints
=
args
[
"endpoints"
].
split
(
","
)
rank
=
args
[
"trainerid"
]
current_endpoint
=
args
[
"currentendpoint"
]
nranks
=
2
paddle
.
distributed
.
init_parallel_env
()
if
args
[
'backend'
]
==
'nccl'
:
device_id
=
int
(
os
.
getenv
(
"FLAGS_selected_gpus"
,
"0"
))
place
=
fluid
.
CUDAPlace
(
device_id
)
#if args.use_gpu else fluid.CPUPlace()
elif
args
[
'backend'
]
==
'bkcl'
:
device_id
=
int
(
os
.
getenv
(
"FLAGS_selected_xpus"
,
"0"
))
place
=
fluid
.
XPUPlace
(
device_id
)
else
:
place
=
fluid
.
CPUPlace
()
indata
=
test_base
.
create_test_data
(
shape
=
(
10
,
1000
),
dtype
=
args
[
"dtype"
],
seed
=
os
.
getpid
())
assert
args
[
'static_mode'
]
==
1
,
"collective_allgather_api only support static mode"
result
=
self
.
get_model
(
train_prog
,
startup_prog
,
rank
,
dtype
=
args
[
"dtype"
])
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_prog
)
fetch_list
=
[]
for
elem
in
result
:
fetch_list
.
append
(
elem
.
name
)
out
=
exe
.
run
(
train_prog
,
feed
=
{
'tindata'
:
indata
},
fetch_list
=
fetch_list
)
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
runtime_main
(
TestCollectiveAllgatherAPI
,
"allgather"
)
test_base
.
runtime_main
(
TestCollectiveAllgatherAPI
,
"allgather"
)
python/paddle/fluid/tests/unittests/collective_allgather_api_dygraph.py
0 → 100644
浏览文件 @
d4cf02bc
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
paddle
import
paddle.fluid
as
fluid
import
unittest
import
test_collective_api_base
as
test_base
class
TestCollectiveAllgatherAPI
(
test_base
.
TestCollectiveAPIRunnerBase
):
def
__init__
(
self
):
self
.
global_ring_id
=
0
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
paddle
.
to_tensor
(
indata
)
tensor_list
=
[]
paddle
.
distributed
.
all_gather
(
tensor_list
,
tindata
)
return
[
tensor
.
numpy
()
for
tensor
in
tensor_list
]
if
__name__
==
"__main__"
:
test_base
.
runtime_main
(
TestCollectiveAllgatherAPI
,
"allgather"
)
python/paddle/fluid/tests/unittests/collective_allgather_object_api_dygraph.py
0 → 100644
浏览文件 @
d4cf02bc
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
paddle
import
paddle.fluid
as
fluid
import
test_collective_api_base
as
test_base
class
TestCollectiveAllgatherObjectAPI
(
test_base
.
TestCollectiveAPIRunnerBase
):
def
__init__
(
self
):
self
.
global_ring_id
=
0
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
object_list
=
[]
paddle
.
distributed
.
all_gather_object
(
object_list
,
indata
)
return
object_list
if
__name__
==
"__main__"
:
test_base
.
runtime_main
(
TestCollectiveAllgatherObjectAPI
,
"allgather_object"
)
python/paddle/fluid/tests/unittests/test_collective_allgather_api.py
浏览文件 @
d4cf02bc
...
@@ -28,12 +28,212 @@ class TestCollectiveAllgatherAPI(TestDistBase):
...
@@ -28,12 +28,212 @@ class TestCollectiveAllgatherAPI(TestDistBase):
pass
pass
def
test_allgather_nccl
(
self
):
def
test_allgather_nccl
(
self
):
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
self
.
check_with_place
(
"collective_allgather_api.py"
,
"nccl"
)
"allgather"
,
"nccl"
,
dtype
=
"float16"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"float32"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"float64"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"bool"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"uint8"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"int8"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"int32"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"int64"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"complex64"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"complex128"
)
def
test_allgather_gloo
(
self
):
def
test_allgather_gloo
(
self
):
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
self
.
check_with_place
(
"collective_allgather_api.py"
,
"gloo"
,
"3"
)
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"float16"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"float32"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"float64"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"bool"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"uint8"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"int8"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"int32"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"int64"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"complex64"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"complex128"
)
def
test_allgatther_nccl_dygraph
(
self
):
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"float16"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"float32"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"float64"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"bool"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"uint8"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"int8"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"int32"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"int64"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"complex64"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"complex128"
)
def
test_allgather_gloo_dygraph
(
self
):
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"float16"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"float32"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"float64"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"bool"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"uint8"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"int8"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"int32"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"int64"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"complex64"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"complex128"
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/test_collective_allgather_object_api.py
0 → 100644
浏览文件 @
d4cf02bc
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
test_collective_api_base
as
test_base
class
TestCollectiveAllgatherObjectAPI
(
test_base
.
TestDistBase
):
def
_setup_config
(
self
):
pass
def
test_allgather_nccl
(
self
):
self
.
check_with_place
(
"collective_allgather_object_api_dygraph.py"
,
"allgather_object"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"pylist"
)
self
.
check_with_place
(
"collective_allgather_object_api_dygraph.py"
,
"allgather_object"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"pydict"
)
def
test_allgather_gloo_dygraph
(
self
):
self
.
check_with_place
(
"collective_allgather_object_api_dygraph.py"
,
"allgather_object"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"pylist"
)
self
.
check_with_place
(
"collective_allgather_object_api_dygraph.py"
,
"allgather_object"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"pydict"
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_collective_api_base.py
浏览文件 @
d4cf02bc
...
@@ -31,9 +31,77 @@ import paddle.fluid.unique_name as nameGen
...
@@ -31,9 +31,77 @@ import paddle.fluid.unique_name as nameGen
from
paddle.fluid
import
core
from
paddle.fluid
import
core
def
create_bool_test_data
(
shape
=
None
,
seed
=
None
):
if
seed
:
np
.
random
.
seed
(
seed
)
data
=
np
.
random
.
choice
([
True
,
False
],
size
=
shape
)
return
data
def
create_float_test_data
(
shape
=
None
,
dtype
=
None
,
seed
=
None
):
if
seed
:
np
.
random
.
seed
(
seed
)
data
=
np
.
random
.
random
(
shape
).
astype
(
dtype
)
return
data
def
create_int_test_data
(
shape
=
None
,
dtype
=
None
,
seed
=
None
):
if
seed
:
np
.
random
.
seed
(
seed
)
data
=
np
.
random
.
randint
(
0
,
high
=
100
,
size
=
shape
).
astype
(
dtype
)
return
data
def
create_complex_test_data
(
shape
=
None
,
dtype
=
None
,
seed
=
None
):
if
seed
:
np
.
random
.
seed
(
seed
)
data
=
np
.
random
.
random
(
shape
).
astype
(
dtype
)
data
.
imag
=
np
.
random
.
random
(
shape
)
return
data
def
create_pylist_test_data
(
shape
=
None
,
seed
=
None
):
if
seed
:
np
.
random
.
seed
(
seed
)
data
=
np
.
random
.
random
(
shape
).
tolist
()
return
data
def
create_pydict_test_data
(
shape
=
None
,
seed
=
None
):
if
seed
:
np
.
random
.
seed
(
seed
)
key
=
[
i
for
i
in
range
(
0
,
shape
[
0
])]
value
=
np
.
random
.
random
(
shape
).
tolist
()
data
=
dict
(
zip
(
key
,
value
))
return
data
def
create_test_data
(
shape
=
None
,
dtype
=
None
,
seed
=
None
):
assert
shape
,
"Shape should be specified"
if
dtype
==
"float32"
or
dtype
==
"float16"
or
dtype
==
"float64"
:
return
create_float_test_data
(
shape
=
shape
,
dtype
=
dtype
,
seed
=
seed
)
elif
dtype
==
"bool"
:
return
create_bool_test_data
(
shape
=
shape
,
seed
=
seed
)
elif
dtype
==
"int32"
or
dtype
==
"int64"
or
dtype
==
"int8"
or
dtype
==
"uint8"
:
return
create_int_test_data
(
shape
=
shape
,
dtype
=
dtype
,
seed
=
seed
)
elif
dtype
==
"complex64"
or
dtype
==
"complex128"
:
return
create_complex_test_data
(
shape
=
shape
,
dtype
=
dtype
,
seed
=
seed
)
elif
dtype
==
"pylist"
:
return
create_pylist_test_data
(
shape
=
shape
,
seed
=
seed
)
elif
dtype
==
"pydict"
:
return
create_pydict_test_data
(
shape
=
shape
,
seed
=
seed
)
else
:
raise
NotImplementedError
(
"Unsupported dtype for creating test data."
)
class
TestCollectiveAPIRunnerBase
(
object
):
class
TestCollectiveAPIRunnerBase
(
object
):
def
get_model
(
self
,
train_prog
,
startup_prog
,
rank
,
indata
=
None
):
def
get_model
(
self
,
train_prog
,
startup_prog
,
rank
,
indata
=
None
,
dtype
=
None
):
raise
NotImplementedError
(
raise
NotImplementedError
(
"get model should be implemented by child class."
)
"get model should be implemented by child class."
)
...
@@ -54,8 +122,9 @@ class TestCollectiveAPIRunnerBase(object):
...
@@ -54,8 +122,9 @@ class TestCollectiveAPIRunnerBase(object):
place
=
fluid
.
XPUPlace
(
device_id
)
place
=
fluid
.
XPUPlace
(
device_id
)
else
:
else
:
place
=
fluid
.
CPUPlace
()
place
=
fluid
.
CPUPlace
()
np
.
random
.
seed
(
os
.
getpid
())
indata
=
create_test_data
(
shape
=
(
10
,
1000
),
indata
=
np
.
random
.
random
((
10
,
1000
)).
astype
(
"float32"
)
dtype
=
args
[
"dtype"
],
seed
=
os
.
getpid
())
if
args
[
'static_mode'
]:
if
args
[
'static_mode'
]:
result
=
self
.
get_model
(
train_prog
,
startup_prog
,
rank
)
result
=
self
.
get_model
(
train_prog
,
startup_prog
,
rank
)
exe
=
fluid
.
Executor
(
place
)
exe
=
fluid
.
Executor
(
place
)
...
@@ -83,6 +152,7 @@ def runtime_main(test_class, col_type):
...
@@ -83,6 +152,7 @@ def runtime_main(test_class, col_type):
args
[
"backend"
]
=
os
.
getenv
(
"BACKEND"
)
args
[
"backend"
]
=
os
.
getenv
(
"BACKEND"
)
args
[
"path_id"
]
=
int
(
os
.
getenv
(
"PATH_ID"
))
args
[
"path_id"
]
=
int
(
os
.
getenv
(
"PATH_ID"
))
args
[
"static_mode"
]
=
int
(
os
.
getenv
(
"STATIC_MODE"
))
args
[
"static_mode"
]
=
int
(
os
.
getenv
(
"STATIC_MODE"
))
args
[
"dtype"
]
=
os
.
getenv
(
"DTYPE"
)
model
.
run_trainer
(
args
)
model
.
run_trainer
(
args
)
...
@@ -203,18 +273,22 @@ class TestDistBase(unittest.TestCase):
...
@@ -203,18 +273,22 @@ class TestDistBase(unittest.TestCase):
static_mode
=
"1"
,
static_mode
=
"1"
,
check_error_log
=
False
,
check_error_log
=
False
,
need_envs
=
{},
need_envs
=
{},
eager_mode
=
True
):
eager_mode
=
True
,
dtype
=
None
):
if
backend
==
"nccl"
or
backend
==
"bkcl"
:
if
backend
==
"nccl"
or
backend
==
"bkcl"
:
with_gloo
=
'0'
with_gloo
=
'0'
else
:
else
:
with_gloo
=
'1'
with_gloo
=
'1'
required_envs
=
os
.
environ
.
copy
()
required_envs
=
os
.
environ
.
copy
()
dtype
=
"float32"
if
dtype
is
None
else
dtype
additional_envs
=
{
additional_envs
=
{
"NCCL_P2P_DISABLE"
:
"1"
,
"NCCL_P2P_DISABLE"
:
"1"
,
"STATIC_MODE"
:
static_mode
,
"STATIC_MODE"
:
static_mode
,
"PADDLE_WITH_GLOO"
:
with_gloo
,
"PADDLE_WITH_GLOO"
:
with_gloo
,
"PADDLE_DISTRI_BACKEND"
:
backend
,
"BACKEND"
:
backend
,
"BACKEND"
:
backend
,
"PATH_ID"
:
path_id
"PATH_ID"
:
path_id
,
"DTYPE"
:
dtype
}
}
required_envs
.
update
(
additional_envs
)
required_envs
.
update
(
additional_envs
)
required_envs
.
update
(
need_envs
)
required_envs
.
update
(
need_envs
)
...
@@ -234,16 +308,18 @@ class TestDistBase(unittest.TestCase):
...
@@ -234,16 +308,18 @@ class TestDistBase(unittest.TestCase):
tr0_out
,
tr1_out
,
pid0
,
pid1
=
self
.
_run_cluster
(
tr0_out
,
tr1_out
,
pid0
,
pid1
=
self
.
_run_cluster
(
model_file
,
required_envs
)
model_file
,
required_envs
)
np
.
random
.
seed
(
pid0
)
input1
=
create_test_data
(
shape
=
(
10
,
1000
),
dtype
=
dtype
,
seed
=
pid0
)
input1
=
np
.
random
.
random
((
10
,
1000
))
input2
=
create_test_data
(
shape
=
(
10
,
1000
),
dtype
=
dtype
,
seed
=
pid1
)
np
.
random
.
seed
(
pid1
)
input2
=
np
.
random
.
random
((
10
,
1000
))
if
col_type
==
"allgather"
:
if
col_type
==
"allgather"
:
need_result
=
np
.
vstack
((
input1
,
input2
))
need_result
=
np
.
vstack
((
input1
,
input2
))
tr_out0
=
np
.
vstack
((
tr0_out
[
0
],
tr0_out
[
1
]))
tr_out0
=
np
.
vstack
((
tr0_out
[
0
],
tr0_out
[
1
]))
tr_out1
=
np
.
vstack
((
tr1_out
[
0
],
tr1_out
[
1
]))
tr_out1
=
np
.
vstack
((
tr1_out
[
0
],
tr1_out
[
1
]))
self
.
assertTrue
(
np
.
allclose
(
tr_out0
,
need_result
))
self
.
assertTrue
(
np
.
allclose
(
tr_out0
,
need_result
))
self
.
assertTrue
(
np
.
allclose
(
tr_out1
,
need_result
))
self
.
assertTrue
(
np
.
allclose
(
tr_out1
,
need_result
))
if
col_type
==
"allgather_object"
:
need_result
=
[
input1
,
input2
]
self
.
assertEqual
(
need_result
,
tr0_out
)
self
.
assertEqual
(
need_result
,
tr1_out
)
elif
col_type
==
"broadcast"
:
elif
col_type
==
"broadcast"
:
need_result
=
input2
need_result
=
input2
self
.
assertTrue
(
np
.
allclose
(
tr0_out
,
need_result
))
self
.
assertTrue
(
np
.
allclose
(
tr0_out
,
need_result
))
...
...
python/paddle/tensor/manipulation.py
浏览文件 @
d4cf02bc
...
@@ -1737,7 +1737,7 @@ def split(x, num_or_sections, axis=0, name=None):
...
@@ -1737,7 +1737,7 @@ def split(x, num_or_sections, axis=0, name=None):
Split the input tensor into multiple sub-Tensors.
Split the input tensor into multiple sub-Tensors.
Args:
Args:
x (Tensor): A N-D Tensor. The data type is bool, float16, float32, float64, int32 or int64.
x (Tensor): A N-D Tensor. The data type is bool, float16, float32, float64,
uint8, int8,
int32 or int64.
num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections``
num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections``
indicates the number of equal sized sub-Tensors that the ``x`` will be divided into.
indicates the number of equal sized sub-Tensors that the ``x`` will be divided into.
If ``num_or_sections`` is a list or tuple, the length of it indicates the number of
If ``num_or_sections`` is a list or tuple, the length of it indicates the number of
...
@@ -1814,9 +1814,10 @@ def split(x, num_or_sections, axis=0, name=None):
...
@@ -1814,9 +1814,10 @@ def split(x, num_or_sections, axis=0, name=None):
_C_ops
.
split
(
input
,
out
,
*
attrs
)
_C_ops
.
split
(
input
,
out
,
*
attrs
)
return
out
return
out
check_variable_and_dtype
(
check_variable_and_dtype
(
input
,
'input'
,
[
input
,
'input'
,
'bool'
,
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint8'
,
[
'bool'
,
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'split'
)
'int8'
],
'split'
)
check_type
(
num_or_sections
,
'num_or_sections'
,
(
list
,
int
,
tuple
),
'split'
)
check_type
(
num_or_sections
,
'num_or_sections'
,
(
list
,
int
,
tuple
),
'split'
)
check_type
(
dim
,
'dim'
,
(
int
,
Variable
),
'split'
)
check_type
(
dim
,
'dim'
,
(
int
,
Variable
),
'split'
)
if
isinstance
(
dim
,
Variable
):
if
isinstance
(
dim
,
Variable
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录