Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d4cf02bc
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d4cf02bc
编写于
7月 28, 2022
作者:
L
LiYuRio
提交者:
GitHub
7月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Complete the dtypes for all_gather, add all_gather_object api (#44417)
上级
768e50c9
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
594 addition
and
48 deletion
+594
-48
paddle/fluid/distributed/collective/ProcessGroupGloo.cc
paddle/fluid/distributed/collective/ProcessGroupGloo.cc
+9
-0
paddle/fluid/operators/collective/c_allgather_op.cc
paddle/fluid/operators/collective/c_allgather_op.cc
+3
-0
paddle/fluid/operators/collective/c_allgather_op.cu.cc
paddle/fluid/operators/collective/c_allgather_op.cu.cc
+3
-0
paddle/fluid/platform/device/gpu/nccl_helper.h
paddle/fluid/platform/device/gpu/nccl_helper.h
+10
-0
paddle/phi/kernels/cpu/split_kernel.cc
paddle/phi/kernels/cpu/split_kernel.cc
+2
-0
paddle/phi/kernels/gpu/split_kernel.cu
paddle/phi/kernels/gpu/split_kernel.cu
+2
-0
python/paddle/distributed/__init__.py
python/paddle/distributed/__init__.py
+5
-3
python/paddle/distributed/collective.py
python/paddle/distributed/collective.py
+93
-20
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+5
-1
python/paddle/fluid/tests/unittests/collective_allgather_api.py
.../paddle/fluid/tests/unittests/collective_allgather_api.py
+43
-7
python/paddle/fluid/tests/unittests/collective_allgather_api_dygraph.py
...fluid/tests/unittests/collective_allgather_api_dygraph.py
+37
-0
python/paddle/fluid/tests/unittests/collective_allgather_object_api_dygraph.py
...ests/unittests/collective_allgather_object_api_dygraph.py
+35
-0
python/paddle/fluid/tests/unittests/test_collective_allgather_api.py
...le/fluid/tests/unittests/test_collective_allgather_api.py
+204
-4
python/paddle/fluid/tests/unittests/test_collective_allgather_object_api.py
...d/tests/unittests/test_collective_allgather_object_api.py
+53
-0
python/paddle/fluid/tests/unittests/test_collective_api_base.py
.../paddle/fluid/tests/unittests/test_collective_api_base.py
+85
-9
python/paddle/tensor/manipulation.py
python/paddle/tensor/manipulation.py
+5
-4
未找到文件。
paddle/fluid/distributed/collective/ProcessGroupGloo.cc
浏览文件 @
d4cf02bc
...
...
@@ -79,6 +79,15 @@ namespace distributed {
case experimental::DataType::INT64: \
func<int64_t>(args); \
break; \
case experimental::DataType::INT8: \
func<int8_t>(args); \
break; \
case experimental::DataType::UINT8: \
func<uint8_t>(args); \
break; \
case experimental::DataType::BOOL: \
func<bool>(args); \
break; \
default: \
VLOG(0) << "Error: Unknown DataType."; \
exit(-1); \
...
...
paddle/fluid/operators/collective/c_allgather_op.cc
浏览文件 @
d4cf02bc
...
...
@@ -94,4 +94,7 @@ REGISTER_OP_CPU_KERNEL(c_allgather,
ops
::
CAllGatherOpCPUKernel
<
double
>
,
ops
::
CAllGatherOpCPUKernel
<
int
>
,
ops
::
CAllGatherOpCPUKernel
<
int64_t
>
,
ops
::
CAllGatherOpCPUKernel
<
uint8_t
>
,
ops
::
CAllGatherOpCPUKernel
<
int8_t
>
,
ops
::
CAllGatherOpCPUKernel
<
bool
>
,
ops
::
CAllGatherOpCPUKernel
<
plat
::
float16
>
);
paddle/fluid/operators/collective/c_allgather_op.cu.cc
浏览文件 @
d4cf02bc
...
...
@@ -100,5 +100,8 @@ REGISTER_OP_CUDA_KERNEL(c_allgather,
ops
::
CAllGatherOpCUDAKernel
<
plat
::
bfloat16
>
,
#endif
ops
::
CAllGatherOpCUDAKernel
<
int
>
,
ops
::
CAllGatherOpCUDAKernel
<
uint8_t
>
,
ops
::
CAllGatherOpCUDAKernel
<
int8_t
>
,
ops
::
CAllGatherOpCUDAKernel
<
int64_t
>
,
ops
::
CAllGatherOpCUDAKernel
<
bool
>
,
ops
::
CAllGatherOpCUDAKernel
<
plat
::
float16
>
);
paddle/fluid/platform/device/gpu/nccl_helper.h
浏览文件 @
d4cf02bc
...
...
@@ -55,6 +55,10 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
return
ncclFloat16
;
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT8
)
{
return
ncclInt8
;
}
else
if
(
type
==
framework
::
proto
::
VarType
::
UINT8
)
{
return
ncclUint8
;
}
else
if
(
type
==
framework
::
proto
::
VarType
::
BOOL
)
{
return
ncclUint8
;
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
}
else
if
(
type
==
framework
::
proto
::
VarType
::
BF16
)
{
return
ncclBfloat16
;
...
...
@@ -76,6 +80,12 @@ inline ncclDataType_t ToNCCLDataType(experimental::DataType type) {
return
ncclInt64
;
}
else
if
(
type
==
experimental
::
DataType
::
FLOAT16
)
{
return
ncclFloat16
;
}
else
if
(
type
==
experimental
::
DataType
::
UINT8
)
{
return
ncclUint8
;
}
else
if
(
type
==
experimental
::
DataType
::
INT8
)
{
return
ncclInt8
;
}
else
if
(
type
==
experimental
::
DataType
::
BOOL
)
{
return
ncclUint8
;
#if CUDNN_VERSION_MIN(8, 1, 0) && NCCL_VERSION_CODE >= 21000
}
else
if
(
type
==
experimental
::
DataType
::
BFLOAT16
)
{
return
ncclBfloat16
;
...
...
paddle/phi/kernels/cpu/split_kernel.cc
浏览文件 @
d4cf02bc
...
...
@@ -72,5 +72,7 @@ PD_REGISTER_KERNEL(split,
int64_t
,
int
,
bool
,
uint8_t
,
int8_t
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/split_kernel.cu
浏览文件 @
d4cf02bc
...
...
@@ -71,5 +71,7 @@ PD_REGISTER_KERNEL(split,
int64_t
,
int
,
bool
,
uint8_t
,
int8_t
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
python/paddle/distributed/__init__.py
浏览文件 @
d4cf02bc
...
...
@@ -31,6 +31,7 @@ from .collective import broadcast # noqa: F401
from
.collective
import
all_reduce
# noqa: F401
from
.collective
import
reduce
# noqa: F401
from
.collective
import
all_gather
# noqa: F401
from
.collective
import
all_gather_object
# noqa: F401
from
.collective
import
scatter
# noqa: F401
from
.collective
import
barrier
# noqa: F401
from
.collective
import
ReduceOp
# noqa: F401
...
...
@@ -71,7 +72,8 @@ __all__ = [ # noqa
"init_parallel_env"
,
"gloo_init_parallel_env"
,
"gloo_barrier"
,
"gloo_release"
,
"QueueDataset"
,
"split"
,
"CountFilterEntry"
,
"ShowClickEntry"
,
"get_world_size"
,
"get_group"
,
"all_gather"
,
"InMemoryDataset"
,
"barrier"
,
"all_reduce"
,
"alltoall"
,
"send"
,
"reduce"
,
"recv"
,
"ReduceOp"
,
"wait"
,
"get_rank"
,
"ProbabilityEntry"
,
"ParallelMode"
,
"is_initialized"
,
"isend"
,
"irecv"
,
"reduce_scatter"
"all_gather_object"
,
"InMemoryDataset"
,
"barrier"
,
"all_reduce"
,
"alltoall"
,
"send"
,
"reduce"
,
"recv"
,
"ReduceOp"
,
"wait"
,
"get_rank"
,
"ProbabilityEntry"
,
"ParallelMode"
,
"is_initialized"
,
"isend"
,
"irecv"
,
"reduce_scatter"
]
python/paddle/distributed/collective.py
浏览文件 @
d4cf02bc
...
...
@@ -14,6 +14,8 @@
import
numpy
as
np
import
os
import
pickle
import
io
from
datetime
import
timedelta
from
..fluid.layer_helper
import
LayerHelper
from
..fluid.framework
import
Variable
...
...
@@ -927,9 +929,9 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
Args:
tensor_list (list): A list of output Tensors. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32
or int64
.
should be float16, float32, float64, int32
, int64, int8, uint8, bool, complex64 or complex128
.
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32
or int64
.
should be float16, float32, float64, int32
, int64, int8, uint8, bool, complex64 or complex128
.
group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
Default to True.
...
...
@@ -941,7 +943,6 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
.. code-block:: python
# required: distributed
import numpy as np
import paddle
from paddle.distributed import init_parallel_env
...
...
@@ -949,21 +950,26 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
init_parallel_env()
tensor_list = []
if paddle.distributed.ParallelEnv().local_rank == 0:
np_data1 = np.array([[4, 5, 6], [4, 5, 6]])
np_data2 = np.array([[4, 5, 6], [4, 5, 6]])
data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2)
data1 = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
paddle.distributed.all_gather(tensor_list, data1)
else:
np_data1 = np.array([[1, 2, 3], [1, 2, 3]])
np_data2 = np.array([[1, 2, 3], [1, 2, 3]])
data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2)
data2 = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
paddle.distributed.all_gather(tensor_list, data2)
"""
if
group
is
not
None
and
not
group
.
is_member
():
return
def
convert_to_complex
(
list_of_tensor
):
list_of_complex
=
[]
for
tensor
in
list_of_tensor
:
list_of_complex
.
append
(
paddle
.
as_complex
(
tensor
))
return
list_of_complex
is_input_complex
=
(
tensor
.
dtype
==
paddle
.
complex64
or
tensor
.
dtype
==
paddle
.
complex128
)
if
is_input_complex
:
tensor
=
paddle
.
as_real
(
tensor
)
if
in_dygraph_mode
():
group
=
_get_default_group
()
if
group
is
None
else
group
if
len
(
tensor_list
)
==
0
:
...
...
@@ -975,7 +981,11 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
task
=
group
.
process_group
.
all_gather
(
tensor
,
out
)
task
.
wait
()
tensor_list
.
clear
()
tensor_list
.
extend
(
paddle
.
split
(
out
,
group
.
nranks
,
0
))
list_of_tensor
=
paddle
.
split
(
out
,
group
.
nranks
,
0
)
if
is_input_complex
:
tensor_list
.
extend
(
convert_to_complex
(
list_of_tensor
))
else
:
tensor_list
.
extend
(
list_of_tensor
)
return
ring_id
=
0
if
group
is
None
else
group
.
id
...
...
@@ -992,13 +1002,14 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
raise
ValueError
(
"The type of 'tensor_list' for all_gather "
"should be list."
)
for
elem
in
tensor_list
:
check_variable_and_dtype
(
elem
,
'tensor_list'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'all_gather'
)
check_variable_and_dtype
(
tensor
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'all_gather'
)
check_variable_and_dtype
(
elem
,
'tensor_list'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'bool'
,
'int8'
,
'uint8'
,
'complex64'
,
'complex128'
],
'all_gather'
)
check_variable_and_dtype
(
tensor
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'bool'
,
'int8'
,
'uint8'
,
'complex64'
,
'complex128'
],
'all_gather'
)
helper
.
append_op
(
type
=
op_type
,
inputs
=
{
'X'
:
[
tensor
]},
outputs
=
{
'Out'
:
[
out
]},
...
...
@@ -1008,7 +1019,69 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
'nranks'
:
nranks
})
tensor_list
.
extend
(
paddle
.
split
(
out
,
nranks
,
0
))
list_of_tensor
=
paddle
.
split
(
out
,
nranks
,
0
)
if
is_input_complex
:
tensor_list
.
extend
(
convert_to_complex
(
list_of_tensor
))
else
:
tensor_list
.
extend
(
list_of_tensor
)
def
_convert_object_to_tensor
(
obj
):
_pickler
=
pickle
.
Pickler
f
=
io
.
BytesIO
()
_pickler
(
f
).
dump
(
obj
)
data
=
np
.
frombuffer
(
f
.
getvalue
(),
dtype
=
np
.
uint8
)
tensor
=
paddle
.
to_tensor
(
data
)
return
tensor
def
_convert_tensor_to_object
(
tensor
):
_unpickler
=
pickle
.
Unpickler
return
_unpickler
(
io
.
BytesIO
(
tensor
.
numpy
())).
load
()
def
all_gather_object
(
object_list
,
obj
,
group
=
None
):
"""
Gather picklable objects from all participators and all get the result. Similiar to all_gather(), but python object can be passed in.
Args:
object_list (list): A list of output object. The datatype of every element in the list is same as the input obj.
obj (Any): The picklable object to send.
group (Group): The group instance return by new_group or None for global default group.
Returns:
None.
Warning:
This API only supports the dygraph mode.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
dist.init_parallel_env()
object_list = []
if paddle.distributed.ParallelEnv().local_rank == 0:
obj = {"foo": [1, 2, 3]}
paddle.distributed.all_gather_object(object_list, obj)
else:
obj = {"bar": [4, 5, 6]}
paddle.distributed.all_gather_object(object_list, obj)
"""
assert
in_dygraph_mode
(
),
"all_gather_object doesn't support static graph mode."
tensor
=
_convert_object_to_tensor
(
obj
)
tensor_list
=
[]
all_gather
(
tensor_list
,
tensor
,
group
)
for
tensor
in
tensor_list
:
object_list
.
append
(
_convert_tensor_to_object
(
tensor
))
def
scatter
(
tensor
,
tensor_list
=
None
,
src
=
0
,
group
=
None
,
use_calc_stream
=
True
):
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
d4cf02bc
...
...
@@ -183,6 +183,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
list
(
REMOVE_ITEM TEST_OPS test_new_group_api
)
list
(
REMOVE_ITEM TEST_OPS test_collective_broadcast_api
)
list
(
REMOVE_ITEM TEST_OPS test_collective_allgather_api
)
list
(
REMOVE_ITEM TEST_OPS test_collective_allgather_object_api
)
list
(
REMOVE_ITEM TEST_OPS test_collective_alltoall_api
)
list
(
REMOVE_ITEM TEST_OPS test_collective_global_gather
)
list
(
REMOVE_ITEM TEST_OPS test_collective_global_scatter
)
...
...
@@ -1598,7 +1599,9 @@ if(APPLE)
endif
()
if
((
WITH_ROCM OR WITH_GPU
)
AND NOT WIN32
)
set_tests_properties
(
test_collective_allgather_api PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_collective_allgather_api PROPERTIES TIMEOUT 300
)
set_tests_properties
(
test_collective_allgather_object_api PROPERTIES TIMEOUT
120
)
set_tests_properties
(
test_collective_alltoall_api PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_collective_global_gather PROPERTIES TIMEOUT 200
)
set_tests_properties
(
test_collective_global_scatter PROPERTIES TIMEOUT 200
)
...
...
@@ -1629,6 +1632,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
test_new_group_api
test_collective_broadcast_api
test_collective_allgather_api
test_collective_allgather_object_api
test_collective_alltoall_api
test_collective_global_gather
test_collective_global_scatter
...
...
python/paddle/fluid/tests/unittests/collective_allgather_api.py
浏览文件 @
d4cf02bc
...
...
@@ -30,28 +30,64 @@ import paddle.fluid.profiler as profiler
import
paddle.fluid.unique_name
as
nameGen
from
paddle.fluid
import
core
import
unittest
import
pickle
from
multiprocessing
import
Process
import
paddle.fluid.layers
as
layers
from
functools
import
reduce
from
test_collective_api_base
import
TestCollectiveAPIRunnerBase
,
runtime_main
import
test_collective_api_base
as
test_base
paddle
.
enable_static
()
class
TestCollectiveAllgatherAPI
(
TestCollectiveAPIRunnerBase
):
class
TestCollectiveAllgatherAPI
(
test_base
.
TestCollectiveAPIRunnerBase
):
def
__init__
(
self
):
self
.
global_ring_id
=
0
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
):
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
dtype
=
None
):
dtype
=
"float32"
if
dtype
is
None
else
dtype
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tensor_list
=
[]
tindata
=
layers
.
data
(
name
=
"tindata"
,
shape
=
[
10
,
1000
],
dtype
=
'float32'
)
tindata
=
layers
.
data
(
name
=
"tindata"
,
shape
=
[
10
,
1000
],
dtype
=
dtype
)
paddle
.
distributed
.
all_gather
(
tensor_list
,
tindata
)
return
tensor_list
def
run_trainer
(
self
,
args
):
train_prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
endpoints
=
args
[
"endpoints"
].
split
(
","
)
rank
=
args
[
"trainerid"
]
current_endpoint
=
args
[
"currentendpoint"
]
nranks
=
2
paddle
.
distributed
.
init_parallel_env
()
if
args
[
'backend'
]
==
'nccl'
:
device_id
=
int
(
os
.
getenv
(
"FLAGS_selected_gpus"
,
"0"
))
place
=
fluid
.
CUDAPlace
(
device_id
)
#if args.use_gpu else fluid.CPUPlace()
elif
args
[
'backend'
]
==
'bkcl'
:
device_id
=
int
(
os
.
getenv
(
"FLAGS_selected_xpus"
,
"0"
))
place
=
fluid
.
XPUPlace
(
device_id
)
else
:
place
=
fluid
.
CPUPlace
()
indata
=
test_base
.
create_test_data
(
shape
=
(
10
,
1000
),
dtype
=
args
[
"dtype"
],
seed
=
os
.
getpid
())
assert
args
[
'static_mode'
]
==
1
,
"collective_allgather_api only support static mode"
result
=
self
.
get_model
(
train_prog
,
startup_prog
,
rank
,
dtype
=
args
[
"dtype"
])
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_prog
)
fetch_list
=
[]
for
elem
in
result
:
fetch_list
.
append
(
elem
.
name
)
out
=
exe
.
run
(
train_prog
,
feed
=
{
'tindata'
:
indata
},
fetch_list
=
fetch_list
)
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out
))
if
__name__
==
"__main__"
:
runtime_main
(
TestCollectiveAllgatherAPI
,
"allgather"
)
test_base
.
runtime_main
(
TestCollectiveAllgatherAPI
,
"allgather"
)
python/paddle/fluid/tests/unittests/collective_allgather_api_dygraph.py
0 → 100644
浏览文件 @
d4cf02bc
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
paddle
import
paddle.fluid
as
fluid
import
unittest
import
test_collective_api_base
as
test_base
class
TestCollectiveAllgatherAPI
(
test_base
.
TestCollectiveAPIRunnerBase
):
def
__init__
(
self
):
self
.
global_ring_id
=
0
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
paddle
.
to_tensor
(
indata
)
tensor_list
=
[]
paddle
.
distributed
.
all_gather
(
tensor_list
,
tindata
)
return
[
tensor
.
numpy
()
for
tensor
in
tensor_list
]
if
__name__
==
"__main__"
:
test_base
.
runtime_main
(
TestCollectiveAllgatherAPI
,
"allgather"
)
python/paddle/fluid/tests/unittests/collective_allgather_object_api_dygraph.py
0 → 100644
浏览文件 @
d4cf02bc
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
paddle
import
paddle.fluid
as
fluid
import
test_collective_api_base
as
test_base
class
TestCollectiveAllgatherObjectAPI
(
test_base
.
TestCollectiveAPIRunnerBase
):
def
__init__
(
self
):
self
.
global_ring_id
=
0
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
object_list
=
[]
paddle
.
distributed
.
all_gather_object
(
object_list
,
indata
)
return
object_list
if
__name__
==
"__main__"
:
test_base
.
runtime_main
(
TestCollectiveAllgatherObjectAPI
,
"allgather_object"
)
python/paddle/fluid/tests/unittests/test_collective_allgather_api.py
浏览文件 @
d4cf02bc
...
...
@@ -28,12 +28,212 @@ class TestCollectiveAllgatherAPI(TestDistBase):
pass
def
test_allgather_nccl
(
self
):
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"float16"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"float32"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"float64"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"bool"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"uint8"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"int8"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"int32"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"int64"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"complex64"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
"complex128"
)
def
test_allgather_gloo
(
self
):
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"float16"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"float32"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"float64"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"bool"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"uint8"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"int8"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"int32"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"int64"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"complex64"
)
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
"complex128"
)
def
test_allgatther_nccl_dygraph
(
self
):
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"float16"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"float32"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"float64"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"bool"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"uint8"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"int8"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"int32"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"int64"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"complex64"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"complex128"
)
def
test_allgather_gloo_dygraph
(
self
):
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"float16"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"float32"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"float64"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"bool"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"uint8"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"int8"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"int32"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"int64"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"complex64"
)
self
.
check_with_place
(
"collective_allgather_api_dygraph.py"
,
"allgather"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"complex128"
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/test_collective_allgather_object_api.py
0 → 100644
浏览文件 @
d4cf02bc
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
test_collective_api_base
as
test_base
class
TestCollectiveAllgatherObjectAPI
(
test_base
.
TestDistBase
):
def
_setup_config
(
self
):
pass
def
test_allgather_nccl
(
self
):
self
.
check_with_place
(
"collective_allgather_object_api_dygraph.py"
,
"allgather_object"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"pylist"
)
self
.
check_with_place
(
"collective_allgather_object_api_dygraph.py"
,
"allgather_object"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
"pydict"
)
def
test_allgather_gloo_dygraph
(
self
):
self
.
check_with_place
(
"collective_allgather_object_api_dygraph.py"
,
"allgather_object"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"pylist"
)
self
.
check_with_place
(
"collective_allgather_object_api_dygraph.py"
,
"allgather_object"
,
"gloo"
,
"3"
,
static_mode
=
"0"
,
dtype
=
"pydict"
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_collective_api_base.py
浏览文件 @
d4cf02bc
...
...
@@ -31,9 +31,77 @@ import paddle.fluid.unique_name as nameGen
from
paddle.fluid
import
core
def
create_bool_test_data
(
shape
=
None
,
seed
=
None
):
if
seed
:
np
.
random
.
seed
(
seed
)
data
=
np
.
random
.
choice
([
True
,
False
],
size
=
shape
)
return
data
def
create_float_test_data
(
shape
=
None
,
dtype
=
None
,
seed
=
None
):
if
seed
:
np
.
random
.
seed
(
seed
)
data
=
np
.
random
.
random
(
shape
).
astype
(
dtype
)
return
data
def
create_int_test_data
(
shape
=
None
,
dtype
=
None
,
seed
=
None
):
if
seed
:
np
.
random
.
seed
(
seed
)
data
=
np
.
random
.
randint
(
0
,
high
=
100
,
size
=
shape
).
astype
(
dtype
)
return
data
def
create_complex_test_data
(
shape
=
None
,
dtype
=
None
,
seed
=
None
):
if
seed
:
np
.
random
.
seed
(
seed
)
data
=
np
.
random
.
random
(
shape
).
astype
(
dtype
)
data
.
imag
=
np
.
random
.
random
(
shape
)
return
data
def
create_pylist_test_data
(
shape
=
None
,
seed
=
None
):
if
seed
:
np
.
random
.
seed
(
seed
)
data
=
np
.
random
.
random
(
shape
).
tolist
()
return
data
def
create_pydict_test_data
(
shape
=
None
,
seed
=
None
):
if
seed
:
np
.
random
.
seed
(
seed
)
key
=
[
i
for
i
in
range
(
0
,
shape
[
0
])]
value
=
np
.
random
.
random
(
shape
).
tolist
()
data
=
dict
(
zip
(
key
,
value
))
return
data
def
create_test_data
(
shape
=
None
,
dtype
=
None
,
seed
=
None
):
assert
shape
,
"Shape should be specified"
if
dtype
==
"float32"
or
dtype
==
"float16"
or
dtype
==
"float64"
:
return
create_float_test_data
(
shape
=
shape
,
dtype
=
dtype
,
seed
=
seed
)
elif
dtype
==
"bool"
:
return
create_bool_test_data
(
shape
=
shape
,
seed
=
seed
)
elif
dtype
==
"int32"
or
dtype
==
"int64"
or
dtype
==
"int8"
or
dtype
==
"uint8"
:
return
create_int_test_data
(
shape
=
shape
,
dtype
=
dtype
,
seed
=
seed
)
elif
dtype
==
"complex64"
or
dtype
==
"complex128"
:
return
create_complex_test_data
(
shape
=
shape
,
dtype
=
dtype
,
seed
=
seed
)
elif
dtype
==
"pylist"
:
return
create_pylist_test_data
(
shape
=
shape
,
seed
=
seed
)
elif
dtype
==
"pydict"
:
return
create_pydict_test_data
(
shape
=
shape
,
seed
=
seed
)
else
:
raise
NotImplementedError
(
"Unsupported dtype for creating test data."
)
class
TestCollectiveAPIRunnerBase
(
object
):
def
get_model
(
self
,
train_prog
,
startup_prog
,
rank
,
indata
=
None
):
def
get_model
(
self
,
train_prog
,
startup_prog
,
rank
,
indata
=
None
,
dtype
=
None
):
raise
NotImplementedError
(
"get model should be implemented by child class."
)
...
...
@@ -54,8 +122,9 @@ class TestCollectiveAPIRunnerBase(object):
place
=
fluid
.
XPUPlace
(
device_id
)
else
:
place
=
fluid
.
CPUPlace
()
np
.
random
.
seed
(
os
.
getpid
())
indata
=
np
.
random
.
random
((
10
,
1000
)).
astype
(
"float32"
)
indata
=
create_test_data
(
shape
=
(
10
,
1000
),
dtype
=
args
[
"dtype"
],
seed
=
os
.
getpid
())
if
args
[
'static_mode'
]:
result
=
self
.
get_model
(
train_prog
,
startup_prog
,
rank
)
exe
=
fluid
.
Executor
(
place
)
...
...
@@ -83,6 +152,7 @@ def runtime_main(test_class, col_type):
args
[
"backend"
]
=
os
.
getenv
(
"BACKEND"
)
args
[
"path_id"
]
=
int
(
os
.
getenv
(
"PATH_ID"
))
args
[
"static_mode"
]
=
int
(
os
.
getenv
(
"STATIC_MODE"
))
args
[
"dtype"
]
=
os
.
getenv
(
"DTYPE"
)
model
.
run_trainer
(
args
)
...
...
@@ -203,18 +273,22 @@ class TestDistBase(unittest.TestCase):
static_mode
=
"1"
,
check_error_log
=
False
,
need_envs
=
{},
eager_mode
=
True
):
eager_mode
=
True
,
dtype
=
None
):
if
backend
==
"nccl"
or
backend
==
"bkcl"
:
with_gloo
=
'0'
else
:
with_gloo
=
'1'
required_envs
=
os
.
environ
.
copy
()
dtype
=
"float32"
if
dtype
is
None
else
dtype
additional_envs
=
{
"NCCL_P2P_DISABLE"
:
"1"
,
"STATIC_MODE"
:
static_mode
,
"PADDLE_WITH_GLOO"
:
with_gloo
,
"PADDLE_DISTRI_BACKEND"
:
backend
,
"BACKEND"
:
backend
,
"PATH_ID"
:
path_id
"PATH_ID"
:
path_id
,
"DTYPE"
:
dtype
}
required_envs
.
update
(
additional_envs
)
required_envs
.
update
(
need_envs
)
...
...
@@ -234,16 +308,18 @@ class TestDistBase(unittest.TestCase):
tr0_out
,
tr1_out
,
pid0
,
pid1
=
self
.
_run_cluster
(
model_file
,
required_envs
)
np
.
random
.
seed
(
pid0
)
input1
=
np
.
random
.
random
((
10
,
1000
))
np
.
random
.
seed
(
pid1
)
input2
=
np
.
random
.
random
((
10
,
1000
))
input1
=
create_test_data
(
shape
=
(
10
,
1000
),
dtype
=
dtype
,
seed
=
pid0
)
input2
=
create_test_data
(
shape
=
(
10
,
1000
),
dtype
=
dtype
,
seed
=
pid1
)
if
col_type
==
"allgather"
:
need_result
=
np
.
vstack
((
input1
,
input2
))
tr_out0
=
np
.
vstack
((
tr0_out
[
0
],
tr0_out
[
1
]))
tr_out1
=
np
.
vstack
((
tr1_out
[
0
],
tr1_out
[
1
]))
self
.
assertTrue
(
np
.
allclose
(
tr_out0
,
need_result
))
self
.
assertTrue
(
np
.
allclose
(
tr_out1
,
need_result
))
if
col_type
==
"allgather_object"
:
need_result
=
[
input1
,
input2
]
self
.
assertEqual
(
need_result
,
tr0_out
)
self
.
assertEqual
(
need_result
,
tr1_out
)
elif
col_type
==
"broadcast"
:
need_result
=
input2
self
.
assertTrue
(
np
.
allclose
(
tr0_out
,
need_result
))
...
...
python/paddle/tensor/manipulation.py
浏览文件 @
d4cf02bc
...
...
@@ -1737,7 +1737,7 @@ def split(x, num_or_sections, axis=0, name=None):
Split the input tensor into multiple sub-Tensors.
Args:
x (Tensor): A N-D Tensor. The data type is bool, float16, float32, float64, int32 or int64.
x (Tensor): A N-D Tensor. The data type is bool, float16, float32, float64,
uint8, int8,
int32 or int64.
num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections``
indicates the number of equal sized sub-Tensors that the ``x`` will be divided into.
If ``num_or_sections`` is a list or tuple, the length of it indicates the number of
...
...
@@ -1814,9 +1814,10 @@ def split(x, num_or_sections, axis=0, name=None):
_C_ops
.
split
(
input
,
out
,
*
attrs
)
return
out
check_variable_and_dtype
(
input
,
'input'
,
[
'bool'
,
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'split'
)
check_variable_and_dtype
(
input
,
'input'
,
[
'bool'
,
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint8'
,
'int8'
],
'split'
)
check_type
(
num_or_sections
,
'num_or_sections'
,
(
list
,
int
,
tuple
),
'split'
)
check_type
(
dim
,
'dim'
,
(
int
,
Variable
),
'split'
)
if
isinstance
(
dim
,
Variable
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录