Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
f5f86a05
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
f5f86a05
编写于
9月 28, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
docs(mge/distributed): add distributed.server docs
GitOrigin-RevId: 929d6adfcc2e5301c8bedf592871acc3e06ea126
上级
6c5cf25f
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
94 addition
and
38 deletion
+94
-38
imperative/python/megengine/distributed/server.py
imperative/python/megengine/distributed/server.py
+94
-38
未找到文件。
imperative/python/megengine/distributed/server.py
浏览文件 @
f5f86a05
...
...
@@ -21,6 +21,12 @@ from .util import get_free_ports
class
Methods
:
"""Distributed Server Method.
Used for exchange information between distributed nodes.
:param mm_server_port: multiple machine rpc server port.
"""
def
__init__
(
self
,
mm_server_port
):
self
.
lock
=
threading
.
Lock
()
self
.
mm_server_port
=
mm_server_port
...
...
@@ -31,51 +37,65 @@ class Methods:
self
.
dict_barrier_event
=
defaultdict
(
threading
.
Event
)
def
connect
(
self
):
"""Method for checking connection success."""
return
True
def
get_mm_server_port
(
self
):
"""Get multiple machine rpc server port."""
return
self
.
mm_server_port
def
set_is_grad
(
self
,
rank_peer
,
is_grad
):
def
set_is_grad
(
self
,
key
,
is_grad
):
"""Mark send/recv need gradiants by key.
:param key: key to match send/recv op.
:param is_grad: whether this op need grad.
"""
with
self
.
lock
:
future
=
self
.
dict_is_grad
[
rank_peer
]
future
=
self
.
dict_is_grad
[
key
]
future
.
set
(
is_grad
)
return
True
def
check_is_grad
(
self
,
rank_peer
):
def
check_is_grad
(
self
,
key
):
"""Check whether send/recv need gradiants.
:param key: key to match send/recv op.
"""
with
self
.
lock
:
future
=
self
.
dict_is_grad
[
rank_peer
]
future
=
self
.
dict_is_grad
[
key
]
ret
=
future
.
get
()
with
self
.
lock
:
del
self
.
dict_is_grad
[
rank_peer
]
del
self
.
dict_is_grad
[
key
]
return
ret
def
set_remote_tracer
(
self
,
rank_peer
,
tracer_set
):
def
set_remote_tracer
(
self
,
key
,
tracer_set
):
"""Set tracer dict for tracing send/recv op.
:param key: key to match send/recv op.
:param tracer_set: valid tracer set.
"""
with
self
.
lock
:
future
=
self
.
dict_remote_tracer
[
rank_peer
]
future
=
self
.
dict_remote_tracer
[
key
]
future
.
set
(
tracer_set
)
return
True
def
check_remote_tracer
(
self
,
rank_peer
):
def
check_remote_tracer
(
self
,
key
):
"""Get tracer dict for send/recv op.
:param key: key to match send/recv op.
"""
with
self
.
lock
:
future
=
self
.
dict_remote_tracer
[
rank_peer
]
future
=
self
.
dict_remote_tracer
[
key
]
ret
=
future
.
get
()
with
self
.
lock
:
del
self
.
dict_remote_tracer
[
rank_peer
]
del
self
.
dict_remote_tracer
[
key
]
return
ret
def
set_pack_list
(
self
,
key
,
pack_list
):
with
self
.
lock
:
future
=
self
.
dict_pack_list
[
key
]
future
.
set
(
pack_list
)
return
True
def
get_pack_list
(
self
,
key
):
with
self
.
lock
:
future
=
self
.
dict_pack_list
[
key
]
return
future
.
get
()
def
group_barrier
(
self
,
key
,
size
):
"""A barrier wait for all group member.
:param key: group key to match each other.
:param size: group size.
"""
with
self
.
lock
:
self
.
dict_barrier_counter
[
key
]
+=
1
counter
=
self
.
dict_barrier_counter
[
key
]
...
...
@@ -94,12 +114,23 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
def
start_server
(
py_server_port
,
mm_server_port
):
"""Start python distributed server and multiple machine server.
:param py_server_port: python server port.
:param mm_server_port: multiple machine server port.
"""
server
=
ThreadXMLRPCServer
((
"0.0.0.0"
,
py_server_port
),
logRequests
=
False
)
server
.
register_instance
(
Methods
(
mm_server_port
))
server
.
serve_forever
()
class
Server
:
"""Distributed Server for distributed training.
Should be running at master node.
:param port: python server port.
"""
def
__init__
(
self
,
port
):
self
.
py_server_port
=
get_free_ports
(
1
)[
0
]
if
port
==
0
else
port
self
.
mm_server_port
=
create_mm_server
(
"0.0.0.0"
,
0
)
...
...
@@ -112,12 +143,19 @@ class Server:
class
Client
:
"""Distributed Client for distributed training.
:param master_ip: ip address of master node.
:param port: port of server at master node.
"""
def
__init__
(
self
,
master_ip
,
port
):
self
.
master_ip
=
master_ip
self
.
port
=
port
self
.
connect
()
def
connect
(
self
):
"""Check connection success."""
while
True
:
try
:
self
.
proxy
=
ServerProxy
(
...
...
@@ -129,25 +167,43 @@ class Client:
time
.
sleep
(
1
)
def
get_mm_server_port
(
self
):
"""Get multiple machine server port."""
return
self
.
proxy
.
get_mm_server_port
()
def
set_is_grad
(
self
,
rank_peer
,
is_grad
):
self
.
proxy
.
set_is_grad
(
rank_peer
,
is_grad
)
def
check_is_grad
(
self
,
rank_peer
):
return
self
.
proxy
.
check_is_grad
(
rank_peer
)
def
set_remote_tracer
(
self
,
rank_peer
,
tracer_set
):
self
.
proxy
.
set_remote_tracer
(
rank_peer
,
tracer_set
)
def
check_remote_tracer
(
self
,
rank_peer
):
return
self
.
proxy
.
check_remote_tracer
(
rank_peer
)
def
set_pack_list
(
self
,
key
,
pack_list
):
self
.
proxy
.
set_pack_list
(
key
,
pack_list
)
def
get_pack_list
(
self
,
key
):
return
self
.
proxy
.
get_pack_list
(
key
)
def
set_is_grad
(
self
,
key
,
is_grad
):
"""Mark send/recv need gradiants by key.
:param key: key to match send/recv op.
:param is_grad: whether this op need grad.
"""
self
.
proxy
.
set_is_grad
(
key
,
is_grad
)
def
check_is_grad
(
self
,
key
):
"""Check whether send/recv need gradiants.
:param key: key to match send/recv op.
"""
return
self
.
proxy
.
check_is_grad
(
key
)
def
set_remote_tracer
(
self
,
key
,
tracer_set
):
"""Set tracer dict for tracing send/recv op.
:param key: key to match send/recv op.
:param tracer_set: valid tracer set.
"""
self
.
proxy
.
set_remote_tracer
(
key
,
tracer_set
)
def
check_remote_tracer
(
self
,
key
):
"""Get tracer dict for send/recv op.
:param key: key to match send/recv op.
"""
return
self
.
proxy
.
check_remote_tracer
(
key
)
def
group_barrier
(
self
,
key
,
size
):
"""A barrier wait for all group member.
:param key: group key to match each other.
:param size: group size.
"""
self
.
proxy
.
group_barrier
(
key
,
size
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录