Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e1a1c354
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e1a1c354
编写于
11月 07, 2022
作者:
W
Wen Sun
提交者:
GitHub
11月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor collective communication all_gather, all_reduce, broadcast & barrier C++ API (#47481)
上级
2337e609
变更
16
展开全部
隐藏空白更改
内联
并排
Showing
16 changed file
with
676 addition
and
767 deletion
+676
-767
paddle/fluid/distributed/collective/CMakeLists.txt
paddle/fluid/distributed/collective/CMakeLists.txt
+1
-0
paddle/fluid/distributed/collective/Common.cc
paddle/fluid/distributed/collective/Common.cc
+2
-0
paddle/fluid/distributed/collective/Common.h
paddle/fluid/distributed/collective/Common.h
+2
-0
paddle/fluid/distributed/collective/NCCLTools.h
paddle/fluid/distributed/collective/NCCLTools.h
+0
-198
paddle/fluid/distributed/collective/ProcessGroup.cc
paddle/fluid/distributed/collective/ProcessGroup.cc
+13
-9
paddle/fluid/distributed/collective/ProcessGroup.h
paddle/fluid/distributed/collective/ProcessGroup.h
+48
-13
paddle/fluid/distributed/collective/ProcessGroupGloo.cc
paddle/fluid/distributed/collective/ProcessGroupGloo.cc
+11
-0
paddle/fluid/distributed/collective/ProcessGroupGloo.h
paddle/fluid/distributed/collective/ProcessGroupGloo.h
+7
-0
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
+364
-328
paddle/fluid/distributed/collective/ProcessGroupNCCL.h
paddle/fluid/distributed/collective/ProcessGroupNCCL.h
+84
-74
paddle/fluid/distributed/collective/ProcessGroupStream.cc
paddle/fluid/distributed/collective/ProcessGroupStream.cc
+38
-37
paddle/fluid/distributed/collective/ProcessGroupStream.h
paddle/fluid/distributed/collective/ProcessGroupStream.h
+30
-25
paddle/fluid/operators/fused/fused_attention_op.cu
paddle/fluid/operators/fused/fused_attention_op.cu
+1
-6
paddle/fluid/operators/fused/fused_feedforward_op.cu
paddle/fluid/operators/fused/fused_feedforward_op.cu
+1
-6
paddle/fluid/pybind/distributed_py.cc
paddle/fluid/pybind/distributed_py.cc
+66
-64
python/paddle/distributed/communication/stream/all_gather.py
python/paddle/distributed/communication/stream/all_gather.py
+8
-7
未找到文件。
paddle/fluid/distributed/collective/CMakeLists.txt
浏览文件 @
e1a1c354
...
@@ -28,6 +28,7 @@ if(WITH_NCCL OR WITH_RCCL)
...
@@ -28,6 +28,7 @@ if(WITH_NCCL OR WITH_RCCL)
enforce
enforce
collective_helper
collective_helper
device_context
device_context
${
DEVICE_EVENT_LIBS
}
dense_tensor
)
dense_tensor
)
if
(
WITH_DISTRIBUTE AND WITH_PSCORE
)
if
(
WITH_DISTRIBUTE AND WITH_PSCORE
)
if
(
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0
)
if
(
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0
)
...
...
paddle/fluid/distributed/collective/Common.cc
浏览文件 @
e1a1c354
...
@@ -40,6 +40,8 @@ std::string GetKeyFromPlaces(const std::vector<Place>& places) {
...
@@ -40,6 +40,8 @@ std::string GetKeyFromPlaces(const std::vector<Place>& places) {
return
placeList
;
return
placeList
;
}
}
std
::
string
GetKeyFromPlace
(
const
Place
&
place
)
{
return
place
.
DebugString
();
}
bool
CheckTensorsInCudaPlace
(
const
std
::
vector
<
phi
::
DenseTensor
>&
tensors
)
{
bool
CheckTensorsInCudaPlace
(
const
std
::
vector
<
phi
::
DenseTensor
>&
tensors
)
{
return
std
::
all_of
(
return
std
::
all_of
(
tensors
.
cbegin
(),
tensors
.
cend
(),
[
&
](
const
phi
::
DenseTensor
&
t
)
{
tensors
.
cbegin
(),
tensors
.
cend
(),
[
&
](
const
phi
::
DenseTensor
&
t
)
{
...
...
paddle/fluid/distributed/collective/Common.h
浏览文件 @
e1a1c354
...
@@ -25,6 +25,8 @@ using Place = paddle::platform::Place;
...
@@ -25,6 +25,8 @@ using Place = paddle::platform::Place;
std
::
vector
<
Place
>
GetPlaceList
(
const
std
::
vector
<
phi
::
DenseTensor
>&
tensors
);
std
::
vector
<
Place
>
GetPlaceList
(
const
std
::
vector
<
phi
::
DenseTensor
>&
tensors
);
// Get the deviceList String from the list of devices
// Get the deviceList String from the list of devices
std
::
string
GetKeyFromPlaces
(
const
std
::
vector
<
Place
>&
places
);
std
::
string
GetKeyFromPlaces
(
const
std
::
vector
<
Place
>&
places
);
// Get the device string from one device
std
::
string
GetKeyFromPlace
(
const
Place
&
place
);
bool
CheckTensorsInCudaPlace
(
const
std
::
vector
<
phi
::
DenseTensor
>&
tensors
);
bool
CheckTensorsInCudaPlace
(
const
std
::
vector
<
phi
::
DenseTensor
>&
tensors
);
...
...
paddle/fluid/distributed/collective/NCCLTools.h
浏览文件 @
e1a1c354
...
@@ -59,204 +59,6 @@ namespace distributed {
...
@@ -59,204 +59,6 @@ namespace distributed {
} \
} \
} while (0)
} while (0)
// NOTE(shenliang03): EventManager are movable not copyable CudaEvent wrapper.
// EventManage is different from paddle::platform::CudaEvent.
// It uses lazy initialization and is only created when the
// Record() method is called for the first time; it also monitors
// device information to ensure that recorded stream and event
// are on the same device.
class
EventManager
{
public:
EventManager
()
{}
explicit
EventManager
(
unsigned
int
flags
)
:
flags_
{
flags
}
{}
~
EventManager
()
{
if
(
is_created_
)
{
platform
::
CUDADeviceGuard
guard
(
device_index_
);
#ifdef PADDLE_WITH_HIP
hipEventDestroy
(
event_
);
#else
cudaEventDestroy
(
event_
);
#endif
}
}
EventManager
(
const
EventManager
&
)
=
delete
;
EventManager
&
operator
=
(
const
EventManager
&
)
=
delete
;
EventManager
(
EventManager
&&
other
)
{
std
::
swap
(
flags_
,
other
.
flags_
);
std
::
swap
(
is_created_
,
other
.
is_created_
);
std
::
swap
(
device_index_
,
other
.
device_index_
);
std
::
swap
(
event_
,
other
.
event_
);
}
EventManager
&
operator
=
(
EventManager
&&
other
)
{
std
::
swap
(
flags_
,
other
.
flags_
);
std
::
swap
(
is_created_
,
other
.
is_created_
);
std
::
swap
(
device_index_
,
other
.
device_index_
);
std
::
swap
(
event_
,
other
.
event_
);
return
*
this
;
}
bool
IsCreated
()
const
{
return
is_created_
;
}
bool
DeviceId
()
const
{
return
device_index_
;
}
gpuEvent_t
GetRawCudaEvent
()
const
{
return
event_
;
}
void
Record
(
const
phi
::
GPUContext
&
ctx
)
{
auto
device_index
=
ctx
.
GetPlace
().
device
;
if
(
!
is_created_
)
{
CreateEvent
(
device_index
);
}
PADDLE_ENFORCE_EQ
(
device_index
,
device_index_
,
platform
::
errors
::
PreconditionNotMet
(
"phi::GPUContext's device %d does not match"
"Event's device %d"
,
device_index
,
device_index_
));
platform
::
CUDADeviceGuard
guard
(
device_index_
);
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaEventRecord
(
event_
,
ctx
.
stream
()));
#else
PADDLE_ENFORCE_GPU_SUCCESS
(
hipEventRecord
(
event_
,
ctx
.
stream
()));
#endif
}
bool
Query
()
const
{
#ifdef PADDLE_WITH_HIP
gpuError_t
err
=
hipEventQuery
(
event_
);
if
(
err
==
hipSuccess
)
{
return
true
;
}
if
(
err
==
hipErrorNotReady
)
{
return
false
;
}
#else
gpuError_t
err
=
cudaEventQuery
(
event_
);
if
(
err
==
cudaSuccess
)
{
return
true
;
}
if
(
err
==
cudaErrorNotReady
)
{
return
false
;
}
#endif
PADDLE_ENFORCE_GPU_SUCCESS
(
err
);
return
false
;
}
void
Synchronize
()
const
{
if
(
is_created_
)
{
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS
(
hipEventSynchronize
(
event_
));
#else
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaEventSynchronize
(
event_
));
#endif
}
}
void
Block
(
const
phi
::
GPUContext
&
ctx
)
const
{
if
(
is_created_
)
{
auto
device_index
=
ctx
.
GetPlace
().
device
;
PADDLE_ENFORCE_EQ
(
device_index
,
device_index_
,
platform
::
errors
::
PreconditionNotMet
(
"phi::GPUContext's device %d does not match"
"Event's device %d"
,
device_index
,
device_index_
));
platform
::
CUDADeviceGuard
guard
(
device_index_
);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS
(
hipStreamWaitEvent
(
ctx
.
stream
(),
event_
,
0
));
#else
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaStreamWaitEvent
(
ctx
.
stream
(),
event_
,
0
));
#endif
}
}
private:
#ifdef PADDLE_WITH_HIP
unsigned
int
flags_
=
hipEventDefault
;
#else
unsigned
int
flags_
=
cudaEventDefault
;
#endif
bool
is_created_
{
false
};
gpuEvent_t
event_
{};
int8_t
device_index_
{
0
};
private:
void
CreateEvent
(
int
device_index
)
{
device_index_
=
device_index
;
platform
::
CUDADeviceGuard
guard
(
device_index
);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS
(
hipEventCreateWithFlags
(
&
event_
,
flags_
));
#else
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaEventCreateWithFlags
(
&
event_
,
flags_
));
#endif
is_created_
=
true
;
}
};
// NOTE(shenliang03): NCCLCommManager is more lightweight than
// platform::NCCLComm
class
NCCLCommManager
{
public:
explicit
NCCLCommManager
(
ncclComm_t
ncclComm
)
:
nccl_comm_
(
ncclComm
)
{}
NCCLCommManager
()
:
NCCLCommManager
(
nullptr
)
{}
~
NCCLCommManager
()
noexcept
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
nccl_comm_
)
{
platform
::
dynload
::
ncclCommDestroy
(
nccl_comm_
);
}
}
static
std
::
shared_ptr
<
NCCLCommManager
>
Create
(
int
num_ranks
,
int
rank
,
ncclUniqueId
comm_id
)
{
auto
nccl_manager
=
std
::
make_shared
<
NCCLCommManager
>
();
NCCLCHECK
(
platform
::
dynload
::
ncclCommInitRank
(
&
(
nccl_manager
->
nccl_comm_
),
num_ranks
,
comm_id
,
rank
));
nccl_manager
->
nccl_id_
=
comm_id
;
nccl_manager
->
rank_
=
rank
;
return
nccl_manager
;
}
ncclUniqueId
GetNcclId
()
const
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
return
nccl_id_
;
}
ncclComm_t
GetNcclComm
()
const
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
return
nccl_comm_
;
}
NCCLCommManager
(
const
NCCLCommManager
&
)
=
delete
;
NCCLCommManager
&
operator
=
(
const
NCCLCommManager
&
)
=
delete
;
NCCLCommManager
&
operator
=
(
NCCLCommManager
&&
other
)
=
delete
;
NCCLCommManager
(
NCCLCommManager
&&
other
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
other
.
mutex_
);
std
::
swap
(
nccl_comm_
,
other
.
nccl_comm_
);
}
protected:
ncclComm_t
nccl_comm_
;
ncclUniqueId
nccl_id_
;
int
rank_
;
mutable
std
::
mutex
mutex_
;
};
ncclRedOp_t
ToNCCLRedType
(
ReduceOp
reduction
);
ncclRedOp_t
ToNCCLRedType
(
ReduceOp
reduction
);
std
::
string
SerializeNCCLUniqueId
(
const
ncclUniqueId
&
ncclID
);
std
::
string
SerializeNCCLUniqueId
(
const
ncclUniqueId
&
ncclID
);
...
...
paddle/fluid/distributed/collective/ProcessGroup.cc
浏览文件 @
e1a1c354
...
@@ -17,15 +17,7 @@
...
@@ -17,15 +17,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
ProcessGroup
::
Task
::
Task
(
int
rank
,
ProcessGroup
::
Task
::
Task
(
int
rank
,
CommType
comm_type
,
bool
sync_op
)
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
CommType
comm_type
)
:
rank_
(
rank
),
comm_type_
(
comm_type
)
{}
ProcessGroup
::
Task
::
Task
(
int
rank
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
CommType
comm_type
,
bool
sync_op
)
:
rank_
(
rank
),
comm_type_
(
comm_type
),
sync_op_
(
sync_op
)
{}
:
rank_
(
rank
),
comm_type_
(
comm_type
),
sync_op_
(
sync_op
)
{}
ProcessGroup
::
Task
::~
Task
()
=
default
;
ProcessGroup
::
Task
::~
Task
()
=
default
;
...
@@ -62,5 +54,17 @@ ProcessGroup::ProcessGroup(int rank, int size, int gid)
...
@@ -62,5 +54,17 @@ ProcessGroup::ProcessGroup(int rank, int size, int gid)
}
}
}
}
// TODO(sunyilun): methods below will be removed later
ProcessGroup
::
Task
::
Task
(
int
rank
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
CommType
comm_type
)
:
rank_
(
rank
),
comm_type_
(
comm_type
)
{}
ProcessGroup
::
Task
::
Task
(
int
rank
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
CommType
comm_type
,
bool
sync_op
)
:
rank_
(
rank
),
comm_type_
(
comm_type
),
sync_op_
(
sync_op
)
{}
}
// namespace distributed
}
// namespace distributed
}
// namespace paddle
}
// namespace paddle
paddle/fluid/distributed/collective/ProcessGroup.h
浏览文件 @
e1a1c354
...
@@ -54,13 +54,7 @@ class ProcessGroup {
...
@@ -54,13 +54,7 @@ class ProcessGroup {
public:
public:
class
Task
{
class
Task
{
public:
public:
Task
(
int
rank
,
Task
(
int
rank
,
CommType
comm_type
,
bool
sync_op
);
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
CommType
comm_type
);
Task
(
int
rank
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
CommType
comm_type
,
bool
sync_op
);
virtual
~
Task
();
virtual
~
Task
();
virtual
bool
IsCompleted
();
virtual
bool
IsCompleted
();
...
@@ -69,6 +63,15 @@ class ProcessGroup {
...
@@ -69,6 +63,15 @@ class ProcessGroup {
virtual
void
UpdateWaitChain
(
const
phi
::
DeviceContext
&
ctx
);
virtual
void
UpdateWaitChain
(
const
phi
::
DeviceContext
&
ctx
);
bool
IsSync
()
const
{
return
sync_op_
;
}
bool
IsSync
()
const
{
return
sync_op_
;
}
// TODO(sunyilun): methods below will be removed later
Task
(
int
rank
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
CommType
comm_type
);
Task
(
int
rank
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
CommType
comm_type
,
bool
sync_op
);
protected:
protected:
const
int
rank_
;
const
int
rank_
;
CommType
comm_type_
{
CommType
::
UNKNOWN
};
CommType
comm_type_
{
CommType
::
UNKNOWN
};
...
@@ -79,6 +82,7 @@ class ProcessGroup {
...
@@ -79,6 +82,7 @@ class ProcessGroup {
bool
sync_op_
{
true
};
bool
sync_op_
{
true
};
};
};
public:
explicit
ProcessGroup
(
int
rank
,
explicit
ProcessGroup
(
int
rank
,
int
size
,
int
size
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
...
@@ -93,12 +97,48 @@ class ProcessGroup {
...
@@ -93,12 +97,48 @@ class ProcessGroup {
int
GetSize
()
const
{
return
size_
;
}
int
GetSize
()
const
{
return
size_
;
}
virtual
std
::
string
GetBackendName
()
const
=
0
;
virtual
std
::
string
GetBackendName
()
const
=
0
;
virtual
const
phi
::
DeviceContext
&
GetDeviceContext
(
const
Place
&
place
)
const
{
virtual
const
phi
::
DeviceContext
&
GetDeviceContext
(
const
Place
&
place
)
const
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Does not support to get device_context from ProcessGroup%s."
,
"Does not support to get device_context from ProcessGroup%s."
,
GetBackendName
()));
GetBackendName
()));
}
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
bool
sync_op
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support all_gather with sync_op flag"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
AllreduceOptions
&
opts
,
bool
sync_op
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support all_reduce with sync_op flag"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Barrier
(
const
BarrierOptions
&
=
BarrierOptions
())
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support barrier"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
BroadcastOptions
&
opts
,
bool
sync_op
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support broadcast with sync_op flag"
,
GetBackendName
()));
}
// TODO(liyurui): This API will be moved later
// TODO(liyurui): This API will be moved later
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
vector
<
phi
::
DenseTensor
>&
/* input tensors */
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
/* input tensors */
,
// NOLINT
...
@@ -118,6 +158,7 @@ class ProcessGroup {
...
@@ -118,6 +158,7 @@ class ProcessGroup {
GetBackendName
()));
GetBackendName
()));
}
}
// TODO(sunyilun): methods below will be removed later
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
/* input tensors */
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
/* input tensors */
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
/* output tensors */
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
/* output tensors */
,
// NOLINT
...
@@ -136,12 +177,6 @@ class ProcessGroup {
...
@@ -136,12 +177,6 @@ class ProcessGroup {
GetBackendName
()));
GetBackendName
()));
}
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Barrier
(
const
BarrierOptions
&
=
BarrierOptions
())
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support barrier"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Send
(
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Send
(
std
::
vector
<
phi
::
DenseTensor
>&
,
int
)
{
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
,
int
)
{
// NOLINT
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
...
...
paddle/fluid/distributed/collective/ProcessGroupGloo.cc
浏览文件 @
e1a1c354
...
@@ -229,6 +229,17 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask {
...
@@ -229,6 +229,17 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask {
}
}
};
};
// TODO(sunyilun): for compatibility, will be updated later
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupGloo
::
Broadcast
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
BroadcastOptions
&
opts
,
bool
sync_op
)
{
std
::
vector
<
phi
::
DenseTensor
>
in_wrapper
=
{
in_tensor
};
std
::
vector
<
phi
::
DenseTensor
>
out_wrapper
=
{
*
out_tensor
};
return
Broadcast
(
in_wrapper
,
out_wrapper
,
opts
,
true
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupGloo
::
Broadcast
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupGloo
::
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
std
::
vector
<
phi
::
DenseTensor
>&
outputs
,
std
::
vector
<
phi
::
DenseTensor
>&
outputs
,
...
...
paddle/fluid/distributed/collective/ProcessGroupGloo.h
浏览文件 @
e1a1c354
...
@@ -108,6 +108,13 @@ class ProcessGroupGloo : public ProcessGroup {
...
@@ -108,6 +108,13 @@ class ProcessGroupGloo : public ProcessGroup {
~
ProcessGroupGloo
()
=
default
;
~
ProcessGroupGloo
()
=
default
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
BroadcastOptions
&
opts
,
bool
sync_op
)
override
;
// TODO(sunyilun): methods below will be removed later
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
std
::
vector
<
phi
::
DenseTensor
>&
outputs
,
std
::
vector
<
phi
::
DenseTensor
>&
outputs
,
...
...
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
浏览文件 @
e1a1c354
此差异已折叠。
点击以展开。
paddle/fluid/distributed/collective/ProcessGroupNCCL.h
浏览文件 @
e1a1c354
...
@@ -24,10 +24,10 @@
...
@@ -24,10 +24,10 @@
#include "paddle/fluid/distributed/collective/ProcessGroupStream.h"
#include "paddle/fluid/distributed/collective/ProcessGroupStream.h"
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device_
contex
t.h"
#include "paddle/fluid/platform/device_
even
t.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/
fluid/platform/gen_comm_id_helper
.h"
#include "paddle/
phi/common/place
.h"
#include "paddle/
fluid/platform/place
.h"
#include "paddle/
phi/core/device_context
.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/NCCLTools.h"
#include "paddle/fluid/distributed/collective/NCCLTools.h"
...
@@ -44,16 +44,28 @@ namespace distributed {
...
@@ -44,16 +44,28 @@ namespace distributed {
using
Place
=
paddle
::
platform
::
Place
;
using
Place
=
paddle
::
platform
::
Place
;
class
ProcessGroupNCCL
:
public
ProcessGroupStream
{
class
ProcessGroupNCCL
final
:
public
ProcessGroupStream
{
public:
public:
class
NCCLTask
:
public
ProcessGroupStream
::
TaskStream
,
class
NCCLTask
final
:
public
ProcessGroupStream
::
TaskStream
,
public
std
::
enable_shared_from_this
<
NCCLTask
>
{
public
std
::
enable_shared_from_this
<
NCCLTask
>
{
public:
public:
NCCLTask
(
const
Place
&
place
,
int
rank
,
CommType
comm_type
,
bool
sync_op
,
bool
use_calc_stream
);
virtual
~
NCCLTask
();
bool
IsCompleted
()
override
;
bool
Wait
(
std
::
chrono
::
milliseconds
timeout
=
kWaitTimeout
)
override
;
void
Synchronize
()
override
;
void
UpdateWaitChain
(
const
phi
::
DeviceContext
&
ctx
)
override
;
// TODO(sunyilun): methods below will be removed later
NCCLTask
(
const
std
::
vector
<
Place
>&
places
,
NCCLTask
(
const
std
::
vector
<
Place
>&
places
,
int
rank
,
int
rank
,
CommType
CommType
,
CommType
CommType
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
);
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
);
NCCLTask
(
const
std
::
vector
<
Place
>&
places
,
NCCLTask
(
const
std
::
vector
<
Place
>&
places
,
int
rank
,
int
rank
,
CommType
comm_type
,
CommType
comm_type
,
...
@@ -61,31 +73,15 @@ class ProcessGroupNCCL : public ProcessGroupStream {
...
@@ -61,31 +73,15 @@ class ProcessGroupNCCL : public ProcessGroupStream {
bool
sync_op
,
bool
sync_op
,
bool
use_calc_stream
);
bool
use_calc_stream
);
bool
IsCompleted
();
public:
bool
barrier_
{
false
};
void
SynchronizeStreams
();
platform
::
DeviceEvent
comm_event_
;
// event on comm stream
bool
Wait
(
std
::
chrono
::
milliseconds
timeout
=
kWaitTimeout
);
void
Synchronize
();
void
SetOutputs
(
std
::
vector
<
phi
::
DenseTensor
>&
outputs
);
// NOLINT
virtual
~
NCCLTask
();
void
UpdateWaitChain
(
const
phi
::
DeviceContext
&
ctx
)
override
;
std
::
vector
<
EventManager
>
control_events_
;
std
::
vector
<
phi
::
DenseTensor
>
barrierTensors_
;
protected:
std
::
vector
<
Place
>
places_
;
std
::
vector
<
std
::
shared_ptr
<
NCCLCommManager
>>
ncclComms_
;
std
::
shared_ptr
<
std
::
vector
<
phi
::
DenseTensor
>>
outputs_
;
private:
private:
Place
place_
;
};
};
public:
ProcessGroupNCCL
(
const
std
::
shared_ptr
<
Store
>&
store
,
ProcessGroupNCCL
(
const
std
::
shared_ptr
<
Store
>&
store
,
int
rank
,
int
rank
,
int
size
,
int
size
,
...
@@ -99,34 +95,47 @@ class ProcessGroupNCCL : public ProcessGroupStream {
...
@@ -99,34 +95,47 @@ class ProcessGroupNCCL : public ProcessGroupStream {
const
phi
::
DeviceContext
&
GetDeviceContext
(
const
phi
::
DeviceContext
&
GetDeviceContext
(
const
Place
&
place
,
bool
use_calc_stream
)
const
override
;
const
Place
&
place
,
bool
use_calc_stream
)
const
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
bool
sync_op
,
bool
use_calc_stream
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
phi
::
DenseTensor
*
out_tensor
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
const
phi
::
DenseTensor
&
in_tensor
,
const
AllreduceOptions
&
options
,
const
AllreduceOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Barrier
(
const
BarrierOptions
&
=
BarrierOptions
())
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
BroadcastOptions
&
opts
,
bool
sync_op
,
bool
sync_op
,
bool
use_calc_stream
)
override
;
bool
use_calc_stream
)
override
;
static
void
GroupStart
();
static
void
GroupEnd
();
ncclComm_t
NCCLComm
(
const
Place
&
place
)
const
;
// TODO(liyurui): This API will be moved later
// TODO(liyurui): This API will be moved later
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
AllreduceOptions
&
=
AllreduceOptions
())
override
;
const
AllreduceOptions
&
=
AllreduceOptions
())
override
;
// TODO(sunyilun): methods below will be removed later
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
BroadcastOptions
&
=
BroadcastOptions
())
override
;
const
BroadcastOptions
&
=
BroadcastOptions
())
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
BroadcastOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Barrier
(
const
BarrierOptions
&
=
BarrierOptions
())
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Send
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Send
(
std
::
vector
<
phi
::
DenseTensor
>&
tensors
,
int
dst_rank
)
override
;
std
::
vector
<
phi
::
DenseTensor
>&
tensors
,
int
dst_rank
)
override
;
...
@@ -175,12 +184,6 @@ class ProcessGroupNCCL : public ProcessGroupStream {
...
@@ -175,12 +184,6 @@ class ProcessGroupNCCL : public ProcessGroupStream {
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
)
override
;
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
bool
sync_op
,
bool
use_calc_stream
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather_Partial
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather_Partial
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
...
@@ -255,20 +258,37 @@ class ProcessGroupNCCL : public ProcessGroupStream {
...
@@ -255,20 +258,37 @@ class ProcessGroupNCCL : public ProcessGroupStream {
phi
::
DenseTensor
&
,
// NOLINT
phi
::
DenseTensor
&
,
// NOLINT
const
ReduceScatterOptions
&
)
override
;
const
ReduceScatterOptions
&
)
override
;
static
void
GroupStart
();
private:
std
::
shared_ptr
<
ProcessGroupNCCL
::
NCCLTask
>
CreateTask
(
const
Place
&
place
,
int
rank
,
CommType
op_type
,
bool
sync_op
,
bool
use_calc_stream
);
static
void
GroupEnd
(
);
void
BroadcastUniqueNCCLID
(
ncclUniqueId
*
nccl_id
);
ncclComm_t
NCCLComm
(
const
Place
&
place
)
const
;
void
CreateNCCLEnvCache
(
const
Place
&
place
,
const
std
::
string
&
place_key
);
template
<
typename
Fn
>
std
::
shared_ptr
<
ProcessGroupStream
::
Task
>
Collective
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
Fn
fn
,
CommType
comm_type
,
bool
sync_op
,
bool
use_calc_stream
);
protected:
void
SyncCalcStream
(
const
Place
&
place
,
virtual
std
::
shared_ptr
<
ProcessGroupNCCL
::
NCCLTask
>
CreateTask
(
const
std
::
shared_ptr
<
platform
::
DeviceEvent
>&
event
);
// TODO(sunyilun): methods below will be removed later
std
::
shared_ptr
<
ProcessGroupNCCL
::
NCCLTask
>
CreateTask
(
std
::
vector
<
Place
>
places
,
std
::
vector
<
Place
>
places
,
int
rank
,
int
rank
,
CommType
op_type
,
CommType
op_type
,
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
);
const
std
::
vector
<
phi
::
DenseTensor
>&
inputs
);
virtual
std
::
shared_ptr
<
ProcessGroupNCCL
::
NCCLTask
>
CreateTask
(
std
::
shared_ptr
<
ProcessGroupNCCL
::
NCCLTask
>
CreateTask
(
const
std
::
vector
<
Place
>&
places
,
const
std
::
vector
<
Place
>&
places
,
int
rank
,
int
rank
,
CommType
op_type
,
CommType
op_type
,
...
@@ -276,27 +296,6 @@ class ProcessGroupNCCL : public ProcessGroupStream {
...
@@ -276,27 +296,6 @@ class ProcessGroupNCCL : public ProcessGroupStream {
bool
sync_op
,
bool
sync_op
,
bool
use_calc_stream
);
bool
use_calc_stream
);
protected:
std
::
shared_ptr
<
Store
>
store_
;
std
::
shared_ptr
<
NCCLCommManager
>
nccl_comm_
;
std
::
mutex
mutex_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
NCCLCommManager
>>>
places_to_ncclcomm_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
EventManager
>>
places_to_events_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
phi
::
GPUContext
>>>
places_to_ctx_
;
std
::
set
<
int
>
used_place_ids_
;
private:
void
BcastNCCLId
(
std
::
vector
<
ncclUniqueId
>&
nccl_ids
,
// NOLINT
int
root
,
// NOLINT
int
server_fd
);
void
BroadcastUniqueNCCLID
(
std
::
vector
<
ncclUniqueId
>&
nccl_ids
);
// NOLINT
template
<
typename
Fn
>
template
<
typename
Fn
>
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Collective
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Collective
(
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
inputs
,
// NOLINT
...
@@ -340,6 +339,17 @@ class ProcessGroupNCCL : public ProcessGroupStream {
...
@@ -340,6 +339,17 @@ class ProcessGroupNCCL : public ProcessGroupStream {
void
CheckSplitSizes
(
std
::
vector
<
int64_t
>*
split_sizes
,
void
CheckSplitSizes
(
std
::
vector
<
int64_t
>*
split_sizes
,
std
::
vector
<
int64_t
>
tensor_shape
);
std
::
vector
<
int64_t
>
tensor_shape
);
private:
std
::
shared_ptr
<
Store
>
store_
;
std
::
shared_ptr
<
platform
::
DeviceEvent
>
calc_event_
;
// event on calc stream
std
::
unordered_map
<
std
::
string
,
phi
::
GPUContext
*>
place_to_calc_ctx_
;
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
phi
::
GPUContext
>>
place_to_comm_ctx_
;
// TODO(sunyilun): attrs below will be removed later
std
::
mutex
mutex_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
phi
::
GPUContext
*>>
places_to_ctx_
;
};
};
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/collective/ProcessGroupStream.cc
浏览文件 @
e1a1c354
...
@@ -30,18 +30,18 @@ const phi::DeviceContext& ProcessGroupStream::GetDeviceContext(
...
@@ -30,18 +30,18 @@ const phi::DeviceContext& ProcessGroupStream::GetDeviceContext(
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllGather
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllGather
(
std
::
vector
<
phi
::
DenseTensor
>&
input_tensors
,
// NOLINT
phi
::
DenseTensor
*
out_tensor
,
std
::
vector
<
phi
::
DenseTensor
>&
output_tensors
,
// NOLINT
const
phi
::
DenseTensor
&
in_tensor
,
bool
sync_op
)
{
bool
sync_op
)
{
return
AllGather
(
input_tensors
,
return
AllGather
(
out_tensor
,
output_tensors
,
in_tensor
,
sync_op
,
sync_op
,
/*use_calc_stream*/
false
);
/*use_calc_stream*/
false
);
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllGather
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllGather
(
std
::
vector
<
phi
::
DenseTensor
>&
input_tensors
,
// NOLINT
phi
::
DenseTensor
*
out_tensor
,
std
::
vector
<
phi
::
DenseTensor
>&
output_tensors
,
// NOLINT
const
phi
::
DenseTensor
&
in_tensor
,
bool
sync_op
,
bool
sync_op
,
bool
use_calc_stream
)
{
bool
use_calc_stream
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
...
@@ -49,27 +49,50 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
...
@@ -49,27 +49,50 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllReduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllReduce
(
std
::
vector
<
phi
::
DenseTensor
>&
input_tensors
,
// NOLINT
phi
::
DenseTensor
*
out_tensor
,
std
::
vector
<
phi
::
DenseTensor
>&
output_tensors
,
// NOLINT
const
phi
::
DenseTensor
&
in_tensor
,
const
AllreduceOptions
&
opt
ion
s
,
const
AllreduceOptions
&
opts
,
bool
sync_op
)
{
bool
sync_op
)
{
return
AllReduce
(
input_tensors
,
return
AllReduce
(
out_tensor
,
output_tensors
,
in_tensor
,
opt
ion
s
,
opts
,
sync_op
,
sync_op
,
/*use_calc_stream*/
false
);
/*use_calc_stream*/
false
);
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllReduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllReduce
(
std
::
vector
<
phi
::
DenseTensor
>&
input_tensors
,
// NOLINT
phi
::
DenseTensor
*
out_tensor
,
std
::
vector
<
phi
::
DenseTensor
>&
output_tensors
,
// NOLINT
const
phi
::
DenseTensor
&
in_tensor
,
const
AllreduceOptions
&
opt
ion
s
,
const
AllreduceOptions
&
opts
,
bool
sync_op
,
bool
sync_op
,
bool
use_calc_stream
)
{
bool
use_calc_stream
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support do all_reduce"
,
GetBackendName
()));
"ProcessGroup%s does not support do all_reduce"
,
GetBackendName
()));
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Broadcast
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
BroadcastOptions
&
opts
,
bool
sync_op
)
{
return
Broadcast
(
out_tensor
,
in_tensor
,
opts
,
sync_op
,
/*use_calc_stream*/
false
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Broadcast
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
BroadcastOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support do broadcast"
,
GetBackendName
()));
}
// TODO(sunyilun): methods below will be removed later
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllToAll
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
AllToAll
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
...
@@ -114,28 +137,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAllSingle(
...
@@ -114,28 +137,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAllSingle(
"ProcessGroup%s does not support do alltoall_single"
,
GetBackendName
()));
"ProcessGroup%s does not support do alltoall_single"
,
GetBackendName
()));
}
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
BroadcastOptions
&
opts
,
bool
sync_op
)
{
return
Broadcast
(
in_tensors
,
out_tensors
,
opts
,
sync_op
,
/*use_calc_stream*/
false
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
const
BroadcastOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support do broadcast"
,
GetBackendName
()));
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Reduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupStream
::
Reduce
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
...
...
paddle/fluid/distributed/collective/ProcessGroupStream.h
浏览文件 @
e1a1c354
...
@@ -27,6 +27,11 @@ class ProcessGroupStream : public ProcessGroup {
...
@@ -27,6 +27,11 @@ class ProcessGroupStream : public ProcessGroup {
public:
public:
class
TaskStream
:
public
ProcessGroup
::
Task
{
class
TaskStream
:
public
ProcessGroup
::
Task
{
public:
public:
TaskStream
(
int
rank
,
CommType
comm_type
,
bool
sync_op
,
bool
use_calc_stream
)
:
Task
(
rank
,
comm_type
,
sync_op
),
use_calc_stream_
(
use_calc_stream
)
{}
virtual
~
TaskStream
()
=
default
;
// TODO(liyurui): This constructor is temporary here for compatible reason,
// TODO(liyurui): This constructor is temporary here for compatible reason,
// will be deleted soon.
// will be deleted soon.
TaskStream
(
int
rank
,
TaskStream
(
int
rank
,
...
@@ -42,8 +47,6 @@ class ProcessGroupStream : public ProcessGroup {
...
@@ -42,8 +47,6 @@ class ProcessGroupStream : public ProcessGroup {
:
Task
(
rank
,
inputs
,
comm_type
,
sync_op
),
:
Task
(
rank
,
inputs
,
comm_type
,
sync_op
),
use_calc_stream_
(
use_calc_stream
)
{}
use_calc_stream_
(
use_calc_stream
)
{}
virtual
~
TaskStream
()
=
default
;
protected:
protected:
bool
UseCalcStream
()
const
{
return
use_calc_stream_
;
}
bool
UseCalcStream
()
const
{
return
use_calc_stream_
;
}
...
@@ -51,6 +54,7 @@ class ProcessGroupStream : public ProcessGroup {
...
@@ -51,6 +54,7 @@ class ProcessGroupStream : public ProcessGroup {
bool
use_calc_stream_
{
false
};
bool
use_calc_stream_
{
false
};
};
};
public:
ProcessGroupStream
(
int
rank
,
int
size
,
const
platform
::
Place
&
place
,
int
gid
);
ProcessGroupStream
(
int
rank
,
int
size
,
const
platform
::
Place
&
place
,
int
gid
);
virtual
~
ProcessGroupStream
()
=
default
;
virtual
~
ProcessGroupStream
()
=
default
;
...
@@ -58,29 +62,43 @@ class ProcessGroupStream : public ProcessGroup {
...
@@ -58,29 +62,43 @@ class ProcessGroupStream : public ProcessGroup {
const
Place
&
place
,
bool
use_calc_stream
)
const
;
const
Place
&
place
,
bool
use_calc_stream
)
const
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
phi
::
DenseTensor
*
out_tensor
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
const
phi
::
DenseTensor
&
in_tensor
,
bool
sync_op
)
override
;
bool
sync_op
)
override
;
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
phi
::
DenseTensor
*
out_tensor
,
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
const
phi
::
DenseTensor
&
in_tensor
,
bool
sync_op
,
bool
sync_op
,
bool
use_calc_stream
);
bool
use_calc_stream
);
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
vector
<
phi
::
DenseTensor
>&
input_tensors
,
// NOLINT
phi
::
DenseTensor
*
out_tensor
,
std
::
vector
<
phi
::
DenseTensor
>&
output_tensors
,
// NOLINT
const
phi
::
DenseTensor
&
in_tensor
,
const
AllreduceOptions
&
opt
ion
s
,
const
AllreduceOptions
&
opts
,
bool
sync_op
)
override
;
bool
sync_op
)
override
;
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
vector
<
phi
::
DenseTensor
>&
input_tensors
,
// NOLINT
phi
::
DenseTensor
*
out_tensor
,
std
::
vector
<
phi
::
DenseTensor
>&
output_tensors
,
// NOLINT
const
phi
::
DenseTensor
&
in_tensor
,
const
AllreduceOptions
&
options
,
const
AllreduceOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
);
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
BroadcastOptions
&
opts
,
bool
sync_op
)
override
;
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
BroadcastOptions
&
opts
,
bool
sync_op
,
bool
sync_op
,
bool
use_calc_stream
);
bool
use_calc_stream
);
// TODO(sunyilun): methods below will be removed later
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllToAll
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllToAll
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
...
@@ -107,19 +125,6 @@ class ProcessGroupStream : public ProcessGroup {
...
@@ -107,19 +125,6 @@ class ProcessGroupStream : public ProcessGroup {
bool
sync_op
,
bool
sync_op
,
bool
use_calc_stream
);
bool
use_calc_stream
);
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
const
BroadcastOptions
&
opts
,
bool
sync_op
)
override
;
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
const
BroadcastOptions
&
opts
,
bool
sync_op
,
bool
use_calc_stream
);
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Reduce
(
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Reduce
(
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
in_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
std
::
vector
<
phi
::
DenseTensor
>&
out_tensors
,
// NOLINT
...
...
paddle/fluid/operators/fused/fused_attention_op.cu
浏览文件 @
e1a1c354
...
@@ -51,14 +51,9 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT
...
@@ -51,14 +51,9 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT
if
(
map
->
has
(
ring_id
))
{
if
(
map
->
has
(
ring_id
))
{
paddle
::
distributed
::
ProcessGroup
*
pg
=
map
->
get
(
ring_id
);
paddle
::
distributed
::
ProcessGroup
*
pg
=
map
->
get
(
ring_id
);
auto
pg_nccl
=
static_cast
<
distributed
::
ProcessGroupNCCL
*>
(
pg
);
auto
pg_nccl
=
static_cast
<
distributed
::
ProcessGroupNCCL
*>
(
pg
);
std
::
vector
<
phi
::
DenseTensor
>
in_tensor
;
std
::
vector
<
phi
::
DenseTensor
>
out_tensor
;
in_tensor
.
push_back
(
tensor
);
out_tensor
.
push_back
(
tensor
);
paddle
::
distributed
::
AllreduceOptions
opts
;
paddle
::
distributed
::
AllreduceOptions
opts
;
opts
.
reduce_op
=
distributed
::
ReduceOp
::
SUM
;
opts
.
reduce_op
=
distributed
::
ReduceOp
::
SUM
;
auto
task
=
pg_nccl
->
AllReduce
(
in_tensor
,
out_
tensor
,
opts
,
true
,
true
);
auto
task
=
pg_nccl
->
AllReduce
(
&
tensor
,
tensor
,
opts
,
true
,
true
);
task
->
Wait
();
task
->
Wait
();
}
else
{
}
else
{
auto
dtype
=
platform
::
ToNCCLDataType
(
auto
dtype
=
platform
::
ToNCCLDataType
(
...
...
paddle/fluid/operators/fused/fused_feedforward_op.cu
浏览文件 @
e1a1c354
...
@@ -44,14 +44,9 @@ static void AllReduce(phi::DenseTensor& tensor, // NOLINT
...
@@ -44,14 +44,9 @@ static void AllReduce(phi::DenseTensor& tensor, // NOLINT
if
(
map
->
has
(
ring_id
))
{
if
(
map
->
has
(
ring_id
))
{
paddle
::
distributed
::
ProcessGroup
*
pg
=
map
->
get
(
ring_id
);
paddle
::
distributed
::
ProcessGroup
*
pg
=
map
->
get
(
ring_id
);
auto
pg_nccl
=
static_cast
<
distributed
::
ProcessGroupNCCL
*>
(
pg
);
auto
pg_nccl
=
static_cast
<
distributed
::
ProcessGroupNCCL
*>
(
pg
);
std
::
vector
<
phi
::
DenseTensor
>
in_tensor
;
std
::
vector
<
phi
::
DenseTensor
>
out_tensor
;
in_tensor
.
push_back
(
tensor
);
out_tensor
.
push_back
(
tensor
);
paddle
::
distributed
::
AllreduceOptions
opts
;
paddle
::
distributed
::
AllreduceOptions
opts
;
opts
.
reduce_op
=
distributed
::
ReduceOp
::
SUM
;
opts
.
reduce_op
=
distributed
::
ReduceOp
::
SUM
;
auto
task
=
pg_nccl
->
AllReduce
(
in_tensor
,
out_
tensor
,
opts
,
true
,
true
);
auto
task
=
pg_nccl
->
AllReduce
(
&
tensor
,
tensor
,
opts
,
true
,
true
);
task
->
Wait
();
task
->
Wait
();
}
else
{
}
else
{
auto
dtype
=
platform
::
ToNCCLDataType
(
auto
dtype
=
platform
::
ToNCCLDataType
(
...
...
paddle/fluid/pybind/distributed_py.cc
浏览文件 @
e1a1c354
...
@@ -147,12 +147,12 @@ void BindDistributed(py::module *m) {
...
@@ -147,12 +147,12 @@ void BindDistributed(py::module *m) {
distributed
::
ReduceOp
op
,
distributed
::
ReduceOp
op
,
bool
sync_op
)
{
bool
sync_op
)
{
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
distributed
::
AllreduceOptions
opts
;
auto
p_dense
=
opts
.
reduce_op
=
op
;
auto
dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
.
impl
());
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
tensors
=
{
*
dense
};
auto
*
out_dense
=
p_dense
.
get
();
return
self
.
AllReduce
(
tensors
,
tensors
,
opts
,
sync_op
);
auto
in_dense
=
*
p_dense
;
distributed
::
AllreduceOptions
opts
{
op
};
return
self
.
AllReduce
(
out_dense
,
in_dense
,
opts
,
sync_op
);
},
},
py
::
arg
(
"tensor"
),
py
::
arg
(
"tensor"
),
py
::
arg
(
"op"
),
py
::
arg
(
"op"
),
...
@@ -183,11 +183,12 @@ void BindDistributed(py::module *m) {
...
@@ -183,11 +183,12 @@ void BindDistributed(py::module *m) {
int
src
,
int
src
,
bool
sync_op
)
{
bool
sync_op
)
{
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
distributed
::
BroadcastOptions
opts
{
src
};
auto
p_dense
=
auto
dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
.
impl
());
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
tensors
=
{
*
dense
};
auto
*
out_dense
=
p_dense
.
get
();
return
self
.
Broadcast
(
tensors
,
tensors
,
opts
,
sync_op
);
auto
in_dense
=
*
p_dense
;
distributed
::
BroadcastOptions
opts
{
src
};
return
self
.
Broadcast
(
out_dense
,
in_dense
,
opts
,
sync_op
);
},
},
py
::
arg
(
"tensor"
),
py
::
arg
(
"tensor"
),
py
::
arg
(
"src"
),
py
::
arg
(
"src"
),
...
@@ -380,52 +381,52 @@ void BindDistributed(py::module *m) {
...
@@ -380,52 +381,52 @@ void BindDistributed(py::module *m) {
.
def
(
.
def
(
"allgather"
,
"allgather"
,
[](
distributed
::
ProcessGroup
&
self
,
[](
distributed
::
ProcessGroup
&
self
,
py
::
handle
py_in_tensor
,
py
::
handle
py_out_tensor_list
,
py
::
handle
py_out_tensor_list
,
py
::
handle
py_in_tensor
,
bool
sync_op
)
{
bool
sync_op
)
{
auto
in_tensor
=
CastPyArg2Tensor
(
py_in_tensor
.
ptr
(),
0
);
auto
in_dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
in_tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
in_wrapper
=
{
*
in_dense
};
auto
out_tensor_list
=
auto
out_tensor_list
=
CastPyArg2VectorOfTensor
(
py_out_tensor_list
.
ptr
(),
0
);
CastPyArg2VectorOfTensor
(
py_out_tensor_list
.
ptr
(),
0
);
Tensor
concat_out_tensor
=
paddle
::
concat
(
out_tensor_list
,
0
);
Tensor
concat_out_tensor
=
paddle
::
concat
(
out_tensor_list
,
0
);
auto
out_dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
auto
p_out_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
concat_out_tensor
.
impl
());
concat_out_tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
out_wrapper
=
{
*
out_dense
};
auto
*
out_dense
=
p_out_tensor
.
get
();
auto
in_tensor
=
CastPyArg2Tensor
(
py_in_tensor
.
ptr
(),
0
);
auto
p_in_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
in_tensor
.
impl
());
auto
in_dense
=
*
p_in_tensor
;
const
auto
&
dev_ctx
=
self
.
GetDeviceContext
(
in_tensor
.
place
());
const
auto
&
dev_ctx
=
self
.
GetDeviceContext
(
in_tensor
.
place
());
auto
task
=
self
.
AllGather
(
in_wrapper
,
out_wrapper
,
sync_op
);
auto
task
=
self
.
AllGather
(
out_dense
,
in_dense
,
sync_op
);
distributed
::
SplitTensor
(
dev_ctx
,
*
out_dense
,
&
out_tensor_list
);
distributed
::
SplitTensor
(
dev_ctx
,
*
out_dense
,
&
out_tensor_list
);
task
->
UpdateWaitChain
(
dev_ctx
);
task
->
UpdateWaitChain
(
dev_ctx
);
return
task
;
return
task
;
},
},
py
::
arg
(
"in"
),
py
::
arg
(
"out"
),
py
::
arg
(
"out"
),
py
::
arg
(
"in"
),
py
::
arg
(
"sync_op"
),
py
::
arg
(
"sync_op"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
.
def
(
"allgather_into_tensor"
,
"allgather_into_tensor"
,
[](
distributed
::
ProcessGroup
&
self
,
[](
distributed
::
ProcessGroup
&
self
,
py
::
handle
py_in_tensor
,
py
::
handle
py_out_tensor
,
py
::
handle
py_out_tensor
,
py
::
handle
py_in_tensor
,
bool
sync_op
)
{
bool
sync_op
)
{
auto
in_tensor
=
CastPyArg2Tensor
(
py_in_tensor
.
ptr
(),
0
);
auto
in_dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
in_tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
in_wrapper
=
{
*
in_dense
};
auto
out_tensor
=
CastPyArg2Tensor
(
py_out_tensor
.
ptr
(),
0
);
auto
out_tensor
=
CastPyArg2Tensor
(
py_out_tensor
.
ptr
(),
0
);
auto
out_dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
auto
p_out_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
out_tensor
.
impl
());
out_tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
out_wrapper
=
{
*
out_dense
};
auto
*
out_dense
=
p_out_tensor
.
get
();
auto
in_tensor
=
CastPyArg2Tensor
(
py_in_tensor
.
ptr
(),
0
);
auto
p_in_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
in_tensor
.
impl
());
auto
in_dense
=
*
p_in_tensor
;
return
self
.
AllGather
(
in_wrapper
,
out_wrapper
,
sync_op
);
return
self
.
AllGather
(
out_dense
,
in_dense
,
sync_op
);
},
},
py
::
arg
(
"in"
),
py
::
arg
(
"out"
),
py
::
arg
(
"out"
),
py
::
arg
(
"in"
),
py
::
arg
(
"sync_op"
),
py
::
arg
(
"sync_op"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
...
@@ -784,55 +785,55 @@ void BindDistributed(py::module *m) {
...
@@ -784,55 +785,55 @@ void BindDistributed(py::module *m) {
.
def
(
.
def
(
"allgather_on_calc_stream"
,
"allgather_on_calc_stream"
,
[](
distributed
::
ProcessGroupStream
&
self
,
[](
distributed
::
ProcessGroupStream
&
self
,
py
::
handle
py_in_tensor
,
py
::
handle
py_out_tensor_list
,
py
::
handle
py_out_tensor_list
)
{
py
::
handle
py_in_tensor
)
{
auto
in_tensor
=
CastPyArg2Tensor
(
py_in_tensor
.
ptr
(),
0
);
auto
in_dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
in_tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
in_wrapper
=
{
*
in_dense
};
auto
out_tensor_list
=
auto
out_tensor_list
=
CastPyArg2VectorOfTensor
(
py_out_tensor_list
.
ptr
(),
0
);
CastPyArg2VectorOfTensor
(
py_out_tensor_list
.
ptr
(),
0
);
Tensor
concat_out_tensor
=
paddle
::
concat
(
out_tensor_list
,
0
);
Tensor
concat_out_tensor
=
paddle
::
concat
(
out_tensor_list
,
0
);
auto
out_dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
auto
p_out_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
concat_out_tensor
.
impl
());
concat_out_tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
out_wrapper
=
{
*
out_dense
};
auto
*
out_dense
=
p_out_tensor
.
get
();
auto
in_tensor
=
CastPyArg2Tensor
(
py_in_tensor
.
ptr
(),
0
);
auto
p_in_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
in_tensor
.
impl
());
auto
in_dense
=
*
p_in_tensor
;
const
auto
&
dev_ctx
=
const
auto
&
dev_ctx
=
self
.
GetDeviceContext
(
in_tensor
.
place
(),
true
);
self
.
GetDeviceContext
(
in_tensor
.
place
(),
true
);
auto
task
=
self
.
AllGather
(
in_wrapper
,
auto
task
=
self
.
AllGather
(
out_dense
,
out_wrapper
,
in_dense
,
/*sync_op*/
true
,
/*sync_op*/
true
,
/*use_calc_stream*/
true
);
/*use_calc_stream*/
true
);
distributed
::
SplitTensor
(
dev_ctx
,
*
out_dense
,
&
out_tensor_list
);
distributed
::
SplitTensor
(
dev_ctx
,
*
out_dense
,
&
out_tensor_list
);
return
task
;
return
task
;
},
},
py
::
arg
(
"in"
),
py
::
arg
(
"out"
),
py
::
arg
(
"out"
),
py
::
arg
(
"in"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
.
def
(
"allgather_into_tensor_on_calc_stream"
,
"allgather_into_tensor_on_calc_stream"
,
[](
distributed
::
ProcessGroupStream
&
self
,
[](
distributed
::
ProcessGroupStream
&
self
,
py
::
handle
py_in_tensor
,
py
::
handle
py_out_tensor
,
py
::
handle
py_out_tensor
)
{
py
::
handle
py_in_tensor
)
{
auto
in_tensor
=
CastPyArg2Tensor
(
py_in_tensor
.
ptr
(),
0
);
auto
in_dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
in_tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
in_wrapper
=
{
*
in_dense
};
auto
out_tensor
=
CastPyArg2Tensor
(
py_out_tensor
.
ptr
(),
0
);
auto
out_tensor
=
CastPyArg2Tensor
(
py_out_tensor
.
ptr
(),
0
);
auto
out_dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
auto
p_out_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
out_tensor
.
impl
());
out_tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
out_wrapper
=
{
*
out_dense
}
;
auto
*
out_dense
=
p_out_tensor
.
get
()
;
return
self
.
AllGather
(
in_wrapper
,
auto
in_tensor
=
CastPyArg2Tensor
(
py_in_tensor
.
ptr
(),
0
);
out_wrapper
,
auto
p_in_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
in_tensor
.
impl
());
auto
in_dense
=
*
p_in_tensor
;
return
self
.
AllGather
(
out_dense
,
in_dense
,
/*sync_op*/
true
,
/*sync_op*/
true
,
/*use_calc_stream*/
true
);
/*use_calc_stream*/
true
);
},
},
py
::
arg
(
"in"
),
py
::
arg
(
"out"
),
py
::
arg
(
"out"
),
py
::
arg
(
"in"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
.
def
(
...
@@ -872,13 +873,13 @@ void BindDistributed(py::module *m) {
...
@@ -872,13 +873,13 @@ void BindDistributed(py::module *m) {
py
::
handle
py_tensor
,
py
::
handle
py_tensor
,
distributed
::
ReduceOp
op
)
{
distributed
::
ReduceOp
op
)
{
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
distributed
::
AllreduceOptions
opts
;
auto
p_dense
=
opts
.
reduce_op
=
op
;
auto
dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
.
impl
());
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
tensors
=
{
*
dense
};
auto
in_dense
=
*
p_dense
;
return
self
.
AllReduce
(
tensors
,
auto
*
out_dense
=
p_dense
.
get
();
tensors
,
distributed
::
AllreduceOptions
opts
{
op
};
return
self
.
AllReduce
(
out_dense
,
in_dense
,
opts
,
opts
,
/*sync_op*/
true
,
/*sync_op*/
true
,
/*use_calc_stream*/
true
);
/*use_calc_stream*/
true
);
...
@@ -980,12 +981,13 @@ void BindDistributed(py::module *m) {
...
@@ -980,12 +981,13 @@ void BindDistributed(py::module *m) {
py
::
handle
py_tensor
,
py
::
handle
py_tensor
,
int
src
)
{
int
src
)
{
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
distributed
::
BroadcastOptions
opts
{
src
};
auto
p_dense
=
auto
dense
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
.
impl
());
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
.
impl
());
std
::
vector
<
phi
::
DenseTensor
>
tensors
=
{
*
dense
};
auto
*
out_dense
=
p_dense
.
get
();
return
self
.
Broadcast
(
tensors
,
auto
in_dense
=
*
p_dense
;
tensors
,
distributed
::
BroadcastOptions
opts
{
src
};
return
self
.
Broadcast
(
out_dense
,
in_dense
,
opts
,
opts
,
/*sync_op*/
true
,
/*sync_op*/
true
,
/*use_calc_stream*/
true
);
/*use_calc_stream*/
true
);
...
...
python/paddle/distributed/communication/stream/all_gather.py
浏览文件 @
e1a1c354
...
@@ -21,18 +21,18 @@ def _check_tensor_shape(tensor, shape, nranks=1):
...
@@ -21,18 +21,18 @@ def _check_tensor_shape(tensor, shape, nranks=1):
expect_shape
=
list
(
shape
)
expect_shape
=
list
(
shape
)
expect_shape
[
0
]
*=
nranks
expect_shape
[
0
]
*=
nranks
if
list
(
tensor
.
shape
)
!=
expect_shape
:
if
list
(
tensor
.
shape
)
!=
expect_shape
:
raise
RuntimeError
(
'The tensor for all_gather is not correctly-sized.'
)
raise
RuntimeError
(
"The tensor for all_gather is not correctly-sized."
)
def
_check_tensor_list_shape
(
tensor_list
,
shape
,
nranks
=
1
):
def
_check_tensor_list_shape
(
tensor_list
,
shape
,
nranks
=
1
):
if
len
(
tensor_list
)
!=
nranks
:
if
len
(
tensor_list
)
!=
nranks
:
raise
RuntimeError
(
raise
RuntimeError
(
'The tensor_list for all_gather is not correctly-sized.'
"The tensor_list for all_gather is not correctly-sized."
)
)
for
tensor
in
tensor_list
:
for
tensor
in
tensor_list
:
if
tensor
.
shape
!=
shape
:
if
tensor
.
shape
!=
shape
:
raise
RuntimeError
(
raise
RuntimeError
(
'The tensor_list for all_gather is not correctly-sized.'
"The tensor_list for all_gather is not correctly-sized."
)
)
...
@@ -45,11 +45,12 @@ def _all_gather_into_tensor_in_dygraph(
...
@@ -45,11 +45,12 @@ def _all_gather_into_tensor_in_dygraph(
if
use_calc_stream
:
if
use_calc_stream
:
return
group
.
process_group
.
allgather_into_tensor_on_calc_stream
(
return
group
.
process_group
.
allgather_into_tensor_on_calc_stream
(
in_tensor
,
out_tensor
out_tensor
,
in_tensor
,
)
)
task
=
group
.
process_group
.
allgather_into_tensor
(
task
=
group
.
process_group
.
allgather_into_tensor
(
in_tensor
,
out
_tensor
,
sync_op
out_tensor
,
in
_tensor
,
sync_op
)
)
if
sync_op
:
if
sync_op
:
task
.
wait
()
task
.
wait
()
...
@@ -68,9 +69,9 @@ def _all_gather_in_dygraph(
...
@@ -68,9 +69,9 @@ def _all_gather_in_dygraph(
_check_tensor_list_shape
(
tensor_list
,
tensor
.
shape
,
group
.
nranks
)
_check_tensor_list_shape
(
tensor_list
,
tensor
.
shape
,
group
.
nranks
)
if
use_calc_stream
:
if
use_calc_stream
:
return
group
.
process_group
.
allgather_on_calc_stream
(
tensor
,
tensor_list
)
return
group
.
process_group
.
allgather_on_calc_stream
(
tensor
_list
,
tensor
)
task
=
group
.
process_group
.
allgather
(
tensor
,
tensor_list
,
sync_op
)
task
=
group
.
process_group
.
allgather
(
tensor
_list
,
tensor
,
sync_op
)
if
sync_op
:
if
sync_op
:
task
.
wait
()
task
.
wait
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录