Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e7711592
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e7711592
编写于
12月 12, 2022
作者:
W
Wen Sun
提交者:
GitHub
12月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add dynamic checks for collective communication on NCCL (#48915)
* chore: unify `SingleTensor` * feat: dynamic check
上级
e66dbc38
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
474 addition
and
107 deletion
+474
-107
paddle/fluid/distributed/collective/CMakeLists.txt
paddle/fluid/distributed/collective/CMakeLists.txt
+1
-1
paddle/fluid/distributed/collective/NCCLTools.cc
paddle/fluid/distributed/collective/NCCLTools.cc
+1
-1
paddle/fluid/distributed/collective/NCCLTools.h
paddle/fluid/distributed/collective/NCCLTools.h
+12
-25
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
+126
-74
paddle/fluid/distributed/collective/ProcessGroupNCCL.h
paddle/fluid/distributed/collective/ProcessGroupNCCL.h
+2
-2
paddle/fluid/distributed/collective/check.cc
paddle/fluid/distributed/collective/check.cc
+290
-0
paddle/fluid/distributed/collective/check.h
paddle/fluid/distributed/collective/check.h
+42
-4
未找到文件。
paddle/fluid/distributed/collective/CMakeLists.txt
浏览文件 @
e7711592
...
@@ -21,7 +21,7 @@ endif()
...
@@ -21,7 +21,7 @@ endif()
if
(
WITH_NCCL OR WITH_RCCL
)
if
(
WITH_NCCL OR WITH_RCCL
)
cc_library
(
cc_library
(
processgroup_nccl
processgroup_nccl
SRCS ProcessGroupNCCL.cc NCCLTools.cc Common.cc
static_
check.cc
SRCS ProcessGroupNCCL.cc NCCLTools.cc Common.cc check.cc
DEPS processgroup
DEPS processgroup
processgroup_stream
processgroup_stream
place
place
...
...
paddle/fluid/distributed/collective/NCCLTools.cc
浏览文件 @
e7711592
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#include "paddle/fluid/distributed/collective/NCCLTools.h"
#include "paddle/fluid/distributed/collective/NCCLTools.h"
#include "paddle/fluid/
distributed/collective/Types
.h"
#include "paddle/fluid/
platform/enforce
.h"
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
...
...
paddle/fluid/distributed/collective/NCCLTools.h
浏览文件 @
e7711592
...
@@ -21,29 +21,16 @@
...
@@ -21,29 +21,16 @@
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#endif
#endif
#include <error.h>
#include <string>
#include <string>
#include "paddle/fluid/distributed/collective/Types.h"
#include "paddle/fluid/distributed/collective/Types.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/variable.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
#include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_RCCL
#ifdef PADDLE_WITH_RCCL
#include "paddle/
fluid/platform
/dynload/rccl.h"
#include "paddle/
phi/backends
/dynload/rccl.h"
#else
#else
#include "paddle/
fluid/platform
/dynload/nccl.h"
#include "paddle/
phi/backends
/dynload/nccl.h"
#endif
#endif
#include "paddle/fluid/platform/enforce.h"
#include "paddle/utils/variant.h"
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
...
@@ -54,7 +41,7 @@ namespace distributed {
...
@@ -54,7 +41,7 @@ namespace distributed {
printf("Failed, NCCL error %s:%d '%s'\n", \
printf("Failed, NCCL error %s:%d '%s'\n", \
__FILE__, \
__FILE__, \
__LINE__, \
__LINE__, \
p
latform
::dynload::ncclGetErrorString(r)); \
p
hi
::dynload::ncclGetErrorString(r)); \
exit(EXIT_FAILURE); \
exit(EXIT_FAILURE); \
} \
} \
} while (0)
} while (0)
...
...
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
浏览文件 @
e7711592
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
#include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/distributed/collective/NCCLTools.h"
#include "paddle/fluid/distributed/collective/NCCLTools.h"
#include "paddle/fluid/distributed/collective/
static_
check.h"
#include "paddle/fluid/distributed/collective/check.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
...
@@ -25,6 +25,8 @@
...
@@ -25,6 +25,8 @@
DECLARE_bool
(
nccl_blocking_wait
);
DECLARE_bool
(
nccl_blocking_wait
);
DECLARE_bool
(
use_stream_safe_cuda_allocator
);
DECLARE_bool
(
use_stream_safe_cuda_allocator
);
// set this flag to `true` and recompile to enable dynamic checks
constexpr
bool
FLAGS_enable_nccl_dynamic_check
=
false
;
constexpr
int64_t
kWaitBlockTImeout
=
10
;
constexpr
int64_t
kWaitBlockTImeout
=
10
;
namespace
paddle
{
namespace
paddle
{
...
@@ -89,12 +91,10 @@ ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
...
@@ -89,12 +91,10 @@ ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
:
ProcessGroupStream
(
rank
,
size
,
gid
),
store_
(
store
)
{}
:
ProcessGroupStream
(
rank
,
size
,
gid
),
store_
(
store
)
{}
void
ProcessGroupNCCL
::
GroupStart
()
{
void
ProcessGroupNCCL
::
GroupStart
()
{
NCCL_CHECK
(
p
latform
::
dynload
::
ncclGroupStart
());
NCCL_CHECK
(
p
hi
::
dynload
::
ncclGroupStart
());
}
}
void
ProcessGroupNCCL
::
GroupEnd
()
{
void
ProcessGroupNCCL
::
GroupEnd
()
{
NCCL_CHECK
(
phi
::
dynload
::
ncclGroupEnd
());
}
NCCL_CHECK
(
platform
::
dynload
::
ncclGroupEnd
());
}
phi
::
DeviceContext
*
ProcessGroupNCCL
::
GetDeviceContext
(
phi
::
DeviceContext
*
ProcessGroupNCCL
::
GetDeviceContext
(
const
Place
&
place
)
const
{
const
Place
&
place
)
const
{
...
@@ -146,7 +146,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
...
@@ -146,7 +146,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
size_
);
size_
);
return
RunFnInNCCLEnv
(
return
RunFnInNCCLEnv
(
[
&
](
ncclComm_t
comm
,
gpuStream_t
stream
)
{
[
&
](
ncclComm_t
comm
,
gpuStream_t
stream
)
{
NCCL_CHECK
(
platform
::
dynload
::
ncclAllGather
(
if
(
FLAGS_enable_nccl_dynamic_check
)
{
CommDynamicCheck
::
CheckShape
(
*
out_tensor
,
/*root_rank*/
0
,
rank_
,
comm
);
}
NCCL_CHECK
(
phi
::
dynload
::
ncclAllGather
(
in_tensor_maybe_partial
.
data
(),
in_tensor_maybe_partial
.
data
(),
out_tensor
->
data
(),
out_tensor
->
data
(),
in_tensor_maybe_partial
.
numel
(),
in_tensor_maybe_partial
.
numel
(),
...
@@ -173,7 +179,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
...
@@ -173,7 +179,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
size_
);
size_
);
return
RunFnInNCCLEnv
(
return
RunFnInNCCLEnv
(
[
&
](
ncclComm_t
comm
,
gpuStream_t
stream
)
{
[
&
](
ncclComm_t
comm
,
gpuStream_t
stream
)
{
NCCL_CHECK
(
platform
::
dynload
::
ncclAllReduce
(
if
(
FLAGS_enable_nccl_dynamic_check
)
{
CommDynamicCheck
::
CheckShape
(
*
out_tensor
,
/*root_rank*/
0
,
rank_
,
comm
);
}
NCCL_CHECK
(
phi
::
dynload
::
ncclAllReduce
(
in_tensor
.
data
(),
in_tensor
.
data
(),
out_tensor
->
data
(),
out_tensor
->
data
(),
in_tensor
.
numel
(),
in_tensor
.
numel
(),
...
@@ -219,9 +231,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
...
@@ -219,9 +231,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
CheckSizeOnEachRank
(
out_dim
,
out_size_each_rank
,
size_
);
CheckSizeOnEachRank
(
out_dim
,
out_size_each_rank
,
size_
);
CheckSizeOnEachRank
(
in_dim
,
in_size_each_rank
,
size_
);
CheckSizeOnEachRank
(
in_dim
,
in_size_each_rank
,
size_
);
// NOTE: Since `all_to_all` needs other processes'
s
participation, it cannot
// NOTE: Since `all_to_all` needs other processes' participation, it cannot
// simply be covered by static checks. Factors are set to 0 here to skip the
// simply be covered by static checks. Factors are set to 0 here to skip the
// shape check. Its shape check will be done by dynamic checks in debug mode.
// shape check. Its shape check will be done by dynamic checks with
// FLAGS_enable_nccl_dynamic_check.
CommStaticCheck
::
CheckShape
(
*
out_tensor
,
CommStaticCheck
::
CheckShape
(
*
out_tensor
,
in_tensor
,
in_tensor
,
/*dst_rank*/
rank_
,
/*dst_rank*/
rank_
,
...
@@ -231,6 +244,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
...
@@ -231,6 +244,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
/*in_size_factor*/
0
);
/*in_size_factor*/
0
);
return
RunFnInNCCLEnv
(
return
RunFnInNCCLEnv
(
[
&
](
ncclComm_t
comm
,
gpuStream_t
stream
)
{
[
&
](
ncclComm_t
comm
,
gpuStream_t
stream
)
{
if
(
FLAGS_enable_nccl_dynamic_check
)
{
CommDynamicCheck
::
CheckShape
(
*
out_tensor
,
in_tensor
,
in_size_each_rank
,
rank_
,
size_
,
comm
);
}
int64_t
in_row_size
=
in_tensor
.
numel
()
/
in_dim
[
0
],
int64_t
in_row_size
=
in_tensor
.
numel
()
/
in_dim
[
0
],
out_row_size
=
out_tensor
->
numel
()
/
out_dim
[
0
];
out_row_size
=
out_tensor
->
numel
()
/
out_dim
[
0
];
int64_t
in_offset
=
0
,
in_numel
=
0
,
out_offset
=
0
,
out_numel
=
0
;
int64_t
in_offset
=
0
,
in_numel
=
0
,
out_offset
=
0
,
out_numel
=
0
;
...
@@ -240,7 +257,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
...
@@ -240,7 +257,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
for
(
auto
i
=
0
;
i
<
size_
;
i
++
)
{
for
(
auto
i
=
0
;
i
<
size_
;
i
++
)
{
in_numel
=
in_size_each_rank
[
i
]
*
in_row_size
;
in_numel
=
in_size_each_rank
[
i
]
*
in_row_size
;
input_partial
=
GetPartialTensor
(
in_tensor
,
in_offset
,
in_numel
);
input_partial
=
GetPartialTensor
(
in_tensor
,
in_offset
,
in_numel
);
NCCL_CHECK
(
p
latform
::
dynload
::
ncclSend
(
NCCL_CHECK
(
p
hi
::
dynload
::
ncclSend
(
input_partial
.
data
(),
input_partial
.
data
(),
in_numel
,
in_numel
,
platform
::
ToNCCLDataType
(
input_partial
.
dtype
()),
platform
::
ToNCCLDataType
(
input_partial
.
dtype
()),
...
@@ -251,7 +268,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
...
@@ -251,7 +268,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
out_numel
=
out_size_each_rank
[
i
]
*
out_row_size
;
out_numel
=
out_size_each_rank
[
i
]
*
out_row_size
;
output_partial
=
GetPartialTensor
(
*
out_tensor
,
out_offset
,
out_numel
);
output_partial
=
GetPartialTensor
(
*
out_tensor
,
out_offset
,
out_numel
);
NCCL_CHECK
(
p
latform
::
dynload
::
ncclRecv
(
NCCL_CHECK
(
p
hi
::
dynload
::
ncclRecv
(
output_partial
.
data
(),
output_partial
.
data
(),
out_numel
,
out_numel
,
platform
::
ToNCCLDataType
(
output_partial
.
dtype
()),
platform
::
ToNCCLDataType
(
output_partial
.
dtype
()),
...
@@ -304,7 +321,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
...
@@ -304,7 +321,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
return
RunFnInNCCLEnv
(
return
RunFnInNCCLEnv
(
[
&
](
ncclComm_t
comm
,
gpuStream_t
stream
)
{
[
&
](
ncclComm_t
comm
,
gpuStream_t
stream
)
{
int
root
=
opts
.
source_rank
+
opts
.
source_root
;
int
root
=
opts
.
source_rank
+
opts
.
source_root
;
NCCL_CHECK
(
platform
::
dynload
::
ncclBroadcast
(
if
(
FLAGS_enable_nccl_dynamic_check
)
{
CommDynamicCheck
::
CheckShape
(
*
out_tensor
,
root
,
rank_
,
comm
);
}
NCCL_CHECK
(
phi
::
dynload
::
ncclBroadcast
(
in_tensor
.
data
(),
in_tensor
.
data
(),
out_tensor
->
data
(),
out_tensor
->
data
(),
in_tensor
.
numel
(),
in_tensor
.
numel
(),
...
@@ -332,7 +352,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
...
@@ -332,7 +352,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
size_
);
size_
);
return
RunFnInNCCLEnv
(
return
RunFnInNCCLEnv
(
[
&
](
ncclComm_t
comm
,
gpuStream_t
stream
)
{
[
&
](
ncclComm_t
comm
,
gpuStream_t
stream
)
{
NCCL_CHECK
(
platform
::
dynload
::
ncclReduce
(
if
(
FLAGS_enable_nccl_dynamic_check
)
{
CommDynamicCheck
::
CheckShape
(
*
out_tensor
,
/*root_rank*/
opts
.
root_rank
,
rank_
,
comm
);
}
NCCL_CHECK
(
phi
::
dynload
::
ncclReduce
(
in_tensor
.
data
(),
in_tensor
.
data
(),
out_tensor
->
data
(),
out_tensor
->
data
(),
in_tensor
.
numel
(),
in_tensor
.
numel
(),
...
@@ -361,7 +387,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
...
@@ -361,7 +387,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
size_
);
size_
);
return
RunFnInNCCLEnv
(
return
RunFnInNCCLEnv
(
[
&
](
ncclComm_t
comm
,
gpuStream_t
stream
)
{
[
&
](
ncclComm_t
comm
,
gpuStream_t
stream
)
{
NCCL_CHECK
(
platform
::
dynload
::
ncclReduceScatter
(
if
(
FLAGS_enable_nccl_dynamic_check
)
{
CommDynamicCheck
::
CheckShape
(
*
out_tensor
,
/*root_rank*/
0
,
rank_
,
comm
);
}
NCCL_CHECK
(
phi
::
dynload
::
ncclReduceScatter
(
in_tensor
.
data
(),
in_tensor
.
data
(),
out_tensor
->
data
(),
out_tensor
->
data
(),
out_tensor
->
numel
(),
out_tensor
->
numel
(),
...
@@ -389,6 +421,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
...
@@ -389,6 +421,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
size_
);
size_
);
return
RunFnInNCCLEnv
(
return
RunFnInNCCLEnv
(
[
&
](
ncclComm_t
comm
,
gpuStream_t
stream
)
{
[
&
](
ncclComm_t
comm
,
gpuStream_t
stream
)
{
if
(
FLAGS_enable_nccl_dynamic_check
)
{
CommDynamicCheck
::
CheckShape
(
*
out_tensor
,
/*root_rank*/
opts
.
root_rank
,
rank_
,
comm
);
}
int64_t
numel
=
in_tensor
.
numel
()
/
size_
;
int64_t
numel
=
in_tensor
.
numel
()
/
size_
;
if
(
rank_
==
opts
.
root_rank
)
{
if
(
rank_
==
opts
.
root_rank
)
{
int64_t
offset
=
0
;
int64_t
offset
=
0
;
...
@@ -396,7 +434,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
...
@@ -396,7 +434,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
GroupStart
();
GroupStart
();
for
(
auto
i
=
0
;
i
<
size_
;
i
++
)
{
for
(
auto
i
=
0
;
i
<
size_
;
i
++
)
{
partial_tensor
=
GetPartialTensor
(
in_tensor
,
offset
,
numel
);
partial_tensor
=
GetPartialTensor
(
in_tensor
,
offset
,
numel
);
NCCL_CHECK
(
p
latform
::
dynload
::
ncclSend
(
NCCL_CHECK
(
p
hi
::
dynload
::
ncclSend
(
partial_tensor
.
data
(),
partial_tensor
.
data
(),
numel
,
numel
,
platform
::
ToNCCLDataType
(
partial_tensor
.
dtype
()),
platform
::
ToNCCLDataType
(
partial_tensor
.
dtype
()),
...
@@ -405,7 +443,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
...
@@ -405,7 +443,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
stream
));
stream
));
offset
+=
numel
;
offset
+=
numel
;
}
}
NCCL_CHECK
(
p
latform
::
dynload
::
ncclRecv
(
NCCL_CHECK
(
p
hi
::
dynload
::
ncclRecv
(
out_tensor
->
data
(),
out_tensor
->
data
(),
numel
,
numel
,
platform
::
ToNCCLDataType
(
out_tensor
->
dtype
()),
platform
::
ToNCCLDataType
(
out_tensor
->
dtype
()),
...
@@ -414,7 +452,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
...
@@ -414,7 +452,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
stream
));
stream
));
GroupEnd
();
GroupEnd
();
}
else
{
}
else
{
NCCL_CHECK
(
p
latform
::
dynload
::
ncclRecv
(
NCCL_CHECK
(
p
hi
::
dynload
::
ncclRecv
(
out_tensor
->
data
(),
out_tensor
->
data
(),
numel
,
numel
,
platform
::
ToNCCLDataType
(
out_tensor
->
dtype
()),
platform
::
ToNCCLDataType
(
out_tensor
->
dtype
()),
...
@@ -443,11 +481,17 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
...
@@ -443,11 +481,17 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
tensor
=
&
partial_tensor
;
tensor
=
&
partial_tensor
;
}
}
CommStaticCheck
::
SingleTensor
(
*
tensor
,
rank_
,
size_
);
CommStaticCheck
::
CheckShape
(
*
tensor
,
rank_
,
size_
);
return
RunFnInNCCLEnv
(
return
RunFnInNCCLEnv
(
[
&
](
ncclComm_t
comm
,
gpuStream_t
stream
)
{
[
&
](
ncclComm_t
comm
,
gpuStream_t
stream
)
{
NCCL_CHECK
(
platform
::
dynload
::
ncclRecv
(
if
(
FLAGS_enable_nccl_dynamic_check
)
{
tensor
->
data
(),
CommDynamicCheck
::
CheckShape
(
*
tensor
,
/*root_rank*/
src_rank
,
rank_
,
comm
);
}
NCCL_CHECK
(
phi
::
dynload
::
ncclRecv
(
tensor
->
data
(),
tensor
->
numel
(),
tensor
->
numel
(),
platform
::
ToNCCLDataType
(
tensor
->
dtype
()),
platform
::
ToNCCLDataType
(
tensor
->
dtype
()),
src_rank
,
src_rank
,
...
@@ -471,10 +515,16 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
...
@@ -471,10 +515,16 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
const
phi
::
DenseTensor
&
tensor_maybe_partial
=
const
phi
::
DenseTensor
&
tensor_maybe_partial
=
numel
>
0
?
GetPartialTensor
(
tensor
,
offset
,
numel
)
:
tensor
;
numel
>
0
?
GetPartialTensor
(
tensor
,
offset
,
numel
)
:
tensor
;
CommStaticCheck
::
SingleTensor
(
tensor_maybe_partial
,
rank_
,
size_
);
CommStaticCheck
::
CheckShape
(
tensor_maybe_partial
,
rank_
,
size_
);
return
RunFnInNCCLEnv
(
return
RunFnInNCCLEnv
(
[
&
](
ncclComm_t
comm
,
gpuStream_t
stream
)
{
[
&
](
ncclComm_t
comm
,
gpuStream_t
stream
)
{
NCCL_CHECK
(
platform
::
dynload
::
ncclSend
(
if
(
FLAGS_enable_nccl_dynamic_check
)
{
CommDynamicCheck
::
CheckShape
(
tensor_maybe_partial
,
/*root_rank*/
rank_
,
rank_
,
comm
);
}
NCCL_CHECK
(
phi
::
dynload
::
ncclSend
(
tensor_maybe_partial
.
data
(),
tensor_maybe_partial
.
data
(),
tensor_maybe_partial
.
numel
(),
tensor_maybe_partial
.
numel
(),
platform
::
ToNCCLDataType
(
tensor_maybe_partial
.
dtype
()),
platform
::
ToNCCLDataType
(
tensor_maybe_partial
.
dtype
()),
...
@@ -520,7 +570,7 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
...
@@ -520,7 +570,7 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
ncclUniqueId
nccl_id
;
ncclUniqueId
nccl_id
;
if
(
rank_
==
0
)
{
if
(
rank_
==
0
)
{
NCCL_CHECK
(
p
latform
::
dynload
::
ncclGetUniqueId
(
&
nccl_id
));
NCCL_CHECK
(
p
hi
::
dynload
::
ncclGetUniqueId
(
&
nccl_id
));
}
}
BroadcastUniqueNCCLID
(
&
nccl_id
);
BroadcastUniqueNCCLID
(
&
nccl_id
);
...
@@ -532,7 +582,7 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
...
@@ -532,7 +582,7 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
auto
comm_ctx
=
std
::
make_unique
<
phi
::
GPUContext
>
(
place
);
auto
comm_ctx
=
std
::
make_unique
<
phi
::
GPUContext
>
(
place
);
ncclComm_t
nccl_comm
;
ncclComm_t
nccl_comm
;
NCCL_CHECK
(
p
latform
::
dynload
::
ncclCommInitRank
(
NCCL_CHECK
(
p
hi
::
dynload
::
ncclCommInitRank
(
&
nccl_comm
,
GetSize
(),
nccl_id
,
GetRank
()));
&
nccl_comm
,
GetSize
(),
nccl_id
,
GetRank
()));
comm_ctx
->
set_nccl_comm
(
nccl_comm
);
comm_ctx
->
set_nccl_comm
(
nccl_comm
);
...
@@ -589,6 +639,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::RunFnInNCCLEnv(
...
@@ -589,6 +639,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::RunFnInNCCLEnv(
task
->
UpdateWaitChain
(
*
comm_ctx
);
task
->
UpdateWaitChain
(
*
comm_ctx
);
}
}
if
(
FLAGS_enable_nccl_dynamic_check
)
{
task
->
SetBlockCPUInWait
();
task
->
Wait
();
}
return
task
;
return
task
;
}
}
...
@@ -633,7 +687,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
...
@@ -633,7 +687,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
ncclUniqueId
nccl_id
;
ncclUniqueId
nccl_id
;
if
(
rank_
==
0
)
{
if
(
rank_
==
0
)
{
NCCL_CHECK
(
p
latform
::
dynload
::
ncclGetUniqueId
(
&
nccl_id
));
NCCL_CHECK
(
p
hi
::
dynload
::
ncclGetUniqueId
(
&
nccl_id
));
}
}
BroadcastUniqueNCCLID
(
&
nccl_id
);
BroadcastUniqueNCCLID
(
&
nccl_id
);
...
@@ -654,7 +708,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
...
@@ -654,7 +708,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
dev_ctx
[
i
].
reset
(
new
phi
::
GPUContext
(
places
[
i
]));
dev_ctx
[
i
].
reset
(
new
phi
::
GPUContext
(
places
[
i
]));
ncclComm_t
nccl_comm
;
ncclComm_t
nccl_comm
;
NCCL_CHECK
(
p
latform
::
dynload
::
ncclCommInitRank
(
NCCL_CHECK
(
p
hi
::
dynload
::
ncclCommInitRank
(
&
nccl_comm
,
GetSize
(),
nccl_id
,
GetRank
()));
&
nccl_comm
,
GetSize
(),
nccl_id
,
GetRank
()));
dev_ctx
[
i
]
->
set_nccl_comm
(
nccl_comm
);
dev_ctx
[
i
]
->
set_nccl_comm
(
nccl_comm
);
dev_ctx_raw
[
i
]
=
dev_ctx
[
i
].
get
();
dev_ctx_raw
[
i
]
=
dev_ctx
[
i
].
get
();
...
@@ -791,7 +845,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
...
@@ -791,7 +845,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
phi
::
DenseTensor
&
output
,
phi
::
DenseTensor
&
output
,
ncclComm_t
comm
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
const
gpuStream_t
&
stream
)
{
return
p
latform
::
dynload
::
ncclAllReduce
(
return
p
hi
::
dynload
::
ncclAllReduce
(
input
.
data
(),
input
.
data
(),
output
.
data
(),
output
.
data
(),
input
.
numel
(),
input
.
numel
(),
...
@@ -821,7 +875,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
...
@@ -821,7 +875,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
const
gpuStream_t
&
stream
)
{
const
gpuStream_t
&
stream
)
{
const
auto
root
=
const
auto
root
=
opts
.
source_rank
*
in_tensors
.
size
()
+
opts
.
source_root
;
opts
.
source_rank
*
in_tensors
.
size
()
+
opts
.
source_root
;
return
p
latform
::
dynload
::
ncclBroadcast
(
return
p
hi
::
dynload
::
ncclBroadcast
(
input
.
data
(),
input
.
data
(),
output
.
data
(),
output
.
data
(),
input
.
numel
(),
input
.
numel
(),
...
@@ -871,8 +925,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
...
@@ -871,8 +925,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
ncclComm_t
comm
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
,
const
gpuStream_t
&
stream
,
int
dst_rank
)
{
int
dst_rank
)
{
return
platform
::
dynload
::
ncclSend
(
return
phi
::
dynload
::
ncclSend
(
input
.
data
(),
input
.
data
(),
input
.
numel
(),
input
.
numel
(),
platform
::
ToNCCLDataType
(
input
.
dtype
()),
platform
::
ToNCCLDataType
(
input
.
dtype
()),
dst_rank
,
dst_rank
,
...
@@ -894,8 +947,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
...
@@ -894,8 +947,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
ncclComm_t
comm
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
,
const
gpuStream_t
&
stream
,
int
src_rank
)
{
int
src_rank
)
{
return
platform
::
dynload
::
ncclRecv
(
return
phi
::
dynload
::
ncclRecv
(
output
.
data
(),
output
.
data
(),
output
.
numel
(),
output
.
numel
(),
platform
::
ToNCCLDataType
(
output
.
dtype
()),
platform
::
ToNCCLDataType
(
output
.
dtype
()),
src_rank
,
src_rank
,
...
@@ -925,7 +977,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
...
@@ -925,7 +977,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
phi
::
DenseTensor
&
output
,
phi
::
DenseTensor
&
output
,
ncclComm_t
comm
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
const
gpuStream_t
&
stream
)
{
return
p
latform
::
dynload
::
ncclAllGather
(
return
p
hi
::
dynload
::
ncclAllGather
(
input
.
data
(),
input
.
data
(),
output
.
data
(),
output
.
data
(),
input
.
numel
(),
input
.
numel
(),
...
@@ -994,14 +1046,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
...
@@ -994,14 +1046,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
size_t
offset
=
0
;
size_t
offset
=
0
;
GroupStart
();
GroupStart
();
for
(
auto
i
=
0
;
i
<
size_
;
i
++
)
{
for
(
auto
i
=
0
;
i
<
size_
;
i
++
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
p
latform
::
dynload
::
ncclSend
(
PADDLE_ENFORCE_GPU_SUCCESS
(
p
hi
::
dynload
::
ncclSend
(
GetPointerByOffset
(
input
.
data
(),
offset
,
input
.
dtype
()),
GetPointerByOffset
(
input
.
data
(),
offset
,
input
.
dtype
()),
input
.
numel
()
/
size_
,
input
.
numel
()
/
size_
,
platform
::
ToNCCLDataType
(
input
.
dtype
()),
platform
::
ToNCCLDataType
(
input
.
dtype
()),
i
,
i
,
comm
,
comm
,
stream
));
stream
));
PADDLE_ENFORCE_GPU_SUCCESS
(
p
latform
::
dynload
::
ncclRecv
(
PADDLE_ENFORCE_GPU_SUCCESS
(
p
hi
::
dynload
::
ncclRecv
(
GetPointerByOffset
(
output
.
data
(),
offset
,
input
.
dtype
()),
GetPointerByOffset
(
output
.
data
(),
offset
,
input
.
dtype
()),
input
.
numel
()
/
size_
,
input
.
numel
()
/
size_
,
platform
::
ToNCCLDataType
(
input
.
dtype
()),
platform
::
ToNCCLDataType
(
input
.
dtype
()),
...
@@ -1030,8 +1082,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
...
@@ -1030,8 +1082,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
phi
::
DenseTensor
&
output
,
phi
::
DenseTensor
&
output
,
ncclComm_t
comm
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
const
gpuStream_t
&
stream
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclReduce
(
PADDLE_ENFORCE_GPU_SUCCESS
(
input
.
data
(),
phi
::
dynload
::
ncclReduce
(
input
.
data
(),
output
.
data
(),
output
.
data
(),
input
.
numel
(),
input
.
numel
(),
platform
::
ToNCCLDataType
(
input
.
dtype
()),
platform
::
ToNCCLDataType
(
input
.
dtype
()),
...
@@ -1066,7 +1118,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
...
@@ -1066,7 +1118,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
if
(
rank_
==
opts
.
root_rank
)
{
if
(
rank_
==
opts
.
root_rank
)
{
GroupStart
();
GroupStart
();
for
(
auto
i
=
0
;
i
<
size_
;
i
++
)
{
for
(
auto
i
=
0
;
i
<
size_
;
i
++
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
p
latform
::
dynload
::
ncclSend
(
PADDLE_ENFORCE_GPU_SUCCESS
(
p
hi
::
dynload
::
ncclSend
(
GetPointerByOffset
(
input
.
data
(),
offset
,
input
.
dtype
()),
GetPointerByOffset
(
input
.
data
(),
offset
,
input
.
dtype
()),
input
.
numel
()
/
size_
,
input
.
numel
()
/
size_
,
platform
::
ToNCCLDataType
(
input
.
dtype
()),
platform
::
ToNCCLDataType
(
input
.
dtype
()),
...
@@ -1075,8 +1127,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
...
@@ -1075,8 +1127,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
stream
));
stream
));
offset
+=
input
.
numel
()
/
size_
;
offset
+=
input
.
numel
()
/
size_
;
}
}
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclRecv
(
PADDLE_ENFORCE_GPU_SUCCESS
(
output
.
data
(),
phi
::
dynload
::
ncclRecv
(
output
.
data
(),
input
.
numel
()
/
size_
,
input
.
numel
()
/
size_
,
platform
::
ToNCCLDataType
(
input
.
dtype
()),
platform
::
ToNCCLDataType
(
input
.
dtype
()),
opts
.
root_rank
,
opts
.
root_rank
,
...
@@ -1084,8 +1136,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
...
@@ -1084,8 +1136,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
stream
));
stream
));
GroupEnd
();
GroupEnd
();
}
else
{
}
else
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclRecv
(
PADDLE_ENFORCE_GPU_SUCCESS
(
output
.
data
(),
phi
::
dynload
::
ncclRecv
(
output
.
data
(),
input
.
numel
()
/
size_
,
input
.
numel
()
/
size_
,
platform
::
ToNCCLDataType
(
input
.
dtype
()),
platform
::
ToNCCLDataType
(
input
.
dtype
()),
opts
.
root_rank
,
opts
.
root_rank
,
...
...
paddle/fluid/distributed/collective/ProcessGroupNCCL.h
浏览文件 @
e7711592
...
@@ -33,9 +33,9 @@
...
@@ -33,9 +33,9 @@
#endif
#endif
#ifdef PADDLE_WITH_RCCL
#ifdef PADDLE_WITH_RCCL
#include "paddle/
fluid/platform
/dynload/rccl.h"
#include "paddle/
phi/backends
/dynload/rccl.h"
#elif PADDLE_WITH_NCCL
#elif PADDLE_WITH_NCCL
#include "paddle/
fluid/platform
/dynload/nccl.h"
#include "paddle/
phi/backends
/dynload/nccl.h"
#endif
#endif
namespace
paddle
{
namespace
paddle
{
...
...
paddle/fluid/distributed/collective/
static_
check.cc
→
paddle/fluid/distributed/collective/check.cc
浏览文件 @
e7711592
...
@@ -12,16 +12,32 @@
...
@@ -12,16 +12,32 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/distributed/collective/
static_
check.h"
#include "paddle/fluid/distributed/collective/check.h"
#include "paddle/fluid/distributed/collective/NCCLTools.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/errors.h"
#ifdef PADDLE_WITH_HIP
#define gpuMalloc hipMalloc
#define gpuMemcpy hipMemcpy
#define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost
#define gpuMemcpyHostToDevice hipMemcpyHostToDevice
#define gpuFree hipFree
#else
#define gpuMalloc cudaMalloc
#define gpuMemcpy cudaMemcpy
#define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost
#define gpuMemcpyHostToDevice cudaMemcpyHostToDevice
#define gpuFree cudaFree
#endif
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
// static checks
void
CommStaticCheck
::
CheckRank
(
int
rank
,
int
world_size
)
{
void
CommStaticCheck
::
CheckRank
(
int
rank
,
int
world_size
)
{
PADDLE_ENFORCE_GE
(
rank
,
PADDLE_ENFORCE_GE
(
rank
,
0
,
0
,
...
@@ -102,7 +118,7 @@ void CommStaticCheck::CheckShape(const phi::DenseTensor& out_tensor,
...
@@ -102,7 +118,7 @@ void CommStaticCheck::CheckShape(const phi::DenseTensor& out_tensor,
}
}
}
}
void
CommStaticCheck
::
SingleTensor
(
const
phi
::
DenseTensor
&
tensor
,
void
CommStaticCheck
::
CheckShape
(
const
phi
::
DenseTensor
&
tensor
,
int
rank
,
int
rank
,
int
world_size
)
{
int
world_size
)
{
CheckPlace
(
tensor
);
CheckPlace
(
tensor
);
...
@@ -151,5 +167,124 @@ void CommStaticCheck::GatherLikeShape(const phi::DenseTensor& out_tensor,
...
@@ -151,5 +167,124 @@ void CommStaticCheck::GatherLikeShape(const phi::DenseTensor& out_tensor,
/*in_size_factor*/
world_size
);
/*in_size_factor*/
world_size
);
}
}
// dynamic checks
void
CommDynamicCheck
::
CheckDataType
(
const
phi
::
DenseTensor
&
tensor
,
int64_t
dtype
)
{
PADDLE_ENFORCE_EQ
(
static_cast
<
int64_t
>
(
tensor
.
dtype
()),
dtype
,
phi
::
errors
::
InvalidArgument
(
"Tensors in communication are expected to have the same data type."
));
}
void
CommDynamicCheck
::
CheckDataType
(
const
phi
::
DenseTensor
&
tensor
,
int
root_rank
,
int
cur_rank
,
ncclComm_t
comm
)
{
constexpr
int
kSize
=
sizeof
(
int64_t
);
int64_t
dtype_host
=
static_cast
<
int64_t
>
(
tensor
.
dtype
());
int64_t
*
dtype_device
;
PADDLE_ENFORCE_GPU_SUCCESS
(
gpuMalloc
(
&
dtype_device
,
kSize
));
PADDLE_ENFORCE_GPU_SUCCESS
(
gpuMemcpy
(
dtype_device
,
&
dtype_host
,
kSize
,
gpuMemcpyHostToDevice
));
NCCL_CHECK
(
phi
::
dynload
::
ncclBroadcast
(
dtype_device
,
dtype_device
,
kSize
,
ncclInt64
,
root_rank
,
comm
,
kDefaultStream
));
if
(
root_rank
==
cur_rank
)
{
VLOG
(
3
)
<<
"Dynamic check broadcast metadata, dtype: "
<<
dtype_host
;
}
else
{
PADDLE_ENFORCE_GPU_SUCCESS
(
gpuMemcpy
(
&
dtype_host
,
dtype_device
,
kSize
,
gpuMemcpyDeviceToHost
));
VLOG
(
3
)
<<
"Dynamic check recv metadata, dtype: "
<<
dtype_host
;
CheckDataType
(
tensor
,
dtype_host
);
}
PADDLE_ENFORCE_GPU_SUCCESS
(
gpuFree
(
dtype_device
));
}
void
CommDynamicCheck
::
CheckShape
(
const
phi
::
DenseTensor
&
tensor
,
int64_t
shape
)
{
PADDLE_ENFORCE_EQ
(
tensor
.
numel
(),
shape
,
phi
::
errors
::
InvalidArgument
(
"Tensors in communication are expected to have matching sizes."
));
}
void
CommDynamicCheck
::
CheckShape
(
const
phi
::
DenseTensor
&
tensor
,
int
root_rank
,
int
cur_rank
,
ncclComm_t
comm
)
{
CheckDataType
(
tensor
,
root_rank
,
cur_rank
,
comm
);
constexpr
int
kSize
=
sizeof
(
int64_t
);
int64_t
shape_host
=
tensor
.
numel
();
int64_t
*
shape_device
;
PADDLE_ENFORCE_GPU_SUCCESS
(
gpuMalloc
(
&
shape_device
,
kSize
));
PADDLE_ENFORCE_GPU_SUCCESS
(
gpuMemcpy
(
shape_device
,
&
shape_host
,
kSize
,
gpuMemcpyHostToDevice
));
NCCL_CHECK
(
phi
::
dynload
::
ncclBroadcast
(
shape_device
,
shape_device
,
kSize
,
ncclInt64
,
root_rank
,
comm
,
kDefaultStream
));
if
(
root_rank
==
cur_rank
)
{
VLOG
(
3
)
<<
"Dynamic check broadcast metadata, shape: "
<<
shape_host
;
}
else
{
PADDLE_ENFORCE_GPU_SUCCESS
(
gpuMemcpy
(
&
shape_host
,
shape_device
,
kSize
,
gpuMemcpyDeviceToHost
));
VLOG
(
3
)
<<
"Dynamic check recv metadata, shape: "
<<
shape_host
;
CheckShape
(
tensor
,
shape_host
);
}
PADDLE_ENFORCE_GPU_SUCCESS
(
gpuFree
(
shape_device
));
}
void
CommDynamicCheck
::
CheckShape
(
const
phi
::
DenseTensor
&
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
std
::
vector
<
int64_t
>&
in_size_each_rank
,
int
cur_rank
,
int
world_size
,
ncclComm_t
comm
)
{
CheckDataType
(
out_tensor
,
/*root_rank*/
0
,
cur_rank
,
comm
);
CheckDataType
(
in_tensor
,
/*root_rank*/
0
,
cur_rank
,
comm
);
constexpr
int
kSize
=
sizeof
(
int64_t
);
int64_t
in_row_size
=
in_tensor
.
numel
()
/
in_tensor
.
dims
()[
0
];
for
(
int
rank
=
0
;
rank
<
world_size
;
++
rank
)
{
int64_t
in_shape_host
=
in_size_each_rank
[
rank
]
*
in_row_size
;
int64_t
*
in_shape_device
;
PADDLE_ENFORCE_GPU_SUCCESS
(
gpuMalloc
(
&
in_shape_device
,
kSize
));
PADDLE_ENFORCE_GPU_SUCCESS
(
gpuMemcpy
(
in_shape_device
,
&
in_shape_host
,
kSize
,
gpuMemcpyHostToDevice
));
NCCL_CHECK
(
phi
::
dynload
::
ncclReduce
(
in_shape_device
,
in_shape_device
,
kSize
,
ncclInt64
,
ncclSum
,
rank
,
comm
,
kDefaultStream
));
if
(
rank
==
cur_rank
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
gpuMemcpy
(
&
in_shape_host
,
in_shape_device
,
kSize
,
gpuMemcpyDeviceToHost
));
VLOG
(
3
)
<<
"Dynamic check recv metadata, shape: "
<<
in_shape_host
;
CheckShape
(
out_tensor
,
in_shape_host
);
}
PADDLE_ENFORCE_GPU_SUCCESS
(
gpuFree
(
in_shape_device
));
}
}
}
// namespace distributed
}
// namespace distributed
}
// namespace paddle
}
// namespace paddle
paddle/fluid/distributed/collective/
static_
check.h
→
paddle/fluid/distributed/collective/check.h
浏览文件 @
e7711592
...
@@ -14,7 +14,18 @@
...
@@ -14,7 +14,18 @@
#pragma once
#pragma once
// forward declaration to reduce deps
#include <cstdint>
#include <vector>
#include "paddle/phi/backends/gpu/forwards.h"
#ifdef PADDLE_WITH_HIP
using
gpuStream_t
=
hipStream_t
;
#else
using
gpuStream_t
=
cudaStream_t
;
#endif
// forward declarations
namespace
phi
{
namespace
phi
{
class
DenseTensor
;
class
DenseTensor
;
}
}
...
@@ -49,7 +60,7 @@ struct CommStaticCheck {
...
@@ -49,7 +60,7 @@ struct CommStaticCheck {
int
in_size_factor
);
int
in_size_factor
);
// for p2p
// for p2p
static
void
SingleTensor
(
const
phi
::
DenseTensor
&
tensor
,
static
void
CheckShape
(
const
phi
::
DenseTensor
&
tensor
,
int
rank
,
int
rank
,
int
world_size
);
int
world_size
);
...
@@ -73,5 +84,32 @@ struct CommStaticCheck {
...
@@ -73,5 +84,32 @@ struct CommStaticCheck {
int
world_size
);
int
world_size
);
};
};
struct
CommDynamicCheck
{
static
void
CheckDataType
(
const
phi
::
DenseTensor
&
tensor
,
int64_t
dtype
);
static
void
CheckDataType
(
const
phi
::
DenseTensor
&
tensor
,
int
root_rank
,
int
cur_rank
,
ncclComm_t
comm
);
static
void
CheckShape
(
const
phi
::
DenseTensor
&
tensor
,
int64_t
shape
);
static
void
CheckShape
(
const
phi
::
DenseTensor
&
tensor
,
int
root_rank
,
int
cur_rank
,
ncclComm_t
comm
);
static
void
CheckShape
(
const
phi
::
DenseTensor
&
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
const
std
::
vector
<
int64_t
>&
in_size_each_rank
,
int
cur_rank
,
int
world_size
,
ncclComm_t
comm
);
private:
// `0` represents default stream for both cuda & hip
static
constexpr
gpuStream_t
kDefaultStream
=
0
;
};
}
// namespace distributed
}
// namespace distributed
}
// namespace paddle
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录