Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9fcdaeba
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9fcdaeba
编写于
3月 01, 2021
作者:
L
lw921014
提交者:
GitHub
3月 01, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add allreduce and broadcast without test (#31024)
add allreduce and broadcast without test
上级
5618f140
变更
29
隐藏空白更改
内联
并排
Showing
29 changed file
with
1895 addition
and
19 deletion
+1895
-19
cmake/external/ascend.cmake
cmake/external/ascend.cmake
+4
-0
cmake/external/protobuf.cmake
cmake/external/protobuf.cmake
+5
-2
cmake/flags.cmake
cmake/flags.cmake
+2
-0
paddle/fluid/memory/allocation/allocator_facade.cc
paddle/fluid/memory/allocation/allocator_facade.cc
+2
-0
paddle/fluid/memory/allocation/allocator_strategy.cc
paddle/fluid/memory/allocation/allocator_strategy.cc
+1
-0
paddle/fluid/operators/collective/CMakeLists.txt
paddle/fluid/operators/collective/CMakeLists.txt
+14
-9
paddle/fluid/operators/collective/c_allreduce_max_op_npu.cc
paddle/fluid/operators/collective/c_allreduce_max_op_npu.cc
+31
-0
paddle/fluid/operators/collective/c_allreduce_min_op_npu.cc
paddle/fluid/operators/collective/c_allreduce_min_op_npu.cc
+31
-0
paddle/fluid/operators/collective/c_allreduce_op.h
paddle/fluid/operators/collective/c_allreduce_op.h
+93
-1
paddle/fluid/operators/collective/c_allreduce_prod_op_npu.cc
paddle/fluid/operators/collective/c_allreduce_prod_op_npu.cc
+31
-0
paddle/fluid/operators/collective/c_allreduce_sum_op_npu.cc
paddle/fluid/operators/collective/c_allreduce_sum_op_npu.cc
+31
-0
paddle/fluid/operators/collective/c_broadcast_op.cc
paddle/fluid/operators/collective/c_broadcast_op.cc
+5
-0
paddle/fluid/operators/collective/c_broadcast_op_npu.cc
paddle/fluid/operators/collective/c_broadcast_op_npu.cc
+94
-0
paddle/fluid/operators/collective/c_comm_init_hccl_op.cc
paddle/fluid/operators/collective/c_comm_init_hccl_op.cc
+79
-0
paddle/fluid/operators/collective/c_create_group_op.cc
paddle/fluid/operators/collective/c_create_group_op.cc
+76
-0
paddle/fluid/operators/collective/c_hcom_op_npu_test.cc
paddle/fluid/operators/collective/c_hcom_op_npu_test.cc
+192
-0
paddle/fluid/platform/CMakeLists.txt
paddle/fluid/platform/CMakeLists.txt
+1
-1
paddle/fluid/platform/collective_helper.h
paddle/fluid/platform/collective_helper.h
+161
-4
paddle/fluid/platform/collective_helper_npu.cc
paddle/fluid/platform/collective_helper_npu.cc
+111
-0
paddle/fluid/platform/dynload/CMakeLists.txt
paddle/fluid/platform/dynload/CMakeLists.txt
+3
-1
paddle/fluid/platform/dynload/base.h
paddle/fluid/platform/dynload/base.h
+127
-0
paddle/fluid/platform/dynload/dynamic_loader.cc
paddle/fluid/platform/dynload/dynamic_loader.cc
+29
-0
paddle/fluid/platform/dynload/dynamic_loader.h
paddle/fluid/platform/dynload/dynamic_loader.h
+1
-0
paddle/fluid/platform/dynload/hccl.cc
paddle/fluid/platform/dynload/hccl.cc
+38
-0
paddle/fluid/platform/dynload/hccl.h
paddle/fluid/platform/dynload/hccl.h
+84
-0
paddle/fluid/platform/dynload/hcom.h
paddle/fluid/platform/dynload/hcom.h
+275
-0
paddle/fluid/platform/enforce.h
paddle/fluid/platform/enforce.h
+9
-0
paddle/fluid/platform/hccl_helper.h
paddle/fluid/platform/hccl_helper.h
+364
-0
paddle/fluid/pybind/op_function_generator.cc
paddle/fluid/pybind/op_function_generator.cc
+1
-1
未找到文件。
cmake/external/ascend.cmake
浏览文件 @
9fcdaeba
...
...
@@ -62,6 +62,7 @@ endif()
if
(
WITH_ASCEND_CL
)
set
(
ASCEND_CL_DIR
${
ASCEND_DIR
}
/ascend-toolkit/latest/fwkacllib/lib64
)
set
(
ascend_hccl_lib
${
ASCEND_CL_DIR
}
/libhccl.so
)
set
(
ascendcl_lib
${
ASCEND_CL_DIR
}
/libascendcl.so
)
set
(
acl_op_compiler_lib
${
ASCEND_CL_DIR
}
/libacl_op_compiler.so
)
set
(
ASCEND_CL_INC_DIR
${
ASCEND_DIR
}
/ascend-toolkit/latest/fwkacllib/include
)
...
...
@@ -73,6 +74,9 @@ if(WITH_ASCEND_CL)
ADD_LIBRARY
(
ascendcl SHARED IMPORTED GLOBAL
)
SET_PROPERTY
(
TARGET ascendcl PROPERTY IMPORTED_LOCATION
${
ascendcl_lib
}
)
ADD_LIBRARY
(
ascend_hccl SHARED IMPORTED GLOBAL
)
SET_PROPERTY
(
TARGET ascend_hccl PROPERTY IMPORTED_LOCATION
${
ascend_hccl_lib
}
)
ADD_LIBRARY
(
acl_op_compiler SHARED IMPORTED GLOBAL
)
SET_PROPERTY
(
TARGET acl_op_compiler PROPERTY IMPORTED_LOCATION
${
acl_op_compiler_lib
}
)
add_custom_target
(
extern_ascend_cl DEPENDS ascendcl acl_op_compiler
)
...
...
cmake/external/protobuf.cmake
浏览文件 @
9fcdaeba
...
...
@@ -205,8 +205,11 @@ elseif(WITH_ASCEND_CL AND NOT WITH_ASCEND_CXX11)
SET
(
PROTOBUF_REPOSITORY https://gitee.com/tianjianhe/protobuf.git
)
SET
(
PROTOBUF_TAG v3.8.0
)
else
()
SET
(
PROTOBUF_REPOSITORY
${
GIT_URL
}
/protocolbuffers/protobuf.git
)
SET
(
PROTOBUF_TAG 9f75c5aa851cd877fb0d93ccc31b8567a6706546
)
SET
(
PROTOBUF_REPOSITORY https://gitee.com/tianjianhe/protobuf.git
)
SET
(
PROTOBUF_TAG v3.8.0
)
# SET(PROTOBUF_REPOSITORY ${GIT_URL}/protocolbuffers/protobuf.git)
# SET(PROTOBUF_TAG 9f75c5aa851cd877fb0d93ccc31b8567a6706546)
endif
()
cache_third_party
(
${
TARGET_NAME
}
...
...
cmake/flags.cmake
浏览文件 @
9fcdaeba
...
...
@@ -151,6 +151,8 @@ set(COMMON_FLAGS
-Wno-error=int-in-bool-context
# Warning in Eigen gcc 7.2
-Wimplicit-fallthrough=0
# Warning in tinyformat.h
-Wno-error=maybe-uninitialized
# Warning in boost gcc 7.2
-Wno-error=nonnull-compare
# Warning in boost gcc 7.2
-Wno-error=address
# Warning in boost gcc 7.2
${
fsanitize
}
)
...
...
paddle/fluid/memory/allocation/allocator_facade.cc
浏览文件 @
9fcdaeba
...
...
@@ -79,6 +79,7 @@ class AllocatorFacadePrivate {
InitNaiveBestFitCUDAPinnedAllocator
();
#endif
#ifdef PADDLE_WITH_ASCEND_CL
VLOG
(
3
)
<<
"npu num: "
<<
platform
::
GetNPUDeviceCount
();
for
(
int
dev_id
=
0
;
dev_id
<
platform
::
GetNPUDeviceCount
();
++
dev_id
)
{
InitNaiveBestFitNPUAllocator
(
platform
::
NPUPlace
(
dev_id
));
}
...
...
@@ -141,6 +142,7 @@ class AllocatorFacadePrivate {
(
size
>
0
?
(
UNLIKELY
(
FLAGS_use_system_allocator
)
?
system_allocators_
:
allocators_
)
:
zero_size_allocators_
);
VLOG
(
3
)
<<
size
;
auto
iter
=
allocators
.
find
(
place
);
PADDLE_ENFORCE_NE
(
iter
,
allocators
.
end
(),
platform
::
errors
::
NotFound
(
...
...
paddle/fluid/memory/allocation/allocator_strategy.cc
浏览文件 @
9fcdaeba
...
...
@@ -24,6 +24,7 @@ namespace memory {
namespace
allocation
{
static
AllocatorStrategy
GetStrategyFromFlag
()
{
VLOG
(
3
)
<<
"FLAGS_allocator_strategy"
<<
FLAGS_allocator_strategy
;
if
(
FLAGS_allocator_strategy
==
"naive_best_fit"
)
{
return
AllocatorStrategy
::
kNaiveBestFit
;
}
...
...
paddle/fluid/operators/collective/CMakeLists.txt
浏览文件 @
9fcdaeba
...
...
@@ -11,24 +11,29 @@ foreach(src ${OPS})
set_source_files_properties
(
${
src
}
PROPERTIES COMPILE_FLAGS
${
COLLECTIVE_COMPILE_FLAGS
}
)
endforeach
()
register_operators
(
EXCLUDES c_gen_nccl_id_op gen_nccl_id_op DEPS
${
COLLECTIVE_DEPS
}
)
register_operators
(
EXCLUDES c_gen_
bkcl_id_op gen_bkcl_id_op c_gen_
nccl_id_op gen_nccl_id_op DEPS
${
COLLECTIVE_DEPS
}
)
if
(
WITH_NCCL
)
set
(
COLLECTIVE_DEPS
${
COLLECTIVE_DEPS
}
nccl_common collective_helper
)
cc_library
(
gen_nccl_id_op_helper SRCS gen_nccl_id_op_helper.cc DEPS nccl_common
)
op_library
(
c_gen_nccl_id_op DEPS
${
COLLECTIVE_DEPS
}
gen_nccl_id_op_helper
)
op_library
(
gen_nccl_id_op DEPS
${
COLLECTIVE_DEPS
}
gen_nccl_id_op_helper
)
op_library
(
c_gen_nccl_id_op DEPS
${
COLLECTIVE_DEPS
}
)
op_library
(
gen_nccl_id_op DEPS
${
COLLECTIVE_DEPS
}
)
endif
()
if
(
WITH_ASCEND
)
op_library
(
gen_nccl_id_op
)
op_library
(
c_gen_nccl_id_op
)
if
(
WITH_GLOO
)
set
(
COLLECTIVE_DEPS
${
COLLECTIVE_DEPS
}
gloo_wrapper
)
endif
()
if
(
WITH_XPU_BKCL
)
set
(
COLLECTIVE_DEPS
${
COLLECTIVE_DEPS
}
collective_helper
)
op_library
(
c_gen_bkcl_id_op DEPS
${
COLLECTIVE_DEPS
}
)
op_library
(
gen_bkcl_id_op DEPS
${
COLLECTIVE_DEPS
}
)
endif
()
if
(
WITH_
GLOO
)
set
(
COLLECTIVE_DEPS
${
COLLECTIVE_DEPS
}
gloo_wrap
per
)
if
(
WITH_
ASCEND_CL
)
set
(
COLLECTIVE_DEPS
${
COLLECTIVE_DEPS
}
collective_hel
per
)
endif
()
set
(
OPERATOR_DEPS
${
OPERATOR_DEPS
}
${
COLLECTIVE_DEPS
}
PARENT_SCOPE
)
set
(
GLOB_COLLECTIVE_DEPS
${
COLLECTIVE_DEPS
}
CACHE INTERNAL
"collective dependency"
)
cc_test
(
c_hcom_op_npu_test SRCS c_hcom_op_npu_test.cc DEPS op_registry c_broadcast_op c_allreduce_sum_op c_comm_init_hccl_op c_create_group_op
${
COLLECTIVE_DEPS
}
ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor
)
paddle/fluid/operators/collective/c_allreduce_max_op_npu.cc
0 → 100644
浏览文件 @
9fcdaeba
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace
paddle
{
namespace
platform
{
struct
ASCENDPlace
;
struct
float16
;
}
// namespace platform
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_NPU_KERNEL
(
c_allreduce_max
,
ops
::
CAllReduceOpASCENDKernel
<
ops
::
kRedMax
,
float
>
,
ops
::
CAllReduceOpASCENDKernel
<
ops
::
kRedMax
,
int
>
,
ops
::
CAllReduceOpASCENDKernel
<
ops
::
kRedMax
,
int8_t
>
,
ops
::
CAllReduceOpASCENDKernel
<
ops
::
kRedMax
,
plat
::
float16
>
)
paddle/fluid/operators/collective/c_allreduce_min_op_npu.cc
0 → 100644
浏览文件 @
9fcdaeba
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace
paddle
{
namespace
platform
{
struct
ASCENDPlace
;
struct
float16
;
}
// namespace platform
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_NPU_KERNEL
(
c_allreduce_min
,
ops
::
CAllReduceOpASCENDKernel
<
ops
::
kRedMin
,
float
>
,
ops
::
CAllReduceOpASCENDKernel
<
ops
::
kRedMin
,
int
>
,
ops
::
CAllReduceOpASCENDKernel
<
ops
::
kRedMin
,
int8_t
>
,
ops
::
CAllReduceOpASCENDKernel
<
ops
::
kRedMin
,
plat
::
float16
>
)
paddle/fluid/operators/collective/c_allreduce_op.h
浏览文件 @
9fcdaeba
...
...
@@ -30,6 +30,11 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/hccl_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
...
...
@@ -105,6 +110,88 @@ class CAllReduceOpCPUKernel : public framework::OpKernel<T> {
}
};
template
<
ReduceType
red_type
,
typename
T
>
class
CAllReduceOpASCENDKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_ASCEND_CL)
auto
in
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
place
=
ctx
.
GetPlace
();
hcclDataType_t
dtype
=
platform
::
ToHCCLDataType
(
in
->
type
());
int64_t
numel
=
in
->
numel
();
void
*
sendbuff
=
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
in
->
data
<
T
>
()));
// void* sendbuff = reinterpret_cast<void*>(const_cast<T*>(in->mutable_data<T>(place)));
out
->
Resize
(
in
->
dims
());
// void* recvbuff = reinterpret_cast<void*>(const_cast<T*>(out->data<T>()));
void
*
recvbuff
=
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
out
->
mutable_data
<
T
>
(
place
)));
// void* recvbuff = sendbuff;
std
::
string
tag
=
ctx
.
Attr
<
std
::
string
>
(
"tag"
);
int
ring_id
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
// s他的:
std
::
string
group
=
std
::
string
(
HCOM_GROUP_PREFIX
)
+
std
::
to_string
(
ring_id
);
group
=
"hccl_world_group"
;
// std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
auto
comm
=
paddle
::
platform
::
HCCLCommContext
::
Instance
().
Get
();
aclrtStream
stream
=
nullptr
;
if
(
ctx
.
Attr
<
bool
>
(
"use_calc_stream"
))
{
auto
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
stream
=
static_cast
<
platform
::
NPUDeviceContext
*>
(
dev_ctx
)
->
stream
();
}
else
{
stream
=
comm
->
stream
();
}
hcclRedOp_t
hccl_red_type
=
HCCL_REP_OP_SUM
;
switch
(
red_type
)
{
case
kRedSum
:
hccl_red_type
=
HCCL_REP_OP_SUM
;
break
;
case
kRedMax
:
hccl_red_type
=
HCCL_REP_OP_MAX
;
break
;
case
kRedMin
:
hccl_red_type
=
HCCL_REP_OP_MIN
;
break
;
case
kRedProd
:
hccl_red_type
=
HCCL_REP_OP_PROD
;
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Invalid reduce type: %d"
,
red_type
));
}
VLOG
(
3
)
<<
"begin hccl allreduce, parameter is: "
<<
"input num: "
<<
numel
<<
"dtype: "
<<
dtype
<<
"hccl_red_type: "
<<
hccl_red_type
<<
", group is: "
<<
group
<<
", tag is "
<<
tag
;
printf
(
"sendbuff: %p
\n
"
,
sendbuff
);
printf
(
"recvbuff: %p
\n
"
,
recvbuff
);
// printf("sendbuff: %p, %d\n", sendbuff, ((int*)sendbuff)[0]);
// printf("recvbuff: %p, %d\n", recvbuff, ((int*)recvbuff)[0]);
PADDLE_ENFORCE_NPU_SUCCESS
(
platform
::
dynload
::
hcom_all_reduce
(
tag
.
c_str
(),
sendbuff
,
recvbuff
,
numel
,
dtype
,
hccl_red_type
,
group
.
c_str
(),
(
void
*
)
stream
));
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddlePaddle should compile with GPU."
));
#endif
}
};
template
<
ReduceType
red_type
,
typename
T
>
class
CAllReduceOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -114,7 +201,7 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
place
=
ctx
.
GetPlace
();
ncclDataType_t
dtype
=
platform
::
To
N
CCLDataType
(
in
->
type
());
ncclDataType_t
dtype
=
platform
::
To
H
CCLDataType
(
in
->
type
());
int64_t
numel
=
in
->
numel
();
const
void
*
sendbuff
=
in
->
data
<
void
>
();
out
->
Resize
(
in
->
dims
());
...
...
@@ -170,6 +257,11 @@ class CAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput
(
"Out"
,
"(Tensor) the allreduced result."
);
AddAttr
<
int
>
(
"ring_id"
,
"(int default 0) communication ring id."
)
.
SetDefault
(
0
);
#if defined(PADDLE_WITH_ASCEND_CL)
#pragma message("hccl CAllReduceOpMaker need tag attr")
AddAttr
<
std
::
string
>
(
"tag"
,
"(string default tag) tag for all reduce."
)
.
SetDefault
(
"tag"
);
#endif
AddAttr
<
bool
>
(
"use_calc_stream"
,
"(bool default false) eject CUDA operations to calculation stream."
)
...
...
paddle/fluid/operators/collective/c_allreduce_prod_op_npu.cc
0 → 100644
浏览文件 @
9fcdaeba
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace
paddle
{
namespace
platform
{
struct
ASCENDPlace
;
struct
float16
;
}
// namespace platform
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_NPU_KERNEL
(
c_allreduce_prod
,
ops
::
CAllReduceOpASCENDKernel
<
ops
::
kRedProd
,
float
>
,
ops
::
CAllReduceOpASCENDKernel
<
ops
::
kRedProd
,
int
>
,
ops
::
CAllReduceOpASCENDKernel
<
ops
::
kRedProd
,
int8_t
>
,
ops
::
CAllReduceOpASCENDKernel
<
ops
::
kRedProd
,
plat
::
float16
>
)
paddle/fluid/operators/collective/c_allreduce_sum_op_npu.cc
0 → 100644
浏览文件 @
9fcdaeba
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace
paddle
{
namespace
platform
{
struct
ASCENDPlace
;
struct
float16
;
}
// namespace platform
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_NPU_KERNEL
(
c_allreduce_sum
,
ops
::
CAllReduceOpASCENDKernel
<
ops
::
kRedSum
,
float
>
,
ops
::
CAllReduceOpASCENDKernel
<
ops
::
kRedSum
,
int
>
,
ops
::
CAllReduceOpASCENDKernel
<
ops
::
kRedSum
,
int8_t
>
,
ops
::
CAllReduceOpASCENDKernel
<
ops
::
kRedSum
,
plat
::
float16
>
)
paddle/fluid/operators/collective/c_broadcast_op.cc
浏览文件 @
9fcdaeba
...
...
@@ -42,6 +42,11 @@ class CBroadcastOpMaker : public framework::OpProtoAndCheckerMaker {
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"root"
,
"(int default 0) root id for broadcasting."
)
.
SetDefault
(
0
);
#if defined(PADDLE_WITH_ASCEND_CL)
#pragma message("tag")
AddAttr
<
std
::
string
>
(
"tag"
,
"(string default tag) tag for broadcasting."
)
.
SetDefault
(
"tag"
);
#endif
AddAttr
<
bool
>
(
"use_calc_stream"
,
"(bool default false) eject CUDA operations to calculation stream."
)
...
...
paddle/fluid/operators/collective/c_broadcast_op_npu.cc
0 → 100644
浏览文件 @
9fcdaeba
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_broadcast_op.h"
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/hccl_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
CBroadcastOpASCENDKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_ASCEND_CL)
auto
x
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
void
*
ptr
=
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
x
->
data
<
T
>
()));
int
numel
=
x
->
numel
();
hcclDataType_t
dtype
=
platform
::
ToHCCLDataType
(
x
->
type
());
auto
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
place
=
ctx
.
GetPlace
();
auto
comm
=
paddle
::
platform
::
HCCLCommContext
::
Instance
().
Get
();
aclrtStream
stream
=
nullptr
;
if
(
ctx
.
Attr
<
bool
>
(
"use_calc_stream"
))
{
auto
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
stream
=
static_cast
<
platform
::
NPUDeviceContext
*>
(
dev_ctx
)
->
stream
();
}
else
{
stream
=
comm
->
stream
();
}
int
root
=
ctx
.
Attr
<
int
>
(
"root"
);
int
ring_id
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
std
::
string
group
=
std
::
string
(
HCOM_GROUP_PREFIX
)
+
std
::
to_string
(
ring_id
);
std
::
string
tag
=
ctx
.
Attr
<
std
::
string
>
(
"tag"
);
VLOG
(
3
)
<<
"begin hccl broadcast, parameter is: "
<<
"root "
<<
root
<<
", group is "
<<
group
<<
", tag is "
<<
tag
;
if
(
root
==
static_cast
<
int
>
(
comm
->
rank
()))
{
PADDLE_ENFORCE_NPU_SUCCESS
(
platform
::
dynload
::
hcom_broadcast
(
tag
.
c_str
(),
ptr
,
numel
,
dtype
,
(
uint32_t
)
root
,
group
.
c_str
(),
(
void
*
)
stream
));
VLOG
(
3
)
<<
"rank "
<<
comm
->
rank
()
<<
" invoke Bcast. sent "
<<
x
->
numel
();
}
else
{
PADDLE_ENFORCE_NPU_SUCCESS
(
platform
::
dynload
::
hcom_broadcast
(
tag
.
c_str
(),
ptr
,
numel
,
dtype
,
(
uint32_t
)
root
,
group
.
c_str
(),
(
void
*
)
stream
));
VLOG
(
3
)
<<
"rank "
<<
comm
->
rank
()
<<
" invoke Bcast. recieved "
<<
framework
::
product
(
out
->
dims
());
}
if
(
out
!=
x
)
{
framework
::
TensorCopy
(
*
static_cast
<
const
framework
::
Tensor
*>
(
x
),
place
,
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
),
static_cast
<
framework
::
Tensor
*>
(
out
));
}
out
->
Resize
(
x
->
dims
());
out
->
set_lod
(
x
->
lod
());
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddlePaddle should compile with GPU."
));
#endif
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_NPU_KERNEL
(
c_broadcast
,
ops
::
CBroadcastOpASCENDKernel
<
float
>
,
ops
::
CBroadcastOpASCENDKernel
<
int
>
,
ops
::
CBroadcastOpASCENDKernel
<
int8_t
>
,
ops
::
CBroadcastOpASCENDKernel
<
plat
::
float16
>
);
paddle/fluid/operators/collective/c_comm_init_hccl_op.cc
0 → 100644
浏览文件 @
9fcdaeba
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/hccl_helper.h"
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
operators
{
class
CCommInitOpNPU
:
public
framework
::
OperatorBase
{
public:
CCommInitOpNPU
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
std
::
string
rank_table_file
=
Attr
<
std
::
string
>
(
"rank_table_file"
);
uint32_t
rank_id
=
Attr
<
int
>
(
"rank_id"
);
uint32_t
device_id
=
Attr
<
int
>
(
"device_id"
);
VLOG
(
3
)
<<
"begin init hccl, parameter is: "
<<
"rank_table_file "
<<
rank_table_file
<<
" rank_id "
<<
rank_id
<<
" device_id "
<<
device_id
;
platform
::
HCCLCommContext
::
Instance
().
CreateHCCLComm
(
rank_table_file
,
rank_id
,
device_id
);
}
};
class
CCommInitOpNPUMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddComment
(
R"DOC(
CCommInit operator on NPU
Initialize collective communication context within this trainer
)DOC"
);
AddAttr
<
std
::
string
>
(
"rank_table_file"
,
"(string) path to rank_table_file"
);
AddAttr
<
int
>
(
"rank_id"
,
"(int) world rank id of the process"
);
AddAttr
<
int
>
(
"device_id"
,
"(int) device id of the process/thread"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
c_comm_init_hccl
,
ops
::
CCommInitOpNPU
,
ops
::
CCommInitOpNPUMaker
);
#endif
paddle/fluid/operators/collective/c_create_group_op.cc
0 → 100644
浏览文件 @
9fcdaeba
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/hccl_helper.h"
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
operators
{
class
CCreateGroupOpNPU
:
public
framework
::
OperatorBase
{
public:
CCreateGroupOpNPU
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
std
::
string
group_name
=
Attr
<
std
::
string
>
(
"group_name"
);
int
nranks
=
Attr
<
int
>
(
"nranks"
);
std
::
vector
<
int
>
rank_ids
=
Attr
<
std
::
vector
<
int
>>
(
"rank_ids"
);
paddle
::
platform
::
HCCLCommContext
::
Instance
().
CreateHCCLGroup
(
group_name
,
(
uint32_t
)
nranks
,
std
::
vector
<
uint32_t
>
(
rank_ids
.
begin
(),
rank_ids
.
end
()));
}
};
class
CCreateGroupOpNPUMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddComment
(
R"DOC(
CCreateGroup operator on NPU
Create collective communication group on NPU
)DOC"
);
AddAttr
<
std
::
string
>
(
"group_name"
,
"(string) name of the collective communication group"
);
AddAttr
<
int
>
(
"nranks"
,
"(int) number of the group"
);
AddAttr
<
std
::
vector
<
int
>>
(
"rank_ids"
,
"(list of int) The world rank id of the group members"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
c_create_group
,
ops
::
CCreateGroupOpNPU
,
ops
::
CCreateGroupOpNPUMaker
);
#endif
paddle/fluid/operators/collective/c_hcom_op_npu_test.cc
0 → 100644
浏览文件 @
9fcdaeba
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef _WIN32
#include <unistd.h>
#endif
#include <stdio.h>
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/operators/collective/c_broadcast_op.h"
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/hccl_helper.h"
#endif
namespace
f
=
paddle
::
framework
;
namespace
p
=
paddle
::
platform
;
namespace
m
=
paddle
::
operators
::
math
;
USE_OP
(
c_broadcast
);
USE_OP
(
c_allreduce_sum
);
USE_NO_KERNEL_OP
(
c_comm_init_hccl
);
USE_NO_KERNEL_OP
(
c_create_group
);
USE_OP_DEVICE_KERNEL
(
c_broadcast
,
NPU
);
USE_OP_DEVICE_KERNEL
(
c_allreduce_sum
,
NPU
);
void
Prepare
(
f
::
Scope
*
scope
,
const
p
::
DeviceContext
&
ctx
){
std
::
string
rank_table_file
=
getenv
(
"RANK_TABLE_FILE"
);
int
rank_id
=
atoi
(
getenv
(
"RANK_ID"
));
int
device_id
=
atoi
(
getenv
(
"DEVICE_ID"
));
printf
(
"rank_table_file: %s, rank_id = %d, device_id = %d
\n
"
,
rank_table_file
.
c_str
(),
rank_id
,
device_id
);
f
::
AttributeMap
attrs
;
attrs
[
"rank_table_file"
]
=
rank_table_file
;
attrs
[
"rank_id"
]
=
rank_id
;
attrs
[
"device_id"
]
=
device_id
;
auto
comm_init_op
=
f
::
OpRegistry
::
CreateOp
(
"c_comm_init_hccl"
,
{},
{},
attrs
);
auto
place
=
ctx
.
GetPlace
();
comm_init_op
->
Run
(
*
scope
,
place
);
ctx
.
Wait
();
f
::
AttributeMap
create_attrs
;
create_attrs
[
"group_name"
]
=
HCOM_GROUP_PREFIX
+
std
::
to_string
(
0
);
create_attrs
[
"nranks"
]
=
2
;
std
::
vector
<
int
>
rank_ids
{
0
,
1
};
create_attrs
[
"rank_ids"
]
=
rank_ids
;
auto
create_group_op
=
f
::
OpRegistry
::
CreateOp
(
"c_create_group"
,
{},
{},
create_attrs
);
create_group_op
->
Run
(
*
scope
,
place
);
ctx
.
Wait
();
}
void
TestHCCLBroadcastOp
(
f
::
Scope
*
scope
,
const
p
::
DeviceContext
&
ctx
)
{
std
::
cout
<<
"BEGIN TEST:"
<<
__FUNCTION__
<<
std
::
endl
;
// init
auto
x
=
scope
->
Var
(
"X"
);
auto
tensor_x
=
x
->
GetMutable
<
f
::
LoDTensor
>
();
int
num
=
2
;
std
::
vector
<
float
>
init
;
int
rank_id
=
atoi
(
getenv
(
"RANK_ID"
));
std
::
cout
<<
"rank_id:"
<<
rank_id
<<
std
::
endl
;
for
(
int64_t
i
=
0
;
i
<
num
*
num
;
++
i
)
{
init
.
push_back
(
1.0
+
rank_id
);
std
::
cout
<<
init
[
0
];
}
std
::
cout
<<
std
::
endl
;
TensorFromVector
(
init
,
ctx
,
tensor_x
);
tensor_x
->
Resize
({
num
,
num
});
ctx
.
Wait
();
auto
place
=
ctx
.
GetPlace
();
auto
out
=
scope
->
Var
(
"Out"
);
auto
tensor_out
=
out
->
GetMutable
<
f
::
LoDTensor
>
();
tensor_out
->
Resize
({
num
,
num
});
tensor_out
->
mutable_data
<
float
>
(
place
);
// allocate
ctx
.
Wait
();
// run
f
::
AttributeMap
attrs
;
attrs
[
"tag"
]
=
std
::
string
(
"tagx"
);
attrs
[
"root"
]
=
0
;
attrs
[
"ring_id"
]
=
0
;
auto
op
=
f
::
OpRegistry
::
CreateOp
(
"c_broadcast"
,
{{
"X"
,
{
"X"
}}},
{{
"Out"
,
{
"Out"
}}},
attrs
);
op
->
Run
(
*
scope
,
place
);
std
::
vector
<
float
>
out_vec
;
TensorToVector
(
*
tensor_out
,
ctx
,
&
out_vec
);
ctx
.
Wait
();
EXPECT_EQ
(
out_vec
.
size
(),
init
.
size
());
for
(
uint32_t
i
=
0
;
i
<
out_vec
.
size
();
i
++
)
{
EXPECT_EQ
(
out_vec
[
i
],
1.0
);
}
}
void
TestHCCLAllReduceOp
(
f
::
Scope
*
scope
,
const
p
::
DeviceContext
&
ctx
)
{
std
::
cout
<<
"BEGIN TEST:"
<<
__FUNCTION__
<<
std
::
endl
;
// init
auto
x
=
scope
->
Var
(
"X"
);
auto
tensor_x
=
x
->
GetMutable
<
f
::
LoDTensor
>
();
std
::
vector
<
float
>
init
;
int
rank_id
=
atoi
(
getenv
(
"RANK_ID"
));
std
::
cout
<<
"rank_id:"
<<
rank_id
<<
std
::
endl
;
int
num1
=
1
;
int
num2
=
4
;
for
(
int64_t
i
=
0
;
i
<
num1
*
num2
;
++
i
)
{
init
.
push_back
(
1.0
);
// init.push_back(1.0 + rank_id * 3);
std
::
cout
<<
init
[
0
];
}
std
::
cout
<<
std
::
endl
;
TensorFromVector
(
init
,
ctx
,
tensor_x
);
tensor_x
->
Resize
({
num1
,
num2
});
ctx
.
Wait
();
auto
place
=
ctx
.
GetPlace
();
auto
out
=
scope
->
Var
(
"Out"
);
auto
tensor_out
=
out
->
GetMutable
<
f
::
LoDTensor
>
();
tensor_out
->
Resize
({
num1
,
num2
});
tensor_out
->
mutable_data
<
float
>
(
place
);
// allocate
ctx
.
Wait
();
// run
f
::
AttributeMap
attrs
;
attrs
[
"tag"
]
=
std
::
string
(
"tagx"
);
attrs
[
"ring_id"
]
=
0
;
auto
op
=
f
::
OpRegistry
::
CreateOp
(
"c_allreduce_sum"
,
{{
"X"
,
{
"X"
}}},
{{
"Out"
,
{
"Out"
}}},
attrs
);
op
->
Run
(
*
scope
,
place
);
std
::
vector
<
float
>
out_vec
;
TensorToVector
(
*
tensor_out
,
ctx
,
&
out_vec
);
ctx
.
Wait
();
EXPECT_EQ
(
out_vec
.
size
(),
init
.
size
());
for
(
uint32_t
i
=
0
;
i
<
out_vec
.
size
();
i
++
)
{
EXPECT_EQ
(
out_vec
[
i
],
2.0
);
}
}
TEST
(
c_broadcast
,
NPU
)
{
f
::
Scope
scope
;
char
*
npu_id
=
getenv
(
"FLAGS_selected_npus"
);
p
::
NPUDeviceContext
ctx
(
p
::
NPUPlace
(
atoi
(
npu_id
)));
Prepare
(
&
scope
,
ctx
);
// TestHCCLBroadcastOp(&scope, ctx);
TestHCCLAllReduceOp
(
&
scope
,
ctx
);
}
paddle/fluid/platform/CMakeLists.txt
浏览文件 @
9fcdaeba
...
...
@@ -128,7 +128,7 @@ cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool
place eigen3 stringpiece cpu_helper cpu_info framework_proto
${
GPU_CTX_DEPS
}
${
NPU_CTX_DEPS
}
${
MKLDNN_CTX_DEPS
}
${
dgc_deps
}
dlpack cudnn_workspace_helper
${
XPU_CTX_DEPS
}
)
cc_library
(
collective_helper SRCS collective_helper.cc DEPS framework_proto device_context enforce
)
cc_library
(
collective_helper SRCS collective_helper.cc
collective_helper_npu.cc
DEPS framework_proto device_context enforce
)
if
(
WITH_GPU
)
cc_library
(
cuda_resource_pool SRCS cuda_resource_pool.cc DEPS gpu_info
)
...
...
paddle/fluid/platform/collective_helper.h
浏览文件 @
9fcdaeba
...
...
@@ -14,20 +14,21 @@
#pragma once
#if defined(PADDLE_WITH_NCCL)
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "boost/variant.hpp"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/dynload/hccl.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
platform
{
#if defined(PADDLE_WITH_NCCL)
// In order to apply hierarchical communication with NCCL, we need
// a communication ring contains NCCL communicators associated to a global
// ncclUniqueId. E.g. for a hierarchical case,
...
...
@@ -47,6 +48,8 @@ namespace platform {
//
// The NCCLComm instance is created and reversed in the NCCLCommContext
// singleton with a global user specified group id.
class
CUDADeviceContext
;
class
NCCLComm
{
public:
virtual
int
ring_id
()
const
=
0
;
...
...
@@ -120,8 +123,162 @@ class NCCLCommContext {
NCCLCommContext
()
=
default
;
DISABLE_COPY_AND_ASSIGN
(
NCCLCommContext
);
};
#endif
}
// namespace platform
}
// namespace paddle
#if defined(PADDLE_WITH_ASCEND_CL)
// In order to apply hierarchical communication with HCCL, we need
// a communication ring contains HCCL communicators associated to a global
// HCCLUniqueId. E.g. for a hierarchical case,
//
// 11 - 12 21 - 22
// | | | |
// 13 - 14 - 23 - 24
// | |
// 31 - 32 - 41 - 42
// | | | |
// 33 - 34 43 - 44
//
// we group (14,23,32,41) as the top, and (11,12,13,14), (21,22,23,24),
// (31,32,33,34), (41,42,43,44) as bottoms respectively.
//
// We could also use a single communication ring for the flatten case
//
// The HCCLComm instance is created and reversed in the HCCLCommContext
// singleton with a global user specified group id.
class
NPUDeviceContext
;
class
HCCLComm
{
public:
virtual
std
::
string
rank_table_file
()
const
=
0
;
virtual
uint32_t
rank
()
const
=
0
;
virtual
uint32_t
device_id
()
const
=
0
;
virtual
aclrtStream
stream
()
const
=
0
;
virtual
NPUDeviceContext
*
dev_context
()
const
=
0
;
virtual
~
HCCLComm
()
=
default
;
};
// A singleton HCCL communicator context reserves communication ring ids
class
HCCLCommContext
{
public:
static
HCCLCommContext
&
Instance
()
{
static
HCCLCommContext
comm_ctx
;
return
comm_ctx
;
}
HCCLComm
*
CreateHCCLComm
(
const
std
::
string
&
config_file
,
uint32_t
rank
,
uint32_t
device_id
);
void
CreateHCCLGroup
(
const
std
::
string
&
group_name
,
uint32_t
nranks
,
const
std
::
vector
<
uint32_t
>&
rank_ids
);
// retrieve a communicator by the ring id and place
HCCLComm
*
Get
()
const
{
return
comm_
.
get
();
}
private:
std
::
once_flag
once_flag_
;
std
::
mutex
comm_map_mutex_
;
std
::
unique_ptr
<
HCCLComm
>
comm_
;
HCCLComm
*
AssignHCCLComm
(
const
std
::
string
&
config_file
,
uint32_t
rank
,
uint32_t
device_id
);
HCCLCommContext
()
=
default
;
DISABLE_COPY_AND_ASSIGN
(
HCCLCommContext
);
};
#endif
#if defined(PADDLE_WITH_XPU_BKCL)
// In order to apply hierarchical communication with BKCL, we need
// a communication ring contains BKCL communicators associated to a global
// BKCLUniqueId. E.g. for a hierarchical case,
//
// 11 - 12 21 - 22
// | | | |
// 13 - 14 - 23 - 24
// | |
// 31 - 32 - 41 - 42
// | | | |
// 33 - 34 43 - 44
//
// we group (14,23,32,41) as the top, and (11,12,13,14), (21,22,23,24),
// (31,32,33,34), (41,42,43,44) as bottoms respectively.
//
// We could also use a single communication ring for the flatten case
//
// The BKCLComm instance is created and reversed in the BKCLCommContext
// singleton with a global user specified group id.
class
BKCLComm
{
public:
virtual
int
ring_id
()
const
=
0
;
virtual
int
nranks
()
const
=
0
;
virtual
int
rank
()
const
=
0
;
virtual
int
device_id
()
const
=
0
;
virtual
BKCLContext_t
comm
()
const
=
0
;
virtual
XPUStream
stream
()
const
=
0
;
virtual
XPUDeviceContext
*
dev_context
()
const
=
0
;
virtual
~
BKCLComm
()
=
default
;
};
// A singleton BKCL communicator context reserves communication ring ids
class
BKCLCommContext
{
public:
static
BKCLCommContext
&
Instance
()
{
static
BKCLCommContext
comm_ctx
;
return
comm_ctx
;
}
BKCLComm
*
CreateBKCLComm
(
BKCLUniqueId
*
bkcl_id
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
=
0
);
void
CreateAllBKCLComms
(
const
std
::
vector
<
int
>&
dev_ids
,
int
ring_id
=
0
);
// a latter comm with the same dev_id and the same ring_id
// will override the former
BKCLComm
*
AssignBKCLComm
(
BKCLContext_t
comm
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
=
0
);
// retrieve a communicator by the ring id in multiprocessing mode
BKCLComm
*
Get
(
int
ring_id
)
const
{
PADDLE_ENFORCE_GT
(
comm_map_
.
count
(
ring_id
),
0
,
platform
::
errors
::
InvalidArgument
(
"Communicator in ring id %d has not been initialized."
,
ring_id
));
PADDLE_ENFORCE_EQ
(
comm_map_
.
at
(
ring_id
).
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"One device id should be specified to retrieve from "
"multiple communicators."
));
return
comm_map_
.
at
(
ring_id
).
begin
()
->
second
.
get
();
}
// retrieve a communicator by the ring id and the device id
BKCLComm
*
Get
(
int
ring_id
,
int
dev_id
)
const
{
PADDLE_ENFORCE_GT
(
comm_map_
.
count
(
ring_id
),
0
,
platform
::
errors
::
InvalidArgument
(
"Communicator of ring id %d has not been initialized."
,
ring_id
));
PADDLE_ENFORCE_GT
(
comm_map_
.
at
(
ring_id
).
count
(
dev_id
),
0
,
platform
::
errors
::
InvalidArgument
(
"Communicator at device id %d has not been initialized in ring %d."
,
dev_id
,
ring_id
));
return
comm_map_
.
at
(
ring_id
).
at
(
dev_id
).
get
();
}
// retrieve a communicator by the ring id and place
BKCLComm
*
Get
(
int
ring_id
,
Place
place
)
const
{
return
Get
(
ring_id
,
BOOST_GET_CONST
(
XPUPlace
,
place
).
device
);
}
private:
std
::
once_flag
once_flag_
;
std
::
mutex
comm_map_mutex_
;
// ring id to dev-BKCLComm
std
::
map
<
int
,
std
::
map
<
int
,
std
::
unique_ptr
<
BKCLComm
>>>
comm_map_
;
void
ReleaseBKCLComms
();
BKCLCommContext
()
=
default
;
DISABLE_COPY_AND_ASSIGN
(
BKCLCommContext
);
};
#endif
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/collective_helper_npu.cc
0 → 100644
浏览文件 @
9fcdaeba
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
#include <utility>
namespace
paddle
{
namespace
platform
{
class
HCCLCommImpl
:
public
HCCLComm
{
public:
void
set_rank_table_file
(
const
std
::
string
&
rank_table_file
)
{
rank_table_file_
=
rank_table_file
;
}
std
::
string
rank_table_file
()
const
override
{
return
rank_table_file_
;
}
void
set_rank
(
uint32_t
rank
)
{
rank_
=
rank
;
}
uint32_t
rank
()
const
override
{
return
rank_
;
}
void
set_device_id
(
uint32_t
device_id
)
{
device_id_
=
device_id
;
}
uint32_t
device_id
()
const
override
{
return
device_id_
;
}
aclrtStream
stream
()
const
override
{
return
dev_ctx_
->
stream
();
}
void
set_dev_ctx
(
std
::
unique_ptr
<
NPUDeviceContext
>&&
dev_ctx
)
{
dev_ctx_
=
std
::
move
(
dev_ctx
);
}
NPUDeviceContext
*
dev_context
()
const
override
{
return
dev_ctx_
.
get
();
}
private:
std
::
string
rank_table_file_
;
uint32_t
rank_
;
uint32_t
device_id_
;
std
::
unique_ptr
<
NPUDeviceContext
>
dev_ctx_
;
};
HCCLComm
*
HCCLCommContext
::
CreateHCCLComm
(
const
std
::
string
&
rank_table_file
,
uint32_t
rank
,
uint32_t
device_id
)
{
/*
PADDLE_ENFORCE_NOT_NULL(rank_table_file,
platform::errors::InvalidArgument(
"The rank table file should not be null."));
PADDLE_ENFORCE_GE(rank, 0,
platform::errors::InvalidArgument(
"Expected rank >= 0. But received rank is %d.", rank));
PADDLE_ENFORCE_GE(device_id, 0,
platform::errors::InvalidArgument(
"Expected dev_id >= 0. But received dev_id is %d.", device_id));
*/
auto
*
comm_wrapper
=
AssignHCCLComm
(
rank_table_file
,
rank
,
device_id
);
platform
::
dynload
::
hcom_init
(
rank_table_file
.
c_str
(),
std
::
to_string
(
rank
).
c_str
());
platform
::
dynload
::
hcom_bind_model
(
comm_wrapper
->
stream
(),
comm_wrapper
->
stream
());
VLOG
(
1
)
<<
"hccl communicator of rank "
<<
rank
<<
" has been created"
;
return
comm_wrapper
;
}
HCCLComm
*
HCCLCommContext
::
AssignHCCLComm
(
const
std
::
string
&
rank_table_file
,
uint32_t
rank
,
uint32_t
device_id
)
{
std
::
unique_ptr
<
NPUDeviceContext
>
dev_ctx
(
new
NPUDeviceContext
(
NPUPlace
(
device_id
)));
VLOG
(
3
)
<<
"device_id"
<<
device_id
;
VLOG
(
3
)
<<
"dev_ctx->stream()"
<<
dev_ctx
->
stream
();
HCCLCommImpl
*
c
=
new
HCCLCommImpl
;
c
->
set_rank_table_file
(
rank_table_file
);
c
->
set_rank
(
rank
);
c
->
set_device_id
(
device_id
);
c
->
set_dev_ctx
(
std
::
move
(
dev_ctx
));
// comm_ = c
comm_
.
reset
(
c
);
return
c
;
}
void
HCCLCommContext
::
CreateHCCLGroup
(
const
std
::
string
&
group_name
,
uint32_t
nranks
,
const
std
::
vector
<
uint32_t
>&
rank_ids
)
{
/*
PADDLE_ENFORCE_NOT_NULL(group_name,
platform::errors::InvalidArgument(
"The group name should not be null."));
PADDLE_ENFORCE_GT(nranks, 0,
platform::errors::InvalidArgument(
"Expected nranks > 0. But received nranks is %d.", nranks));
PADDLE_ENFORCE_NOT_NULL(rank_ids,
platform::errors::InvalidArgument(
"The rank ids should not be null."));
*/
platform
::
dynload
::
hcom_create_group
(
group_name
.
c_str
(),
nranks
,
(
unsigned
int
*
)
rank_ids
.
data
());
VLOG
(
1
)
<<
"hccl group with name "
<<
group_name
<<
" has been created"
;
}
}
// namespace platform
}
// namespace paddle
#endif
paddle/fluid/platform/dynload/CMakeLists.txt
浏览文件 @
9fcdaeba
...
...
@@ -9,7 +9,7 @@ endif()
# There is no macOS version of NCCL.
# Disable nvrtc and cuda_driver api on MacOS and Windows, and only do a early test on Linux.
if
(
NOT APPLE AND NOT WIN32
)
list
(
APPEND CUDA_SRCS nvrtc.cc cuda_driver.cc
)
list
(
APPEND CUDA_SRCS nvrtc.cc cuda_driver.cc
)
if
(
WITH_NCCL
)
list
(
APPEND CUDA_SRCS nccl.cc
)
endif
()
...
...
@@ -32,6 +32,8 @@ endif(CUPTI_FOUND)
if
(
WITH_ROCM_PLATFORM
)
hip_library
(
dynload_cuda SRCS
${
HIP_SRCS
}
DEPS dynamic_loader
)
hip_library
(
dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc
)
elseif
(
WITH_ASCEND_CL
)
cc_library
(
dynload_warpctc SRCS warpctc.cc hccl.cc DEPS dynamic_loader warpctc
)
else
()
nv_library
(
dynload_cuda SRCS
${
CUDA_SRCS
}
DEPS dynamic_loader
)
cc_library
(
dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc
)
...
...
paddle/fluid/platform/dynload/base.h
0 → 100644
浏览文件 @
9fcdaeba
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* @file base.h
* @brief HCOM data type definition
*
*/
#ifndef HCCL_BASE_H_
#define HCCL_BASE_H_
#define HCOM_GROUP_PREFIX "HCOM_GROUP_"
#ifdef __cplusplus
extern
"C"
{
#endif // __cplusplus
typedef
signed
char
s8
;
typedef
signed
short
s16
;
typedef
signed
int
s32
;
typedef
signed
long
long
s64
;
typedef
unsigned
char
u8
;
typedef
unsigned
short
u16
;
typedef
unsigned
int
u32
;
typedef
unsigned
long
long
u64
;
/**
* @brief HCOM functions return value definition
*/
typedef
enum
tagHcclResult
{
HCCL_SUCCESS
=
0
,
/**< success */
HCCL_E_PARA
=
1
,
/**< parameter error */
HCCL_E_PTR
=
2
,
/**< empty pointer */
HCCL_E_MEMORY
=
3
,
/**< memory error */
HCCL_E_INTERNAL
=
4
,
/**< internal error */
HCCL_E_NOT_SUPPORT
=
5
,
/**< not support feature */
HCCL_E_NOT_FOUND
=
6
,
/**< not found specific resource */
HCCL_E_UNAVAIL
=
7
,
/**< resource unavailable */
HCCL_E_SYSCALL
=
8
,
/**< call system interface error */
HCCL_E_TIMEOUT
=
9
,
/**< timeout */
HCCL_E_OPEN_FILE_FAILURE
=
10
,
/**< open file fail */
HCCL_E_TCP_CONNECT
=
11
,
/**< tcp connect fail */
HCCL_E_ROCE_CONNECT
=
12
,
/**< roce connect fail */
HCCL_E_TCP_TRANSFER
=
13
,
/**< tcp transfer fail */
HCCL_E_ROCE_TRANSFER
=
14
,
/**< roce transfer fail */
HCCL_E_RUNTIME
=
15
,
/**< call runtime api fail */
HCCL_E_DRV
=
16
,
/**< call driver api fail */
HCCL_E_PROFILING
=
17
,
/**< call profiling api fail */
HCCL_E_CCE
=
18
,
/**< call cce api fail */
HCCL_E_NETWORK
=
19
,
/**< call network api fail */
HCCL_E_RESERVED
/**< reserved */
}
hcclResult_t
;
/* handle to communicator */
typedef
void
*
hcclComm_t
;
/**
* @brief HCCL Reduction opperation
*/
typedef
enum
tagHcclRedOp
{
HCCL_REP_OP_SUM
=
0
,
/**< sum */
HCCL_REP_OP_PROD
=
1
,
/**< prod */
HCCL_REP_OP_MAX
=
2
,
/**< max */
HCCL_REP_OP_MIN
=
3
,
/**< min */
HCCL_REP_OP_RESERVED
/**< reserved */
}
hcclRedOp_t
;
/**
* @brief HCCL data type
*/
typedef
enum
tagHcclDataType
{
HCCL_DATA_TYPE_INT8
=
0
,
/**< int8 */
HCCL_DATA_TYPE_INT16
=
1
,
/**< int16 */
HCCL_DATA_TYPE_INT32
=
2
,
/**< int32 */
HCCL_DATA_TYPE_FP16
=
3
,
/**< fp16 */
HCCL_DATA_TYPE_FP32
=
4
,
/**< fp32 */
HCCL_DATA_TYPE_INT64
=
5
,
/**< fp32 */
HCCL_DATA_TYPE_UINT64
=
6
,
/**< fp32 */
HCCL_DATA_TYPE_RESERVED
/**< reserved */
}
hcclDataType_t
;
const
u32
HCCL_MAX_SEGMENT_NUM
=
8
;
// The max number of gradient segments.
/**
* @brief the feature of the model
*/
struct
model_feature
{
const
char
*
model_name
;
/**< The model name */
u32
gradient_num
;
/**< The number of gradients */
float
*
gradient_size
;
/**< The size of each gradient */
float
*
gradient_time
;
/**< The BP compution time of each gradient */
};
enum
GradSplitForceMode
{
FORCE_NONE
,
/**< no force */
FORCE_SIZE
,
/**< force split gradient by size */
FORCE_RESERVED
/**< reserved */
};
/**
* @brief stream handle.
*/
typedef
void
*
rtStream_t
;
/**
* @brief model handle.
*/
typedef
void
*
rtModel_t
;
#ifdef __cplusplus
}
#endif // __cplusplus
#endif // HCCL_BASE_H_
paddle/fluid/platform/dynload/dynamic_loader.cc
浏览文件 @
9fcdaeba
...
...
@@ -21,6 +21,10 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/cupti_lib_path.h"
#include "paddle/fluid/platform/enforce.h"
DEFINE_string
(
cudnn_dir
,
""
,
"Specify path for loading libcudnn.so. For instance, "
"/usr/local/cudnn/lib. If empty [default], dlopen "
...
...
@@ -36,6 +40,11 @@ DEFINE_string(nccl_dir, "",
"For instance, /usr/local/cuda/lib64. If default, "
"dlopen will search cuda from LD_LIBRARY_PATH"
);
DEFINE_string
(
hccl_dir
,
""
,
"Specify path for loading hccl library, such as libhccl.so. "
"For instance, /usr/local/Ascend/ascend-toolkit/latest/fwkacllib/lib64/. If default, "
"dlopen will search hccl from LD_LIBRARY_PATH"
);
DEFINE_string
(
cupti_dir
,
""
,
"Specify path for loading cupti.so."
);
DEFINE_string
(
...
...
@@ -383,6 +392,26 @@ void* GetNCCLDsoHandle() {
warning_msg
);
#endif
}
void
*
GetHCCLDsoHandle
()
{
std
::
string
warning_msg
(
"You may need to install 'hccl2' from Huawei official website: "
"before install PaddlePaddle."
);
#if defined(__APPLE__) || defined(__OSX__)
return
GetDsoHandleFromSearchPath
(
FLAGS_nccl_dir
,
"libnccl.dylib"
,
true
,
{},
warning_msg
);
#elif defined(PADDLE_WITH_HIP) && defined(PADDLE_WITH_RCCL)
return
GetDsoHandleFromSearchPath
(
FLAGS_rccl_dir
,
"librccl.so"
,
true
);
#elif defined(PADDLE_WITH_ASCEND_CL)
return
GetDsoHandleFromSearchPath
(
FLAGS_hccl_dir
,
"libhccl.so"
,
true
,
{},
warning_msg
);
#else
return
GetDsoHandleFromSearchPath
(
FLAGS_nccl_dir
,
"libnccl.so"
,
true
,
{},
warning_msg
);
#endif
}
void
*
GetTensorRtDsoHandle
()
{
#if defined(__APPLE__) || defined(__OSX__)
...
...
paddle/fluid/platform/dynload/dynamic_loader.h
浏览文件 @
9fcdaeba
...
...
@@ -34,6 +34,7 @@ void* GetNVRTCDsoHandle();
void
*
GetCUDADsoHandle
();
void
*
GetWarpCTCDsoHandle
();
void
*
GetNCCLDsoHandle
();
void
*
GetHCCLDsoHandle
();
void
*
GetTensorRtDsoHandle
();
void
*
GetMKLMLDsoHandle
();
void
*
GetOpDsoHandle
(
const
std
::
string
&
dso_name
);
...
...
paddle/fluid/platform/dynload/hccl.cc
0 → 100644
浏览文件 @
9fcdaeba
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/dynload/hccl.h"
namespace
paddle
{
namespace
platform
{
namespace
dynload
{
std
::
once_flag
hccl_dso_flag
;
void
*
hccl_dso_handle
;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
HCCL_RAND_ROUTINE_EACH
(
DEFINE_WRAP
);
#if HCCL_VERSION_CODE >= 2212
HCCL_RAND_ROUTINE_EACH_AFTER_2212
(
DEFINE_WRAP
)
#endif
#if HCCL_VERSION_CODE >= 2703
HCCL_RAND_ROUTINE_EACH_AFTER_2703
(
DEFINE_WRAP
)
#endif
}
// namespace dynload
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/dynload/hccl.h
0 → 100644
浏览文件 @
9fcdaeba
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
// #include <hccl/hccl.h>
// #include <hccl/hccl_types.h>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/dynload/hcom.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
namespace
paddle
{
namespace
platform
{
namespace
dynload
{
extern
std
::
once_flag
hccl_dso_flag
;
extern
void
*
hccl_dso_handle
;
#define DECLARE_DYNAMIC_LOAD_HCCL_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
using HCCL_func = decltype(&::__name); \
std::call_once(hccl_dso_flag, []() { \
hccl_dso_handle = paddle::platform::dynload::GetHCCLDsoHandle(); \
}); \
static void* p_##__name = dlsym(hccl_dso_handle, #__name); \
return reinterpret_cast<HCCL_func>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
#define HCCL_RAND_ROUTINE_EACH(__macro) \
__macro(hcom_init); \
__macro(hcom_destroy); \
__macro(hcom_bind_model); \
__macro(hcom_unbind_model); \
__macro(hcom_send); \
__macro(hcom_receive); \
__macro(hcom_broadcast); \
__macro(hcom_all_gather); \
__macro(hcom_all_reduce); \
__macro(hcom_reduce_scatter); \
__macro(hcom_create_group); \
__macro(hcom_destroy_group); \
__macro(hcom_get_rank_id); \
__macro(hcom_get_local_rank_id); \
__macro(hcom_get_local_rank_size); \
__macro(hcom_get_split_strategy); \
__macro(hcom_set_split_strategy_by_size); \
__macro(hcom_set_split_strategy_by_index); \
__macro(hcom_get_group_rank_from_world_rank); \
__macro(hcom_get_world_rank_from_group_rank);
HCCL_RAND_ROUTINE_EACH
(
DECLARE_DYNAMIC_LOAD_HCCL_WRAP
)
#if HCCL_VERSION_CODE >= 2212
#define HCCL_RAND_ROUTINE_EACH_AFTER_2212(__macro) __macro(HCCLBroadcast);
HCCL_RAND_ROUTINE_EACH_AFTER_2212
(
DECLARE_DYNAMIC_LOAD_HCCL_WRAP
)
#endif
#if HCCL_VERSION_CODE >= 2703
#define HCCL_RAND_ROUTINE_EACH_AFTER_2703(__macro) \
__macro(HCCLSend); \
__macro(HCCLRecv);
HCCL_RAND_ROUTINE_EACH_AFTER_2703
(
DECLARE_DYNAMIC_LOAD_HCCL_WRAP
)
#endif
}
// namespace dynload
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/dynload/hcom.h
0 → 100644
浏览文件 @
9fcdaeba
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* @file hcom.h
* @brief HCOM API
*/
#ifndef HCOM_H_
#define HCOM_H_
// #include <runtime/rt.h>
#include "paddle/fluid/platform/dynload/base.h"
#ifdef __cplusplus
extern
"C"
{
#endif // __cplusplus
/**
* @brief Initialize HCOM.
*
* @param rank_table A string identifying the rank table file path, include file name.
* @param identify A string identifying the identify for the rank.
* @return hcclResult_t
* @see hcom_destroy()
*/
extern
hcclResult_t
hcom_init
(
const
char
*
rank_table
,
const
char
*
identify
);
/**
* @brief Destroy HCOM
*
* @return hcclResult_t
* @see hcom_init()
*/
extern
hcclResult_t
hcom_destroy
(
void
);
/**
* @brief Bind the model.
*
* @param model A pointer identifying the model information.
* @param stream A pointer identifying the stream information.
* @return hcclResult_t
* @see hcom_unbind_model()
*/
extern
hcclResult_t
hcom_bind_model
(
rtModel_t
model
,
rtStream_t
stream
);
/**
* @brief Unbind the model.
*
* @param model An pointer identifying the model information.
* @return hcclResult_t
* @see hcom_unbind_model()
*/
extern
hcclResult_t
hcom_unbind_model
(
rtModel_t
model
);
/**
* @brief All-gather operator.
*
* @param tag A string identifying the tag of the operator.
* @param inputPtr A pointer identifying the input data address of the operator.
* @param outputPtr A pointer identifying the output data address of the operator.
* @param inputCount An integer(u64) identifying the number of the input data.
* @param dataType The data type of the operator, must be one of the following types: int8, int32, float16, float32.
* @param group A string identifying the group name of ranks participating in the operator.
* @param stream A pointer identifying the stream information.
* @return hcclResult_t
*/
extern
hcclResult_t
hcom_all_gather
(
const
char
*
tag
,
void
*
inputPtr
,
void
*
outputPtr
,
u64
inputCount
,
hcclDataType_t
dataType
,
const
char
*
group
,
rtStream_t
stream
);
/**
* @brief All-reduce operator.
*
* @param tag A string identifying the tag of the operator.
* @param inputPtr A pointer identifying the input data address of the operator.
* @param outputPtr A pointer identifying the output data address of the operator.
* @param count An integer(u64) identifying the number of the output data.
* @param dataType The data type of the operator, must be one of the following types: int8, int32, float16, float32.
* @param op The reduction type of the operator, must be one of the following types: sum, min, max, prod.
* @param group A string identifying the group name of ranks participating in the operator.
* @param stream A pointer identifying the stream information.
* @return hcclResult_t
*/
extern
hcclResult_t
hcom_all_reduce
(
const
char
*
tag
,
void
*
inputPtr
,
void
*
outputPtr
,
u64
count
,
hcclDataType_t
dataType
,
hcclRedOp_t
op
,
const
char
*
group
,
rtStream_t
stream
);
/**
* @brief Broadcast operator.
*
* @param tag A string identifying the tag of the operator.
* @param ptr A pointer identifying the data address of the operator.
* @param count An integer(u64) identifying the number of the data.
* @param dataType The data type of the operator, must be one of the following types: int8, int32, float16, float32.
* @param root An integer(u32) identifying the the root rank in the operator.
* @param group A string identifying the group name of ranks participating in the operator.
* @param stream A pointer identifying the stream information.
* @return hcclResult_t
*/
extern
hcclResult_t
hcom_broadcast
(
const
char
*
tag
,
void
*
ptr
,
u64
count
,
hcclDataType_t
dataType
,
u32
root
,
const
char
*
group
,
rtStream_t
stream
);
/**
* @brief Reduce-scatter operator.
*
* @param tag A string identifying the tag of the operator.
* @param inputPtr A pointer identifying the input data address of the operator.
* @param outputPtr A pointer identifying the output data address of the operator.
* @param count An integer(u64) identifying the number of the data.
* @param dataType The data type of the operator, must be one of the following types: int8, int32, float16, float32.
* @param op The reduction type of the operator, must be one of the following types: sum, min, max, prod.
* @param group A string identifying the group name of ranks participating in the operator.
* @param stream A pointer identifying the stream information.
* @return hcclResult_t
*/
extern
hcclResult_t
hcom_reduce_scatter
(
const
char
*
tag
,
void
*
inputPtr
,
void
*
outputPtr
,
u64
count
,
hcclDataType_t
dataType
,
hcclRedOp_t
op
,
const
char
*
group
,
rtStream_t
stream
);
/**
* @brief Get the rank number in the group.
*
* @param group A string identifying the group name.
* @param rankSize A pointer identifying the rank number.
* @return hcclResult_t
*/
hcclResult_t
hcom_get_rank_size
(
const
char
*
group
,
u32
*
rankSize
);
/**
* @brief Get the rank number of this rank's server within the group.
*
* @param group A string identifying the group name.
* @param localRankSize A pointer identifying the rank number.
* @return hcclResult_t
*/
hcclResult_t
hcom_get_local_rank_size
(
const
char
*
group
,
u32
*
localRankSize
);
/**
* @brief Get the rank id of this rank.
*
* @param group A string identifying the group name.
* @param rankId A pointer identifying the rank id.
* @return hcclResult_t
*/
hcclResult_t
hcom_get_rank_id
(
const
char
*
group
,
u32
*
rankId
);
/**
* @brief Get the local rank id of this rank's server within the group.
*
* @param group A string identifying the group name.
* @param localRankId A pointer identifying the local rank id.
* @return hcclResult_t
*/
hcclResult_t
hcom_get_local_rank_id
(
const
char
*
group
,
u32
*
localRankId
);
/**
* @brief Get the world rank id according to the group rank id.
*
* @param group A string identifying the group name.
* @param groupRank An integer(u32) identifying the group rank id.
* @param worldRank A pointer identifying the world rank id.
* @return hcclResult_t
*/
hcclResult_t
hcom_get_world_rank_from_group_rank
(
const
char
*
group
,
u32
groupRank
,
u32
*
worldRank
);
/**
* @brief Get the group rank id according to the world rank id.
*
* @param worldRank An integer(u32) identifying the world rank id.
* @param group A string identifying the group name.
* @param groupRank A pointer identifying the group rank id.
* @return hcclResult_t
*/
hcclResult_t
hcom_get_group_rank_from_world_rank
(
u32
worldRank
,
const
char
*
group
,
u32
*
groupRank
);
/**
* @brief Create group.
*
* @param group A string identifying the group name.
* @param rankNum An integer(u32) identifying the number of ranks in the group.
* @param rankIds A list identifying the ranks in the group.
* @return hcclResult_t
*/
hcclResult_t
hcom_create_group
(
const
char
*
group
,
u32
rankNum
,
u32
*
rankIds
);
/**
* @brief Destroy group
*
* @param group A string identifying the group name.
* @return hcclResult_t
*/
hcclResult_t
hcom_destroy_group
(
const
char
*
group
);
/**
* @brief Send operator.
*
* @param tag A string identifying the tag of the operator.
* @param inputPtr A pointer identifying the input data address of the operator.
* @param count An integer(u64) identifying the number of the data.
* @param dataType The data type of the operator, must be one of the following types: int8, int32, float16, float32.
* @param destRank An integer identifying the destination rank.
* @param srTag An integer identifying the send/recv message tag.
* The message will be send by the receive operator with the same "sr_tag".
* @param group A string identifying the group name of ranks participating in the operator.
* @param stream A pointer identifying the stream information.
* @return hcclResult_t
*/
hcclResult_t
hcom_send
(
const
char
*
tag
,
void
*
inputPtr
,
u64
count
,
hcclDataType_t
dataType
,
u32
destRank
,
u32
srTag
,
const
char
*
group
,
rtStream_t
stream
);
/**
* @brief Receive operator.
*
* @param tag A string identifying the tag of the operator.
* @param outputPtr A pointer identifying the output data address of the operator.
* @param count An integer(u64) identifying the number of the data.
* @param dataType The data type of the operator, must be one of the following types: int8, int32, float16, float32.
* @param srcRank An integer identifying the source rank.
* @param srTag An integer identifying the send/recv message tag.
* The message will be send by the send operator with the same "sr_tag".
* @param group A string identifying the group name of ranks participating in the operator.
* @param stream A pointer identifying the stream information.
* @return hcclResult_t
*/
hcclResult_t
hcom_receive
(
const
char
*
tag
,
void
*
outputPtr
,
u64
count
,
hcclDataType_t
dataType
,
u32
srcRank
,
u32
srTag
,
const
char
*
group
,
rtStream_t
stream
);
/**
* @brief Get the gradient split strategy with in the group.
*
* @param group A string identifying the group name.
* @param feature A pointer identifying the feature of the model.
* @param maxSegmentNum An integer(u32) identifying the max segments of gradients.
* @param segmentNum A pointer identifying the segments number of gradients.
* @param segmentIdx A list identifying the index of end gradient in each segment.
* @return hcclResult_t
*/
hcclResult_t
hcom_get_split_strategy
(
const
char
*
group
,
const
struct
model_feature
*
feature
,
u32
maxSegmentNum
,
u32
*
segmentNum
,
u32
*
segmentIdx
,
GradSplitForceMode
force
=
FORCE_NONE
);
/**
* @brief Set the gradient split strategy with in the group, according to gradient index.
*
* @param group A string identifying the group name.
* @param segmentNum An integer(u32) identifying the segments number of gradients.
* @param IdxList A list identifying the index of end gradient in each segment.
* @return hcclResult_t
*/
extern
hcclResult_t
hcom_set_split_strategy_by_index
(
const
char
*
group
,
u32
segmentNum
,
const
u32
*
IdxList
);
/**
* @brief Set the gradient split strategy with in the group, according to gradient data size.
*
* @param group A string identifying the group name.
* @param segmentNum An integer(u32) identifying the segments number of gradients.
* @param sizeList A list identifying the percent of each segment.
* @return hcclResult_t
*/
extern
hcclResult_t
hcom_set_split_strategy_by_size
(
const
char
*
group
,
u32
segmentNum
,
const
float
*
sizeList
);
#ifdef __cplusplus
}
#endif // __cplusplus
#endif // HCOM_H_
paddle/fluid/platform/enforce.h
浏览文件 @
9fcdaeba
...
...
@@ -40,6 +40,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_ASCEND_CL
#include "acl/acl.h"
#include "paddle/fluid/platform/dynload/hcom.h"
#endif // PADDLE_WITH_ASCEND_CL
#include <fstream>
...
...
@@ -1012,6 +1013,7 @@ struct NPUStatusType {};
}
DEFINE_NPU_STATUS_TYPE
(
aclError
,
ACL_ERROR_NONE
);
DEFINE_NPU_STATUS_TYPE
(
hcclResult_t
,
HCCL_SUCCESS
);
}
// namespace details
inline
std
::
string
build_npu_error_msg
(
aclError
stat
)
{
...
...
@@ -1020,6 +1022,13 @@ inline std::string build_npu_error_msg(aclError stat) {
return
sout
.
str
();
}
inline
std
::
string
build_npu_error_msg
(
hcclResult_t
stat
)
{
std
::
ostringstream
sout
;
sout
<<
" HCCL error, the error code is : "
<<
stat
<<
". "
;
return
sout
.
str
();
}
#define PADDLE_ENFORCE_NPU_SUCCESS(COND) \
do { \
auto __cond__ = (COND); \
...
...
paddle/fluid/platform/hccl_helper.h
0 → 100644
浏览文件 @
9fcdaeba
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_ASCEND_CL)
#include <stdio.h>
#include <memory>
#include <string>
#include <thread> // NOLINT
#include <typeindex>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/collective_helper.h"
#ifdef PADDLE_WITH_NCCL
#include "paddle/fluid/platform/dynload/nccl.h"
#endif
#ifdef PADDLE_WITH_RCCL
#include "paddle/fluid/platform/dynload/rccl.h"
#endif
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/dynload/hccl.h"
#endif
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#define NCCL_ID_VARNAME "NCCLID"
namespace
paddle
{
namespace
platform
{
inline
hcclDataType_t
ToHCCLDataType
(
framework
::
proto
::
VarType
::
Type
type
)
{
if
(
type
==
framework
::
proto
::
VarType
::
FP32
)
{
return
HCCL_DATA_TYPE_FP32
;
}
else
if
(
type
==
framework
::
proto
::
VarType
::
FP16
)
{
return
HCCL_DATA_TYPE_FP16
;
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT32
)
{
return
HCCL_DATA_TYPE_INT32
;
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT8
)
{
return
HCCL_DATA_TYPE_INT8
;
}
// else if (type == framework::proto::VarType::FP64) {
// return HCCL_DATA_TYPE_FP32;
// }
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"This datatype in hccl is not supported."
));
}
}
// // NOTE(minqiyang): according to the ncclGroupEnd documentations:
// // https://docs.nvidia.com/deeplearning/sdk/nccl-api/ncclapidoc.html,
// // ncclGroupEnd will wait for all communicators to be initialized, which will
// // cause blocking problem when a runtime_error was thrown, so try only guard
// // NCCL actions when use it.
// class NCCLGroupGuard {
// public:
// static std::mutex &NCCLMutex() {
// static std::mutex mtx;
// return mtx;
// }
// inline NCCLGroupGuard() {
// NCCLMutex().lock();
// PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclGroupStart());
// }
// inline ~NCCLGroupGuard() PADDLE_MAY_THROW {
// PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclGroupEnd());
// NCCLMutex().unlock();
// }
// };
// struct NCCLContext {
// std::unique_ptr<CUDADeviceContext> ctx_;
// ncclComm_t comm_;
// explicit NCCLContext(int dev_id)
// : ctx_(new CUDADeviceContext(CUDAPlace(dev_id))), comm_{nullptr} {}
// gpuStream_t stream() const { return ctx_->stream(); }
// ncclComm_t comm() const { return comm_; }
// int device_id() const {
// return BOOST_GET_CONST(platform::CUDAPlace, ctx_->GetPlace()).device;
// }
// };
// struct NCCLContextMap {
// std::unordered_map<int, NCCLContext> contexts_;
// std::vector<int> order_;
// explicit NCCLContextMap(const std::vector<platform::Place> &places,
// ncclUniqueId *nccl_id = nullptr,
// size_t num_trainers = 1, size_t trainer_id = 0) {
// PADDLE_ENFORCE_EQ(!places.empty(), true,
// platform::errors::InvalidArgument(
// "The NCCL place should not be empty."));
// order_.reserve(places.size());
// for (auto &p : places) {
// int dev_id = BOOST_GET_CONST(CUDAPlace, p).device;
// order_.emplace_back(dev_id);
// contexts_.emplace(dev_id, NCCLContext(dev_id));
// }
// PADDLE_ENFORCE_EQ(
// order_.size(), contexts_.size(),
// platform::errors::Unavailable("NCCL Context Map does not support "
// "contain two or more same device."));
// std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
// // if num_trainers == 1, should create a new nccl id for local comms.
// if (num_trainers == 1 && nccl_id == nullptr) {
// std::lock_guard<std::mutex> guard(NCCLGroupGuard::NCCLMutex());
// PADDLE_RETRY_CUDA_SUCCESS(platform::dynload::ncclCommInitAll(
// comms.get(), static_cast<int>(order_.size()), order_.data()));
// } else {
// PADDLE_ENFORCE_NOT_NULL(nccl_id, platform::errors::InvalidArgument(
// "The NCCL id should not be null."));
// {
// int nranks = num_trainers * order_.size();
// NCCLGroupGuard gurad;
// for (size_t i = 0; i < order_.size(); ++i) {
// int gpu_id = order_[i];
// int rank;
// if (order_.size() > 1) {
// rank = trainer_id * order_.size() + i;
// } else {
// rank = trainer_id;
// }
// VLOG(1) << "init nccl rank:" << rank << ", nranks:" << nranks
// << ", gpu_id:" << gpu_id << ", dev_id:" << order_[i];
// SetDeviceId(gpu_id);
// PADDLE_RETRY_CUDA_SUCCESS(platform::dynload::ncclCommInitRank(
// comms.get() + i, nranks, *nccl_id, rank));
// }
// }
// }
// int i = 0;
// for (auto &dev_id : order_) {
// contexts_.at(dev_id).comm_ = comms[i++];
// }
// }
// NCCLContextMap(const NCCLContextMap &other) = delete;
// NCCLContextMap &operator=(const NCCLContextMap &other) = delete;
// CUDADeviceContext *DevCtx(int dev_id) const { return at(dev_id).ctx_.get(); }
// CUDADeviceContext *DevCtx(platform::Place p) const {
// return DevCtx(BOOST_GET_CONST(CUDAPlace, p).device);
// }
// const NCCLContext &at(platform::Place p) const {
// return this->at(BOOST_GET_CONST(CUDAPlace, p).device);
// }
// const NCCLContext &at(int dev_id) const { return contexts_.at(dev_id); }
// void WaitAll() {
// for (auto &p : contexts_) {
// p.second.ctx_->Wait();
// }
// }
// };
// inline std::string GetFlatNCCLVarName(size_t pos) {
// if (pos == 0) {
// return NCCL_ID_VARNAME;
// }
// return string::Sprintf("%s_%d", NCCL_ID_VARNAME, static_cast<int>(pos));
// }
// inline std::string GetHierarchicalExterNCCLVarName(size_t pos) {
// return string::Sprintf("Hierarchical_exter_%s_%d", NCCL_ID_VARNAME,
// static_cast<int>(pos));
// }
// inline std::string GetHierarchicalInterNCCLVarName(size_t pos) {
// return string::Sprintf("Hierarchical_inter_%s_%d", NCCL_ID_VARNAME,
// static_cast<int>(pos));
// }
// class NCCLCommunicator {
// public:
// NCCLCommunicator() {}
// virtual ~NCCLCommunicator() PADDLE_MAY_THROW {}
// NCCLContextMap *DefaultFlatCtx() const {
// if (flat_ctxs_.size() == 0) {
// return nullptr;
// }
// return flat_ctxs_[0].get();
// }
// std::vector<std::unique_ptr<NCCLContextMap>> *GetFlatCtxs() {
// return &flat_ctxs_;
// }
// NCCLContextMap *GetFlatCtx(size_t run_order) const {
// return flat_ctxs_[run_order % flat_ctxs_.size()].get();
// }
// NCCLContextMap *GetRunEnvNCCLCtx(size_t run_order,
// bool use_hierarchical_allreduce) const {
// if (!use_hierarchical_allreduce) {
// return GetFlatCtx(run_order);
// }
// return GetHierarchicalInterCtx(run_order);
// }
// *When nccl inits nccl comm using ncclCommInitAll, it meets error when
// *allreduce ophandle and sync_batch_norm_op use ncclallreduce parallelly. So
// *create a new nccl comm for sync_batch_norm_op. And these codes should be
// *polished with a unified nccl management.
// NCCLContextMap *GetSyncBatchNormCtx(
// framework::Scope *scope, const std::vector<platform::Place> &places) {
// auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
// if (nccl_id_var != nullptr) {
// return DefaultFlatCtx();
// }
// if (sync_batch_norm_ctx_.get() == nullptr) {
// sync_batch_norm_ctx_.reset(new NCCLContextMap(places));
// }
// return sync_batch_norm_ctx_.get();
// }
// void InitFlatCtxs(const std::vector<platform::Place> &places,
// const std::vector<ncclUniqueId *> &nccl_ids,
// size_t trainers_num, size_t trainer_id) {
// if (nccl_ids.size() == 0) {
// auto ptr = new platform::NCCLContextMap(places);
// VLOG(1) << "init local trainer";
// flat_ctxs_.emplace_back(ptr);
// } else {
// for (size_t i = 0; i < nccl_ids.size(); i++) {
// auto ptr = new platform::NCCLContextMap(places, nccl_ids[i],
// trainers_num, trainer_id);
// VLOG(1) << "init trainer_id:" << trainer_id << ", comm no:" << i;
// flat_ctxs_.emplace_back(ptr);
// }
// }
// // as Executor have no way to use ncclComm created by ParallelExecutor,
// // we assign all flatten contexts to NCCLCommContext to fix.
// int nranks = static_cast<int>(trainers_num * places.size());
// int nrings = static_cast<int>(flat_ctxs_.size());
// for (int ring_id = 0; ring_id < nrings; ++ring_id) {
// for (size_t p = 0; p < places.size(); ++p) {
// int rank = trainer_id * places.size() + p;
// int dev_id = BOOST_GET_CONST(CUDAPlace, places[p]).device;
// auto &ctx = flat_ctxs_[ring_id]->contexts_.at(dev_id);
// NCCLCommContext::Instance().AssignNCCLComm(ctx.comm_, nranks, rank,
// dev_id, ring_id);
// }
// }
// }
// void InitHierarchicalCtxs(const std::vector<platform::Place> &places,
// const std::vector<ncclUniqueId *> &inter_nccl_ids,
// const std::vector<ncclUniqueId *> &exter_nccl_ids,
// size_t trainers_num, size_t trainer_id,
// size_t inter_trainers_num,
// size_t exter_trainers_num) {
// PADDLE_ENFORCE_EQ(
// trainers_num, inter_trainers_num * exter_trainers_num,
// platform::errors::InvalidArgument(
// "trainers_num:%llu != inter_trainers_num:%llu * "
// "exter_trainers_num:%llu",
// trainers_num, inter_trainers_num, exter_trainers_num));
// PADDLE_ENFORCE_GT(
// inter_trainers_num, 1,
// platform::errors::InvalidArgument(
// "The inter_trainers_num:%llu should be larger than 1.",
// inter_trainers_num));
// int inter_trainer_id = trainer_id % inter_trainers_num;
// for (size_t i = 0; i < inter_nccl_ids.size(); i++) {
// VLOG(1) << "init inter_trainer_id:" << inter_trainer_id
// << ", comm no:" << i;
// auto local = new NCCLContextMap(places, inter_nccl_ids[i],
// inter_trainers_num, inter_trainer_id);
// h_inter_ctxs_.emplace_back(local);
// }
// int exter_trainer_id = -1;
// if (trainer_id % inter_trainers_num == 0) {
// exter_trainer_id = trainer_id / inter_trainers_num;
// }
// if (exter_trainer_id >= 0) {
// for (size_t i = 0; i < exter_nccl_ids.size(); i++) {
// auto ex = new NCCLContextMap(places, exter_nccl_ids[i],
// exter_trainers_num, exter_trainer_id);
// VLOG(1) << "init exter_trainer_id:" << exter_trainer_id
// << ", comm no:" << i;
// h_exter_ctxs_.emplace_back(ex);
// }
// }
// }
// bool NeedExterAllReduce() const { return h_exter_ctxs_.size() > 0; }
// NCCLContextMap *GetHierarchicalInterCtx(size_t run_order) const {
// PADDLE_ENFORCE_GT(h_inter_ctxs_.size(), 0,
// platform::errors::InvalidArgument(
// "Hierarchical ctxs should be initialized firstly!"));
// return h_inter_ctxs_[run_order % h_inter_ctxs_.size()].get();
// }
// NCCLContextMap *GetHierarchicalExterCtx(size_t run_order) const {
// PADDLE_ENFORCE_GT(h_exter_ctxs_.size(), 0,
// platform::errors::InvalidArgument(
// "Hierarchical ctxs should be initialized firstly!"));
// return h_exter_ctxs_[run_order % h_exter_ctxs_.size()].get();
// }
// std::vector<std::unique_ptr<NCCLContextMap>> *GetHierarchicalInterCtxs() {
// return &h_inter_ctxs_;
// }
// std::vector<std::unique_ptr<NCCLContextMap>> *GetHierarchicalExterCtxs() {
// return &h_exter_ctxs_;
// }
// protected:
// // Support multi nccl comm on default nccl ring while NCCLContextMap can't.
// std::vector<std::unique_ptr<NCCLContextMap>> flat_ctxs_;
// // h_inter_ctxs_ and h_exter_ctxs_ are for 2d allreduce.
// // And h_exter_ctxs_ can support multi comm too.
// std::vector<std::unique_ptr<NCCLContextMap>> h_inter_ctxs_;
// std::vector<std::unique_ptr<NCCLContextMap>> h_exter_ctxs_;
// // just used for sync_batch_norm op.
// std::unique_ptr<NCCLContextMap> sync_batch_norm_ctx_;
// };
}
// namespace platform
}
// namespace paddle
#endif
paddle/fluid/pybind/op_function_generator.cc
浏览文件 @
9fcdaeba
...
...
@@ -326,7 +326,7 @@ GenerateOpFunctions(const std::string& module_name) {
}
ins_initializer
+=
"}"
;
if
(
input_args
.
back
()
==
','
)
{
if
(
!
input_args
.
empty
()
&&
input_args
.
back
()
==
','
)
{
input_args
.
pop_back
();
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录