Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
4c5141d6
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
4c5141d6
编写于
5月 21, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(dist): speed up bcast_val
GitOrigin-RevId: 21c4123b09480b425676681a16a50962141b1eda
上级
0c6ee228
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
28 addition
and
12 deletion
+28
-12
imperative/python/megengine/distributed/server.py
imperative/python/megengine/distributed/server.py
+28
-12
未找到文件。
imperative/python/megengine/distributed/server.py
浏览文件 @
4c5141d6
...
@@ -36,6 +36,7 @@ class Methods:
...
@@ -36,6 +36,7 @@ class Methods:
self
.
dict_barrier_counter
=
defaultdict
(
int
)
self
.
dict_barrier_counter
=
defaultdict
(
int
)
self
.
dict_barrier_event
=
defaultdict
(
threading
.
Event
)
self
.
dict_barrier_event
=
defaultdict
(
threading
.
Event
)
self
.
user_dict
=
defaultdict
(
partial
(
Future
,
False
))
self
.
user_dict
=
defaultdict
(
partial
(
Future
,
False
))
self
.
bcast_dict
=
{}
def
connect
(
self
):
def
connect
(
self
):
"""Method for checking connection success."""
"""Method for checking connection success."""
...
@@ -127,6 +128,23 @@ class Methods:
...
@@ -127,6 +128,23 @@ class Methods:
future
=
self
.
user_dict
[
key
]
future
=
self
.
user_dict
[
key
]
return
future
.
get
()
return
future
.
get
()
def
bcast_val
(
self
,
val
,
key
,
size
):
with
self
.
lock
:
if
key
not
in
self
.
bcast_dict
:
self
.
bcast_dict
[
key
]
=
[
Future
(
False
),
size
]
arr
=
self
.
bcast_dict
[
key
]
if
val
is
not
None
:
arr
[
0
].
set
(
val
)
val
=
None
else
:
val
=
arr
[
0
].
get
()
with
self
.
lock
:
cnt
=
arr
[
1
]
-
1
arr
[
1
]
=
cnt
if
cnt
==
0
:
del
self
.
bcast_dict
[
key
]
return
val
class
ThreadXMLRPCServer
(
ThreadingMixIn
,
SimpleXMLRPCServer
):
class
ThreadXMLRPCServer
(
ThreadingMixIn
,
SimpleXMLRPCServer
):
pass
pass
...
@@ -142,7 +160,9 @@ def _start_server(py_server_port, queue):
...
@@ -142,7 +160,9 @@ def _start_server(py_server_port, queue):
"""
"""
try
:
try
:
mm_server_port
=
create_mm_server
(
"0.0.0.0"
,
0
)
mm_server_port
=
create_mm_server
(
"0.0.0.0"
,
0
)
server
=
ThreadXMLRPCServer
((
"0.0.0.0"
,
py_server_port
),
logRequests
=
False
)
server
=
ThreadXMLRPCServer
(
(
"0.0.0.0"
,
py_server_port
),
logRequests
=
False
,
allow_none
=
True
)
server
.
register_instance
(
Methods
(
mm_server_port
))
server
.
register_instance
(
Methods
(
mm_server_port
))
_
,
py_server_port
=
server
.
server_address
_
,
py_server_port
=
server
.
server_address
queue
.
put
((
py_server_port
,
mm_server_port
))
queue
.
put
((
py_server_port
,
mm_server_port
))
...
@@ -185,13 +205,14 @@ class Client:
...
@@ -185,13 +205,14 @@ class Client:
self
.
master_ip
=
master_ip
self
.
master_ip
=
master_ip
self
.
port
=
port
self
.
port
=
port
self
.
connect
()
self
.
connect
()
self
.
bcast_dict
=
defaultdict
(
lambda
:
0
)
def
connect
(
self
):
def
connect
(
self
):
"""Check connection success."""
"""Check connection success."""
while
True
:
while
True
:
try
:
try
:
self
.
proxy
=
ServerProxy
(
self
.
proxy
=
ServerProxy
(
"http://{}:{}"
.
format
(
self
.
master_ip
,
self
.
port
)
"http://{}:{}"
.
format
(
self
.
master_ip
,
self
.
port
)
,
allow_none
=
True
)
)
if
self
.
proxy
.
connect
():
if
self
.
proxy
.
connect
():
break
break
...
@@ -247,22 +268,17 @@ class Client:
...
@@ -247,22 +268,17 @@ class Client:
def
user_set
(
self
,
key
,
val
):
def
user_set
(
self
,
key
,
val
):
"""Set user defined key-value pairs across processes."""
"""Set user defined key-value pairs across processes."""
self
.
proxy
.
user_set
(
key
,
val
)
return
self
.
proxy
.
user_set
(
key
,
val
)
def
user_get
(
self
,
key
):
def
user_get
(
self
,
key
):
"""Get user defined key-value pairs across processes."""
"""Get user defined key-value pairs across processes."""
return
self
.
proxy
.
user_get
(
key
)
return
self
.
proxy
.
user_get
(
key
)
def
bcast_val
(
self
,
val
,
key
,
size
):
def
bcast_val
(
self
,
val
,
key
,
size
):
if
val
is
not
None
:
idx
=
self
.
bcast_dict
[
key
]
+
1
self
.
user_set
(
key
+
"_sync"
,
val
)
self
.
bcast_dict
[
key
]
=
idx
self
.
group_barrier
(
key
,
size
)
key
=
key
+
"_bcast_"
+
str
(
idx
)
self
.
group_barrier
(
key
,
size
)
return
self
.
proxy
.
bcast_val
(
val
,
key
,
size
)
else
:
self
.
group_barrier
(
key
,
size
)
val
=
self
.
user_get
(
key
+
"_sync"
)
self
.
group_barrier
(
key
,
size
)
return
val
def
main
(
port
=
0
,
verbose
=
True
):
def
main
(
port
=
0
,
verbose
=
True
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录