Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d28f6f7b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d28f6f7b
编写于
1月 30, 2022
作者:
mhhhh1
提交者:
GitHub
1月 30, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(cncl_mlu): add cncl dev for mlu distributed backend (#39294)
上级
eefe5feb
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
402 addition
and
1 deletion
+402
-1
CMakeLists.txt
CMakeLists.txt
+8
-0
cmake/neuware.cmake
cmake/neuware.cmake
+8
-1
paddle/fluid/platform/CMakeLists.txt
paddle/fluid/platform/CMakeLists.txt
+4
-0
paddle/fluid/platform/collective_helper.h
paddle/fluid/platform/collective_helper.h
+100
-0
paddle/fluid/platform/device/mlu/CMakeLists.txt
paddle/fluid/platform/device/mlu/CMakeLists.txt
+1
-0
paddle/fluid/platform/device/mlu/cncl_helper.h
paddle/fluid/platform/device/mlu/cncl_helper.h
+57
-0
paddle/fluid/platform/device/mlu/device_context.h
paddle/fluid/platform/device/mlu/device_context.h
+15
-0
paddle/fluid/platform/device/mlu/enforce.h
paddle/fluid/platform/device/mlu/enforce.h
+14
-0
paddle/fluid/platform/device/mlu/enforce_test.cc
paddle/fluid/platform/device/mlu/enforce_test.cc
+10
-0
paddle/fluid/platform/device/mlu/mlu_collective_helper.cc
paddle/fluid/platform/device/mlu/mlu_collective_helper.cc
+179
-0
paddle/fluid/platform/device/mlu/mlu_info.h
paddle/fluid/platform/device/mlu/mlu_info.h
+6
-0
未找到文件。
CMakeLists.txt
浏览文件 @
d28f6f7b
...
...
@@ -230,6 +230,7 @@ option(WITH_INFRT "Compile PaddlePaddle with INFRT" OFF)
option
(
WITH_NCCL
"Compile PaddlePaddle with NCCL support"
ON
)
option
(
WITH_RCCL
"Compile PaddlePaddle with RCCL support"
ON
)
option
(
WITH_XPU_BKCL
"Compile PaddlePaddle with BAIDU KUNLUN XPU BKCL"
OFF
)
option
(
WITH_CNCL
"Compile PaddlePaddle with CNCL support"
OFF
)
option
(
WITH_CRYPTO
"Compile PaddlePaddle with crypto support"
ON
)
option
(
WITH_ARM
"Compile PaddlePaddle with arm support"
OFF
)
option
(
WITH_SW
"Compile PaddlePaddle with sw support"
OFF
)
...
...
@@ -292,6 +293,13 @@ if (NOT WITH_XPU AND WITH_XPU_BKCL)
"Disable BKCL when compiling without XPU"
FORCE
)
endif
()
if
(
NOT WITH_MLU AND WITH_CNCL
)
MESSAGE
(
WARNING
"Disable CNCL when compiling without MLU. Force WITH_MLU=OFF."
)
set
(
WITH_MLU OFF CACHE STRING
"Disable CNCL when compiling without MLU"
FORCE
)
endif
()
if
(
WITH_NCCL
)
add_definitions
(
"-DPADDLE_WITH_NCCL"
)
include
(
nccl
)
...
...
cmake/neuware.cmake
浏览文件 @
d28f6f7b
...
...
@@ -19,4 +19,11 @@ set(CNRT_LIB ${NEUWARE_LIB_DIR}/libcnrt.so)
set
(
CNDRV_LIB
${
NEUWARE_LIB_DIR
}
/libcndrv.so
)
generate_dummy_static_lib
(
LIB_NAME
"neuware_lib"
GENERATOR
"neuware.cmake"
)
TARGET_LINK_LIBRARIES
(
neuware_lib
${
CNNL_LIB
}
${
CNRT_LIB
}
${
CNDRV_LIB
}
)
if
(
WITH_CNCL
)
MESSAGE
(
STATUS
"Compile with CNCL!"
)
ADD_DEFINITIONS
(
-DPADDLE_WITH_CNCL
)
set
(
CNCL_LIB
${
NEUWARE_LIB_DIR
}
/libcncl.so
)
TARGET_LINK_LIBRARIES
(
neuware_lib
${
CNCL_LIB
}
${
CNNL_LIB
}
${
CNRT_LIB
}
${
CNDRV_LIB
}
)
else
()
TARGET_LINK_LIBRARIES
(
neuware_lib
${
CNNL_LIB
}
${
CNRT_LIB
}
${
CNDRV_LIB
}
)
endif
()
paddle/fluid/platform/CMakeLists.txt
浏览文件 @
d28f6f7b
...
...
@@ -133,6 +133,10 @@ if(WITH_ASCEND_CL)
target_link_libraries
(
collective_helper npu_collective_helper
)
endif
()
if
(
WITH_CNCL
)
target_link_libraries
(
collective_helper mlu_collective_helper
)
endif
()
if
(
WITH_GPU OR WITH_ROCM
)
target_link_libraries
(
device_context gpu_resource_pool
)
endif
()
...
...
paddle/fluid/platform/collective_helper.h
浏览文件 @
d28f6f7b
...
...
@@ -24,6 +24,9 @@
#include "paddle/fluid/platform/device/npu/dynload/hccl.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#if defined(PADDLE_WITH_CNCL)
#include "paddle/fluid/platform/device/mlu/device_context.h"
#endif
namespace
paddle
{
namespace
platform
{
...
...
@@ -333,5 +336,102 @@ class BKCLCommContext {
};
#endif
#if defined(PADDLE_WITH_CNCL)
// In order to apply hierarchical communication with CNCL, we need
// a communication ring contains CNCL communicators associated to a global
// cnclUniqueId. E.g. for a hierarchical case,
//
// 11 - 12 21 - 22
// | | | |
// 13 - 14 - 23 - 24
// | |
// 31 - 32 - 41 - 42
// | | | |
// 33 - 34 43 - 44
//
// we group (14,23,32,41) as the top, and (11,12,13,14), (21,22,23,24),
// (31,32,33,34), (41,42,43,44) as bottoms respectively.
//
// We could also use a single communication ring for the flatten case
//
// The CNCLComm instance is created and reversed in the CNCLCommContext
// singleton with a global user specified group id.
class
MLUDeviceContext
;
class
CNCLComm
{
public:
virtual
int
ring_id
()
const
=
0
;
virtual
int
nranks
()
const
=
0
;
virtual
int
rank
()
const
=
0
;
virtual
int
device_id
()
const
=
0
;
virtual
cnclComm_t
comm
()
const
=
0
;
virtual
mluStream
stream
()
const
=
0
;
virtual
MLUDeviceContext
*
dev_context
()
const
=
0
;
virtual
~
CNCLComm
()
=
default
;
};
// A singleton CNCL communicator context reserves communication ring ids
class
CNCLCommContext
{
public:
static
CNCLCommContext
&
Instance
()
{
static
CNCLCommContext
comm_ctx
;
return
comm_ctx
;
}
CNCLComm
*
CreateComm
(
cnclCliqueId
*
cncl_id
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
=
0
);
void
CreateAllCNCLComms
(
const
std
::
vector
<
int
>&
dev_ids
,
int
ring_id
=
0
);
// a latter comm with the same dev_id and the same ring_id
// will override the former
CNCLComm
*
AssignCNCLComm
(
cnclComm_t
comm
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
=
0
);
// retrieve a communicator by the ring id in multiprocessing mode
CNCLComm
*
Get
(
int
ring_id
)
const
{
PADDLE_ENFORCE_GT
(
comm_map_
.
count
(
ring_id
),
0
,
platform
::
errors
::
InvalidArgument
(
"Communicator in ring id %d has not been initialized."
,
ring_id
));
PADDLE_ENFORCE_EQ
(
comm_map_
.
at
(
ring_id
).
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"One device id should be specified to retrieve from "
"multiple communicators."
));
return
comm_map_
.
at
(
ring_id
).
begin
()
->
second
.
get
();
}
// retrieve a communicator by the ring id and the device id
CNCLComm
*
Get
(
int
ring_id
,
int
dev_id
)
const
{
PADDLE_ENFORCE_GT
(
comm_map_
.
count
(
ring_id
),
0
,
platform
::
errors
::
InvalidArgument
(
"Communicator of ring id %d has not been initialized."
,
ring_id
));
PADDLE_ENFORCE_GT
(
comm_map_
.
at
(
ring_id
).
count
(
dev_id
),
0
,
platform
::
errors
::
InvalidArgument
(
"Communicator at device id %d has not been initialized in ring %d."
,
dev_id
,
ring_id
));
return
comm_map_
.
at
(
ring_id
).
at
(
dev_id
).
get
();
}
// retrieve a communicator by the ring id and place
CNCLComm
*
Get
(
int
ring_id
,
Place
place
)
const
{
return
Get
(
ring_id
,
place
.
device
);
}
private:
std
::
once_flag
once_flag_
;
std
::
mutex
comm_map_mutex_
;
// ring id to dev-CNCLComm
std
::
map
<
int
,
std
::
map
<
int
,
std
::
unique_ptr
<
CNCLComm
>>>
comm_map_
;
void
ReleaseCNCLComms
();
CNCLCommContext
()
=
default
;
DISABLE_COPY_AND_ASSIGN
(
CNCLCommContext
);
};
#endif
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/device/mlu/CMakeLists.txt
浏览文件 @
d28f6f7b
...
...
@@ -8,3 +8,4 @@ cc_library(mlu_info SRCS mlu_info.cc DEPS enforce glog monitor neuware_lib)
cc_library
(
mlu_stream SRCS mlu_stream.cc DEPS boost mlu_info stream_callback_manager eigen3
${
MKLDNN_CTX_DEPS
}
)
cc_library
(
mlu_device_context SRCS device_context.cc DEPS mlu_stream
)
cc_test
(
mlu_device_context_test SRCS device_context_test.cc DEPS mlu_device_context
)
cc_library
(
mlu_collective_helper SRCS mlu_collective_helper.cc DEPS mlu_stream mlu_info
)
paddle/fluid/platform/device/mlu/cncl_helper.h
0 → 100644
浏览文件 @
d28f6f7b
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_CNCL
#include <cncl.h>
#include <stdio.h>
#include <memory>
#include <string>
#include <thread> // NOLINT
#include <typeindex>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/mlu/enforce.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
platform
{
inline
cnclDataType_t
ToCNCLDataType
(
framework
::
proto
::
VarType
::
Type
type
)
{
if
(
type
==
framework
::
proto
::
VarType
::
FP32
)
{
return
cnclFloat32
;
}
else
if
(
type
==
framework
::
proto
::
VarType
::
FP16
)
{
return
cnclFloat16
;
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT32
)
{
return
cnclInt32
;
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT16
)
{
return
cnclInt16
;
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT8
)
{
return
cnclInt8
;
}
else
if
(
type
==
framework
::
proto
::
VarType
::
UINT8
)
{
return
cnclUint8
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"This datatype in cncl is not supported."
));
}
}
}
// namespace platform
}
// namespace paddle
#endif
paddle/fluid/platform/device/mlu/device_context.h
浏览文件 @
d28f6f7b
...
...
@@ -15,6 +15,9 @@ limitations under the License. */
#include "paddle/fluid/platform/device/mlu/enforce.h"
#include "paddle/fluid/platform/device/mlu/mlu_stream.h"
#include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_CNCL
#include <cncl.h>
#endif
namespace
Eigen
{
struct
DefaultDevice
;
...
...
@@ -88,6 +91,14 @@ class MLUDeviceContext : public DeviceContext {
/*! \brief Return mlu stream in the device context. */
mluStream
stream
()
const
;
#ifdef PADDLE_WITH_CNCL
/*! \brief Return cncl communicators. */
cnclComm_t
cncl_comm
()
const
{
return
cncl_comm_
;
}
/*! \brief Set cncl communicators. */
void
set_cncl_comm
(
cnclComm_t
comm
)
{
cncl_comm_
=
comm
;
}
#endif
template
<
typename
Callback
>
void
RecordEvent
(
mluEventHandle
ev
,
Callback
callback
)
const
{
return
context
()
->
Stream
()
->
RecordEvent
(
ev
,
callback
);
...
...
@@ -132,6 +143,10 @@ class MLUDeviceContext : public DeviceContext {
thread_ctx_
;
static
thread_local
std
::
mutex
ctx_mtx_
;
#ifdef PADDLE_WITH_CNCL
cnclComm_t
cncl_comm_
{
nullptr
};
#endif
DISABLE_COPY_AND_ASSIGN
(
MLUDeviceContext
);
};
...
...
paddle/fluid/platform/device/mlu/enforce.h
浏览文件 @
d28f6f7b
...
...
@@ -42,6 +42,9 @@ struct MLUStatusType {};
DEFINE_MLU_STATUS_TYPE
(
cnrtStatus
,
cnrtSuccess
,
CNRT
);
DEFINE_MLU_STATUS_TYPE
(
cnnlStatus
,
CNNL_STATUS_SUCCESS
,
CNNL
);
DEFINE_MLU_STATUS_TYPE
(
cnStatus
,
CN_SUCCESS
,
CN
);
#ifdef PADDLE_WITH_CNCL
DEFINE_MLU_STATUS_TYPE
(
cnclStatus
,
CNCL_RET_SUCCESS
,
CNCL
);
#endif
}
// namespace details
...
...
@@ -80,6 +83,17 @@ inline std::string build_mlu_error_msg(cnStatus stat) {
return
sout
.
str
();
}
/*************** CNCL ERROR ***************/
#ifdef PADDLE_WITH_CNCL
inline
bool
is_error
(
cnclStatus
e
)
{
return
e
!=
CNCL_RET_SUCCESS
;
}
inline
std
::
string
build_mlu_error_msg
(
cnclStatus
e
)
{
std
::
ostringstream
sout
;
sout
<<
"MLU CNCL error("
<<
e
<<
"), "
<<
cnclGetErrorStr
(
e
)
<<
". "
;
return
sout
.
str
();
}
#endif
#define PADDLE_ENFORCE_MLU_SUCCESS(COND) \
do { \
auto __cond__ = (COND); \
...
...
paddle/fluid/platform/device/mlu/enforce_test.cc
浏览文件 @
d28f6f7b
...
...
@@ -58,5 +58,15 @@ TEST(mlu_enforce, mlu_success) {
CheckMluStatusFailure
(
CN_ERROR_INVALID_VALUE
,
"invalid argument"
));
EXPECT_TRUE
(
CheckMluStatusFailure
(
CN_MEMORY_ERROR_OUT_OF_MEMORY
,
"device has no memory to alloc"
));
#ifdef PADDLE_WITH_CNCL
EXPECT_TRUE
(
CheckMluStatusSuccess
(
CNCL_RET_SUCCESS
));
EXPECT_TRUE
(
CheckMluStatusFailure
(
CNCL_RET_ERR_INTERNAL
,
"CNCL error"
));
EXPECT_TRUE
(
CheckMluStatusFailure
(
CNCL_RET_ERR_NULL_POINTER
,
"CNCL error"
));
EXPECT_TRUE
(
CheckMluStatusFailure
(
CNCL_RET_ERR_INIT
,
"CNCL error"
));
EXPECT_TRUE
(
CheckMluStatusFailure
(
CNCL_RET_ERR_NOT_INIT
,
"CNCL error"
));
EXPECT_TRUE
(
CheckMluStatusFailure
(
CNCL_RET_ERR_REINIT
,
"CNCL error"
));
EXPECT_TRUE
(
CheckMluStatusFailure
(
CNCL_RET_ERR_INVALID_VERSION
,
"CNCL error"
));
#endif
}
#endif
paddle/fluid/platform/device/mlu/mlu_collective_helper.cc
0 → 100644
浏览文件 @
d28f6f7b
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(PADDLE_WITH_CNCL)
#include <utility>
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/mlu/enforce.h"
namespace
paddle
{
namespace
platform
{
class
CNCLCommImpl
:
public
CNCLComm
{
public:
void
set_ring_id
(
int
ring_id
)
{
ring_id_
=
ring_id
;
}
int
ring_id
()
const
override
{
return
ring_id_
;
}
void
set_nranks
(
int
nranks
)
{
nranks_
=
nranks
;
}
int
nranks
()
const
override
{
return
nranks_
;
}
void
set_rank
(
int
rank
)
{
rank_
=
rank
;
}
int
rank
()
const
override
{
return
rank_
;
}
int
device_id
()
const
override
{
return
dev_ctx_
->
GetPlace
().
device
;
}
void
set_comm
(
cnclComm_t
comm
)
{
comm_
=
comm
;
}
cnclComm_t
comm
()
const
override
{
return
comm_
;
}
mluStream
stream
()
const
override
{
return
dev_ctx_
->
stream
();
}
void
set_dev_ctx
(
std
::
unique_ptr
<
MLUDeviceContext
>&&
dev_ctx
)
{
dev_ctx_
=
std
::
move
(
dev_ctx
);
}
MLUDeviceContext
*
dev_context
()
const
override
{
return
dev_ctx_
.
get
();
}
~
CNCLCommImpl
()
{
if
(
comm_
)
{
PADDLE_ENFORCE_MLU_SUCCESS
(
cnclFreeComm
(
comm_
));
}
}
private:
int
ring_id_
;
int
nranks_
;
int
rank_
;
cnclComm_t
comm_
;
std
::
unique_ptr
<
MLUDeviceContext
>
dev_ctx_
;
};
CNCLComm
*
CNCLCommContext
::
CreateComm
(
cnclCliqueId
*
cncl_id
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
)
{
PADDLE_ENFORCE_NOT_NULL
(
cncl_id
,
platform
::
errors
::
InvalidArgument
(
"The cncl unique id should not be null."
));
PADDLE_ENFORCE_GT
(
nranks
,
1
,
platform
::
errors
::
InvalidArgument
(
"Expected nranks > 1. But received nranks is %d."
,
nranks
));
PADDLE_ENFORCE_GE
(
rank
,
0
,
platform
::
errors
::
InvalidArgument
(
"Expected rank >= 0. But received rank is %d."
,
rank
));
PADDLE_ENFORCE_LT
(
rank
,
nranks
,
platform
::
errors
::
InvalidArgument
(
"Expected rank < nranks. But received rank is %d, nranks is %d."
,
rank
,
nranks
));
PADDLE_ENFORCE_GE
(
dev_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"Expected dev_id >= 0. But received dev_id is %d."
,
dev_id
));
cnclComm_t
comm
;
int
dev_list
[]
=
{
dev_id
};
int
rank_list
[]
=
{
rank
};
SetMLUDeviceId
(
dev_id
);
PADDLE_ENFORCE_MLU_SUCCESS
(
cnclInitComms
(
&
comm
,
1
,
dev_list
,
rank_list
,
nranks
,
cncl_id
));
auto
*
comm_wrapper
=
AssignCNCLComm
(
comm
,
nranks
,
rank
,
dev_id
,
ring_id
);
VLOG
(
1
)
<<
"cncl communicator of rank "
<<
rank
<<
" in ring "
<<
ring_id
<<
" has been created on device "
<<
dev_id
;
std
::
call_once
(
once_flag_
,
[]()
{
std
::
atexit
([]()
{
CNCLCommContext
::
Instance
().
ReleaseCNCLComms
();
});
});
return
comm_wrapper
;
}
void
CNCLCommContext
::
CreateAllCNCLComms
(
const
std
::
vector
<
int
>&
dev_ids
,
int
ring_id
)
{
PADDLE_ENFORCE_GT
(
dev_ids
.
size
(),
0
,
platform
::
errors
::
InvalidArgument
(
"Expected the size of dev_ids > 0. But "
"received the size of dev_ids is %d."
,
dev_ids
.
size
()));
const
int
kDevices
=
dev_ids
.
size
();
cnclComm_t
comms
[
kDevices
];
int
*
rank_list
=
new
int
[
kDevices
];
for
(
int
i
=
0
;
i
<
kDevices
;
i
++
)
{
rank_list
[
i
]
=
i
;
}
cnclCliqueId
clique_id
;
PADDLE_ENFORCE_MLU_SUCCESS
(
cnclGetCliqueId
(
&
clique_id
));
PADDLE_ENFORCE_MLU_SUCCESS
(
cnclInitComms
(
comms
,
dev_ids
.
size
(),
dev_ids
.
data
(),
rank_list
,
dev_ids
.
size
(),
&
clique_id
));
PADDLE_ENFORCE_EQ
(
comm_map_
.
count
(
ring_id
),
0
,
platform
::
errors
::
InvalidArgument
(
"Expected comm_map_.count(ring_id) = 0. But received "
"comm_map_.count(ring_id) is %d."
,
comm_map_
.
count
(
ring_id
)));
for
(
size_t
i
=
0
;
i
<
dev_ids
.
size
();
++
i
)
{
AssignCNCLComm
(
comms
[
i
],
dev_ids
.
size
(),
i
,
dev_ids
[
i
],
ring_id
);
VLOG
(
1
)
<<
"cncl communicator of rank "
<<
i
<<
" in ring "
<<
ring_id
<<
" has been created on device "
<<
dev_ids
[
i
];
}
std
::
call_once
(
once_flag_
,
[]()
{
std
::
atexit
([]()
{
CNCLCommContext
::
Instance
().
ReleaseCNCLComms
();
});
});
delete
[]
rank_list
;
}
CNCLComm
*
CNCLCommContext
::
AssignCNCLComm
(
cnclComm_t
comm
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
)
{
std
::
unique_ptr
<
MLUDeviceContext
>
dev_ctx
(
new
MLUDeviceContext
(
MLUPlace
(
dev_id
)));
CNCLCommImpl
*
c
=
new
CNCLCommImpl
;
c
->
set_ring_id
(
ring_id
);
c
->
set_nranks
(
nranks
);
c
->
set_rank
(
rank
);
c
->
set_comm
(
comm
);
c
->
set_dev_ctx
(
std
::
move
(
dev_ctx
));
comm_map_mutex_
.
lock
();
if
(
comm_map_
.
count
(
ring_id
)
==
0
)
{
comm_map_
.
emplace
(
ring_id
,
std
::
map
<
int
,
std
::
unique_ptr
<
CNCLComm
>>
());
}
auto
&
dev2comm
=
comm_map_
[
ring_id
];
dev2comm
.
emplace
(
dev_id
,
std
::
unique_ptr
<
CNCLComm
>
(
c
));
comm_map_mutex_
.
unlock
();
if
(
ring_id
==
0
)
{
auto
*
dev_ctx
=
static_cast
<
platform
::
MLUDeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
MLUPlace
(
dev_id
)));
dev_ctx
->
set_cncl_comm
(
comm
);
}
return
comm_map_
[
ring_id
][
dev_id
].
get
();
}
void
CNCLCommContext
::
ReleaseCNCLComms
()
{
for
(
auto
&
p
:
comm_map_
)
{
for
(
auto
&
q
:
p
.
second
)
{
q
.
second
.
reset
();
}
}
}
}
// namespace platform
}
// namespace paddle
#endif
paddle/fluid/platform/device/mlu/mlu_info.h
浏览文件 @
d28f6f7b
...
...
@@ -18,6 +18,9 @@ limitations under the License. */
#include <cn_api.h>
#include <cnnl.h>
#include <cnrt.h>
#ifdef PADDLE_WITH_CNCL
#include <cncl.h>
#endif
#include <vector>
namespace
paddle
{
...
...
@@ -25,6 +28,9 @@ namespace paddle {
using
cnStatus
=
CNresult
;
using
cnrtStatus
=
cnrtRet_t
;
using
cnnlStatus
=
cnnlStatus_t
;
#ifdef PADDLE_WITH_CNCL
using
cnclStatus
=
cnclResult_t
;
#endif
using
mluStream
=
cnrtQueue_t
;
using
mluCnnlHandle
=
cnnlHandle_t
;
using
mluEventHandle
=
CNnotifier
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录