Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
45765d6e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
45765d6e
编写于
3月 02, 2021
作者:
V
Void Main
提交者:
GitHub
3月 02, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor HCCLCommContext to be compatible with Paddle (#31359)
Refactor HCCLCommContext to be compatible with Paddle (#31359)
上级
8497e2aa
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
185 addition
and
190 deletion
+185
-190
paddle/fluid/operators/collective/CMakeLists.txt
paddle/fluid/operators/collective/CMakeLists.txt
+1
-1
paddle/fluid/operators/collective/c_allreduce_op.h
paddle/fluid/operators/collective/c_allreduce_op.h
+1
-1
paddle/fluid/operators/collective/c_broadcast_op_npu.cc
paddle/fluid/operators/collective/c_broadcast_op_npu.cc
+2
-2
paddle/fluid/operators/collective/c_comm_init_hcom_op.cc
paddle/fluid/operators/collective/c_comm_init_hcom_op.cc
+31
-16
paddle/fluid/operators/collective/c_create_group_op.cc
paddle/fluid/operators/collective/c_create_group_op.cc
+0
-76
paddle/fluid/operators/collective/c_hcom_op_npu_test.cc
paddle/fluid/operators/collective/c_hcom_op_npu_test.cc
+14
-22
paddle/fluid/platform/collective_helper.h
paddle/fluid/platform/collective_helper.h
+49
-10
paddle/fluid/platform/collective_helper_npu.cc
paddle/fluid/platform/collective_helper_npu.cc
+87
-54
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+0
-8
未找到文件。
paddle/fluid/operators/collective/CMakeLists.txt
浏览文件 @
45765d6e
...
...
@@ -36,4 +36,4 @@ endif()
set
(
OPERATOR_DEPS
${
OPERATOR_DEPS
}
${
COLLECTIVE_DEPS
}
PARENT_SCOPE
)
set
(
GLOB_COLLECTIVE_DEPS
${
COLLECTIVE_DEPS
}
CACHE INTERNAL
"collective dependency"
)
cc_test
(
c_hcom_op_npu_test SRCS c_hcom_op_npu_test.cc DEPS op_registry c_broadcast_op c_allreduce_sum_op c_comm_init_hc
cl_op c_create_group
_op
${
COLLECTIVE_DEPS
}
ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor
)
cc_test
(
c_hcom_op_npu_test SRCS c_hcom_op_npu_test.cc DEPS op_registry c_broadcast_op c_allreduce_sum_op c_comm_init_hc
om
_op
${
COLLECTIVE_DEPS
}
ascend_hccl dynamic_loader dynload_warpctc scope device_context enforce executor
)
paddle/fluid/operators/collective/c_allreduce_op.h
浏览文件 @
45765d6e
...
...
@@ -135,7 +135,7 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
std
::
string
group
=
std
::
string
(
HCOM_GROUP_PREFIX
)
+
std
::
to_string
(
ring_id
);
group
=
"hccl_world_group"
;
// std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
auto
comm
=
paddle
::
platform
::
HCCLCommContext
::
Instance
().
Get
();
auto
comm
=
paddle
::
platform
::
HCCLCommContext
::
Instance
().
Get
(
ring_id
,
place
);
aclrtStream
stream
=
nullptr
;
if
(
ctx
.
Attr
<
bool
>
(
"use_calc_stream"
))
{
...
...
paddle/fluid/operators/collective/c_broadcast_op_npu.cc
浏览文件 @
45765d6e
...
...
@@ -34,8 +34,9 @@ class CBroadcastOpASCENDKernel : public framework::OpKernel<T> {
auto
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
int
ring_id
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
auto
place
=
ctx
.
GetPlace
();
auto
comm
=
paddle
::
platform
::
HCCLCommContext
::
Instance
().
Get
();
auto
comm
=
paddle
::
platform
::
HCCLCommContext
::
Instance
().
Get
(
ring_id
,
place
);
aclrtStream
stream
=
nullptr
;
if
(
ctx
.
Attr
<
bool
>
(
"use_calc_stream"
))
{
...
...
@@ -46,7 +47,6 @@ class CBroadcastOpASCENDKernel : public framework::OpKernel<T> {
}
int
root
=
ctx
.
Attr
<
int
>
(
"root"
);
int
ring_id
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
std
::
string
group
=
std
::
string
(
HCOM_GROUP_PREFIX
)
+
std
::
to_string
(
ring_id
);
std
::
string
tag
=
ctx
.
Attr
<
std
::
string
>
(
"tag"
);
...
...
paddle/fluid/operators/collective/c_comm_init_hc
cl
_op.cc
→
paddle/fluid/operators/collective/c_comm_init_hc
om
_op.cc
浏览文件 @
45765d6e
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/platform/hccl_helper.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/npu_op_runner.h"
...
...
@@ -40,16 +41,24 @@ class CCommInitOpNPU : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
std
::
string
rank_table_file
=
Attr
<
std
::
string
>
(
"rank_table_file"
);
uint32_t
rank_id
=
Attr
<
int
>
(
"rank_id"
);
uint32_t
device_id
=
Attr
<
int
>
(
"device_id"
);
int
rid
=
Attr
<
int
>
(
"ring_id"
);
int
nranks
=
Attr
<
int
>
(
"nranks"
);
int
rank_id
=
Attr
<
int
>
(
"rank"
);
int
device_id
=
BOOST_GET_CONST
(
platform
::
NPUPlace
,
place
).
device
;
if
(
Attr
<
int
>
(
"device_id"
)
>=
0
)
{
device_id
=
Attr
<
int
>
(
"device_id"
);
}
std
::
vector
<
int
>
rank_ids
=
Attr
<
std
::
vector
<
int
>>
(
"rank_ids"
);
VLOG
(
3
)
<<
"begin init hccl, parameter is: "
<<
"rank_table_file "
<<
rank_table_file
<<
" rank_id "
<<
rank_id
<<
" device_id "
<<
device_id
;
VLOG
(
3
)
<<
"begin c_comm_init on npu, parameters are: "
<<
"ring id["
<<
rid
<<
"], nranks["
<<
nranks
<<
"], rank_id["
<<
rank_id
<<
"], device_id["
<<
device_id
<<
"]"
;
platform
::
HCCLCommContext
::
Instance
().
CreateHCCLComm
(
rank_table_file
,
rank_id
,
device_id
);
platform
::
HCCLCommContext
::
Instance
().
CreateHCCLComm
(
rank_ids
,
rank_id
,
device_id
,
rid
);
}
};
...
...
@@ -61,10 +70,17 @@ CCommInit operator on NPU
Initialize collective communication context within this trainer
)DOC"
);
AddAttr
<
std
::
string
>
(
"rank_table_file"
,
"(string) path to rank_table_file"
);
AddAttr
<
int
>
(
"rank_id"
,
"(int) world rank id of the process"
);
AddAttr
<
int
>
(
"device_id"
,
"(int) device id of the process/thread"
);
AddAttr
<
int
>
(
"nranks"
,
"(int) The number of ranks of distributed trainers"
);
AddAttr
<
std
::
vector
<
int
>>
(
"rank_ids"
,
"The world rank ids of the group"
);
AddAttr
<
int
>
(
"rank"
,
"(int) The rank of the trainer in distributed training."
);
AddAttr
<
int
>
(
"device_id"
,
"(int) The deivce_id on which to initialize the communicator."
"Now, you only have to set this attr manually for pipeline "
"training. Otherwise, make it as default."
)
.
SetDefault
(
-
1
);
AddAttr
<
int
>
(
"ring_id"
,
"(int default 0) user specified ring id"
)
.
SetDefault
(
0
);
}
};
...
...
@@ -73,7 +89,6 @@ Initialize collective communication context within this trainer
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
c_comm_init_hccl
,
ops
::
CCommInitOpNPU
,
ops
::
CCommInitOpNPUMaker
);
REGISTER_OPERATOR
(
c_comm_init_hcom
,
ops
::
CCommInitOpNPU
,
ops
::
CCommInitOpNPUMaker
);
#endif
paddle/fluid/operators/collective/c_create_group_op.cc
已删除
100644 → 0
浏览文件 @
8497e2aa
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/hccl_helper.h"
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
operators
{
class
CCreateGroupOpNPU
:
public
framework
::
OperatorBase
{
public:
CCreateGroupOpNPU
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
std
::
string
group_name
=
Attr
<
std
::
string
>
(
"group_name"
);
int
nranks
=
Attr
<
int
>
(
"nranks"
);
std
::
vector
<
int
>
rank_ids
=
Attr
<
std
::
vector
<
int
>>
(
"rank_ids"
);
paddle
::
platform
::
HCCLCommContext
::
Instance
().
CreateHCCLGroup
(
group_name
,
(
uint32_t
)
nranks
,
std
::
vector
<
uint32_t
>
(
rank_ids
.
begin
(),
rank_ids
.
end
()));
}
};
class
CCreateGroupOpNPUMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddComment
(
R"DOC(
CCreateGroup operator on NPU
Create collective communication group on NPU
)DOC"
);
AddAttr
<
std
::
string
>
(
"group_name"
,
"(string) name of the collective communication group"
);
AddAttr
<
int
>
(
"nranks"
,
"(int) number of the group"
);
AddAttr
<
std
::
vector
<
int
>>
(
"rank_ids"
,
"(list of int) The world rank id of the group members"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
c_create_group
,
ops
::
CCreateGroupOpNPU
,
ops
::
CCreateGroupOpNPUMaker
);
#endif
paddle/fluid/operators/collective/c_hcom_op_npu_test.cc
浏览文件 @
45765d6e
...
...
@@ -43,37 +43,29 @@ namespace m = paddle::operators::math;
USE_OP
(
c_broadcast
);
USE_OP
(
c_allreduce_sum
);
USE_NO_KERNEL_OP
(
c_comm_init_hccl
);
USE_NO_KERNEL_OP
(
c_create_group
);
USE_NO_KERNEL_OP
(
c_comm_init_hcom
);
USE_OP_DEVICE_KERNEL
(
c_broadcast
,
NPU
);
USE_OP_DEVICE_KERNEL
(
c_allreduce_sum
,
NPU
);
void
Prepare
(
f
::
Scope
*
scope
,
const
p
::
DeviceContext
&
ctx
){
std
::
string
rank_table_file
=
getenv
(
"RANK_TABLE_FILE"
);
int
rank_id
=
atoi
(
getenv
(
"RANK_ID"
));
int
device_id
=
atoi
(
getenv
(
"DEVICE_ID"
));
printf
(
"rank_
table_file: %s, rank_id = %d, device_id = %d
\n
"
,
rank_table_file
.
c_str
()
,
rank_id
,
device_id
);
printf
(
"rank_
id = %d, device_id = %d
\n
"
,
rank_id
,
device_id
);
f
::
AttributeMap
attrs
;
attrs
[
"rank_table_file"
]
=
rank_table_file
;
attrs
[
"rank_id"
]
=
rank_id
;
attrs
[
"device_id"
]
=
device_id
;
std
::
vector
<
int
>
rank_ids
{
0
,
1
};
f
::
AttributeMap
comm_init_attrs
;
comm_init_attrs
[
"ring_id"
]
=
0
;
comm_init_attrs
[
"nranks"
]
=
2
;
comm_init_attrs
[
"rank"
]
=
rank_id
;
comm_init_attrs
[
"device_id"
]
=
device_id
;
comm_init_attrs
[
"rank_ids"
]
=
rank_ids
;
auto
comm_init_op
=
f
::
OpRegistry
::
CreateOp
(
"c_comm_init_hc
cl"
,
{},
{},
attrs
);
f
::
OpRegistry
::
CreateOp
(
"c_comm_init_hc
om"
,
{},
{},
comm_init_
attrs
);
auto
place
=
ctx
.
GetPlace
();
comm_init_op
->
Run
(
*
scope
,
place
);
ctx
.
Wait
();
f
::
AttributeMap
create_attrs
;
create_attrs
[
"group_name"
]
=
HCOM_GROUP_PREFIX
+
std
::
to_string
(
0
);
create_attrs
[
"nranks"
]
=
2
;
std
::
vector
<
int
>
rank_ids
{
0
,
1
};
create_attrs
[
"rank_ids"
]
=
rank_ids
;
auto
create_group_op
=
f
::
OpRegistry
::
CreateOp
(
"c_create_group"
,
{},
{},
create_attrs
);
create_group_op
->
Run
(
*
scope
,
place
);
ctx
.
Wait
();
}
void
TestHCCLBroadcastOp
(
f
::
Scope
*
scope
,
const
p
::
DeviceContext
&
ctx
)
{
std
::
cout
<<
"BEGIN TEST:"
<<
__FUNCTION__
<<
std
::
endl
;
...
...
@@ -187,6 +179,6 @@ TEST(c_broadcast, NPU) {
p
::
NPUDeviceContext
ctx
(
p
::
NPUPlace
(
atoi
(
npu_id
)));
Prepare
(
&
scope
,
ctx
);
//
TestHCCLBroadcastOp(&scope, ctx);
TestHCCLAllReduceOp
(
&
scope
,
ctx
);
TestHCCLBroadcastOp
(
&
scope
,
ctx
);
//
TestHCCLAllReduceOp(&scope, ctx);
}
paddle/fluid/platform/collective_helper.h
浏览文件 @
45765d6e
...
...
@@ -147,11 +147,16 @@ class NCCLCommContext {
// singleton with a global user specified group id.
class
NPUDeviceContext
;
#define ENV_RANK_TABLE_FILE "RANK_TABLE_FILE"
#define ENV_RANK_ID "RANK_ID"
#define ENV_DEV_ID "DEV_ID"
class
HCCLComm
{
public:
virtual
std
::
string
rank_table_file
()
const
=
0
;
virtual
uint32_t
rank
()
const
=
0
;
virtual
uint32_t
device_id
()
const
=
0
;
virtual
int
ring_id
()
const
=
0
;
virtual
int
nranks
()
const
=
0
;
virtual
int
rank
()
const
=
0
;
virtual
int
device_id
()
const
=
0
;
virtual
aclrtStream
stream
()
const
=
0
;
virtual
NPUDeviceContext
*
dev_context
()
const
=
0
;
virtual
~
HCCLComm
()
=
default
;
...
...
@@ -165,22 +170,56 @@ class HCCLCommContext {
return
comm_ctx
;
}
HCCLComm
*
CreateHCCLComm
(
const
std
::
string
&
config_file
,
uint32_t
rank
,
uint32_t
device_id
);
HCCLComm
*
CreateHCCLComm
(
const
std
::
vector
<
int
>&
world_rank_ids
,
int
rank
,
int
dev_id
,
int
ring_id
=
0
);
// a latter comm with the same dev_id and the same ring_id
// will override the former
HCCLComm
*
AssignHCCLComm
(
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
=
0
);
// retrieve a communicator by the ring id in multiprocessing mode
HCCLComm
*
Get
(
int
ring_id
)
const
{
PADDLE_ENFORCE_GT
(
comm_map_
.
count
(
ring_id
),
0
,
platform
::
errors
::
InvalidArgument
(
"Communicator in ring id %d has not been initialized."
,
ring_id
));
PADDLE_ENFORCE_EQ
(
comm_map_
.
at
(
ring_id
).
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"One device id should be specified to retrieve from "
"multiple communicators."
));
return
comm_map_
.
at
(
ring_id
).
begin
()
->
second
.
get
();
}
void
CreateHCCLGroup
(
const
std
::
string
&
group_name
,
uint32_t
nranks
,
const
std
::
vector
<
uint32_t
>&
rank_ids
);
// retrieve a communicator by the ring id and the device id
HCCLComm
*
Get
(
int
ring_id
,
int
dev_id
)
const
{
PADDLE_ENFORCE_GT
(
comm_map_
.
count
(
ring_id
),
0
,
platform
::
errors
::
InvalidArgument
(
"Communicator of ring id %d has not been initialized."
,
ring_id
));
PADDLE_ENFORCE_GT
(
comm_map_
.
at
(
ring_id
).
count
(
dev_id
),
0
,
platform
::
errors
::
InvalidArgument
(
"Communicator at device id %d has not been initialized in ring %d."
,
dev_id
,
ring_id
));
return
comm_map_
.
at
(
ring_id
).
at
(
dev_id
).
get
();
}
// retrieve a communicator by the ring id and place
HCCLComm
*
Get
()
const
{
return
comm_
.
get
(
);
HCCLComm
*
Get
(
int
ring_id
,
Place
place
)
const
{
return
Get
(
ring_id
,
BOOST_GET_CONST
(
NPUPlace
,
place
).
device
);
}
private:
// Init global hcom
HCCLCommContext
()
{
InitHcomWorldGroup
();
}
std
::
once_flag
once_flag_
;
std
::
mutex
comm_map_mutex_
;
std
::
unique_ptr
<
HCCLComm
>
comm_
;
// ring id to dev-HCCLComm
std
::
map
<
int
,
std
::
map
<
int
,
std
::
unique_ptr
<
HCCLComm
>>>
comm_map_
;
HCCLComm
*
AssignHCCLComm
(
const
std
::
string
&
config_file
,
uint32_t
rank
,
uint32_t
device_id
);
void
InitHcomWorldGroup
();
void
ReleaseHCCLComms
();
HCCLCommContext
()
=
default
;
DISABLE_COPY_AND_ASSIGN
(
HCCLCommContext
);
};
#endif
...
...
paddle/fluid/platform/collective_helper_npu.cc
浏览文件 @
45765d6e
...
...
@@ -21,14 +21,18 @@ namespace platform {
class
HCCLCommImpl
:
public
HCCLComm
{
public:
void
set_r
ank_table_file
(
const
std
::
string
&
rank_table_file
)
{
rank_table_file_
=
rank_table_file
;
}
std
::
string
rank_table_file
()
const
override
{
return
rank_table_file
_
;
}
void
set_r
ing_id
(
int
ring_id
)
{
ring_id_
=
ring_id
;
}
int
ring_id
()
const
override
{
return
ring_id
_
;
}
void
set_
rank
(
uint32_t
rank
)
{
rank_
=
rank
;
}
uint32_t
rank
()
const
override
{
return
rank
_
;
}
void
set_
nranks
(
int
nranks
)
{
nranks_
=
nranks
;
}
int
nranks
()
const
override
{
return
nranks
_
;
}
void
set_device_id
(
uint32_t
device_id
)
{
device_id_
=
device_id
;
}
uint32_t
device_id
()
const
override
{
return
device_id_
;
}
void
set_rank
(
int
rank
)
{
rank_
=
rank
;
}
int
rank
()
const
override
{
return
rank_
;
}
int
device_id
()
const
override
{
return
BOOST_GET_CONST
(
NPUPlace
,
dev_ctx_
->
GetPlace
()).
device
;
}
aclrtStream
stream
()
const
override
{
return
dev_ctx_
->
stream
();
}
...
...
@@ -38,74 +42,103 @@ class HCCLCommImpl : public HCCLComm {
NPUDeviceContext
*
dev_context
()
const
override
{
return
dev_ctx_
.
get
();
}
private:
std
::
string
rank_table_file
_
;
uint32_t
rank
_
;
uint32_t
device_id
_
;
int
ring_id
_
;
int
nranks
_
;
int
rank
_
;
std
::
unique_ptr
<
NPUDeviceContext
>
dev_ctx_
;
};
HCCLComm
*
HCCLCommContext
::
CreateHCCLComm
(
const
std
::
string
&
rank_table_file
,
uint32_t
rank
,
uint32_t
device_id
)
{
/*
PADDLE_ENFORCE_NOT_NULL(rank_table_file,
HCCLComm
*
HCCLCommContext
::
CreateHCCLComm
(
const
std
::
vector
<
int
>&
world_rank_ids
,
int
rank
,
int
dev_id
,
int
ring_id
)
{
PADDLE_ENFORCE_GT
(
world_rank_ids
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"The rank table file should not be null."));
"Expected world_rank_ids.size() > 1. But received size is %d."
,
world_rank_ids
.
size
()));
PADDLE_ENFORCE_GE
(
rank
,
0
,
platform
::
errors
::
InvalidArgument
(
"Expected rank >= 0. But received rank is %d."
,
rank
));
PADDLE_ENFORCE_GE(device_id, 0
,
PADDLE_ENFORCE_LT
(
rank
,
world_rank_ids
.
size
()
,
platform
::
errors
::
InvalidArgument
(
"Expected dev_id >= 0. But received dev_id is %d.", device_id));
*/
auto
*
comm_wrapper
=
AssignHCCLComm
(
rank_table_file
,
rank
,
device_id
);
"Expected rank < nranks. But received rank is %d, nranks is %d."
,
rank
,
world_rank_ids
.
size
()));
PADDLE_ENFORCE_GE
(
dev_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"Expected dev_id >= 0. But received dev_id is %d."
,
dev_id
));
PADDLE_ENFORCE_GE
(
ring_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"Expected ring_id >= 0. But received ring_id is %d."
,
ring_id
));
auto
*
comm_wrapper
=
AssignHCCLComm
(
world_rank_ids
.
size
(),
rank
,
dev_id
,
ring_id
);
platform
::
dynload
::
hcom_init
(
rank_table_file
.
c_str
(),
std
::
to_string
(
rank
).
c_str
());
platform
::
dynload
::
hcom_bind_model
(
comm_wrapper
->
stream
(),
comm_wrapper
->
stream
());
// HACK(sunpeng17): hcom API requires bind stream to a model
// but we don't need model in Paddle, so we feed stream pointer as model pointer
PADDLE_ENFORCE_NPU_SUCCESS
(
platform
::
dynload
::
hcom_bind_model
(
comm_wrapper
->
stream
(),
comm_wrapper
->
stream
()));
// Get world_rank_ids registered in gen_nccl_id op
std
::
string
group_name
=
HCOM_GROUP_PREFIX
+
std
::
to_string
(
ring_id
);
PADDLE_ENFORCE_NPU_SUCCESS
(
platform
::
dynload
::
hcom_create_group
(
group_name
.
c_str
(),
world_rank_ids
.
size
(),
(
unsigned
int
*
)
world_rank_ids
.
data
()));
VLOG
(
1
)
<<
"hccl communicator of rank "
<<
rank
<<
" in ring "
<<
ring_id
<<
" has been created on device "
<<
dev_id
<<
", group name: "
<<
group_name
;
std
::
call_once
(
once_flag_
,
[]()
{
std
::
atexit
([]()
{
HCCLCommContext
::
Instance
().
ReleaseHCCLComms
();
});
});
VLOG
(
1
)
<<
"hccl communicator of rank "
<<
rank
<<
" has been created"
;
return
comm_wrapper
;
}
HCCLComm
*
HCCLCommContext
::
AssignHCCLComm
(
const
std
::
string
&
rank_table_file
,
uint32_t
rank
,
uint32_t
device_id
)
{
HCCLComm
*
HCCLCommContext
::
AssignHCCLComm
(
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
)
{
std
::
unique_ptr
<
NPUDeviceContext
>
dev_ctx
(
new
NPUDeviceContext
(
NPUPlace
(
device_id
)));
VLOG
(
3
)
<<
"device_id"
<<
device_id
;
VLOG
(
3
)
<<
"dev_ctx->stream()"
<<
dev_ctx
->
stream
();
new
NPUDeviceContext
(
NPUPlace
(
dev_id
)));
HCCLCommImpl
*
c
=
new
HCCLCommImpl
;
c
->
set_rank_table_file
(
rank_table_file
);
c
->
set_ring_id
(
ring_id
);
c
->
set_nranks
(
nranks
);
c
->
set_rank
(
rank
);
c
->
set_device_id
(
device_id
);
c
->
set_dev_ctx
(
std
::
move
(
dev_ctx
));
// comm_ = c
comm_
.
reset
(
c
);
return
c
;
comm_map_mutex_
.
lock
();
if
(
comm_map_
.
count
(
ring_id
)
==
0
)
{
comm_map_
.
emplace
(
ring_id
,
std
::
map
<
int
,
std
::
unique_ptr
<
HCCLComm
>>
());
}
auto
&
dev2comm
=
comm_map_
[
ring_id
];
dev2comm
.
emplace
(
dev_id
,
std
::
unique_ptr
<
HCCLComm
>
(
c
));
comm_map_mutex_
.
unlock
();
return
comm_map_
[
ring_id
][
dev_id
].
get
();
}
void
HCCLCommContext
::
CreateHCCLGroup
(
const
std
::
string
&
group_name
,
uint32_t
nranks
,
const
std
::
vector
<
uint32_t
>&
rank_ids
)
{
/*
PADDLE_ENFORCE_NOT_NULL(group_name,
platform::errors::InvalidArgument(
"The group name should not be null."));
PADDLE_ENFORCE_GT(nranks, 0,
platform::errors::InvalidArgument(
"Expected nranks > 0. But received nranks is %d.", nranks));
PADDLE_ENFORCE_NOT_NULL(rank_ids,
platform::errors::InvalidArgument(
"The rank ids should not be null."));
*/
platform
::
dynload
::
hcom_create_group
(
group_name
.
c_str
(),
nranks
,
(
unsigned
int
*
)
rank_ids
.
data
());
void
HCCLCommContext
::
InitHcomWorldGroup
()
{
const
char
*
rank_table_file
=
getenv
(
ENV_RANK_TABLE_FILE
);
PADDLE_ENFORCE_NOT_NULL
(
rank_table_file
,
platform
::
errors
::
InvalidArgument
(
"The RANK_TABLE_FILE environment variable should not be null."
));
VLOG
(
1
)
<<
"hccl group with name "
<<
group_name
<<
" has been created"
;
const
char
*
rank_id
=
getenv
(
ENV_RANK_ID
);
PADDLE_ENFORCE_NOT_NULL
(
rank_id
,
platform
::
errors
::
InvalidArgument
(
"The RANK_ID environment variable should not be null."
));
PADDLE_ENFORCE_NPU_SUCCESS
(
platform
::
dynload
::
hcom_init
(
rank_table_file
,
rank_id
));
VLOG
(
3
)
<<
"Successfully initialized hcom. rank_table_file: "
<<
rank_table_file
<<
", rank_id "
<<
rank_id
;
}
void
HCCLCommContext
::
ReleaseHCCLComms
()
{
for
(
auto
&
p
:
comm_map_
)
{
for
(
auto
&
q
:
p
.
second
)
{
q
.
second
.
reset
();
}
}
}
}
// namespace platform
}
// namespace paddle
#endif
paddle/fluid/platform/device_context.h
浏览文件 @
45765d6e
...
...
@@ -178,14 +178,6 @@ class NPUDeviceContext : public DeviceContext {
/*! \brief Return npu stream in the device context. */
aclrtStream
stream
()
const
;
#ifdef PADDLE_WITH_ASCEND_HCCL
/*! \brief Return bkcl context. */
HCCLContext_t
hccl_context
()
const
{
return
hccl_context_
;
}
/*! \brief Set bkcl context. */
void
set_hccl_context
(
HCCLContext_t
context
)
{
hccl_context_
=
context
;
}
#endif
private:
NPUPlace
place_
;
aclrtContext
context_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录