Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
1bec737d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
1bec737d
编写于
11月 18, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(distributed): support distributed opr for rocm
GitOrigin-RevId: 4840100d07dbaa2b7d8e3e113b444ddf81eeea51
上级
a31b7c6e
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
42 addition
and
20 deletion
+42
-20
CMakeLists.txt
CMakeLists.txt
+4
-2
src/core/impl/comp_node/comp_node.cpp
src/core/impl/comp_node/comp_node.cpp
+2
-0
src/core/impl/comp_node/rocm/comp_node.cpp
src/core/impl/comp_node/rocm/comp_node.cpp
+16
-0
src/opr-mm/impl/collective_comm.cpp
src/opr-mm/impl/collective_comm.cpp
+1
-8
src/opr-mm/impl/io_remote.cpp
src/opr-mm/impl/io_remote.cpp
+2
-6
src/opr-mm/impl/megray_helper.cpp
src/opr-mm/impl/megray_helper.cpp
+13
-0
src/opr-mm/impl/zmq_rpc.cpp
src/opr-mm/impl/zmq_rpc.cpp
+0
-2
src/opr-mm/include/megbrain/opr/megray_helper.h
src/opr-mm/include/megbrain/opr/megray_helper.h
+4
-0
src/opr-mm/include/megbrain/opr/zmq_rpc.h
src/opr-mm/include/megbrain/opr/zmq_rpc.h
+0
-2
未找到文件。
CMakeLists.txt
浏览文件 @
1bec737d
...
@@ -339,8 +339,8 @@ if(MGE_BUILD_IMPERATIVE_RT)
...
@@ -339,8 +339,8 @@ if(MGE_BUILD_IMPERATIVE_RT)
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CXX_STANDARD 17
)
endif
()
endif
()
if
(
NOT
MGE_WITH_CUDA
)
if
(
NOT
${
MGE_WITH_CUDA
}
AND NOT
${
MGE_WITH_ROCM
}
)
message
(
STATUS
"Disable distributed support, as
CUDA is not en
abled."
)
message
(
STATUS
"Disable distributed support, as
both CUDA and ROCm are dis
abled."
)
set
(
MGE_WITH_DISTRIBUTED OFF
)
set
(
MGE_WITH_DISTRIBUTED OFF
)
endif
()
endif
()
...
@@ -903,6 +903,8 @@ if(MGE_WITH_JIT_MLIR OR MGE_BUILD_IMPERATIVE_RT)
...
@@ -903,6 +903,8 @@ if(MGE_WITH_JIT_MLIR OR MGE_BUILD_IMPERATIVE_RT)
endif
()
endif
()
if
(
MGE_WITH_DISTRIBUTED
)
if
(
MGE_WITH_DISTRIBUTED
)
set
(
MEGRAY_WITH_NCCL
${
MGE_WITH_CUDA
}
CACHE BOOL
"Override MegRay option"
FORCE
)
set
(
MEGRAY_WITH_RCCL
${
MGE_WITH_ROCM
}
CACHE BOOL
"Override MegRay option"
FORCE
)
add_subdirectory
(
${
PROJECT_SOURCE_DIR
}
/third_party/MegRay
)
add_subdirectory
(
${
PROJECT_SOURCE_DIR
}
/third_party/MegRay
)
endif
()
endif
()
...
...
src/core/impl/comp_node/comp_node.cpp
浏览文件 @
1bec737d
...
@@ -79,6 +79,8 @@ namespace {
...
@@ -79,6 +79,8 @@ namespace {
if
(
g_unspec_locator_type
==
DT
::
UNSPEC
)
{
if
(
g_unspec_locator_type
==
DT
::
UNSPEC
)
{
if
(
CudaCompNode
::
available
())
{
if
(
CudaCompNode
::
available
())
{
g_unspec_locator_type
=
DT
::
CUDA
;
g_unspec_locator_type
=
DT
::
CUDA
;
}
else
if
(
ROCmCompNode
::
available
())
{
g_unspec_locator_type
=
DT
::
ROCM
;
}
else
{
}
else
{
g_unspec_locator_type
=
DT
::
CPU
;
g_unspec_locator_type
=
DT
::
CPU
;
}
}
...
...
src/core/impl/comp_node/rocm/comp_node.cpp
浏览文件 @
1bec737d
...
@@ -217,6 +217,11 @@ public:
...
@@ -217,6 +217,11 @@ public:
Locator
locator
()
override
{
return
m_locator
;
}
Locator
locator
()
override
{
return
m_locator
;
}
Locator
locator_logical
()
override
{
return
m_locator_logical
;
}
Locator
locator_logical
()
override
{
return
m_locator_logical
;
}
uint64_t
get_uid
()
override
{
return
m_uid
;
}
private:
uint64_t
m_uid
;
};
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
ROCmCompNode
::
CompNodeImpl
);
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
ROCmCompNode
::
CompNodeImpl
);
...
@@ -278,6 +283,17 @@ void ROCmCompNodeImpl::init(const Locator& locator,
...
@@ -278,6 +283,17 @@ void ROCmCompNodeImpl::init(const Locator& locator,
m_locator_logical
=
locator_logical
;
m_locator_logical
=
locator_logical
;
m_initialized
=
true
;
m_initialized
=
true
;
#if defined(__linux__) || defined(TARGET_OS_MAC)
FILE
*
fp
;
fp
=
fopen
(
"/dev/urandom"
,
"r"
);
mgb_assert
(
fread
(
&
m_uid
,
sizeof
(
m_uid
),
1
,
fp
)
==
1
);
fclose
(
fp
);
#else
m_uid
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
nanoseconds
>
(
std
::
chrono
::
system_clock
::
now
().
time_since_epoch
()
).
count
();
#endif
auto
on_succ
=
[
this
](
hipStream_t
stream
)
{
auto
on_succ
=
[
this
](
hipStream_t
stream
)
{
auto
locator
=
m_locator
;
auto
locator
=
m_locator
;
log_comp_node_created
(
locator
,
m_locator_logical
);
log_comp_node_created
(
locator
,
m_locator_logical
);
...
...
src/opr-mm/impl/collective_comm.cpp
浏览文件 @
1bec737d
...
@@ -47,9 +47,6 @@ const char* get_param_name(CollectiveComm::Param param) {
...
@@ -47,9 +47,6 @@ const char* get_param_name(CollectiveComm::Param param) {
}
}
}
}
cudaStream_t
get_stream
(
VarNode
*
var
)
{
return
CompNodeEnv
::
from_comp_node
(
var
->
comp_node
()).
cuda_env
().
stream
;
}
}
// anonymous namespace
}
// anonymous namespace
/* ================= ModeTrait ================= */
/* ================= ModeTrait ================= */
...
@@ -519,8 +516,6 @@ CollectiveComm::CollectiveComm(
...
@@ -519,8 +516,6 @@ CollectiveComm::CollectiveComm(
// add input
// add input
mgb_assert
(
inputs
.
size
()
<=
1
,
"one or zero input expected, got %zu"
,
inputs
.
size
());
mgb_assert
(
inputs
.
size
()
<=
1
,
"one or zero input expected, got %zu"
,
inputs
.
size
());
if
(
inputs
.
size
()
>
0
)
{
if
(
inputs
.
size
()
>
0
)
{
mgb_assert
(
inputs
[
0
]
->
comp_node
().
device_type
()
==
CompNode
::
DeviceType
::
CUDA
,
"CollectiveComm currectly only supports CUDA"
);
add_input
({
inputs
[
0
]});
add_input
({
inputs
[
0
]});
}
}
...
@@ -531,8 +526,6 @@ CollectiveComm::CollectiveComm(
...
@@ -531,8 +526,6 @@ CollectiveComm::CollectiveComm(
const
auto
&
cns
=
config
.
comp_node
();
const
auto
&
cns
=
config
.
comp_node
();
mgb_assert
(
cns
.
size
()
<=
1
,
"one or zero comp node expected, got %zu"
,
cns
.
size
());
mgb_assert
(
cns
.
size
()
<=
1
,
"one or zero comp node expected, got %zu"
,
cns
.
size
());
if
(
cns
.
size
()
>
0
)
{
if
(
cns
.
size
()
>
0
)
{
mgb_assert
(
cns
[
0
].
device_type
()
==
CompNode
::
DeviceType
::
CUDA
,
"CollectiveComm currectly only supports CUDA"
);
output
(
0
)
->
comp_node
(
cns
[
0
]);
output
(
0
)
->
comp_node
(
cns
[
0
]);
}
else
{
}
else
{
output
(
0
)
->
comp_node
(
inputs
[
0
]
->
comp_node
());
output
(
0
)
->
comp_node
(
inputs
[
0
]
->
comp_node
());
...
@@ -609,7 +602,7 @@ void CollectiveComm::opr_register() {
...
@@ -609,7 +602,7 @@ void CollectiveComm::opr_register() {
reg_info
.
hash
,
m_key
,
m_nr_devices
,
m_rank
,
reg_info
.
hash
,
m_key
,
m_nr_devices
,
m_rank
,
get_megray_backend
(
m_backend
),
m_group_client
);
get_megray_backend
(
m_backend
),
m_group_client
);
m_megray_ctx
=
MegRay
::
CudaContext
::
make
(
get_stream
(
output
(
0
)
));
m_megray_ctx
=
get_megray_context
(
output
(
0
)
->
comp_node
(
));
m_init
=
true
;
m_init
=
true
;
}
}
...
...
src/opr-mm/impl/io_remote.cpp
浏览文件 @
1bec737d
...
@@ -18,10 +18,6 @@
...
@@ -18,10 +18,6 @@
using
namespace
mgb
;
using
namespace
mgb
;
using
namespace
opr
;
using
namespace
opr
;
cudaStream_t
get_stream
(
VarNode
*
var
)
{
return
CompNodeEnv
::
from_comp_node
(
var
->
comp_node
()).
cuda_env
().
stream
;
}
/* ===================== RemoteSend ===================== */
/* ===================== RemoteSend ===================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
RemoteSend
);
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
RemoteSend
);
...
@@ -70,7 +66,7 @@ void RemoteSend::scn_do_execute() {
...
@@ -70,7 +66,7 @@ void RemoteSend::scn_do_execute() {
m_megray_comm
=
MegRayCommBuilder
::
get_megray_comm
(
m_megray_comm
=
MegRayCommBuilder
::
get_megray_comm
(
reg_info
.
hash
,
m_key
,
2
,
0
,
MegRay
::
MEGRAY_NCCL
,
m_group_client
);
reg_info
.
hash
,
m_key
,
2
,
0
,
MegRay
::
MEGRAY_NCCL
,
m_group_client
);
m_megray_ctx
=
MegRay
::
CudaContext
::
make
(
get_stream
(
output
(
0
)
));
m_megray_ctx
=
get_megray_context
(
output
(
0
)
->
comp_node
(
));
m_init
=
true
;
m_init
=
true
;
}
}
...
@@ -207,7 +203,7 @@ void RemoteRecv::scn_do_execute() {
...
@@ -207,7 +203,7 @@ void RemoteRecv::scn_do_execute() {
m_megray_comm
=
MegRayCommBuilder
::
get_megray_comm
(
m_megray_comm
=
MegRayCommBuilder
::
get_megray_comm
(
reg_info
.
hash
,
m_key
,
2
,
1
,
MegRay
::
MEGRAY_NCCL
,
m_group_client
);
reg_info
.
hash
,
m_key
,
2
,
1
,
MegRay
::
MEGRAY_NCCL
,
m_group_client
);
m_megray_ctx
=
MegRay
::
CudaContext
::
make
(
get_stream
(
output
(
0
)
));
m_megray_ctx
=
get_megray_context
(
output
(
0
)
->
comp_node
(
));
m_init
=
true
;
m_init
=
true
;
}
}
...
...
src/opr-mm/impl/megray_helper.cpp
浏览文件 @
1bec737d
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
*/
*/
#include "megbrain/opr/megray_helper.h"
#include "megbrain/opr/megray_helper.h"
#include "megbrain/comp_node_env.h"
using
namespace
mgb
;
using
namespace
mgb
;
using
namespace
opr
;
using
namespace
opr
;
...
@@ -34,6 +35,8 @@ MegRay::DType mgb::opr::get_megray_dtype(megdnn::DType dtype) {
...
@@ -34,6 +35,8 @@ MegRay::DType mgb::opr::get_megray_dtype(megdnn::DType dtype) {
MegRay
::
Backend
mgb
::
opr
::
get_megray_backend
(
const
std
::
string
&
backend
)
{
MegRay
::
Backend
mgb
::
opr
::
get_megray_backend
(
const
std
::
string
&
backend
)
{
if
(
backend
==
"nccl"
)
{
if
(
backend
==
"nccl"
)
{
return
MegRay
::
MEGRAY_NCCL
;
return
MegRay
::
MEGRAY_NCCL
;
}
else
if
(
backend
==
"rccl"
)
{
return
MegRay
::
MEGRAY_RCCL
;
}
else
if
(
backend
==
"ucx"
)
{
}
else
if
(
backend
==
"ucx"
)
{
return
MegRay
::
MEGRAY_UCX
;
return
MegRay
::
MEGRAY_UCX
;
}
else
{
}
else
{
...
@@ -41,6 +44,16 @@ MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) {
...
@@ -41,6 +44,16 @@ MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) {
}
}
}
}
std
::
shared_ptr
<
MegRay
::
Context
>
mgb
::
opr
::
get_megray_context
(
CompNode
comp_node
){
#if MGB_CUDA
return
MegRay
::
CudaContext
::
make
(
CompNodeEnv
::
from_comp_node
(
comp_node
).
cuda_env
().
stream
);
#elif MGB_ROCM
return
MegRay
::
HipContext
::
make
(
CompNodeEnv
::
from_comp_node
(
comp_node
).
rocm_env
().
stream
);
#else
#error "neither CUDA nor ROCm is enabled"
#endif
}
bool
MegRayCommBuilder
::
find
(
uint64_t
hash
,
std
::
shared_ptr
<
MegRay
::
Communicator
>&
comm
)
{
bool
MegRayCommBuilder
::
find
(
uint64_t
hash
,
std
::
shared_ptr
<
MegRay
::
Communicator
>&
comm
)
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
m_map_mtx
);
std
::
unique_lock
<
std
::
mutex
>
lk
(
m_map_mtx
);
auto
it
=
m_megray_comms
.
find
(
hash
);
auto
it
=
m_megray_comms
.
find
(
hash
);
...
...
src/opr-mm/impl/zmq_rpc.cpp
浏览文件 @
1bec737d
#include "megbrain_build_config.h"
#include "megbrain_build_config.h"
#if MGB_CUDA
#include "megbrain/opr/zmq_rpc.h"
#include "megbrain/opr/zmq_rpc.h"
#include "megbrain/common.h"
#include "megbrain/common.h"
#include "megbrain/exception.h"
#include "megbrain/exception.h"
...
@@ -228,4 +227,3 @@ void ZmqRpcClient::request(message_t& request, message_t& reply) {
...
@@ -228,4 +227,3 @@ void ZmqRpcClient::request(message_t& request, message_t& reply) {
DISCARD_RETVAL
(
client
->
recv
(
reply
));
DISCARD_RETVAL
(
client
->
recv
(
reply
));
add_socket
(
client
);
add_socket
(
client
);
}
}
#endif // MGB_CUDA
src/opr-mm/include/megbrain/opr/megray_helper.h
浏览文件 @
1bec737d
...
@@ -12,7 +12,9 @@
...
@@ -12,7 +12,9 @@
#pragma once
#pragma once
#include <mutex>
#include <mutex>
#include <memory>
#include "megbrain/comp_node.h"
#include "megbrain/opr/group_manager.h"
#include "megbrain/opr/group_manager.h"
#include "megray.h"
#include "megray.h"
...
@@ -23,6 +25,8 @@ MegRay::DType get_megray_dtype(megdnn::DType);
...
@@ -23,6 +25,8 @@ MegRay::DType get_megray_dtype(megdnn::DType);
MegRay
::
Backend
get_megray_backend
(
const
std
::
string
&
backend
);
MegRay
::
Backend
get_megray_backend
(
const
std
::
string
&
backend
);
std
::
shared_ptr
<
MegRay
::
Context
>
get_megray_context
(
CompNode
comp_node
);
/*!
/*!
* gather MegRay unique ids and build communicator, use hash for deduplication
* gather MegRay unique ids and build communicator, use hash for deduplication
*/
*/
...
...
src/opr-mm/include/megbrain/opr/zmq_rpc.h
浏览文件 @
1bec737d
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
#include "megbrain_build_config.h"
#include "megbrain_build_config.h"
#if MGB_CUDA
#include <unistd.h>
#include <unistd.h>
#include <cassert>
#include <cassert>
#include <iostream>
#include <iostream>
...
@@ -101,4 +100,3 @@ private:
...
@@ -101,4 +100,3 @@ private:
std
::
vector
<
std
::
shared_ptr
<
zmq
::
socket_t
>>
m_own_sockets
;
std
::
vector
<
std
::
shared_ptr
<
zmq
::
socket_t
>>
m_own_sockets
;
};
};
}
// namespace ZmqRpc
}
// namespace ZmqRpc
#endif
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录