Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b9918c32
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
396
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b9918c32
编写于
11月 18, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/distributed): support distributed key-value store
GitOrigin-RevId: b4abe8001459020a1b371a188eba856830ce86df
上级
ab82c8da
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
54 addition
and
0 deletion
+54
-0
imperative/python/megengine/distributed/server.py
imperative/python/megengine/distributed/server.py
+22
-0
imperative/python/test/unit/distributed/test_distributed.py
imperative/python/test/unit/distributed/test_distributed.py
+32
-0
未找到文件。
imperative/python/megengine/distributed/server.py
浏览文件 @
b9918c32
...
...
@@ -35,6 +35,7 @@ class Methods:
self
.
dict_pack_list
=
defaultdict
(
partial
(
Future
,
False
))
self
.
dict_barrier_counter
=
defaultdict
(
int
)
self
.
dict_barrier_event
=
defaultdict
(
threading
.
Event
)
self
.
user_dict
=
defaultdict
(
partial
(
Future
,
False
))
def
connect
(
self
):
"""Method for checking connection success."""
...
...
@@ -113,6 +114,19 @@ class Methods:
event
.
wait
()
return
True
def
user_set
(
self
,
key
,
val
):
"""Set user defined key-value pairs across processes."""
with
self
.
lock
:
future
=
self
.
user_dict
[
key
]
future
.
set
(
val
)
return
True
def
user_get
(
self
,
key
):
"""Get user defined key-value pairs across processes."""
with
self
.
lock
:
future
=
self
.
user_dict
[
key
]
return
future
.
get
()
class
ThreadXMLRPCServer
(
ThreadingMixIn
,
SimpleXMLRPCServer
):
pass
...
...
@@ -220,3 +234,11 @@ class Client:
:param size: group size.
"""
self
.
proxy
.
group_barrier
(
key
,
size
)
def
user_set
(
self
,
key
,
val
):
"""Set user defined key-value pairs across processes."""
self
.
proxy
.
user_set
(
key
,
val
)
def
user_get
(
self
,
key
):
"""Get user defined key-value pairs across processes."""
return
self
.
proxy
.
user_get
(
key
)
imperative/python/test/unit/distributed/test_distributed.py
浏览文件 @
b9918c32
...
...
@@ -195,6 +195,38 @@ def test_synchronized():
assert
p
.
exitcode
==
0
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Darwin"
,
reason
=
"do not imp GPU mode at macos now"
)
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Windows"
,
reason
=
"windows disable MGB_ENABLE_OPR_MM"
)
@
pytest
.
mark
.
skipif
(
get_device_count_by_fork
(
"gpu"
)
<
2
,
reason
=
"need more gpu device"
)
@
pytest
.
mark
.
isolated_distributed
def
test_user_set_get
():
world_size
=
2
port
=
dist
.
get_free_ports
(
1
)[
0
]
server
=
dist
.
Server
(
port
)
def
worker
(
rank
):
dist
.
init_process_group
(
"localhost"
,
port
,
world_size
,
rank
,
rank
)
# set in race condition
dist
.
get_client
().
user_set
(
"foo"
,
1
)
# get in race condition
ret
=
dist
.
get_client
().
user_get
(
"foo"
)
assert
ret
==
1
procs
=
[]
for
rank
in
range
(
world_size
):
p
=
mp
.
Process
(
target
=
worker
,
args
=
(
rank
,))
p
.
start
()
procs
.
append
(
p
)
for
p
in
procs
:
p
.
join
(
20
)
assert
p
.
exitcode
==
0
def
test_oprmm_hashable
():
lhs
=
(
CollectiveComm
(),
ParamPackConcat
(),
ParamPackSplit
())
rhs
=
(
CollectiveComm
(),
ParamPackConcat
(),
ParamPackSplit
())
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录