提交 f5f86a05 编写于 作者: M Megvii Engine Team

docs(mge/distributed): add distributed.server docs

GitOrigin-RevId: 929d6adfcc2e5301c8bedf592871acc3e06ea126
上级 6c5cf25f
......@@ -21,6 +21,12 @@ from .util import get_free_ports
class Methods:
"""Distributed Server Method.
Used for exchange information between distributed nodes.
:param mm_server_port: multiple machine rpc server port.
"""
def __init__(self, mm_server_port):
self.lock = threading.Lock()
self.mm_server_port = mm_server_port
......@@ -31,51 +37,65 @@ class Methods:
self.dict_barrier_event = defaultdict(threading.Event)
def connect(self):
"""Method for checking connection success."""
return True
def get_mm_server_port(self):
"""Get multiple machine rpc server port."""
return self.mm_server_port
def set_is_grad(self, rank_peer, is_grad):
def set_is_grad(self, key, is_grad):
"""Mark send/recv need gradiants by key.
:param key: key to match send/recv op.
:param is_grad: whether this op need grad.
"""
with self.lock:
future = self.dict_is_grad[rank_peer]
future = self.dict_is_grad[key]
future.set(is_grad)
return True
def check_is_grad(self, rank_peer):
def check_is_grad(self, key):
"""Check whether send/recv need gradiants.
:param key: key to match send/recv op.
"""
with self.lock:
future = self.dict_is_grad[rank_peer]
future = self.dict_is_grad[key]
ret = future.get()
with self.lock:
del self.dict_is_grad[rank_peer]
del self.dict_is_grad[key]
return ret
def set_remote_tracer(self, rank_peer, tracer_set):
def set_remote_tracer(self, key, tracer_set):
"""Set tracer dict for tracing send/recv op.
:param key: key to match send/recv op.
:param tracer_set: valid tracer set.
"""
with self.lock:
future = self.dict_remote_tracer[rank_peer]
future = self.dict_remote_tracer[key]
future.set(tracer_set)
return True
def check_remote_tracer(self, rank_peer):
def check_remote_tracer(self, key):
"""Get tracer dict for send/recv op.
:param key: key to match send/recv op.
"""
with self.lock:
future = self.dict_remote_tracer[rank_peer]
future = self.dict_remote_tracer[key]
ret = future.get()
with self.lock:
del self.dict_remote_tracer[rank_peer]
del self.dict_remote_tracer[key]
return ret
def set_pack_list(self, key, pack_list):
with self.lock:
future = self.dict_pack_list[key]
future.set(pack_list)
return True
def get_pack_list(self, key):
with self.lock:
future = self.dict_pack_list[key]
return future.get()
def group_barrier(self, key, size):
"""A barrier wait for all group member.
:param key: group key to match each other.
:param size: group size.
"""
with self.lock:
self.dict_barrier_counter[key] += 1
counter = self.dict_barrier_counter[key]
......@@ -94,12 +114,23 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
def start_server(py_server_port, mm_server_port):
"""Start python distributed server and multiple machine server.
:param py_server_port: python server port.
:param mm_server_port: multiple machine server port.
"""
server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False)
server.register_instance(Methods(mm_server_port))
server.serve_forever()
class Server:
"""Distributed Server for distributed training.
Should be running at master node.
:param port: python server port.
"""
def __init__(self, port):
self.py_server_port = get_free_ports(1)[0] if port == 0 else port
self.mm_server_port = create_mm_server("0.0.0.0", 0)
......@@ -112,12 +143,19 @@ class Server:
class Client:
"""Distributed Client for distributed training.
:param master_ip: ip address of master node.
:param port: port of server at master node.
"""
def __init__(self, master_ip, port):
self.master_ip = master_ip
self.port = port
self.connect()
def connect(self):
"""Check connection success."""
while True:
try:
self.proxy = ServerProxy(
......@@ -129,25 +167,43 @@ class Client:
time.sleep(1)
def get_mm_server_port(self):
"""Get multiple machine server port."""
return self.proxy.get_mm_server_port()
def set_is_grad(self, rank_peer, is_grad):
self.proxy.set_is_grad(rank_peer, is_grad)
def check_is_grad(self, rank_peer):
return self.proxy.check_is_grad(rank_peer)
def set_remote_tracer(self, rank_peer, tracer_set):
self.proxy.set_remote_tracer(rank_peer, tracer_set)
def check_remote_tracer(self, rank_peer):
return self.proxy.check_remote_tracer(rank_peer)
def set_pack_list(self, key, pack_list):
self.proxy.set_pack_list(key, pack_list)
def get_pack_list(self, key):
return self.proxy.get_pack_list(key)
def set_is_grad(self, key, is_grad):
"""Mark send/recv need gradiants by key.
:param key: key to match send/recv op.
:param is_grad: whether this op need grad.
"""
self.proxy.set_is_grad(key, is_grad)
def check_is_grad(self, key):
"""Check whether send/recv need gradiants.
:param key: key to match send/recv op.
"""
return self.proxy.check_is_grad(key)
def set_remote_tracer(self, key, tracer_set):
"""Set tracer dict for tracing send/recv op.
:param key: key to match send/recv op.
:param tracer_set: valid tracer set.
"""
self.proxy.set_remote_tracer(key, tracer_set)
def check_remote_tracer(self, key):
"""Get tracer dict for send/recv op.
:param key: key to match send/recv op.
"""
return self.proxy.check_remote_tracer(key)
def group_barrier(self, key, size):
"""A barrier wait for all group member.
:param key: group key to match each other.
:param size: group size.
"""
self.proxy.group_barrier(key, size)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册