From 9c92701f63d36e1962e3b0d125bdb230c3f1cf8c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 30 Dec 2020 18:29:22 +0800 Subject: [PATCH] feat(mge): support python -m megengine.distributed.server GitOrigin-RevId: f1e5c8e3cf441248157060efc04b8ff9a19b154b --- .../python/megengine/distributed/server.py | 37 ++++++++++++++----- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/imperative/python/megengine/distributed/server.py b/imperative/python/megengine/distributed/server.py index 8567e5702..955d28df0 100644 --- a/imperative/python/megengine/distributed/server.py +++ b/imperative/python/megengine/distributed/server.py @@ -49,7 +49,7 @@ class Methods: def set_is_grad(self, key, is_grad): """ Mark send/recv need gradiants by key. - + :param key: key to match send/recv op. :param is_grad: whether this op need grad. """ @@ -61,7 +61,7 @@ class Methods: def check_is_grad(self, key): """ Check whether send/recv need gradiants. - + :param key: key to match send/recv op. """ with self.lock: @@ -86,7 +86,7 @@ class Methods: def check_remote_tracer(self, key): """ Get tracer dict for send/recv op. - + :param key: key to match send/recv op. """ with self.lock: @@ -99,7 +99,7 @@ class Methods: def group_barrier(self, key, size): """ A barrier wait for all group member. - + :param key: group key to match each other. :param size: group size. """ @@ -136,7 +136,7 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): def _start_server(py_server_port, mm_server_port, queue): """ Start python distributed server and multiple machine server. - + :param py_server_port: python server port. :param mm_server_port: multiple machine server port. :param queue: server port will put in this queue, puts exception when process fails. @@ -205,7 +205,7 @@ class Client: def set_is_grad(self, key, is_grad): """ Mark send/recv need gradiants by key. - + :param key: key to match send/recv op. :param is_grad: whether this op need grad. """ @@ -214,7 +214,7 @@ class Client: def check_is_grad(self, key): """ Check whether send/recv need gradiants. - + :param key: key to match send/recv op. """ return self.proxy.check_is_grad(key) @@ -231,7 +231,7 @@ class Client: def check_remote_tracer(self, key): """ Get tracer dict for send/recv op. - + :param key: key to match send/recv op. """ return self.proxy.check_remote_tracer(key) @@ -239,7 +239,7 @@ class Client: def group_barrier(self, key, size): """ A barrier wait for all group member. - + :param key: group key to match each other. :param size: group size. """ @@ -252,3 +252,22 @@ class Client: def user_get(self, key): """Get user defined key-value pairs across processes.""" return self.proxy.user_get(key) + + +def main(port=0, verbose=True): + mm_server_port = create_mm_server("0.0.0.0", 0) + server = ThreadXMLRPCServer(("0.0.0.0", port), logRequests=verbose) + server.register_instance(Methods(mm_server_port)) + _, port = server.server_address + print("serving on port", port) + server.serve_forever() + + +if __name__ == "__main__": + import argparse + + ap = argparse.ArgumentParser() + ap.add_argument("-p", "--port", type=int, default=0) + ap.add_argument("-v", "--verbose", type=bool, default=True) + args = ap.parse_args() + main(port=args.port, verbose=args.verbose) -- GitLab