Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
f94edc3b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f94edc3b
编写于
10月 11, 2022
作者:
W
Wen Sun
提交者:
GitHub
10月 11, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support both use_calc_stream and sync_op in collective communication API (#46761)
上级
e4eb8d36
变更
35
展开全部
隐藏空白更改
内联
并排
Showing
35 changed file
with
3014 addition
and
144 deletion
+3014
-144
paddle/fluid/distributed/collective/ProcessGroup.h
paddle/fluid/distributed/collective/ProcessGroup.h
+75
-17
paddle/fluid/distributed/collective/ProcessGroupCustom.cc
paddle/fluid/distributed/collective/ProcessGroupCustom.cc
+2
-2
paddle/fluid/distributed/collective/ProcessGroupCustom.h
paddle/fluid/distributed/collective/ProcessGroupCustom.h
+4
-4
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
+282
-10
paddle/fluid/distributed/collective/ProcessGroupNCCL.h
paddle/fluid/distributed/collective/ProcessGroupNCCL.h
+57
-15
paddle/fluid/distributed/collective/ProcessGroupStream.cc
paddle/fluid/distributed/collective/ProcessGroupStream.cc
+144
-12
paddle/fluid/distributed/collective/ProcessGroupStream.h
paddle/fluid/distributed/collective/ProcessGroupStream.h
+90
-12
paddle/fluid/distributed/collective/Utils.h
paddle/fluid/distributed/collective/Utils.h
+126
-18
paddle/fluid/pybind/distributed_py.cc
paddle/fluid/pybind/distributed_py.cc
+517
-39
python/paddle/distributed/communication/stream/__init__.py
python/paddle/distributed/communication/stream/__init__.py
+11
-2
python/paddle/distributed/communication/stream/all_gather.py
python/paddle/distributed/communication/stream/all_gather.py
+9
-7
python/paddle/distributed/communication/stream/all_reduce.py
python/paddle/distributed/communication/stream/all_reduce.py
+1
-1
python/paddle/distributed/communication/stream/alltoall.py
python/paddle/distributed/communication/stream/alltoall.py
+157
-0
python/paddle/distributed/communication/stream/alltoall_single.py
...addle/distributed/communication/stream/alltoall_single.py
+128
-0
python/paddle/distributed/communication/stream/broadcast.py
python/paddle/distributed/communication/stream/broadcast.py
+83
-0
python/paddle/distributed/communication/stream/recv.py
python/paddle/distributed/communication/stream/recv.py
+1
-1
python/paddle/distributed/communication/stream/reduce.py
python/paddle/distributed/communication/stream/reduce.py
+93
-0
python/paddle/distributed/communication/stream/reduce_scatter.py
...paddle/distributed/communication/stream/reduce_scatter.py
+216
-0
python/paddle/distributed/communication/stream/scatter.py
python/paddle/distributed/communication/stream/scatter.py
+162
-0
python/paddle/distributed/communication/stream/send.py
python/paddle/distributed/communication/stream/send.py
+1
-1
python/paddle/fluid/tests/unittests/collective/CMakeLists.txt
...on/paddle/fluid/tests/unittests/collective/CMakeLists.txt
+48
-0
python/paddle/fluid/tests/unittests/collective/communication_stream_alltoall_api_dygraph.py
...s/collective/communication_stream_alltoall_api_dygraph.py
+113
-0
python/paddle/fluid/tests/unittests/collective/communication_stream_alltoall_single_api_dygraph.py
...ctive/communication_stream_alltoall_single_api_dygraph.py
+74
-0
python/paddle/fluid/tests/unittests/collective/communication_stream_broadcast_api_dygraph.py
.../collective/communication_stream_broadcast_api_dygraph.py
+54
-0
python/paddle/fluid/tests/unittests/collective/communication_stream_reduce_api_dygraph.py
...sts/collective/communication_stream_reduce_api_dygraph.py
+66
-0
python/paddle/fluid/tests/unittests/collective/communication_stream_reduce_scatter_api_dygraph.py
...ective/communication_stream_reduce_scatter_api_dygraph.py
+94
-0
python/paddle/fluid/tests/unittests/collective/communication_stream_scatter_api_dygraph.py
...ts/collective/communication_stream_scatter_api_dygraph.py
+84
-0
python/paddle/fluid/tests/unittests/collective/communication_stream_sendrecv_api_dygraph.py
...s/collective/communication_stream_sendrecv_api_dygraph.py
+6
-3
python/paddle/fluid/tests/unittests/collective/test_communication_stream_alltoall_api.py
...ests/collective/test_communication_stream_alltoall_api.py
+51
-0
python/paddle/fluid/tests/unittests/collective/test_communication_stream_alltoall_single_api.py
...llective/test_communication_stream_alltoall_single_api.py
+53
-0
python/paddle/fluid/tests/unittests/collective/test_communication_stream_broadcast_api.py
...sts/collective/test_communication_stream_broadcast_api.py
+51
-0
python/paddle/fluid/tests/unittests/collective/test_communication_stream_reduce_api.py
...ttests/collective/test_communication_stream_reduce_api.py
+51
-0
python/paddle/fluid/tests/unittests/collective/test_communication_stream_reduce_scatter_api.py
...ollective/test_communication_stream_reduce_scatter_api.py
+53
-0
python/paddle/fluid/tests/unittests/collective/test_communication_stream_scatter_api.py
...tests/collective/test_communication_stream_scatter_api.py
+51
-0
python/paddle/fluid/tests/unittests/collective/testslist.csv
python/paddle/fluid/tests/unittests/collective/testslist.csv
+6
-0
未找到文件。
paddle/fluid/distributed/collective/ProcessGroup.h
浏览文件 @
f94edc3b
...
...
@@ -125,6 +125,16 @@ class ProcessGroup {
"ProcessGroup%s does not support broadcast"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
/* input tensors */
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
/* output tensors */
,
// NOLINT
const
BroadcastOptions
&
,
bool
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support broadcast with sync_op flag"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Barrier
(
const
BarrierOptions
&
=
BarrierOptions
())
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
...
...
@@ -160,14 +170,14 @@ class ProcessGroup {
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Send_Partial
(
phi
::
DenseTensor
&
,
// NOLINT
int
,
int
,
int
)
{
int
64_t
,
int
64_t
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support send_partial"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Send_Partial
(
phi
::
DenseTensor
&
,
int
,
int
,
in
t
,
bool
)
{
// NOLINT
phi
::
DenseTensor
&
,
int
,
int
64_t
,
int64_
t
,
bool
)
{
// NOLINT
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support send_partial with sync_op flag"
,
GetBackendName
()));
...
...
@@ -176,14 +186,14 @@ class ProcessGroup {
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Recv_Partial
(
phi
::
DenseTensor
&
,
// NOLINT
int
,
int
,
int
)
{
int
64_t
,
int
64_t
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support recv_partial"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Recv_Partial
(
phi
::
DenseTensor
&
,
int
,
int
,
in
t
,
bool
)
{
// NOLINT
phi
::
DenseTensor
&
,
int
,
int
64_t
,
int64_
t
,
bool
)
{
// NOLINT
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support recv_partial with sync_op flag"
,
GetBackendName
()));
...
...
@@ -208,8 +218,8 @@ class ProcessGroup {
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather_Partial
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
int
offset
,
int
length
)
{
// NOLINT
int
64_t
offset
,
int
64_t
length
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support AllGather_Partial"
,
GetBackendName
()));
}
...
...
@@ -217,9 +227,9 @@ class ProcessGroup {
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather_Partial
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
int
offset
,
int
length
,
bool
)
{
// NOLINT
int
64_t
offset
,
int
64_t
length
,
bool
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support AllGather_Partial"
,
GetBackendName
()));
}
...
...
@@ -231,6 +241,14 @@ class ProcessGroup {
"ProcessGroup%s does not support AllToAll"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllToAll
(
std
::
vector
<
phi
::
DenseTensor
>&
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
,
// NOLINT
bool
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support alltoall"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllToAll_Single
(
std
::
vector
<
phi
::
DenseTensor
>&
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
,
// NOLINT
...
...
@@ -240,26 +258,66 @@ class ProcessGroup {
"ProcessGroup%s does not support AllToAll_Single"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllToAllSingle
(
std
::
vector
<
phi
::
DenseTensor
>&
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
,
// NOLINT
std
::
vector
<
int64_t
>&
,
std
::
vector
<
int64_t
>&
,
bool
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support alltoall_single"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Reduce
(
std
::
vector
<
phi
::
DenseTensor
>&
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
,
// NOLINT
const
ReduceOptions
&
opts
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support Reduce"
,
GetBackendName
()));
"ProcessGroup%s does not support reduce"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Reduce
(
std
::
vector
<
phi
::
DenseTensor
>&
/* input tensors */
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
/* output tensors */
,
// NOLINT
const
ReduceOptions
&
,
bool
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support reduce with sync_op flag"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Scatter
(
std
::
vector
<
phi
::
DenseTensor
>&
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
,
// NOLINT
const
ScatterOptions
&
)
{
// NOLINT
const
ScatterOptions
&
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support Scatter"
,
GetBackendName
()));
"ProcessGroup%s does not support scatter"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Scatter
(
std
::
vector
<
phi
::
DenseTensor
>&
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
,
// NOLINT
const
ScatterOptions
&
,
bool
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support scatter with sync_op flag"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ReduceScatter
(
std
::
vector
<
phi
::
DenseTensor
>&
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
,
// NOLINT
const
ReduceScatterOptions
&
,
bool
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support reduce_scatter with sync_op flag"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
_ReduceScatterBase
(
phi
::
DenseTensor
&
,
// NOLINT
phi
::
DenseTensor
&
,
// NOLINT
const
ReduceScatterOptions
&
)
{
// NOLINT
phi
::
DenseTensor
&
,
// NOLINT
phi
::
DenseTensor
&
,
// NOLINT
const
ReduceScatterOptions
&
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support ReduceScatter"
,
GetBackendName
()));
}
...
...
paddle/fluid/distributed/collective/ProcessGroupCustom.cc
浏览文件 @
f94edc3b
...
...
@@ -267,8 +267,8 @@ void* XcclGetPointerByOffset(void* raw_pointer,
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupCustom
::
AllGather_Partial
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
int
offset
,
int
length
)
{
int
64_t
offset
,
int
64_t
length
)
{
PADDLE_ENFORCE_EQ
(
CheckTensorsInCustomPlace
(
in_tensors
,
device_type_
),
true
,
...
...
paddle/fluid/distributed/collective/ProcessGroupCustom.h
浏览文件 @
f94edc3b
...
...
@@ -80,8 +80,8 @@ class ProcessGroupCustom : public ProcessGroup {
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather_Partial
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
int
offset
,
int
length
)
override
;
int
64_t
offset
,
int
64_t
length
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
...
...
@@ -117,8 +117,8 @@ class ProcessGroupCustom : public ProcessGroup {
std
::
set
<
int
>
used_place_ids_
;
private:
void
BcastCustomId
(
std
::
vector
<
phi
::
ccl
::
CCLRootId
>&
ccl_ids
,
int
root
,
// NOLINT
void
BcastCustomId
(
std
::
vector
<
phi
::
ccl
::
CCLRootId
>&
ccl_ids
,
// NOLINT
int
root
,
int
server_fd
);
void
BroadcastUniqueCustomID
(
...
...
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
浏览文件 @
f94edc3b
...
...
@@ -628,6 +628,40 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
CommType
::
BROADCAST
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
BroadcastOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
{
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
in_tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
return
Collective
(
in_tensors
,
out_tensors
,
[
&
](
phi
::
DenseTensor
&
input
,
phi
::
DenseTensor
&
output
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
const
auto
root
=
opts
.
source_rank
*
in_tensors
.
size
()
+
opts
.
source_root
;
return
platform
::
dynload
::
ncclBroadcast
(
input
.
data
(),
output
.
data
(),
input
.
numel
(),
platform
::
ToNCCLDataType
(
input
.
type
()),
root
,
comm
,
stream
);
},
CommType
::
BROADCAST
,
sync_op
,
use_calc_stream
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Barrier
(
const
BarrierOptions
&
opts
)
{
// Only support single card single process
...
...
@@ -782,7 +816,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Send_Partial
(
phi
::
DenseTensor
&
tensors
,
int
dst_rank
,
int
offset
,
in
t
length
)
{
phi
::
DenseTensor
&
tensors
,
int
dst_rank
,
int
64_t
offset
,
int64_
t
length
)
{
// CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
phi
::
DenseTensor
flatten_tensor
;
...
...
@@ -813,8 +847,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send_Partial(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Send_Partial
(
phi
::
DenseTensor
&
tensors
,
int
dst_rank
,
int
offset
,
int
length
,
int
64_t
offset
,
int
64_t
length
,
bool
sync_op
,
bool
use_calc_stream
)
{
phi
::
DenseTensor
flatten_tensor
;
...
...
@@ -845,7 +879,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send_Partial(
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Recv_Partial
(
phi
::
DenseTensor
&
tensors
,
int
src_rank
,
int
offset
,
in
t
length
)
{
phi
::
DenseTensor
&
tensors
,
int
src_rank
,
int
64_t
offset
,
int64_
t
length
)
{
// phi::DenseTensor shared_input = tensors.Slice(offset, offset+length);
phi
::
DenseTensor
flatten_tensor
;
...
...
@@ -876,8 +910,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv_Partial(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Recv_Partial
(
phi
::
DenseTensor
&
tensors
,
int
src_rank
,
int
offset
,
int
length
,
int
64_t
offset
,
int
64_t
length
,
bool
sync_op
,
bool
use_calc_stream
)
{
phi
::
DenseTensor
flatten_tensor
;
...
...
@@ -1009,8 +1043,8 @@ void* GetPointerByOffset(void* raw_pointer,
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
AllGather_Partial
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
int
offset
,
int
length
)
{
int
64_t
offset
,
int
64_t
length
)
{
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
in_tensors
),
true
,
...
...
@@ -1040,8 +1074,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather_Partial(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
AllGather_Partial
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
int
offset
,
int
length
,
int
64_t
offset
,
int
64_t
length
,
bool
sync_op
,
bool
use_calc_stream
)
{
PADDLE_ENFORCE_EQ
(
...
...
@@ -1114,6 +1148,52 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
CommType
::
ALLTOALL
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
AllToAll
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
bool
sync_op
,
bool
use_calc_stream
)
{
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
in_tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
out_tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
return
Collective
(
in_tensors
,
out_tensors
,
[
&
](
phi
::
DenseTensor
&
input
,
phi
::
DenseTensor
&
output
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
size_t
offset
=
0
;
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupStart
());
for
(
auto
i
=
0
;
i
<
size_
;
i
++
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclSend
(
GetPointerByOffset
(
input
.
data
(),
offset
,
input
.
dtype
()),
input
.
numel
()
/
size_
,
platform
::
ToNCCLDataType
(
input
.
dtype
()),
i
,
comm
,
stream
));
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclRecv
(
GetPointerByOffset
(
output
.
data
(),
offset
,
input
.
dtype
()),
input
.
numel
()
/
size_
,
platform
::
ToNCCLDataType
(
input
.
dtype
()),
i
,
comm
,
stream
));
offset
+=
input
.
numel
()
/
size_
;
}
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupEnd
());
},
CommType
::
ALLTOALL
,
sync_op
,
use_calc_stream
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
AllToAll_Single
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
...
...
@@ -1176,6 +1256,72 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll_Single(
CommType
::
ALLTOALL_SINGLE
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
AllToAllSingle
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
std
::
vector
<
int64_t
>&
in_sizes
,
std
::
vector
<
int64_t
>&
out_sizes
,
bool
sync_op
,
bool
use_calc_stream
)
{
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
in_tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
out_tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
return
Collective
(
in_tensors
,
out_tensors
,
[
&
](
phi
::
DenseTensor
&
input
,
phi
::
DenseTensor
&
output
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
PADDLE_ENFORCE_EQ
(
input
.
dtype
()
==
output
.
dtype
(),
true
,
platform
::
errors
::
InvalidArgument
(
"The dtypes of input and output must be equal."
));
std
::
vector
<
int64_t
>
in_dims
=
phi
::
vectorize
(
input
.
dims
());
std
::
vector
<
int64_t
>
out_dims
=
phi
::
vectorize
(
output
.
dims
());
CheckSplitSizes
(
&
in_sizes
,
in_dims
);
CheckSplitSizes
(
&
out_sizes
,
out_dims
);
size_t
in_offset
=
0
,
out_offset
=
0
;
size_t
in_length
=
0
,
out_length
=
0
;
size_t
in_row_size
=
input
.
numel
()
/
in_dims
[
0
];
size_t
out_row_size
=
output
.
numel
()
/
out_dims
[
0
];
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupStart
());
for
(
auto
i
=
0
;
i
<
size_
;
i
++
)
{
in_length
=
in_sizes
[
i
]
*
in_row_size
;
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclSend
(
GetPointerByOffset
(
input
.
data
(),
in_offset
,
input
.
dtype
()),
in_length
,
platform
::
ToNCCLDataType
(
input
.
dtype
()),
i
,
comm
,
stream
));
in_offset
+=
in_length
;
out_length
=
out_sizes
[
i
]
*
out_row_size
;
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclRecv
(
GetPointerByOffset
(
output
.
data
(),
out_offset
,
input
.
dtype
()),
out_length
,
platform
::
ToNCCLDataType
(
input
.
dtype
()),
i
,
comm
,
stream
));
out_offset
+=
out_length
;
}
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupEnd
());
},
CommType
::
ALLTOALL_SINGLE
,
sync_op
,
use_calc_stream
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Reduce
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
...
...
@@ -1204,6 +1350,70 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
CommType
::
REDUCE
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Reduce
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
ReduceOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
{
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
in_tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
return
Collective
(
in_tensors
,
out_tensors
,
[
&
](
const
phi
::
DenseTensor
&
input
,
phi
::
DenseTensor
&
output
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclReduce
(
input
.
data
(),
output
.
data
(),
input
.
numel
(),
platform
::
ToNCCLDataType
(
input
.
dtype
()),
ToNCCLRedType
(
opts
.
reduce_op
),
opts
.
root_rank
,
comm
,
stream
));
},
CommType
::
REDUCE
,
sync_op
,
use_calc_stream
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
ReduceScatter
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
ReduceScatterOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
{
return
Collective
(
in_tensors
,
out_tensors
,
[
&
](
phi
::
DenseTensor
&
input
,
phi
::
DenseTensor
&
output
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
if
(
FLAGS_use_stream_safe_cuda_allocator
)
{
platform
::
CUDADeviceGuard
cuda_guard
;
cuda_guard
.
SetDevice
(
output
.
place
());
memory
::
RecordStream
(
output
.
Holder
(),
stream
);
}
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclReduceScatter
(
input
.
data
(),
output
.
data
(),
output
.
numel
(),
platform
::
ToNCCLDataType
(
input
.
dtype
()),
ToNCCLRedType
(
opts
.
reduce_op
),
comm
,
stream
));
},
CommType
::
REDUCE_SCATTER
,
sync_op
,
use_calc_stream
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Scatter
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
...
...
@@ -1257,6 +1467,68 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
CommType
::
SCATTER
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Scatter
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
ScatterOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
{
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
in_tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
out_tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
return
Collective
(
in_tensors
,
out_tensors
,
[
&
](
phi
::
DenseTensor
&
input
,
phi
::
DenseTensor
&
output
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
PADDLE_ENFORCE_EQ
(
output
.
numel
(),
input
.
numel
()
/
size_
,
platform
::
errors
::
InvalidArgument
(
"Input and output tensors should have the same shape."
));
size_t
offset
=
0
;
if
(
rank_
==
opts
.
root_rank
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupStart
());
for
(
auto
i
=
0
;
i
<
size_
;
i
++
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclSend
(
GetPointerByOffset
(
input
.
data
(),
offset
,
input
.
dtype
()),
input
.
numel
()
/
size_
,
platform
::
ToNCCLDataType
(
input
.
dtype
()),
i
,
comm
,
stream
));
offset
+=
input
.
numel
()
/
size_
;
}
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclRecv
(
output
.
data
(),
input
.
numel
()
/
size_
,
platform
::
ToNCCLDataType
(
input
.
dtype
()),
opts
.
root_rank
,
comm
,
stream
));
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupEnd
());
}
else
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclRecv
(
output
.
data
(),
input
.
numel
()
/
size_
,
platform
::
ToNCCLDataType
(
input
.
dtype
()),
opts
.
root_rank
,
comm
,
stream
));
}
},
CommType
::
SCATTER
,
sync_op
,
use_calc_stream
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
_ReduceScatterBase
(
phi
::
DenseTensor
&
out_tensor
,
phi
::
DenseTensor
&
in_tensor
,
...
...
paddle/fluid/distributed/collective/ProcessGroupNCCL.h
浏览文件 @
f94edc3b
...
...
@@ -119,6 +119,13 @@ class ProcessGroupNCCL : public ProcessGroupStream {
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
BroadcastOptions
&
=
BroadcastOptions
())
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
BroadcastOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Barrier
(
const
BarrierOptions
&
=
BarrierOptions
())
override
;
...
...
@@ -142,27 +149,27 @@ class ProcessGroupNCCL : public ProcessGroupStream {
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Send_Partial
(
phi
::
DenseTensor
&
tensors
,
int
dst_rank
,
int
offset
,
int
length
)
override
;
int
64_t
offset
,
int
64_t
length
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Send_Partial
(
phi
::
DenseTensor
&
tensors
,
int
dst_rank
,
int
offset
,
int
length
,
int
64_t
offset
,
int
64_t
length
,
bool
sync_op
,
bool
use_calc_stream
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Recv_Partial
(
phi
::
DenseTensor
&
tensors
,
int
src_rank
,
int
offset
,
int
length
)
override
;
int
64_t
offset
,
int
64_t
length
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Recv_Partial
(
phi
::
DenseTensor
&
tensors
,
int
src_rank
,
int
offset
,
int
length
,
int
64_t
offset
,
int
64_t
length
,
bool
sync_op
,
bool
use_calc_stream
)
override
;
...
...
@@ -179,20 +186,26 @@ class ProcessGroupNCCL : public ProcessGroupStream {
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather_Partial
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
int
offset
,
int
length
)
override
;
int
64_t
offset
,
int
64_t
length
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather_Partial
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
int
offset
,
int
length
,
int
64_t
offset
,
int
64_t
length
,
bool
sync_op
,
bool
use_calc_stream
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllToAll
(
std
::
vector
<
phi
::
DenseTensor
>&
in
,
std
::
vector
<
phi
::
DenseTensor
>&
out
)
override
;
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllToAll
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
bool
sync_op
,
bool
use_calc_stream
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllToAll_Single
(
std
::
vector
<
phi
::
DenseTensor
>&
in
,
...
...
@@ -200,15 +213,44 @@ class ProcessGroupNCCL : public ProcessGroupStream {
std
::
vector
<
int64_t
>&
in_sizes
,
std
::
vector
<
int64_t
>&
out_sizes
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllToAllSingle
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
std
::
vector
<
int64_t
>&
in_sizes
,
std
::
vector
<
int64_t
>&
out_sizes
,
bool
sync_op
,
bool
use_calc_stream
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Reduce
(
std
::
vector
<
phi
::
DenseTensor
>&
tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
ReduceOptions
&
opts
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Reduce
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
ReduceOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ReduceScatter
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
ReduceScatterOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Scatter
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
ScatterOptions
&
)
override
;
const
ScatterOptions
&
opts
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Scatter
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
ScatterOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
_ReduceScatterBase
(
phi
::
DenseTensor
&
,
// NOLINT
...
...
paddle/fluid/distributed/collective/ProcessGroupStream.cc
浏览文件 @
f94edc3b
...
...
@@ -70,6 +70,138 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
"ProcessGroup%s does not support do all_reduce"
,
GetBackendName
()));
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllToAll
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
bool
sync_op
)
{
return
AllToAll
(
in_tensors
,
out_tensors
,
sync_op
,
/*use_calc_stream*/
false
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllToAll
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
bool
sync_op
,
bool
use_calc_stream
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support do alltoall"
,
GetBackendName
()));
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllToAllSingle
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
std
::
vector
<
int64_t
>&
in_sizes
,
std
::
vector
<
int64_t
>&
out_sizes
,
bool
sync_op
)
{
return
AllToAllSingle
(
in_tensors
,
out_tensors
,
in_sizes
,
out_sizes
,
sync_op
,
/*use_calc_stream*/
false
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllToAllSingle
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
std
::
vector
<
int64_t
>&
in_sizes
,
std
::
vector
<
int64_t
>&
out_sizes
,
bool
sync_op
,
bool
use_calc_stream
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support do alltoall_single"
,
GetBackendName
()));
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
BroadcastOptions
&
opts
,
bool
sync_op
)
{
return
Broadcast
(
in_tensors
,
out_tensors
,
opts
,
sync_op
,
/*use_calc_stream*/
false
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
BroadcastOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support do broadcast"
,
GetBackendName
()));
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Reduce
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
ReduceOptions
&
opts
,
bool
sync_op
)
{
return
Reduce
(
in_tensors
,
out_tensors
,
opts
,
sync_op
,
/*use_calc_stream*/
false
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Reduce
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
ReduceOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support do reduce"
,
GetBackendName
()));
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
ReduceScatter
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
ReduceScatterOptions
&
opts
,
bool
sync_op
)
{
return
ReduceScatter
(
in_tensors
,
out_tensors
,
opts
,
sync_op
,
/*use_calc_stream*/
false
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
ReduceScatter
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
ReduceScatterOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support do reduce_scatter"
,
GetBackendName
()));
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Scatter
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
ScatterOptions
&
opts
,
bool
sync_op
)
{
return
Scatter
(
in_tensors
,
out_tensors
,
opts
,
sync_op
,
/*use_calc_stream*/
false
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Scatter
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
ScatterOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support do scatter"
,
GetBackendName
()));
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Send
(
std
::
vector
<
phi
::
DenseTensor
>&
tensors
,
int
dst_rank
,
bool
sync_op
)
{
return
Send
(
tensors
,
...
...
@@ -90,8 +222,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Send_Partial
(
phi
::
DenseTensor
&
tensors
,
int
dst_rank
,
int
offset
,
int
length
,
int
64_t
offset
,
int
64_t
length
,
bool
sync_op
)
{
return
Send_Partial
(
tensors
,
dst_rank
,
...
...
@@ -104,8 +236,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send_Partial(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Send_Partial
(
phi
::
DenseTensor
&
tensors
,
int
dst_rank
,
int
offset
,
int
length
,
int
64_t
offset
,
int
64_t
length
,
bool
sync_op
,
bool
use_calc_stream
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
...
...
@@ -132,8 +264,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Recv_Partial
(
phi
::
DenseTensor
&
tensors
,
int
src_rank
,
int
offset
,
int
length
,
int
64_t
offset
,
int
64_t
length
,
bool
sync_op
)
{
return
Recv_Partial
(
tensors
,
src_rank
,
...
...
@@ -146,8 +278,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv_Partial(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Recv_Partial
(
phi
::
DenseTensor
&
tensors
,
int
src_rank
,
int
offset
,
int
length
,
int
64_t
offset
,
int
64_t
length
,
bool
sync_op
,
bool
use_calc_stream
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
...
...
@@ -157,8 +289,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv_Partial(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllGather_Partial
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
int
offset
,
int
length
,
int
64_t
offset
,
int
64_t
length
,
bool
sync_op
)
{
return
AllGather_Partial
(
in_tensors
,
out_tensors
,
...
...
@@ -171,8 +303,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather_Partial(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllGather_Partial
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
int
offset
,
int
length
,
int
64_t
offset
,
int
64_t
length
,
bool
sync_op
,
bool
use_calc_stream
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
...
...
paddle/fluid/distributed/collective/ProcessGroupStream.h
浏览文件 @
f94edc3b
...
...
@@ -81,6 +81,84 @@ class ProcessGroupStream : public ProcessGroup {
bool
sync_op
,
bool
use_calc_stream
);
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllToAll
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
bool
sync_op
)
override
;
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllToAll
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
bool
sync_op
,
bool
use_calc_stream
);
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllToAllSingle
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
std
::
vector
<
int64_t
>&
in_sizes
,
// NOLINT
std
::
vector
<
int64_t
>&
out_sizes
,
// NOLINT
bool
sync_op
)
override
;
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllToAllSingle
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
std
::
vector
<
int64_t
>&
in_sizes
,
// NOLINT
std
::
vector
<
int64_t
>&
out_sizes
,
// NOLINT
bool
sync_op
,
bool
use_calc_stream
);
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
const
BroadcastOptions
&
opts
,
bool
sync_op
)
override
;
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
const
BroadcastOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
);
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Reduce
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
const
ReduceOptions
&
opts
,
bool
sync_op
)
override
;
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Reduce
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
const
ReduceOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
);
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ReduceScatter
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
const
ReduceScatterOptions
&
opts
,
bool
sync_op
)
override
;
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ReduceScatter
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
const
ReduceScatterOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
);
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Scatter
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
const
ScatterOptions
&
opts
,
bool
sync_op
)
override
;
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Scatter
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
const
ScatterOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
);
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Send
(
std
::
vector
<
phi
::
DenseTensor
>&
tensors
,
// NOLINT
int
dst_rank
,
...
...
@@ -95,15 +173,15 @@ class ProcessGroupStream : public ProcessGroup {
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Send_Partial
(
phi
::
DenseTensor
&
tensors
,
// NOLINT
int
dst_rank
,
int
offset
,
int
length
,
int
64_t
offset
,
int
64_t
length
,
bool
sync_op
)
override
;
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Send_Partial
(
phi
::
DenseTensor
&
tensors
,
// NOLINT
int
dst_rank
,
int
offset
,
int
length
,
int
64_t
offset
,
int
64_t
length
,
bool
sync_op
,
bool
use_calc_stream
);
...
...
@@ -121,30 +199,30 @@ class ProcessGroupStream : public ProcessGroup {
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Recv_Partial
(
phi
::
DenseTensor
&
tensors
,
// NOLINT
int
src_rank
,
int
offset
,
int
length
,
int
64_t
offset
,
int
64_t
length
,
bool
sync_op
)
override
;
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Recv_Partial
(
phi
::
DenseTensor
&
tensors
,
// NOLINT
int
src_rank
,
int
offset
,
int
length
,
int
64_t
offset
,
int
64_t
length
,
bool
sync_op
,
bool
use_calc_stream
);
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather_Partial
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
int
offset
,
int
length
,
int
64_t
offset
,
int
64_t
length
,
bool
sync_op
)
override
;
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather_Partial
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
int
offset
,
int
length
,
int
64_t
offset
,
int
64_t
length
,
bool
sync_op
,
bool
use_calc_stream
);
};
...
...
paddle/fluid/distributed/collective/Utils.h
浏览文件 @
f94edc3b
...
...
@@ -14,14 +14,26 @@
#pragma once
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/device_guard.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
namespace
paddle
{
namespace
distributed
{
template
<
typename
DeviceContext
,
typename
T
>
struct
ConcatDenseTensor
{
void
operator
()(
const
DeviceContext
*
context
,
const
std
::
vector
<
phi
::
DenseTensor
>
&
in
,
phi
::
DenseTensor
*
out
,
int
axis
=
0
)
{
phi
::
funcs
::
ConcatFunctor
<
DeviceContext
,
T
>
concat_functor
;
concat_functor
(
*
context
,
in
,
axis
,
out
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
struct
SplitDenseTensor
{
void
operator
()(
const
DeviceContext
*
context
,
...
...
@@ -33,17 +45,36 @@ struct SplitDenseTensor {
for
(
auto
*
p_tensor
:
*
out
)
{
shape_refer
.
emplace_back
(
p_tensor
);
}
operators
::
math
::
SplitFunctor
<
DeviceContext
,
T
>
split_functor_
;
split_functor
_
(
*
context
,
in
,
shape_refer
,
axis
,
out
);
phi
::
funcs
::
SplitFunctor
<
DeviceContext
,
T
>
split_functor
;
split_functor
(
*
context
,
in
,
shape_refer
,
axis
,
out
);
}
};
#ifdef PADDLE_WITH_CUSTOM_DEVICE
template
<
typename
T
>
struct
ConcatDenseTensor
<
platform
::
CustomDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CustomDeviceContext
*
context
,
const
std
::
vector
<
phi
::
DenseTensor
>
&
in
,
phi
::
DenseTensor
*
out
,
int
axis
=
0
)
{
auto
*
out_data
=
out
->
data
<
T
>
();
auto
*
device
=
phi
::
DeviceManager
::
GetDeviceWithPlace
(
context
->
GetPlace
());
size_t
offset
=
0
;
for
(
const
auto
&
tensor
:
in
)
{
const
auto
*
in_data
=
tensor
.
data
<
T
>
();
auto
sz
=
tensor
.
numel
()
*
sizeof
(
T
);
device
->
MemoryCopyD2D
(
out_data
+
offset
,
in_data
,
sz
,
nullptr
);
offset
+=
sz
;
}
}
};
template
<
typename
T
>
struct
SplitDenseTensor
<
platform
::
CustomDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CustomDeviceContext
*
context
,
const
phi
::
DenseTensor
&
in
,
std
::
vector
<
phi
::
DenseTensor
*>
*
out
)
{
std
::
vector
<
phi
::
DenseTensor
*>
*
out
,
int
axis
=
0
)
{
auto
*
in_data
=
in
.
data
<
T
>
();
auto
*
device
=
phi
::
DeviceManager
::
GetDeviceWithPlace
(
context
->
GetPlace
());
size_t
offset
=
0
;
...
...
@@ -57,42 +88,119 @@ struct SplitDenseTensor<platform::CustomDeviceContext, T> {
};
#endif
template
<
typename
DeviceContext
>
void
ConcatDenseTensorWithType
(
const
DeviceContext
*
dev_ctx
,
const
std
::
vector
<
phi
::
DenseTensor
>
&
t_list
,
phi
::
DenseTensor
*
p_out
,
phi
::
DataType
type
)
{
switch
(
type
)
{
case
phi
::
DataType
::
BOOL
:
ConcatDenseTensor
<
DeviceContext
,
bool
>
()(
dev_ctx
,
t_list
,
p_out
);
break
;
case
phi
::
DataType
::
UINT8
:
ConcatDenseTensor
<
DeviceContext
,
uint8_t
>
()(
dev_ctx
,
t_list
,
p_out
);
break
;
case
phi
::
DataType
::
INT8
:
ConcatDenseTensor
<
DeviceContext
,
int8_t
>
()(
dev_ctx
,
t_list
,
p_out
);
break
;
case
phi
::
DataType
::
INT32
:
ConcatDenseTensor
<
DeviceContext
,
int32_t
>
()(
dev_ctx
,
t_list
,
p_out
);
break
;
case
phi
::
DataType
::
INT64
:
ConcatDenseTensor
<
DeviceContext
,
int64_t
>
()(
dev_ctx
,
t_list
,
p_out
);
break
;
case
phi
::
DataType
::
FLOAT16
:
ConcatDenseTensor
<
DeviceContext
,
platform
::
float16
>
()(
dev_ctx
,
t_list
,
p_out
);
break
;
case
phi
::
DataType
::
FLOAT32
:
ConcatDenseTensor
<
DeviceContext
,
float
>
()(
dev_ctx
,
t_list
,
p_out
);
break
;
case
phi
::
DataType
::
FLOAT64
:
ConcatDenseTensor
<
DeviceContext
,
double
>
()(
dev_ctx
,
t_list
,
p_out
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Data type (%s) is not supported when it concats tensors."
,
type
));
}
}
template
<
typename
DeviceContext
>
void
SplitDenseTensorWithType
(
const
DeviceContext
*
dev_ctx
,
const
phi
::
DenseTensor
&
p_dense
,
const
phi
::
DenseTensor
&
t_in
,
std
::
vector
<
phi
::
DenseTensor
*>
*
p_list
,
phi
::
DataType
type
)
{
switch
(
type
)
{
case
phi
::
DataType
::
BOOL
:
SplitDenseTensor
<
DeviceContext
,
bool
>
()(
dev_ctx
,
p_dense
,
p_list
);
SplitDenseTensor
<
DeviceContext
,
bool
>
()(
dev_ctx
,
t_in
,
p_list
);
break
;
case
phi
::
DataType
::
UINT8
:
SplitDenseTensor
<
DeviceContext
,
uint8_t
>
()(
dev_ctx
,
p_dense
,
p_list
);
SplitDenseTensor
<
DeviceContext
,
uint8_t
>
()(
dev_ctx
,
t_in
,
p_list
);
break
;
case
phi
::
DataType
::
INT8
:
SplitDenseTensor
<
DeviceContext
,
int8_t
>
()(
dev_ctx
,
p_dense
,
p_list
);
SplitDenseTensor
<
DeviceContext
,
int8_t
>
()(
dev_ctx
,
t_in
,
p_list
);
break
;
case
phi
::
DataType
::
INT32
:
SplitDenseTensor
<
DeviceContext
,
int32_t
>
()(
dev_ctx
,
p_dense
,
p_list
);
SplitDenseTensor
<
DeviceContext
,
int32_t
>
()(
dev_ctx
,
t_in
,
p_list
);
break
;
case
phi
::
DataType
::
INT64
:
SplitDenseTensor
<
DeviceContext
,
int64_t
>
()(
dev_ctx
,
p_dense
,
p_list
);
SplitDenseTensor
<
DeviceContext
,
int64_t
>
()(
dev_ctx
,
t_in
,
p_list
);
break
;
case
phi
::
DataType
::
FLOAT16
:
SplitDenseTensor
<
DeviceContext
,
platform
::
float16
>
()(
dev_ctx
,
p_dense
,
p_list
);
dev_ctx
,
t_in
,
p_list
);
break
;
case
phi
::
DataType
::
FLOAT32
:
SplitDenseTensor
<
DeviceContext
,
float
>
()(
dev_ctx
,
p_dense
,
p_list
);
SplitDenseTensor
<
DeviceContext
,
float
>
()(
dev_ctx
,
t_in
,
p_list
);
break
;
case
phi
::
DataType
::
FLOAT64
:
SplitDenseTensor
<
DeviceContext
,
double
>
()(
dev_ctx
,
p_dense
,
p_list
);
SplitDenseTensor
<
DeviceContext
,
double
>
()(
dev_ctx
,
t_in
,
p_list
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Data type (%s) is not supported when it splits tensors for "
"allgather."
,
type
));
"Data type (%s) is not supported when it splits tensors."
,
type
));
}
}
void
ConcatTensor
(
const
phi
::
DeviceContext
*
dev_ctx
,
const
std
::
vector
<
phi
::
DenseTensor
>
&
tensor_list
,
const
experimental
::
Tensor
*
tensor
)
{
auto
*
dense_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
->
impl
()).
get
();
const
auto
&
place
=
dev_ctx
->
GetPlace
();
if
(
platform
::
is_gpu_place
(
place
))
{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
ConcatDenseTensorWithType
(
static_cast
<
const
phi
::
GPUContext
*>
(
dev_ctx
),
tensor_list
,
dense_tensor
,
tensor
->
dtype
());
#else
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Paddle can't concat tensor since it's not support GPU, please "
"recompile or reinstall Paddle with GPU support."
));
#endif
}
else
if
(
platform
::
is_custom_place
(
place
))
{
#ifdef PADDLE_WITH_CUSTOM_DEVICE
ConcatDenseTensorWithType
(
static_cast
<
const
platform
::
CustomDeviceContext
*>
(
dev_ctx
),
tensor_list
,
dense_tensor
,
tensor
->
dtype
());
#else
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Paddle can't concat tensor since it's not compiled with "
"CUSTOM_DEVICE, please recompile or reinstall Paddle with "
"CUSTOM_DEVICE support."
));
#endif
}
else
if
(
platform
::
is_cpu_place
(
place
))
{
ConcatDenseTensorWithType
(
static_cast
<
const
phi
::
CPUContext
*>
(
dev_ctx
),
tensor_list
,
dense_tensor
,
tensor
->
dtype
());
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Concat tensor not supported on place (%s)"
,
place
));
}
}
...
...
@@ -115,8 +223,8 @@ void SplitTensor(const phi::DeviceContext *dev_ctx,
tensor
.
dtype
());
#else
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Paddle can't split tensor since it's not support
NCCL/RCCL
, please "
"recompile or reinstall Paddle with
NCCL/RCCL
support."
));
"Paddle can't split tensor since it's not support
GPU
, please "
"recompile or reinstall Paddle with
GPU
support."
));
#endif
}
else
if
(
platform
::
is_custom_place
(
place
))
{
#ifdef PADDLE_WITH_CUSTOM_DEVICE
...
...
paddle/fluid/pybind/distributed_py.cc
浏览文件 @
f94edc3b
此差异已折叠。
点击以展开。
python/paddle/distributed/communication/stream/__init__.py
浏览文件 @
f94edc3b
...
...
@@ -14,7 +14,16 @@
from
.all_gather
import
all_gather
from
.all_reduce
import
all_reduce
from
.send
import
send
from
.alltoall
import
alltoall
from
.alltoall_single
import
alltoall_single
from
.broadcast
import
broadcast
from
.reduce
import
reduce
from
.reduce_scatter
import
_reduce_scatter_base
,
reduce_scatter
from
.recv
import
recv
from
.scatter
import
scatter
from
.send
import
send
__all__
=
[
"all_gather"
,
"all_reduce"
,
"send"
,
"recv"
]
__all__
=
[
"_reduce_scatter_base"
,
"all_reduce"
,
"alltoall"
,
"alltoall_single"
,
"broadcast"
,
"reduce"
,
"reduce_scatter"
,
"recv"
,
"scatter"
,
"send"
]
python/paddle/distributed/communication/stream/all_gather.py
浏览文件 @
f94edc3b
...
...
@@ -34,17 +34,18 @@ def _check_tensor_list_shape(tensor_list, shape, nranks=1):
'The tensor_list for all_gather is not correctly-sized.'
)
def
_all_gather_
base
_in_dygraph
(
out_tensor
,
in_tensor
,
group
,
sync_op
,
use_calc_stream
):
def
_all_gather_
into_tensor
_in_dygraph
(
out_tensor
,
in_tensor
,
group
,
sync_op
,
use_calc_stream
):
group
=
collective
.
_get_default_group
()
if
group
is
None
else
group
_check_tensor_shape
(
out_tensor
,
in_tensor
.
shape
,
group
.
nranks
)
if
use_calc_stream
:
return
group
.
process_group
.
allgather_
base
_on_calc_stream
(
return
group
.
process_group
.
allgather_
into_tensor
_on_calc_stream
(
in_tensor
,
out_tensor
)
task
=
group
.
process_group
.
allgather_base
(
in_tensor
,
out_tensor
,
sync_op
)
task
=
group
.
process_group
.
allgather_into_tensor
(
in_tensor
,
out_tensor
,
sync_op
)
if
sync_op
:
task
.
wait
()
...
...
@@ -83,7 +84,7 @@ def all_gather(tensor_or_tensor_list,
tensor_or_tensor_list (Union[Tensor, List[Tensor]]): The output. If it is a tensor, it should be correctly-sized. If it is a list, it
should be empty or contain correctly-sized tensors.
tensor (Tensor): The input tensor on each rank. The result will overwrite this tenor after communication. Support
float16, float32, float64, int32
or int64
as the input data type.
float16, float32, float64, int32
, int64, int8, uint8 or bool
as the input data type.
group (Group, optional): Communicate in which group. If none is given, use the global group as default.
sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default.
use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This
...
...
@@ -125,8 +126,9 @@ def all_gather(tensor_or_tensor_list,
if
framework
.
in_dygraph_mode
():
if
paddle
.
is_tensor
(
tensor_or_tensor_list
):
return
_all_gather_base_in_dygraph
(
tensor_or_tensor_list
,
tensor
,
group
,
sync_op
,
use_calc_stream
)
return
_all_gather_into_tensor_in_dygraph
(
tensor_or_tensor_list
,
tensor
,
group
,
sync_op
,
use_calc_stream
)
else
:
return
_all_gather_in_dygraph
(
tensor_or_tensor_list
,
tensor
,
group
,
sync_op
,
use_calc_stream
)
...
...
python/paddle/distributed/communication/stream/all_reduce.py
浏览文件 @
f94edc3b
...
...
@@ -70,7 +70,7 @@ def all_reduce(tensor,
Args:
tensor (Tensor): The input tensor on each rank. The result will overwrite this tenor after communication. Support
float16, float32, float64, int32
or int64
as the input data type.
float16, float32, float64, int32
, int64, int8, uint8 or bool
as the input data type.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The reduction used. If none is given, use ReduceOp.SUM as default.
group (Group, optional): Communicate in which group. If none is given, use the global group as default.
sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default.
...
...
python/paddle/distributed/communication/stream/alltoall.py
0 → 100644
浏览文件 @
f94edc3b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.fluid.framework
as
framework
from
paddle.distributed
import
collective
def
_check_tensor_shape
(
tensor
,
shape
,
nranks
=
1
):
if
tensor
.
shape
!=
shape
:
raise
RuntimeError
(
'The tensor for alltoall is not correctly-sized.'
)
def
_check_tensor_list_shape
(
tensor_list
,
shape
,
nranks
=
1
):
if
len
(
tensor_list
)
!=
nranks
:
raise
RuntimeError
(
'The tensor_list for alltoall is not correctly-sized.'
)
for
tensor
in
tensor_list
:
if
tensor
.
shape
!=
shape
:
raise
RuntimeError
(
'The tensor_list for alltoall is not correctly-sized.'
)
def
_alltoall_tensor_in_dygraph
(
out_tensor
,
in_tensor
,
group
,
sync_op
,
use_calc_stream
):
group
=
collective
.
_get_default_group
()
if
group
is
None
else
group
_check_tensor_shape
(
out_tensor
,
in_tensor
.
shape
,
group
.
nranks
)
if
use_calc_stream
:
return
group
.
process_group
.
alltoall_tensor_on_calc_stream
(
in_tensor
,
out_tensor
)
task
=
group
.
process_group
.
alltoall_tensor
(
in_tensor
,
out_tensor
,
sync_op
)
if
sync_op
:
task
.
wait
()
return
task
def
_alltoall_in_dygraph
(
out_tensor_list
,
in_tensor_list
,
group
,
sync_op
,
use_calc_stream
):
group
=
collective
.
_get_default_group
()
if
group
is
None
else
group
if
len
(
in_tensor_list
)
==
0
:
raise
RuntimeError
(
"The input tensor_list should not be empty."
)
if
len
(
out_tensor_list
)
==
0
:
out_tensor_list
+=
[
paddle
.
empty_like
(
tensor
)
for
tensor
in
in_tensor_list
]
else
:
_check_tensor_list_shape
(
out_tensor_list
,
in_tensor_list
[
0
].
shape
,
group
.
nranks
)
if
use_calc_stream
:
return
group
.
process_group
.
alltoall_on_calc_stream
(
in_tensor_list
,
out_tensor_list
)
task
=
group
.
process_group
.
alltoall
(
in_tensor_list
,
out_tensor_list
,
sync_op
)
if
sync_op
:
task
.
wait
()
return
task
def
alltoall
(
out_tensor_or_tensor_list
,
in_tensor_or_tensor_list
,
group
=
None
,
sync_op
=
True
,
use_calc_stream
=
False
):
"""
Scatter a tensor (or a tensor list) across devices and gather outputs to another tensor (or a tensor list, respectively).
Args:
out_tensor_or_tensor_list (Union[Tensor, List[Tensor]]): The output. If it is a tensor, it should be correctly-sized.
If it is a list, it should be empty or contain correctly-sized tensors. Its data type should be the same as the input.
in_tensor_or_tensor_list (Union[Tensor, List[Tensor]]): The input to scatter (must be specified on the source rank).
If it is a tensor, it should be correctly-sized. If it is a list, it should contain correctly-sized tensors. Support
float16, float32, float64, int32, int64, int8, uint8 or bool as the input data type.
group (Group, optional): Communicate in which group. If none is given, use the global group as default.
sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default.
use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This
option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning.
Returns:
Return a task object.
Warning:
This API only supports the dygraph mode now.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
out_tensor_list = []
if dist.get_rank() == 0:
data1 = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])
data2 = paddle.to_tensor([[7, 8, 9], [10, 11, 12]])
else:
data1 = paddle.to_tensor([[13, 14, 15], [16, 17, 18]])
data2 = paddle.to_tensor([[19, 20, 21], [22, 23, 24]])
task = dist.stream.alltoall(out_tensor_list, [data1, data2], sync_op=False)
task.wait()
print(out_tensor_list)
# [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]] (2 GPUs, out for rank 0)
# [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]] (2 GPUs, out for rank 1)
"""
if
group
is
not
None
and
not
group
.
is_member
():
raise
RuntimeError
(
"The group should not be None and all ranks which invoke this operation should be the member of this group."
)
if
not
sync_op
and
use_calc_stream
:
raise
RuntimeError
(
"use_calc_stream can only be true in sync op behavior."
)
if
out_tensor_or_tensor_list
is
None
:
raise
RuntimeError
(
"The output should be specified."
)
if
in_tensor_or_tensor_list
is
None
:
raise
RuntimeError
(
"The input should be specified."
)
if
framework
.
in_dygraph_mode
():
out_is_tensor
=
paddle
.
is_tensor
(
out_tensor_or_tensor_list
)
in_is_tensor
=
paddle
.
is_tensor
(
in_tensor_or_tensor_list
)
if
out_is_tensor
and
in_is_tensor
:
return
_alltoall_tensor_in_dygraph
(
out_tensor_or_tensor_list
,
in_tensor_or_tensor_list
,
group
,
sync_op
,
use_calc_stream
)
elif
not
out_is_tensor
and
not
in_is_tensor
:
return
_alltoall_in_dygraph
(
out_tensor_or_tensor_list
,
in_tensor_or_tensor_list
,
group
,
sync_op
,
use_calc_stream
)
else
:
raise
RuntimeError
(
"The output and input should be both tensor or tensor list."
)
raise
RuntimeError
(
"paddle.distributed.stream.alltoall is only supported in dygraph mode now."
)
python/paddle/distributed/communication/stream/alltoall_single.py
0 → 100644
浏览文件 @
f94edc3b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle.fluid.framework
as
framework
from
paddle.distributed
import
collective
def
_alltoall_single_in_dygraph
(
out_tensor
,
in_tensor
,
out_split_sizes
,
in_split_sizes
,
group
,
sync_op
,
use_calc_stream
):
group
=
collective
.
_get_default_group
()
if
group
is
None
else
group
if
out_split_sizes
is
None
:
out_split_sizes
=
[]
if
in_split_sizes
is
None
:
in_split_sizes
=
[]
if
use_calc_stream
:
return
group
.
process_group
.
alltoall_single_on_calc_stream
(
in_tensor
,
out_tensor
,
in_split_sizes
,
out_split_sizes
)
task
=
group
.
process_group
.
alltoall_single
(
in_tensor
,
out_tensor
,
in_split_sizes
,
out_split_sizes
,
sync_op
)
if
sync_op
:
task
.
wait
()
return
task
def
alltoall_single
(
out_tensor
,
in_tensor
,
out_split_sizes
=
None
,
in_split_sizes
=
None
,
group
=
None
,
sync_op
=
True
,
use_calc_stream
=
False
):
"""
Split and Scatter the splitted input tensor to the out tensor across devices.
Args:
out_tensor(Tensor): The output tensor. Its data type should be the same as the input.
in_tensor (Tensor): The input tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8 or bool.
out_split_sizes (List[int], optional): Split sizes of out_tensor for dim[0]. If not given, dim[0] of out_tensor must be divisible
by group size and out_tensor will be gathered averagely from all participators. If none is given, use a empty list as default.
in_split_sizes (List[int], optional): Split sizes of in_tensor for dim[0]. If not given, dim[0] of in_tensor must be divisible
by group size and in_tensor will be scattered averagely to all participators. If none is given, use a empty list as default.
group (Group, optional): Communicate in which group. If none is given, use the global group as default.
sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default.
use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This
option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning.
Returns:
Return a task object.
Warning:
This API only supports the dygraph mode now.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
local_rank = dist.get_rank()
# case 1
output = paddle.empty([2], dtype="int64")
if local_rank == 0:
data = paddle.to_tensor([0, 1])
else:
data = paddle.to_tensor([2, 3])
task = dist.stream.alltoall_single(output, data, sync_op=False)
task.wait()
out = output.numpy()
# [0, 2] (2 GPUs, out for rank 0)
# [1, 3] (2 GPUs, out for rank 1)
# case 2
size = dist.get_world_size()
output = paddle.empty([(local_rank + 1) * size, size], dtype='float32')
if local_rank == 0:
data = paddle.to_tensor([[0., 0.], [0., 0.], [0., 0.]])
else:
data = paddle.to_tensor([[1., 1.], [1., 1.], [1., 1.]])
out_split_sizes = [local_rank + 1 for i in range(size)]
in_split_sizes = [i + 1 for i in range(size)]
task = dist.stream.alltoall_single(output,
data,
out_split_sizes,
in_split_sizes,
sync_op=False)
task.wait()
out = output.numpy()
# [[0., 0.], [1., 1.]] (2 GPUs, out for rank 0)
# [[0., 0.], [0., 0.], [1., 1.], [1., 1.]] (2 GPUs, out for rank 1)
"""
if
group
is
not
None
and
not
group
.
is_member
():
raise
RuntimeError
(
"The group should not be None and all ranks which invoke this operation should be the member of this group."
)
if
not
sync_op
and
use_calc_stream
:
raise
RuntimeError
(
"use_calc_stream can only be true in sync op behavior."
)
if
framework
.
in_dygraph_mode
():
return
_alltoall_single_in_dygraph
(
out_tensor
,
in_tensor
,
out_split_sizes
,
in_split_sizes
,
group
,
sync_op
,
use_calc_stream
)
raise
RuntimeError
(
"paddle.distributed.stream.alltoall_single is only supported in dygraph mode now."
)
python/paddle/distributed/communication/stream/broadcast.py
0 → 100644
浏览文件 @
f94edc3b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle.fluid.framework
as
framework
from
paddle.distributed
import
collective
def
_broadcast_in_dygraph
(
tensor
,
src
,
group
,
sync_op
,
use_calc_stream
):
group
=
collective
.
_get_default_group
()
if
group
is
None
else
group
if
use_calc_stream
:
return
group
.
process_group
.
broadcast_on_calc_stream
(
tensor
,
src
)
task
=
group
.
process_group
.
broadcast
(
tensor
,
src
,
sync_op
)
if
sync_op
:
task
.
wait
()
return
task
def
broadcast
(
tensor
,
src
=
0
,
group
=
None
,
sync_op
=
True
,
use_calc_stream
=
False
):
"""
Broadcast a tensor to all devices.
Args:
tensor (Tensor): The tensor to broadcast. Support float16, float32, float64, int32, int64, int8, uint8 or bool as its data type.
src (int, optional): Rank of the source device. If none is given, use `0` as default.
group (Group, optional): Communicate in which group. If none is given, use the global group as default.
sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default.
use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This
option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning.
Returns:
Return a task object.
Warning:
This API only supports the dygraph mode now.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
local_rank = dist.get_rank()
if local_rank == 0:
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
else:
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
task = dist.stream.broadcast(data, src=1, sync_op=False)
task.wait()
out = data.numpy()
# [[1, 2, 3], [1, 2, 3]] (2 GPUs)
"""
if
group
is
not
None
and
not
group
.
is_member
():
raise
RuntimeError
(
"The group should not be None and all ranks which invoke this operation should be the member of this group."
)
if
not
sync_op
and
use_calc_stream
:
raise
RuntimeError
(
"use_calc_stream can only be True in sync op behavior."
)
if
framework
.
in_dygraph_mode
():
return
_broadcast_in_dygraph
(
tensor
,
src
,
group
,
sync_op
,
use_calc_stream
)
raise
RuntimeError
(
"paddle.distributed.stream.broadcast is only supported in dygraph mode now."
)
python/paddle/distributed/communication/stream/recv.py
浏览文件 @
f94edc3b
...
...
@@ -64,7 +64,7 @@ def recv(tensor, src=0, group=None, sync_op=True, use_calc_stream=False):
task = dist.stream.recv(data, src=0, sync_op=False)
task.wait()
out = data.numpy()
# [[4, 5, 6], [4, 5, 6]
# [[4, 5, 6], [4, 5, 6]
] (2 GPUs)
"""
if
group
is
not
None
and
not
group
.
is_member
():
raise
RuntimeError
(
...
...
python/paddle/distributed/communication/stream/reduce.py
0 → 100644
浏览文件 @
f94edc3b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle.fluid.framework
as
framework
from
paddle.distributed.communication.group
import
_get_global_group
from
paddle.distributed.communication.reduce
import
_get_reduce_op
,
ReduceOp
def
_reduce_in_dygraph
(
tensor
,
dst
,
op
,
group
,
sync_op
,
use_calc_stream
):
op_type
=
_get_reduce_op
(
op
,
"reduce"
)
group
=
_get_global_group
()
if
group
is
None
else
group
if
use_calc_stream
:
return
group
.
process_group
.
reduce_on_calc_stream
(
tensor
,
dst
,
op_type
)
task
=
group
.
process_group
.
reduce
(
tensor
,
dst
,
op_type
,
sync_op
)
if
sync_op
:
task
.
wait
()
return
task
def
reduce
(
tensor
,
dst
=
0
,
op
=
ReduceOp
.
SUM
,
group
=
None
,
sync_op
=
True
,
use_calc_stream
=
False
):
"""
Perform specific reduction (for example, sum, max) on a tensor across devices and send to the destintion device.
Args:
tensor (Tensor): The input tensor on each rank. The result will overwrite this tenor after communication. Support
float16, float32, float64, int32, int64, int8, uint8 or bool as the input data type.
dst (int, optional): Rank of the destination device. If none is given, use `0` as default.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The reduction used. If none is given, use ReduceOp.SUM as default.
group (Group, optional): Communicate in which group. If none is given, use the global group as default.
sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default.
use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This
option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning.
Returns:
Return a task object.
Warning:
This API only supports the dygraph mode now.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
local_rank = dist.get_rank()
if local_rank == 0:
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
else:
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
task = dist.stream.reduce(data, dst=0, sync_op=False)
task.wait()
out = data.numpy()
# [[5, 7, 9], [5, 7, 9]] (2 GPUs, out for rank 0)
# [[1, 2, 3], [1, 2, 3]] (2 GPUs, out for rank 1)
"""
if
group
is
not
None
and
not
group
.
is_member
():
raise
RuntimeError
(
"The group should not be None and all ranks which invoke this operation should be the member of this group."
)
if
not
sync_op
and
use_calc_stream
:
raise
RuntimeError
(
"use_calc_stream can only be true in sync op behavior."
)
if
framework
.
in_dygraph_mode
():
return
_reduce_in_dygraph
(
tensor
,
dst
,
op
,
group
,
sync_op
,
use_calc_stream
)
raise
RuntimeError
(
"paddle.distributed.stream.reduce is only supported in dygraph mode now."
)
python/paddle/distributed/communication/stream/reduce_scatter.py
0 → 100644
浏览文件 @
f94edc3b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.distributed
as
dist
import
paddle.fluid.framework
as
framework
from
paddle.distributed.communication.group
import
_get_global_group
from
paddle.distributed.communication.reduce
import
_get_reduce_op
,
ReduceOp
def
_check_tensor_shape
(
tensor
,
shape
,
nranks
=
1
):
expect_shape
=
list
(
shape
)
expect_shape
[
0
]
//=
nranks
if
list
(
tensor
.
shape
)
!=
expect_shape
:
raise
RuntimeError
(
"The in_tensor for reduce_scatter is not correctly-sized."
)
def
_check_tensor_list_shape
(
tensor_list
,
shape
,
nranks
=
1
):
if
len
(
tensor_list
)
!=
nranks
:
raise
RuntimeError
(
f
"The tensor_list for reduce_scatter is not correctly-sized."
)
for
tensor
in
tensor_list
:
if
tensor
.
shape
!=
shape
:
raise
RuntimeError
(
f
"The tensor_list for reduce_scatter is not correctly-sized."
)
def
_reduce_scatter_tensor_in_dygraph
(
out_tensor
,
in_tensor
,
op
,
group
,
sync_op
,
use_calc_stream
,
caller
=
"reduce_scatter"
):
op_type
=
_get_reduce_op
(
op
,
caller
)
group
=
_get_global_group
()
if
group
is
None
else
group
_check_tensor_shape
(
out_tensor
,
in_tensor
.
shape
,
group
.
nranks
)
if
use_calc_stream
:
return
group
.
process_group
.
reduce_scatter_tensor_on_calc_stream
(
in_tensor
,
out_tensor
,
op_type
)
task
=
group
.
process_group
.
reduce_scatter_tensor
(
in_tensor
,
out_tensor
,
op_type
,
sync_op
)
if
sync_op
:
task
.
wait
()
return
task
def
_reduce_scatter_in_dygraph
(
tensor
,
tensor_list
,
op
,
group
,
sync_op
,
use_calc_stream
):
op_type
=
_get_reduce_op
(
op
,
"reduce_scatter"
)
group
=
_get_global_group
()
if
group
is
None
else
group
_check_tensor_list_shape
(
tensor_list
,
tensor
.
shape
,
group
.
nranks
)
if
use_calc_stream
:
return
group
.
process_group
.
reduce_scatter_on_calc_stream
(
tensor_list
,
tensor
,
op_type
)
task
=
group
.
process_group
.
reduce_scatter
(
tensor_list
,
tensor
,
op_type
,
sync_op
)
if
sync_op
:
task
.
wait
()
return
task
def
reduce_scatter
(
tensor
,
tensor_or_tensor_list
,
op
=
ReduceOp
.
SUM
,
group
=
None
,
sync_op
=
True
,
use_calc_stream
=
False
):
"""
Reduce, then scatter a tensor (or a tensor list) across devices.
Args:
tensor (Tensor): The output tensor on each rank. The result will overwrite this tenor after communication. Support
float16, float32, float64, int32, int64, int8, uint8 or bool as the input data type.
tensor_list (List[Tensor]]): The input to scatter.
If it is a tensor, it should be correctly-sized. If it is a list, it should contain correctly-sized tensors.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The reduction used. If none is given, use ReduceOp.SUM as default.
group (Group, optional): Communicate in which group. If none is given, use the global group as default.
sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default.
use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This
option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning.
Returns:
Return a task object.
Warning:
This API only supports the dygraph mode now.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
if dist.get_rank() == 0:
data1 = paddle.to_tensor([0, 1])
data2 = paddle.to_tensor([2, 3])
else:
data1 = paddle.to_tensor([4, 5])
data2 = paddle.to_tensor([6, 7])
dist.stream.reduce_scatter(data1, [data1, data2])
out = data1.numpy()
# [4, 6] (2 GPUs, out for rank 0)
# [8, 10] (2 GPUs, out for rank 1)
"""
if
group
is
not
None
and
not
group
.
is_member
():
raise
RuntimeError
(
"The group should not be None and all ranks which invoke this operation should be the member of this group."
)
if
not
sync_op
and
use_calc_stream
:
raise
RuntimeError
(
"use_calc_stream can only be true in sync op behavior."
)
if
framework
.
in_dygraph_mode
():
if
paddle
.
is_tensor
(
tensor_or_tensor_list
):
return
_reduce_scatter_tensor_in_dygraph
(
tensor
,
tensor_or_tensor_list
,
op
,
group
,
sync_op
,
use_calc_stream
)
else
:
return
_reduce_scatter_in_dygraph
(
tensor
,
tensor_or_tensor_list
,
op
,
group
,
sync_op
,
use_calc_stream
)
raise
RuntimeError
(
"paddle.distributed.stream.reduce_scatter is only supported in dygraph mode now."
)
def
_reduce_scatter_base
(
out_tensor
,
in_tensor
,
op
=
ReduceOp
.
SUM
,
group
=
None
,
sync_op
=
True
,
use_calc_stream
=
False
):
"""
Reduce, then scatter a flattened tensor across devices.
Args:
out_tensor (Tensor): The output tensor on each rank. The result will overwrite this tenor after communication. Support
float16, float32, float64, int32 or int64 as the input data type.
in_tensor (Tensor): The input tensor to reduce and scatter.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The reduction used. If none is given, use ReduceOp.SUM as default.
group (Group, optional): Communicate in which group. If none is given, use the global group as default.
sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default.
use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This
option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning.
Returns:
Return a task object.
Warning:
This API will be deprecated in the future, and only supports the dygraph mode now.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
if dist.get_rank() == 0:
data1 = paddle.to_tensor([7, 8, 9])
data2 = paddle.to_tensor([10, 11, 12])
dist.stream.scatter(data1, src=1)
else:
data1 = paddle.to_tensor([1, 2, 3])
data2 = paddle.to_tensor([4, 5, 6])
dist.stream.scatter(data1, [data1, data2], src=1)
out = data1.numpy()
# [1, 2, 3] (2 GPUs, out for rank 0)
# [4, 5, 6] (2 GPUs, out for rank 1)
"""
if
group
is
not
None
and
not
group
.
is_member
():
raise
RuntimeError
(
"The group should not be None and all ranks which invoke this operation should be the member of this group."
)
if
not
sync_op
and
use_calc_stream
:
raise
RuntimeError
(
"use_calc_stream can only be true in sync op behavior."
)
if
framework
.
in_dygraph_mode
():
return
_reduce_scatter_tensor_in_dygraph
(
out_tensor
,
in_tensor
,
op
,
group
,
sync_op
,
use_calc_stream
,
"_reduce_scatter_base"
)
raise
RuntimeError
(
"paddle.distributed.stream._reduce_scatter_base is only supported in dygraph mode now."
)
python/paddle/distributed/communication/stream/scatter.py
0 → 100644
浏览文件 @
f94edc3b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.distributed
as
dist
import
paddle.fluid.framework
as
framework
from
paddle.distributed
import
collective
def
_check_tensor_shape
(
tensor
,
shape
,
nranks
=
1
):
expect_shape
=
list
(
shape
)
expect_shape
[
0
]
//=
nranks
if
list
(
tensor
.
shape
)
!=
expect_shape
:
raise
RuntimeError
(
"The in_tensor for scatter is not correctly-sized."
)
def
_check_tensor_list_shape
(
tensor_list
,
shape
,
nranks
=
1
):
if
len
(
tensor_list
)
!=
nranks
:
raise
RuntimeError
(
f
"The tensor_list for scatter is not correctly-sized."
)
for
tensor
in
tensor_list
:
if
tensor
.
shape
!=
shape
:
raise
RuntimeError
(
f
"The tensor_list for scatter is not correctly-sized."
)
def
_scatter_tensor_in_dygraph
(
out_tensor
,
in_tensor
,
src
,
group
,
sync_op
,
use_calc_stream
):
group
=
collective
.
_get_default_group
()
if
group
is
None
else
group
src_rank
=
group
.
get_group_rank
(
src
)
if
src_rank
==
-
1
:
raise
RuntimeError
(
"Src rank out of group."
)
nranks
=
group
.
nranks
rank
=
dist
.
get_rank
()
if
rank
==
src_rank
:
_check_tensor_shape
(
out_tensor
,
in_tensor
.
shape
,
nranks
)
if
use_calc_stream
:
return
group
.
process_group
.
scatter_tensor_on_calc_stream
(
in_tensor
,
out_tensor
,
src
)
task
=
group
.
process_group
.
scatter_tensor
(
in_tensor
,
out_tensor
,
src
,
sync_op
)
if
sync_op
:
task
.
wait
()
return
task
def
_scatter_in_dygraph
(
tensor
,
tensor_list
,
src
,
group
,
sync_op
,
use_calc_stream
):
group
=
collective
.
_get_default_group
()
if
group
is
None
else
group
src_rank
=
group
.
get_group_rank
(
src
)
if
src_rank
==
-
1
:
raise
RuntimeError
(
"Src rank out of group."
)
nranks
=
group
.
nranks
rank
=
dist
.
get_rank
()
if
rank
==
src_rank
:
if
len
(
tensor_list
)
==
0
:
raise
RuntimeError
(
"The tensor_list should not be empty on src rank."
)
_check_tensor_list_shape
(
tensor_list
,
tensor
.
shape
,
nranks
)
else
:
tensor_list
=
[
tensor
for
_
in
range
(
nranks
)]
if
use_calc_stream
:
return
group
.
process_group
.
scatter_on_calc_stream
(
tensor_list
,
tensor
,
src
)
task
=
group
.
process_group
.
scatter
(
tensor_list
,
tensor
,
src
,
sync_op
)
if
sync_op
:
task
.
wait
()
return
task
def
scatter
(
tensor
,
tensor_or_tensor_list
=
None
,
src
=
0
,
group
=
None
,
sync_op
=
True
,
use_calc_stream
=
False
):
"""
Scatter a tensor (or a tensor list) across devices.
Args:
tensor (Tensor): The output tensor on each rank. The result will overwrite this tenor after communication. Support
float16, float32, float64, int32, int64, int8, uint8 or bool as the input data type.
tensor_or_tensor_list (Union[Tensor, List[Tensor]]): The input to scatter (default is `None`, must be specified on the source rank).
If it is a tensor, it should be correctly-sized. If it is a list, it should contain correctly-sized tensors.
src (int, optional): Rank of the source device. If none is given, use `0` as default.
group (Group, optional): Communicate in which group. If none is given, use the global group as default.
sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default.
use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This
option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning.
Returns:
Return a task object.
Warning:
This API only supports the dygraph mode now.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
if dist.get_rank() == 0:
data1 = paddle.to_tensor([7, 8, 9])
data2 = paddle.to_tensor([10, 11, 12])
dist.stream.scatter(data1, src=1)
else:
data1 = paddle.to_tensor([1, 2, 3])
data2 = paddle.to_tensor([4, 5, 6])
dist.stream.scatter(data1, [data1, data2], src=1)
out = data1.numpy()
# [1, 2, 3] (2 GPUs, out for rank 0)
# [4, 5, 6] (2 GPUs, out for rank 1)
"""
if
group
is
not
None
and
not
group
.
is_member
():
raise
RuntimeError
(
"The group should not be None and all ranks which invoke this operation should be the member of this group."
)
if
not
sync_op
and
use_calc_stream
:
raise
RuntimeError
(
"use_calc_stream can only be true in sync op behavior."
)
if
tensor_or_tensor_list
is
None
:
raise
RuntimeError
(
"The input should be specified."
)
if
framework
.
in_dygraph_mode
():
if
paddle
.
is_tensor
(
tensor_or_tensor_list
):
return
_scatter_tensor_in_dygraph
(
tensor
,
tensor_or_tensor_list
,
src
,
group
,
sync_op
,
use_calc_stream
)
else
:
return
_scatter_in_dygraph
(
tensor
,
tensor_or_tensor_list
,
src
,
group
,
sync_op
,
use_calc_stream
)
raise
RuntimeError
(
"paddle.distributed.stream.scatter is only supported in dygraph mode now."
)
python/paddle/distributed/communication/stream/send.py
浏览文件 @
f94edc3b
...
...
@@ -64,7 +64,7 @@ def send(tensor, dst=0, group=None, sync_op=True, use_calc_stream=False):
task = dist.stream.recv(data, src=0, sync_op=False)
task.wait()
out = data.numpy()
# [[4, 5, 6], [4, 5, 6]
# [[4, 5, 6], [4, 5, 6]
] (2 GPUs)
"""
if
group
is
not
None
and
not
group
.
is_member
():
raise
RuntimeError
(
...
...
python/paddle/fluid/tests/unittests/collective/CMakeLists.txt
浏览文件 @
f94edc3b
...
...
@@ -282,6 +282,54 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
set_tests_properties
(
test_communication_stream_allreduce_api
PROPERTIES TIMEOUT
"120"
LABELS
"RUN_TYPE=DIST"
)
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
py_test_modules
(
test_communication_stream_alltoall_api MODULES
test_communication_stream_alltoall_api ENVS
"PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python;http_proxy=;https_proxy="
)
set_tests_properties
(
test_communication_stream_alltoall_api
PROPERTIES TIMEOUT
"120"
LABELS
"RUN_TYPE=DIST"
)
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
py_test_modules
(
test_communication_stream_alltoall_single_api MODULES
test_communication_stream_alltoall_single_api ENVS
"PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python;http_proxy=;https_proxy="
)
set_tests_properties
(
test_communication_stream_alltoall_single_api
PROPERTIES TIMEOUT
"120"
LABELS
"RUN_TYPE=DIST"
)
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
py_test_modules
(
test_communication_stream_broadcast_api MODULES
test_communication_stream_broadcast_api ENVS
"PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python;http_proxy=;https_proxy="
)
set_tests_properties
(
test_communication_stream_broadcast_api
PROPERTIES TIMEOUT
"120"
LABELS
"RUN_TYPE=DIST"
)
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
py_test_modules
(
test_communication_stream_reduce_api MODULES
test_communication_stream_reduce_api ENVS
"PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python;http_proxy=;https_proxy="
)
set_tests_properties
(
test_communication_stream_reduce_api
PROPERTIES TIMEOUT
"120"
LABELS
"RUN_TYPE=DIST"
)
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
py_test_modules
(
test_communication_stream_reduce_scatter_api MODULES
test_communication_stream_reduce_scatter_api ENVS
"PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python;http_proxy=;https_proxy="
)
set_tests_properties
(
test_communication_stream_reduce_scatter_api
PROPERTIES TIMEOUT
"120"
LABELS
"RUN_TYPE=DIST"
)
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
py_test_modules
(
test_communication_stream_scatter_api MODULES
test_communication_stream_scatter_api ENVS
"PYTHONPATH=..:
${
PADDLE_BINARY_DIR
}
/python;http_proxy=;https_proxy="
)
set_tests_properties
(
test_communication_stream_scatter_api
PROPERTIES TIMEOUT
"120"
LABELS
"RUN_TYPE=DIST"
)
endif
()
if
((
WITH_GPU OR WITH_ROCM
)
AND
(
LINUX
))
py_test_modules
(
test_communication_stream_sendrecv_api MODULES
...
...
python/paddle/fluid/tests/unittests/collective/communication_stream_alltoall_api_dygraph.py
0 → 100644
浏览文件 @
f94edc3b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
import
paddle.distributed
as
dist
import
test_communication_api_base
as
test_base
import
test_collective_api_base
as
test_collective_base
class
StreamAllToAllTestCase
():
def
__init__
(
self
):
self
.
_sync_op
=
eval
(
os
.
getenv
(
"sync_op"
))
self
.
_use_calc_stream
=
eval
(
os
.
getenv
(
"use_calc_stream"
))
self
.
_backend
=
os
.
getenv
(
"backend"
)
self
.
_shape
=
eval
(
os
.
getenv
(
"shape"
))
self
.
_dtype
=
os
.
getenv
(
"dtype"
)
self
.
_seeds
=
eval
(
os
.
getenv
(
"seeds"
))
if
self
.
_backend
not
in
[
"nccl"
,
"gloo"
]:
raise
NotImplementedError
(
"Only support nccl and gloo as the backend for now."
)
os
.
environ
[
"PADDLE_DISTRI_BACKEND"
]
=
self
.
_backend
def
run_test_case
(
self
):
dist
.
init_parallel_env
()
test_data_list
=
[]
for
seed
in
self
.
_seeds
:
test_data_list
.
append
(
test_collective_base
.
create_test_data
(
shape
=
self
.
_shape
,
dtype
=
self
.
_dtype
,
seed
=
seed
))
nranks
=
len
(
test_data_list
)
data1
=
test_data_list
[
0
]
data2
=
test_data_list
[
1
]
result1
=
np
.
vstack
(
[
data1
[
0
:
data1
.
shape
[
0
]
//
2
,
:],
data2
[
0
:
data2
.
shape
[
0
]
//
2
,
:]])
result2
=
np
.
vstack
(
[
data1
[
data1
.
shape
[
0
]
//
2
:,
:],
data2
[
data2
.
shape
[
0
]
//
2
:,
:]])
rank
=
dist
.
get_rank
()
tensor
=
paddle
.
to_tensor
(
test_data_list
[
rank
])
t1
,
t2
=
paddle
.
split
(
tensor
,
nranks
,
axis
=
0
)
# case 1: pass an empty tensor list
empty_tensor_list
=
[]
task
=
dist
.
stream
.
alltoall
(
empty_tensor_list
,
[
t1
,
t2
],
sync_op
=
self
.
_sync_op
,
use_calc_stream
=
self
.
_use_calc_stream
)
if
not
self
.
_sync_op
:
task
.
wait
()
result_tensor_list
=
np
.
vstack
(
empty_tensor_list
)
if
rank
==
0
:
assert
np
.
allclose
(
result_tensor_list
,
result1
,
rtol
=
1e-05
,
atol
=
1e-05
)
else
:
assert
np
.
allclose
(
result_tensor_list
,
result2
,
rtol
=
1e-05
,
atol
=
1e-05
)
# case 2: pass a pre-sized tensor list
full_tensor_list
=
[
paddle
.
empty_like
(
t1
)
for
_
in
test_data_list
]
task
=
dist
.
stream
.
alltoall
(
full_tensor_list
,
[
t1
,
t2
],
sync_op
=
self
.
_sync_op
,
use_calc_stream
=
self
.
_use_calc_stream
)
if
not
self
.
_sync_op
:
task
.
wait
()
result_tensor_list
=
np
.
vstack
(
full_tensor_list
)
if
rank
==
0
:
assert
np
.
allclose
(
result_tensor_list
,
result1
,
rtol
=
1e-05
,
atol
=
1e-05
)
else
:
assert
np
.
allclose
(
result_tensor_list
,
result2
,
rtol
=
1e-05
,
atol
=
1e-05
)
# case 3: pass a pre-sized tensor
out_tensor
=
paddle
.
empty_like
(
tensor
)
task
=
dist
.
stream
.
alltoall
(
out_tensor
,
tensor
,
sync_op
=
self
.
_sync_op
,
use_calc_stream
=
self
.
_use_calc_stream
)
if
not
self
.
_sync_op
:
task
.
wait
()
if
rank
==
0
:
assert
np
.
allclose
(
out_tensor
,
result1
,
rtol
=
1e-05
,
atol
=
1e-05
)
else
:
assert
np
.
allclose
(
out_tensor
,
result2
,
rtol
=
1e-05
,
atol
=
1e-05
)
if
__name__
==
"__main__"
:
StreamAllToAllTestCase
().
run_test_case
()
python/paddle/fluid/tests/unittests/collective/communication_stream_alltoall_single_api_dygraph.py
0 → 100644
浏览文件 @
f94edc3b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
import
paddle.distributed
as
dist
import
test_communication_api_base
as
test_base
import
test_collective_api_base
as
test_collective_base
class
StreamAllToAllSingleTestCase
():
def
__init__
(
self
):
self
.
_sync_op
=
eval
(
os
.
getenv
(
"sync_op"
))
self
.
_use_calc_stream
=
eval
(
os
.
getenv
(
"use_calc_stream"
))
self
.
_backend
=
os
.
getenv
(
"backend"
)
self
.
_shape
=
eval
(
os
.
getenv
(
"shape"
))
self
.
_dtype
=
os
.
getenv
(
"dtype"
)
self
.
_seeds
=
eval
(
os
.
getenv
(
"seeds"
))
if
self
.
_backend
not
in
[
"nccl"
,
"gloo"
]:
raise
NotImplementedError
(
"Only support nccl and gloo as the backend for now."
)
os
.
environ
[
"PADDLE_DISTRI_BACKEND"
]
=
self
.
_backend
def
run_test_case
(
self
):
dist
.
init_parallel_env
()
test_data_list
=
[]
for
seed
in
self
.
_seeds
:
test_data_list
.
append
(
test_collective_base
.
create_test_data
(
shape
=
self
.
_shape
,
dtype
=
self
.
_dtype
,
seed
=
seed
))
nranks
=
len
(
test_data_list
)
data1
=
paddle
.
to_tensor
(
test_data_list
[
0
])
data2
=
paddle
.
to_tensor
(
test_data_list
[
1
])
result1
=
np
.
vstack
(
(
data1
[
0
:
data1
.
shape
[
0
]
//
2
,
:],
data2
[
0
:
data2
.
shape
[
0
]
//
2
,
:]))
result2
=
np
.
vstack
(
(
data1
[
data1
.
shape
[
0
]
//
2
:,
:],
data2
[
data2
.
shape
[
0
]
//
2
:,
:]))
rank
=
dist
.
get_rank
()
tensor
=
paddle
.
to_tensor
(
test_data_list
[
rank
])
out_tensor
=
paddle
.
empty_like
(
tensor
)
task
=
dist
.
stream
.
alltoall_single
(
out_tensor
,
tensor
,
sync_op
=
self
.
_sync_op
,
use_calc_stream
=
self
.
_use_calc_stream
)
if
not
self
.
_sync_op
:
task
.
wait
()
if
rank
==
0
:
assert
np
.
allclose
(
out_tensor
,
result1
,
rtol
=
1e-05
,
atol
=
1e-05
)
else
:
assert
np
.
allclose
(
out_tensor
,
result2
,
rtol
=
1e-05
,
atol
=
1e-05
)
if
__name__
==
"__main__"
:
StreamAllToAllSingleTestCase
().
run_test_case
()
python/paddle/fluid/tests/unittests/collective/communication_stream_broadcast_api_dygraph.py
0 → 100644
浏览文件 @
f94edc3b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
numpy
as
np
import
paddle
import
paddle.distributed
as
dist
import
test_collective_api_base
as
test_collective_base
class
StreamBroadcastTestCase
():
def
__init__
(
self
):
self
.
_sync_op
=
eval
(
os
.
getenv
(
"sync_op"
))
self
.
_use_calc_stream
=
eval
(
os
.
getenv
(
"use_calc_stream"
))
self
.
_backend
=
os
.
getenv
(
"backend"
)
self
.
_shape
=
eval
(
os
.
getenv
(
"shape"
))
self
.
_dtype
=
os
.
getenv
(
"dtype"
)
self
.
_seeds
=
eval
(
os
.
getenv
(
"seeds"
))
if
self
.
_backend
not
in
[
"nccl"
,
"gloo"
]:
raise
NotImplementedError
(
"Only support nccl and gloo as the backend for now."
)
os
.
environ
[
"PADDLE_DISTRI_BACKEND"
]
=
self
.
_backend
def
run_test_case
(
self
):
dist
.
init_parallel_env
()
src_rank
=
1
result
=
test_collective_base
.
create_test_data
(
shape
=
self
.
_shape
,
dtype
=
self
.
_dtype
,
seed
=
self
.
_seeds
[
src_rank
])
tensor
=
paddle
.
to_tensor
(
result
)
task
=
dist
.
stream
.
broadcast
(
tensor
,
src
=
src_rank
,
sync_op
=
self
.
_sync_op
,
use_calc_stream
=
self
.
_use_calc_stream
)
if
not
self
.
_sync_op
:
task
.
wait
()
assert
np
.
allclose
(
tensor
,
result
,
rtol
=
1e-05
,
atol
=
1e-05
)
if
__name__
==
"__main__"
:
StreamBroadcastTestCase
().
run_test_case
()
python/paddle/fluid/tests/unittests/collective/communication_stream_reduce_api_dygraph.py
0 → 100644
浏览文件 @
f94edc3b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
numpy
as
np
import
paddle
import
paddle.distributed
as
dist
import
test_collective_api_base
as
test_collective_base
class
StreamReduceTestCase
():
def
__init__
(
self
):
self
.
_sync_op
=
eval
(
os
.
getenv
(
"sync_op"
))
self
.
_use_calc_stream
=
eval
(
os
.
getenv
(
"use_calc_stream"
))
self
.
_backend
=
os
.
getenv
(
"backend"
)
self
.
_shape
=
eval
(
os
.
getenv
(
"shape"
))
self
.
_dtype
=
os
.
getenv
(
"dtype"
)
self
.
_seeds
=
eval
(
os
.
getenv
(
"seeds"
))
if
self
.
_backend
not
in
[
"nccl"
,
"gloo"
]:
raise
NotImplementedError
(
"Only support nccl and gloo as the backend for now."
)
os
.
environ
[
"PADDLE_DISTRI_BACKEND"
]
=
self
.
_backend
def
run_test_case
(
self
):
dist
.
init_parallel_env
()
test_data_list
=
[]
for
seed
in
self
.
_seeds
:
test_data_list
.
append
(
test_collective_base
.
create_test_data
(
shape
=
self
.
_shape
,
dtype
=
self
.
_dtype
,
seed
=
seed
))
rank
=
dist
.
get_rank
()
tensor
=
paddle
.
to_tensor
(
test_data_list
[
rank
])
task
=
dist
.
stream
.
reduce
(
tensor
,
dst
=
1
,
sync_op
=
self
.
_sync_op
,
use_calc_stream
=
self
.
_use_calc_stream
)
if
not
self
.
_sync_op
:
task
.
wait
()
result
=
sum
(
test_data_list
)
if
rank
==
1
:
assert
np
.
allclose
(
tensor
,
result
,
rtol
=
1e-05
,
atol
=
1e-05
)
else
:
assert
np
.
allclose
(
tensor
,
test_data_list
[
rank
],
rtol
=
1e-05
,
atol
=
1e-05
)
if
__name__
==
"__main__"
:
StreamReduceTestCase
().
run_test_case
()
python/paddle/fluid/tests/unittests/collective/communication_stream_reduce_scatter_api_dygraph.py
0 → 100644
浏览文件 @
f94edc3b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
numpy
as
np
import
paddle
import
paddle.distributed
as
dist
import
test_collective_api_base
as
test_collective_base
class
StreamReduceScatterTestCase
():
def
__init__
(
self
):
self
.
_sync_op
=
eval
(
os
.
getenv
(
"sync_op"
))
self
.
_use_calc_stream
=
eval
(
os
.
getenv
(
"use_calc_stream"
))
self
.
_backend
=
os
.
getenv
(
"backend"
)
self
.
_shape
=
eval
(
os
.
getenv
(
"shape"
))
self
.
_dtype
=
os
.
getenv
(
"dtype"
)
self
.
_seeds
=
eval
(
os
.
getenv
(
"seeds"
))
if
self
.
_backend
not
in
[
"nccl"
,
"gloo"
]:
raise
NotImplementedError
(
"Only support nccl and gloo as the backend for now."
)
os
.
environ
[
"PADDLE_DISTRI_BACKEND"
]
=
self
.
_backend
def
run_test_case
(
self
):
dist
.
init_parallel_env
()
test_data_list
=
[]
for
seed
in
self
.
_seeds
:
test_data_list
.
append
(
test_collective_base
.
create_test_data
(
shape
=
self
.
_shape
,
dtype
=
self
.
_dtype
,
seed
=
seed
))
reduce_result
=
sum
(
test_data_list
)
result1
=
reduce_result
[
0
:
reduce_result
.
shape
[
0
]
//
2
]
result2
=
reduce_result
[
reduce_result
.
shape
[
0
]
//
2
:]
rank
=
dist
.
get_rank
()
tensor
=
paddle
.
to_tensor
(
test_data_list
[
rank
])
# case 1: pass a pre-sized tensor list
t1
,
t2
=
paddle
.
split
(
tensor
,
2
,
axis
=
0
)
result_tensor
=
paddle
.
empty_like
(
t1
)
task
=
dist
.
stream
.
reduce_scatter
(
result_tensor
,
[
t1
,
t2
],
sync_op
=
self
.
_sync_op
,
use_calc_stream
=
self
.
_use_calc_stream
)
if
not
self
.
_sync_op
:
task
.
wait
()
if
rank
==
0
:
assert
np
.
allclose
(
result_tensor
,
result1
,
rtol
=
1e-05
,
atol
=
1e-05
)
else
:
assert
np
.
allclose
(
result_tensor
,
result2
,
rtol
=
1e-05
,
atol
=
1e-05
)
# case 2: pass a pre-sized tensor
result_tensor
=
paddle
.
empty_like
(
t1
)
task
=
dist
.
stream
.
reduce_scatter
(
result_tensor
,
tensor
,
sync_op
=
self
.
_sync_op
,
use_calc_stream
=
self
.
_use_calc_stream
)
if
not
self
.
_sync_op
:
task
.
wait
()
if
rank
==
0
:
assert
np
.
allclose
(
result_tensor
,
result1
,
rtol
=
1e-05
,
atol
=
1e-05
)
else
:
assert
np
.
allclose
(
result_tensor
,
result2
,
rtol
=
1e-05
,
atol
=
1e-05
)
# case 3: test the legacy API
result_tensor
=
paddle
.
empty_like
(
t1
)
task
=
dist
.
stream
.
_reduce_scatter_base
(
result_tensor
,
tensor
,
sync_op
=
self
.
_sync_op
,
use_calc_stream
=
self
.
_use_calc_stream
)
if
not
self
.
_sync_op
:
task
.
wait
()
if
rank
==
0
:
assert
np
.
allclose
(
result_tensor
,
result1
,
rtol
=
1e-05
,
atol
=
1e-05
)
else
:
assert
np
.
allclose
(
result_tensor
,
result2
,
rtol
=
1e-05
,
atol
=
1e-05
)
if
__name__
==
"__main__"
:
StreamReduceScatterTestCase
().
run_test_case
()
python/paddle/fluid/tests/unittests/collective/communication_stream_scatter_api_dygraph.py
0 → 100644
浏览文件 @
f94edc3b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
numpy
as
np
import
paddle
import
paddle.distributed
as
dist
import
test_collective_api_base
as
test_collective_base
class
StreamScatterTestCase
():
def
__init__
(
self
):
self
.
_sync_op
=
eval
(
os
.
getenv
(
"sync_op"
))
self
.
_use_calc_stream
=
eval
(
os
.
getenv
(
"use_calc_stream"
))
self
.
_backend
=
os
.
getenv
(
"backend"
)
self
.
_shape
=
eval
(
os
.
getenv
(
"shape"
))
self
.
_dtype
=
os
.
getenv
(
"dtype"
)
self
.
_seeds
=
eval
(
os
.
getenv
(
"seeds"
))
if
self
.
_backend
not
in
[
"nccl"
,
"gloo"
]:
raise
NotImplementedError
(
"Only support nccl and gloo as the backend for now."
)
os
.
environ
[
"PADDLE_DISTRI_BACKEND"
]
=
self
.
_backend
def
run_test_case
(
self
):
dist
.
init_parallel_env
()
test_data_list
=
[]
for
seed
in
self
.
_seeds
:
test_data_list
.
append
(
test_collective_base
.
create_test_data
(
shape
=
self
.
_shape
,
dtype
=
self
.
_dtype
,
seed
=
seed
))
src_rank
=
1
src_data
=
test_data_list
[
src_rank
]
result1
=
src_data
[
0
:
src_data
.
shape
[
0
]
//
2
]
result2
=
src_data
[
src_data
.
shape
[
0
]
//
2
:]
rank
=
dist
.
get_rank
()
# case 1: pass a pre-sized tensor list
tensor
=
paddle
.
to_tensor
(
test_data_list
[
rank
])
t1
,
t2
=
paddle
.
split
(
tensor
,
2
,
axis
=
0
)
task
=
dist
.
stream
.
scatter
(
t1
,
[
t1
,
t2
],
src
=
src_rank
,
sync_op
=
self
.
_sync_op
,
use_calc_stream
=
self
.
_use_calc_stream
)
if
not
self
.
_sync_op
:
task
.
wait
()
if
rank
==
src_rank
:
assert
np
.
allclose
(
t1
,
result2
,
rtol
=
1e-05
,
atol
=
1e-05
)
else
:
assert
np
.
allclose
(
t1
,
result1
,
rtol
=
1e-05
,
atol
=
1e-05
)
# case 2: pass a pre-sized tensor
tensor
=
paddle
.
to_tensor
(
src_data
)
t1
=
paddle
.
empty_like
(
t1
)
task
=
dist
.
stream
.
scatter
(
t1
,
tensor
,
src
=
src_rank
,
sync_op
=
self
.
_sync_op
,
use_calc_stream
=
self
.
_use_calc_stream
)
if
not
self
.
_sync_op
:
task
.
wait
()
if
rank
==
src_rank
:
assert
np
.
allclose
(
t1
,
result2
,
rtol
=
1e-05
,
atol
=
1e-05
)
else
:
assert
np
.
allclose
(
t1
,
result1
,
rtol
=
1e-05
,
atol
=
1e-05
)
if
__name__
==
"__main__"
:
StreamScatterTestCase
().
run_test_case
()
python/paddle/fluid/tests/unittests/collective/communication_stream_sendrecv_api_dygraph.py
浏览文件 @
f94edc3b
...
...
@@ -43,22 +43,25 @@ class StreamSendRecvTestCase():
dtype
=
self
.
_dtype
,
seed
=
seed
))
src_rank
=
0
dst_rank
=
1
rank
=
dist
.
get_rank
()
tensor
=
paddle
.
to_tensor
(
test_data_list
[
rank
])
if
rank
==
0
:
task
=
dist
.
stream
.
send
(
tensor
,
dst
=
1
,
dst
=
dst_rank
,
sync_op
=
self
.
_sync_op
,
use_calc_stream
=
self
.
_use_calc_stream
)
else
:
task
=
dist
.
stream
.
recv
(
tensor
,
src
=
0
,
src
=
src_rank
,
sync_op
=
self
.
_sync_op
,
use_calc_stream
=
self
.
_use_calc_stream
)
if
not
self
.
_sync_op
:
task
.
wait
()
result
=
test_data_list
[
0
]
result
=
test_data_list
[
src_rank
]
assert
np
.
allclose
(
tensor
,
result
,
rtol
=
1e-05
,
atol
=
1e-05
)
...
...
python/paddle/fluid/tests/unittests/collective/test_communication_stream_alltoall_api.py
0 → 100644
浏览文件 @
f94edc3b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
itertools
import
test_communication_api_base
as
test_base
class
TestCommunicationStreamAllToAllAPI
(
test_base
.
CommunicationTestDistBase
):
def
setUp
(
self
):
super
(
TestCommunicationStreamAllToAllAPI
,
self
).
setUp
(
num_of_devices
=
2
,
timeout
=
120
)
self
.
_default_envs
=
{
"backend"
:
"nccl"
,
"shape"
:
"(100, 200)"
,
"dtype"
:
"float32"
,
"seeds"
:
str
(
self
.
_seeds
)
}
self
.
_changeable_envs
=
{
"sync_op"
:
[
"True"
,
"False"
],
"use_calc_stream"
:
[
"True"
,
"False"
]
}
def
test_alltoall_stream
(
self
):
envs_list
=
test_base
.
gen_product_envs_list
(
self
.
_default_envs
,
self
.
_changeable_envs
)
for
envs
in
envs_list
:
if
eval
(
envs
[
"use_calc_stream"
])
and
not
eval
(
envs
[
"sync_op"
]):
continue
self
.
run_test_case
(
"communication_stream_alltoall_api_dygraph.py"
,
user_defined_envs
=
envs
)
def
tearDown
(
self
):
super
(
TestCommunicationStreamAllToAllAPI
,
self
).
tearDown
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/test_communication_stream_alltoall_single_api.py
0 → 100644
浏览文件 @
f94edc3b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
itertools
import
test_communication_api_base
as
test_base
class
TestCommunicationStreamAllToAllSingleAPI
(
test_base
.
CommunicationTestDistBase
):
def
setUp
(
self
):
super
(
TestCommunicationStreamAllToAllSingleAPI
,
self
).
setUp
(
num_of_devices
=
2
,
timeout
=
120
)
self
.
_default_envs
=
{
"backend"
:
"nccl"
,
"shape"
:
"(100, 200)"
,
"dtype"
:
"float32"
,
"seeds"
:
str
(
self
.
_seeds
)
}
self
.
_changeable_envs
=
{
"sync_op"
:
[
"True"
,
"False"
],
"use_calc_stream"
:
[
"True"
,
"False"
]
}
def
test_alltoall_single_stream
(
self
):
envs_list
=
test_base
.
gen_product_envs_list
(
self
.
_default_envs
,
self
.
_changeable_envs
)
for
envs
in
envs_list
:
if
eval
(
envs
[
"use_calc_stream"
])
and
not
eval
(
envs
[
"sync_op"
]):
continue
self
.
run_test_case
(
"communication_stream_alltoall_single_api_dygraph.py"
,
user_defined_envs
=
envs
)
def
tearDown
(
self
):
super
(
TestCommunicationStreamAllToAllSingleAPI
,
self
).
tearDown
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/test_communication_stream_broadcast_api.py
0 → 100644
浏览文件 @
f94edc3b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
itertools
import
test_communication_api_base
as
test_base
class
TestCommunicationStreamBroadcastAPI
(
test_base
.
CommunicationTestDistBase
):
def
setUp
(
self
):
super
(
TestCommunicationStreamBroadcastAPI
,
self
).
setUp
(
num_of_devices
=
2
,
timeout
=
120
)
self
.
_default_envs
=
{
"backend"
:
"nccl"
,
"shape"
:
"(100, 200)"
,
"dtype"
:
"float32"
,
"seeds"
:
str
(
self
.
_seeds
)
}
self
.
_changeable_envs
=
{
"sync_op"
:
[
"True"
,
"False"
],
"use_calc_stream"
:
[
"True"
,
"False"
]
}
def
test_broadcast_stream
(
self
):
envs_list
=
test_base
.
gen_product_envs_list
(
self
.
_default_envs
,
self
.
_changeable_envs
)
for
envs
in
envs_list
:
if
eval
(
envs
[
"use_calc_stream"
])
and
not
eval
(
envs
[
"sync_op"
]):
continue
self
.
run_test_case
(
"communication_stream_broadcast_api_dygraph.py"
,
user_defined_envs
=
envs
)
def
tearDown
(
self
):
super
(
TestCommunicationStreamBroadcastAPI
,
self
).
tearDown
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/test_communication_stream_reduce_api.py
0 → 100644
浏览文件 @
f94edc3b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
itertools
import
test_communication_api_base
as
test_base
class
TestCommunicationStreamReduceAPI
(
test_base
.
CommunicationTestDistBase
):
def
setUp
(
self
):
super
(
TestCommunicationStreamReduceAPI
,
self
).
setUp
(
num_of_devices
=
2
,
timeout
=
120
)
self
.
_default_envs
=
{
"backend"
:
"nccl"
,
"shape"
:
"(100, 200)"
,
"dtype"
:
"float32"
,
"seeds"
:
str
(
self
.
_seeds
)
}
self
.
_changeable_envs
=
{
"sync_op"
:
[
"True"
,
"False"
],
"use_calc_stream"
:
[
"True"
,
"False"
]
}
def
test_reduce_stream
(
self
):
envs_list
=
test_base
.
gen_product_envs_list
(
self
.
_default_envs
,
self
.
_changeable_envs
)
for
envs
in
envs_list
:
if
eval
(
envs
[
"use_calc_stream"
])
and
not
eval
(
envs
[
"sync_op"
]):
continue
self
.
run_test_case
(
"communication_stream_reduce_api_dygraph.py"
,
user_defined_envs
=
envs
)
def
tearDown
(
self
):
super
(
TestCommunicationStreamReduceAPI
,
self
).
tearDown
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/test_communication_stream_reduce_scatter_api.py
0 → 100644
浏览文件 @
f94edc3b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
itertools
import
test_communication_api_base
as
test_base
class
TestCommunicationStreamReduceScatterAPI
(
test_base
.
CommunicationTestDistBase
):
def
setUp
(
self
):
super
(
TestCommunicationStreamReduceScatterAPI
,
self
).
setUp
(
num_of_devices
=
2
,
timeout
=
120
)
self
.
_default_envs
=
{
"backend"
:
"nccl"
,
"shape"
:
"(100, 200)"
,
"dtype"
:
"float32"
,
"seeds"
:
str
(
self
.
_seeds
)
}
self
.
_changeable_envs
=
{
"sync_op"
:
[
"True"
,
"False"
],
"use_calc_stream"
:
[
"True"
,
"False"
]
}
def
test_reduce_scatter_stream
(
self
):
envs_list
=
test_base
.
gen_product_envs_list
(
self
.
_default_envs
,
self
.
_changeable_envs
)
for
envs
in
envs_list
:
if
eval
(
envs
[
"use_calc_stream"
])
and
not
eval
(
envs
[
"sync_op"
]):
continue
self
.
run_test_case
(
"communication_stream_reduce_scatter_api_dygraph.py"
,
user_defined_envs
=
envs
)
def
tearDown
(
self
):
super
(
TestCommunicationStreamReduceScatterAPI
,
self
).
tearDown
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/test_communication_stream_scatter_api.py
0 → 100644
浏览文件 @
f94edc3b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
itertools
import
test_communication_api_base
as
test_base
class
TestCommunicationStreamScatterAPI
(
test_base
.
CommunicationTestDistBase
):
def
setUp
(
self
):
super
(
TestCommunicationStreamScatterAPI
,
self
).
setUp
(
num_of_devices
=
2
,
timeout
=
120
)
self
.
_default_envs
=
{
"backend"
:
"nccl"
,
"shape"
:
"(100, 200)"
,
"dtype"
:
"float32"
,
"seeds"
:
str
(
self
.
_seeds
)
}
self
.
_changeable_envs
=
{
"sync_op"
:
[
"True"
,
"False"
],
"use_calc_stream"
:
[
"True"
,
"False"
]
}
def
test_reduce_stream
(
self
):
envs_list
=
test_base
.
gen_product_envs_list
(
self
.
_default_envs
,
self
.
_changeable_envs
)
for
envs
in
envs_list
:
if
eval
(
envs
[
"use_calc_stream"
])
and
not
eval
(
envs
[
"sync_op"
]):
continue
self
.
run_test_case
(
"communication_stream_scatter_api_dygraph.py"
,
user_defined_envs
=
envs
)
def
tearDown
(
self
):
super
(
TestCommunicationStreamScatterAPI
,
self
).
tearDown
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/collective/testslist.csv
浏览文件 @
f94edc3b
...
...
@@ -34,6 +34,12 @@ test_collective_split_row_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_
test_collective_wait,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_communication_stream_allgather_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=,
test_communication_stream_allreduce_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=,
test_communication_stream_alltoall_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=,
test_communication_stream_alltoall_single_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=,
test_communication_stream_broadcast_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=,
test_communication_stream_reduce_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=,
test_communication_stream_reduce_scatter_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=,
test_communication_stream_scatter_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=,
test_communication_stream_sendrecv_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=,
test_eager_dist_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_gen_nccl_id_op,,gpu;rocm;ASCEND;ASCEND_CL,,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录