From bbde58fb9b4846cdd68e4eeb7aeea6454a8172f6 Mon Sep 17 00:00:00 2001 From: Bo Zhou Date: Wed, 27 Feb 2019 12:20:02 +0800 Subject: [PATCH] first version of network communication (#49) * first version of network communication * fix code styple problems * add a script to get machine's information * code styple problems#2 * fix unit test problems * update dockfile to fix the installation issue of cmake * thread-saftey ensurance & copright * resolve comments --- .copyright.hook | 3 - Dockerfile | 1 + README.md | 2 +- parl/framework/tests/agent_base_test.py | 4 +- parl/remote/decorator.py | 203 +++++++++++++++++++++ parl/remote/server.py | 86 +++++++++ parl/utils/__init__.py | 2 +- parl/utils/communication.py | 71 +++++++ parl/utils/{gputils.py => machine_info.py} | 17 +- 9 files changed, 380 insertions(+), 9 deletions(-) create mode 100644 parl/remote/decorator.py create mode 100644 parl/remote/server.py create mode 100644 parl/utils/communication.py rename parl/utils/{gputils.py => machine_info.py} (83%) diff --git a/.copyright.hook b/.copyright.hook index 09afff2..312b0fa 100644 --- a/.copyright.hook +++ b/.copyright.hook @@ -103,9 +103,6 @@ def main(argv=None): first_line = fd.readline() second_line = fd.readline() if "COPYRIGHT (C)" in first_line.upper(): continue - if first_line.startswith("#!") or PYTHON_ENCODE.match( - second_line) != None or PYTHON_ENCODE.match(first_line) != None: - continue original_contents = io.open(filename, encoding="utf-8").read() new_contents = generate_copyright( COPYRIGHT, lang_type(filename)) + original_contents diff --git a/Dockerfile b/Dockerfile index 27fe916..9080cfb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,6 +17,7 @@ FROM paddlepaddle/paddle:1.1.0-gpu-cuda9.0-cudnn7 +Run apt-get update RUN apt-get install -y cmake RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple gym RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple details diff --git a/README.md b/README.md index 0d25da3..a85a171 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,7 @@ class AtariModel(parl.Model): def __init__(self, img_shape, action_dim): # define your layers self.cnn1 = layers.conv_2d(num_filters=32, filter_size=5, - stride=[1, 1], padding=[2, 2], act='relu') + stride=1, padding=2, act='relu') ... self.fc1 = layers.fc(action_dim) def value(self, img): diff --git a/parl/framework/tests/agent_base_test.py b/parl/framework/tests/agent_base_test.py index f00485a..f93c1c3 100644 --- a/parl/framework/tests/agent_base_test.py +++ b/parl/framework/tests/agent_base_test.py @@ -19,7 +19,7 @@ from paddle import fluid from parl.framework.agent_base import Agent from parl.framework.algorithm_base import Algorithm from parl.framework.model_base import Model -from parl.utils import gputils +from parl.utils.machine_info import get_gpu_count class TestModel(Model): @@ -66,7 +66,7 @@ class AgentBaseTest(unittest.TestCase): self.algorithm = TestAlgorithm(self.model) def test_agent_with_gpu(self): - if gputils.get_gpu_count() > 0: + if get_gpu_count() > 0: agent = TestAgent(self.algorithm, gpu_id=0) obs = np.random.random([3, 10]).astype('float32') output_np = agent.predict(obs) diff --git a/parl/remote/decorator.py b/parl/remote/decorator.py new file mode 100644 index 0000000..08ae715 --- /dev/null +++ b/parl/remote/decorator.py @@ -0,0 +1,203 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import zmq +from parl.utils import logger +import numpy as np +from parl.utils.communication import dumps_argument, loads_argument +from parl.utils.communication import dumps_return, loads_return +from parl.utils.machine_info import get_ip_address +import pyarrow +import threading +""" +Three steps to create a remote class -- +1. add a decroator(@virtual) before the definition of the class. +2. create an instance of remote class +3. call function `remote_run` with server address + +@virtual +Class Simulator(object): + ... + +sim = Simulator() +sim.remote_run(server_ip='172.18.202.45', port=8001) + +""" + + +def virtual(cls, location='client'): + """ + Class wrapper for wrapping a normal class as a remote class that can run in different machines. + Two kinds of wrapper are provided for the client as well as the server. + + Args: + location(str): specify which wrapper to use, available locations: client/server. + users are expected to use `client`. + """ + assert location in ['client', 'server'], \ + 'Remote Class has to be placed at client side or server side.' + + class ClientWrapper(object): + """ + Wrapper for remote class at client side. After the decoration, + the initial class is able to be called to run any function at sever side. + """ + + def __init__(self, *args): + """ + Args: + args: arguments for the initialisation of the initial class. + """ + self.unwrapped = cls(*args) + self.conect_socket = None + self.reply_socket = None + + def create_reply_socket(self): + """ + In fact, we have also a socket server in client side. This server keeps running + and waits for requests (e.g. call a function) from server side. + """ + client_ip = get_ip_address() + context = zmq.Context() + socket = context.socket(zmq.REP) + free_port = None + for port in range(6000, 8000): + try: + socket.bind("tcp://*:{}".format(port)) + logger.info( + "[create_reply_socket] free_port:{}".format(port)) + free_port = port + break + except zmq.error.ZMQError: + logger.warn( + "[create_reply_socket]cannot bind port:{}, retry". + format(port)) + if free_port is not None: + return socket, client_ip, free_port + else: + logger.error( + "cannot find any available port from 6000 to 8000") + sys.exit(1) + + def connect_server(self, server_ip, server_port): + """ + create the connection between client side and server side. + + Args: + server_ip(str): the ip of the server. + server_port(int): the connection port of the server. + """ + self.reply_socket, local_ip, local_port = self.create_reply_socket( + ) + + logger.info("connecting {}:{}".format(server_ip, server_port)) + context = zmq.Context() + socket = context.socket(zmq.REQ) + socket.connect("tcp://{}:{}".format(server_ip, server_port)) + client_id = np.random.randint(int(1e18)) + logger.info("client_id:{}".format(client_id)) + socket.send_string('{}:{} {}'.format(local_ip, local_port, + client_id)) + message = socket.recv_string() + logger.info("[connect_server] done, message from server:{}".format( + message)) + self.connect_socket = socket + + def __getattr__(self, attr): + """ + Call the function of the initial class. The wrapped class do not have + same functions as the unwrapped one. + We have to call the function of the function in unwrapped class, + This implementation utilise a function wrapper. + """ + if hasattr(self.unwrapped, attr): + + def wrapper(*args, **kw): + return getattr(self.unwrapped, attr)(*args, **kw) + + return wrapper + raise AttributeError(attr) + + def remote_run(self, server_ip, server_port): + """ + connect server and wait for requires of running functions from server side. + + Args: + server_ip(str): server's ip + server_port(int): server's port + """ + self.connect_server(server_ip, server_port) + while True: + function_name = self.reply_socket.recv_string() + self.reply_socket.send_string("OK") + data = self.reply_socket.recv() + args, kw = loads_argument(data) + ret = getattr(self.unwrapped, function_name)(*args, **kw) + ret = dumps_return(ret) + self.reply_socket.send(ret) + + class ServerWrapper(object): + """ + Wrapper for remote class at server side. + """ + + def __init__(self, *args): + """ + Args: + args: arguments used to initialize the initial class + """ + self.unwrapped = (cls(*args)).unwrapped + self.command_socket = None + self.internal_lock = threading.Lock() + + def connect_client(self, client_info): + """ + build another connection with the client to send command to the client. + """ + client_address, client_id = client_info.split(' ') + context = zmq.Context() + socket = context.socket(zmq.REQ) + logger.info( + "[connect_client] client_address:{}".format(client_address)) + socket.connect("tcp://{}".format(client_address)) + self.command_socket = socket + self.client_id = client_id + + def __getattr__(self, attr): + """ + Run the function at client side. we also implement this through a wrapper. + + Args: + attr(str): a function name specify which function to run. + """ + if hasattr(self.unwrapped, attr): + + def wrapper(*args, **kw): + self.internal_lock.acquire() + self.command_socket.send_string(attr) + self.command_socket.recv_string() + data = dumps_argument(*args, **kw) + self.command_socket.send(data) + ret = self.command_socket.recv() + ret = loads_return(ret) + self.internal_lock.release() + return ret + + return wrapper + raise NotImplementedError() + + if location == 'client': + return ClientWrapper + else: + return ServerWrapper diff --git a/parl/remote/server.py b/parl/remote/server.py new file mode 100644 index 0000000..8c9b6be --- /dev/null +++ b/parl/remote/server.py @@ -0,0 +1,86 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import queue +from decorator import virtual +from parl.utils import logger +""" +3 steps to finish communication with remote clients. +1. Create a server: +2. Declare the type of remote client +3. Get remote clients + +```python + server = Server() + server.bind(Simulator) + remote_client = server.get_client() +``` + +""" + + +class Server(object): + """ + Base class for network communcation. + """ + + def __init__(self, port): + """ + Args: + port(int): a local port used for network communication. + """ + context = zmq.Context() + socket = context.socket(zmq.REP) + socket.bind("tcp://*:{}".format(port)) + self.socket = socket + self.pool = queue.Queue() + self.cls = None + t = threading.Thread(target=self.wait_for_connection) + t.start() + + def wait_for_connection(self): + """ + A never-ending function keeps waiting for the connection for remote client. + It will put an available remote client into an internel client pool, and clients + can be obtained by calling `get_client`. + + Note that this function has been called inside the `__init__` function. + """ + while True: + client_info = self.socket.recv_string() + client_id = client_info.split(' ')[1] + new_client = virtual(self.cls, location='server')() + self.socket.send_string('Hello World! Client:{}'.format(client_id)) + new_client.connect_client(client_info) + self.pool.put(new_client) + + def get_client(self): + """ + A blocking function to obtain a connected client. + + Returns: + remote_client(self.cls): a **remote** instance that has all functions as the real one. + """ + return self.pool.get() + + def register_client(self, cls): + """ + Declare the type of remote class. + Let the server know which class to use as a remote client. + + Args: + cls(Class): A class decorated by @virtual. + """ + self.cls = cls diff --git a/parl/utils/__init__.py b/parl/utils/__init__.py index c434b9f..6da41cd 100644 --- a/parl/utils/__init__.py +++ b/parl/utils/__init__.py @@ -13,5 +13,5 @@ # limitations under the License. from parl.utils.utils import * -from parl.utils.gputils import * +from parl.utils.machine_info import * from parl.utils.replay_memory import * diff --git a/parl/utils/communication.py b/parl/utils/communication.py new file mode 100644 index 0000000..0101945 --- /dev/null +++ b/parl/utils/communication.py @@ -0,0 +1,71 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pyarrow + +__all__ = ['dumps_argument', 'loads_argument', 'dumps_return', 'loads_return'] + + +def dumps_argument(*args, **kwargs): + """ + + Serialize arguments passed to a function. + + args: + *args, **kwargs are general a commonly used representation of arguments in python. + + Returns: + Implementation-dependent object in bytes. + """ + return pyarrow.serialize([args, kwargs]).to_buffer() + + +def loads_argument(data): + """ + Restore bytes data to their initial data formats. + + Args: + data: the output of `dumps_argument`. + + Returns: + deserialized arguments [args, kwargs] + like the input of `dumps_argument`, args is a tuple, and kwargs is a dict + """ + return pyarrow.deserialize(data) + + +def dumps_return(data): + """ + Serialize the return data of a function. + + Args: + data: the output of a function. + + Returns: + Implementation-dependent object in bytes. + """ + return pyarrow.serialize(data).to_buffer() + + +def loads_return(data): + """ + Deserialize the data generated by `dumps_return`. + + Args: + data: the output of `dumps_return` + + Returns: + deserialized data + """ + return pyarrow.deserialize(data) diff --git a/parl/utils/gputils.py b/parl/utils/machine_info.py similarity index 83% rename from parl/utils/gputils.py rename to parl/utils/machine_info.py index 92ea83e..a29229b 100644 --- a/parl/utils/gputils.py +++ b/parl/utils/machine_info.py @@ -16,11 +16,24 @@ import os import subprocess from parl.utils import logger -__all__ = ['get_gpu_count'] +__all__ = ['get_gpu_count', 'get_ip_address'] + + +def get_ip_address(): + """ + get the IP address of the host. + """ + import socket + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 80)) + local_ip = s.getsockname()[0] + s.close() + return local_ip def get_gpu_count(): - """ get avaliable gpu count + """ + get avaliable gpu count Returns: gpu_count: int -- GitLab