Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
116eee52
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
116eee52
编写于
9月 11, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
build(third_party): update megray
GitOrigin-RevId: da5e05f82b5112474d51f9eab78318b1d6432742
上级
e507228e
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
37 addition
and
32 deletion
+37
-32
src/opr-mm/impl/collective_comm.cpp
src/opr-mm/impl/collective_comm.cpp
+0
-27
src/opr-mm/impl/io_remote.cpp
src/opr-mm/impl/io_remote.cpp
+6
-4
src/opr-mm/impl/megray_helper.cpp
src/opr-mm/impl/megray_helper.cpp
+27
-0
src/opr-mm/include/megbrain/opr/megray_helper.h
src/opr-mm/include/megbrain/opr/megray_helper.h
+4
-1
未找到文件。
src/opr-mm/impl/collective_comm.cpp
浏览文件 @
116eee52
...
...
@@ -47,33 +47,6 @@ const char* get_param_name(CollectiveComm::Param param) {
}
}
MegRay
::
DType
get_megray_dtype
(
megdnn
::
DType
dtype
)
{
switch
(
dtype
.
enumv
())
{
case
DTypeEnum
::
Int8
:
return
MegRay
::
DType
::
MEGRAY_INT8
;
case
DTypeEnum
::
Int32
:
return
MegRay
::
DType
::
MEGRAY_INT32
;
case
DTypeEnum
::
Float32
:
return
MegRay
::
DType
::
MEGRAY_FLOAT32
;
#ifndef MEGDNN_DISABLE_FLOAT16
case
DTypeEnum
::
Float16
:
return
MegRay
::
DType
::
MEGRAY_FLOAT16
;
#endif
default:
mgb_throw
(
MegBrainError
,
"bad CollectiveComm dtype"
);
}
}
MegRay
::
Backend
get_megray_backend
(
const
std
::
string
&
backend
)
{
if
(
backend
==
"nccl"
)
{
return
MegRay
::
MEGRAY_NCCL
;
}
else
if
(
backend
==
"ucx"
)
{
return
MegRay
::
MEGRAY_UCX
;
}
else
{
mgb_throw
(
MegBrainError
,
"back CollectiveComm backend"
);
}
}
cudaStream_t
get_stream
(
VarNode
*
var
)
{
return
CompNodeEnv
::
from_comp_node
(
var
->
comp_node
()).
cuda_env
().
stream
;
}
...
...
src/opr-mm/impl/io_remote.cpp
浏览文件 @
116eee52
...
...
@@ -82,8 +82,9 @@ void RemoteSend::scn_do_execute() {
for
(
size_t
i
=
0
;
i
<
ishp
.
ndim
;
i
++
)
{
data_size
*=
ishp
[
i
];
}
data_size
*=
tensor
.
dtype
().
size
();
auto
status
=
m_megray_comm
->
send
(
tensor
.
raw_ptr
(),
data_size
,
1
,
m_megray_ctx
);
auto
status
=
m_megray_comm
->
send
(
tensor
.
raw_ptr
(),
data_size
,
get_megray_dtype
(
tensor
.
dtype
()),
1
,
m_megray_ctx
);
mgb_assert
(
status
==
MegRay
::
MEGRAY_OK
,
"MegRay send failed"
);
if
(
m_is_grad
)
{
...
...
@@ -192,8 +193,9 @@ void RemoteRecv::scn_do_execute() {
for
(
size_t
i
=
0
;
i
<
ishp
.
ndim
;
i
++
)
{
data_size
*=
ishp
[
i
];
}
data_size
*=
tensor
.
dtype
().
size
();
auto
status
=
m_megray_comm
->
recv
(
tensor
.
raw_ptr
(),
data_size
,
0
,
m_megray_ctx
);
auto
status
=
m_megray_comm
->
recv
(
tensor
.
raw_ptr
(),
data_size
,
get_megray_dtype
(
tensor
.
dtype
()),
0
,
m_megray_ctx
);
mgb_assert
(
status
==
MegRay
::
MEGRAY_OK
,
"MegRay recv failed"
);
}
...
...
src/opr-mm/impl/megray_helper.cpp
浏览文件 @
116eee52
...
...
@@ -14,6 +14,33 @@
using
namespace
mgb
;
using
namespace
opr
;
MegRay
::
DType
mgb
::
opr
::
get_megray_dtype
(
megdnn
::
DType
dtype
)
{
switch
(
dtype
.
enumv
())
{
case
DTypeEnum
::
Int8
:
return
MegRay
::
DType
::
MEGRAY_INT8
;
case
DTypeEnum
::
Int32
:
return
MegRay
::
DType
::
MEGRAY_INT32
;
case
DTypeEnum
::
Float32
:
return
MegRay
::
DType
::
MEGRAY_FLOAT32
;
#ifndef MEGDNN_DISABLE_FLOAT16
case
DTypeEnum
::
Float16
:
return
MegRay
::
DType
::
MEGRAY_FLOAT16
;
#endif
default:
mgb_throw
(
MegBrainError
,
"bad CollectiveComm dtype"
);
}
}
MegRay
::
Backend
mgb
::
opr
::
get_megray_backend
(
const
std
::
string
&
backend
)
{
if
(
backend
==
"nccl"
)
{
return
MegRay
::
MEGRAY_NCCL
;
}
else
if
(
backend
==
"ucx"
)
{
return
MegRay
::
MEGRAY_UCX
;
}
else
{
mgb_throw
(
MegBrainError
,
"back CollectiveComm backend"
);
}
}
bool
MegRayCommBuilder
::
find
(
uint64_t
hash
,
std
::
shared_ptr
<
MegRay
::
Communicator
>&
comm
)
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
m_map_mtx
);
auto
it
=
m_megray_comms
.
find
(
hash
);
...
...
src/opr-mm/include/megbrain/opr/megray_helper.h
浏览文件 @
116eee52
...
...
@@ -13,13 +13,16 @@
#include <mutex>
#include "megbrain/utils/metahelper.h"
#include "megbrain/opr/group_manager.h"
#include "megray.h"
namespace
mgb
{
namespace
opr
{
MegRay
::
DType
get_megray_dtype
(
megdnn
::
DType
);
MegRay
::
Backend
get_megray_backend
(
const
std
::
string
&
backend
);
/*!
* gather MegRay unique ids and build communicator, use hash for deduplication
*/
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录