Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b38e8225
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
b38e8225
编写于
7月 18, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
7月 20, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mgb/opr-mm): update megray communicator init interface and fix ci
GitOrigin-RevId: 55c59879f2cda27f678aaa55c44120c698709a07
上级
5e912edd
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
86 addition
and
67 deletion
+86
-67
src/opr-mm/impl/group_manager.cpp
src/opr-mm/impl/group_manager.cpp
+20
-18
src/opr-mm/impl/megray_helper.cpp
src/opr-mm/impl/megray_helper.cpp
+15
-3
src/opr-mm/impl/mm_handler.cpp
src/opr-mm/impl/mm_handler.cpp
+19
-19
src/opr-mm/include/megbrain/opr/group_manager.h
src/opr-mm/include/megbrain/opr/group_manager.h
+12
-11
src/opr-mm/include/megbrain/opr/mm_handler.h
src/opr-mm/include/megbrain/opr/mm_handler.h
+2
-2
src/opr-mm/proto/mm_handler.proto
src/opr-mm/proto/mm_handler.proto
+12
-9
src/opr-mm/test/mock_client.h
src/opr-mm/test/mock_client.h
+5
-4
third_party/MegRay
third_party/MegRay
+1
-1
未找到文件。
src/opr-mm/impl/group_manager.cpp
浏览文件 @
b38e8225
...
...
@@ -139,27 +139,29 @@ GroupManager::RegisterInfo GroupManager::opr_register(const std::string& key,
return
ret
;
}
std
::
vector
<
std
::
string
>
GroupManager
::
gather_uid
(
const
std
::
string
&
uid
,
const
std
::
string
&
key
,
uint32_t
size
,
uint32_t
rank
)
{
std
::
unique_lock
<
std
::
mutex
>
lk
{
m_key2uids_mtx
};
if
(
m_key2uids_size
[
key
]
==
0
)
m_key2uids
[
key
].
resize
(
size
);
m_key2uids
[
key
][
rank
]
=
uid
;
m_key2uids_size
[
key
]
++
;
if
(
m_key2uids_size
[
key
]
==
size
)
{
m_key2uids_flag
[
key
]
=
true
;
m_gather_uid_cv
.
notify_all
();
void
GroupManager
::
bcast_addr
(
std
::
string
&
master_ip
,
int
&
port
,
const
std
::
string
&
key
,
uint32_t
size
,
uint32_t
rank
,
uint32_t
root
)
{
std
::
unique_lock
<
std
::
mutex
>
lk
{
m_key2addr_mtx
};
if
(
rank
==
root
)
{
m_key2master_ip
[
key
]
=
master_ip
;
m_key2port
[
key
]
=
port
;
}
m_key2addr_size
[
key
]
++
;
if
(
m_key2addr_size
[
key
]
==
size
)
{
m_key2addr_flag
[
key
]
=
true
;
m_bcast_cv
.
notify_all
();
}
else
{
m_
gather_uid
_cv
.
wait
(
lk
,
[
&
]
{
return
m_key2
uids
_flag
.
count
(
key
)
>
0
;
});
m_
bcast
_cv
.
wait
(
lk
,
[
&
]
{
return
m_key2
addr
_flag
.
count
(
key
)
>
0
;
});
}
auto
uids
=
m_key2uids
[
key
];
m_key2uids_size
[
key
]
--
;
if
(
m_key2uids_size
[
key
]
==
0
)
{
m_key2uids
.
erase
(
key
);
m_key2uids_flag
.
erase
(
key
);
master_ip
=
m_key2master_ip
[
key
];
port
=
m_key2port
[
key
];
m_key2addr_size
[
key
]
--
;
if
(
m_key2addr_size
[
key
]
==
0
)
{
m_key2master_ip
.
erase
(
key
);
m_key2port
.
erase
(
key
);
m_key2addr_flag
.
erase
(
key
);
}
return
uids
;
}
void
GroupManager
::
set_output_shape
(
const
std
::
string
&
key
,
...
...
src/opr-mm/impl/megray_helper.cpp
浏览文件 @
b38e8225
...
...
@@ -44,10 +44,22 @@ std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm(
std
::
shared_ptr
<
MegRay
::
Communicator
>
comm
;
if
(
!
sm_instance
->
find
(
hash
,
comm
))
{
uint32_t
root
=
0
;
std
::
string
master_ip
;
int
port
=
0
;
if
(
rank
==
root
)
{
char
*
c
=
MegRay
::
get_host_ip
();
master_ip
=
std
::
string
(
c
);
delete
c
;
port
=
MegRay
::
get_free_port
();
auto
ret
=
MegRay
::
create_server
(
size
,
port
);
mgb_assert
(
ret
==
MegRay
::
Status
::
MEGRAY_OK
);
}
group_client
->
bcast_addr
(
master_ip
,
port
,
key
,
size
,
rank
,
root
);
comm
=
MegRay
::
get_communicator
(
size
,
rank
,
backend
);
auto
uid
=
comm
->
get_uid
();
auto
uids
=
group_client
->
gather_uid
(
uid
,
key
,
size
,
rank
);
mgb_assert
(
comm
->
init
(
uids
)
==
MegRay
::
Status
::
MEGRAY_OK
);
auto
ret
=
comm
->
init
(
master_ip
.
c_str
(),
port
);
mgb_assert
(
ret
==
MegRay
::
Status
::
MEGRAY_OK
);
sm_instance
->
emplace
(
hash
,
comm
);
}
return
comm
;
...
...
src/opr-mm/impl/mm_handler.cpp
浏览文件 @
b38e8225
...
...
@@ -41,7 +41,7 @@ public:
RUNSERVER
(
opr_register
);
RUNSERVER
(
set_output_shape
);
RUNSERVER
(
get_output_shape
);
RUNSERVER
(
gather_uid
);
RUNSERVER
(
bcast_addr
);
RUNSERVER
(
group_barrier
);
mgb_assert
(
false
,
"invalid rpc request"
);
}
...
...
@@ -49,7 +49,7 @@ private:
void
opr_register
(
void
*
input_ptr
,
size_t
input_len
,
std
::
string
*
output
);
void
set_output_shape
(
void
*
input_ptr
,
size_t
input_len
,
std
::
string
*
output
);
void
get_output_shape
(
void
*
input_ptr
,
size_t
input_len
,
std
::
string
*
output
);
void
gather_uid
(
void
*
input_ptr
,
size_t
input_len
,
std
::
string
*
output
);
void
bcast_addr
(
void
*
input_ptr
,
size_t
input_len
,
std
::
string
*
output
);
void
group_barrier
(
void
*
input_ptr
,
size_t
input_len
,
std
::
string
*
output
);
private:
...
...
@@ -101,15 +101,14 @@ void GroupServerProxy::get_output_shape(void* input_ptr, size_t input_len,
rsp
.
SerializeToString
(
output
);
}
void
GroupServerProxy
::
gather_uid
(
void
*
input_ptr
,
size_t
input_len
,
void
GroupServerProxy
::
bcast_addr
(
void
*
input_ptr
,
size_t
input_len
,
std
::
string
*
output
)
{
INFO_INIT
(
mm_handler
,
GatherUid
);
auto
uid
=
req
.
uid
();
auto
uids
=
m_mgr
.
gather_uid
(
uid
,
req
.
key
(),
req
.
size
(),
req
.
rank
());
for
(
size_t
i
=
0
;
i
<
uids
.
size
();
i
++
)
{
rsp
.
add_uids
();
rsp
.
set_uids
(
i
,
uids
[
i
].
data
(),
uids
[
i
].
size
());
}
INFO_INIT
(
mm_handler
,
BcastAddr
);
std
::
string
master_ip
=
req
.
master_ip
();
int
port
=
req
.
port
();
m_mgr
.
bcast_addr
(
master_ip
,
port
,
req
.
key
(),
req
.
size
(),
req
.
rank
(),
req
.
root
());
rsp
.
set_master_ip
(
master_ip
);
rsp
.
set_port
(
port
);
rsp
.
SerializeToString
(
output
);
}
...
...
@@ -184,19 +183,20 @@ TensorShape GroupClientProxy::get_output_shape(const std::string& key) {
}
return
shape
;
}
std
::
vector
<
std
::
string
>
GroupClientProxy
::
gather_uid
(
const
std
::
string
&
uid
,
const
std
::
string
&
key
,
uint32_t
size
,
uint32_t
rank
)
{
INFO_INIT
(
mm_handler
,
gather_uid
,
GatherUid
);
req
.
set_uid
(
uid
.
data
(),
uid
.
size
());
void
GroupClientProxy
::
bcast_addr
(
std
::
string
&
master_ip
,
int
&
port
,
const
std
::
string
&
key
,
uint32_t
size
,
uint32_t
rank
,
uint32_t
root
)
{
INFO_INIT
(
mm_handler
,
bcast_addr
,
BcastAddr
);
req
.
set_master_ip
(
master_ip
.
data
(),
master_ip
.
size
());
req
.
set_port
(
port
);
req
.
set_key
(
key
.
data
(),
key
.
size
());
req
.
set_size
(
size
);
req
.
set_rank
(
rank
);
req
.
set_root
(
root
);
SOLVE_REQUEST
(
func_name
,
req
,
rsp
);
std
::
vector
<
std
::
string
>
rst
;
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
rst
.
push_back
(
rsp
.
uids
(
i
));
}
return
rst
;
master_ip
=
rsp
.
master_ip
();
port
=
rsp
.
port
();
}
uint32_t
GroupClientProxy
::
group_barrier
(
uint32_t
size
,
uint32_t
rank
)
{
...
...
src/opr-mm/include/megbrain/opr/group_manager.h
浏览文件 @
b38e8225
...
...
@@ -82,9 +82,9 @@ class GroupManager {
RegisterInfo
opr_register
(
const
std
::
string
&
key
,
size_t
nr_devices
,
bool
is_root
,
int
rank
,
uint64_t
comp_node_hash
);
//!
gather uids from all ranks
std
::
vector
<
std
::
string
>
gather_uid
(
const
std
::
string
&
uid
,
const
std
::
string
&
key
,
uint32_t
size
,
uint32_t
rank
);
//!
broadcast master_ip and port
void
bcast_addr
(
std
::
string
&
master_ip
,
int
&
port
,
const
std
::
string
&
key
,
uint32_t
size
,
uint32_t
rank
,
uint32_t
root
);
//! Set output shape of this key
void
set_output_shape
(
const
std
::
string
&
key
,
const
TensorShape
&
shape
);
...
...
@@ -102,12 +102,13 @@ class GroupManager {
std
::
unordered_map
<
std
::
string
,
GroupInfo
>
m_key2group_info
;
std
::
mutex
m_key2group_info_mtx
;
//! key -> uid
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
m_key2uids
;
std
::
unordered_map
<
std
::
string
,
uint32_t
>
m_key2uids_size
;
std
::
unordered_map
<
std
::
string
,
bool
>
m_key2uids_flag
;
std
::
mutex
m_key2uids_mtx
;
std
::
condition_variable
m_gather_uid_cv
;
//! key -> addr
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m_key2master_ip
;
std
::
unordered_map
<
std
::
string
,
int
>
m_key2port
;
std
::
unordered_map
<
std
::
string
,
uint32_t
>
m_key2addr_size
;
std
::
unordered_map
<
std
::
string
,
bool
>
m_key2addr_flag
;
std
::
mutex
m_key2addr_mtx
;
std
::
condition_variable
m_bcast_cv
;
//! barrier
uint32_t
m_barrier_size
;
...
...
@@ -133,8 +134,8 @@ class GroupClient {
bool
is_root
,
int
rank
,
uint64_t
comp_node_hash
)
=
0
;
virtual
std
::
vector
<
std
::
string
>
gather_uid
(
const
std
::
string
&
uid
,
const
std
::
string
&
key
,
uint32_t
size
,
uint32_t
rank
)
=
0
;
virtual
void
bcast_addr
(
std
::
string
&
master_ip
,
int
&
port
,
const
std
::
string
&
key
,
uint32_t
size
,
uint32_t
rank
,
uint32_t
root
)
=
0
;
virtual
void
set_output_shape
(
const
std
::
string
&
key
,
const
TensorShape
&
shape
)
=
0
;
...
...
src/opr-mm/include/megbrain/opr/mm_handler.h
浏览文件 @
b38e8225
...
...
@@ -37,8 +37,8 @@ public:
int
rank
,
uint64_t
comp_node_hash
)
override
;
std
::
vector
<
std
::
string
>
gather_uid
(
const
std
::
string
&
uid
,
const
std
::
string
&
key
,
uint32_t
size
,
uint32_t
rank
)
override
;
void
bcast_addr
(
std
::
string
&
master_ip
,
int
&
port
,
const
std
::
string
&
key
,
uint32_t
size
,
uint32_t
rank
,
uint32_t
root
)
override
;
void
set_output_shape
(
const
std
::
string
&
key
,
const
TensorShape
&
shape
)
override
;
...
...
src/opr-mm/proto/mm_handler.proto
浏览文件 @
b38e8225
...
...
@@ -16,15 +16,18 @@ message OprRegisterResponse {
int32
root_rank
=
3
;
}
message
GatherUidRequest
{
bytes
uid
=
1
;
string
key
=
2
;
uint32
size
=
3
;
uint32
rank
=
4
;
}
message
GatherUidResponse
{
repeated
bytes
uids
=
1
;
message
BcastAddrRequest
{
string
master_ip
=
1
;
int32
port
=
2
;
string
key
=
3
;
uint32
size
=
4
;
uint32
rank
=
5
;
uint32
root
=
6
;
}
message
BcastAddrResponse
{
string
master_ip
=
1
;
int32
port
=
2
;
}
message
SetOutputShapeRequest
{
...
...
src/opr-mm/test/mock_client.h
浏览文件 @
b38e8225
...
...
@@ -29,13 +29,14 @@ class MockGroupClient final : public opr::GroupClient {
}
RegisterInfo
opr_register
(
const
std
::
string
&
key
,
size_t
nr_devices
,
bool
is_root
,
int
rank
,
uint64_t
comp_node_hash
)
{
bool
is_root
,
int
rank
,
uint64_t
comp_node_hash
)
override
{
return
m_mgr
.
opr_register
(
key
,
nr_devices
,
is_root
,
rank
,
comp_node_hash
);
}
std
::
vector
<
std
::
string
>
gather_uid
(
const
std
::
string
&
uid
,
const
std
::
string
&
key
,
uint32_t
size
,
uint32_t
rank
)
{
return
m_mgr
.
gather_uid
(
uid
,
key
,
size
,
rank
);
void
bcast_addr
(
std
::
string
&
master_ip
,
int
&
port
,
const
std
::
string
&
key
,
uint32_t
size
,
uint32_t
rank
,
uint32_t
root
)
override
{
return
m_mgr
.
bcast_addr
(
master_ip
,
port
,
key
,
size
,
rank
,
root
);
}
void
set_output_shape
(
const
std
::
string
&
key
,
...
...
MegRay
@
e14e4f84
比较
d06c215d
...
e14e4f84
Subproject commit
d06c215dc1425fa932e20ecfaab7b07c0343a5bc
Subproject commit
e14e4f84c1349598ba17c49923168db47a4e9642
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录