Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
1e56ca8a
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1e56ca8a
编写于
4月 13, 2022
作者:
L
lilong12
提交者:
GitHub
4月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use densetensor instead of Tensor for ProcessGroup (#41403)
上级
1cdd88f6
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
501 addition
and
554 deletion
+501
-554
paddle/fluid/distributed/collective/Common.cc
paddle/fluid/distributed/collective/Common.cc
+7
-11
paddle/fluid/distributed/collective/Common.h
paddle/fluid/distributed/collective/Common.h
+4
-4
paddle/fluid/distributed/collective/ProcessGroup.cc
paddle/fluid/distributed/collective/ProcessGroup.cc
+2
-1
paddle/fluid/distributed/collective/ProcessGroup.h
paddle/fluid/distributed/collective/ProcessGroup.h
+17
-20
paddle/fluid/distributed/collective/ProcessGroupGloo.cc
paddle/fluid/distributed/collective/ProcessGroupGloo.cc
+98
-85
paddle/fluid/distributed/collective/ProcessGroupGloo.h
paddle/fluid/distributed/collective/ProcessGroupGloo.h
+16
-10
paddle/fluid/distributed/collective/ProcessGroupHCCL.cc
paddle/fluid/distributed/collective/ProcessGroupHCCL.cc
+35
-82
paddle/fluid/distributed/collective/ProcessGroupHCCL.h
paddle/fluid/distributed/collective/ProcessGroupHCCL.h
+10
-13
paddle/fluid/distributed/collective/ProcessGroupHeter.cc
paddle/fluid/distributed/collective/ProcessGroupHeter.cc
+77
-148
paddle/fluid/distributed/collective/ProcessGroupHeter.h
paddle/fluid/distributed/collective/ProcessGroupHeter.h
+5
-6
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
+107
-131
paddle/fluid/distributed/collective/ProcessGroupNCCL.h
paddle/fluid/distributed/collective/ProcessGroupNCCL.h
+27
-21
paddle/fluid/distributed/collective/reducer.cc
paddle/fluid/distributed/collective/reducer.cc
+35
-7
paddle/fluid/operators/collective/c_allgather_op.cu.cc
paddle/fluid/operators/collective/c_allgather_op.cu.cc
+14
-0
paddle/fluid/operators/collective/c_broadcast_op.cu.cc
paddle/fluid/operators/collective/c_broadcast_op.cu.cc
+6
-1
paddle/fluid/pybind/distributed_py.cc
paddle/fluid/pybind/distributed_py.cc
+36
-14
python/paddle/fluid/tests/unittests/init_process_group.py
python/paddle/fluid/tests/unittests/init_process_group.py
+5
-0
未找到文件。
paddle/fluid/distributed/collective/Common.cc
浏览文件 @
1e56ca8a
...
@@ -17,11 +17,11 @@
...
@@ -17,11 +17,11 @@
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
std
::
vector
<
Place
>
GetPlaceList
(
const
std
::
vector
<
Tensor
>&
tensors
)
{
std
::
vector
<
Place
>
GetPlaceList
(
const
std
::
vector
<
phi
::
Dense
Tensor
>&
tensors
)
{
std
::
vector
<
Place
>
places
;
std
::
vector
<
Place
>
places
;
places
.
reserve
(
tensors
.
size
());
places
.
reserve
(
tensors
.
size
());
for
(
auto
&
tensor
:
tensors
)
{
for
(
auto
&
tensor
:
tensors
)
{
places
.
push_back
(
tensor
.
inner_
place
());
places
.
push_back
(
tensor
.
place
());
}
}
return
places
;
return
places
;
}
}
...
@@ -40,15 +40,11 @@ std::string GetKeyFromPlaces(const std::vector<Place>& places) {
...
@@ -40,15 +40,11 @@ std::string GetKeyFromPlaces(const std::vector<Place>& places) {
return
placeList
;
return
placeList
;
}
}
static
bool
CheckTensorsInPlace
(
const
std
::
vector
<
Tensor
>&
tensors
,
bool
CheckTensorsInCudaPlace
(
const
std
::
vector
<
phi
::
DenseTensor
>&
tensors
)
{
phi
::
AllocationType
type
)
{
return
std
::
all_of
(
tensors
.
cbegin
(),
tensors
.
cend
(),
return
std
::
all_of
(
tensors
.
cbegin
(),
tensors
.
cend
(),
[
&
](
const
Tensor
&
t
)
{
[
&
](
const
phi
::
DenseTensor
&
t
)
{
return
t
.
place
().
GetType
()
==
type
;
return
platform
::
is_gpu_place
(
t
.
place
());
});
});
}
bool
CheckTensorsInCudaPlace
(
const
std
::
vector
<
Tensor
>&
tensors
)
{
return
CheckTensorsInPlace
(
tensors
,
phi
::
AllocationType
::
GPU
);
}
}
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/collective/Common.h
浏览文件 @
1e56ca8a
...
@@ -16,18 +16,18 @@
...
@@ -16,18 +16,18 @@
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
using
Tensor
=
paddle
::
experimental
::
Tensor
;
using
Place
=
paddle
::
platform
::
Place
;
using
Place
=
paddle
::
platform
::
Place
;
// Get the list of devices from list of tensors
// Get the list of devices from list of tensors
std
::
vector
<
Place
>
GetPlaceList
(
const
std
::
vector
<
Tensor
>&
tensors
);
std
::
vector
<
Place
>
GetPlaceList
(
const
std
::
vector
<
phi
::
Dense
Tensor
>&
tensors
);
// Get the deviceList String from the list of devices
// Get the deviceList String from the list of devices
std
::
string
GetKeyFromPlaces
(
const
std
::
vector
<
Place
>&
places
);
std
::
string
GetKeyFromPlaces
(
const
std
::
vector
<
Place
>&
places
);
bool
CheckTensorsInCudaPlace
(
const
std
::
vector
<
Tensor
>&
tensors
);
bool
CheckTensorsInCudaPlace
(
const
std
::
vector
<
phi
::
Dense
Tensor
>&
tensors
);
}
// namespace distributed
}
// namespace distributed
}
// namespace paddle
}
// namespace paddle
paddle/fluid/distributed/collective/ProcessGroup.cc
浏览文件 @
1e56ca8a
...
@@ -17,7 +17,8 @@
...
@@ -17,7 +17,8 @@
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
ProcessGroup
::
Task
::
Task
(
int
rank
,
const
std
::
vector
<
Tensor
>&
inputTensors
,
ProcessGroup
::
Task
::
Task
(
int
rank
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputTensors
,
CommType
comm_type
)
CommType
comm_type
)
:
rank_
(
rank
),
comm_type_
(
comm_type
)
{}
:
rank_
(
rank
),
comm_type_
(
comm_type
)
{}
...
...
paddle/fluid/distributed/collective/ProcessGroup.h
浏览文件 @
1e56ca8a
...
@@ -54,7 +54,7 @@ class ProcessGroup {
...
@@ -54,7 +54,7 @@ class ProcessGroup {
public:
public:
class
Task
{
class
Task
{
public:
public:
Task
(
int
rank
,
const
std
::
vector
<
Tensor
>&
inputTensors
,
Task
(
int
rank
,
const
std
::
vector
<
phi
::
Dense
Tensor
>&
inputTensors
,
CommType
opType
=
CommType
::
UNKNOWN
);
CommType
opType
=
CommType
::
UNKNOWN
);
virtual
~
Task
();
virtual
~
Task
();
...
@@ -79,25 +79,21 @@ class ProcessGroup {
...
@@ -79,25 +79,21 @@ class ProcessGroup {
virtual
const
std
::
string
GetBackendName
()
const
=
0
;
virtual
const
std
::
string
GetBackendName
()
const
=
0
;
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
vector
<
Tensor
>&
/* tensors */
,
std
::
vector
<
phi
::
DenseTensor
>&
/* input tensors */
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
/* output tensors */
,
// NOLINT
const
AllreduceOptions
&
=
AllreduceOptions
())
{
const
AllreduceOptions
&
=
AllreduceOptions
())
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support allreduce"
,
GetBackendName
()));
"ProcessGroup%s does not support allreduce"
,
GetBackendName
()));
}
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
Tensor
>&
/* tensors */
,
std
::
vector
<
phi
::
DenseTensor
>&
/* input tensors */
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
/* output tensors */
,
// NOLINT
const
BroadcastOptions
&
=
BroadcastOptions
())
{
const
BroadcastOptions
&
=
BroadcastOptions
())
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support broadcast"
,
GetBackendName
()));
"ProcessGroup%s does not support broadcast"
,
GetBackendName
()));
}
}
virtual
void
Broadcast
(
const
phi
::
DenseTensor
*
in
,
phi
::
DenseTensor
*
out
)
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"ProcessGroup%s does not support broadcast for static mode runtime"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Barrier
(
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Barrier
(
const
BarrierOptions
&
=
BarrierOptions
())
{
const
BarrierOptions
&
=
BarrierOptions
())
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
...
@@ -105,42 +101,43 @@ class ProcessGroup {
...
@@ -105,42 +101,43 @@ class ProcessGroup {
}
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Send
(
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Send
(
std
::
vector
<
Tensor
>&
tensors
/* tensors */
,
int
dst_rank
)
{
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
,
int
)
{
// NOLINT
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support send"
,
GetBackendName
()));
"ProcessGroup%s does not support send"
,
GetBackendName
()));
}
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Recv
(
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Recv
(
std
::
vector
<
Tensor
>&
tensors
/* tensors */
,
int
src_rank
)
{
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
tensors
,
int
)
{
// NOLINT
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support receive"
,
GetBackendName
()));
"ProcessGroup%s does not support receive"
,
GetBackendName
()));
}
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
std
::
vector
<
Tensor
>&
in_tensors
/* tensors */
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
,
// NOLINT
std
::
vector
<
Tensor
>&
out_tensors
/* tensors */
)
{
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
)
{
// NOLINT
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support AllGather"
,
GetBackendName
()));
"ProcessGroup%s does not support AllGather"
,
GetBackendName
()));
}
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllToAll
(
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllToAll
(
std
::
vector
<
Tensor
>&
in
/* tensors */
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
,
// NOLINT
std
::
vector
<
Tensor
>&
out
/* tensors */
)
{
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
)
{
// NOLINT
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support AllToAll"
,
GetBackendName
()));
"ProcessGroup%s does not support AllToAll"
,
GetBackendName
()));
}
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Reduce
(
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Reduce
(
std
::
vector
<
Tensor
>&
tensors
/* tensors */
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
,
// NOLINT
const
ReduceOptions
&
opts
)
{
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
,
// NOLINT
const
ReduceOptions
&
opts
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support Reduce"
,
GetBackendName
()));
"ProcessGroup%s does not support Reduce"
,
GetBackendName
()));
}
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Scatter
(
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Scatter
(
std
::
vector
<
Tensor
>&
in_tensors
/* tensors */
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
,
// NOLINT
std
::
vector
<
Tensor
>&
out_tensors
/* tensors */
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
,
// NOLINT
const
ScatterOptions
&
)
{
// NOLINT
const
ScatterOptions
&
)
{
// NOLINT
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support Scatter"
,
GetBackendName
()));
"ProcessGroup%s does not support Scatter"
,
GetBackendName
()));
}
}
...
...
paddle/fluid/distributed/collective/ProcessGroupGloo.cc
浏览文件 @
1e56ca8a
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include <gloo/broadcast.h>
#include <gloo/broadcast.h>
#include <gloo/reduce.h>
#include <gloo/reduce.h>
#include <gloo/scatter.h>
#include <gloo/scatter.h>
#include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/distributed/collective/ProcessGroupGloo.h"
#include "paddle/fluid/distributed/collective/ProcessGroupGloo.h"
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
...
@@ -105,107 +106,104 @@ reduce_func get_function(const ReduceOp& r) {
...
@@ -105,107 +106,104 @@ reduce_func get_function(const ReduceOp& r) {
exit
(
-
1
);
exit
(
-
1
);
}
}
bool
CheckTensorsInCPUPlace
(
const
std
::
vector
<
Tensor
>&
tensors
)
{
return
std
::
all_of
(
tensors
.
cbegin
(),
tensors
.
cend
(),
[
&
](
const
Tensor
&
t
)
{
return
t
.
place
()
==
PlaceType
::
kCPU
;
});
}
template
<
typename
T
>
template
<
typename
T
>
T
*
get_data
(
const
Tensor
&
tensor
)
{
T
*
get_data
(
phi
::
DenseTensor
&
tensor
)
{
// NOLINT
auto
raw_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
.
impl
());
return
reinterpret_cast
<
T
*>
(
tensor
.
data
());
return
static_cast
<
T
*>
(
raw_tensor
->
data
());
}
}
template
<
typename
T
>
template
<
typename
T
>
std
::
vector
<
T
*>
get_multi_data
(
const
std
::
vector
<
Tensor
>&
tensors
)
{
std
::
vector
<
T
*>
get_multi_data
(
std
::
vector
<
T
*>
ret
(
tensors
.
size
());
std
::
vector
<
phi
::
DenseTensor
>&
tensors
)
{
// NOLINT
std
::
vector
<
T
*>
ret
;
ret
.
reserve
(
tensors
.
size
());
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
i
++
)
{
ret
[
i
]
=
get_data
<
T
>
(
tensors
[
i
]
);
ret
.
push_back
(
get_data
<
T
>
(
tensors
[
i
])
);
}
}
return
ret
;
return
ret
;
}
}
template
<
typename
T
,
typename
P
>
template
<
typename
T
,
typename
P
>
void
set_output
(
P
&
opts
,
const
Tensor
&
tensor
)
{
// NOLINT
void
set_output
(
P
&
opts
,
phi
::
Dense
Tensor
&
tensor
)
{
// NOLINT
opts
.
setOutput
(
get_data
<
T
>
(
tensor
),
tensor
.
numel
());
opts
.
setOutput
(
get_data
<
T
>
(
tensor
),
tensor
.
numel
());
}
}
template
<
typename
T
,
typename
P
>
template
<
typename
T
,
typename
P
>
void
set_input
(
P
&
opts
,
const
Tensor
&
tensor
)
{
// NOLINT
void
set_input
(
P
&
opts
,
phi
::
Dense
Tensor
&
tensor
)
{
// NOLINT
opts
.
setInput
(
get_data
<
T
>
(
tensor
),
tensor
.
numel
());
opts
.
setInput
(
get_data
<
T
>
(
tensor
),
tensor
.
numel
());
}
}
template
<
typename
T
,
typename
P
>
template
<
typename
T
,
typename
P
>
void
set_outputs
(
P
&
opts
,
const
std
::
vector
<
Tensor
>&
tensors
)
{
// NOLINT
void
set_outputs
(
P
&
opts
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
tensors
)
{
// NOLINT
opts
.
setOutputs
(
get_multi_data
<
T
>
(
tensors
),
tensors
[
0
].
numel
());
opts
.
setOutputs
(
get_multi_data
<
T
>
(
tensors
),
tensors
[
0
].
numel
());
}
}
template
<
typename
T
,
typename
P
>
template
<
typename
T
,
typename
P
>
void
set_inputs
(
P
&
opts
,
const
std
::
vector
<
Tensor
>&
tensors
)
{
// NOLINT
void
set_inputs
(
P
&
opts
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
tensors
)
{
// NOLINT
opts
.
setInputs
(
get_multi_data
<
T
>
(
tensors
),
tensors
[
0
].
numel
());
opts
.
setInputs
(
get_multi_data
<
T
>
(
tensors
),
tensors
[
0
].
numel
());
}
}
template
<
typename
T
,
typename
P
>
template
<
typename
T
,
typename
P
>
void
set_inputs_for_scatter
(
P
&
opts
,
// NOLINT
void
set_inputs_for_scatter
(
P
&
opts
,
// NOLINT
const
std
::
vector
<
Tensor
>&
tensors
,
// NOLINT
phi
::
DenseTensor
&
tensor
,
// NOLINT
int
nranks
)
{
int
nranks
)
{
std
::
vector
<
T
*>
ret
(
nranks
);
std
::
vector
<
T
*>
ret
;
auto
raw_tensor
=
ret
.
reserve
(
nranks
);
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensors
[
0
].
impl
());
T
*
raw_pointer
=
reinterpret_cast
<
T
*>
(
tensor
.
data
());
T
*
raw_pointer
=
reinterpret_cast
<
T
*>
(
raw_tensor
->
data
());
size_t
offset
=
0
;
size_t
offset
=
0
;
for
(
int
i
=
0
;
i
<
nranks
;
i
++
)
{
for
(
int
i
=
0
;
i
<
nranks
;
i
++
)
{
ret
[
i
]
=
raw_pointer
+
offset
;
ret
.
push_back
(
raw_pointer
+
offset
)
;
offset
+=
tensor
s
[
0
]
.
numel
()
/
nranks
;
offset
+=
tensor
.
numel
()
/
nranks
;
}
}
opts
.
setInputs
(
ret
,
tensor
s
[
0
]
.
numel
()
/
nranks
);
opts
.
setInputs
(
ret
,
tensor
.
numel
()
/
nranks
);
}
}
ProcessGroupGloo
::
GlooTask
::
GlooTask
(
int
rank
,
ProcessGroupGloo
::
GlooTask
::
GlooTask
(
const
std
::
vector
<
Tensor
>&
inputs
,
int
rank
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
CommType
comm_type
)
CommType
comm_type
)
:
ProcessGroup
::
Task
(
rank
,
inputs
,
comm_type
)
{}
:
ProcessGroup
::
Task
(
rank
,
inputs
,
comm_type
)
{
PADDLE_ENFORCE_EQ
(
CheckTensorsInCPUPlace
(
inputs
),
true
,
platform
::
errors
::
Fatal
(
"Only CPU place is supported for ProcessGroupGloo."
));
}
ProcessGroupGloo
::
ProcessGroupGloo
(
ProcessGroupGloo
::
ProcessGroupGloo
(
const
std
::
shared_ptr
<
paddle
::
distributed
::
Store
>&
store
,
int
rank
,
const
std
::
shared_ptr
<
distributed
::
Store
>&
store
,
int
rank
,
int
world_size
,
int
world_size
,
int
gid
,
const
std
::
shared_ptr
<
GlooOptions
>
options
)
int
gid
,
const
std
::
shared_ptr
<
GlooOptions
>
options
)
:
ProcessGroup
(
rank
,
world_size
,
gid
),
:
ProcessGroup
(
rank
,
world_size
,
gid
),
_tag
(
0
),
_tag
(
0
),
_store
(
new
GlooStore
(
store
))
{
_store
(
new
GlooStore
(
store
))
{
_context
=
std
::
make_shared
<
gloo
::
rendezvous
::
Context
>
(
rank
,
world_size
);
_context
=
std
::
make_shared
<
gloo
::
rendezvous
::
Context
>
(
rank
,
world_size
);
auto
prefix_store
=
auto
prefix_store
=
::
gloo
::
rendezvous
::
PrefixStore
(
std
::
to_string
(
0
),
*
_store
);
::
gloo
::
rendezvous
::
PrefixStore
(
std
::
to_string
(
gid
),
*
_store
);
_context
->
connectFullMesh
(
prefix_store
,
options
->
device
);
_context
->
connectFullMesh
(
prefix_store
,
options
->
device
);
}
}
class
BroadcastGlooTask
:
public
ProcessGroupGloo
::
GlooTask
{
class
BroadcastGlooTask
:
public
ProcessGroupGloo
::
GlooTask
{
public:
public:
BroadcastGlooTask
(
const
std
::
shared_ptr
<
gloo
::
Context
>&
context
,
BroadcastGlooTask
(
const
std
::
shared_ptr
<
gloo
::
Context
>&
context
,
const
std
::
vector
<
Tensor
>&
inputs
,
int
rank
,
int
root
,
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
// NOLINT
uint32_t
tag
)
std
::
vector
<
phi
::
DenseTensor
>&
outputs
,
// NOLINT
int
rank
,
int
root
,
uint32_t
tag
)
:
ProcessGroupGloo
::
GlooTask
(
rank
,
inputs
,
CommType
::
BROADCAST
),
:
ProcessGroupGloo
::
GlooTask
(
rank
,
inputs
,
CommType
::
BROADCAST
),
_context
(
context
),
_context
(
context
),
_root
(
root
),
_root
(
root
),
_inputs
(
inputs
),
_inputs
(
inputs
),
_outputs
(
outputs
),
_tag
(
tag
)
{}
_tag
(
tag
)
{}
void
Run
()
override
{
_do_broadcast
(
_inputs
[
0
]);
}
void
Run
()
override
{
_do_broadcast
(
_inputs
[
0
]
,
_outputs
[
0
]
);
}
private:
private:
std
::
shared_ptr
<
gloo
::
Context
>
_context
;
std
::
shared_ptr
<
gloo
::
Context
>
_context
;
const
int
_root
;
const
int
_root
;
std
::
vector
<
Tensor
>
_inputs
{};
std
::
vector
<
phi
::
DenseTensor
>
_inputs
{};
std
::
vector
<
phi
::
DenseTensor
>
_outputs
{};
const
uint32_t
_tag
;
const
uint32_t
_tag
;
void
_do_broadcast
(
const
Tensor
&
tensor
)
{
void
_do_broadcast
(
phi
::
DenseTensor
&
in
,
phi
::
DenseTensor
&
out
)
{
// NOLINT
gloo
::
BroadcastOptions
opts
(
_context
);
gloo
::
BroadcastOptions
opts
(
_context
);
const
auto
&
dtype
=
tensor
.
type
();
const
auto
&
dtype
=
in
.
dtype
();
GENERATE_FUNC
(
dtype
,
set_output
,
opts
,
tensor
);
if
(
rank_
==
_root
)
{
GENERATE_FUNC
(
dtype
,
set_input
,
opts
,
in
);
}
GENERATE_FUNC
(
dtype
,
set_output
,
opts
,
out
);
opts
.
setRoot
(
_root
);
opts
.
setRoot
(
_root
);
opts
.
setTag
(
_tag
);
opts
.
setTag
(
_tag
);
gloo
::
broadcast
(
opts
);
gloo
::
broadcast
(
opts
);
...
@@ -213,12 +211,14 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask {
...
@@ -213,12 +211,14 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask {
};
};
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupGloo
::
Broadcast
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupGloo
::
Broadcast
(
std
::
vector
<
Tensor
>&
inputs
,
const
BroadcastOptions
&
opts
)
{
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
std
::
vector
<
phi
::
DenseTensor
>&
outputs
,
const
BroadcastOptions
&
opts
)
{
auto
root
=
opts
.
source_rank
;
auto
root
=
opts
.
source_rank
;
std
::
unique_ptr
<
BroadcastGlooTask
>
task
;
std
::
unique_ptr
<
BroadcastGlooTask
>
task
;
auto
tag
=
next_tag
();
auto
tag
=
next_tag
();
auto
context
=
get_context
();
auto
context
=
get_context
();
task
=
std
::
make_unique
<
BroadcastGlooTask
>
(
context
,
inputs
,
rank_
,
root
,
tag
);
task
=
std
::
make_unique
<
BroadcastGlooTask
>
(
context
,
inputs
,
outputs
,
rank_
,
root
,
tag
);
task
->
Run
();
task
->
Run
();
return
task
;
return
task
;
}
}
...
@@ -226,19 +226,22 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast(
...
@@ -226,19 +226,22 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast(
class
AllreduceGlooTask
:
public
ProcessGroupGloo
::
GlooTask
{
class
AllreduceGlooTask
:
public
ProcessGroupGloo
::
GlooTask
{
public:
public:
AllreduceGlooTask
(
int
rank
,
const
std
::
shared_ptr
<
gloo
::
Context
>&
context
,
AllreduceGlooTask
(
int
rank
,
const
std
::
shared_ptr
<
gloo
::
Context
>&
context
,
std
::
vector
<
Tensor
>&
inputs
,
ReduceOp
reduce_op
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
// NOLINT
uint32_t
tag
)
std
::
vector
<
phi
::
DenseTensor
>&
outputs
,
// NOLINT
ReduceOp
reduce_op
,
uint32_t
tag
)
:
ProcessGroupGloo
::
GlooTask
(
rank
,
inputs
,
CommType
::
ALLREDUCE
),
:
ProcessGroupGloo
::
GlooTask
(
rank
,
inputs
,
CommType
::
ALLREDUCE
),
_context
(
context
),
_context
(
context
),
_inputs
(
inputs
),
_inputs
(
inputs
),
_outputs
(
outputs
),
_reduce_op
(
reduce_op
),
_reduce_op
(
reduce_op
),
_tag
(
tag
)
{}
_tag
(
tag
)
{}
void
Run
()
override
{
_do_allreduce
(
_inputs
);
}
void
Run
()
override
{
_do_allreduce
(
_inputs
,
_outputs
);
}
private:
private:
std
::
shared_ptr
<
gloo
::
Context
>
_context
;
std
::
shared_ptr
<
gloo
::
Context
>
_context
;
std
::
vector
<
Tensor
>
_inputs
;
std
::
vector
<
phi
::
DenseTensor
>
_inputs
;
std
::
vector
<
phi
::
DenseTensor
>
_outputs
;
const
ReduceOp
_reduce_op
;
const
ReduceOp
_reduce_op
;
uint32_t
_tag
;
uint32_t
_tag
;
...
@@ -255,11 +258,12 @@ class AllreduceGlooTask : public ProcessGroupGloo::GlooTask {
...
@@ -255,11 +258,12 @@ class AllreduceGlooTask : public ProcessGroupGloo::GlooTask {
fn
=
get_function
<
T
>
(
op
);
fn
=
get_function
<
T
>
(
op
);
}
}
void
_do_allreduce
(
std
::
vector
<
Tensor
>&
tensors
)
{
// NOLINT
void
_do_allreduce
(
std
::
vector
<
phi
::
DenseTensor
>&
ins
,
// NOLINT
const
auto
&
dtype
=
tensors
[
0
].
type
();
std
::
vector
<
phi
::
DenseTensor
>&
outs
)
{
// NOLINT
const
auto
&
dtype
=
ins
[
0
].
dtype
();
gloo
::
AllreduceOptions
opts
(
_context
);
gloo
::
AllreduceOptions
opts
(
_context
);
GENERATE_FUNC
(
dtype
,
set_inputs
,
opts
,
tensor
s
);
GENERATE_FUNC
(
dtype
,
set_inputs
,
opts
,
in
s
);
GENERATE_FUNC
(
dtype
,
set_outputs
,
opts
,
tensor
s
);
GENERATE_FUNC
(
dtype
,
set_outputs
,
opts
,
out
s
);
opts
.
setReduceFunction
(
_get_function
(
dtype
,
_reduce_op
));
opts
.
setReduceFunction
(
_get_function
(
dtype
,
_reduce_op
));
opts
.
setTag
(
_tag
);
opts
.
setTag
(
_tag
);
gloo
::
allreduce
(
opts
);
gloo
::
allreduce
(
opts
);
...
@@ -267,11 +271,12 @@ class AllreduceGlooTask : public ProcessGroupGloo::GlooTask {
...
@@ -267,11 +271,12 @@ class AllreduceGlooTask : public ProcessGroupGloo::GlooTask {
};
};
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupGloo
::
AllReduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupGloo
::
AllReduce
(
std
::
vector
<
Tensor
>&
inputs
,
const
AllreduceOptions
&
opts
)
{
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
std
::
vector
<
phi
::
DenseTensor
>&
outputs
,
const
AllreduceOptions
&
opts
)
{
auto
tag
=
next_tag
();
auto
tag
=
next_tag
();
std
::
shared_ptr
<
GlooTask
>
task
;
std
::
shared_ptr
<
GlooTask
>
task
;
auto
context
=
get_context
();
auto
context
=
get_context
();
task
=
std
::
make_shared
<
AllreduceGlooTask
>
(
rank_
,
context
,
inputs
,
task
=
std
::
make_shared
<
AllreduceGlooTask
>
(
rank_
,
context
,
inputs
,
outputs
,
opts
.
reduce_op
,
tag
);
opts
.
reduce_op
,
tag
);
task
->
Run
();
task
->
Run
();
return
task
;
return
task
;
...
@@ -280,7 +285,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
...
@@ -280,7 +285,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
class
BarrierGlooTask
:
public
ProcessGroupGloo
::
GlooTask
{
class
BarrierGlooTask
:
public
ProcessGroupGloo
::
GlooTask
{
public:
public:
BarrierGlooTask
(
int
rank
,
const
std
::
shared_ptr
<
gloo
::
Context
>&
context
)
BarrierGlooTask
(
int
rank
,
const
std
::
shared_ptr
<
gloo
::
Context
>&
context
)
:
ProcessGroupGloo
::
GlooTask
(
rank
,
std
::
vector
<
Tensor
>
{},
:
ProcessGroupGloo
::
GlooTask
(
rank
,
std
::
vector
<
phi
::
Dense
Tensor
>
{},
CommType
::
BARRIER
),
CommType
::
BARRIER
),
_context
(
context
)
{}
_context
(
context
)
{}
...
@@ -307,8 +312,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Barrier(
...
@@ -307,8 +312,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Barrier(
class
AllgatherGlooTask
:
public
ProcessGroupGloo
::
GlooTask
{
class
AllgatherGlooTask
:
public
ProcessGroupGloo
::
GlooTask
{
public:
public:
AllgatherGlooTask
(
int
rank
,
const
std
::
shared_ptr
<
gloo
::
Context
>&
context
,
AllgatherGlooTask
(
int
rank
,
const
std
::
shared_ptr
<
gloo
::
Context
>&
context
,
std
::
vector
<
Tensor
>&
inputs
,
// NOLINT
std
::
vector
<
phi
::
Dense
Tensor
>&
inputs
,
// NOLINT
std
::
vector
<
Tensor
>&
outputs
,
// NOLINT
std
::
vector
<
phi
::
Dense
Tensor
>&
outputs
,
// NOLINT
uint32_t
tag
)
uint32_t
tag
)
:
ProcessGroupGloo
::
GlooTask
(
rank
,
inputs
,
CommType
::
ALLGATHER
),
:
ProcessGroupGloo
::
GlooTask
(
rank
,
inputs
,
CommType
::
ALLGATHER
),
_context
(
context
),
_context
(
context
),
...
@@ -320,13 +325,13 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask {
...
@@ -320,13 +325,13 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask {
private:
private:
std
::
shared_ptr
<
gloo
::
Context
>
_context
;
std
::
shared_ptr
<
gloo
::
Context
>
_context
;
std
::
vector
<
Tensor
>
_inputs
;
std
::
vector
<
phi
::
Dense
Tensor
>
_inputs
;
std
::
vector
<
Tensor
>
_outputs
;
std
::
vector
<
phi
::
Dense
Tensor
>
_outputs
;
uint32_t
_tag
;
uint32_t
_tag
;
void
_do_allgather
(
std
::
vector
<
Tensor
>&
in
,
// NOLINT
void
_do_allgather
(
std
::
vector
<
phi
::
Dense
Tensor
>&
in
,
// NOLINT
std
::
vector
<
Tensor
>&
out
)
{
// NOLINT
std
::
vector
<
phi
::
Dense
Tensor
>&
out
)
{
// NOLINT
const
auto
&
dtype
=
in
[
0
].
type
();
const
auto
&
dtype
=
in
[
0
].
d
type
();
gloo
::
AllgatherOptions
opts
(
_context
);
gloo
::
AllgatherOptions
opts
(
_context
);
GENERATE_FUNC
(
dtype
,
set_input
,
opts
,
in
[
0
]);
GENERATE_FUNC
(
dtype
,
set_input
,
opts
,
in
[
0
]);
GENERATE_FUNC
(
dtype
,
set_output
,
opts
,
out
[
0
]);
GENERATE_FUNC
(
dtype
,
set_output
,
opts
,
out
[
0
]);
...
@@ -336,7 +341,8 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask {
...
@@ -336,7 +341,8 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask {
};
};
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupGloo
::
AllGather
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupGloo
::
AllGather
(
std
::
vector
<
Tensor
>&
in_tensors
,
std
::
vector
<
Tensor
>&
out_tensors
)
{
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
)
{
std
::
shared_ptr
<
AllgatherGlooTask
>
task
;
std
::
shared_ptr
<
AllgatherGlooTask
>
task
;
auto
tag
=
next_tag
();
auto
tag
=
next_tag
();
auto
context
=
get_context
();
auto
context
=
get_context
();
...
@@ -349,20 +355,23 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
...
@@ -349,20 +355,23 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
class
ReduceGlooTask
:
public
ProcessGroupGloo
::
GlooTask
{
class
ReduceGlooTask
:
public
ProcessGroupGloo
::
GlooTask
{
public:
public:
ReduceGlooTask
(
int
rank
,
const
std
::
shared_ptr
<
gloo
::
Context
>&
context
,
ReduceGlooTask
(
int
rank
,
const
std
::
shared_ptr
<
gloo
::
Context
>&
context
,
std
::
vector
<
Tensor
>&
in
,
ReduceOp
reduce_op
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
// NOLINT
int
dst
,
uint32_t
tag
)
std
::
vector
<
phi
::
DenseTensor
>&
outputs
,
// NOLINT
:
ProcessGroupGloo
::
GlooTask
(
rank
,
in
,
CommType
::
REDUCE
),
ReduceOp
reduce_op
,
int
dst
,
uint32_t
tag
)
:
ProcessGroupGloo
::
GlooTask
(
rank
,
inputs
,
CommType
::
REDUCE
),
_context
(
context
),
_context
(
context
),
_inputs
(
in
),
_inputs
(
inputs
),
_outputs
(
outputs
),
_reduce_op
(
reduce_op
),
_reduce_op
(
reduce_op
),
_dst
(
dst
),
_dst
(
dst
),
_tag
(
tag
)
{}
_tag
(
tag
)
{}
void
Run
()
override
{
_do_reduce
(
_inputs
,
_dst
);
}
void
Run
()
override
{
_do_reduce
(
_inputs
,
_
outputs
,
_
dst
);
}
private:
private:
std
::
shared_ptr
<
gloo
::
Context
>
_context
;
std
::
shared_ptr
<
gloo
::
Context
>
_context
;
std
::
vector
<
Tensor
>
_inputs
;
std
::
vector
<
phi
::
DenseTensor
>
_inputs
;
std
::
vector
<
phi
::
DenseTensor
>
_outputs
;
const
ReduceOp
_reduce_op
;
const
ReduceOp
_reduce_op
;
int
_dst
;
int
_dst
;
uint32_t
_tag
;
uint32_t
_tag
;
...
@@ -380,11 +389,13 @@ class ReduceGlooTask : public ProcessGroupGloo::GlooTask {
...
@@ -380,11 +389,13 @@ class ReduceGlooTask : public ProcessGroupGloo::GlooTask {
fn
=
get_function
<
T
>
(
op
);
fn
=
get_function
<
T
>
(
op
);
}
}
void
_do_reduce
(
std
::
vector
<
Tensor
>&
tensors
,
int
dst
)
{
// NOLINT
void
_do_reduce
(
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
// NOLINT
const
auto
&
dtype
=
tensors
[
0
].
type
();
std
::
vector
<
phi
::
DenseTensor
>&
outputs
,
// NOLINT
int
dst
)
{
const
auto
&
dtype
=
inputs
[
0
].
dtype
();
gloo
::
ReduceOptions
opts
(
_context
);
gloo
::
ReduceOptions
opts
(
_context
);
GENERATE_FUNC
(
dtype
,
set_input
,
opts
,
tensor
s
[
0
]);
GENERATE_FUNC
(
dtype
,
set_input
,
opts
,
input
s
[
0
]);
GENERATE_FUNC
(
dtype
,
set_output
,
opts
,
tensor
s
[
0
]);
GENERATE_FUNC
(
dtype
,
set_output
,
opts
,
output
s
[
0
]);
opts
.
setReduceFunction
(
_get_function
(
dtype
,
_reduce_op
));
opts
.
setReduceFunction
(
_get_function
(
dtype
,
_reduce_op
));
opts
.
setTag
(
_tag
);
opts
.
setTag
(
_tag
);
opts
.
setRoot
(
dst
);
opts
.
setRoot
(
dst
);
...
@@ -393,11 +404,12 @@ class ReduceGlooTask : public ProcessGroupGloo::GlooTask {
...
@@ -393,11 +404,12 @@ class ReduceGlooTask : public ProcessGroupGloo::GlooTask {
};
};
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupGloo
::
Reduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupGloo
::
Reduce
(
std
::
vector
<
Tensor
>&
tensors
,
const
ReduceOptions
&
opts
)
{
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
std
::
vector
<
phi
::
DenseTensor
>&
outputs
,
const
ReduceOptions
&
opts
)
{
std
::
shared_ptr
<
ReduceGlooTask
>
task
;
std
::
shared_ptr
<
ReduceGlooTask
>
task
;
auto
tag
=
next_tag
();
auto
tag
=
next_tag
();
auto
context
=
get_context
();
auto
context
=
get_context
();
task
=
std
::
make_shared
<
ReduceGlooTask
>
(
rank_
,
context
,
tensor
s
,
task
=
std
::
make_shared
<
ReduceGlooTask
>
(
rank_
,
context
,
inputs
,
output
s
,
opts
.
reduce_op
,
opts
.
root_rank
,
tag
);
opts
.
reduce_op
,
opts
.
root_rank
,
tag
);
task
->
Run
();
task
->
Run
();
return
task
;
return
task
;
...
@@ -406,8 +418,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Reduce(
...
@@ -406,8 +418,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Reduce(
class
ScatterGlooTask
:
public
ProcessGroupGloo
::
GlooTask
{
class
ScatterGlooTask
:
public
ProcessGroupGloo
::
GlooTask
{
public:
public:
ScatterGlooTask
(
int
rank
,
const
std
::
shared_ptr
<
gloo
::
Context
>&
context
,
ScatterGlooTask
(
int
rank
,
const
std
::
shared_ptr
<
gloo
::
Context
>&
context
,
std
::
vector
<
Tensor
>&
inputs
,
// NOLINT
std
::
vector
<
phi
::
Dense
Tensor
>&
inputs
,
// NOLINT
std
::
vector
<
Tensor
>&
outputs
,
// NOLINT
std
::
vector
<
phi
::
Dense
Tensor
>&
outputs
,
// NOLINT
int
src
,
int
size
,
uint32_t
tag
)
int
src
,
int
size
,
uint32_t
tag
)
:
ProcessGroupGloo
::
GlooTask
(
rank
,
inputs
,
CommType
::
SCATTER
),
:
ProcessGroupGloo
::
GlooTask
(
rank
,
inputs
,
CommType
::
SCATTER
),
_context
(
context
),
_context
(
context
),
...
@@ -421,18 +433,19 @@ class ScatterGlooTask : public ProcessGroupGloo::GlooTask {
...
@@ -421,18 +433,19 @@ class ScatterGlooTask : public ProcessGroupGloo::GlooTask {
private:
private:
std
::
shared_ptr
<
gloo
::
Context
>
_context
;
std
::
shared_ptr
<
gloo
::
Context
>
_context
;
std
::
vector
<
Tensor
>
_inputs
;
std
::
vector
<
phi
::
Dense
Tensor
>
_inputs
;
std
::
vector
<
Tensor
>
_outputs
;
std
::
vector
<
phi
::
Dense
Tensor
>
_outputs
;
int
_src
;
int
_src
;
int
_size
;
int
_size
;
uint32_t
_tag
;
uint32_t
_tag
;
void
_do_scatter
(
std
::
vector
<
Tensor
>&
in
,
std
::
vector
<
Tensor
>&
out
,
// NOLINT
void
_do_scatter
(
std
::
vector
<
phi
::
DenseTensor
>&
in
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out
,
// NOLINT
int
src
)
{
int
src
)
{
const
auto
&
dtype
=
in
[
0
].
type
();
const
auto
&
dtype
=
in
[
0
].
d
type
();
gloo
::
ScatterOptions
opts
(
_context
);
gloo
::
ScatterOptions
opts
(
_context
);
if
(
rank_
==
src
)
{
if
(
rank_
==
src
)
{
GENERATE_FUNC
(
dtype
,
set_inputs_for_scatter
,
opts
,
in
,
_size
);
GENERATE_FUNC
(
dtype
,
set_inputs_for_scatter
,
opts
,
in
[
0
]
,
_size
);
}
}
GENERATE_FUNC
(
dtype
,
set_output
,
opts
,
out
[
0
]);
GENERATE_FUNC
(
dtype
,
set_output
,
opts
,
out
[
0
]);
opts
.
setRoot
(
src
);
opts
.
setRoot
(
src
);
...
@@ -442,8 +455,8 @@ class ScatterGlooTask : public ProcessGroupGloo::GlooTask {
...
@@ -442,8 +455,8 @@ class ScatterGlooTask : public ProcessGroupGloo::GlooTask {
};
};
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupGloo
::
Scatter
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupGloo
::
Scatter
(
std
::
vector
<
Tensor
>&
in_tensors
,
std
::
vector
<
Tensor
>&
out
_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
in
_tensors
,
const
ScatterOptions
&
opts
)
{
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
ScatterOptions
&
opts
)
{
std
::
shared_ptr
<
ScatterGlooTask
>
task
;
std
::
shared_ptr
<
ScatterGlooTask
>
task
;
auto
tag
=
next_tag
();
auto
tag
=
next_tag
();
auto
context
=
get_context
();
auto
context
=
get_context
();
...
...
paddle/fluid/distributed/collective/ProcessGroupGloo.h
浏览文件 @
1e56ca8a
...
@@ -36,7 +36,8 @@ class ProcessGroupGloo : public ProcessGroup {
...
@@ -36,7 +36,8 @@ class ProcessGroupGloo : public ProcessGroup {
class
GlooTask
:
public
ProcessGroup
::
Task
,
class
GlooTask
:
public
ProcessGroup
::
Task
,
public
std
::
enable_shared_from_this
<
GlooTask
>
{
public
std
::
enable_shared_from_this
<
GlooTask
>
{
public:
public:
explicit
GlooTask
(
int
rank
,
const
std
::
vector
<
Tensor
>&
input_tensors
,
explicit
GlooTask
(
int
rank
,
const
std
::
vector
<
phi
::
DenseTensor
>&
input_tensors
,
CommType
comm_type
);
CommType
comm_type
);
~
GlooTask
()
=
default
;
~
GlooTask
()
=
default
;
...
@@ -106,26 +107,31 @@ class ProcessGroupGloo : public ProcessGroup {
...
@@ -106,26 +107,31 @@ class ProcessGroupGloo : public ProcessGroup {
~
ProcessGroupGloo
()
=
default
;
~
ProcessGroupGloo
()
=
default
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
Tensor
>&
inputs
,
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
std
::
vector
<
phi
::
DenseTensor
>&
outputs
,
const
BroadcastOptions
&
=
BroadcastOptions
())
override
;
const
BroadcastOptions
&
=
BroadcastOptions
())
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
vector
<
Tensor
>&
inputs
,
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
std
::
vector
<
phi
::
DenseTensor
>&
outputs
,
const
AllreduceOptions
&
opts
=
AllreduceOptions
())
override
;
const
AllreduceOptions
&
opts
=
AllreduceOptions
())
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Barrier
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Barrier
(
const
BarrierOptions
&
=
BarrierOptions
())
override
;
const
BarrierOptions
&
=
BarrierOptions
())
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
std
::
vector
<
Tensor
>&
in_tensors
,
std
::
vector
<
phi
::
Dense
Tensor
>&
in_tensors
,
std
::
vector
<
Tensor
>&
out_tensors
)
override
;
std
::
vector
<
phi
::
Dense
Tensor
>&
out_tensors
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Reduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Reduce
(
std
::
vector
<
Tensor
>&
tensors
,
const
ReduceOptions
&
opts
)
override
;
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Scatter
(
std
::
vector
<
Tensor
>&
in_tensors
,
const
ReduceOptions
&
opts
)
override
;
std
::
vector
<
Tensor
>&
out_tensors
,
const
ScatterOptions
&
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Scatter
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
ScatterOptions
&
)
override
;
std
::
shared_ptr
<::
gloo
::
Context
>
get_context
()
{
return
_context
;
}
std
::
shared_ptr
<::
gloo
::
Context
>
get_context
()
{
return
_context
;
}
uint64_t
next_tag
()
{
return
_tag
++
;
}
uint64_t
next_tag
()
{
return
_tag
++
;
}
...
...
paddle/fluid/distributed/collective/ProcessGroupHCCL.cc
浏览文件 @
1e56ca8a
...
@@ -44,14 +44,14 @@ void SyncDefaultStream(
...
@@ -44,14 +44,14 @@ void SyncDefaultStream(
std
::
shared_ptr
<
ProcessGroupHCCL
::
HCCLTask
>
ProcessGroupHCCL
::
CreateTask
(
std
::
shared_ptr
<
ProcessGroupHCCL
::
HCCLTask
>
ProcessGroupHCCL
::
CreateTask
(
std
::
vector
<
Place
>
places
,
int
rank
,
CommType
comm_type
,
std
::
vector
<
Place
>
places
,
int
rank
,
CommType
comm_type
,
const
std
::
vector
<
Tensor
>&
inputs
)
{
const
std
::
vector
<
phi
::
Dense
Tensor
>&
inputs
)
{
return
std
::
make_shared
<
ProcessGroupHCCL
::
HCCLTask
>
(
places
,
rank
,
comm_type
,
return
std
::
make_shared
<
ProcessGroupHCCL
::
HCCLTask
>
(
places
,
rank
,
comm_type
,
inputs
);
inputs
);
}
}
ProcessGroupHCCL
::
HCCLTask
::
HCCLTask
(
const
std
::
vector
<
Place
>&
places
,
int
rank
,
ProcessGroupHCCL
::
HCCLTask
::
HCCLTask
(
CommType
CommType
,
const
std
::
vector
<
Place
>&
places
,
int
rank
,
CommType
CommType
,
const
std
::
vector
<
Tensor
>&
inputs
)
const
std
::
vector
<
phi
::
Dense
Tensor
>&
inputs
)
:
Task
(
rank
,
inputs
,
CommType
),
places_
(
places
)
{
:
Task
(
rank
,
inputs
,
CommType
),
places_
(
places
)
{
control_events_
.
resize
(
places
.
size
());
control_events_
.
resize
(
places
.
size
());
hcclComms_
.
resize
(
places
.
size
());
hcclComms_
.
resize
(
places
.
size
());
...
@@ -60,8 +60,8 @@ ProcessGroupHCCL::HCCLTask::HCCLTask(const std::vector<Place>& places, int rank,
...
@@ -60,8 +60,8 @@ ProcessGroupHCCL::HCCLTask::HCCLTask(const std::vector<Place>& places, int rank,
ProcessGroupHCCL
::
HCCLTask
::~
HCCLTask
()
{}
ProcessGroupHCCL
::
HCCLTask
::~
HCCLTask
()
{}
void
ProcessGroupHCCL
::
HCCLTask
::
SetOutputs
(
void
ProcessGroupHCCL
::
HCCLTask
::
SetOutputs
(
std
::
vector
<
Tensor
>&
outputs
)
{
// NOLINT
std
::
vector
<
phi
::
Dense
Tensor
>&
outputs
)
{
// NOLINT
outputs_
=
std
::
make_shared
<
std
::
vector
<
Tensor
>>
(
outputs
);
outputs_
=
std
::
make_shared
<
std
::
vector
<
phi
::
Dense
Tensor
>>
(
outputs
);
}
}
void
ProcessGroupHCCL
::
HCCLTask
::
SynchronizeStreams
()
{
void
ProcessGroupHCCL
::
HCCLTask
::
SynchronizeStreams
()
{
...
@@ -166,8 +166,8 @@ void ProcessGroupHCCL::CreateHCCLManagerCache(
...
@@ -166,8 +166,8 @@ void ProcessGroupHCCL::CreateHCCLManagerCache(
template
<
typename
Fn
>
template
<
typename
Fn
>
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupHCCL
::
Collective
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupHCCL
::
Collective
(
std
::
vector
<
Tensor
>&
inputs
,
std
::
vector
<
Tensor
>&
outputs
,
Fn
fn
,
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
CommType
op_type
)
{
std
::
vector
<
phi
::
DenseTensor
>&
outputs
,
Fn
fn
,
CommType
op_type
)
{
const
auto
places
=
GetPlaceList
(
inputs
);
const
auto
places
=
GetPlaceList
(
inputs
);
const
auto
key
=
GetKeyFromPlaces
(
places
);
const
auto
key
=
GetKeyFromPlaces
(
places
);
...
@@ -208,91 +208,44 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHCCL::Collective(
...
@@ -208,91 +208,44 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupHCCL::Collective(
return
task
;
return
task
;
}
}
template
<
typename
Fn
>
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupHCCL
::
PointToPoint
(
std
::
vector
<
Tensor
>&
tensors
,
Fn
fn
,
int
dst_rank
,
CommType
op_type
)
{
const
auto
places
=
GetPlaceList
(
tensors
);
const
auto
key
=
GetKeyFromPlaces
(
places
);
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
if
(
places_to_hcclcomm_
.
find
(
key
)
==
places_to_hcclcomm_
.
end
())
{
CreateHCCLManagerCache
(
key
,
places
);
}
}
auto
&
hccl_comms
=
places_to_hcclcomm_
[
key
];
SyncDefaultStream
(
places
,
places_to_events_
[
key
],
places_to_ctx_
[
key
]);
auto
task
=
CreateTask
(
places
,
rank_
,
op_type
,
tensors
);
// construct uninitialize guard for device
// if (FLAGS_use_stream_safe_npu_allocator) {
// for (size_t i = 0; i < tensors.size(); ++i) {
// platform::NPUDeviceGuard guard(places[i].GetDeviceId());
// auto dense_tensor =
// std::dynamic_pointer_cast<phi::DenseTensor>(tensors[i].impl());
// memory::RecordStream(dense_tensor->Holder(),
// places_to_ctx_[key][i]->stream());
// }
// }
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
platform
::
NPUDeviceGuard
guard
(
places
[
i
].
GetDeviceId
());
const
auto
&
hccl_stream
=
places_to_ctx_
[
key
][
i
]
->
stream
();
fn
(
tensors
[
i
],
hccl_comms
[
i
]
->
GetHcclComm
(),
hccl_stream
,
dst_rank
);
}
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
platform
::
NPUDeviceGuard
guard
(
places
[
i
].
GetDeviceId
());
task
->
control_events_
[
i
].
Record
(
*
places_to_ctx_
[
key
][
i
]);
}
return
task
;
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupHCCL
::
AllReduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupHCCL
::
AllReduce
(
std
::
vector
<
Tensor
>&
tensors
,
const
AllreduceOptions
&
opts
)
{
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
// PADDLE_ENFORCE_EQ(
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
// CheckTensorsInNPUPlace(tensors), true,
const
AllreduceOptions
&
opts
)
{
// platform::errors::InvalidArgument("All inputs should be in
return
Collective
(
in_tensors
,
out_tensors
,
// NPUPlace."));
[
&
](
phi
::
DenseTensor
&
input
,
phi
::
DenseTensor
&
output
,
return
Collective
(
HcclComm
comm
,
const
aclrtStream
&
stream
)
{
tensors
,
tensors
,
return
platform
::
dynload
::
HcclAllReduce
(
[
&
](
const
Tensor
&
input
,
Tensor
&
output
,
HcclComm
comm
,
input
.
data
(),
output
.
data
(),
input
.
numel
(),
const
aclrtStream
&
stream
)
{
platform
::
ToHCCLDataType
(
input
.
dtype
()),
auto
input_tensor
=
ToHCCLRedType
(
opts
.
reduce_op
),
comm
,
stream
);
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
input
.
impl
());
},
auto
output_tensor
=
CommType
::
ALLREDUCE
);
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
output
.
impl
());
return
platform
::
dynload
::
HcclAllReduce
(
input_tensor
->
data
(),
output_tensor
->
data
(),
input_tensor
->
numel
(),
platform
::
ToHCCLDataType
(
input
.
type
()),
ToHCCLRedType
(
opts
.
reduce_op
),
comm
,
stream
);
},
CommType
::
ALLREDUCE
);
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupHCCL
::
Broadcast
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupHCCL
::
Broadcast
(
std
::
vector
<
Tensor
>&
tensors
,
const
BroadcastOptions
&
opts
)
{
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
const
BroadcastOptions
&
opts
)
{
// PADDLE_ENFORCE_EQ(
// PADDLE_ENFORCE_EQ(
// CheckTensorsInNPUPlace(tensors), true,
// CheckTensorsInNPUPlace(tensors), true,
// platform::errors::InvalidArgument("All inputs should be in
// platform::errors::InvalidArgument("All inputs should be in
// CudaPlace."));
// CudaPlace."));
return
Collective
(
return
Collective
(
tensors
,
tensors
,
in_tensors
,
out_
tensors
,
[
&
](
Tensor
&
input
,
Tensor
&
output
,
HcclComm
comm
,
[
&
](
phi
::
DenseTensor
&
input
,
phi
::
Dense
Tensor
&
output
,
HcclComm
comm
,
const
aclrtStream
&
stream
)
{
const
aclrtStream
&
stream
)
{
const
auto
root
=
opts
.
source_rank
*
tensors
.
size
()
+
opts
.
source_root
;
int
root
=
opts
.
source_rank
*
in_tensors
.
size
()
+
opts
.
source_root
;
auto
input_tensor
=
if
(
rank_
==
root
)
{
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
input
.
impl
());
return
platform
::
dynload
::
HcclBroadcast
(
auto
output_tensor
=
input
.
data
(),
input
.
numel
(),
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
output
.
impl
());
platform
::
ToHCCLDataType
(
input
.
dtype
()),
root
,
comm
,
stream
);
return
platform
::
dynload
::
HcclBroadcast
(
}
else
{
input_tensor
->
data
(),
input_tensor
->
numel
(),
return
platform
::
dynload
::
HcclBroadcast
(
platform
::
ToHCCLDataType
(
input
.
type
()),
root
,
comm
,
stream
);
output
.
data
(),
output
.
numel
(),
platform
::
ToHCCLDataType
(
output
.
dtype
()),
root
,
comm
,
stream
);
}
},
},
CommType
::
BROADCAST
);
CommType
::
BROADCAST
);
}
}
...
...
paddle/fluid/distributed/collective/ProcessGroupHCCL.h
浏览文件 @
1e56ca8a
...
@@ -46,7 +46,7 @@ class ProcessGroupHCCL : public ProcessGroup {
...
@@ -46,7 +46,7 @@ class ProcessGroupHCCL : public ProcessGroup {
public
std
::
enable_shared_from_this
<
HCCLTask
>
{
public
std
::
enable_shared_from_this
<
HCCLTask
>
{
public:
public:
HCCLTask
(
const
std
::
vector
<
Place
>&
places
,
int
rank
,
CommType
CommType
,
HCCLTask
(
const
std
::
vector
<
Place
>&
places
,
int
rank
,
CommType
CommType
,
const
std
::
vector
<
Tensor
>&
inputs
);
const
std
::
vector
<
phi
::
Dense
Tensor
>&
inputs
);
bool
IsCompleted
();
bool
IsCompleted
();
...
@@ -56,7 +56,7 @@ class ProcessGroupHCCL : public ProcessGroup {
...
@@ -56,7 +56,7 @@ class ProcessGroupHCCL : public ProcessGroup {
void
Synchronize
();
void
Synchronize
();
void
SetOutputs
(
std
::
vector
<
Tensor
>&
outputs
);
// NOLINT
void
SetOutputs
(
std
::
vector
<
phi
::
Dense
Tensor
>&
outputs
);
// NOLINT
virtual
~
HCCLTask
();
virtual
~
HCCLTask
();
...
@@ -65,7 +65,7 @@ class ProcessGroupHCCL : public ProcessGroup {
...
@@ -65,7 +65,7 @@ class ProcessGroupHCCL : public ProcessGroup {
protected:
protected:
std
::
vector
<
Place
>
places_
;
std
::
vector
<
Place
>
places_
;
std
::
vector
<
std
::
shared_ptr
<
HCCLCommManager
>>
hcclComms_
;
std
::
vector
<
std
::
shared_ptr
<
HCCLCommManager
>>
hcclComms_
;
std
::
shared_ptr
<
std
::
vector
<
Tensor
>>
outputs_
;
std
::
shared_ptr
<
std
::
vector
<
phi
::
Dense
Tensor
>>
outputs_
;
private:
private:
};
};
...
@@ -78,17 +78,19 @@ class ProcessGroupHCCL : public ProcessGroup {
...
@@ -78,17 +78,19 @@ class ProcessGroupHCCL : public ProcessGroup {
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
vector
<
Tensor
>&
tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
AllreduceOptions
&
=
AllreduceOptions
())
override
;
const
AllreduceOptions
&
=
AllreduceOptions
())
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
Tensor
>&
tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
BroadcastOptions
&
=
BroadcastOptions
())
override
;
const
BroadcastOptions
&
=
BroadcastOptions
())
override
;
protected:
protected:
virtual
std
::
shared_ptr
<
ProcessGroupHCCL
::
HCCLTask
>
CreateTask
(
virtual
std
::
shared_ptr
<
ProcessGroupHCCL
::
HCCLTask
>
CreateTask
(
std
::
vector
<
Place
>
places
,
int
rank
,
CommType
opType
,
std
::
vector
<
Place
>
places
,
int
rank
,
CommType
opType
,
const
std
::
vector
<
Tensor
>&
inputs
);
const
std
::
vector
<
phi
::
Dense
Tensor
>&
inputs
);
std
::
shared_ptr
<
Store
>
store_
;
std
::
shared_ptr
<
Store
>
store_
;
std
::
shared_ptr
<
HCCLCommManager
>
hccl_comm_
;
std
::
shared_ptr
<
HCCLCommManager
>
hccl_comm_
;
...
@@ -113,15 +115,10 @@ class ProcessGroupHCCL : public ProcessGroup {
...
@@ -113,15 +115,10 @@ class ProcessGroupHCCL : public ProcessGroup {
template
<
typename
Fn
>
template
<
typename
Fn
>
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Collective
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Collective
(
std
::
vector
<
Tensor
>&
inputs
,
// NOLINT
std
::
vector
<
phi
::
Dense
Tensor
>&
inputs
,
// NOLINT
std
::
vector
<
Tensor
>&
outputs
,
// NOLINT
std
::
vector
<
phi
::
Dense
Tensor
>&
outputs
,
// NOLINT
Fn
fn
,
CommType
op_type
);
Fn
fn
,
CommType
op_type
);
template
<
typename
Fn
>
std
::
shared_ptr
<
ProcessGroup
::
Task
>
PointToPoint
(
std
::
vector
<
Tensor
>&
tensors
,
// NOLINT
Fn
fn
,
int
dst_rank
,
CommType
op_type
);
void
CreateHCCLManagerCache
(
const
std
::
string
&
places_key
,
void
CreateHCCLManagerCache
(
const
std
::
string
&
places_key
,
const
std
::
vector
<
Place
>&
places
);
const
std
::
vector
<
Place
>&
places
);
};
};
...
...
paddle/fluid/distributed/collective/ProcessGroupHeter.cc
浏览文件 @
1e56ca8a
...
@@ -26,13 +26,13 @@ namespace distributed {
...
@@ -26,13 +26,13 @@ namespace distributed {
using
Place
=
paddle
::
platform
::
Place
;
using
Place
=
paddle
::
platform
::
Place
;
std
::
shared_ptr
<
ProcessGroupHeter
::
HeterTask
>
ProcessGroupHeter
::
CreateTask
(
std
::
shared_ptr
<
ProcessGroupHeter
::
HeterTask
>
ProcessGroupHeter
::
CreateTask
(
int
rank
,
CommType
comm_type
,
const
std
::
vector
<
Tensor
>&
inputs
)
{
int
rank
,
CommType
comm_type
,
const
std
::
vector
<
phi
::
Dense
Tensor
>&
inputs
)
{
return
std
::
make_shared
<
ProcessGroupHeter
::
HeterTask
>
(
rank
,
comm_type
,
return
std
::
make_shared
<
ProcessGroupHeter
::
HeterTask
>
(
rank
,
comm_type
,
inputs
);
inputs
);
}
}
ProcessGroupHeter
::
HeterTask
::
HeterTask
(
int
rank
,
CommType
CommType
,
ProcessGroupHeter
::
HeterTask
::
HeterTask
(
const
std
::
vector
<
Tensor
>&
inputs
)
int
rank
,
CommType
CommType
,
const
std
::
vector
<
phi
::
Dense
Tensor
>&
inputs
)
:
Task
(
rank
,
inputs
,
CommType
)
{}
:
Task
(
rank
,
inputs
,
CommType
)
{}
ProcessGroupHeter
::
HeterTask
::~
HeterTask
()
{}
ProcessGroupHeter
::
HeterTask
::~
HeterTask
()
{}
...
@@ -86,248 +86,177 @@ static void _do_add(T* dst, T* src, size_t size) {
...
@@ -86,248 +86,177 @@ static void _do_add(T* dst, T* src, size_t size) {
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupHeter
::
AllReduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupHeter
::
AllReduce
(
std
::
vector
<
Tensor
>&
tensors
,
const
AllreduceOptions
&
opts
)
{
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
AllreduceOptions
&
opts
)
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
tensors
),
true
,
CheckTensorsInCudaPlace
(
in_
tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
out_tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All outputs should be in CudaPlace."
));
#endif
#endif
// Step1: do allreduce in inner cluster
// Step1: do allreduce in inner cluster
auto
task
=
inner_pg_
->
AllReduce
(
tensors
,
opts
);
auto
task
=
inner_pg_
->
AllReduce
(
in_tensors
,
in_
tensors
,
opts
);
task
->
Wait
();
task
->
Wait
();
// Step2: copy tensors to CPU
// Step2: copy tensors to CPU
if
(
local_rank_
==
0
)
{
if
(
local_rank_
==
0
)
{
std
::
vector
<
Tensor
>
cpu_tensors
;
std
::
vector
<
phi
::
DenseTensor
>
cpu_tensors
;
cpu_tensors
.
reserve
(
tensors
.
size
());
cpu_tensors
.
reserve
(
in_tensors
.
size
());
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
in_tensors
.
size
();
i
++
)
{
auto
dense_gpu_tensor
=
auto
gpu_tensor
=
in_tensors
[
i
];
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensors
[
i
].
impl
());
auto
cpu_tensor
=
cpu_tensors
[
i
];
phi
::
DenseTensorMeta
meta
=
phi
::
DenseTensorMeta
(
cpu_tensor
.
Resize
(
gpu_tensor
.
dims
());
dense_gpu_tensor
->
dtype
(),
dense_gpu_tensor
->
dims
());
framework
::
TensorCopySync
(
gpu_tensor
,
platform
::
CPUPlace
(),
&
cpu_tensor
);
std
::
shared_ptr
<
phi
::
DenseTensor
>
dense_cpu_tensor
=
std
::
make_shared
<
phi
::
DenseTensor
>
(
std
::
make_unique
<
paddle
::
experimental
::
DefaultAllocator
>
(
paddle
::
platform
::
CPUPlace
())
.
get
(),
meta
);
dense_cpu_tensor
->
ResizeAndAllocate
(
dense_gpu_tensor
->
dims
());
cpu_tensors
[
i
]
=
paddle
::
experimental
::
Tensor
(
dense_cpu_tensor
);
framework
::
TensorCopySync
(
*
dense_gpu_tensor
,
platform
::
CPUPlace
(),
dense_cpu_tensor
.
get
());
}
}
// Step3: do inter cluster allreduce
// Step3: do inter cluster allreduce
if
(
with_switch_
)
{
if
(
with_switch_
)
{
if
(
local_rank_
==
0
)
{
if
(
local_rank_
==
0
)
{
HeterClient
*
client_
=
HeterClient
*
client_
=
HeterClient
::
GetInstance
({
switch_endpoint_
},
{},
0
).
get
();
HeterClient
::
GetInstance
({
switch_endpoint_
},
{},
0
).
get
();
auto
dense_cpu_tensor
=
auto
dense_cpu_tensor
=
cpu_tensors
[
0
];
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
cpu_tensors
[
0
].
impl
());
std
::
vector
<
int
>
send_size
;
std
::
vector
<
int
>
send_size
;
send_size
.
push_back
(
dense_cpu_tensor
->
numel
());
send_size
.
push_back
(
dense_cpu_tensor
.
numel
());
int
ret
=
client_
->
Send
(
int
ret
=
client_
->
Send
(
gid_
,
{
dense_cpu_tensor
->
name
()},
send_size
,
gid_
,
{
dense_cpu_tensor
.
name
()},
send_size
,
dense_cpu_tensor
.
data
(),
dense_cpu_tensor
->
data
(),
dense_cpu_tensor
.
numel
()
*
dense_cpu_tensor
->
numel
()
*
framework
::
DataTypeSize
(
dense_cpu_tensor
.
dtype
()));
framework
::
DataTypeSize
(
dense_cpu_tensor
->
dtype
()));
PADDLE_ENFORCE_EQ
(
ret
,
0
,
platform
::
errors
::
PreconditionNotMet
(
PADDLE_ENFORCE_EQ
(
ret
,
0
,
platform
::
errors
::
PreconditionNotMet
(
"Send to the switch module error."
));
"Send to the switch module error."
));
phi
::
DenseTensorMeta
meta
=
phi
::
DenseTensorMeta
(
phi
::
DenseTensorMeta
meta
=
phi
::
DenseTensorMeta
(
dense_cpu_tensor
->
dtype
(),
dense_cpu_tensor
->
dims
());
dense_cpu_tensor
.
dtype
(),
dense_cpu_tensor
.
dims
());
std
::
shared_ptr
<
phi
::
DenseTensor
>
dense_cpu_tensor2
=
std
::
shared_ptr
<
phi
::
DenseTensor
>
dense_cpu_tensor2
=
std
::
make_shared
<
phi
::
DenseTensor
>
(
std
::
make_shared
<
phi
::
DenseTensor
>
(
std
::
make_unique
<
paddle
::
experimental
::
DefaultAllocator
>
(
std
::
make_unique
<
paddle
::
experimental
::
DefaultAllocator
>
(
paddle
::
platform
::
CPUPlace
())
paddle
::
platform
::
CPUPlace
())
.
get
(),
.
get
(),
meta
);
meta
);
dense_cpu_tensor2
->
ResizeAndAllocate
(
dense_cpu_tensor
->
dims
());
dense_cpu_tensor2
->
ResizeAndAllocate
(
dense_cpu_tensor
.
dims
());
Tensor
cpu_tensor_temp
=
paddle
::
experimental
::
Tensor
(
dense_cpu_tensor2
);
ret
=
client_
->
Recv
(
ret
=
client_
->
Recv
(
gid_
,
{
dense_cpu_tensor
->
name
()},
dense_cpu_tensor2
->
data
(),
gid_
,
{
dense_cpu_tensor
.
name
()},
dense_cpu_tensor2
->
data
(),
dense_cpu_tensor2
->
numel
()
*
dense_cpu_tensor2
->
numel
()
*
framework
::
DataTypeSize
(
dense_cpu_tensor2
->
dtype
()));
framework
::
DataTypeSize
(
dense_cpu_tensor2
->
dtype
()));
PADDLE_ENFORCE_EQ
(
ret
,
0
,
platform
::
errors
::
PreconditionNotMet
(
PADDLE_ENFORCE_EQ
(
ret
,
0
,
platform
::
errors
::
PreconditionNotMet
(
"Recv from the switch module error."
));
"Recv from the switch module error."
));
switch
(
dense_cpu_tensor
->
dtype
())
{
switch
(
dense_cpu_tensor
.
dtype
())
{
case
DataType
::
FLOAT32
:
case
DataType
::
FLOAT32
:
_do_add
<
float
>
(
reinterpret_cast
<
float
*>
(
dense_cpu_tensor
->
data
()),
_do_add
<
float
>
(
reinterpret_cast
<
float
*>
(
dense_cpu_tensor
.
data
()),
reinterpret_cast
<
float
*>
(
dense_cpu_tensor2
->
data
()),
reinterpret_cast
<
float
*>
(
dense_cpu_tensor2
->
data
()),
dense_cpu_tensor
->
numel
());
dense_cpu_tensor
.
numel
());
break
;
break
;
case
DataType
::
FLOAT64
:
case
DataType
::
FLOAT64
:
_do_add
<
double
>
(
_do_add
<
double
>
(
reinterpret_cast
<
double
*>
(
dense_cpu_tensor
->
data
()),
reinterpret_cast
<
double
*>
(
dense_cpu_tensor
.
data
()),
reinterpret_cast
<
double
*>
(
dense_cpu_tensor2
->
data
()),
reinterpret_cast
<
double
*>
(
dense_cpu_tensor2
->
data
()),
dense_cpu_tensor
->
numel
());
dense_cpu_tensor
.
numel
());
break
;
break
;
case
DataType
::
INT32
:
case
DataType
::
INT32
:
_do_add
<
int
>
(
reinterpret_cast
<
int
*>
(
dense_cpu_tensor
->
data
()),
_do_add
<
int
>
(
reinterpret_cast
<
int
*>
(
dense_cpu_tensor
.
data
()),
reinterpret_cast
<
int
*>
(
dense_cpu_tensor2
->
data
()),
reinterpret_cast
<
int
*>
(
dense_cpu_tensor2
->
data
()),
dense_cpu_tensor
->
numel
());
dense_cpu_tensor
.
numel
());
break
;
break
;
default:
default:
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"Unsupported data type (%s) to do add."
,
"Unsupported data type (%s) to do add."
,
framework
::
DataType2String
(
dense_cpu_tensor
->
dtype
())));
framework
::
DataType2String
(
dense_cpu_tensor
.
dtype
())));
}
}
}
}
}
else
{
}
else
{
auto
gloo_task
=
inter_pg_
->
AllReduce
(
cpu_tensors
,
opts
);
auto
gloo_task
=
inter_pg_
->
AllReduce
(
cpu_tensors
,
cpu_tensors
,
opts
);
gloo_task
->
Wait
();
gloo_task
->
Wait
();
}
}
// Step4: copy cpu tensors to gpu
// Step4: copy cpu tensors to gpu
// copy cpu tensors to gpu
// copy cpu tensors to gpu
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
in_tensors
.
size
();
i
++
)
{
auto
dense_gpu_tensor
=
auto
gpu_tensor
=
out_tensors
[
i
];
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensors
[
i
].
impl
());
auto
cpu_tensor
=
cpu_tensors
[
i
];
auto
dense_cpu_tensor
=
framework
::
TensorCopySync
(
cpu_tensor
,
cpu_tensor
.
place
(),
&
gpu_tensor
);
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
cpu_tensors
[
i
].
impl
());
framework
::
TensorCopySync
(
*
dense_cpu_tensor
,
dense_cpu_tensor
->
place
(),
dense_gpu_tensor
.
get
());
}
}
}
}
// Step5: broadcast among inner cluster
// Step5: broadcast among inner cluster
auto
b_opts
=
BroadcastOptions
();
auto
b_opts
=
BroadcastOptions
();
b_opts
.
source_r
oot
=
0
;
b_opts
.
source_r
ank
=
0
;
auto
broadcast_task
=
inner_pg_
->
Broadcast
(
tensors
,
b_opts
);
auto
broadcast_task
=
inner_pg_
->
Broadcast
(
out_tensors
,
out_
tensors
,
b_opts
);
broadcast_task
->
Wait
();
broadcast_task
->
Wait
();
return
CreateTask
(
rank_
,
CommType
::
ALLREDUCE
,
tensors
);
return
CreateTask
(
rank_
,
CommType
::
ALLREDUCE
,
in_
tensors
);
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupHeter
::
Broadcast
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupHeter
::
Broadcast
(
std
::
vector
<
Tensor
>&
tensors
,
const
BroadcastOptions
&
opts
)
{
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
BroadcastOptions
&
opts
)
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
tensors
),
true
,
CheckTensorsInCudaPlace
(
in_
tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
out_tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All outputs should be in CudaPlace."
));
#endif
#endif
// Step1: do broadcast in inner cluster
// Step1: do broadcast in inner cluster
auto
b_opts
=
BroadcastOptions
();
auto
b_opts
=
BroadcastOptions
();
b_opts
.
source_r
oot
=
0
;
b_opts
.
source_r
ank
=
0
;
inner_pg_
->
Broadcast
(
tensors
,
b_opts
);
inner_pg_
->
Broadcast
(
in_tensors
,
out_
tensors
,
b_opts
);
if
(
local_rank_
==
0
)
{
if
(
local_rank_
==
0
)
{
std
::
vector
<
Tensor
>
cpu_tensors
;
std
::
vector
<
phi
::
DenseTensor
>
cpu_tensors
;
cpu_tensors
.
reserve
(
tensors
.
size
());
cpu_tensors
.
reserve
(
in_tensors
.
size
());
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
in_tensors
.
size
();
i
++
)
{
auto
dense_gpu_tensor
=
auto
gpu_tensor
=
in_tensors
[
i
];
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensors
[
i
].
impl
());
auto
cpu_tensor
=
cpu_tensors
[
i
];
phi
::
DenseTensorMeta
meta
=
phi
::
DenseTensorMeta
(
cpu_tensor
.
Resize
(
gpu_tensor
.
dims
());
dense_gpu_tensor
->
dtype
(),
dense_gpu_tensor
->
dims
());
framework
::
TensorCopySync
(
gpu_tensor
,
platform
::
CPUPlace
(),
&
cpu_tensor
);
std
::
shared_ptr
<
phi
::
DenseTensor
>
dense_cpu_tensor
=
std
::
make_shared
<
phi
::
DenseTensor
>
(
std
::
make_unique
<
paddle
::
experimental
::
DefaultAllocator
>
(
paddle
::
platform
::
CPUPlace
())
.
get
(),
meta
);
dense_cpu_tensor
->
ResizeAndAllocate
(
dense_gpu_tensor
->
dims
());
cpu_tensors
[
i
]
=
paddle
::
experimental
::
Tensor
(
dense_cpu_tensor
);
framework
::
TensorCopySync
(
*
dense_gpu_tensor
,
platform
::
CPUPlace
(),
dense_cpu_tensor
.
get
());
}
}
if
(
with_switch_
)
{
if
(
with_switch_
)
{
if
(
local_rank_
==
0
)
{
if
(
local_rank_
==
0
)
{
HeterClient
*
client_
=
HeterClient
*
client_
=
HeterClient
::
GetInstance
({
switch_endpoint_
},
{},
0
).
get
();
HeterClient
::
GetInstance
({
switch_endpoint_
},
{},
0
).
get
();
auto
dense_cpu_tensor
=
auto
dense_cpu_tensor
=
cpu_tensors
[
0
];
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
cpu_tensors
[
0
].
impl
());
if
(
gloo_rank_
==
0
)
{
if
(
gloo_rank_
==
0
)
{
std
::
vector
<
int
>
send_size
;
std
::
vector
<
int
>
send_size
;
send_size
.
push_back
(
dense_cpu_tensor
->
numel
());
send_size
.
push_back
(
dense_cpu_tensor
.
numel
());
int
ret
=
client_
->
Send
(
int
ret
=
client_
->
Send
(
gid_
,
{
dense_cpu_tensor
->
name
()},
send_size
,
gid_
,
{
dense_cpu_tensor
.
name
()},
send_size
,
dense_cpu_tensor
->
data
(),
dense_cpu_tensor
.
data
(),
dense_cpu_tensor
->
numel
()
*
dense_cpu_tensor
.
numel
()
*
framework
::
DataTypeSize
(
dense_cpu_tensor
->
dtype
()));
framework
::
DataTypeSize
(
dense_cpu_tensor
.
dtype
()));
PADDLE_ENFORCE_EQ
(
ret
,
0
,
platform
::
errors
::
PreconditionNotMet
(
PADDLE_ENFORCE_EQ
(
ret
,
0
,
platform
::
errors
::
PreconditionNotMet
(
"Send to the switch module error."
));
"Send to the switch module error."
));
}
else
{
}
else
{
int
ret
=
client_
->
Recv
(
int
ret
=
client_
->
Recv
(
gid_
,
{
dense_cpu_tensor
->
name
()},
dense_cpu_tensor
->
data
(),
gid_
,
{
dense_cpu_tensor
.
name
()},
dense_cpu_tensor
.
data
(),
dense_cpu_tensor
->
numel
()
*
dense_cpu_tensor
.
numel
()
*
framework
::
DataTypeSize
(
dense_cpu_tensor
->
dtype
()));
framework
::
DataTypeSize
(
dense_cpu_tensor
.
dtype
()));
PADDLE_ENFORCE_EQ
(
ret
,
0
,
PADDLE_ENFORCE_EQ
(
ret
,
0
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"Receive from the switch module error."
));
"Receive from the switch module error."
));
ret
=
client_
->
Recv
(
ret
=
client_
->
Recv
(
gid_
,
{
dense_cpu_tensor
->
name
()},
dense_cpu_tensor
->
data
(),
gid_
,
{
dense_cpu_tensor
.
name
()},
dense_cpu_tensor
.
data
(),
dense_cpu_tensor
->
numel
()
*
dense_cpu_tensor
.
numel
()
*
framework
::
DataTypeSize
(
dense_cpu_tensor
->
dtype
()));
framework
::
DataTypeSize
(
dense_cpu_tensor
.
dtype
()));
PADDLE_ENFORCE_EQ
(
ret
,
0
,
PADDLE_ENFORCE_EQ
(
ret
,
0
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"Receive from the switch module error."
));
"Receive from the switch module error."
));
}
}
}
}
}
else
{
}
else
{
auto
gloo_task
=
inter_pg_
->
Broadcast
(
cpu_tensors
,
opts
);
auto
gloo_task
=
inter_pg_
->
Broadcast
(
cpu_tensors
,
cpu_tensors
,
opts
);
gloo_task
->
Wait
();
gloo_task
->
Wait
();
}
}
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
in_tensors
.
size
();
i
++
)
{
auto
dense_gpu_tensor
=
auto
gpu_tensor
=
out_tensors
[
i
];
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensors
[
i
].
impl
());
auto
cpu_tensor
=
cpu_tensors
[
i
];
auto
dense_cpu_tensor
=
framework
::
TensorCopySync
(
cpu_tensor
,
gpu_tensor
.
place
(),
&
gpu_tensor
);
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
cpu_tensors
[
i
].
impl
());
framework
::
TensorCopySync
(
*
dense_cpu_tensor
,
dense_cpu_tensor
->
place
(),
dense_gpu_tensor
.
get
());
}
}
}
}
auto
broadcast_task
=
inner_pg_
->
Broadcast
(
tensors
,
b_opts
);
auto
broadcast_task
=
inner_pg_
->
Broadcast
(
out_tensors
,
out_
tensors
,
b_opts
);
broadcast_task
->
Wait
();
broadcast_task
->
Wait
();
return
CreateTask
(
rank_
,
CommType
::
BROADCAST
,
tensors
);
return
CreateTask
(
rank_
,
CommType
::
BROADCAST
,
in_tensors
);
}
void
ProcessGroupHeter
::
Broadcast
(
const
phi
::
DenseTensor
*
in
,
phi
::
DenseTensor
*
out
)
{
// Step1: do broadcast in inner cluster
inner_pg_
->
Broadcast
(
in
,
out
);
if
(
local_rank_
==
0
)
{
phi
::
DenseTensorMeta
meta
=
phi
::
DenseTensorMeta
(
in
->
dtype
(),
in
->
dims
());
std
::
shared_ptr
<
phi
::
DenseTensor
>
dense_cpu_tensor
=
std
::
make_shared
<
phi
::
DenseTensor
>
(
std
::
make_unique
<
paddle
::
experimental
::
DefaultAllocator
>
(
paddle
::
platform
::
CPUPlace
())
.
get
(),
meta
);
dense_cpu_tensor
->
ResizeAndAllocate
(
in
->
dims
());
Tensor
cpu_tensor
=
paddle
::
experimental
::
Tensor
(
dense_cpu_tensor
);
framework
::
TensorCopySync
(
*
in
,
platform
::
CPUPlace
(),
dense_cpu_tensor
.
get
());
if
(
with_switch_
)
{
if
(
local_rank_
==
0
)
{
HeterClient
*
client_
=
HeterClient
::
GetInstance
({
switch_endpoint_
},
{},
0
).
get
();
if
(
gloo_rank_
==
0
)
{
std
::
vector
<
int
>
send_size
;
send_size
.
push_back
(
in
->
numel
());
int
ret
=
client_
->
Send
(
gid_
,
{
in
->
name
()},
send_size
,
dense_cpu_tensor
->
data
(),
in
->
numel
()
*
framework
::
DataTypeSize
(
in
->
dtype
()));
PADDLE_ENFORCE_EQ
(
ret
,
0
,
platform
::
errors
::
PreconditionNotMet
(
"Send to the switch module error."
));
}
else
{
int
ret
=
client_
->
Recv
(
gid_
,
{
in
->
name
()},
dense_cpu_tensor
->
data
(),
in
->
numel
()
*
framework
::
DataTypeSize
(
in
->
dtype
()));
PADDLE_ENFORCE_EQ
(
ret
,
0
,
platform
::
errors
::
PreconditionNotMet
(
"Receive from the switch module error."
));
}
}
}
else
{
std
::
vector
<
Tensor
>
cpu_tensors
=
{
cpu_tensor
};
auto
gloo_task
=
inter_pg_
->
Broadcast
(
cpu_tensors
);
gloo_task
->
Wait
();
}
framework
::
TensorCopySync
(
*
dense_cpu_tensor
,
out
->
place
(),
out
);
}
inner_pg_
->
Broadcast
(
out
,
out
);
}
}
}
//
namespace distributed
}
// namespace distributed
}
//
namespace paddle
}
// namespace paddle
paddle/fluid/distributed/collective/ProcessGroupHeter.h
浏览文件 @
1e56ca8a
...
@@ -66,7 +66,8 @@ class ProcessGroupHeter : public ProcessGroup {
...
@@ -66,7 +66,8 @@ class ProcessGroupHeter : public ProcessGroup {
class
HeterTask
:
public
ProcessGroup
::
Task
,
class
HeterTask
:
public
ProcessGroup
::
Task
,
public
std
::
enable_shared_from_this
<
HeterTask
>
{
public
std
::
enable_shared_from_this
<
HeterTask
>
{
public:
public:
HeterTask
(
int
rank
,
CommType
CommType
,
const
std
::
vector
<
Tensor
>&
inputs
);
HeterTask
(
int
rank
,
CommType
CommType
,
const
std
::
vector
<
phi
::
DenseTensor
>&
);
bool
IsCompleted
();
bool
IsCompleted
();
...
@@ -89,18 +90,16 @@ class ProcessGroupHeter : public ProcessGroup {
...
@@ -89,18 +90,16 @@ class ProcessGroupHeter : public ProcessGroup {
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
vector
<
Tensor
>&
tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
,
std
::
vector
<
phi
::
DenseTensor
>&
,
const
AllreduceOptions
&
=
AllreduceOptions
())
override
;
const
AllreduceOptions
&
=
AllreduceOptions
())
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
Tensor
>&
tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
,
std
::
vector
<
phi
::
DenseTensor
>&
,
const
BroadcastOptions
&
=
BroadcastOptions
())
override
;
const
BroadcastOptions
&
=
BroadcastOptions
())
override
;
void
Broadcast
(
const
phi
::
DenseTensor
*
in
,
phi
::
DenseTensor
*
out
)
override
;
protected:
protected:
virtual
std
::
shared_ptr
<
ProcessGroupHeter
::
HeterTask
>
CreateTask
(
virtual
std
::
shared_ptr
<
ProcessGroupHeter
::
HeterTask
>
CreateTask
(
int
rank
,
CommType
opType
,
const
std
::
vector
<
Tensor
>&
inputs
);
int
rank
,
CommType
opType
,
const
std
::
vector
<
phi
::
Dense
Tensor
>&
inputs
);
private:
private:
std
::
shared_ptr
<
Store
>
store_
;
std
::
shared_ptr
<
Store
>
store_
;
...
...
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
浏览文件 @
1e56ca8a
...
@@ -41,14 +41,14 @@ void SyncDefaultStream(
...
@@ -41,14 +41,14 @@ void SyncDefaultStream(
std
::
shared_ptr
<
ProcessGroupNCCL
::
NCCLTask
>
ProcessGroupNCCL
::
CreateTask
(
std
::
shared_ptr
<
ProcessGroupNCCL
::
NCCLTask
>
ProcessGroupNCCL
::
CreateTask
(
std
::
vector
<
Place
>
places
,
int
rank
,
CommType
comm_type
,
std
::
vector
<
Place
>
places
,
int
rank
,
CommType
comm_type
,
const
std
::
vector
<
Tensor
>&
inputs
)
{
const
std
::
vector
<
phi
::
Dense
Tensor
>&
inputs
)
{
return
std
::
make_shared
<
ProcessGroupNCCL
::
NCCLTask
>
(
places
,
rank
,
comm_type
,
return
std
::
make_shared
<
ProcessGroupNCCL
::
NCCLTask
>
(
places
,
rank
,
comm_type
,
inputs
);
inputs
);
}
}
ProcessGroupNCCL
::
NCCLTask
::
NCCLTask
(
const
std
::
vector
<
Place
>&
places
,
int
rank
,
ProcessGroupNCCL
::
NCCLTask
::
NCCLTask
(
CommType
CommType
,
const
std
::
vector
<
Place
>&
places
,
int
rank
,
CommType
CommType
,
const
std
::
vector
<
Tensor
>&
inputs
)
const
std
::
vector
<
phi
::
Dense
Tensor
>&
inputs
)
:
Task
(
rank
,
inputs
,
CommType
),
places_
(
places
)
{
:
Task
(
rank
,
inputs
,
CommType
),
places_
(
places
)
{
control_events_
.
resize
(
places
.
size
());
control_events_
.
resize
(
places
.
size
());
ncclComms_
.
resize
(
places
.
size
());
ncclComms_
.
resize
(
places
.
size
());
...
@@ -57,8 +57,8 @@ ProcessGroupNCCL::NCCLTask::NCCLTask(const std::vector<Place>& places, int rank,
...
@@ -57,8 +57,8 @@ ProcessGroupNCCL::NCCLTask::NCCLTask(const std::vector<Place>& places, int rank,
ProcessGroupNCCL
::
NCCLTask
::~
NCCLTask
()
{}
ProcessGroupNCCL
::
NCCLTask
::~
NCCLTask
()
{}
void
ProcessGroupNCCL
::
NCCLTask
::
SetOutputs
(
void
ProcessGroupNCCL
::
NCCLTask
::
SetOutputs
(
std
::
vector
<
Tensor
>&
outputs
)
{
// NOLINT
std
::
vector
<
phi
::
Dense
Tensor
>&
outputs
)
{
// NOLINT
outputs_
=
std
::
make_shared
<
std
::
vector
<
Tensor
>>
(
outputs
);
outputs_
=
std
::
make_shared
<
std
::
vector
<
phi
::
Dense
Tensor
>>
(
outputs
);
}
}
void
ProcessGroupNCCL
::
NCCLTask
::
SynchronizeStreams
()
{
void
ProcessGroupNCCL
::
NCCLTask
::
SynchronizeStreams
()
{
...
@@ -180,8 +180,8 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
...
@@ -180,8 +180,8 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
template
<
typename
Fn
>
template
<
typename
Fn
>
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Collective
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Collective
(
std
::
vector
<
Tensor
>&
inputs
,
std
::
vector
<
Tensor
>&
outputs
,
Fn
fn
,
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
CommType
op_type
)
{
std
::
vector
<
phi
::
DenseTensor
>&
outputs
,
Fn
fn
,
CommType
op_type
)
{
const
auto
places
=
GetPlaceList
(
inputs
);
const
auto
places
=
GetPlaceList
(
inputs
);
const
auto
key
=
GetKeyFromPlaces
(
places
);
const
auto
key
=
GetKeyFromPlaces
(
places
);
...
@@ -205,9 +205,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
...
@@ -205,9 +205,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
if
(
FLAGS_use_stream_safe_cuda_allocator
)
{
if
(
FLAGS_use_stream_safe_cuda_allocator
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
cuda_guard
.
SetDevice
(
places
[
i
]);
cuda_guard
.
SetDevice
(
places
[
i
]);
auto
dense_tensor
=
memory
::
RecordStream
(
inputs
[
i
].
Holder
(),
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
inputs
[
i
].
impl
());
memory
::
RecordStream
(
dense_tensor
->
Holder
(),
places_to_ctx_
[
key
][
i
]
->
stream
());
places_to_ctx_
[
key
][
i
]
->
stream
());
}
}
}
}
...
@@ -267,7 +265,8 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in,
...
@@ -267,7 +265,8 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in,
template
<
typename
Fn
>
template
<
typename
Fn
>
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
PointToPoint
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
PointToPoint
(
std
::
vector
<
Tensor
>&
tensors
,
Fn
fn
,
int
dst_rank
,
CommType
op_type
)
{
std
::
vector
<
phi
::
DenseTensor
>&
tensors
,
Fn
fn
,
int
dst_rank
,
CommType
op_type
)
{
const
auto
places
=
GetPlaceList
(
tensors
);
const
auto
places
=
GetPlaceList
(
tensors
);
const
auto
key
=
GetKeyFromPlaces
(
places
);
const
auto
key
=
GetKeyFromPlaces
(
places
);
...
@@ -290,9 +289,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
...
@@ -290,9 +289,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
if
(
FLAGS_use_stream_safe_cuda_allocator
)
{
if
(
FLAGS_use_stream_safe_cuda_allocator
)
{
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
cuda_guard
.
SetDevice
(
places
[
i
]);
cuda_guard
.
SetDevice
(
places
[
i
]);
auto
dense_tensor
=
memory
::
RecordStream
(
tensors
[
i
].
Holder
(),
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensors
[
i
].
impl
());
memory
::
RecordStream
(
dense_tensor
->
Holder
(),
places_to_ctx_
[
key
][
i
]
->
stream
());
places_to_ctx_
[
key
][
i
]
->
stream
());
}
}
}
}
...
@@ -314,46 +311,40 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
...
@@ -314,46 +311,40 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
AllReduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
AllReduce
(
std
::
vector
<
Tensor
>&
tensors
,
const
AllreduceOptions
&
opts
)
{
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
AllreduceOptions
&
opts
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
tensors
),
true
,
CheckTensorsInCudaPlace
(
in_
tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
return
Collective
(
return
Collective
(
in_tensors
,
out_tensors
,
tensors
,
tensors
,
[
&
](
const
phi
::
DenseTensor
&
input
,
phi
::
DenseTensor
&
output
,
[
&
](
const
Tensor
&
input
,
Tensor
&
output
,
ncclComm_t
comm
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
const
gpuStream_t
&
stream
)
{
return
platform
::
dynload
::
ncclAllReduce
(
auto
input_tensor
=
input
.
data
(),
output
.
data
(),
input
.
numel
(),
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
input
.
impl
());
platform
::
ToNCCLDataType
(
input
.
type
()),
auto
output_tensor
=
ToNCCLRedType
(
opts
.
reduce_op
),
comm
,
stream
);
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
output
.
impl
());
},
return
platform
::
dynload
::
ncclAllReduce
(
CommType
::
ALLREDUCE
);
input_tensor
->
data
(),
output_tensor
->
data
(),
input_tensor
->
numel
(),
platform
::
ToNCCLDataType
(
input
.
type
()),
ToNCCLRedType
(
opts
.
reduce_op
),
comm
,
stream
);
},
CommType
::
ALLREDUCE
);
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Broadcast
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Broadcast
(
std
::
vector
<
Tensor
>&
tensors
,
const
BroadcastOptions
&
opts
)
{
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
BroadcastOptions
&
opts
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
tensors
),
true
,
CheckTensorsInCudaPlace
(
in_
tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
return
Collective
(
return
Collective
(
in_tensors
,
out_tensors
,
tensors
,
tensors
,
[
&
](
phi
::
DenseTensor
&
input
,
phi
::
DenseTensor
&
output
,
[
&
](
Tensor
&
input
,
Tensor
&
output
,
ncclComm_t
comm
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
const
gpuStream_t
&
stream
)
{
const
auto
root
=
opts
.
source_rank
*
in_tensors
.
size
()
+
const
auto
root
=
opts
.
source_rank
*
tensors
.
size
()
+
opts
.
source_root
;
opts
.
source_root
;
auto
input_tensor
=
return
platform
::
dynload
::
ncclBroadcast
(
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
input
.
impl
());
input
.
data
(),
output
.
data
(),
input
.
numel
(),
auto
output_tensor
=
platform
::
ToNCCLDataType
(
input
.
type
()),
root
,
comm
,
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
output
.
impl
());
stream
);
return
platform
::
dynload
::
ncclBcast
(
},
input_tensor
->
data
(),
input_tensor
->
numel
(),
CommType
::
BROADCAST
);
platform
::
ToNCCLDataType
(
input
.
type
()),
root
,
comm
,
stream
);
},
CommType
::
BROADCAST
);
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Barrier
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Barrier
(
...
@@ -374,23 +365,24 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
...
@@ -374,23 +365,24 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
places
.
emplace_back
(
place_id
);
places
.
emplace_back
(
place_id
);
}
}
std
::
vector
<
Tensor
>
barrierTensors
;
std
::
vector
<
phi
::
Dense
Tensor
>
barrierTensors
;
barrierTensors
.
reserve
(
places
.
size
());
barrierTensors
.
reserve
(
places
.
size
());
platform
::
CUDADeviceGuard
gpuGuard
;
platform
::
CUDADeviceGuard
gpuGuard
;
for
(
auto
&
place
:
places
)
{
for
(
auto
&
place
:
places
)
{
gpuGuard
.
SetDeviceIndex
(
place
.
GetDeviceId
());
gpuGuard
.
SetDeviceIndex
(
place
.
GetDeviceId
());
auto
dt
=
full
({
1
},
0
,
phi
::
DataType
::
FLOAT32
,
phi
::
GPUPlace
());
auto
dt
=
full
({
1
},
0
,
phi
::
DataType
::
FLOAT32
,
phi
::
GPUPlace
());
barrierTensors
.
push_back
(
dt
);
barrierTensors
.
push_back
(
*
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
dt
.
impl
()));
}
}
auto
task
=
ProcessGroupNCCL
::
AllReduce
(
barrierTensors
);
auto
task
=
ProcessGroupNCCL
::
AllReduce
(
barrierTensors
,
barrierTensors
);
auto
nccl_task
=
dynamic_cast
<
ProcessGroupNCCL
::
NCCLTask
*>
(
task
.
get
());
auto
nccl_task
=
dynamic_cast
<
ProcessGroupNCCL
::
NCCLTask
*>
(
task
.
get
());
nccl_task
->
barrierTensors_
=
std
::
move
(
barrierTensors
);
nccl_task
->
barrierTensors_
=
std
::
move
(
barrierTensors
);
return
task
;
return
task
;
}
}
void
CheckTensorsInDifferentDevices
(
const
std
::
vector
<
Tensor
>&
tensors
,
void
CheckTensorsInDifferentDevices
(
const
size_t
num_devices
)
{
const
std
::
vector
<
phi
::
DenseTensor
>&
tensors
,
const
size_t
num_devices
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
tensors
.
size
()
==
0
,
false
,
tensors
.
size
()
==
0
,
false
,
platform
::
errors
::
InvalidArgument
(
"Tensor list must be nonempty."
));
platform
::
errors
::
InvalidArgument
(
"Tensor list must be nonempty."
));
...
@@ -402,11 +394,11 @@ void CheckTensorsInDifferentDevices(const std::vector<Tensor>& tensors,
...
@@ -402,11 +394,11 @@ void CheckTensorsInDifferentDevices(const std::vector<Tensor>& tensors,
std
::
set
<
Place
>
used_devices
;
std
::
set
<
Place
>
used_devices
;
for
(
const
auto
&
t
:
tensors
)
{
for
(
const
auto
&
t
:
tensors
)
{
PADDLE_ENFORCE_EQ
(
t
.
is_gpu
()
&&
t
.
is_dense_tensor
(
),
true
,
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
t
.
place
()
),
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Tensors must be CUDA and dense tensor."
));
"Tensors must be CUDA and dense tensor."
));
const
auto
inserted
=
used_devices
.
insert
(
t
.
inner_
place
()).
second
;
const
auto
inserted
=
used_devices
.
insert
(
t
.
place
()).
second
;
PADDLE_ENFORCE_EQ
(
inserted
,
true
,
PADDLE_ENFORCE_EQ
(
inserted
,
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Tensors must be on distinct GPU devices."
));
"Tensors must be on distinct GPU devices."
));
...
@@ -414,62 +406,55 @@ void CheckTensorsInDifferentDevices(const std::vector<Tensor>& tensors,
...
@@ -414,62 +406,55 @@ void CheckTensorsInDifferentDevices(const std::vector<Tensor>& tensors,
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Send
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Send
(
std
::
vector
<
Tensor
>&
tensors
,
int
dst_rank
)
{
std
::
vector
<
phi
::
Dense
Tensor
>&
tensors
,
int
dst_rank
)
{
CheckTensorsInDifferentDevices
(
tensors
,
static_cast
<
size_t
>
(
GetSize
()));
CheckTensorsInDifferentDevices
(
tensors
,
static_cast
<
size_t
>
(
GetSize
()));
auto
task
=
PointToPoint
(
auto
task
=
PointToPoint
(
tensors
,
tensors
,
[
&
](
phi
::
DenseTensor
&
input
,
ncclComm_t
comm
,
[
&
](
Tensor
&
input
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
,
const
gpuStream_t
&
stream
,
int
dst_rank
)
{
int
dst_rank
)
{
return
platform
::
dynload
::
ncclSend
(
auto
input_tensor
=
input
.
data
(),
input
.
numel
(),
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
input
.
impl
());
platform
::
ToNCCLDataType
(
input
.
dtype
()),
return
platform
::
dynload
::
ncclSend
(
dst_rank
,
comm
,
stream
);
input_tensor
->
data
(),
input_tensor
->
numel
(),
},
platform
::
ToNCCLDataType
(
input
.
type
()),
dst_rank
,
comm
,
stream
);
dst_rank
,
CommType
::
SEND
);
},
dst_rank
,
CommType
::
SEND
);
return
task
;
return
task
;
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Recv
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Recv
(
std
::
vector
<
Tensor
>&
tensors
,
int
src_rank
)
{
std
::
vector
<
phi
::
Dense
Tensor
>&
tensors
,
int
src_rank
)
{
CheckTensorsInDifferentDevices
(
tensors
,
static_cast
<
size_t
>
(
GetSize
()));
CheckTensorsInDifferentDevices
(
tensors
,
static_cast
<
size_t
>
(
GetSize
()));
auto
task
=
PointToPoint
(
auto
task
=
PointToPoint
(
tensors
,
tensors
,
[
&
](
phi
::
DenseTensor
&
output
,
ncclComm_t
comm
,
[
&
](
Tensor
&
output
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
,
const
gpuStream_t
&
stream
,
int
src_rank
)
{
int
src_rank
)
{
return
platform
::
dynload
::
ncclRecv
(
auto
output_tensor
=
output
.
data
(),
output
.
numel
(),
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
output
.
impl
());
platform
::
ToNCCLDataType
(
output
.
dtype
()),
return
platform
::
dynload
::
ncclRecv
(
src_rank
,
comm
,
stream
);
output_tensor
->
data
(),
output_tensor
->
numel
(),
},
platform
::
ToNCCLDataType
(
output
.
type
()),
src_rank
,
comm
,
stream
);
src_rank
,
CommType
::
RECV
);
},
src_rank
,
CommType
::
RECV
);
return
task
;
return
task
;
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
AllGather
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
AllGather
(
std
::
vector
<
Tensor
>&
in_tensors
,
std
::
vector
<
Tensor
>&
out_tensors
)
{
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
in_tensors
),
true
,
CheckTensorsInCudaPlace
(
in_tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
out_tensors
),
true
,
CheckTensorsInCudaPlace
(
out_tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All outputs should be in CudaPlace."
));
platform
::
errors
::
InvalidArgument
(
"All outputs should be in CudaPlace."
));
return
Collective
(
return
Collective
(
in_tensors
,
out_tensors
,
in_tensors
,
out_tensors
,
[
&
](
const
phi
::
DenseTensor
&
input
,
phi
::
DenseTensor
&
output
,
[
&
](
const
Tensor
&
input
,
Tensor
&
output
,
ncclComm_t
comm
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
const
gpuStream_t
&
stream
)
{
return
platform
::
dynload
::
ncclAllGather
(
auto
input_tensor
=
input
.
data
(),
output
.
data
(),
input
.
numel
(),
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
input
.
impl
());
platform
::
ToNCCLDataType
(
input
.
dtype
()),
comm
,
auto
output_tensor
=
stream
);
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
output
.
impl
());
},
return
platform
::
dynload
::
ncclAllGather
(
CommType
::
ALLGATHER
);
input_tensor
->
data
(),
output_tensor
->
data
(),
input_tensor
->
numel
(),
platform
::
ToNCCLDataType
(
input
.
type
()),
comm
,
stream
);
},
CommType
::
ALLGATHER
);
}
}
void
*
GetPointerByOffset
(
void
*
raw_pointer
,
size_t
offset
,
void
*
GetPointerByOffset
(
void
*
raw_pointer
,
size_t
offset
,
...
@@ -493,10 +478,12 @@ void* GetPointerByOffset(void* raw_pointer, size_t offset,
...
@@ -493,10 +478,12 @@ void* GetPointerByOffset(void* raw_pointer, size_t offset,
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"This datatype in nccl is not supported."
));
"This datatype in nccl is not supported."
));
}
}
return
nullptr
;
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
AllToAll
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
AllToAll
(
std
::
vector
<
Tensor
>&
in_tensors
,
std
::
vector
<
Tensor
>&
out_tensors
)
{
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
in_tensors
),
true
,
CheckTensorsInCudaPlace
(
in_tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
...
@@ -505,24 +492,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
...
@@ -505,24 +492,20 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
return
Collective
(
return
Collective
(
in_tensors
,
out_tensors
,
in_tensors
,
out_tensors
,
[
&
](
const
Tensor
&
input
,
Tensor
&
output
,
ncclComm_t
comm
,
[
&
](
phi
::
DenseTensor
&
input
,
phi
::
Dense
Tensor
&
output
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
const
gpuStream_t
&
stream
)
{
auto
input_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
input
.
impl
());
auto
output_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
output
.
impl
());
size_t
offset
=
0
;
size_t
offset
=
0
;
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupStart
());
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupStart
());
for
(
auto
i
=
0
;
i
<
size_
;
i
++
)
{
for
(
auto
i
=
0
;
i
<
size_
;
i
++
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclSend
(
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclSend
(
GetPointerByOffset
(
input
_tensor
->
data
(),
offset
,
input
.
type
()),
GetPointerByOffset
(
input
.
data
(),
offset
,
input
.
d
type
()),
input
_tensor
->
numel
()
/
size_
,
input
.
numel
()
/
size_
,
platform
::
ToNCCLDataType
(
input
.
dtype
()),
i
,
platform
::
ToNCCLDataType
(
input
.
type
()),
i
,
comm
,
stream
));
comm
,
stream
));
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclRecv
(
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclRecv
(
GetPointerByOffset
(
output
_tensor
->
data
(),
offset
,
input
.
type
()),
GetPointerByOffset
(
output
.
data
(),
offset
,
input
.
d
type
()),
input
_tensor
->
numel
()
/
size_
,
input
.
numel
()
/
size_
,
platform
::
ToNCCLDataType
(
input
.
dtype
()),
i
,
platform
::
ToNCCLDataType
(
input
.
type
()),
i
,
comm
,
stream
));
comm
,
stream
));
offset
+=
input
_tensor
->
numel
()
/
size_
;
offset
+=
input
.
numel
()
/
size_
;
}
}
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupEnd
());
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupEnd
());
},
},
...
@@ -530,29 +513,26 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
...
@@ -530,29 +513,26 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Reduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Reduce
(
std
::
vector
<
Tensor
>&
tensors
,
const
ReduceOptions
&
opts
)
{
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
ReduceOptions
&
opts
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
tensors
),
true
,
CheckTensorsInCudaPlace
(
in_
tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
return
Collective
(
return
Collective
(
tensors
,
tensors
,
in_tensors
,
out_tensors
,
[
&
](
const
Tensor
&
input
,
Tensor
&
output
,
ncclComm_t
comm
,
[
&
](
const
phi
::
DenseTensor
&
input
,
phi
::
DenseTensor
&
output
,
const
gpuStream_t
&
stream
)
{
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
auto
input_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
input
.
impl
());
auto
output_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
output
.
impl
());
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclReduce
(
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclReduce
(
input
_tensor
->
data
(),
output_tensor
->
data
(),
input
.
numel
(),
input
.
data
(),
output
.
data
(),
input
.
numel
(),
platform
::
ToNCCLDataType
(
input
.
type
()),
platform
::
ToNCCLDataType
(
input
.
d
type
()),
ToNCCLRedType
(
opts
.
reduce_op
),
opts
.
root_rank
,
comm
,
stream
));
ToNCCLRedType
(
opts
.
reduce_op
),
opts
.
root_rank
,
comm
,
stream
));
},
},
CommType
::
REDUCE
);
CommType
::
REDUCE
);
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Scatter
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Scatter
(
std
::
vector
<
Tensor
>&
in_tensors
,
std
::
vector
<
Tensor
>&
out
_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
in
_tensors
,
const
ScatterOptions
&
opts
)
{
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
ScatterOptions
&
opts
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
in_tensors
),
true
,
CheckTensorsInCudaPlace
(
in_tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
...
@@ -561,31 +541,27 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
...
@@ -561,31 +541,27 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
return
Collective
(
return
Collective
(
in_tensors
,
out_tensors
,
in_tensors
,
out_tensors
,
[
&
](
const
Tensor
&
input
,
Tensor
&
output
,
ncclComm_t
comm
,
[
&
](
phi
::
DenseTensor
&
input
,
phi
::
Dense
Tensor
&
output
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
const
gpuStream_t
&
stream
)
{
auto
input_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
input
.
impl
());
auto
output_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
output
.
impl
());
size_t
offset
=
0
;
size_t
offset
=
0
;
if
(
rank_
==
opts
.
root_rank
)
{
if
(
rank_
==
opts
.
root_rank
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupStart
());
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupStart
());
for
(
auto
i
=
0
;
i
<
size_
;
i
++
)
{
for
(
auto
i
=
0
;
i
<
size_
;
i
++
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclSend
(
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclSend
(
GetPointerByOffset
(
input
_tensor
->
data
(),
offset
,
input
.
type
()),
GetPointerByOffset
(
input
.
data
(),
offset
,
input
.
d
type
()),
input
_tensor
->
numel
()
/
size_
,
input
.
numel
()
/
size_
,
platform
::
ToNCCLDataType
(
input
.
dtype
())
,
platform
::
ToNCCLDataType
(
input
.
type
()),
i
,
comm
,
stream
));
i
,
comm
,
stream
));
offset
+=
input
_tensor
->
numel
()
/
size_
;
offset
+=
input
.
numel
()
/
size_
;
}
}
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclRecv
(
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclRecv
(
output
_tensor
->
data
(),
input_tensor
->
numel
()
/
size_
,
output
.
data
(),
input
.
numel
()
/
size_
,
platform
::
ToNCCLDataType
(
input
.
type
()),
opts
.
root_rank
,
comm
,
platform
::
ToNCCLDataType
(
input
.
d
type
()),
opts
.
root_rank
,
comm
,
stream
));
stream
));
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupEnd
());
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupEnd
());
}
else
{
}
else
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclRecv
(
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclRecv
(
output
_tensor
->
data
(),
input_tensor
->
numel
()
/
size_
,
output
.
data
(),
input
.
numel
()
/
size_
,
platform
::
ToNCCLDataType
(
input
.
type
()),
opts
.
root_rank
,
comm
,
platform
::
ToNCCLDataType
(
input
.
d
type
()),
opts
.
root_rank
,
comm
,
stream
));
stream
));
}
}
},
},
...
...
paddle/fluid/distributed/collective/ProcessGroupNCCL.h
浏览文件 @
1e56ca8a
...
@@ -51,7 +51,7 @@ class ProcessGroupNCCL : public ProcessGroup {
...
@@ -51,7 +51,7 @@ class ProcessGroupNCCL : public ProcessGroup {
public
std
::
enable_shared_from_this
<
NCCLTask
>
{
public
std
::
enable_shared_from_this
<
NCCLTask
>
{
public:
public:
NCCLTask
(
const
std
::
vector
<
Place
>&
places
,
int
rank
,
CommType
CommType
,
NCCLTask
(
const
std
::
vector
<
Place
>&
places
,
int
rank
,
CommType
CommType
,
const
std
::
vector
<
Tensor
>&
inputs
);
const
std
::
vector
<
phi
::
Dense
Tensor
>&
inputs
);
bool
IsCompleted
();
bool
IsCompleted
();
...
@@ -61,17 +61,17 @@ class ProcessGroupNCCL : public ProcessGroup {
...
@@ -61,17 +61,17 @@ class ProcessGroupNCCL : public ProcessGroup {
void
Synchronize
();
void
Synchronize
();
void
SetOutputs
(
std
::
vector
<
Tensor
>&
outputs
);
// NOLINT
void
SetOutputs
(
std
::
vector
<
phi
::
Dense
Tensor
>&
outputs
);
// NOLINT
virtual
~
NCCLTask
();
virtual
~
NCCLTask
();
std
::
vector
<
EventManager
>
control_events_
;
std
::
vector
<
EventManager
>
control_events_
;
std
::
vector
<
Tensor
>
barrierTensors_
;
std
::
vector
<
phi
::
Dense
Tensor
>
barrierTensors_
;
protected:
protected:
std
::
vector
<
Place
>
places_
;
std
::
vector
<
Place
>
places_
;
std
::
vector
<
std
::
shared_ptr
<
NCCLCommManager
>>
ncclComms_
;
std
::
vector
<
std
::
shared_ptr
<
NCCLCommManager
>>
ncclComms_
;
std
::
shared_ptr
<
std
::
vector
<
Tensor
>>
outputs_
;
std
::
shared_ptr
<
std
::
vector
<
phi
::
Dense
Tensor
>>
outputs_
;
private:
private:
};
};
...
@@ -84,40 +84,46 @@ class ProcessGroupNCCL : public ProcessGroup {
...
@@ -84,40 +84,46 @@ class ProcessGroupNCCL : public ProcessGroup {
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
vector
<
Tensor
>&
tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
AllreduceOptions
&
=
AllreduceOptions
())
override
;
const
AllreduceOptions
&
=
AllreduceOptions
())
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
Tensor
>&
tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
BroadcastOptions
&
=
BroadcastOptions
())
override
;
const
BroadcastOptions
&
=
BroadcastOptions
())
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Barrier
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Barrier
(
const
BarrierOptions
&
=
BarrierOptions
())
override
;
const
BarrierOptions
&
=
BarrierOptions
())
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Send
(
std
::
vector
<
Tensor
>&
tensors
,
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Send
(
int
dst_rank
)
override
;
std
::
vector
<
phi
::
DenseTensor
>&
tensors
,
int
dst_rank
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Recv
(
std
::
vector
<
Tensor
>&
tensors
,
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Recv
(
int
src_rank
)
override
;
std
::
vector
<
phi
::
DenseTensor
>&
tensors
,
int
src_rank
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
std
::
vector
<
Tensor
>&
in_tensors
,
std
::
vector
<
phi
::
Dense
Tensor
>&
in_tensors
,
std
::
vector
<
Tensor
>&
out_tensors
)
override
;
std
::
vector
<
phi
::
Dense
Tensor
>&
out_tensors
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllToAll
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllToAll
(
std
::
vector
<
Tensor
>&
in
,
std
::
vector
<
Tensor
>&
out
)
override
;
std
::
vector
<
phi
::
DenseTensor
>&
in
,
std
::
vector
<
phi
::
DenseTensor
>&
out
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Reduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Reduce
(
std
::
vector
<
Tensor
>&
tensors
,
const
ReduceOptions
&
opts
)
override
;
std
::
vector
<
phi
::
DenseTensor
>&
tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
ReduceOptions
&
opts
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Scatter
(
std
::
vector
<
Tensor
>&
in_tensors
,
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Scatter
(
std
::
vector
<
Tensor
>&
out_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
const
ScatterOptions
&
)
override
;
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
ScatterOptions
&
)
override
;
protected:
protected:
virtual
std
::
shared_ptr
<
ProcessGroupNCCL
::
NCCLTask
>
CreateTask
(
virtual
std
::
shared_ptr
<
ProcessGroupNCCL
::
NCCLTask
>
CreateTask
(
std
::
vector
<
Place
>
places
,
int
rank
,
CommType
opType
,
std
::
vector
<
Place
>
places
,
int
rank
,
CommType
opType
,
const
std
::
vector
<
Tensor
>&
inputs
);
const
std
::
vector
<
phi
::
Dense
Tensor
>&
inputs
);
protected:
protected:
std
::
shared_ptr
<
Store
>
store_
;
std
::
shared_ptr
<
Store
>
store_
;
...
@@ -142,8 +148,8 @@ class ProcessGroupNCCL : public ProcessGroup {
...
@@ -142,8 +148,8 @@ class ProcessGroupNCCL : public ProcessGroup {
template
<
typename
Fn
>
template
<
typename
Fn
>
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Collective
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Collective
(
std
::
vector
<
Tensor
>&
inputs
,
// NOLINT
std
::
vector
<
phi
::
Dense
Tensor
>&
inputs
,
// NOLINT
std
::
vector
<
Tensor
>&
outputs
,
// NOLINT
std
::
vector
<
phi
::
Dense
Tensor
>&
outputs
,
// NOLINT
Fn
fn
,
CommType
op_type
);
Fn
fn
,
CommType
op_type
);
template
<
typename
Fn
>
template
<
typename
Fn
>
...
@@ -152,7 +158,7 @@ class ProcessGroupNCCL : public ProcessGroup {
...
@@ -152,7 +158,7 @@ class ProcessGroupNCCL : public ProcessGroup {
template
<
typename
Fn
>
template
<
typename
Fn
>
std
::
shared_ptr
<
ProcessGroup
::
Task
>
PointToPoint
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
PointToPoint
(
std
::
vector
<
Tensor
>&
tensors
,
// NOLINT
std
::
vector
<
phi
::
Dense
Tensor
>&
tensors
,
// NOLINT
Fn
fn
,
int
dst_rank
,
CommType
op_type
);
Fn
fn
,
int
dst_rank
,
CommType
op_type
);
void
CreateNCCLManagerCache
(
const
std
::
string
&
places_key
,
void
CreateNCCLManagerCache
(
const
std
::
string
&
places_key
,
...
...
paddle/fluid/distributed/collective/reducer.cc
浏览文件 @
1e56ca8a
...
@@ -734,7 +734,11 @@ void EagerReducer::ProcessUnusedDenseVars() {
...
@@ -734,7 +734,11 @@ void EagerReducer::ProcessUnusedDenseVars() {
distributed
::
AllreduceOptions
opts
;
distributed
::
AllreduceOptions
opts
;
opts
.
reduce_op
=
ReduceOp
::
SUM
;
opts
.
reduce_op
=
ReduceOp
::
SUM
;
std
::
vector
<
Tensor
>
reduce_tensors
=
{
global_used_vars_
};
std
::
vector
<
Tensor
>
reduce_tensors
=
{
global_used_vars_
};
process_group_
->
AllReduce
(
reduce_tensors
,
opts
)
->
Synchronize
();
std
::
vector
<
phi
::
DenseTensor
>
in_out
;
for
(
auto
&
t
:
reduce_tensors
)
{
in_out
.
push_back
(
*
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
t
.
impl
()));
}
process_group_
->
AllReduce
(
in_out
,
in_out
,
opts
)
->
Synchronize
();
framework
::
TensorToVector
<
int
>
(
*
global_used_tensor
,
*
dev_ctx
,
framework
::
TensorToVector
<
int
>
(
*
global_used_tensor
,
*
dev_ctx
,
&
local_used_vars_
);
&
local_used_vars_
);
...
@@ -820,7 +824,11 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group,
...
@@ -820,7 +824,11 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group,
// all_reduce
// all_reduce
std
::
vector
<
Tensor
>
reduce_tensors
=
{
group
->
dense_contents_
};
std
::
vector
<
Tensor
>
reduce_tensors
=
{
group
->
dense_contents_
};
group
->
task
=
process_group_
->
AllReduce
(
reduce_tensors
,
opts
);
std
::
vector
<
phi
::
DenseTensor
>
in_out
;
for
(
auto
&
t
:
reduce_tensors
)
{
in_out
.
push_back
(
*
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
t
.
impl
()));
}
group
->
task
=
process_group_
->
AllReduce
(
in_out
,
in_out
,
opts
);
// split in FinalizeBackward()
// split in FinalizeBackward()
}
}
...
@@ -871,7 +879,11 @@ void EagerReducer::AllReduceSparse(EagerGroup *group,
...
@@ -871,7 +879,11 @@ void EagerReducer::AllReduceSparse(EagerGroup *group,
distributed
::
AllreduceOptions
opts
;
distributed
::
AllreduceOptions
opts
;
opts
.
reduce_op
=
ReduceOp
::
SUM
;
opts
.
reduce_op
=
ReduceOp
::
SUM
;
std
::
vector
<
Tensor
>
reduce_tensors
=
{
rows_num_tensor
};
std
::
vector
<
Tensor
>
reduce_tensors
=
{
rows_num_tensor
};
process_group_
->
AllReduce
(
reduce_tensors
,
opts
)
->
Synchronize
();
std
::
vector
<
phi
::
DenseTensor
>
in_out
;
for
(
auto
&
t
:
reduce_tensors
)
{
in_out
.
push_back
(
*
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
t
.
impl
()));
}
process_group_
->
AllReduce
(
in_out
,
in_out
,
opts
)
->
Synchronize
();
framework
::
TensorToVector
<
int64_t
>
(
*
rows_num_dense_tensor
,
*
dev_ctx
,
framework
::
TensorToVector
<
int64_t
>
(
*
rows_num_dense_tensor
,
*
dev_ctx
,
&
rows_num_vector
);
&
rows_num_vector
);
...
@@ -908,8 +920,15 @@ void EagerReducer::AllReduceSparse(EagerGroup *group,
...
@@ -908,8 +920,15 @@ void EagerReducer::AllReduceSparse(EagerGroup *group,
std
::
vector
<
Tensor
>
src_rows_tensors
=
{
src_rows_tensor
};
std
::
vector
<
Tensor
>
src_rows_tensors
=
{
src_rows_tensor
};
std
::
vector
<
Tensor
>
dst_rows_tensors
=
{
dst_rows_tensor
};
std
::
vector
<
Tensor
>
dst_rows_tensors
=
{
dst_rows_tensor
};
process_group_
->
AllGather
(
src_rows_tensors
,
dst_rows_tensors
)
std
::
vector
<
phi
::
DenseTensor
>
in
;
->
Synchronize
();
std
::
vector
<
phi
::
DenseTensor
>
out
;
for
(
auto
&
t
:
src_rows_tensors
)
{
in
.
push_back
(
*
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
t
.
impl
()));
}
for
(
auto
&
t
:
dst_rows_tensors
)
{
out
.
push_back
(
*
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
t
.
impl
()));
}
process_group_
->
AllGather
(
in
,
out
)
->
Synchronize
();
framework
::
Vector
<
int64_t
>
dst_rows_vector
(
rows_num
,
0
);
framework
::
Vector
<
int64_t
>
dst_rows_vector
(
rows_num
,
0
);
auto
*
dst_rows_dense_tensor
=
auto
*
dst_rows_dense_tensor
=
...
@@ -934,8 +953,17 @@ void EagerReducer::AllReduceSparse(EagerGroup *group,
...
@@ -934,8 +953,17 @@ void EagerReducer::AllReduceSparse(EagerGroup *group,
std
::
vector
<
Tensor
>
src_value_tensors
=
{
src_value_tensor
};
std
::
vector
<
Tensor
>
src_value_tensors
=
{
src_value_tensor
};
std
::
vector
<
Tensor
>
dst_value_tensors
=
{
dst_value_tensor
};
std
::
vector
<
Tensor
>
dst_value_tensors
=
{
dst_value_tensor
};
process_group_
->
AllGather
(
src_value_tensors
,
dst_value_tensors
)
std
::
vector
<
phi
::
DenseTensor
>
src_dense
;
->
Synchronize
();
std
::
vector
<
phi
::
DenseTensor
>
dst_dense
;
for
(
auto
&
t
:
src_value_tensors
)
{
src_dense
.
push_back
(
*
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
t
.
impl
()));
}
for
(
auto
&
t
:
dst_value_tensors
)
{
dst_dense
.
push_back
(
*
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
t
.
impl
()));
}
process_group_
->
AllGather
(
src_dense
,
dst_dense
)
->
Synchronize
();
src
->
set_rows
(
dst_rows_vector
);
src
->
set_rows
(
dst_rows_vector
);
*
(
src
->
mutable_value
())
=
*
(
src
->
mutable_value
())
=
...
...
paddle/fluid/operators/collective/c_allgather_op.cu.cc
浏览文件 @
1e56ca8a
...
@@ -18,7 +18,9 @@ limitations under the License. */
...
@@ -18,7 +18,9 @@ limitations under the License. */
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
#endif
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/api/include/tensor.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -35,6 +37,18 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -35,6 +37,18 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> {
int
nranks
=
ctx
.
Attr
<
int
>
(
"nranks"
);
int
nranks
=
ctx
.
Attr
<
int
>
(
"nranks"
);
int
rid
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
int
rid
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
auto
map
=
distributed
::
ProcessGroupMapFromGid
::
getInstance
();
if
(
map
->
has
(
rid
))
{
// Use ProcessGroup
distributed
::
ProcessGroup
*
pg
=
map
->
get
(
rid
);
std
::
vector
<
phi
::
DenseTensor
>
in_tensor
;
std
::
vector
<
phi
::
DenseTensor
>
out_tensor
;
in_tensor
.
push_back
(
*
in
);
out_tensor
.
push_back
(
*
out
);
auto
task
=
pg
->
AllGather
(
in_tensor
,
out_tensor
);
task
->
Wait
();
return
;
}
auto
place
=
ctx
.
GetPlace
();
auto
place
=
ctx
.
GetPlace
();
auto
comm
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
rid
,
place
);
auto
comm
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
rid
,
place
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
...
...
paddle/fluid/operators/collective/c_broadcast_op.cu.cc
浏览文件 @
1e56ca8a
...
@@ -41,7 +41,12 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -41,7 +41,12 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> {
if
(
map
->
has
(
rid
))
{
if
(
map
->
has
(
rid
))
{
// Use ProcessGroup
// Use ProcessGroup
distributed
::
ProcessGroup
*
pg
=
map
->
get
(
rid
);
distributed
::
ProcessGroup
*
pg
=
map
->
get
(
rid
);
pg
->
Broadcast
(
x
,
out
);
std
::
vector
<
phi
::
DenseTensor
>
in_tensor
;
std
::
vector
<
phi
::
DenseTensor
>
out_tensor
;
in_tensor
.
push_back
(
*
x
);
out_tensor
.
push_back
(
*
out
);
auto
task
=
pg
->
Broadcast
(
in_tensor
,
out_tensor
);
task
->
Wait
();
return
;
return
;
}
}
...
...
paddle/fluid/pybind/distributed_py.cc
浏览文件 @
1e56ca8a
...
@@ -115,8 +115,10 @@ void BindDistributed(py::module *m) {
...
@@ -115,8 +115,10 @@ void BindDistributed(py::module *m) {
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
distributed
::
AllreduceOptions
opts
;
distributed
::
AllreduceOptions
opts
;
opts
.
reduce_op
=
op
;
opts
.
reduce_op
=
op
;
std
::
vector
<
Tensor
>
tensors
=
{
tensor
};
auto
dense
=
return
self
.
AllReduce
(
tensors
,
opts
);
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
tensors
=
{
*
dense
};
return
self
.
AllReduce
(
tensors
,
tensors
,
opts
);
},
},
py
::
arg
(
"tensor"
),
py
::
arg
(
"op"
)
=
distributed
::
ReduceOp
::
SUM
,
py
::
arg
(
"tensor"
),
py
::
arg
(
"op"
)
=
distributed
::
ReduceOp
::
SUM
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
...
@@ -127,8 +129,10 @@ void BindDistributed(py::module *m) {
...
@@ -127,8 +129,10 @@ void BindDistributed(py::module *m) {
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
distributed
::
BroadcastOptions
opts
;
distributed
::
BroadcastOptions
opts
;
opts
.
source_rank
=
source_rank
;
opts
.
source_rank
=
source_rank
;
std
::
vector
<
Tensor
>
tensors
=
{
tensor
};
auto
dense
=
return
self
.
Broadcast
(
tensors
,
opts
);
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
tensors
=
{
*
dense
};
return
self
.
Broadcast
(
tensors
,
tensors
,
opts
);
},
},
py
::
arg
(
"tensor"
),
py
::
arg
(
"source_rank"
),
py
::
arg
(
"tensor"
),
py
::
arg
(
"source_rank"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
...
@@ -146,7 +150,9 @@ void BindDistributed(py::module *m) {
...
@@ -146,7 +150,9 @@ void BindDistributed(py::module *m) {
[](
distributed
::
ProcessGroup
&
self
,
py
::
handle
py_tensor
,
[](
distributed
::
ProcessGroup
&
self
,
py
::
handle
py_tensor
,
int
dst
)
{
int
dst
)
{
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
std
::
vector
<
Tensor
>
tensors
=
{
tensor
};
auto
dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
tensors
=
{
*
dense
};
return
self
.
Send
(
tensors
,
dst
);
return
self
.
Send
(
tensors
,
dst
);
},
},
py
::
arg
(
"tensor"
),
py
::
arg
(
"dst"
),
py
::
arg
(
"tensor"
),
py
::
arg
(
"dst"
),
...
@@ -156,7 +162,9 @@ void BindDistributed(py::module *m) {
...
@@ -156,7 +162,9 @@ void BindDistributed(py::module *m) {
[](
distributed
::
ProcessGroup
&
self
,
py
::
handle
py_tensor
,
[](
distributed
::
ProcessGroup
&
self
,
py
::
handle
py_tensor
,
int
src
)
{
int
src
)
{
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
std
::
vector
<
Tensor
>
tensors
=
{
tensor
};
auto
dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
tensors
=
{
*
dense
};
return
self
.
Recv
(
tensors
,
src
);
return
self
.
Recv
(
tensors
,
src
);
},
},
py
::
arg
(
"tensor"
),
py
::
arg
(
"src"
),
py
::
arg
(
"tensor"
),
py
::
arg
(
"src"
),
...
@@ -167,8 +175,12 @@ void BindDistributed(py::module *m) {
...
@@ -167,8 +175,12 @@ void BindDistributed(py::module *m) {
py
::
handle
py_out_tensor
)
{
py
::
handle
py_out_tensor
)
{
auto
in_tensor
=
CastPyArg2Tensor
(
py_in_tensor
.
ptr
(),
0
);
auto
in_tensor
=
CastPyArg2Tensor
(
py_in_tensor
.
ptr
(),
0
);
auto
out_tensor
=
CastPyArg2Tensor
(
py_out_tensor
.
ptr
(),
0
);
auto
out_tensor
=
CastPyArg2Tensor
(
py_out_tensor
.
ptr
(),
0
);
std
::
vector
<
Tensor
>
in_tensors
=
{
in_tensor
};
auto
in_dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
std
::
vector
<
Tensor
>
out_tensors
=
{
out_tensor
};
in_tensor
.
impl
());
auto
out_dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
out_tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
in_tensors
=
{
*
in_dense
};
std
::
vector
<
phi
::
DenseTensor
>
out_tensors
=
{
*
out_dense
};
return
self
.
AllGather
(
in_tensors
,
out_tensors
);
return
self
.
AllGather
(
in_tensors
,
out_tensors
);
},
},
py
::
arg
(
"in"
),
py
::
arg
(
"out"
),
py
::
arg
(
"in"
),
py
::
arg
(
"out"
),
...
@@ -179,8 +191,12 @@ void BindDistributed(py::module *m) {
...
@@ -179,8 +191,12 @@ void BindDistributed(py::module *m) {
py
::
handle
py_out_tensor
)
{
py
::
handle
py_out_tensor
)
{
auto
in_tensor
=
CastPyArg2Tensor
(
py_in_tensor
.
ptr
(),
0
);
auto
in_tensor
=
CastPyArg2Tensor
(
py_in_tensor
.
ptr
(),
0
);
auto
out_tensor
=
CastPyArg2Tensor
(
py_out_tensor
.
ptr
(),
0
);
auto
out_tensor
=
CastPyArg2Tensor
(
py_out_tensor
.
ptr
(),
0
);
std
::
vector
<
Tensor
>
in_tensors
=
{
in_tensor
};
auto
in_dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
std
::
vector
<
Tensor
>
out_tensors
=
{
out_tensor
};
in_tensor
.
impl
());
auto
out_dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
out_tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
in_tensors
=
{
*
in_dense
};
std
::
vector
<
phi
::
DenseTensor
>
out_tensors
=
{
*
out_dense
};
return
self
.
AllToAll
(
in_tensors
,
out_tensors
);
return
self
.
AllToAll
(
in_tensors
,
out_tensors
);
},
},
py
::
arg
(
"in"
),
py
::
arg
(
"out"
),
py
::
arg
(
"in"
),
py
::
arg
(
"out"
),
...
@@ -193,8 +209,10 @@ void BindDistributed(py::module *m) {
...
@@ -193,8 +209,10 @@ void BindDistributed(py::module *m) {
distributed
::
ReduceOptions
opts
;
distributed
::
ReduceOptions
opts
;
opts
.
reduce_op
=
op
;
opts
.
reduce_op
=
op
;
opts
.
root_rank
=
dst
;
opts
.
root_rank
=
dst
;
std
::
vector
<
Tensor
>
tensors
=
{
in_tensor
};
auto
dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
return
self
.
Reduce
(
tensors
,
opts
);
in_tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
tensors
=
{
*
dense
};
return
self
.
Reduce
(
tensors
,
tensors
,
opts
);
},
},
py
::
arg
(
"tensor"
),
py
::
arg
(
"dst"
),
py
::
arg
(
"tensor"
),
py
::
arg
(
"dst"
),
py
::
arg
(
"op"
)
=
distributed
::
ReduceOp
::
SUM
,
py
::
arg
(
"op"
)
=
distributed
::
ReduceOp
::
SUM
,
...
@@ -207,8 +225,12 @@ void BindDistributed(py::module *m) {
...
@@ -207,8 +225,12 @@ void BindDistributed(py::module *m) {
auto
out_tensor
=
CastPyArg2Tensor
(
py_out_tensor
.
ptr
(),
0
);
auto
out_tensor
=
CastPyArg2Tensor
(
py_out_tensor
.
ptr
(),
0
);
distributed
::
ScatterOptions
opts
;
distributed
::
ScatterOptions
opts
;
opts
.
root_rank
=
src
;
opts
.
root_rank
=
src
;
std
::
vector
<
Tensor
>
in_tensors
=
{
in_tensor
};
auto
in_dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
std
::
vector
<
Tensor
>
out_tensors
=
{
out_tensor
};
in_tensor
.
impl
());
auto
out_dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
out_tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
in_tensors
=
{
*
in_dense
};
std
::
vector
<
phi
::
DenseTensor
>
out_tensors
=
{
*
out_dense
};
return
self
.
Scatter
(
in_tensors
,
out_tensors
,
opts
);
return
self
.
Scatter
(
in_tensors
,
out_tensors
,
opts
);
},
},
py
::
arg
(
"in"
),
py
::
arg
(
"out"
),
py
::
arg
(
"src"
),
py
::
arg
(
"in"
),
py
::
arg
(
"out"
),
py
::
arg
(
"src"
),
...
...
python/paddle/fluid/tests/unittests/init_process_group.py
浏览文件 @
1e56ca8a
...
@@ -46,6 +46,11 @@ class TestProcessGroupFp32(unittest.TestCase):
...
@@ -46,6 +46,11 @@ class TestProcessGroupFp32(unittest.TestCase):
group
=
paddle
.
distributed
.
collective
.
Group
(
-
1
,
2
,
0
,
[
-
1
,
-
2
])
group
=
paddle
.
distributed
.
collective
.
Group
(
-
1
,
2
,
0
,
[
-
1
,
-
2
])
ret
=
paddle
.
distributed
.
barrier
(
group
)
ret
=
paddle
.
distributed
.
barrier
(
group
)
assert
ret
==
None
assert
ret
==
None
paddle
.
enable_static
()
in_tensor
=
paddle
.
empty
((
1
,
2
))
in_tensor2
=
paddle
.
empty
((
1
,
2
))
paddle
.
distributed
.
broadcast
(
in_tensor
,
src
=
0
)
paddle
.
distributed
.
all_gather
([
in_tensor
,
in_tensor2
],
in_tensor
)
print
(
"test ok
\n
"
)
print
(
"test ok
\n
"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录