Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e1a1c354
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e1a1c354
编写于
11月 07, 2022
作者:
W
Wen Sun
提交者:
GitHub
11月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor collective communication all_gather, all_reduce, broadcast & barrier C++ API (#47481)
上级
2337e609
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
676 addition
and
767 deletion
+676
-767
paddle/fluid/distributed/collective/CMakeLists.txt
paddle/fluid/distributed/collective/CMakeLists.txt
+1
-0
paddle/fluid/distributed/collective/Common.cc
paddle/fluid/distributed/collective/Common.cc
+2
-0
paddle/fluid/distributed/collective/Common.h
paddle/fluid/distributed/collective/Common.h
+2
-0
paddle/fluid/distributed/collective/NCCLTools.h
paddle/fluid/distributed/collective/NCCLTools.h
+0
-198
paddle/fluid/distributed/collective/ProcessGroup.cc
paddle/fluid/distributed/collective/ProcessGroup.cc
+13
-9
paddle/fluid/distributed/collective/ProcessGroup.h
paddle/fluid/distributed/collective/ProcessGroup.h
+48
-13
paddle/fluid/distributed/collective/ProcessGroupGloo.cc
paddle/fluid/distributed/collective/ProcessGroupGloo.cc
+11
-0
paddle/fluid/distributed/collective/ProcessGroupGloo.h
paddle/fluid/distributed/collective/ProcessGroupGloo.h
+7
-0
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
+364
-328
paddle/fluid/distributed/collective/ProcessGroupNCCL.h
paddle/fluid/distributed/collective/ProcessGroupNCCL.h
+84
-74
paddle/fluid/distributed/collective/ProcessGroupStream.cc
paddle/fluid/distributed/collective/ProcessGroupStream.cc
+38
-37
paddle/fluid/distributed/collective/ProcessGroupStream.h
paddle/fluid/distributed/collective/ProcessGroupStream.h
+30
-25
paddle/fluid/operators/fused/fused_attention_op.cu
paddle/fluid/operators/fused/fused_attention_op.cu
+1
-6
paddle/fluid/operators/fused/fused_feedforward_op.cu
paddle/fluid/operators/fused/fused_feedforward_op.cu
+1
-6
paddle/fluid/pybind/distributed_py.cc
paddle/fluid/pybind/distributed_py.cc
+66
-64
python/paddle/distributed/communication/stream/all_gather.py
python/paddle/distributed/communication/stream/all_gather.py
+8
-7
未找到文件。
paddle/fluid/distributed/collective/CMakeLists.txt
浏览文件 @
e1a1c354
...
@@ -28,6 +28,7 @@ if(WITH_NCCL OR WITH_RCCL)
...
@@ -28,6 +28,7 @@ if(WITH_NCCL OR WITH_RCCL)
enforce
enforce
collective_helper
collective_helper
device_context
device_context
${
DEVICE_EVENT_LIBS
}
dense_tensor
)
dense_tensor
)
if
(
WITH_DISTRIBUTE AND WITH_PSCORE
)
if
(
WITH_DISTRIBUTE AND WITH_PSCORE
)
if
(
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0
)
if
(
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0
)
...
...
paddle/fluid/distributed/collective/Common.cc
浏览文件 @
e1a1c354
...
@@ -40,6 +40,8 @@ std::string GetKeyFromPlaces(const std::vector<Place>& places) {
...
@@ -40,6 +40,8 @@ std::string GetKeyFromPlaces(const std::vector<Place>& places) {
return
placeList
;
return
placeList
;
}
}
std
::
string
GetKeyFromPlace
(
const
Place
&
place
)
{
return
place
.
DebugString
();
}
bool
CheckTensorsInCudaPlace
(
const
std
::
vector
<
phi
::
DenseTensor
>&
tensors
)
{
bool
CheckTensorsInCudaPlace
(
const
std
::
vector
<
phi
::
DenseTensor
>&
tensors
)
{
return
std
::
all_of
(
return
std
::
all_of
(
tensors
.
cbegin
(),
tensors
.
cend
(),
[
&
](
const
phi
::
DenseTensor
&
t
)
{
tensors
.
cbegin
(),
tensors
.
cend
(),
[
&
](
const
phi
::
DenseTensor
&
t
)
{
...
...
paddle/fluid/distributed/collective/Common.h
浏览文件 @
e1a1c354
...
@@ -25,6 +25,8 @@ using Place = paddle::platform::Place;
...
@@ -25,6 +25,8 @@ using Place = paddle::platform::Place;
std
::
vector
<
Place
>
GetPlaceList
(
const
std
::
vector
<
phi
::
DenseTensor
>&
tensors
);
std
::
vector
<
Place
>
GetPlaceList
(
const
std
::
vector
<
phi
::
DenseTensor
>&
tensors
);
// Get the deviceList String from the list of devices
// Get the deviceList String from the list of devices
std
::
string
GetKeyFromPlaces
(
const
std
::
vector
<
Place
>&
places
);
std
::
string
GetKeyFromPlaces
(
const
std
::
vector
<
Place
>&
places
);
// Get the device string from one device
std
::
string
GetKeyFromPlace
(
const
Place
&
place
);
bool
CheckTensorsInCudaPlace
(
const
std
::
vector
<
phi
::
DenseTensor
>&
tensors
);
bool
CheckTensorsInCudaPlace
(
const
std
::
vector
<
phi
::
DenseTensor
>&
tensors
);
...
...
paddle/fluid/distributed/collective/NCCLTools.h
浏览文件 @
e1a1c354
...
@@ -59,204 +59,6 @@ namespace distributed {
...
@@ -59,204 +59,6 @@ namespace distributed {
} \
} \
} while (0)
} while (0)
// NOTE(shenliang03): EventManager are movable not copyable CudaEvent wrapper.
// EventManage is different from paddle::platform::CudaEvent.
// It uses lazy initialization and is only created when the
// Record() method is called for the first time; it also monitors
// device information to ensure that recorded stream and event
// are on the same device.
class
EventManager
{
public:
EventManager
()
{}
explicit
EventManager
(
unsigned
int
flags
)
:
flags_
{
flags
}
{}
~
EventManager
()
{
if
(
is_created_
)
{
platform
::
CUDADeviceGuard
guard
(
device_index_
);
#ifdef PADDLE_WITH_HIP
hipEventDestroy
(
event_
);
#else
cudaEventDestroy
(
event_
);
#endif
}
}
EventManager
(
const
EventManager
&
)
=
delete
;
EventManager
&
operator
=
(
const
EventManager
&
)
=
delete
;
EventManager
(
EventManager
&&
other
)
{
std
::
swap
(
flags_
,
other
.
flags_
);
std
::
swap
(
is_created_
,
other
.
is_created_
);
std
::
swap
(
device_index_
,
other
.
device_index_
);
std
::
swap
(
event_
,
other
.
event_
);
}
EventManager
&
operator
=
(
EventManager
&&
other
)
{
std
::
swap
(
flags_
,
other
.
flags_
);
std
::
swap
(
is_created_
,
other
.
is_created_
);
std
::
swap
(
device_index_
,
other
.
device_index_
);
std
::
swap
(
event_
,
other
.
event_
);
return
*
this
;
}
bool
IsCreated
()
const
{
return
is_created_
;
}
bool
DeviceId
()
const
{
return
device_index_
;
}
gpuEvent_t
GetRawCudaEvent
()
const
{
return
event_
;
}
void
Record
(
const
phi
::
GPUContext
&
ctx
)
{
auto
device_index
=
ctx
.
GetPlace
().
device
;
if
(
!
is_created_
)
{
CreateEvent
(
device_index
);
}
PADDLE_ENFORCE_EQ
(
device_index
,
device_index_
,
platform
::
errors
::
PreconditionNotMet
(
"phi::GPUContext's device %d does not match"
"Event's device %d"
,
device_index
,
device_index_
));
platform
::
CUDADeviceGuard
guard
(
device_index_
);
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaEventRecord
(
event_
,
ctx
.
stream
()));
#else
PADDLE_ENFORCE_GPU_SUCCESS
(
hipEventRecord
(
event_
,
ctx
.
stream
()));
#endif
}
bool
Query
()
const
{
#ifdef PADDLE_WITH_HIP
gpuError_t
err
=
hipEventQuery
(
event_
);
if
(
err
==
hipSuccess
)
{
return
true
;
}
if
(
err
==
hipErrorNotReady
)
{
return
false
;
}
#else
gpuError_t
err
=
cudaEventQuery
(
event_
);
if
(
err
==
cudaSuccess
)
{
return
true
;
}
if
(
err
==
cudaErrorNotReady
)
{
return
false
;
}
#endif
PADDLE_ENFORCE_GPU_SUCCESS
(
err
);
return
false
;
}
void
Synchronize
()
const
{
if
(
is_created_
)
{
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS
(
hipEventSynchronize
(
event_
));
#else
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaEventSynchronize
(
event_
));
#endif
}
}
void
Block
(
const
phi
::
GPUContext
&
ctx
)
const
{
if
(
is_created_
)
{
auto
device_index
=
ctx
.
GetPlace
().
device
;
PADDLE_ENFORCE_EQ
(
device_index
,
device_index_
,
platform
::
errors
::
PreconditionNotMet
(
"phi::GPUContext's device %d does not match"
"Event's device %d"
,
device_index
,
device_index_
));
platform
::
CUDADeviceGuard
guard
(
device_index_
);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS
(
hipStreamWaitEvent
(
ctx
.
stream
(),
event_
,
0
));
#else
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaStreamWaitEvent
(
ctx
.
stream
(),
event_
,
0
));
#endif
}
}
private:
#ifdef PADDLE_WITH_HIP
unsigned
int
flags_
=
hipEventDefault
;
#else
unsigned
int
flags_
=
cudaEventDefault
;
#endif
bool
is_created_
{
false
};
gpuEvent_t
event_
{};
int8_t
device_index_
{
0
};
private:
void
CreateEvent
(
int
device_index
)
{
device_index_
=
device_index
;
platform
::
CUDADeviceGuard
guard
(
device_index
);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS
(
hipEventCreateWithFlags
(
&
event_
,
flags_
));
#else
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaEventCreateWithFlags
(
&
event_
,
flags_
));
#endif
is_created_
=
true
;
}
};
// NOTE(shenliang03): NCCLCommManager is more lightweight than
// platform::NCCLComm
class
NCCLCommManager
{
public:
explicit
NCCLCommManager
(
ncclComm_t
ncclComm
)
:
nccl_comm_
(
ncclComm
)
{}
NCCLCommManager
()
:
NCCLCommManager
(
nullptr
)
{}
~
NCCLCommManager
()
noexcept
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
nccl_comm_
)
{
platform
::
dynload
::
ncclCommDestroy
(
nccl_comm_
);
}
}
static
std
::
shared_ptr
<
NCCLCommManager
>
Create
(
int
num_ranks
,
int
rank
,
ncclUniqueId
comm_id
)
{
auto
nccl_manager
=
std
::
make_shared
<
NCCLCommManager
>
();
NCCLCHECK
(
platform
::
dynload
::
ncclCommInitRank
(
&
(
nccl_manager
->
nccl_comm_
),
num_ranks
,
comm_id
,
rank
));
nccl_manager
->
nccl_id_
=
comm_id
;
nccl_manager
->
rank_
=
rank
;
return
nccl_manager
;
}
ncclUniqueId
GetNcclId
()
const
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
return
nccl_id_
;
}
ncclComm_t
GetNcclComm
()
const
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
return
nccl_comm_
;
}
NCCLCommManager
(
const
NCCLCommManager
&
)
=
delete
;
NCCLCommManager
&
operator
=
(
const
NCCLCommManager
&
)
=
delete
;
NCCLCommManager
&
operator
=
(
NCCLCommManager
&&
other
)
=
delete
;
NCCLCommManager
(
NCCLCommManager
&&
other
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
other
.
mutex_
);
std
::
swap
(
nccl_comm_
,
other
.
nccl_comm_
);
}
protected:
ncclComm_t
nccl_comm_
;
ncclUniqueId
nccl_id_
;
int
rank_
;
mutable
std
::
mutex
mutex_
;
};
ncclRedOp_t
ToNCCLRedType
(
ReduceOp
reduction
);
ncclRedOp_t
ToNCCLRedType
(
ReduceOp
reduction
);
std
::
string
SerializeNCCLUniqueId
(
const
ncclUniqueId
&
ncclID
);
std
::
string
SerializeNCCLUniqueId
(
const
ncclUniqueId
&
ncclID
);
...
...
paddle/fluid/distributed/collective/ProcessGroup.cc
浏览文件 @
e1a1c354
...
@@ -17,15 +17,7 @@
...
@@ -17,15 +17,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
ProcessGroup
::
Task
::
Task
(
int
rank
,
ProcessGroup
::
Task
::
Task
(
int
rank
,
CommType
comm_type
,
bool
sync_op
)
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
CommType
comm_type
)
:
rank_
(
rank
),
comm_type_
(
comm_type
)
{}
ProcessGroup
::
Task
::
Task
(
int
rank
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
CommType
comm_type
,
bool
sync_op
)
:
rank_
(
rank
),
comm_type_
(
comm_type
),
sync_op_
(
sync_op
)
{}
:
rank_
(
rank
),
comm_type_
(
comm_type
),
sync_op_
(
sync_op
)
{}
ProcessGroup
::
Task
::~
Task
()
=
default
;
ProcessGroup
::
Task
::~
Task
()
=
default
;
...
@@ -62,5 +54,17 @@ ProcessGroup::ProcessGroup(int rank, int size, int gid)
...
@@ -62,5 +54,17 @@ ProcessGroup::ProcessGroup(int rank, int size, int gid)
}
}
}
}
// TODO(sunyilun): methods below will be removed later
ProcessGroup
::
Task
::
Task
(
int
rank
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
CommType
comm_type
)
:
rank_
(
rank
),
comm_type_
(
comm_type
)
{}
ProcessGroup
::
Task
::
Task
(
int
rank
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
CommType
comm_type
,
bool
sync_op
)
:
rank_
(
rank
),
comm_type_
(
comm_type
),
sync_op_
(
sync_op
)
{}
}
// namespace distributed
}
// namespace distributed
}
// namespace paddle
}
// namespace paddle
paddle/fluid/distributed/collective/ProcessGroup.h
浏览文件 @
e1a1c354
...
@@ -54,13 +54,7 @@ class ProcessGroup {
...
@@ -54,13 +54,7 @@ class ProcessGroup {
public:
public:
class
Task
{
class
Task
{
public:
public:
Task
(
int
rank
,
Task
(
int
rank
,
CommType
comm_type
,
bool
sync_op
);
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
CommType
comm_type
);
Task
(
int
rank
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
CommType
comm_type
,
bool
sync_op
);
virtual
~
Task
();
virtual
~
Task
();
virtual
bool
IsCompleted
();
virtual
bool
IsCompleted
();
...
@@ -69,6 +63,15 @@ class ProcessGroup {
...
@@ -69,6 +63,15 @@ class ProcessGroup {
virtual
void
UpdateWaitChain
(
const
phi
::
DeviceContext
&
ctx
);
virtual
void
UpdateWaitChain
(
const
phi
::
DeviceContext
&
ctx
);
bool
IsSync
()
const
{
return
sync_op_
;
}
bool
IsSync
()
const
{
return
sync_op_
;
}
// TODO(sunyilun): methods below will be removed later
Task
(
int
rank
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
CommType
comm_type
);
Task
(
int
rank
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
CommType
comm_type
,
bool
sync_op
);
protected:
protected:
const
int
rank_
;
const
int
rank_
;
CommType
comm_type_
{
CommType
::
UNKNOWN
};
CommType
comm_type_
{
CommType
::
UNKNOWN
};
...
@@ -79,6 +82,7 @@ class ProcessGroup {
...
@@ -79,6 +82,7 @@ class ProcessGroup {
bool
sync_op_
{
true
};
bool
sync_op_
{
true
};
};
};
public:
explicit
ProcessGroup
(
int
rank
,
explicit
ProcessGroup
(
int
rank
,
int
size
,
int
size
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
...
@@ -93,12 +97,48 @@ class ProcessGroup {
...
@@ -93,12 +97,48 @@ class ProcessGroup {
int
GetSize
()
const
{
return
size_
;
}
int
GetSize
()
const
{
return
size_
;
}
virtual
std
::
string
GetBackendName
()
const
=
0
;
virtual
std
::
string
GetBackendName
()
const
=
0
;
virtual
const
phi
::
DeviceContext
&
GetDeviceContext
(
const
Place
&
place
)
const
{
virtual
const
phi
::
DeviceContext
&
GetDeviceContext
(
const
Place
&
place
)
const
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Does not support to get device_context from ProcessGroup%s."
,
"Does not support to get device_context from ProcessGroup%s."
,
GetBackendName
()));
GetBackendName
()));
}
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
bool
sync_op
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support all_gather with sync_op flag"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
AllreduceOptions
&
opts
,
bool
sync_op
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support all_reduce with sync_op flag"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Barrier
(
const
BarrierOptions
&
=
BarrierOptions
())
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support barrier"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
BroadcastOptions
&
opts
,
bool
sync_op
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support broadcast with sync_op flag"
,
GetBackendName
()));
}
// TODO(liyurui): This API will be moved later
// TODO(liyurui): This API will be moved later
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
vector
<
phi
::
DenseTensor
>&
/* input tensors */
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
/* input tensors */
,
// NOLINT
...
@@ -118,6 +158,7 @@ class ProcessGroup {
...
@@ -118,6 +158,7 @@ class ProcessGroup {
GetBackendName
()));
GetBackendName
()));
}
}
// TODO(sunyilun): methods below will be removed later
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
/* input tensors */
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
/* input tensors */
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
/* output tensors */
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
/* output tensors */
,
// NOLINT
...
@@ -136,12 +177,6 @@ class ProcessGroup {
...
@@ -136,12 +177,6 @@ class ProcessGroup {
GetBackendName
()));
GetBackendName
()));
}
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Barrier
(
const
BarrierOptions
&
=
BarrierOptions
())
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support barrier"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Send
(
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Send
(
std
::
vector
<
phi
::
DenseTensor
>&
,
int
)
{
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
,
int
)
{
// NOLINT
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
...
...
paddle/fluid/distributed/collective/ProcessGroupGloo.cc
浏览文件 @
e1a1c354
...
@@ -229,6 +229,17 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask {
...
@@ -229,6 +229,17 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask {
}
}
};
};
// TODO(sunyilun): for compatibility, will be updated later
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupGloo
::
Broadcast
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
BroadcastOptions
&
opts
,
bool
sync_op
)
{
std
::
vector
<
phi
::
DenseTensor
>
in_wrapper
=
{
in_tensor
};
std
::
vector
<
phi
::
DenseTensor
>
out_wrapper
=
{
*
out_tensor
};
return
Broadcast
(
in_wrapper
,
out_wrapper
,
opts
,
true
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupGloo
::
Broadcast
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupGloo
::
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
std
::
vector
<
phi
::
DenseTensor
>&
outputs
,
std
::
vector
<
phi
::
DenseTensor
>&
outputs
,
...
...
paddle/fluid/distributed/collective/ProcessGroupGloo.h
浏览文件 @
e1a1c354
...
@@ -108,6 +108,13 @@ class ProcessGroupGloo : public ProcessGroup {
...
@@ -108,6 +108,13 @@ class ProcessGroupGloo : public ProcessGroup {
~
ProcessGroupGloo
()
=
default
;
~
ProcessGroupGloo
()
=
default
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
BroadcastOptions
&
opts
,
bool
sync_op
)
override
;
// TODO(sunyilun): methods below will be removed later
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
std
::
vector
<
phi
::
DenseTensor
>&
outputs
,
std
::
vector
<
phi
::
DenseTensor
>&
outputs
,
...
...
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
浏览文件 @
e1a1c354
...
@@ -15,12 +15,8 @@
...
@@ -15,12 +15,8 @@
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h"
DECLARE_bool
(
nccl_blocking_wait
);
DECLARE_bool
(
nccl_blocking_wait
);
DECLARE_bool
(
use_stream_safe_cuda_allocator
);
DECLARE_bool
(
use_stream_safe_cuda_allocator
);
...
@@ -30,89 +26,299 @@ constexpr int64_t kWaitBlockTImeout = 10;
...
@@ -30,89 +26,299 @@ constexpr int64_t kWaitBlockTImeout = 10;
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
void
SyncDefaultStream
(
ProcessGroupNCCL
::
NCCLTask
::
NCCLTask
(
const
Place
&
place
,
const
std
::
vector
<
Place
>&
places
,
int
rank
,
std
::
vector
<
EventManager
>&
ncclEvents
,
// NOLINT
CommType
comm_type
,
std
::
vector
<
std
::
unique_ptr
<
phi
::
GPUContext
>>&
dev_ctx
)
{
// NOLINT
bool
sync_op
,
for
(
size_t
i
=
0
;
i
<
places
.
size
();
++
i
)
{
bool
use_calc_stream
)
auto
*
default_ctx
=
static_cast
<
phi
::
GPUContext
*>
(
:
TaskStream
(
rank
,
comm_type
,
sync_op
,
use_calc_stream
),
platform
::
DeviceContextPool
::
Instance
().
Get
(
places
[
i
]));
comm_event_
(
place
),
ncclEvents
[
i
].
Record
(
*
default_ctx
);
place_
(
place
)
{}
ncclEvents
[
i
].
Block
(
*
dev_ctx
[
i
]);
ProcessGroupNCCL
::
NCCLTask
::~
NCCLTask
()
{}
bool
ProcessGroupNCCL
::
NCCLTask
::
IsCompleted
()
{
return
comm_event_
.
Query
();
}
void
ProcessGroupNCCL
::
NCCLTask
::
UpdateWaitChain
(
const
phi
::
DeviceContext
&
ctx
)
{
comm_event_
.
Record
(
&
ctx
);
}
// TODO(sheniang03): Add timeout for wait, now timeout unused
bool
ProcessGroupNCCL
::
NCCLTask
::
Wait
(
std
::
chrono
::
milliseconds
timeout
)
{
// Warning here when use calc stream but also invoke waiting explicitly.
if
(
UseCalcStream
())
{
VLOG
(
3
)
<<
"Warning: The communication is on calc stream, wait here is "
"useless."
;
return
true
;
}
const
auto
*
calc_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
);
comm_event_
.
Wait
(
platform
::
Place2DeviceType
(
place_
),
calc_ctx
);
if
(
FLAGS_nccl_blocking_wait
)
{
// NOTE(shenliang03): It will block host for sync
while
(
!
IsCompleted
())
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
kWaitBlockTImeout
));
}
}
}
if
(
barrier_
)
{
// If we use the work to do barrier, we should block cpu
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaDeviceSynchronize
());
#else
PADDLE_ENFORCE_GPU_SUCCESS
(
hipDeviceSynchronize
());
#endif
}
return
true
;
}
}
std
::
shared_ptr
<
ProcessGroupNCCL
::
NCCLTask
>
ProcessGroupNCCL
::
CreateTask
(
// Same as Wait
std
::
vector
<
Place
>
places
,
void
ProcessGroupNCCL
::
NCCLTask
::
Synchronize
()
{
Wait
(
kWaitTimeout
);
}
int
rank
,
CommType
comm_type
,
ProcessGroupNCCL
::
ProcessGroupNCCL
(
const
std
::
shared_ptr
<
Store
>&
store
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
)
{
int
rank
,
return
std
::
make_shared
<
ProcessGroupNCCL
::
NCCLTask
>
(
int
size
,
places
,
rank
,
comm_type
,
inputs
);
const
platform
::
Place
&
place
,
int
gid
)
:
ProcessGroupStream
(
rank
,
size
,
place
,
gid
),
store_
(
store
)
{
platform
::
SetDeviceId
(
place_
.
device
);
}
void
ProcessGroupNCCL
::
GroupStart
()
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupStart
());
}
void
ProcessGroupNCCL
::
GroupEnd
()
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupEnd
());
}
const
phi
::
DeviceContext
&
ProcessGroupNCCL
::
GetDeviceContext
(
const
Place
&
place
)
const
{
return
GetDeviceContext
(
place
,
/*use_calc_stream*/
false
);
}
const
phi
::
DeviceContext
&
ProcessGroupNCCL
::
GetDeviceContext
(
const
Place
&
place
,
bool
use_calc_stream
)
const
{
const
std
::
string
&
key
=
GetKeyFromPlace
(
place
);
if
(
use_calc_stream
)
{
const
auto
&
iter
=
place_to_calc_ctx_
.
find
(
key
);
return
*
iter
->
second
;
}
else
{
const
auto
&
iter
=
place_to_comm_ctx_
.
find
(
key
);
PADDLE_ENFORCE_NE
(
iter
,
place_to_comm_ctx_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find the device context in this process group."
));
return
*
iter
->
second
;
}
}
ncclComm_t
ProcessGroupNCCL
::
NCCLComm
(
const
Place
&
place
)
const
{
const
std
::
string
&
key
=
GetKeyFromPlace
(
place
);
const
auto
&
iter
=
place_to_comm_ctx_
.
find
(
key
);
PADDLE_ENFORCE_NE
(
iter
,
place_to_comm_ctx_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find the NCCL commmunicator in this process group."
));
return
iter
->
second
->
nccl_comm
();
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
AllGather
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
bool
sync_op
,
bool
use_calc_stream
)
{
return
Collective
(
out_tensor
,
in_tensor
,
[
&
](
phi
::
DenseTensor
*
output
,
const
phi
::
DenseTensor
&
input
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
return
platform
::
dynload
::
ncclAllGather
(
input
.
data
(),
output
->
data
(),
input
.
numel
(),
platform
::
ToNCCLDataType
(
input
.
dtype
()),
comm
,
stream
);
},
CommType
::
ALLGATHER
,
sync_op
,
use_calc_stream
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
AllReduce
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
AllreduceOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
{
return
Collective
(
out_tensor
,
in_tensor
,
[
&
](
phi
::
DenseTensor
*
output
,
const
phi
::
DenseTensor
&
input
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
return
platform
::
dynload
::
ncclAllReduce
(
input
.
data
(),
output
->
data
(),
input
.
numel
(),
platform
::
ToNCCLDataType
(
input
.
type
()),
ToNCCLRedType
(
opts
.
reduce_op
),
comm
,
stream
);
},
CommType
::
ALLREDUCE
,
sync_op
,
use_calc_stream
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Barrier
(
const
BarrierOptions
&
opts
)
{
auto
allocator
=
std
::
unique_ptr
<
phi
::
Allocator
>
(
new
paddle
::
experimental
::
DefaultAllocator
(
place_
));
phi
::
DenseTensorMeta
meta
(
phi
::
DataType
::
FLOAT32
,
phi
::
DDim
{
1
});
phi
::
DenseTensor
barrier_tensor
{
allocator
.
get
(),
meta
};
auto
task
=
AllReduce
(
&
barrier_tensor
,
barrier_tensor
,
{},
/*sync_op*/
true
,
/*use_calc_stream*/
false
);
auto
nccl_task
=
dynamic_cast
<
NCCLTask
*>
(
task
.
get
());
nccl_task
->
barrier_
=
true
;
return
task
;
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Broadcast
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
BroadcastOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
{
return
Collective
(
out_tensor
,
in_tensor
,
[
&
](
phi
::
DenseTensor
*
output
,
const
phi
::
DenseTensor
&
input
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
int
root
=
opts
.
source_rank
+
opts
.
source_root
;
return
platform
::
dynload
::
ncclBroadcast
(
input
.
data
(),
output
->
data
(),
input
.
numel
(),
platform
::
ToNCCLDataType
(
input
.
type
()),
root
,
comm
,
stream
);
},
CommType
::
BROADCAST
,
sync_op
,
use_calc_stream
);
}
}
std
::
shared_ptr
<
ProcessGroupNCCL
::
NCCLTask
>
ProcessGroupNCCL
::
CreateTask
(
std
::
shared_ptr
<
ProcessGroupNCCL
::
NCCLTask
>
ProcessGroupNCCL
::
CreateTask
(
const
std
::
vector
<
Place
>&
places
,
const
Place
&
place
,
int
rank
,
int
rank
,
CommType
comm_type
,
CommType
comm_type
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
bool
is_sync
,
bool
is_sync
,
bool
use_calc_stream
)
{
bool
use_calc_stream
)
{
return
std
::
make_shared
<
ProcessGroupNCCL
::
NCCLTask
>
(
return
std
::
make_shared
<
ProcessGroupNCCL
::
NCCLTask
>
(
place
s
,
rank
,
comm_type
,
inputs
,
is_sync
,
use_calc_stream
);
place
,
rank
,
comm_type
,
is_sync
,
use_calc_stream
);
}
}
ProcessGroupNCCL
::
NCCLTask
::
NCCLTask
(
void
ProcessGroupNCCL
::
BroadcastUniqueNCCLID
(
ncclUniqueId
*
nccl_id
)
{
const
std
::
vector
<
Place
>&
places
,
const
std
::
string
key
=
int
rank
,
"ProcessGroupNCCL/nccl_ids/"
+
std
::
to_string
(
gid_
)
+
"/0"
;
CommType
CommType
,
if
(
rank_
==
0
)
{
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
)
std
::
vector
<
uint8_t
>
nccl_id_wrapper
(
:
TaskStream
(
rank
,
inputs
,
CommType
),
places_
(
places
)
{
reinterpret_cast
<
uint8_t
*>
(
nccl_id
),
control_events_
.
resize
(
places
.
size
());
reinterpret_cast
<
uint8_t
*>
(
nccl_id
)
+
NCCL_UNIQUE_ID_BYTES
);
ncclComms_
.
resize
(
places
.
size
());
store_
->
set
(
key
,
nccl_id_wrapper
);
}
else
{
const
auto
&
nccl_id_wrapper
=
store_
->
get
(
key
);
std
::
memcpy
(
nccl_id
,
nccl_id_wrapper
.
data
(),
nccl_id_wrapper
.
size
());
}
}
}
ProcessGroupNCCL
::
NCCLTask
::
NCCLTask
(
void
ProcessGroupNCCL
::
CreateNCCLEnvCache
(
const
Place
&
place
,
const
std
::
vector
<
Place
>&
places
,
const
std
::
string
&
place_key
)
{
int
rank
,
ncclUniqueId
nccl_id
;
CommType
comm_type
,
if
(
rank_
==
0
)
{
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGetUniqueId
(
&
nccl_id
));
bool
sync_op
,
}
bool
use_calc_stream
)
BroadcastUniqueNCCLID
(
&
nccl_id
);
:
TaskStream
(
rank
,
inputs
,
comm_type
,
sync_op
,
use_calc_stream
),
places_
(
places
)
{
control_events_
.
resize
(
places
.
size
());
ncclComms_
.
resize
(
places
.
size
());
}
ProcessGroupNCCL
::
NCCLTask
::~
NCCLTask
()
{}
VLOG
(
3
)
<<
"init nccl rank: "
<<
rank_
<<
", nranks: "
<<
size_
<<
", place: "
<<
place_key
<<
", nccl uniqueid: "
<<
SerializeNCCLUniqueId
(
nccl_id
);
void
ProcessGroupNCCL
::
NCCLTask
::
SetOutputs
(
calc_event_
=
std
::
make_shared
<
platform
::
DeviceEvent
>
(
place
);
std
::
vector
<
phi
::
DenseTensor
>&
outputs
)
{
// NOLINT
auto
*
calc_ctx
=
static_cast
<
phi
::
GPUContext
*>
(
outputs_
=
std
::
make_shared
<
std
::
vector
<
phi
::
DenseTensor
>>
(
outputs
);
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
auto
comm_ctx
=
std
::
make_unique
<
phi
::
GPUContext
>
(
place
);
ncclComm_t
nccl_comm
;
NCCLCHECK
(
platform
::
dynload
::
ncclCommInitRank
(
&
nccl_comm
,
GetSize
(),
nccl_id
,
GetRank
()));
comm_ctx
->
set_nccl_comm
(
nccl_comm
);
place_to_calc_ctx_
[
place_key
]
=
calc_ctx
;
place_to_comm_ctx_
[
place_key
]
=
std
::
move
(
comm_ctx
);
// TODO(sunyilun): for compatibility, will be removed later
places_to_ctx_
[
place_key
]
=
{
place_to_comm_ctx_
[
place_key
].
get
()};
}
}
void
ProcessGroupNCCL
::
NCCLTask
::
SynchronizeStreams
()
{
void
ProcessGroupNCCL
::
SyncCalcStream
(
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
const
Place
&
place
,
const
std
::
shared_ptr
<
platform
::
DeviceEvent
>&
event
)
{
auto
*
default_ctx
=
static_cast
<
phi
::
GPUContext
*>
(
const
std
::
string
&
key
=
GetKeyFromPlace
(
place
);
platform
::
DeviceContextPool
::
Instance
().
Get
(
places_
[
i
]));
const
auto
*
calc_ctx
=
place_to_calc_ctx_
[
key
];
default_ctx
->
WaitEvent
(
control_events_
[
i
].
GetRawCudaEvent
());
const
auto
*
comm_ctx
=
place_to_comm_ctx_
[
key
].
get
();
}
event
->
Record
(
calc_ctx
);
event
->
Wait
(
platform
::
Place2DeviceType
(
place
),
comm_ctx
);
}
}
bool
ProcessGroupNCCL
::
NCCLTask
::
IsCompleted
()
{
template
<
typename
Fn
>
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Collective
(
if
(
!
control_events_
[
i
].
Query
())
{
phi
::
DenseTensor
*
out_tensor
,
return
false
;
const
phi
::
DenseTensor
&
in_tensor
,
}
Fn
fn
,
CommType
comm_type
,
bool
sync_op
,
bool
use_calc_stream
)
{
const
auto
&
place
=
in_tensor
.
place
();
const
auto
&
key
=
GetKeyFromPlace
(
place
);
if
(
!
calc_event_
)
{
CreateNCCLEnvCache
(
place
,
key
);
}
}
return
true
;
if
(
!
use_calc_stream
)
{
}
SyncCalcStream
(
place
,
calc_event_
);
}
void
ProcessGroupNCCL
::
NCCLTask
::
UpdateWaitChain
(
auto
task
=
CreateTask
(
place
,
rank_
,
comm_type
,
sync_op
,
use_calc_stream
);
const
phi
::
DeviceContext
&
ctx
)
{
control_events_
[
0
].
Record
(
*
static_cast
<
const
phi
::
GPUContext
*>
(
&
ctx
));
const
auto
*
calc_ctx
=
place_to_calc_ctx_
[
key
];
const
auto
&
comm_ctx
=
place_to_comm_ctx_
[
key
];
auto
nccl_stream
=
use_calc_stream
?
calc_ctx
->
stream
()
:
comm_ctx
->
stream
();
fn
(
out_tensor
,
in_tensor
,
comm_ctx
->
nccl_comm
(),
nccl_stream
);
if
(
!
use_calc_stream
)
{
if
(
FLAGS_use_stream_safe_cuda_allocator
)
{
memory
::
RecordStream
(
in_tensor
.
Holder
(),
nccl_stream
);
}
task
->
comm_event_
.
Record
(
comm_ctx
.
get
());
}
return
task
;
}
}
void
ProcessGroupNCCL
::
CheckSplitSizes
(
std
::
vector
<
int64_t
>*
split_sizes
,
void
ProcessGroupNCCL
::
CheckSplitSizes
(
std
::
vector
<
int64_t
>*
split_sizes
,
...
@@ -144,70 +350,58 @@ void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>* split_sizes,
...
@@ -144,70 +350,58 @@ void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>* split_sizes,
}
}
}
}
// TODO(sheniang03): Add timeout for wait, now timeout unused
// TODO(sunyilun): methods below will be removed later
bool
ProcessGroupNCCL
::
NCCLTask
::
Wait
(
std
::
chrono
::
milliseconds
timeout
)
{
void
SyncDefaultStream
(
const
std
::
vector
<
Place
>&
places
,
// Warning here when use calc stream but also invoke waiting explicitly.
const
std
::
shared_ptr
<
platform
::
DeviceEvent
>&
nccl_event
,
if
(
UseCalcStream
())
{
std
::
vector
<
phi
::
GPUContext
*>&
dev_ctx
)
{
// NOLINT
VLOG
(
3
)
<<
"Warning: The communication is on calc stream, wait here is "
for
(
size_t
i
=
0
;
i
<
places
.
size
();
++
i
)
{
"useless."
;
auto
*
default_ctx
=
static_cast
<
phi
::
GPUContext
*>
(
return
true
;
platform
::
DeviceContextPool
::
Instance
().
Get
(
places
[
i
]));
}
nccl_event
->
Record
(
default_ctx
);
nccl_event
->
Wait
(
platform
::
Place2DeviceType
(
places
[
i
]),
dev_ctx
[
i
]);
SynchronizeStreams
();
if
(
FLAGS_nccl_blocking_wait
)
{
// NOTE(shenliang03): It will block host for sync
while
(
!
IsCompleted
())
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
kWaitBlockTImeout
));
}
}
if
(
!
barrierTensors_
.
empty
())
{
// If we use the work to do barrier, we should block cpu
for
(
auto
&
place
:
places_
)
{
platform
::
CUDADeviceGuard
gpuGuard
(
place
);
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaDeviceSynchronize
());
#else
PADDLE_ENFORCE_GPU_SUCCESS
(
hipDeviceSynchronize
());
#endif
}
}
}
return
true
;
}
}
// Same as Wait
std
::
shared_ptr
<
ProcessGroupNCCL
::
NCCLTask
>
ProcessGroupNCCL
::
CreateTask
(
void
ProcessGroupNCCL
::
NCCLTask
::
Synchronize
()
{
Wait
(
kWaitTimeout
);
}
std
::
vector
<
Place
>
places
,
int
rank
,
ProcessGroupNCCL
::
ProcessGroupNCCL
(
const
std
::
shared_ptr
<
Store
>&
store
,
CommType
comm_type
,
int
rank
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
)
{
int
size
,
return
std
::
make_shared
<
ProcessGroupNCCL
::
NCCLTask
>
(
const
platform
::
Place
&
place
,
places
,
rank
,
comm_type
,
inputs
);
int
gid
)
:
ProcessGroupStream
(
rank
,
size
,
place
,
gid
),
store_
(
store
)
{
platform
::
SetDeviceId
(
place_
.
device
);
}
}
void
ProcessGroupNCCL
::
BroadcastUniqueNCCLID
(
std
::
shared_ptr
<
ProcessGroupNCCL
::
NCCLTask
>
ProcessGroupNCCL
::
CreateTask
(
std
::
vector
<
ncclUniqueId
>&
nccl_ids
)
{
// NOLINT
const
std
::
vector
<
Place
>&
places
,
if
(
rank_
==
0
)
{
int
rank
,
for
(
size_t
i
=
0
;
i
<
nccl_ids
.
size
();
i
++
)
{
CommType
comm_type
,
auto
key
=
"ProcessGroupNCCL/nccl_ids/"
+
std
::
to_string
(
gid_
)
+
"/"
+
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
std
::
to_string
(
i
);
bool
is_sync
,
auto
nccl_id
=
std
::
vector
<
uint8_t
>
(
bool
use_calc_stream
)
{
reinterpret_cast
<
uint8_t
*>
(
&
nccl_ids
[
i
]),
return
std
::
make_shared
<
ProcessGroupNCCL
::
NCCLTask
>
(
reinterpret_cast
<
uint8_t
*>
(
&
nccl_ids
[
i
])
+
NCCL_UNIQUE_ID_BYTES
);
places
,
rank
,
comm_type
,
inputs
,
is_sync
,
use_calc_stream
);
store_
->
set
(
key
,
nccl_id
);
}
}
else
{
for
(
size_t
i
=
0
;
i
<
nccl_ids
.
size
();
i
++
)
{
auto
key
=
"ProcessGroupNCCL/nccl_ids/"
+
std
::
to_string
(
gid_
)
+
"/"
+
std
::
to_string
(
i
);
auto
ret
=
store_
->
get
(
key
);
std
::
memcpy
(
&
nccl_ids
[
i
],
ret
.
data
(),
ret
.
size
());
}
}
}
}
ProcessGroupNCCL
::
NCCLTask
::
NCCLTask
(
const
std
::
vector
<
Place
>&
places
,
int
rank
,
CommType
CommType
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
)
:
TaskStream
(
rank
,
inputs
,
CommType
),
comm_event_
(
places
[
0
]),
place_
(
places
[
0
])
{}
ProcessGroupNCCL
::
NCCLTask
::
NCCLTask
(
const
std
::
vector
<
Place
>&
places
,
int
rank
,
CommType
comm_type
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
bool
sync_op
,
bool
use_calc_stream
)
:
TaskStream
(
rank
,
inputs
,
comm_type
,
sync_op
,
use_calc_stream
),
comm_event_
(
places
[
0
]),
place_
(
places
[
0
])
{}
// create NCCLManager cache for places_key
// create NCCLManager cache for places_key
void
ProcessGroupNCCL
::
CreateNCCLManagerCache
(
void
ProcessGroupNCCL
::
CreateNCCLManagerCache
(
const
std
::
string
&
places_key
,
const
std
::
vector
<
Place
>&
places
)
{
const
std
::
string
&
places_key
,
const
std
::
vector
<
Place
>&
places
)
{
...
@@ -217,22 +411,11 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
...
@@ -217,22 +411,11 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
"Not able to create/get the NCCL Communicator since "
"Not able to create/get the NCCL Communicator since "
"the GPU place are not known"
));
"the GPU place are not known"
));
std
::
vector
<
std
::
shared_ptr
<
NCCLCommManager
>>
nccl_comms
;
ncclUniqueId
nccl_id
;
nccl_comms
.
resize
(
places
.
size
());
// using vector just for broadcast
std
::
vector
<
ncclUniqueId
>
nccl_ids
;
nccl_ids
.
resize
(
1
);
auto
&
nccl_id
=
nccl_ids
.
front
();
for
(
auto
&
place
:
places
)
{
used_place_ids_
.
insert
(
place
.
GetDeviceId
());
}
if
(
rank_
==
0
)
{
if
(
rank_
==
0
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGetUniqueId
(
&
nccl_id
));
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGetUniqueId
(
&
nccl_id
));
}
}
BroadcastUniqueNCCLID
(
nccl_ids
);
BroadcastUniqueNCCLID
(
&
nccl_id
);
VLOG
(
3
)
<<
"init nccl rank: "
<<
rank_
<<
", nranks: "
<<
size_
VLOG
(
3
)
<<
"init nccl rank: "
<<
rank_
<<
", nranks: "
<<
size_
<<
", place: "
<<
places_key
<<
", place: "
<<
places_key
...
@@ -241,23 +424,33 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
...
@@ -241,23 +424,33 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
std
::
vector
<
std
::
unique_ptr
<
phi
::
GPUContext
>>
dev_ctx
;
std
::
vector
<
std
::
unique_ptr
<
phi
::
GPUContext
>>
dev_ctx
;
dev_ctx
.
resize
(
places
.
size
());
dev_ctx
.
resize
(
places
.
size
());
std
::
vector
<
phi
::
GPUContext
*>
dev_ctx_raw
;
dev_ctx_raw
.
resize
(
places
.
size
());
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupStart
());
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupStart
());
for
(
size_t
i
=
0
;
i
<
places
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places
.
size
();
++
i
)
{
platform
::
CUDADeviceGuard
guard
(
places
[
i
]);
platform
::
CUDADeviceGuard
guard
(
places
[
i
]);
nccl_comms
[
i
]
=
NCCLCommManager
::
Create
(
GetSize
(),
GetRank
(),
nccl_id
);
dev_ctx
[
i
].
reset
(
new
phi
::
GPUContext
(
places
[
i
]));
dev_ctx
[
i
].
reset
(
new
phi
::
GPUContext
(
places
[
i
]));
ncclComm_t
nccl_comm
;
NCCLCHECK
(
platform
::
dynload
::
ncclCommInitRank
(
&
nccl_comm
,
GetSize
(),
nccl_id
,
GetRank
()));
dev_ctx
[
i
]
->
set_nccl_comm
(
nccl_comm
);
dev_ctx_raw
[
i
]
=
dev_ctx
[
i
].
get
();
}
}
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupEnd
());
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupEnd
());
calc_event_
=
std
::
make_shared
<
platform
::
DeviceEvent
>
(
places
[
0
]);
std
::
vector
<
EventManager
>
events
;
// TODO(sunyilun): for compatibility, will be removed later
events
.
resize
(
places
.
size
());
place_to_calc_ctx_
[
places_key
]
=
static_cast
<
phi
::
GPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
places
[
0
]));
place_to_comm_ctx_
[
places_key
]
=
std
::
move
(
dev_ctx
[
0
]);
// These caches will be useful to process sync/wait/communicate
// These caches will be useful to process sync/wait/communicate
places_to_events_
.
emplace
(
places_key
,
std
::
move
(
events
));
places_to_ctx_
.
emplace
(
places_key
,
std
::
move
(
dev_ctx_raw
));
places_to_ncclcomm_
.
emplace
(
places_key
,
std
::
move
(
nccl_comms
));
places_to_ctx_
.
emplace
(
places_key
,
std
::
move
(
dev_ctx
));
}
}
template
<
typename
Fn
>
template
<
typename
Fn
>
...
@@ -273,15 +466,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
...
@@ -273,15 +466,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
if
(
places_to_ncclcomm_
.
find
(
key
)
==
places_to_ncclcomm_
.
end
()
)
{
if
(
!
calc_event_
)
{
CreateNCCLManagerCache
(
key
,
places
);
CreateNCCLManagerCache
(
key
,
places
);
}
}
}
}
auto
&
nccl_comms
=
places_to_ncclcomm_
[
key
];
if
(
!
use_calc_stream
)
{
if
(
!
use_calc_stream
)
{
SyncDefaultStream
(
places
,
places_to_events_
[
key
]
,
places_to_ctx_
[
key
]);
SyncDefaultStream
(
places
,
calc_event_
,
places_to_ctx_
[
key
]);
}
}
auto
task
=
auto
task
=
...
@@ -304,7 +495,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
...
@@ -304,7 +495,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
nccl_stream
=
places_to_ctx_
[
key
][
i
]
->
stream
();
nccl_stream
=
places_to_ctx_
[
key
][
i
]
->
stream
();
}
}
fn
(
inputs
[
i
],
outputs
[
i
],
nccl_comms
[
i
]
->
GetNcclComm
(),
nccl_stream
);
fn
(
inputs
[
i
],
outputs
[
i
],
places_to_ctx_
[
key
][
i
]
->
nccl_comm
(),
nccl_stream
);
}
}
}
}
...
@@ -330,7 +524,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
...
@@ -330,7 +524,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
if
(
!
use_calc_stream
)
{
if
(
!
use_calc_stream
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
cuda_guard
.
SetDevice
(
places
[
i
]);
cuda_guard
.
SetDevice
(
places
[
i
]);
task
->
co
ntrol_events_
[
i
].
Record
(
*
places_to_ctx_
[
key
][
i
]);
task
->
co
mm_event_
.
Record
(
places_to_ctx_
[
key
][
i
]);
}
}
}
}
...
@@ -348,14 +542,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
...
@@ -348,14 +542,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
if
(
places_to_ncclcomm_
.
find
(
key
)
==
places_to_ncclcomm_
.
end
()
)
{
if
(
!
calc_event_
)
{
CreateNCCLManagerCache
(
key
,
places
);
CreateNCCLManagerCache
(
key
,
places
);
}
}
}
}
auto
&
nccl_comms
=
places_to_ncclcomm_
[
key
];
SyncDefaultStream
(
places
,
calc_event_
,
places_to_ctx_
[
key
]);
SyncDefaultStream
(
places
,
places_to_events_
[
key
],
places_to_ctx_
[
key
]);
auto
task
=
CreateTask
(
places
,
rank_
,
op_type
,
inputs
);
auto
task
=
CreateTask
(
places
,
rank_
,
op_type
,
inputs
);
...
@@ -367,7 +559,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
...
@@ -367,7 +559,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
cuda_guard
.
SetDevice
(
places
[
i
]);
cuda_guard
.
SetDevice
(
places
[
i
]);
const
auto
&
nccl_stream
=
places_to_ctx_
[
key
][
i
]
->
stream
();
const
auto
&
nccl_stream
=
places_to_ctx_
[
key
][
i
]
->
stream
();
fn
(
inputs
[
i
],
outputs
[
i
],
nccl_comms
[
i
]
->
GetNcclComm
(),
nccl_stream
);
fn
(
inputs
[
i
],
outputs
[
i
],
places_to_ctx_
[
key
][
i
]
->
nccl_comm
(),
nccl_stream
);
}
}
}
}
...
@@ -381,7 +576,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
...
@@ -381,7 +576,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Collective(
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
cuda_guard
.
SetDevice
(
places
[
i
]);
cuda_guard
.
SetDevice
(
places
[
i
]);
task
->
co
ntrol_events_
[
i
].
Record
(
*
places_to_ctx_
[
key
][
i
]);
task
->
co
mm_event_
.
Record
(
places_to_ctx_
[
key
][
i
]);
}
}
return
task
;
return
task
;
}
}
...
@@ -393,18 +588,16 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in,
...
@@ -393,18 +588,16 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in,
CommType
op_type
)
{
CommType
op_type
)
{
std
::
vector
<
Place
>
places
;
std
::
vector
<
Place
>
places
;
places
.
push_back
(
in
->
place
());
places
.
push_back
(
in
->
place
());
const
auto
key
=
GetKeyFromPlaces
(
places
);
const
std
::
string
&
key
=
GetKeyFromPlaces
(
places
);
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
if
(
places_to_ncclcomm_
.
find
(
key
)
==
places_to_ncclcomm_
.
end
()
)
{
if
(
!
calc_event_
)
{
CreateNCCLManagerCache
(
key
,
places
);
CreateNCCLManagerCache
(
key
,
places
);
}
}
}
}
auto
&
nccl_comms
=
places_to_ncclcomm_
[
key
];
SyncDefaultStream
(
places
,
calc_event_
,
places_to_ctx_
[
key
]);
SyncDefaultStream
(
places
,
places_to_events_
[
key
],
places_to_ctx_
[
key
]);
// construct uninitialize guard for device
// construct uninitialize guard for device
platform
::
CUDADeviceGuard
cuda_guard
;
platform
::
CUDADeviceGuard
cuda_guard
;
...
@@ -418,7 +611,7 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in,
...
@@ -418,7 +611,7 @@ void ProcessGroupNCCL::Collective(const phi::DenseTensor* in,
platform
::
NCCLGroupGuard
nccl_guard
;
platform
::
NCCLGroupGuard
nccl_guard
;
cuda_guard
.
SetDevice
(
places
[
0
]);
cuda_guard
.
SetDevice
(
places
[
0
]);
const
auto
&
nccl_stream
=
places_to_ctx_
[
key
][
0
]
->
stream
();
const
auto
&
nccl_stream
=
places_to_ctx_
[
key
][
0
]
->
stream
();
fn
(
in
,
out
,
nccl_comms
[
0
]
->
GetNcclC
omm
(),
nccl_stream
);
fn
(
in
,
out
,
places_to_ctx_
[
key
][
0
]
->
nccl_c
omm
(),
nccl_stream
);
}
}
cuda_guard
.
SetDevice
(
places
[
0
]);
cuda_guard
.
SetDevice
(
places
[
0
]);
...
@@ -437,15 +630,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
...
@@ -437,15 +630,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
if
(
places_to_ncclcomm_
.
find
(
key
)
==
places_to_ncclcomm_
.
end
()
)
{
if
(
!
calc_event_
)
{
CreateNCCLManagerCache
(
key
,
places
);
CreateNCCLManagerCache
(
key
,
places
);
}
}
}
}
auto
&
nccl_comms
=
places_to_ncclcomm_
[
key
];
if
(
!
use_calc_stream
)
{
if
(
!
use_calc_stream
)
{
SyncDefaultStream
(
places
,
places_to_events_
[
key
]
,
places_to_ctx_
[
key
]);
SyncDefaultStream
(
places
,
calc_event_
,
places_to_ctx_
[
key
]);
}
}
auto
task
=
auto
task
=
...
@@ -466,7 +657,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
...
@@ -466,7 +657,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
}
else
{
}
else
{
nccl_stream
=
places_to_ctx_
[
key
][
i
]
->
stream
();
nccl_stream
=
places_to_ctx_
[
key
][
i
]
->
stream
();
}
}
fn
(
tensors
[
i
],
nccl_comms
[
i
]
->
GetNcclComm
(),
nccl_stream
,
dst_rank
);
fn
(
tensors
[
i
],
places_to_ctx_
[
key
][
i
]
->
nccl_comm
(),
nccl_stream
,
dst_rank
);
}
}
}
}
...
@@ -489,7 +683,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
...
@@ -489,7 +683,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
if
(
!
use_calc_stream
)
{
if
(
!
use_calc_stream
)
{
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
cuda_guard
.
SetDevice
(
places
[
i
]);
cuda_guard
.
SetDevice
(
places
[
i
]);
task
->
co
ntrol_events_
[
i
].
Record
(
*
places_to_ctx_
[
key
][
i
]);
task
->
co
mm_event_
.
Record
(
places_to_ctx_
[
key
][
i
]);
}
}
}
}
...
@@ -507,14 +701,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
...
@@ -507,14 +701,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
if
(
places_to_ncclcomm_
.
find
(
key
)
==
places_to_ncclcomm_
.
end
()
)
{
if
(
!
calc_event_
)
{
CreateNCCLManagerCache
(
key
,
places
);
CreateNCCLManagerCache
(
key
,
places
);
}
}
}
}
auto
&
nccl_comms
=
places_to_ncclcomm_
[
key
];
SyncDefaultStream
(
places
,
calc_event_
,
places_to_ctx_
[
key
]);
SyncDefaultStream
(
places
,
places_to_events_
[
key
],
places_to_ctx_
[
key
]);
auto
task
=
CreateTask
(
places
,
rank_
,
op_type
,
tensors
);
auto
task
=
CreateTask
(
places
,
rank_
,
op_type
,
tensors
);
...
@@ -526,7 +718,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
...
@@ -526,7 +718,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
cuda_guard
.
SetDevice
(
places
[
i
]);
cuda_guard
.
SetDevice
(
places
[
i
]);
const
auto
&
nccl_stream
=
places_to_ctx_
[
key
][
i
]
->
stream
();
const
auto
&
nccl_stream
=
places_to_ctx_
[
key
][
i
]
->
stream
();
fn
(
tensors
[
i
],
nccl_comms
[
i
]
->
GetNcclComm
(),
nccl_stream
,
dst_rank
);
fn
(
tensors
[
i
],
places_to_ctx_
[
key
][
i
]
->
nccl_comm
(),
nccl_stream
,
dst_rank
);
}
}
}
}
...
@@ -540,7 +735,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
...
@@ -540,7 +735,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::PointToPoint(
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
cuda_guard
.
SetDevice
(
places
[
i
]);
cuda_guard
.
SetDevice
(
places
[
i
]);
task
->
co
ntrol_events_
[
i
].
Record
(
*
places_to_ctx_
[
key
][
i
]);
task
->
co
mm_event_
.
Record
(
places_to_ctx_
[
key
][
i
]);
}
}
return
task
;
return
task
;
}
}
...
@@ -572,37 +767,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
...
@@ -572,37 +767,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
CommType
::
ALLREDUCE
);
CommType
::
ALLREDUCE
);
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
AllReduce
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
AllreduceOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
{
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
in_tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
return
Collective
(
in_tensors
,
out_tensors
,
[
&
](
const
phi
::
DenseTensor
&
input
,
phi
::
DenseTensor
&
output
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
return
platform
::
dynload
::
ncclAllReduce
(
input
.
data
(),
output
.
data
(),
input
.
numel
(),
platform
::
ToNCCLDataType
(
input
.
type
()),
ToNCCLRedType
(
opts
.
reduce_op
),
comm
,
stream
);
},
CommType
::
ALLREDUCE
,
sync_op
,
use_calc_stream
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Broadcast
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
...
@@ -633,63 +797,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
...
@@ -633,63 +797,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
CommType
::
BROADCAST
);
CommType
::
BROADCAST
);
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
BroadcastOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
{
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
in_tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
return
Collective
(
in_tensors
,
out_tensors
,
[
&
](
phi
::
DenseTensor
&
input
,
phi
::
DenseTensor
&
output
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
const
auto
root
=
opts
.
source_rank
*
in_tensors
.
size
()
+
opts
.
source_root
;
return
platform
::
dynload
::
ncclBroadcast
(
input
.
data
(),
output
.
data
(),
input
.
numel
(),
platform
::
ToNCCLDataType
(
input
.
type
()),
root
,
comm
,
stream
);
},
CommType
::
BROADCAST
,
sync_op
,
use_calc_stream
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Barrier
(
const
BarrierOptions
&
opts
)
{
// Only support single card single process
std
::
vector
<
phi
::
GPUPlace
>
places
=
{
place_
};
std
::
vector
<
phi
::
DenseTensor
>
barrierTensors
;
barrierTensors
.
reserve
(
places
.
size
());
platform
::
CUDADeviceGuard
gpuGuard
;
for
(
auto
&
place
:
places
)
{
gpuGuard
.
SetDeviceIndex
(
place
.
GetDeviceId
());
phi
::
DenseTensorMeta
meta
(
phi
::
DataType
::
FLOAT32
,
phi
::
DDim
({
1
}));
auto
allocator
=
std
::
unique_ptr
<
phi
::
Allocator
>
(
new
paddle
::
experimental
::
DefaultAllocator
(
place
));
barrierTensors
.
emplace_back
(
allocator
.
get
(),
meta
);
}
auto
task
=
ProcessGroupNCCL
::
AllReduce
(
barrierTensors
,
barrierTensors
,
AllreduceOptions
());
auto
nccl_task
=
dynamic_cast
<
ProcessGroupNCCL
::
NCCLTask
*>
(
task
.
get
());
nccl_task
->
barrierTensors_
=
std
::
move
(
barrierTensors
);
return
task
;
}
void
CheckTensorsInDifferentDevices
(
void
CheckTensorsInDifferentDevices
(
const
std
::
vector
<
phi
::
DenseTensor
>&
tensors
,
const
size_t
num_devices
)
{
const
std
::
vector
<
phi
::
DenseTensor
>&
tensors
,
const
size_t
num_devices
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
...
@@ -975,39 +1082,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
...
@@ -975,39 +1082,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
CommType
::
ALLGATHER
);
CommType
::
ALLGATHER
);
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
AllGather
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
bool
sync_op
,
bool
use_calc_stream
)
{
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
in_tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
out_tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All outputs should be in CudaPlace."
));
return
Collective
(
in_tensors
,
out_tensors
,
[
&
](
const
phi
::
DenseTensor
&
input
,
phi
::
DenseTensor
&
output
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
return
platform
::
dynload
::
ncclAllGather
(
input
.
data
(),
output
.
data
(),
input
.
numel
(),
platform
::
ToNCCLDataType
(
input
.
dtype
()),
comm
,
stream
);
},
CommType
::
ALLGATHER
,
sync_op
,
use_calc_stream
);
}
void
*
GetPointerByOffset
(
void
*
raw_pointer
,
void
*
GetPointerByOffset
(
void
*
raw_pointer
,
size_t
offset
,
size_t
offset
,
experimental
::
DataType
type
)
{
experimental
::
DataType
type
)
{
...
@@ -1578,43 +1652,5 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::_ReduceScatterBase(
...
@@ -1578,43 +1652,5 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::_ReduceScatterBase(
CommType
::
REDUCE_SCATTER
);
CommType
::
REDUCE_SCATTER
);
}
}
void
ProcessGroupNCCL
::
GroupStart
()
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupStart
());
}
void
ProcessGroupNCCL
::
GroupEnd
()
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupEnd
());
}
ncclComm_t
ProcessGroupNCCL
::
NCCLComm
(
const
Place
&
place
)
const
{
std
::
vector
<
Place
>
places
=
{
place
};
const
auto
&
iter
=
places_to_ncclcomm_
.
find
(
GetKeyFromPlaces
(
places
));
PADDLE_ENFORCE_NE
(
iter
,
places_to_ncclcomm_
.
end
(),
platform
::
errors
::
InvalidArgument
(
"Cannot find nccl comm in process group."
));
return
iter
->
second
[
0
]
->
GetNcclComm
();
}
const
phi
::
DeviceContext
&
ProcessGroupNCCL
::
GetDeviceContext
(
const
Place
&
place
)
const
{
return
GetDeviceContext
(
place
,
/*use_calc_stream*/
false
);
}
const
phi
::
DeviceContext
&
ProcessGroupNCCL
::
GetDeviceContext
(
const
Place
&
place
,
bool
use_calc_stream
)
const
{
if
(
use_calc_stream
)
{
return
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
}
else
{
std
::
vector
<
Place
>
places
=
{
place
};
const
auto
&
iter
=
places_to_ctx_
.
find
(
GetKeyFromPlaces
(
places
));
PADDLE_ENFORCE_NE
(
iter
,
places_to_ctx_
.
end
(),
platform
::
errors
::
InvalidArgument
(
"Cannot find device context in process group."
));
return
*
iter
->
second
[
0
];
}
}
}
// namespace distributed
}
// namespace distributed
}
// namespace paddle
}
// namespace paddle
paddle/fluid/distributed/collective/ProcessGroupNCCL.h
浏览文件 @
e1a1c354
...
@@ -24,10 +24,10 @@
...
@@ -24,10 +24,10 @@
#include "paddle/fluid/distributed/collective/ProcessGroupStream.h"
#include "paddle/fluid/distributed/collective/ProcessGroupStream.h"
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device_
contex
t.h"
#include "paddle/fluid/platform/device_
even
t.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/
fluid/platform/gen_comm_id_helper
.h"
#include "paddle/
phi/common/place
.h"
#include "paddle/
fluid/platform/place
.h"
#include "paddle/
phi/core/device_context
.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/NCCLTools.h"
#include "paddle/fluid/distributed/collective/NCCLTools.h"
...
@@ -44,16 +44,28 @@ namespace distributed {
...
@@ -44,16 +44,28 @@ namespace distributed {
using
Place
=
paddle
::
platform
::
Place
;
using
Place
=
paddle
::
platform
::
Place
;
class
ProcessGroupNCCL
:
public
ProcessGroupStream
{
class
ProcessGroupNCCL
final
:
public
ProcessGroupStream
{
public:
public:
class
NCCLTask
:
public
ProcessGroupStream
::
TaskStream
,
class
NCCLTask
final
:
public
ProcessGroupStream
::
TaskStream
,
public
std
::
enable_shared_from_this
<
NCCLTask
>
{
public
std
::
enable_shared_from_this
<
NCCLTask
>
{
public:
public:
NCCLTask
(
const
Place
&
place
,
int
rank
,
CommType
comm_type
,
bool
sync_op
,
bool
use_calc_stream
);
virtual
~
NCCLTask
();
bool
IsCompleted
()
override
;
bool
Wait
(
std
::
chrono
::
milliseconds
timeout
=
kWaitTimeout
)
override
;
void
Synchronize
()
override
;
void
UpdateWaitChain
(
const
phi
::
DeviceContext
&
ctx
)
override
;
// TODO(sunyilun): methods below will be removed later
NCCLTask
(
const
std
::
vector
<
Place
>&
places
,
NCCLTask
(
const
std
::
vector
<
Place
>&
places
,
int
rank
,
int
rank
,
CommType
CommType
,
CommType
CommType
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
);
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
);
NCCLTask
(
const
std
::
vector
<
Place
>&
places
,
NCCLTask
(
const
std
::
vector
<
Place
>&
places
,
int
rank
,
int
rank
,
CommType
comm_type
,
CommType
comm_type
,
...
@@ -61,31 +73,15 @@ class ProcessGroupNCCL : public ProcessGroupStream {
...
@@ -61,31 +73,15 @@ class ProcessGroupNCCL : public ProcessGroupStream {
bool
sync_op
,
bool
sync_op
,
bool
use_calc_stream
);
bool
use_calc_stream
);
bool
IsCompleted
();
public:
bool
barrier_
{
false
};
void
SynchronizeStreams
();
platform
::
DeviceEvent
comm_event_
;
// event on comm stream
bool
Wait
(
std
::
chrono
::
milliseconds
timeout
=
kWaitTimeout
);
void
Synchronize
();
void
SetOutputs
(
std
::
vector
<
phi
::
DenseTensor
>&
outputs
);
// NOLINT
virtual
~
NCCLTask
();
void
UpdateWaitChain
(
const
phi
::
DeviceContext
&
ctx
)
override
;
std
::
vector
<
EventManager
>
control_events_
;
std
::
vector
<
phi
::
DenseTensor
>
barrierTensors_
;
protected:
std
::
vector
<
Place
>
places_
;
std
::
vector
<
std
::
shared_ptr
<
NCCLCommManager
>>
ncclComms_
;
std
::
shared_ptr
<
std
::
vector
<
phi
::
DenseTensor
>>
outputs_
;
private:
private:
Place
place_
;
};
};
public:
ProcessGroupNCCL
(
const
std
::
shared_ptr
<
Store
>&
store
,
ProcessGroupNCCL
(
const
std
::
shared_ptr
<
Store
>&
store
,
int
rank
,
int
rank
,
int
size
,
int
size
,
...
@@ -99,34 +95,47 @@ class ProcessGroupNCCL : public ProcessGroupStream {
...
@@ -99,34 +95,47 @@ class ProcessGroupNCCL : public ProcessGroupStream {
const
phi
::
DeviceContext
&
GetDeviceContext
(
const
phi
::
DeviceContext
&
GetDeviceContext
(
const
Place
&
place
,
bool
use_calc_stream
)
const
override
;
const
Place
&
place
,
bool
use_calc_stream
)
const
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
bool
sync_op
,
bool
use_calc_stream
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
phi
::
DenseTensor
*
out_tensor
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
const
phi
::
DenseTensor
&
in_tensor
,
const
AllreduceOptions
&
options
,
const
AllreduceOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Barrier
(
const
BarrierOptions
&
=
BarrierOptions
())
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
BroadcastOptions
&
opts
,
bool
sync_op
,
bool
sync_op
,
bool
use_calc_stream
)
override
;
bool
use_calc_stream
)
override
;
static
void
GroupStart
();
static
void
GroupEnd
();
ncclComm_t
NCCLComm
(
const
Place
&
place
)
const
;
// TODO(liyurui): This API will be moved later
// TODO(liyurui): This API will be moved later
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
AllreduceOptions
&
=
AllreduceOptions
())
override
;
const
AllreduceOptions
&
=
AllreduceOptions
())
override
;
// TODO(sunyilun): methods below will be removed later
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
BroadcastOptions
&
=
BroadcastOptions
())
override
;
const
BroadcastOptions
&
=
BroadcastOptions
())
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
BroadcastOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Barrier
(
const
BarrierOptions
&
=
BarrierOptions
())
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Send
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Send
(
std
::
vector
<
phi
::
DenseTensor
>&
tensors
,
int
dst_rank
)
override
;
std
::
vector
<
phi
::
DenseTensor
>&
tensors
,
int
dst_rank
)
override
;
...
@@ -175,12 +184,6 @@ class ProcessGroupNCCL : public ProcessGroupStream {
...
@@ -175,12 +184,6 @@ class ProcessGroupNCCL : public ProcessGroupStream {
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
)
override
;
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
bool
sync_op
,
bool
use_calc_stream
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather_Partial
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather_Partial
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
...
@@ -255,20 +258,37 @@ class ProcessGroupNCCL : public ProcessGroupStream {
...
@@ -255,20 +258,37 @@ class ProcessGroupNCCL : public ProcessGroupStream {
phi
::
DenseTensor
&
,
// NOLINT
phi
::
DenseTensor
&
,
// NOLINT
const
ReduceScatterOptions
&
)
override
;
const
ReduceScatterOptions
&
)
override
;
static
void
GroupStart
();
private:
std
::
shared_ptr
<
ProcessGroupNCCL
::
NCCLTask
>
CreateTask
(
const
Place
&
place
,
int
rank
,
CommType
op_type
,
bool
sync_op
,
bool
use_calc_stream
);
static
void
GroupEnd
(
);
void
BroadcastUniqueNCCLID
(
ncclUniqueId
*
nccl_id
);
ncclComm_t
NCCLComm
(
const
Place
&
place
)
const
;
void
CreateNCCLEnvCache
(
const
Place
&
place
,
const
std
::
string
&
place_key
);
template
<
typename
Fn
>
std
::
shared_ptr
<
ProcessGroupStream
::
Task
>
Collective
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
Fn
fn
,
CommType
comm_type
,
bool
sync_op
,
bool
use_calc_stream
);
protected:
void
SyncCalcStream
(
const
Place
&
place
,
virtual
std
::
shared_ptr
<
ProcessGroupNCCL
::
NCCLTask
>
CreateTask
(
const
std
::
shared_ptr
<
platform
::
DeviceEvent
>&
event
);
// TODO(sunyilun): methods below will be removed later
std
::
shared_ptr
<
ProcessGroupNCCL
::
NCCLTask
>
CreateTask
(
std
::
vector
<
Place
>
places
,
std
::
vector
<
Place
>
places
,
int
rank
,
int
rank
,
CommType
op_type
,
CommType
op_type
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
);
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
);
virtual
std
::
shared_ptr
<
ProcessGroupNCCL
::
NCCLTask
>
CreateTask
(
std
::
shared_ptr
<
ProcessGroupNCCL
::
NCCLTask
>
CreateTask
(
const
std
::
vector
<
Place
>&
places
,
const
std
::
vector
<
Place
>&
places
,
int
rank
,
int
rank
,
CommType
op_type
,
CommType
op_type
,
...
@@ -276,27 +296,6 @@ class ProcessGroupNCCL : public ProcessGroupStream {
...
@@ -276,27 +296,6 @@ class ProcessGroupNCCL : public ProcessGroupStream {
bool
sync_op
,
bool
sync_op
,
bool
use_calc_stream
);
bool
use_calc_stream
);
protected:
std
::
shared_ptr
<
Store
>
store_
;
std
::
shared_ptr
<
NCCLCommManager
>
nccl_comm_
;
std
::
mutex
mutex_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
NCCLCommManager
>>>
places_to_ncclcomm_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
EventManager
>>
places_to_events_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
phi
::
GPUContext
>>>
places_to_ctx_
;
std
::
set
<
int
>
used_place_ids_
;
private:
void
BcastNCCLId
(
std
::
vector
<
ncclUniqueId
>&
nccl_ids
,
// NOLINT
int
root
,
// NOLINT
int
server_fd
);
void
BroadcastUniqueNCCLID
(
std
::
vector
<
ncclUniqueId
>&
nccl_ids
);
// NOLINT
template
<
typename
Fn
>
template
<
typename
Fn
>
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Collective
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Collective
(
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
// NOLINT
...
@@ -340,6 +339,17 @@ class ProcessGroupNCCL : public ProcessGroupStream {
...
@@ -340,6 +339,17 @@ class ProcessGroupNCCL : public ProcessGroupStream {
void
CheckSplitSizes
(
std
::
vector
<
int64_t
>*
split_sizes
,
void
CheckSplitSizes
(
std
::
vector
<
int64_t
>*
split_sizes
,
std
::
vector
<
int64_t
>
tensor_shape
);
std
::
vector
<
int64_t
>
tensor_shape
);
private:
std
::
shared_ptr
<
Store
>
store_
;
std
::
shared_ptr
<
platform
::
DeviceEvent
>
calc_event_
;
// event on calc stream
std
::
unordered_map
<
std
::
string
,
phi
::
GPUContext
*>
place_to_calc_ctx_
;
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
phi
::
GPUContext
>>
place_to_comm_ctx_
;
// TODO(sunyilun): attrs below will be removed later
std
::
mutex
mutex_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
phi
::
GPUContext
*>>
places_to_ctx_
;
};
};
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/collective/ProcessGroupStream.cc
浏览文件 @
e1a1c354
...
@@ -30,18 +30,18 @@ const phi::DeviceContext& ProcessGroupStream::GetDeviceContext(
...
@@ -30,18 +30,18 @@ const phi::DeviceContext& ProcessGroupStream::GetDeviceContext(
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllGather
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllGather
(
std
::
vector
<
phi
::
DenseTensor
>&
input_tensors
,
// NOLINT
phi
::
DenseTensor
*
out_tensor
,
std
::
vector
<
phi
::
DenseTensor
>&
output_tensors
,
// NOLINT
const
phi
::
DenseTensor
&
in_tensor
,
bool
sync_op
)
{
bool
sync_op
)
{
return
AllGather
(
input_tensors
,
return
AllGather
(
out_tensor
,
output_tensors
,
in_tensor
,
sync_op
,
sync_op
,
/*use_calc_stream*/
false
);
/*use_calc_stream*/
false
);
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllGather
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllGather
(
std
::
vector
<
phi
::
DenseTensor
>&
input_tensors
,
// NOLINT
phi
::
DenseTensor
*
out_tensor
,
std
::
vector
<
phi
::
DenseTensor
>&
output_tensors
,
// NOLINT
const
phi
::
DenseTensor
&
in_tensor
,
bool
sync_op
,
bool
sync_op
,
bool
use_calc_stream
)
{
bool
use_calc_stream
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
...
@@ -49,27 +49,50 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
...
@@ -49,27 +49,50 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllReduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllReduce
(
std
::
vector
<
phi
::
DenseTensor
>&
input_tensors
,
// NOLINT
phi
::
DenseTensor
*
out_tensor
,
std
::
vector
<
phi
::
DenseTensor
>&
output_tensors
,
// NOLINT
const
phi
::
DenseTensor
&
in_tensor
,
const
AllreduceOptions
&
opt
ion
s
,
const
AllreduceOptions
&
opts
,
bool
sync_op
)
{
bool
sync_op
)
{
return
AllReduce
(
input_tensors
,
return
AllReduce
(
out_tensor
,
output_tensors
,
in_tensor
,
opt
ion
s
,
opts
,
sync_op
,
sync_op
,
/*use_calc_stream*/
false
);
/*use_calc_stream*/
false
);
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllReduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllReduce
(
std
::
vector
<
phi
::
DenseTensor
>&
input_tensors
,
// NOLINT
phi
::
DenseTensor
*
out_tensor
,
std
::
vector
<
phi
::
DenseTensor
>&
output_tensors
,
// NOLINT
const
phi
::
DenseTensor
&
in_tensor
,
const
AllreduceOptions
&
opt
ion
s
,
const
AllreduceOptions
&
opts
,
bool
sync_op
,
bool
sync_op
,
bool
use_calc_stream
)
{
bool
use_calc_stream
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support do all_reduce"
,
GetBackendName
()));
"ProcessGroup%s does not support do all_reduce"
,
GetBackendName
()));
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Broadcast
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
BroadcastOptions
&
opts
,
bool
sync_op
)
{
return
Broadcast
(
out_tensor
,
in_tensor
,
opts
,
sync_op
,
/*use_calc_stream*/
false
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Broadcast
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
BroadcastOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support do broadcast"
,
GetBackendName
()));
}
// TODO(sunyilun): methods below will be removed later
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllToAll
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllToAll
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
...
@@ -114,28 +137,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAllSingle(
...
@@ -114,28 +137,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAllSingle(
"ProcessGroup%s does not support do alltoall_single"
,
GetBackendName
()));
"ProcessGroup%s does not support do alltoall_single"
,
GetBackendName
()));
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
BroadcastOptions
&
opts
,
bool
sync_op
)
{
return
Broadcast
(
in_tensors
,
out_tensors
,
opts
,
sync_op
,
/*use_calc_stream*/
false
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
BroadcastOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support do broadcast"
,
GetBackendName
()));
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Reduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Reduce
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
...
...
paddle/fluid/distributed/collective/ProcessGroupStream.h
浏览文件 @
e1a1c354
...
@@ -27,6 +27,11 @@ class ProcessGroupStream : public ProcessGroup {
...
@@ -27,6 +27,11 @@ class ProcessGroupStream : public ProcessGroup {
public:
public:
class
TaskStream
:
public
ProcessGroup
::
Task
{
class
TaskStream
:
public
ProcessGroup
::
Task
{
public:
public:
TaskStream
(
int
rank
,
CommType
comm_type
,
bool
sync_op
,
bool
use_calc_stream
)
:
Task
(
rank
,
comm_type
,
sync_op
),
use_calc_stream_
(
use_calc_stream
)
{}
virtual
~
TaskStream
()
=
default
;
// TODO(liyurui): This constructor is temporary here for compatible reason,
// TODO(liyurui): This constructor is temporary here for compatible reason,
// will be deleted soon.
// will be deleted soon.
TaskStream
(
int
rank
,
TaskStream
(
int
rank
,
...
@@ -42,8 +47,6 @@ class ProcessGroupStream : public ProcessGroup {
...
@@ -42,8 +47,6 @@ class ProcessGroupStream : public ProcessGroup {
:
Task
(
rank
,
inputs
,
comm_type
,
sync_op
),
:
Task
(
rank
,
inputs
,
comm_type
,
sync_op
),
use_calc_stream_
(
use_calc_stream
)
{}
use_calc_stream_
(
use_calc_stream
)
{}
virtual
~
TaskStream
()
=
default
;
protected:
protected:
bool
UseCalcStream
()
const
{
return
use_calc_stream_
;
}
bool
UseCalcStream
()
const
{
return
use_calc_stream_
;
}
...
@@ -51,6 +54,7 @@ class ProcessGroupStream : public ProcessGroup {
...
@@ -51,6 +54,7 @@ class ProcessGroupStream : public ProcessGroup {
bool
use_calc_stream_
{
false
};
bool
use_calc_stream_
{
false
};
};
};
public:
ProcessGroupStream
(
int
rank
,
int
size
,
const
platform
::
Place
&
place
,
int
gid
);
ProcessGroupStream
(
int
rank
,
int
size
,
const
platform
::
Place
&
place
,
int
gid
);
virtual
~
ProcessGroupStream
()
=
default
;
virtual
~
ProcessGroupStream
()
=
default
;
...
@@ -58,29 +62,43 @@ class ProcessGroupStream : public ProcessGroup {
...
@@ -58,29 +62,43 @@ class ProcessGroupStream : public ProcessGroup {
const
Place
&
place
,
bool
use_calc_stream
)
const
;
const
Place
&
place
,
bool
use_calc_stream
)
const
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
phi
::
DenseTensor
*
out_tensor
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
const
phi
::
DenseTensor
&
in_tensor
,
bool
sync_op
)
override
;
bool
sync_op
)
override
;
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
phi
::
DenseTensor
*
out_tensor
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
const
phi
::
DenseTensor
&
in_tensor
,
bool
sync_op
,
bool
sync_op
,
bool
use_calc_stream
);
bool
use_calc_stream
);
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
vector
<
phi
::
DenseTensor
>&
input_tensors
,
// NOLINT
phi
::
DenseTensor
*
out_tensor
,
std
::
vector
<
phi
::
DenseTensor
>&
output_tensors
,
// NOLINT
const
phi
::
DenseTensor
&
in_tensor
,
const
AllreduceOptions
&
opt
ion
s
,
const
AllreduceOptions
&
opts
,
bool
sync_op
)
override
;
bool
sync_op
)
override
;
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
vector
<
phi
::
DenseTensor
>&
input_tensors
,
// NOLINT
phi
::
DenseTensor
*
out_tensor
,
std
::
vector
<
phi
::
DenseTensor
>&
output_tensors
,
// NOLINT
const
phi
::
DenseTensor
&
in_tensor
,
const
AllreduceOptions
&
options
,
const
AllreduceOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
);
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
BroadcastOptions
&
opts
,
bool
sync_op
)
override
;
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
BroadcastOptions
&
opts
,
bool
sync_op
,
bool
sync_op
,
bool
use_calc_stream
);
bool
use_calc_stream
);
// TODO(sunyilun): methods below will be removed later
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllToAll
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllToAll
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
...
@@ -107,19 +125,6 @@ class ProcessGroupStream : public ProcessGroup {
...
@@ -107,19 +125,6 @@ class ProcessGroupStream : public ProcessGroup {
bool
sync_op
,
bool
sync_op
,
bool
use_calc_stream
);
bool
use_calc_stream
);
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
const
BroadcastOptions
&
opts
,
bool
sync_op
)
override
;
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
const
BroadcastOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
);
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Reduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Reduce
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
...
...
paddle/fluid/operators/fused/fused_attention_op.cu
浏览文件 @
e1a1c354
...
@@ -51,14 +51,9 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT
...
@@ -51,14 +51,9 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT
if
(
map
->
has
(
ring_id
))
{
if
(
map
->
has
(
ring_id
))
{
paddle
::
distributed
::
ProcessGroup
*
pg
=
map
->
get
(
ring_id
);
paddle
::
distributed
::
ProcessGroup
*
pg
=
map
->
get
(
ring_id
);
auto
pg_nccl
=
static_cast
<
distributed
::
ProcessGroupNCCL
*>
(
pg
);
auto
pg_nccl
=
static_cast
<
distributed
::
ProcessGroupNCCL
*>
(
pg
);
std
::
vector
<
phi
::
DenseTensor
>
in_tensor
;
std
::
vector
<
phi
::
DenseTensor
>
out_tensor
;
in_tensor
.
push_back
(
tensor
);
out_tensor
.
push_back
(
tensor
);
paddle
::
distributed
::
AllreduceOptions
opts
;
paddle
::
distributed
::
AllreduceOptions
opts
;
opts
.
reduce_op
=
distributed
::
ReduceOp
::
SUM
;
opts
.
reduce_op
=
distributed
::
ReduceOp
::
SUM
;
auto
task
=
pg_nccl
->
AllReduce
(
in_tensor
,
out_
tensor
,
opts
,
true
,
true
);
auto
task
=
pg_nccl
->
AllReduce
(
&
tensor
,
tensor
,
opts
,
true
,
true
);
task
->
Wait
();
task
->
Wait
();
}
else
{
}
else
{
auto
dtype
=
platform
::
ToNCCLDataType
(
auto
dtype
=
platform
::
ToNCCLDataType
(
...
...
paddle/fluid/operators/fused/fused_feedforward_op.cu
浏览文件 @
e1a1c354
...
@@ -44,14 +44,9 @@ static void AllReduce(phi::DenseTensor& tensor, // NOLINT
...
@@ -44,14 +44,9 @@ static void AllReduce(phi::DenseTensor& tensor, // NOLINT
if
(
map
->
has
(
ring_id
))
{
if
(
map
->
has
(
ring_id
))
{
paddle
::
distributed
::
ProcessGroup
*
pg
=
map
->
get
(
ring_id
);
paddle
::
distributed
::
ProcessGroup
*
pg
=
map
->
get
(
ring_id
);
auto
pg_nccl
=
static_cast
<
distributed
::
ProcessGroupNCCL
*>
(
pg
);
auto
pg_nccl
=
static_cast
<
distributed
::
ProcessGroupNCCL
*>
(
pg
);
std
::
vector
<
phi
::
DenseTensor
>
in_tensor
;
std
::
vector
<
phi
::
DenseTensor
>
out_tensor
;
in_tensor
.
push_back
(
tensor
);
out_tensor
.
push_back
(
tensor
);
paddle
::
distributed
::
AllreduceOptions
opts
;
paddle
::
distributed
::
AllreduceOptions
opts
;
opts
.
reduce_op
=
distributed
::
ReduceOp
::
SUM
;
opts
.
reduce_op
=
distributed
::
ReduceOp
::
SUM
;
auto
task
=
pg_nccl
->
AllReduce
(
in_tensor
,
out_
tensor
,
opts
,
true
,
true
);
auto
task
=
pg_nccl
->
AllReduce
(
&
tensor
,
tensor
,
opts
,
true
,
true
);
task
->
Wait
();
task
->
Wait
();
}
else
{
}
else
{
auto
dtype
=
platform
::
ToNCCLDataType
(
auto
dtype
=
platform
::
ToNCCLDataType
(
...
...
paddle/fluid/pybind/distributed_py.cc
浏览文件 @
e1a1c354
...
@@ -147,12 +147,12 @@ void BindDistributed(py::module *m) {
...
@@ -147,12 +147,12 @@ void BindDistributed(py::module *m) {
distributed
::
ReduceOp
op
,
distributed
::
ReduceOp
op
,
bool
sync_op
)
{
bool
sync_op
)
{
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
distributed
::
AllreduceOptions
opts
;
auto
p_dense
=
opts
.
reduce_op
=
op
;
auto
dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
.
impl
());
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
tensors
=
{
*
dense
};
auto
*
out_dense
=
p_dense
.
get
();
return
self
.
AllReduce
(
tensors
,
tensors
,
opts
,
sync_op
);
auto
in_dense
=
*
p_dense
;
distributed
::
AllreduceOptions
opts
{
op
};
return
self
.
AllReduce
(
out_dense
,
in_dense
,
opts
,
sync_op
);
},
},
py
::
arg
(
"tensor"
),
py
::
arg
(
"tensor"
),
py
::
arg
(
"op"
),
py
::
arg
(
"op"
),
...
@@ -183,11 +183,12 @@ void BindDistributed(py::module *m) {
...
@@ -183,11 +183,12 @@ void BindDistributed(py::module *m) {
int
src
,
int
src
,
bool
sync_op
)
{
bool
sync_op
)
{
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
distributed
::
BroadcastOptions
opts
{
src
};
auto
p_dense
=
auto
dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
.
impl
());
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
tensors
=
{
*
dense
};
auto
*
out_dense
=
p_dense
.
get
();
return
self
.
Broadcast
(
tensors
,
tensors
,
opts
,
sync_op
);
auto
in_dense
=
*
p_dense
;
distributed
::
BroadcastOptions
opts
{
src
};
return
self
.
Broadcast
(
out_dense
,
in_dense
,
opts
,
sync_op
);
},
},
py
::
arg
(
"tensor"
),
py
::
arg
(
"tensor"
),
py
::
arg
(
"src"
),
py
::
arg
(
"src"
),
...
@@ -380,52 +381,52 @@ void BindDistributed(py::module *m) {
...
@@ -380,52 +381,52 @@ void BindDistributed(py::module *m) {
.
def
(
.
def
(
"allgather"
,
"allgather"
,
[](
distributed
::
ProcessGroup
&
self
,
[](
distributed
::
ProcessGroup
&
self
,
py
::
handle
py_in_tensor
,
py
::
handle
py_out_tensor_list
,
py
::
handle
py_out_tensor_list
,
py
::
handle
py_in_tensor
,
bool
sync_op
)
{
bool
sync_op
)
{
auto
in_tensor
=
CastPyArg2Tensor
(
py_in_tensor
.
ptr
(),
0
);
auto
in_dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
in_tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
in_wrapper
=
{
*
in_dense
};
auto
out_tensor_list
=
auto
out_tensor_list
=
CastPyArg2VectorOfTensor
(
py_out_tensor_list
.
ptr
(),
0
);
CastPyArg2VectorOfTensor
(
py_out_tensor_list
.
ptr
(),
0
);
Tensor
concat_out_tensor
=
paddle
::
concat
(
out_tensor_list
,
0
);
Tensor
concat_out_tensor
=
paddle
::
concat
(
out_tensor_list
,
0
);
auto
out_dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
auto
p_out_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
concat_out_tensor
.
impl
());
concat_out_tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
out_wrapper
=
{
*
out_dense
};
auto
*
out_dense
=
p_out_tensor
.
get
();
auto
in_tensor
=
CastPyArg2Tensor
(
py_in_tensor
.
ptr
(),
0
);
auto
p_in_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
in_tensor
.
impl
());
auto
in_dense
=
*
p_in_tensor
;
const
auto
&
dev_ctx
=
self
.
GetDeviceContext
(
in_tensor
.
place
());
const
auto
&
dev_ctx
=
self
.
GetDeviceContext
(
in_tensor
.
place
());
auto
task
=
self
.
AllGather
(
in_wrapper
,
out_wrapper
,
sync_op
);
auto
task
=
self
.
AllGather
(
out_dense
,
in_dense
,
sync_op
);
distributed
::
SplitTensor
(
dev_ctx
,
*
out_dense
,
&
out_tensor_list
);
distributed
::
SplitTensor
(
dev_ctx
,
*
out_dense
,
&
out_tensor_list
);
task
->
UpdateWaitChain
(
dev_ctx
);
task
->
UpdateWaitChain
(
dev_ctx
);
return
task
;
return
task
;
},
},
py
::
arg
(
"in"
),
py
::
arg
(
"out"
),
py
::
arg
(
"out"
),
py
::
arg
(
"in"
),
py
::
arg
(
"sync_op"
),
py
::
arg
(
"sync_op"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
.
def
(
"allgather_into_tensor"
,
"allgather_into_tensor"
,
[](
distributed
::
ProcessGroup
&
self
,
[](
distributed
::
ProcessGroup
&
self
,
py
::
handle
py_in_tensor
,
py
::
handle
py_out_tensor
,
py
::
handle
py_out_tensor
,
py
::
handle
py_in_tensor
,
bool
sync_op
)
{
bool
sync_op
)
{
auto
in_tensor
=
CastPyArg2Tensor
(
py_in_tensor
.
ptr
(),
0
);
auto
in_dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
in_tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
in_wrapper
=
{
*
in_dense
};
auto
out_tensor
=
CastPyArg2Tensor
(
py_out_tensor
.
ptr
(),
0
);
auto
out_tensor
=
CastPyArg2Tensor
(
py_out_tensor
.
ptr
(),
0
);
auto
out_dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
auto
p_out_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
out_tensor
.
impl
());
out_tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
out_wrapper
=
{
*
out_dense
};
auto
*
out_dense
=
p_out_tensor
.
get
();
auto
in_tensor
=
CastPyArg2Tensor
(
py_in_tensor
.
ptr
(),
0
);
auto
p_in_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
in_tensor
.
impl
());
auto
in_dense
=
*
p_in_tensor
;
return
self
.
AllGather
(
in_wrapper
,
out_wrapper
,
sync_op
);
return
self
.
AllGather
(
out_dense
,
in_dense
,
sync_op
);
},
},
py
::
arg
(
"in"
),
py
::
arg
(
"out"
),
py
::
arg
(
"out"
),
py
::
arg
(
"in"
),
py
::
arg
(
"sync_op"
),
py
::
arg
(
"sync_op"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
...
@@ -784,55 +785,55 @@ void BindDistributed(py::module *m) {
...
@@ -784,55 +785,55 @@ void BindDistributed(py::module *m) {
.
def
(
.
def
(
"allgather_on_calc_stream"
,
"allgather_on_calc_stream"
,
[](
distributed
::
ProcessGroupStream
&
self
,
[](
distributed
::
ProcessGroupStream
&
self
,
py
::
handle
py_in_tensor
,
py
::
handle
py_out_tensor_list
,
py
::
handle
py_out_tensor_list
)
{
py
::
handle
py_in_tensor
)
{
auto
in_tensor
=
CastPyArg2Tensor
(
py_in_tensor
.
ptr
(),
0
);
auto
in_dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
in_tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
in_wrapper
=
{
*
in_dense
};
auto
out_tensor_list
=
auto
out_tensor_list
=
CastPyArg2VectorOfTensor
(
py_out_tensor_list
.
ptr
(),
0
);
CastPyArg2VectorOfTensor
(
py_out_tensor_list
.
ptr
(),
0
);
Tensor
concat_out_tensor
=
paddle
::
concat
(
out_tensor_list
,
0
);
Tensor
concat_out_tensor
=
paddle
::
concat
(
out_tensor_list
,
0
);
auto
out_dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
auto
p_out_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
concat_out_tensor
.
impl
());
concat_out_tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
out_wrapper
=
{
*
out_dense
};
auto
*
out_dense
=
p_out_tensor
.
get
();
auto
in_tensor
=
CastPyArg2Tensor
(
py_in_tensor
.
ptr
(),
0
);
auto
p_in_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
in_tensor
.
impl
());
auto
in_dense
=
*
p_in_tensor
;
const
auto
&
dev_ctx
=
const
auto
&
dev_ctx
=
self
.
GetDeviceContext
(
in_tensor
.
place
(),
true
);
self
.
GetDeviceContext
(
in_tensor
.
place
(),
true
);
auto
task
=
self
.
AllGather
(
in_wrapper
,
auto
task
=
self
.
AllGather
(
out_dense
,
out_wrapper
,
in_dense
,
/*sync_op*/
true
,
/*sync_op*/
true
,
/*use_calc_stream*/
true
);
/*use_calc_stream*/
true
);
distributed
::
SplitTensor
(
dev_ctx
,
*
out_dense
,
&
out_tensor_list
);
distributed
::
SplitTensor
(
dev_ctx
,
*
out_dense
,
&
out_tensor_list
);
return
task
;
return
task
;
},
},
py
::
arg
(
"in"
),
py
::
arg
(
"out"
),
py
::
arg
(
"out"
),
py
::
arg
(
"in"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
.
def
(
"allgather_into_tensor_on_calc_stream"
,
"allgather_into_tensor_on_calc_stream"
,
[](
distributed
::
ProcessGroupStream
&
self
,
[](
distributed
::
ProcessGroupStream
&
self
,
py
::
handle
py_in_tensor
,
py
::
handle
py_out_tensor
,
py
::
handle
py_out_tensor
)
{
py
::
handle
py_in_tensor
)
{
auto
in_tensor
=
CastPyArg2Tensor
(
py_in_tensor
.
ptr
(),
0
);
auto
in_dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
in_tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
in_wrapper
=
{
*
in_dense
};
auto
out_tensor
=
CastPyArg2Tensor
(
py_out_tensor
.
ptr
(),
0
);
auto
out_tensor
=
CastPyArg2Tensor
(
py_out_tensor
.
ptr
(),
0
);
auto
out_dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
auto
p_out_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
out_tensor
.
impl
());
out_tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
out_wrapper
=
{
*
out_dense
}
;
auto
*
out_dense
=
p_out_tensor
.
get
()
;
return
self
.
AllGather
(
in_wrapper
,
auto
in_tensor
=
CastPyArg2Tensor
(
py_in_tensor
.
ptr
(),
0
);
out_wrapper
,
auto
p_in_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
in_tensor
.
impl
());
auto
in_dense
=
*
p_in_tensor
;
return
self
.
AllGather
(
out_dense
,
in_dense
,
/*sync_op*/
true
,
/*sync_op*/
true
,
/*use_calc_stream*/
true
);
/*use_calc_stream*/
true
);
},
},
py
::
arg
(
"in"
),
py
::
arg
(
"out"
),
py
::
arg
(
"out"
),
py
::
arg
(
"in"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
.
def
(
...
@@ -872,13 +873,13 @@ void BindDistributed(py::module *m) {
...
@@ -872,13 +873,13 @@ void BindDistributed(py::module *m) {
py
::
handle
py_tensor
,
py
::
handle
py_tensor
,
distributed
::
ReduceOp
op
)
{
distributed
::
ReduceOp
op
)
{
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
distributed
::
AllreduceOptions
opts
;
auto
p_dense
=
opts
.
reduce_op
=
op
;
auto
dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
.
impl
());
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
tensors
=
{
*
dense
};
auto
in_dense
=
*
p_dense
;
return
self
.
AllReduce
(
tensors
,
auto
*
out_dense
=
p_dense
.
get
();
tensors
,
distributed
::
AllreduceOptions
opts
{
op
};
return
self
.
AllReduce
(
out_dense
,
in_dense
,
opts
,
opts
,
/*sync_op*/
true
,
/*sync_op*/
true
,
/*use_calc_stream*/
true
);
/*use_calc_stream*/
true
);
...
@@ -980,12 +981,13 @@ void BindDistributed(py::module *m) {
...
@@ -980,12 +981,13 @@ void BindDistributed(py::module *m) {
py
::
handle
py_tensor
,
py
::
handle
py_tensor
,
int
src
)
{
int
src
)
{
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
distributed
::
BroadcastOptions
opts
{
src
};
auto
p_dense
=
auto
dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
.
impl
());
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
tensors
=
{
*
dense
};
auto
*
out_dense
=
p_dense
.
get
();
return
self
.
Broadcast
(
tensors
,
auto
in_dense
=
*
p_dense
;
tensors
,
distributed
::
BroadcastOptions
opts
{
src
};
return
self
.
Broadcast
(
out_dense
,
in_dense
,
opts
,
opts
,
/*sync_op*/
true
,
/*sync_op*/
true
,
/*use_calc_stream*/
true
);
/*use_calc_stream*/
true
);
...
...
python/paddle/distributed/communication/stream/all_gather.py
浏览文件 @
e1a1c354
...
@@ -21,18 +21,18 @@ def _check_tensor_shape(tensor, shape, nranks=1):
...
@@ -21,18 +21,18 @@ def _check_tensor_shape(tensor, shape, nranks=1):
expect_shape
=
list
(
shape
)
expect_shape
=
list
(
shape
)
expect_shape
[
0
]
*=
nranks
expect_shape
[
0
]
*=
nranks
if
list
(
tensor
.
shape
)
!=
expect_shape
:
if
list
(
tensor
.
shape
)
!=
expect_shape
:
raise
RuntimeError
(
'The tensor for all_gather is not correctly-sized.'
)
raise
RuntimeError
(
"The tensor for all_gather is not correctly-sized."
)
def
_check_tensor_list_shape
(
tensor_list
,
shape
,
nranks
=
1
):
def
_check_tensor_list_shape
(
tensor_list
,
shape
,
nranks
=
1
):
if
len
(
tensor_list
)
!=
nranks
:
if
len
(
tensor_list
)
!=
nranks
:
raise
RuntimeError
(
raise
RuntimeError
(
'The tensor_list for all_gather is not correctly-sized.'
"The tensor_list for all_gather is not correctly-sized."
)
)
for
tensor
in
tensor_list
:
for
tensor
in
tensor_list
:
if
tensor
.
shape
!=
shape
:
if
tensor
.
shape
!=
shape
:
raise
RuntimeError
(
raise
RuntimeError
(
'The tensor_list for all_gather is not correctly-sized.'
"The tensor_list for all_gather is not correctly-sized."
)
)
...
@@ -45,11 +45,12 @@ def _all_gather_into_tensor_in_dygraph(
...
@@ -45,11 +45,12 @@ def _all_gather_into_tensor_in_dygraph(
if
use_calc_stream
:
if
use_calc_stream
:
return
group
.
process_group
.
allgather_into_tensor_on_calc_stream
(
return
group
.
process_group
.
allgather_into_tensor_on_calc_stream
(
in_tensor
,
out_tensor
out_tensor
,
in_tensor
,
)
)
task
=
group
.
process_group
.
allgather_into_tensor
(
task
=
group
.
process_group
.
allgather_into_tensor
(
in_tensor
,
out
_tensor
,
sync_op
out_tensor
,
in
_tensor
,
sync_op
)
)
if
sync_op
:
if
sync_op
:
task
.
wait
()
task
.
wait
()
...
@@ -68,9 +69,9 @@ def _all_gather_in_dygraph(
...
@@ -68,9 +69,9 @@ def _all_gather_in_dygraph(
_check_tensor_list_shape
(
tensor_list
,
tensor
.
shape
,
group
.
nranks
)
_check_tensor_list_shape
(
tensor_list
,
tensor
.
shape
,
group
.
nranks
)
if
use_calc_stream
:
if
use_calc_stream
:
return
group
.
process_group
.
allgather_on_calc_stream
(
tensor
,
tensor_list
)
return
group
.
process_group
.
allgather_on_calc_stream
(
tensor
_list
,
tensor
)
task
=
group
.
process_group
.
allgather
(
tensor
,
tensor_list
,
sync_op
)
task
=
group
.
process_group
.
allgather
(
tensor
_list
,
tensor
,
sync_op
)
if
sync_op
:
if
sync_op
:
task
.
wait
()
task
.
wait
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录