Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
7a92e74b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7a92e74b
编写于
9月 06, 2022
作者:
W
Wen Sun
提交者:
GitHub
9月 06, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Completes basic dtypes for collective api in eager mode (#45574)
上级
1137677a
变更
23
隐藏空白更改
内联
并排
Showing
23 changed file
with
679 addition
and
221 deletion
+679
-221
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
+11
-2
paddle/phi/kernels/cpu/concat_kernel.cc
paddle/phi/kernels/cpu/concat_kernel.cc
+2
-0
paddle/phi/kernels/gpu/concat_kernel.cu
paddle/phi/kernels/gpu/concat_kernel.cu
+1
-0
python/paddle/distributed/collective.py
python/paddle/distributed/collective.py
+183
-196
python/paddle/fluid/tests/unittests/collective/CMakeLists.txt
...on/paddle/fluid/tests/unittests/collective/CMakeLists.txt
+26
-3
python/paddle/fluid/tests/unittests/collective/collective_alltoall_api_dygraph.py
...s/unittests/collective/collective_alltoall_api_dygraph.py
+3
-6
python/paddle/fluid/tests/unittests/collective/collective_alltoall_single_api_dygraph.py
...ests/collective/collective_alltoall_single_api_dygraph.py
+36
-0
python/paddle/fluid/tests/unittests/collective/collective_broadcast_api_dygraph.py
.../unittests/collective/collective_broadcast_api_dygraph.py
+36
-0
python/paddle/fluid/tests/unittests/collective/collective_isend_irecv_api_dygraph.py
...nittests/collective/collective_isend_irecv_api_dygraph.py
+40
-0
python/paddle/fluid/tests/unittests/collective/collective_reduce_api_dygraph.py
...sts/unittests/collective/collective_reduce_api_dygraph.py
+36
-0
python/paddle/fluid/tests/unittests/collective/collective_reduce_scatter_api_dygraph.py
...tests/collective/collective_reduce_scatter_api_dygraph.py
+37
-0
python/paddle/fluid/tests/unittests/collective/collective_scatter_api_dygraph.py
...ts/unittests/collective/collective_scatter_api_dygraph.py
+42
-0
python/paddle/fluid/tests/unittests/collective/test_collective_alltoall_api.py
...ests/unittests/collective/test_collective_alltoall_api.py
+10
-4
python/paddle/fluid/tests/unittests/collective/test_collective_alltoall_single_api.py
...ittests/collective/test_collective_alltoall_single_api.py
+39
-0
python/paddle/fluid/tests/unittests/collective/test_collective_broadcast_api.py
...sts/unittests/collective/test_collective_broadcast_api.py
+25
-0
python/paddle/fluid/tests/unittests/collective/test_collective_isend_irecv_api.py
...s/unittests/collective/test_collective_isend_irecv_api.py
+39
-0
python/paddle/fluid/tests/unittests/collective/test_collective_reduce_api.py
.../tests/unittests/collective/test_collective_reduce_api.py
+25
-0
python/paddle/fluid/tests/unittests/collective/test_collective_reduce_scatter_api.py
...nittests/collective/test_collective_reduce_scatter_api.py
+39
-0
python/paddle/fluid/tests/unittests/collective/test_collective_scatter_api.py
...tests/unittests/collective/test_collective_scatter_api.py
+25
-0
python/paddle/fluid/tests/unittests/collective/test_collective_sendrecv_api.py
...ests/unittests/collective/test_collective_sendrecv_api.py
+7
-2
python/paddle/fluid/tests/unittests/collective/testslist.csv
python/paddle/fluid/tests/unittests/collective/testslist.csv
+6
-3
python/paddle/fluid/tests/unittests/test_collective_api_base.py
.../paddle/fluid/tests/unittests/test_collective_api_base.py
+6
-0
python/paddle/tensor/manipulation.py
python/paddle/tensor/manipulation.py
+5
-5
未找到文件。
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
浏览文件 @
7a92e74b
...
...
@@ -738,14 +738,23 @@ void* GetPointerByOffset(void* raw_pointer,
}
else
if
(
type
==
experimental
::
DataType
::
FLOAT64
)
{
return
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
double
*>
(
raw_pointer
)
+
offset
);
}
else
if
(
type
==
experimental
::
DataType
::
FLOAT16
)
{
return
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
int16_t
*>
(
raw_pointer
)
+
offset
);
}
else
if
(
type
==
experimental
::
DataType
::
INT32
)
{
return
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
int32_t
*>
(
raw_pointer
)
+
offset
);
}
else
if
(
type
==
experimental
::
DataType
::
INT64
)
{
return
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
int64_t
*>
(
raw_pointer
)
+
offset
);
}
else
if
(
type
==
experimental
::
DataType
::
FLOAT16
)
{
return
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
int16_t
*>
(
raw_pointer
)
+
}
else
if
(
type
==
experimental
::
DataType
::
INT8
)
{
return
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
int8_t
*>
(
raw_pointer
)
+
offset
);
}
else
if
(
type
==
experimental
::
DataType
::
UINT8
)
{
return
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
uint8_t
*>
(
raw_pointer
)
+
offset
);
}
else
if
(
type
==
experimental
::
DataType
::
BOOL
)
{
return
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
bool
*>
(
raw_pointer
)
+
offset
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
...
...
paddle/phi/kernels/cpu/concat_kernel.cc
浏览文件 @
7a92e74b
...
...
@@ -124,6 +124,8 @@ PD_REGISTER_KERNEL(concat,
int64_t
,
int
,
uint8_t
,
int8_t
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
,
phi
::
dtype
::
complex
<
float
>
,
phi
::
dtype
::
complex
<
double
>
)
{}
paddle/phi/kernels/gpu/concat_kernel.cu
浏览文件 @
7a92e74b
...
...
@@ -121,6 +121,7 @@ PD_REGISTER_KERNEL(concat,
int64_t
,
int
,
uint8_t
,
int8_t
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
,
phi
::
dtype
::
complex
<
float
>
,
...
...
python/paddle/distributed/collective.py
浏览文件 @
7a92e74b
...
...
@@ -60,21 +60,18 @@ class ReduceOp:
Examples:
.. code-block:: python
import numpy as np
# required: distributed
import paddle
from paddle.distributed import ReduceOp
from paddle.distributed import init_parallel_env
import paddle.distributed as dist
paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
init_parallel_env()
if paddle.distributed.ParallelEnv().local_rank == 0:
np_data = np.array([[4, 5, 6], [4, 5, 6]])
dist.init_parallel_env()
if dist.get_rank() == 0:
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
else:
np_data = np.array([[1, 2, 3], [1, 2, 3]])
data = paddle.to_tensor(np_data)
paddle.distributed.all_reduce(data, op=ReduceOp.SUM)
out = data.numpy()
# [[5, 7, 9], [5, 7, 9]]
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
dist.all_reduce(data, op=dist.ReduceOp.SUM)
print(data)
# [[5, 7, 9], [5, 7, 9]] (2 GPUs)
"""
SUM
=
0
MAX
=
1
...
...
@@ -589,15 +586,16 @@ def destroy_process_group(group=None):
# required: distributed
import paddle
import paddle.distributed as dist
paddle.distributed
.init_parallel_env()
group =
paddle.distributed
.new_group([0, 1])
dist
.init_parallel_env()
group =
dist
.new_group([0, 1])
paddle.distributed
.destroy_process_group(group)
print(
paddle.distributed
.is_initialized())
dist
.destroy_process_group(group)
print(
dist
.is_initialized())
# True
paddle.distributed
.destroy_process_group()
print(
paddle.distributed
.is_initialized())
dist
.destroy_process_group()
print(
dist
.is_initialized())
# False
"""
...
...
@@ -690,8 +688,8 @@ def broadcast(tensor, src, group=None, use_calc_stream=True):
"""
Broadcast a tensor from the source to all others.
As shown below,
4 GPUs each start 4 processes
and GPU0 owns data 0. Through broadcast operator,
the
data 0 will be sent to all GPUs from GPU0.
As shown below,
one process is started with a GPU
and GPU0 owns data 0. Through broadcast operator,
data 0 will be sent to all GPUs from GPU0.
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/broadcast.png
:width: 800
...
...
@@ -699,8 +697,8 @@ def broadcast(tensor, src, group=None, use_calc_stream=True):
:align: center
Args:
tensor (Tensor): The Tensor to send if current rank is the source, or the
t
ensor to receive otherwise. Its data type
should be float16, float32, float64, int32
or int64
.
tensor (Tensor): The Tensor to send if current rank is the source, or the
T
ensor to receive otherwise. Its data type
should be float16, float32, float64, int32
, int64, int8, uint8 or bool
.
src (int): The source rank.
group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
...
...
@@ -713,20 +711,17 @@ def broadcast(tensor, src, group=None, use_calc_stream=True):
.. code-block:: python
# required: distributed
import numpy as np
import paddle
from paddle.distributed import init_parallel_env
import paddle.distributed as dist
paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
init_parallel_env()
if paddle.distributed.ParallelEnv().local_rank == 0:
np_data = np.array([[4, 5, 6], [4, 5, 6]])
dist.init_parallel_env()
if dist.get_rank() == 0:
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
else:
np_data = np.array([[1, 2, 3], [1, 2, 3]])
data = paddle.to_tensor(np_data)
paddle.distributed.broadcast(data, 1)
out = data.numpy()
# [[1, 2, 3], [1, 2, 3]]
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
dist.broadcast(data, src=1)
print(data)
# [[1, 2, 3], [1, 2, 3]] (2 GPUs)
"""
if
group
is
not
None
and
not
group
.
is_member
():
...
...
@@ -756,9 +751,10 @@ def broadcast(tensor, src, group=None, use_calc_stream=True):
'ring_id'
,
ring_id
)
op_type
=
'c_broadcast'
check_variable_and_dtype
(
tensor
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'broadcast'
)
check_variable_and_dtype
(
tensor
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
'bool'
],
'broadcast'
)
helper
=
LayerHelper
(
op_type
,
**
locals
())
helper
.
append_op
(
type
=
op_type
,
...
...
@@ -800,15 +796,16 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
# required: distributed
import paddle
from paddle.distributed import init_parallel_env
import paddle.distributed as dist
paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
init_parallel_env()
if paddle.distributed.ParallelEnv().local_rank == 0:
dist.init_parallel_env()
if dist.get_rank() == 0:
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
else:
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
paddle.distributed.all_reduce(data)
dist.all_reduce(data)
print(data)
# [[5, 7, 9], [5, 7, 9]] (2 GPUs)
"""
if
group
is
not
None
and
not
group
.
is_member
():
return
...
...
@@ -871,8 +868,8 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
def
reduce
(
tensor
,
dst
,
op
=
ReduceOp
.
SUM
,
group
=
None
,
use_calc_stream
=
True
):
"""
Reduce a tensor to the destination from all others. As shown below,
4 GPUs each start 4 processes and the data on each GPU is respres
nted
by
the GPU number
. The destination of the reduce operator is GPU0 and the process is sum. Through reduce operator,
Reduce a tensor to the destination from all others. As shown below,
one process is started with a GPU and the data of this process is represe
nted
by
its group rank
. The destination of the reduce operator is GPU0 and the process is sum. Through reduce operator,
the GPU0 will owns the sum of all data from all GPUs.
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/reduce.png
...
...
@@ -882,7 +879,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
Args:
tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type
should be float16, float32, float64, int32
or int64
.
should be float16, float32, float64, int32
, int64, int8, uint8 or bool
.
dst (int): The destination rank id.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM.
group (Group): The group instance return by new_group or None for global default group.
...
...
@@ -896,20 +893,18 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
.. code-block:: python
# required: distributed
import numpy as np
import paddle
from paddle.distributed import init_parallel_env
import paddle.distributed as dist
paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
init_parallel_env()
if paddle.distributed.ParallelEnv().local_rank == 0:
np_data = np.array([[4, 5, 6], [4, 5, 6]])
dist.init_parallel_env()
if dist.get_rank() == 0:
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
else:
np_data = np.array
([[1, 2, 3], [1, 2, 3]])
d
ata = paddle.to_tensor(np_data
)
p
addle.distributed.reduce(data, 0
)
out = data.numpy(
)
# [[
5, 7, 9], [5, 7, 9]]
data = paddle.to_tensor
([[1, 2, 3], [1, 2, 3]])
d
ist.reduce(data, dst=0
)
p
rint(data
)
# [[5, 7, 9], [5, 7, 9]] (2 GPUs, out for rank 0
)
# [[
1, 2, 3], [1, 2, 3]] (2 GPUs, out for rank 1)
"""
if
group
is
not
None
and
not
group
.
is_member
():
return
...
...
@@ -952,9 +947,10 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
raise
ValueError
(
"Unknown parameter: {}."
.
format
(
op
))
op_type
=
'c_reduce'
check_variable_and_dtype
(
tensor
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'all_reduce'
)
check_variable_and_dtype
(
tensor
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
'bool'
],
'reduce'
)
if
op
==
ReduceOp
.
SUM
:
op_type
=
'c_reduce_sum'
...
...
@@ -980,8 +976,8 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
"""
Gather tensors from all participators and all get the result. As shown
below,
4 GPUs each starts 4 processes and the data on each GPU
is represented
by
the GPU number
. Through the all_gather operator, each GPU will have data
below,
one process is started with a GPU and the data of this process
is represented
by
its group rank
. Through the all_gather operator, each GPU will have data
from all GPUs.
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allgather.png
...
...
@@ -1006,17 +1002,17 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
# required: distributed
import paddle
from paddle.distributed import init_parallel_env
import paddle.distributed as dist
paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
init_parallel_env()
dist.init_parallel_env()
tensor_list = []
if paddle.distributed.ParallelEnv().local_rank == 0:
data1 = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
paddle.distributed.all_gather(tensor_list, data1)
if dist.get_rank() == 0:
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
else:
data2 = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
paddle.distributed.all_gather(tensor_list, data2)
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
dist.all_gather(tensor_list, data)
print(tensor_list)
# [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
"""
if
group
is
not
None
and
not
group
.
is_member
():
return
...
...
@@ -1126,15 +1122,15 @@ def all_gather_object(object_list, obj, group=None):
import paddle
import paddle.distributed as dist
paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
dist.init_parallel_env()
object_list = []
if
paddle.distributed.ParallelEnv().local_rank
== 0:
if
dist.get_rank()
== 0:
obj = {"foo": [1, 2, 3]}
paddle.distributed.all_gather_object(object_list, obj)
else:
obj = {"bar": [4, 5, 6]}
paddle.distributed.all_gather_object(object_list, obj)
dist.all_gather_object(object_list, obj)
print(object_list)
# [{'foo': [1, 2, 3]}, {'bar': [4, 5, 6]}] (2 GPUs)
"""
assert
in_dygraph_mode
(
),
"all_gather_object doesn't support static graph mode."
...
...
@@ -1163,7 +1159,7 @@ def all_gather_object(object_list, obj, group=None):
def
scatter
(
tensor
,
tensor_list
=
None
,
src
=
0
,
group
=
None
,
use_calc_stream
=
True
):
"""
Scatter a tensor to all participators. As shown below,
4 GPUs each start 4 processes
and the source of the scatter
Scatter a tensor to all participators. As shown below,
one process is started with a GPU
and the source of the scatter
is GPU0. Through scatter operator, the data in GPU0 will be sent to all GPUs averagely.
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/scatter.png
...
...
@@ -1173,9 +1169,9 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
Args:
tensor (Tensor): The output Tensor. Its data type
should be float16, float32, float64, int32
or int64
.
should be float16, float32, float64, int32
, int64, int8, uint8 or bool
.
tensor_list (list|tuple): A list/tuple of Tensors to scatter. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32
or int64
. Default value is None.
should be float16, float32, float64, int32
, int64, int8, uint8 or bool
. Default value is None.
src (int): The source rank id. Default value is 0.
group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False).
...
...
@@ -1188,25 +1184,21 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
.. code-block:: python
# required: distributed
import numpy as np
import paddle
from paddle.distributed import init_parallel_env
import paddle.distributed as dist
paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
init_parallel_env()
if paddle.distributed.ParallelEnv().local_rank == 0:
np_data1 = np.array([7, 8, 9])
np_data2 = np.array([10, 11, 12])
else:
np_data1 = np.array([1, 2, 3])
np_data2 = np.array([4, 5, 6])
data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2)
if paddle.distributed.ParallelEnv().local_rank == 0:
paddle.distributed.scatter(data1, src=1)
dist.init_parallel_env()
if dist.get_rank() == 0:
data1 = paddle.to_tensor([7, 8, 9])
data2 = paddle.to_tensor([10, 11, 12])
dist.scatter(data1, src=1)
else:
paddle.distributed.scatter(data1, tensor_list=[data1, data2], src=1)
out = data1.numpy()
data1 = paddle.to_tensor([1, 2, 3])
data2 = paddle.to_tensor([4, 5, 6])
dist.scatter(data1, tensor_list=[data1, data2], src=1)
print(data1, data2)
# [1, 2, 3] [10, 11, 12] (2 GPUs, out for rank 0)
# [4, 5, 6] [4, 5, 6] (2 GPUs, out for rank 1)
"""
if
group
is
not
None
and
not
group
.
is_member
():
return
...
...
@@ -1244,9 +1236,10 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
use_calc_stream
,
'ring_id'
,
ring_id
,
'nranks'
,
nranks
,
'root'
,
gsrc
)
op_type
=
'c_scatter'
check_variable_and_dtype
(
tensor
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'scatter'
)
check_variable_and_dtype
(
tensor
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
'bool'
],
'scatter'
)
helper
=
LayerHelper
(
op_type
,
**
locals
())
helper
.
append_op
(
type
=
op_type
,
inputs
=
{
'X'
:
[
temp
]},
...
...
@@ -2014,7 +2007,7 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
Args:
in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32
or int64
.
should be float16, float32, float64, int32
, int64, int8, uint8 or bool
.
out_tensor_list (list): A list of output Tensors. The data type of its elements should be the same as the
data type of the input Tensors.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
...
...
@@ -2027,29 +2020,29 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
.. code-block:: python
# required: distributed
import numpy as np
import paddle
from paddle.distributed import init_parallel_env
init_parallel_env()
import paddle.distributed as dist
dist.
init_parallel_env()
out_tensor_list = []
if
paddle.distributed.ParallelEnv().rank
== 0:
np_data1 = np.array
([[1, 2, 3], [4, 5, 6]])
np_data2 = np.array
([[7, 8, 9], [10, 11, 12]])
if
dist.get_rank()
== 0:
data1 = paddle.to_tensor
([[1, 2, 3], [4, 5, 6]])
data2 = paddle.to_tensor
([[7, 8, 9], [10, 11, 12]])
else:
np_data1 = np.array([[13, 14, 15], [16, 17, 18]])
np_data2 = np.array([[19, 20, 21], [22, 23, 24]])
data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2)
paddle.distributed.alltoall([data1, data2], out_tensor_list)
# out for rank 0: [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]]
# out for rank 1: [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]]
data1 = paddle.to_tensor([[13, 14, 15], [16, 17, 18]])
data2 = paddle.to_tensor([[19, 20, 21], [22, 23, 24]])
dist.alltoall([data1, data2], out_tensor_list)
print(out_tensor_list)
# [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]] (2 GPUs, out for rank 0)
# [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]] (2 GPUs, out for rank 1)
"""
if
group
is
not
None
and
not
group
.
is_member
():
return
if
in_dygraph_mode
():
group
=
_get_default_group
()
if
group
is
None
else
group
backend
=
_group_map_backend
[
group
]
assert
backend
!=
'gloo'
,
(
"backend gloo is not supported yet"
)
else
:
ring_id
=
0
if
group
is
None
else
group
.
id
...
...
@@ -2114,7 +2107,7 @@ def alltoall_single(in_tensor,
``alltoall_single`` is only supported in eager mode.
Args:
in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32
or int64
.
in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32
, int64, int8, uint8 or bool
.
out_tensor (Tensor): Output Tensor. The data type should be the same as the data type of the input Tensor.
in_split_sizes (list[int], optional): Split sizes of ``in_tensor`` for dim[0]. If not given, dim[0] of ``in_tensor``
must be divisible by group size and ``in_tensor`` will be scattered averagely to all participators. Default: None.
...
...
@@ -2137,35 +2130,36 @@ def alltoall_single(in_tensor,
rank = dist.get_rank()
size = dist.get_world_size()
# case 1
input = paddle.arange(2, dtype='int64') + rank * 2
# input for rank 0: [0, 1]
# input for rank 1: [2, 3]
# case 1 (2 GPUs)
data = paddle.arange(2, dtype='int64') + rank * 2
# data for rank 0: [0, 1]
# data for rank 1: [2, 3]
output = paddle.empty([2], dtype='int64')
dist.alltoall_single(input, output)
dist.alltoall_single(data, output)
print(output)
# output for rank 0: [0, 2]
# output for rank 1: [1, 3]
# case 2
# case 2
(2 GPUs)
in_split_sizes = [i + 1 for i in range(size)]
# in_split_sizes for rank 0: [1, 2] and for rank 1: [1, 2]
# in_split_sizes for rank 0: [1, 2]
# in_split_sizes for rank 1: [1, 2]
out_split_sizes = [rank + 1 for i in range(size)]
# out_split_sizes for rank 0: [1, 1]
and for rank 1: [2, 2]
input
= paddle.ones([sum(in_split_sizes), size], dtype='float32') * rank
#
input
for rank 0: [[0., 0.], [0., 0.], [0., 0.]]
#
input
for rank 1: [[1., 1.], [1., 1.], [1., 1.]]
# out_split_sizes for rank 0: [1, 1]
# out_split_sizes for rank 1: [2, 2]
data
= paddle.ones([sum(in_split_sizes), size], dtype='float32') * rank
#
data
for rank 0: [[0., 0.], [0., 0.], [0., 0.]]
#
data
for rank 1: [[1., 1.], [1., 1.], [1., 1.]]
output = paddle.empty([(rank + 1) * size, size], dtype='float32')
group = dist.new_group([0, 1])
task = dist.alltoall_single(
input
,
task = dist.alltoall_single(
data
,
output,
in_split_sizes,
out_split_sizes,
use_calc_stream=False,
group=group)
task.wait()
print(output)
# output for rank 0: [[0., 0.], [1., 1.]]
# output for rank 1: [[0., 0.], [0., 0.], [1., 1.], [1., 1.]]
...
...
@@ -2177,6 +2171,9 @@ def alltoall_single(in_tensor,
# _check_single_tensor
group
=
_get_default_group
()
if
group
is
None
else
group
backend
=
_group_map_backend
[
group
]
assert
backend
!=
'gloo'
,
(
"backend gloo is not supported yet"
)
in_split_sizes
=
[]
if
in_split_sizes
is
None
else
in_split_sizes
out_split_sizes
=
[]
if
out_split_sizes
is
None
else
out_split_sizes
...
...
@@ -2199,7 +2196,7 @@ def send(tensor, dst=0, group=None, use_calc_stream=True):
Args:
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32
or int64
.
should be float16, float32, float64, int32
, int64, int8, uint8 or bool
.
dst (int): The destination rank id.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
use_calc_stream (bool, optional): Whether to use calculate stream or communication stream. Default: True.
...
...
@@ -2212,22 +2209,25 @@ def send(tensor, dst=0, group=None, use_calc_stream=True):
# required: distributed
import paddle
from paddle.distributed import init_parallel_env
import paddle.distributed as dist
init_parallel_env()
if
paddle.distributed.ParallelEnv().rank
== 0:
dist.
init_parallel_env()
if
dist.get_rank()
== 0:
data = paddle.to_tensor([7, 8, 9])
paddle.distributed
.send(data, dst=1)
dist
.send(data, dst=1)
else:
data = paddle.to_tensor([1,2,3])
paddle.distributed.recv(data, src=0)
out = data.numpy()
data = paddle.to_tensor([1, 2, 3])
dist.recv(data, src=0)
print(data)
# [7, 8, 9] (2 GPUs)
"""
if
group
is
not
None
and
not
group
.
is_member
():
return
dst
=
_get_group_rank
(
dst
,
group
)
if
in_dygraph_mode
():
group
=
_get_default_group
()
if
group
is
None
else
group
backend
=
_group_map_backend
[
group
]
assert
backend
!=
'gloo'
,
(
"backend gloo is not supported yet"
)
task
=
group
.
process_group
.
send
(
tensor
,
dst
)
if
use_calc_stream
:
task
.
wait
()
...
...
@@ -2261,7 +2261,7 @@ def recv(tensor, src=0, group=None, use_calc_stream=True):
Args:
tensor (Tensor): The Tensor to receive. Its data type
should be float16, float32, float64, int32
or int64
.
should be float16, float32, float64, int32
, int64, int8, uint8 or bool
.
src (int): The source rank id.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
use_calc_stream (bool, optional): Whether to use calculate stream or communication stream. Default: True.
...
...
@@ -2274,16 +2274,17 @@ def recv(tensor, src=0, group=None, use_calc_stream=True):
# required: distributed
import paddle
from paddle.distributed import init_parallel_env
import paddle.distributed as dist
init_parallel_env()
if
paddle.distributed.ParallelEnv().rank
== 0:
dist.
init_parallel_env()
if
dist.get_rank()
== 0:
data = paddle.to_tensor([7, 8, 9])
paddle.distributed
.send(data, dst=1)
dist
.send(data, dst=1)
else:
data = paddle.to_tensor([1,2,3])
paddle.distributed.recv(data, src=0)
out = data.numpy()
data = paddle.to_tensor([1, 2, 3])
dist.recv(data, src=0)
print(data)
# [7, 8, 9] (2 GPUs)
"""
if
group
is
not
None
and
not
group
.
is_member
():
return
...
...
@@ -2291,6 +2292,8 @@ def recv(tensor, src=0, group=None, use_calc_stream=True):
src
=
_get_group_rank
(
src
,
group
)
if
in_dygraph_mode
():
group
=
_get_default_group
()
if
group
is
None
else
group
backend
=
_group_map_backend
[
group
]
assert
backend
!=
'gloo'
,
(
"backend gloo is not supported yet"
)
task
=
group
.
process_group
.
recv
(
tensor
,
src
)
if
use_calc_stream
:
task
.
wait
()
...
...
@@ -2340,7 +2343,7 @@ def isend(tensor, dst, group=None):
Args:
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32
or int64
.
should be float16, float32, float64, int32
, int64, int8, uint8 or bool
.
dst (int): The destination rank.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
...
...
@@ -2358,21 +2361,15 @@ def isend(tensor, dst, group=None):
import paddle.distributed as dist
dist.init_parallel_env()
rank = dist.get_rank()
world_size = dist.get_world_size()
if rank == 0:
if dist.get_rank() == 0:
data = paddle.to_tensor([7, 8, 9])
task =
paddle.distributed
.isend(data, dst=1)
task =
dist
.isend(data, dst=1)
else:
data = paddle.to_tensor([1, 2, 3])
task = paddle.distributed.irecv(data, src=0)
task = dist.irecv(data, src=0)
task.wait()
print(data)
# paddle.tensor([7, 8, 9]) # Rank-0
# paddle.tensor([7, 8, 9]) # Rank-1
# [7, 8, 9] (2 GPUs)
"""
_check_single_tensor
(
tensor
,
"tensor"
)
...
...
@@ -2381,6 +2378,8 @@ def isend(tensor, dst, group=None):
if
in_dygraph_mode
():
group
=
_get_default_group
()
if
group
is
None
else
group
backend
=
_group_map_backend
[
group
]
assert
backend
!=
'gloo'
,
(
"backend gloo is not supported yet"
)
group_dst_rank
=
group
.
get_group_rank
(
dst
)
assert
group_dst_rank
>=
0
,
(
"dst rank out of group, need global rank"
)
return
group
.
process_group
.
send
(
tensor
,
group_dst_rank
)
...
...
@@ -2394,12 +2393,12 @@ def irecv(tensor, src=None, group=None):
Args:
tensor (Tensor): The Tensor to receive. Its data type
should be float16, float32, float64, int32
or int64
.
should be float16, float32, float64, int32
, int64, int8, uint8 or bool
.
src (int): The source rank id.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
Returns:
A distributed task object.
A distributed task object.
Warning:
This API only supports the dygraph mode.
...
...
@@ -2412,21 +2411,15 @@ def irecv(tensor, src=None, group=None):
import paddle.distributed as dist
dist.init_parallel_env()
rank = dist.get_rank()
world_size = dist.get_world_size()
if rank == 0:
if dist.get_rank() == 0:
data = paddle.to_tensor([7, 8, 9])
task =
paddle.distributed
.isend(data, dst=1)
task =
dist
.isend(data, dst=1)
else:
data = paddle.to_tensor([1, 2, 3])
task = paddle.distributed.irecv(data, src=0)
task = dist.irecv(data, src=0)
task.wait()
print(data)
# paddle.tensor([7, 8, 9]) # Rank-0
# paddle.tensor([7, 8, 9]) # Rank-1
# [7, 8, 9] (2 GPUs)
"""
_check_single_tensor
(
tensor
,
"tensor"
)
if
group
is
not
None
and
not
group
.
is_member
():
...
...
@@ -2434,6 +2427,8 @@ def irecv(tensor, src=None, group=None):
if
in_dygraph_mode
():
group
=
_get_default_group
()
if
group
is
None
else
group
backend
=
_group_map_backend
[
group
]
assert
backend
!=
'gloo'
,
(
"backend gloo is not supported yet"
)
group_src_rank
=
group
.
get_group_rank
(
src
)
assert
group_src_rank
>=
0
,
(
"src rank out of group, need global rank"
)
return
group
.
process_group
.
recv
(
tensor
,
group_src_rank
)
...
...
@@ -2581,8 +2576,9 @@ def reduce_scatter(tensor,
Reduces, then scatters a list of tensors to all processes in a group
Args:
tensor (Tensor): Output tensor.
tensor_list (list[Tensor]): List of tensors to reduce and scatter.
tensor (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8 or bool.
tensor_list (list[Tensor]): List of tensors to reduce and scatter. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
group (Group, optional): The group instance return by new_group or None for global
default group. Default: None.
...
...
@@ -2604,24 +2600,16 @@ def reduce_scatter(tensor,
import paddle.distributed as dist
dist.init_parallel_env()
rank = dist.get_rank()
world_size = dist.get_world_size()
if rank == 0:
t1 = paddle.to_tensor([0, 1])
t2 = paddle.to_tensor([2, 3])
if dist.get_rank() == 0:
data1 = paddle.to_tensor([0, 1])
data2 = paddle.to_tensor([2, 3])
else:
t1 = paddle.to_tensor([4, 5])
t2 = paddle.to_tensor([6, 7])
tensor_list = [t1, t2]
output = paddle.empty(shape=[2], dtype=tensor_list[0].dtype)
dist.reduce_scatter(output, tensor_list)
print(output)
# [4, 6] # Rank-0
# [8, 10] # Rank-1
data1 = paddle.to_tensor([4, 5])
data2 = paddle.to_tensor([6, 7])
dist.reduce_scatter(data1, [data1, data2])
print(data1)
# [4, 6] (2 GPUs, out for rank 0)
# [8, 10] (2 GPUs, out for rank 1)
"""
_check_single_tensor
(
tensor
,
"tensor"
)
...
...
@@ -2633,6 +2621,8 @@ def reduce_scatter(tensor,
if
in_dygraph_mode
():
op_type
=
_get_reduce_op
(
op
,
"reduce_scatter"
)
group
=
_get_default_group
()
if
group
is
None
else
group
backend
=
_group_map_backend
[
group
]
assert
backend
!=
'gloo'
,
(
"backend gloo is not supported yet"
)
temp
=
paddle
.
concat
(
tensor_list
,
axis
=
0
)
task
=
group
.
process_group
.
_reduce_scatter_base
(
tensor
,
temp
,
op_type
)
...
...
@@ -2654,8 +2644,9 @@ def _reduce_scatter_base(output,
Reduces, then scatters a flattened tensor to all processes in a group.
Args:
output (Tensor): Output tensor.
input (Tensor): Input tensor that is of size output tensor size times world size
output (Tensor): Output tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8 or bool.
input (Tensor): Input tensor that is of size output tensor size times world size. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
...
...
@@ -2669,23 +2660,19 @@ def _reduce_scatter_base(output,
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
rank = dist.get_rank()
world_size = dist.get_world_size()
input = paddle.arange(4) + rank
# [0, 1, 2, 3] # Rank-0
# [1, 2, 3, 4] # Rank-1
output = paddle.empty(shape=[2], dtype=input.dtype)
paddle.distributed.collective._reduce_scatter_base(output, input)
data = paddle.arange(4) + rank
# [0, 1, 2, 3] (2 GPUs, for rank 0)
# [1, 2, 3, 4] (2 GPUs, for rank 1)
output = paddle.empty(shape=[2], dtype=data.dtype)
dist.collective._reduce_scatter_base(output, data)
print(output)
# [1, 3]
# Rank-0
# [5, 7]
# Rank-1
# [1, 3]
(2 GPUs, out for rank 0)
# [5, 7]
(2 GPUs, out for rank 1)
"""
_check_single_tensor
(
output
,
"output"
)
...
...
python/paddle/fluid/tests/unittests/collective/CMakeLists.txt
浏览文件 @
7a92e74b
...
...
@@ -78,7 +78,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_alltoall_api MODULES test_collective_alltoall_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
set_tests_properties
(
test_collective_alltoall_api
PROPERTIES TIMEOUT
"
12
0"
LABELS
"RUN_TYPE=DIST"
)
PROPERTIES TIMEOUT
"
30
0"
LABELS
"RUN_TYPE=DIST"
)
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
bash_test_modules
(
...
...
@@ -92,6 +92,14 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
)
set_tests_properties
(
test_collective_alltoall_single PROPERTIES TIMEOUT
"350"
)
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
py_test_modules
(
test_collective_alltoall_single_api MODULES
test_collective_alltoall_single_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
set_tests_properties
(
test_collective_alltoall_single_api
PROPERTIES TIMEOUT
"300"
LABELS
"RUN_TYPE=DIST"
)
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
py_test_modules
(
test_collective_barrier_api MODULES test_collective_barrier_api ENVS
...
...
@@ -117,7 +125,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_broadcast_api MODULES test_collective_broadcast_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
set_tests_properties
(
test_collective_broadcast_api
PROPERTIES TIMEOUT
"
12
0"
LABELS
"RUN_TYPE=DIST"
)
PROPERTIES TIMEOUT
"
30
0"
LABELS
"RUN_TYPE=DIST"
)
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
py_test_modules
(
...
...
@@ -141,6 +149,13 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
set_tests_properties
(
test_collective_global_scatter
PROPERTIES TIMEOUT
"200"
LABELS
"RUN_TYPE=DIST"
)
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
py_test_modules
(
test_collective_isend_irecv_api MODULES test_collective_isend_irecv_api
ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
set_tests_properties
(
test_collective_isend_irecv_api
PROPERTIES TIMEOUT
"300"
LABELS
"RUN_TYPE=DIST"
)
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
py_test_modules
(
test_collective_optimizer MODULES test_collective_optimizer ENVS
...
...
@@ -186,6 +201,14 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
)
set_tests_properties
(
test_collective_reduce_scatter PROPERTIES TIMEOUT
"350"
)
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
py_test_modules
(
test_collective_reduce_scatter_api MODULES
test_collective_reduce_scatter_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
set_tests_properties
(
test_collective_reduce_scatter_api
PROPERTIES TIMEOUT
"300"
LABELS
"RUN_TYPE=DIST"
)
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
py_test_modules
(
test_collective_scatter MODULES test_collective_scatter ENVS
...
...
@@ -212,7 +235,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_sendrecv_api MODULES test_collective_sendrecv_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python"
)
set_tests_properties
(
test_collective_sendrecv_api
PROPERTIES TIMEOUT
"
12
0"
LABELS
"RUN_TYPE=DIST"
)
PROPERTIES TIMEOUT
"
30
0"
LABELS
"RUN_TYPE=DIST"
)
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
py_test_modules
(
...
...
python/paddle/fluid/tests/unittests/collective/collective_alltoall_api_dygraph.py
浏览文件 @
7a92e74b
...
...
@@ -45,12 +45,9 @@ class TestCollectiveAllToAllAPI(TestCollectiveAPIRunnerBase):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
paddle
.
to_tensor
(
indata
)
tindata
=
paddle
.
split
(
tindata
,
2
,
axis
=
0
)
tout_data
=
[]
paddle
.
distributed
.
alltoall
(
tindata
,
tout_data
)
output_data
=
[]
for
data
in
tout_data
:
output_data
.
append
(
data
.
numpy
())
return
output_data
toutdata
=
[]
paddle
.
distributed
.
alltoall
(
tindata
,
toutdata
)
return
[
data
.
numpy
()
for
data
in
toutdata
]
if
__name__
==
"__main__"
:
...
...
python/paddle/fluid/tests/unittests/collective/collective_alltoall_single_api_dygraph.py
0 → 100644
浏览文件 @
7a92e74b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
paddle
import
paddle.fluid
as
fluid
import
test_collective_api_base
as
test_base
class
TestCollectiveAllToAllSingleAPI
(
test_base
.
TestCollectiveAPIRunnerBase
):
def
__init__
(
self
):
self
.
global_ring_id
=
0
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
paddle
.
to_tensor
(
indata
)
toutdata
=
paddle
.
to_tensor
(
indata
)
paddle
.
distributed
.
alltoall_single
(
tindata
,
toutdata
)
return
[
toutdata
.
numpy
()]
if
__name__
==
"__main__"
:
test_base
.
runtime_main
(
TestCollectiveAllToAllSingleAPI
,
"alltoall"
)
python/paddle/fluid/tests/unittests/collective/collective_broadcast_api_dygraph.py
0 → 100644
浏览文件 @
7a92e74b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
paddle
import
paddle.fluid
as
fluid
import
unittest
import
test_collective_api_base
as
test_base
class
TestCollectiveBroadcastAPI
(
test_base
.
TestCollectiveAPIRunnerBase
):
def
__init__
(
self
):
self
.
global_ring_id
=
0
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
paddle
.
to_tensor
(
indata
)
paddle
.
distributed
.
broadcast
(
tindata
,
src
=
1
)
return
[
tindata
.
numpy
()]
if
__name__
==
"__main__"
:
test_base
.
runtime_main
(
TestCollectiveBroadcastAPI
,
"broadcast"
)
python/paddle/fluid/tests/unittests/collective/collective_isend_irecv_api_dygraph.py
0 → 100644
浏览文件 @
7a92e74b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
paddle
import
paddle.fluid
as
fluid
import
unittest
import
test_collective_api_base
as
test_base
class
TestCollectiveIsendIrecvAPI
(
test_base
.
TestCollectiveAPIRunnerBase
):
def
__init__
(
self
):
self
.
global_ring_id
=
0
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
paddle
.
to_tensor
(
indata
)
if
rank
==
0
:
task
=
paddle
.
distributed
.
isend
(
tindata
,
dst
=
1
)
else
:
task
=
paddle
.
distributed
.
irecv
(
tindata
,
src
=
0
)
task
.
wait
()
return
[
tindata
.
numpy
()]
if
__name__
==
"__main__"
:
test_base
.
runtime_main
(
TestCollectiveIsendIrecvAPI
,
"sendrecv"
)
python/paddle/fluid/tests/unittests/collective/collective_reduce_api_dygraph.py
0 → 100644
浏览文件 @
7a92e74b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
paddle
import
paddle.fluid
as
fluid
import
unittest
import
test_collective_api_base
as
test_base
class
TestCollectiveReduceAPI
(
test_base
.
TestCollectiveAPIRunnerBase
):
def
__init__
(
self
):
self
.
global_ring_id
=
0
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
paddle
.
to_tensor
(
indata
)
paddle
.
distributed
.
reduce
(
tindata
,
dst
=
0
)
return
[
tindata
.
numpy
()]
if
__name__
==
"__main__"
:
test_base
.
runtime_main
(
TestCollectiveReduceAPI
,
"reduce"
)
python/paddle/fluid/tests/unittests/collective/collective_reduce_scatter_api_dygraph.py
0 → 100644
浏览文件 @
7a92e74b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
paddle
import
paddle.fluid
as
fluid
import
unittest
import
test_collective_api_base
as
test_base
class
TestCollectiveReduceScatterAPI
(
test_base
.
TestCollectiveAPIRunnerBase
):
def
__init__
(
self
):
self
.
global_ring_id
=
0
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
paddle
.
to_tensor
(
indata
)
subdata1
,
subdata2
=
paddle
.
split
(
tindata
,
2
,
axis
=
0
)
paddle
.
distributed
.
reduce_scatter
(
subdata1
,
[
subdata1
,
subdata2
])
return
[
subdata1
.
numpy
()]
if
__name__
==
"__main__"
:
test_base
.
runtime_main
(
TestCollectiveReduceScatterAPI
,
"reduce_scatter"
)
python/paddle/fluid/tests/unittests/collective/collective_scatter_api_dygraph.py
0 → 100644
浏览文件 @
7a92e74b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
paddle
import
paddle.fluid
as
fluid
import
unittest
import
test_collective_api_base
as
test_base
class
TestCollectiveScatterAPI
(
test_base
.
TestCollectiveAPIRunnerBase
):
def
__init__
(
self
):
self
.
global_ring_id
=
0
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
indata
=
None
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
paddle
.
to_tensor
(
indata
)
subdata1
,
subdata2
=
paddle
.
split
(
tindata
,
2
,
axis
=
0
)
if
rank
==
0
:
paddle
.
distributed
.
scatter
(
subdata1
,
src
=
1
)
else
:
paddle
.
distributed
.
scatter
(
subdata1
,
tensor_list
=
[
subdata1
,
subdata2
],
src
=
1
)
return
[
subdata1
.
numpy
()]
if
__name__
==
"__main__"
:
test_base
.
runtime_main
(
TestCollectiveScatterAPI
,
"scatter"
)
python/paddle/fluid/tests/unittests/collective/test_collective_alltoall_api.py
浏览文件 @
7a92e74b
...
...
@@ -31,10 +31,16 @@ class TestCollectiveAllToAllAPI(TestDistBase):
self
.
check_with_place
(
"collective_alltoall_api.py"
,
"alltoall"
,
"nccl"
)
def
test_alltoall_nccl_dygraph
(
self
):
self
.
check_with_place
(
"collective_alltoall_api_dygraph.py"
,
"alltoall"
,
"nccl"
,
static_mode
=
"0"
)
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
'bool'
]
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_alltoall_api_dygraph.py"
,
"alltoall"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
dtype
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/collective/test_collective_alltoall_single_api.py
0 → 100644
浏览文件 @
7a92e74b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
test_collective_api_base
as
test_base
class
TestCollectiveAllToAllSingleAPI
(
test_base
.
TestDistBase
):
def
_setup_config
(
self
):
pass
def
test_alltooall_single_nccl_dygraph
(
self
):
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
'bool'
]
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_alltoall_single_api_dygraph.py"
,
"alltoall"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
dtype
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/test_collective_broadcast_api.py
浏览文件 @
7a92e74b
...
...
@@ -35,6 +35,31 @@ class TestCollectiveBroadcastAPI(TestDistBase):
self
.
check_with_place
(
"collective_broadcast_api.py"
,
"broadcast"
,
"gloo"
,
"0"
)
def
test_broadcast_nccl_dygraph
(
self
):
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
'bool'
]
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_broadcast_api_dygraph.py"
,
"broadcast"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
dtype
)
def
test_broadcast_gloo_dygraph
(
self
):
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
'bool'
]
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_broadcast_api_dygraph.py"
,
"broadcast"
,
"gloo"
,
"0"
,
static_mode
=
"0"
,
dtype
=
dtype
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/test_collective_isend_irecv_api.py
0 → 100644
浏览文件 @
7a92e74b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
test_collective_api_base
as
test_base
class
TestCollectiveIsendIrecvAPI
(
test_base
.
TestDistBase
):
def
_setup_config
(
self
):
pass
def
test_isend_irecv_nccl_dygraph
(
self
):
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
'bool'
]
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_isend_irecv_api_dygraph.py"
,
"sendrecv"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
dtype
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/test_collective_reduce_api.py
浏览文件 @
7a92e74b
...
...
@@ -38,6 +38,31 @@ class TestCollectiveReduceAPI(TestDistBase):
def
test_reduce_gloo
(
self
):
self
.
check_with_place
(
"collective_reduce_api.py"
,
"reduce"
,
"gloo"
,
"1"
)
def
test_reduce_nccl_dygraph
(
self
):
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
'bool'
]
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_reduce_api_dygraph.py"
,
"reduce"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
dtype
)
def
test_reduce_gloo_dygraph
(
self
):
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
'bool'
]
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_reduce_api_dygraph.py"
,
"reduce"
,
"gloo"
,
"1"
,
static_mode
=
"0"
,
dtype
=
dtype
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/test_collective_reduce_scatter_api.py
0 → 100644
浏览文件 @
7a92e74b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
test_collective_api_base
as
test_base
class
TestCollectiveReduceScatterAPI
(
test_base
.
TestDistBase
):
def
_setup_config
(
self
):
pass
def
test_reduce_scatter_nccl_dygraph
(
self
):
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
'bool'
]
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_reduce_scatter_api_dygraph.py"
,
"reduce_scatter"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
dtype
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/test_collective_scatter_api.py
浏览文件 @
7a92e74b
...
...
@@ -34,6 +34,31 @@ class TestCollectiveScatterAPI(TestDistBase):
def
test_scatter_nccl
(
self
):
self
.
check_with_place
(
"collective_scatter_api.py"
,
"scatter"
,
"nccl"
)
def
test_scatter_nccl_dygraph
(
self
):
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
'bool'
]
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_scatter_api_dygraph.py"
,
"scatter"
,
"nccl"
,
static_mode
=
"0"
,
dtype
=
dtype
)
def
test_scatter_gloo_dygraph
(
self
):
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
'bool'
]
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_scatter_api_dygraph.py"
,
"scatter"
,
"gloo"
,
"4"
,
static_mode
=
"0"
,
dtype
=
dtype
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/test_collective_sendrecv_api.py
浏览文件 @
7a92e74b
...
...
@@ -33,11 +33,16 @@ class TestCollectiveSendRecvAPI(TestDistBase):
# "nccl")
def
test_sendrecv_nccl_dygraph
(
self
):
if
paddle
.
fluid
.
core
.
is_compiled_with_cuda
():
dtypes_to_test
=
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'int8'
,
'uint8'
,
'bool'
]
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_sendrecv_api_dygraph.py"
,
"sendrecv"
,
"nccl"
,
static_mode
=
'0'
)
static_mode
=
"0"
,
dtype
=
dtype
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/collective/testslist.csv
浏览文件 @
7a92e74b
...
...
@@ -8,23 +8,26 @@ test_collective_split_embedding,linux,rocm;gpu,300,DIST,../dist_test.sh,2,,PYTHO
test_collective_allgather_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_allgather_object_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_allreduce_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_alltoall_api,linux,gpu;rocm,
12
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_alltoall_api,linux,gpu;rocm,
30
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_alltoall_single,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_alltoall_single_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_barrier_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_batch_isend_irecv,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_broadcast_api,linux,gpu;rocm,
12
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_broadcast_api,linux,gpu;rocm,
30
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_cpu_barrier_with_gloo,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_global_gather,linux,gpu;rocm,200,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_global_scatter,linux,gpu;rocm,200,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_isend_irecv_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_optimizer,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_process_group,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce_scatter,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce_scatter_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_scatter,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_scatter_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_sendrecv,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_sendrecv_api,linux,gpu;rocm,
12
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_sendrecv_api,linux,gpu;rocm,
30
0,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_split_col_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_split_embedding_none_divisible,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_split_row_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
...
...
python/paddle/fluid/tests/unittests/test_collective_api_base.py
浏览文件 @
7a92e74b
...
...
@@ -335,6 +335,12 @@ class TestDistBase(unittest.TestCase):
need_result2
=
need_result
[
need_result
.
shape
[
0
]
//
2
:]
np
.
testing
.
assert_allclose
(
tr0_out
[
0
],
need_result1
,
rtol
=
1e-05
)
np
.
testing
.
assert_allclose
(
tr1_out
[
0
],
need_result2
,
rtol
=
1e-05
)
elif
col_type
==
"reduce_scatter"
:
need_result
=
input1
+
input2
need_result1
=
need_result
[
0
:
need_result
.
shape
[
0
]
//
2
]
need_result2
=
need_result
[
need_result
.
shape
[
0
]
//
2
:]
np
.
testing
.
assert_allclose
(
tr0_out
[
0
],
need_result1
,
rtol
=
1e-05
)
np
.
testing
.
assert_allclose
(
tr1_out
[
0
],
need_result2
,
rtol
=
1e-05
)
elif
col_type
==
"allreduce"
:
need_result
=
input1
+
input2
np
.
testing
.
assert_allclose
(
tr0_out
[
0
],
...
...
python/paddle/tensor/manipulation.py
浏览文件 @
7a92e74b
...
...
@@ -1015,7 +1015,7 @@ def concat(x, axis=0, name=None):
Args:
x (list|tuple): ``x`` is a Tensor list or Tensor tuple which is with data type bool, float16,
float32, float64, int32, int64, uint8. All the Tensors in ``x`` must have same data type.
float32, float64, int32, int64,
int8,
uint8. All the Tensors in ``x`` must have same data type.
axis (int|Tensor, optional): Specify the axis to operate on the input Tensors.
It's a scalar with data type int or a Tensor with shape [1] and data type int32
or int64. The effective range is [-R, R), where R is Rank(x). When ``axis < 0``,
...
...
@@ -1073,10 +1073,10 @@ def concat(x, axis=0, name=None):
check_type
(
input
,
'input'
,
(
list
,
tuple
,
Variable
),
'concat'
)
if
not
isinstance
(
input
,
Variable
):
for
id
,
x
in
enumerate
(
input
):
check_variable_and_dtype
(
x
,
'input['
+
str
(
id
)
+
']
'
,
[
'bool'
,
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'concat'
)
check_variable_and_dtype
(
x
,
'input['
+
str
(
id
)
+
']'
,
[
'bool'
,
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64
'
,
'int8'
,
'unit8'
],
'concat'
)
if
x
.
dtype
!=
input
[
0
].
dtype
:
raise
TypeError
(
"All the Tensors in the input must have the same data type."
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录